1use std::pin::Pin;
2use std::sync::atomic::Ordering;
3
4use futures::Stream;
5
6use crate::checkpoint::{Checkpoint, CheckpointStore};
7use crate::error::AgentError;
8use crate::loop_::AgentEvent;
9
10use super::Agent;
11use super::queueing::drain_messages_from_queue;
12
13fn invalid_state_snapshot(error: &serde_json::Error) -> std::io::Error {
14 std::io::Error::new(
15 std::io::ErrorKind::InvalidData,
16 format!("corrupted session state snapshot: {error}"),
17 )
18}
19
20fn restore_session_state(
21 snapshot: Option<&serde_json::Value>,
22) -> Result<crate::SessionState, std::io::Error> {
23 snapshot.map_or_else(
24 || Ok(crate::SessionState::new()),
25 |state_val| {
26 crate::SessionState::restore_from_snapshot(state_val.clone())
27 .map_err(|e| invalid_state_snapshot(&e))
28 },
29 )
30}
31
32impl Agent {
33 fn rebind_stream_fn_for_current_model(&mut self) {
36 if let Some((_, stream_fn)) = self.model_stream_fns.iter().find(|(m, _)| {
37 m.provider == self.state.model.provider && m.model_id == self.state.model.model_id
38 }) {
39 self.stream_fn = std::sync::Arc::clone(stream_fn);
40 }
41 }
42
43 pub async fn save_checkpoint(
50 &self,
51 id: impl Into<String>,
52 ) -> Result<Checkpoint, std::io::Error> {
53 let mut checkpoint = Checkpoint::new(
54 id,
55 &self.state.system_prompt,
56 &self.state.model.provider,
57 &self.state.model.model_id,
58 &self.state.messages,
59 );
60
61 {
62 let s = self
63 .session_state
64 .read()
65 .unwrap_or_else(std::sync::PoisonError::into_inner);
66 if !s.is_empty() {
67 checkpoint.state = Some(s.snapshot());
68 }
69 }
70
71 if let Some(ref store) = self.checkpoint_store {
72 store.save_checkpoint(checkpoint.clone()).await?;
73 }
74
75 Ok(checkpoint)
76 }
77
78 fn ensure_idle_for_checkpoint_restore(&mut self) -> Result<(), std::io::Error> {
79 self.check_not_running().map_err(|_| {
80 std::io::Error::new(
81 std::io::ErrorKind::WouldBlock,
82 "cannot restore checkpoint while agent is running",
83 )
84 })
85 }
86
87 pub fn restore_from_checkpoint(
101 &mut self,
102 checkpoint: &Checkpoint,
103 ) -> Result<(), std::io::Error> {
104 self.ensure_idle_for_checkpoint_restore()?;
105 let restored_messages =
106 checkpoint.restore_messages(self.custom_message_registry.as_deref());
107 let restored_state = restore_session_state(checkpoint.state.as_ref())?;
108
109 self.clear_transient_runtime_state();
110 self.state.messages = restored_messages;
111 self.state
112 .system_prompt
113 .clone_from(&checkpoint.system_prompt);
114 self.state.model.provider.clone_from(&checkpoint.provider);
115 self.state.model.model_id.clone_from(&checkpoint.model_id);
116 self.rebind_stream_fn_for_current_model();
117 *self
118 .session_state
119 .write()
120 .unwrap_or_else(std::sync::PoisonError::into_inner) = restored_state;
121
122 Ok(())
123 }
124
125 pub async fn load_and_restore_checkpoint(
131 &mut self,
132 id: &str,
133 ) -> Result<Option<Checkpoint>, std::io::Error> {
134 self.ensure_idle_for_checkpoint_restore()?;
135 let store = self
136 .checkpoint_store
137 .as_ref()
138 .ok_or_else(|| std::io::Error::other("no checkpoint store configured"))?;
139
140 let maybe = store.load_checkpoint(id).await?;
141 if let Some(ref checkpoint) = maybe {
142 self.restore_from_checkpoint(checkpoint)?;
143 }
144 Ok(maybe)
145 }
146
147 #[must_use]
149 pub fn checkpoint_store(&self) -> Option<&dyn CheckpointStore> {
150 self.checkpoint_store.as_deref()
151 }
152
153 pub fn pause(&mut self) -> Option<crate::checkpoint::LoopCheckpoint> {
167 if !self.loop_active.load(Ordering::Acquire) {
168 return None;
169 }
170
171 if let Some(ref token) = self.abort_controller {
172 tracing::info!("pausing agent loop");
173 token.cancel();
174 }
175
176 let mut pending_messages = self.pending_message_snapshot.snapshot();
177 pending_messages.extend(drain_messages_from_queue(&self.follow_up_queue));
178
179 let loop_ctx = self.loop_context_snapshot.snapshot();
186 let checkpoint_messages: &[crate::types::AgentMessage] = if let Some(ref ctx) = loop_ctx {
187 ctx.as_slice()
188 } else {
189 self.in_flight_messages
190 .as_deref()
191 .unwrap_or(&self.state.messages)
192 };
193
194 let mut checkpoint = crate::checkpoint::LoopCheckpoint::new(
195 &self.state.system_prompt,
196 &self.state.model.provider,
197 &self.state.model.model_id,
198 checkpoint_messages,
199 )
200 .with_pending_message_batch(&pending_messages)
201 .with_pending_steering_message_batch(&drain_messages_from_queue(&self.steering_queue));
202
203 let s = self
204 .session_state
205 .read()
206 .unwrap_or_else(std::sync::PoisonError::into_inner);
207 if !s.is_empty() {
208 checkpoint.state = Some(s.snapshot());
209 }
210 drop(s);
211
212 Some(checkpoint)
218 }
219
220 pub async fn resume(
222 &mut self,
223 checkpoint: &crate::checkpoint::LoopCheckpoint,
224 ) -> Result<crate::types::AgentResult, AgentError> {
225 self.check_not_running()?;
226 self.restore_from_loop_checkpoint(checkpoint)?;
227 self.continue_async().await
228 }
229
230 pub fn resume_stream(
232 &mut self,
233 checkpoint: &crate::checkpoint::LoopCheckpoint,
234 ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
235 self.check_not_running()?;
236 self.restore_from_loop_checkpoint(checkpoint)?;
237 self.continue_stream()
238 }
239
240 fn restore_from_loop_checkpoint(
241 &mut self,
242 checkpoint: &crate::checkpoint::LoopCheckpoint,
243 ) -> Result<(), AgentError> {
244 let restored_messages =
245 checkpoint.restore_messages(self.custom_message_registry.as_deref());
246 if restored_messages.is_empty() {
247 return Err(AgentError::NoMessages);
248 }
249 let restored_state =
250 restore_session_state(checkpoint.state.as_ref()).map_err(AgentError::stream)?;
251
252 self.clear_transient_runtime_state();
253 self.state.messages = restored_messages;
254 self.state
255 .system_prompt
256 .clone_from(&checkpoint.system_prompt);
257 self.state.model.provider.clone_from(&checkpoint.provider);
258 self.state.model.model_id.clone_from(&checkpoint.model_id);
259 self.rebind_stream_fn_for_current_model();
260 {
261 let mut s = self
262 .session_state
263 .write()
264 .unwrap_or_else(std::sync::PoisonError::into_inner);
265 *s = restored_state;
266 }
267
268 self.clear_queues();
271
272 for msg in checkpoint.restore_pending_messages(self.custom_message_registry.as_deref()) {
273 self.follow_up(msg);
274 }
275 for msg in
276 checkpoint.restore_pending_steering_messages(self.custom_message_registry.as_deref())
277 {
278 self.steer(msg);
279 }
280
281 tracing::info!(
282 messages = self.state.messages.len(),
283 "resuming agent loop from checkpoint"
284 );
285
286 Ok(())
287 }
288}
289
290#[cfg(all(test, feature = "testkit"))]
291mod tests {
292 use std::collections::HashMap;
293 use std::sync::Arc;
294 use std::sync::Mutex;
295
296 use tokio_util::sync::CancellationToken;
297
298 use crate::agent::Agent;
299 use crate::agent_options::AgentOptions;
300 use crate::checkpoint::{CheckpointFuture, CheckpointStore, LoopCheckpoint};
301 use crate::testing::SimpleMockStreamFn;
302 use crate::types::{
303 AgentMessage, CustomMessage, CustomMessageRegistry, LlmMessage, ModelSpec, UserMessage,
304 };
305 use crate::{AgentError, Checkpoint};
306
307 #[derive(Debug, Clone, PartialEq)]
308 struct Tagged {
309 value: String,
310 }
311
312 impl CustomMessage for Tagged {
313 fn as_any(&self) -> &dyn std::any::Any {
314 self
315 }
316 fn type_name(&self) -> Option<&str> {
317 Some("Tagged")
318 }
319 fn to_json(&self) -> Option<serde_json::Value> {
320 Some(serde_json::json!({ "value": self.value }))
321 }
322 fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
323 Some(Box::new(self.clone()))
324 }
325 }
326
327 fn tagged_registry() -> CustomMessageRegistry {
328 let mut reg = CustomMessageRegistry::new();
329 reg.register(
330 "Tagged",
331 Box::new(|val: serde_json::Value| {
332 let value = val
333 .get("value")
334 .and_then(|v| v.as_str())
335 .ok_or_else(|| "missing value".to_string())?;
336 Ok(Box::new(Tagged {
337 value: value.to_string(),
338 }) as Box<dyn CustomMessage>)
339 }),
340 );
341 reg
342 }
343
344 fn make_agent(registry: Option<CustomMessageRegistry>) -> Agent {
345 let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
346 let mut opts =
347 AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn);
348 if let Some(reg) = registry {
349 opts = opts.with_custom_message_registry(reg);
350 }
351 Agent::new(opts)
352 }
353
354 fn user_msg(text: &str) -> AgentMessage {
355 AgentMessage::Llm(LlmMessage::User(UserMessage {
356 content: vec![crate::types::ContentBlock::Text {
357 text: text.to_string(),
358 }],
359 timestamp: 0,
360 cache_hint: None,
361 }))
362 }
363
364 fn seed_transient_runtime_state(agent: &mut Agent) {
365 agent.state.is_running = true;
366 agent.state.stream_message = Some(user_msg("streaming"));
367 agent
368 .state
369 .pending_tool_calls
370 .insert("tool-call-1".to_string());
371 agent.state.error = Some("stale error".to_string());
372 agent.abort_controller = Some(CancellationToken::new());
373 agent.in_flight_llm_messages = Some(vec![user_msg("in-flight-llm")]);
374 agent.in_flight_messages = Some(vec![user_msg("in-flight-checkpoint")]);
375 }
376
377 #[derive(Default)]
378 struct TestCheckpointStore {
379 data: Mutex<HashMap<String, String>>,
380 }
381
382 impl CheckpointStore for TestCheckpointStore {
383 fn save_checkpoint(&self, checkpoint: Checkpoint) -> CheckpointFuture<'_, ()> {
384 let json = serde_json::to_string(&checkpoint).unwrap();
385 let id = checkpoint.id;
386 Box::pin(async move {
387 self.data
388 .lock()
389 .unwrap_or_else(std::sync::PoisonError::into_inner)
390 .insert(id, json);
391 Ok(())
392 })
393 }
394
395 fn load_checkpoint(&self, id: &str) -> CheckpointFuture<'_, Option<Checkpoint>> {
396 let id = id.to_string();
397 Box::pin(async move {
398 self.data
399 .lock()
400 .unwrap_or_else(std::sync::PoisonError::into_inner)
401 .get(&id)
402 .map(|json| serde_json::from_str(json).map_err(std::io::Error::other))
403 .transpose()
404 })
405 }
406
407 fn list_checkpoints(&self) -> CheckpointFuture<'_, Vec<String>> {
408 Box::pin(async move {
409 Ok(self
410 .data
411 .lock()
412 .unwrap_or_else(std::sync::PoisonError::into_inner)
413 .keys()
414 .cloned()
415 .collect())
416 })
417 }
418
419 fn delete_checkpoint(&self, id: &str) -> CheckpointFuture<'_, ()> {
420 let id = id.to_string();
421 Box::pin(async move {
422 self.data
423 .lock()
424 .unwrap_or_else(std::sync::PoisonError::into_inner)
425 .remove(&id);
426 Ok(())
427 })
428 }
429 }
430
431 #[tokio::test]
432 async fn restore_from_checkpoint_rehydrates_custom_messages_via_registry() {
433 let mut source = make_agent(None);
434 source
435 .state
436 .messages
437 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
438 content: vec![crate::types::ContentBlock::Text {
439 text: "hi".to_string(),
440 }],
441 timestamp: 0,
442 cache_hint: None,
443 })));
444 source
445 .state
446 .messages
447 .push(AgentMessage::Custom(Box::new(Tagged {
448 value: "preserved".to_string(),
449 })));
450
451 let checkpoint = source.save_checkpoint("cp-1").await.unwrap();
452 let json = serde_json::to_string(&checkpoint).unwrap();
453 let loaded: crate::checkpoint::Checkpoint = serde_json::from_str(&json).unwrap();
454 assert_eq!(loaded.custom_messages.len(), 1);
455
456 let mut no_reg = make_agent(None);
458 no_reg.restore_from_checkpoint(&loaded).unwrap();
459 assert_eq!(no_reg.state.messages.len(), 1);
460
461 let mut with_reg = make_agent(Some(tagged_registry()));
464 with_reg.restore_from_checkpoint(&loaded).unwrap();
465 assert_eq!(with_reg.state.messages.len(), 2);
466 let restored = with_reg.state.messages[1]
467 .downcast_ref::<Tagged>()
468 .expect("custom message should be restored via registry");
469 assert_eq!(restored.value, "preserved");
470 }
471
472 #[tokio::test]
473 async fn pause_captures_both_steering_and_follow_up_queues() {
474 use crate::types::ContentBlock;
475
476 let mut agent = make_agent(None);
477 agent
479 .state
480 .messages
481 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
482 content: vec![ContentBlock::Text {
483 text: "hi".to_string(),
484 }],
485 timestamp: 0,
486 cache_hint: None,
487 })));
488
489 agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
491 content: vec![ContentBlock::Text {
492 text: "steering-msg".to_string(),
493 }],
494 timestamp: 1,
495 cache_hint: None,
496 })));
497 agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
498 content: vec![ContentBlock::Text {
499 text: "followup-msg".to_string(),
500 }],
501 timestamp: 2,
502 cache_hint: None,
503 })));
504
505 agent
507 .loop_active
508 .store(true, std::sync::atomic::Ordering::Release);
509
510 let checkpoint = agent.pause().expect("agent should be running");
511
512 assert_eq!(
514 checkpoint.pending_messages.len(),
515 1,
516 "follow-up queue should be captured"
517 );
518 assert_eq!(
519 checkpoint.pending_steering_messages.len(),
520 1,
521 "steering queue should be captured"
522 );
523
524 match &checkpoint.pending_messages[0] {
526 LlmMessage::User(u) => match &u.content[0] {
527 ContentBlock::Text { text } => assert_eq!(text, "followup-msg"),
528 _ => panic!("expected text content"),
529 },
530 _ => panic!("expected user message"),
531 }
532 match &checkpoint.pending_steering_messages[0] {
533 LlmMessage::User(u) => match &u.content[0] {
534 ContentBlock::Text { text } => assert_eq!(text, "steering-msg"),
535 _ => panic!("expected text content"),
536 },
537 _ => panic!("expected user message"),
538 }
539
540 assert!(
542 !agent.has_pending_messages(),
543 "queues should be empty after pause drains them"
544 );
545 }
546
547 #[tokio::test]
548 async fn restore_from_loop_checkpoint_routes_steering_to_steering_queue() {
549 use crate::checkpoint::LoopCheckpoint;
550 use crate::types::ContentBlock;
551
552 let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
553 content: vec![ContentBlock::Text {
554 text: "hi".to_string(),
555 }],
556 timestamp: 0,
557 cache_hint: None,
558 }))];
559
560 let cp = LoopCheckpoint::new("system", "mock", "mock-model", &messages)
561 .with_pending_messages(vec![LlmMessage::User(UserMessage {
562 content: vec![ContentBlock::Text {
563 text: "followup".to_string(),
564 }],
565 timestamp: 1,
566 cache_hint: None,
567 })])
568 .with_pending_steering_messages(vec![LlmMessage::User(UserMessage {
569 content: vec![ContentBlock::Text {
570 text: "steering".to_string(),
571 }],
572 timestamp: 2,
573 cache_hint: None,
574 })]);
575
576 let mut agent = make_agent(None);
577 agent.restore_from_loop_checkpoint(&cp).unwrap();
578
579 let steering = agent.steering_queue.lock().unwrap();
581 let follow_up = agent.follow_up_queue.lock().unwrap();
582
583 assert_eq!(steering.len(), 1, "steering queue should have 1 message");
584 assert_eq!(follow_up.len(), 1, "follow-up queue should have 1 message");
585
586 match &steering[0] {
587 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
588 ContentBlock::Text { text } => assert_eq!(text, "steering"),
589 _ => panic!("expected text"),
590 },
591 _ => panic!("expected user message in steering queue"),
592 }
593 match &follow_up[0] {
594 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
595 ContentBlock::Text { text } => assert_eq!(text, "followup"),
596 _ => panic!("expected text"),
597 },
598 _ => panic!("expected user message in follow-up queue"),
599 }
600 }
601
602 #[tokio::test]
607 async fn pause_drains_queues_so_resume_does_not_duplicate() {
608 use crate::types::ContentBlock;
609
610 let mut agent = make_agent(None);
611 agent
612 .state
613 .messages
614 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
615 content: vec![ContentBlock::Text {
616 text: "hi".to_string(),
617 }],
618 timestamp: 0,
619 cache_hint: None,
620 })));
621
622 agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
624 content: vec![ContentBlock::Text {
625 text: "steering-1".to_string(),
626 }],
627 timestamp: 1,
628 cache_hint: None,
629 })));
630 agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
631 content: vec![ContentBlock::Text {
632 text: "followup-1".to_string(),
633 }],
634 timestamp: 2,
635 cache_hint: None,
636 })));
637
638 agent
640 .loop_active
641 .store(true, std::sync::atomic::Ordering::Release);
642
643 let checkpoint = agent.pause().expect("agent should be running");
644
645 assert!(
647 !agent.has_pending_messages(),
648 "queues should be drained after pause"
649 );
650
651 agent
653 .loop_active
654 .store(false, std::sync::atomic::Ordering::Release);
655 agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
656
657 let steering = agent.steering_queue.lock().unwrap();
658 let follow_up = agent.follow_up_queue.lock().unwrap();
659
660 assert_eq!(
661 steering.len(),
662 1,
663 "steering queue should have exactly 1 message, not duplicated"
664 );
665 assert_eq!(
666 follow_up.len(),
667 1,
668 "follow-up queue should have exactly 1 message, not duplicated"
669 );
670 }
671
672 #[tokio::test]
673 async fn pause_and_resume_preserves_serializable_custom_pending_messages() {
674 use crate::types::ContentBlock;
675
676 let mut agent = make_agent(Some(tagged_registry()));
677 agent
678 .state
679 .messages
680 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
681 content: vec![ContentBlock::Text {
682 text: "hi".to_string(),
683 }],
684 timestamp: 0,
685 cache_hint: None,
686 })));
687
688 agent.follow_up(AgentMessage::Llm(LlmMessage::User(UserMessage {
689 content: vec![ContentBlock::Text {
690 text: "followup-1".to_string(),
691 }],
692 timestamp: 1,
693 cache_hint: None,
694 })));
695 agent.follow_up(AgentMessage::Custom(Box::new(Tagged {
696 value: "followup-custom".to_string(),
697 })));
698 agent.steer(AgentMessage::Custom(Box::new(Tagged {
699 value: "steering-custom".to_string(),
700 })));
701 agent.steer(AgentMessage::Llm(LlmMessage::User(UserMessage {
702 content: vec![ContentBlock::Text {
703 text: "steering-1".to_string(),
704 }],
705 timestamp: 2,
706 cache_hint: None,
707 })));
708
709 agent
710 .loop_active
711 .store(true, std::sync::atomic::Ordering::Release);
712
713 let checkpoint = agent.pause().expect("agent should be running");
714 assert!(
715 !agent.has_pending_messages(),
716 "queues should be drained after pause"
717 );
718
719 let json = serde_json::to_string(&checkpoint).unwrap();
720 let loaded: LoopCheckpoint = serde_json::from_str(&json).unwrap();
721
722 agent
723 .loop_active
724 .store(false, std::sync::atomic::Ordering::Release);
725 agent.restore_from_loop_checkpoint(&loaded).unwrap();
726
727 let steering = agent.steering_queue.lock().unwrap();
728 let follow_up = agent.follow_up_queue.lock().unwrap();
729
730 assert_eq!(
731 follow_up.len(),
732 2,
733 "follow-up queue should keep mixed messages"
734 );
735 assert_eq!(
736 steering.len(),
737 2,
738 "steering queue should keep mixed messages"
739 );
740
741 match &follow_up[0] {
742 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
743 ContentBlock::Text { text } => assert_eq!(text, "followup-1"),
744 _ => panic!("expected text content"),
745 },
746 _ => panic!("expected llm follow-up message"),
747 }
748 let follow_up_custom = follow_up[1]
749 .downcast_ref::<Tagged>()
750 .expect("custom follow-up should be restored");
751 assert_eq!(follow_up_custom.value, "followup-custom");
752
753 let steering_custom = steering[0]
754 .downcast_ref::<Tagged>()
755 .expect("custom steering should be restored");
756 assert_eq!(steering_custom.value, "steering-custom");
757 match &steering[1] {
758 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
759 ContentBlock::Text { text } => assert_eq!(text, "steering-1"),
760 _ => panic!("expected text content"),
761 },
762 _ => panic!("expected llm steering message"),
763 }
764 }
765
766 #[tokio::test]
767 async fn pause_captures_messages_already_moved_into_loop_local_pending_state() {
768 let mut agent = make_agent(Some(tagged_registry()));
769 agent.state.messages.push(user_msg("hi"));
770 agent.pending_message_snapshot.replace(&[
771 AgentMessage::Llm(LlmMessage::User(UserMessage {
772 content: vec![crate::types::ContentBlock::Text {
773 text: "polled-follow-up".to_string(),
774 }],
775 timestamp: 1,
776 cache_hint: None,
777 })),
778 AgentMessage::Custom(Box::new(Tagged {
779 value: "polled-custom".to_string(),
780 })),
781 ]);
782
783 agent
784 .loop_active
785 .store(true, std::sync::atomic::Ordering::Release);
786
787 let checkpoint = agent.pause().expect("agent should be running");
788 let pending = checkpoint.restore_pending_messages(agent.custom_message_registry.as_deref());
789
790 assert_eq!(
791 pending.len(),
792 2,
793 "pause should include loop-local pending messages even when the shared queue is already empty"
794 );
795 match &pending[0] {
796 AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
797 crate::types::ContentBlock::Text { text } => {
798 assert_eq!(text, "polled-follow-up");
799 }
800 other => panic!("expected text content, got {other:?}"),
801 },
802 other => panic!("expected user message, got {other:?}"),
803 }
804 let restored_custom = pending[1]
805 .downcast_ref::<Tagged>()
806 .expect("custom pending message should be preserved");
807 assert_eq!(restored_custom.value, "polled-custom");
808 }
809
810 #[tokio::test]
811 async fn pause_preserves_in_flight_custom_messages_during_streamed_runs() {
812 use futures::future::pending;
813
814 struct PendingStreamFn;
815
816 impl crate::stream::StreamFn for PendingStreamFn {
817 fn stream<'a>(
818 &'a self,
819 _model: &'a crate::ModelSpec,
820 _context: &'a crate::AgentContext,
821 _options: &'a crate::StreamOptions,
822 _cancellation_token: tokio_util::sync::CancellationToken,
823 ) -> std::pin::Pin<
824 Box<dyn futures::Stream<Item = crate::AssistantMessageEvent> + Send + 'a>,
825 > {
826 Box::pin(futures::stream::once(async {
827 pending::<()>().await;
828 crate::AssistantMessageEvent::error("unreachable")
829 }))
830 }
831 }
832
833 let stream_fn = Arc::new(PendingStreamFn);
834 let opts =
835 AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
836 .with_custom_message_registry(tagged_registry());
837 let mut agent = Agent::new(opts);
838 agent
839 .state
840 .messages
841 .push(AgentMessage::Custom(Box::new(Tagged {
842 value: "history-custom".to_string(),
843 })));
844
845 let _stream = agent.prompt_stream(vec![user_msg("start")]).unwrap();
846 let checkpoint = agent.pause().expect("agent should be running");
847 let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
848
849 assert_eq!(
850 restored.len(),
851 2,
852 "pause should keep custom history in checkpoint"
853 );
854
855 let restored_custom = restored[0]
856 .downcast_ref::<Tagged>()
857 .expect("custom history should be restored from the paused checkpoint");
858 assert_eq!(restored_custom.value, "history-custom");
859
860 match &restored[1] {
861 AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
862 crate::types::ContentBlock::Text { text } => assert_eq!(text, "start"),
863 other => panic!("expected text content, got {other:?}"),
864 },
865 other => panic!("expected user message, got {other:?}"),
866 }
867 }
868
869 #[tokio::test]
870 async fn restore_from_checkpoint_rebinds_stream_fn_for_matching_model() {
871 use crate::stream::StreamFn;
872 use crate::types::ContentBlock;
873
874 let model_a = ModelSpec::new("provider-a", "model-a");
875 let model_b = ModelSpec::new("provider-b", "model-b");
876 let stream_a = Arc::new(SimpleMockStreamFn::from_text("from-a"));
877 let stream_b = Arc::new(SimpleMockStreamFn::from_text("from-b"));
878
879 let opts = AgentOptions::new_simple("system", model_a.clone(), stream_a.clone())
881 .with_available_models(vec![(model_b.clone(), stream_b.clone())]);
882 let mut agent = Agent::new(opts);
883
884 assert!(
886 Arc::ptr_eq(&agent.stream_fn, &(stream_a.clone() as Arc<dyn StreamFn>)),
887 "initial stream_fn should be stream_a"
888 );
889
890 let source_opts = AgentOptions::new_simple("system", model_b.clone(), stream_b.clone());
892 let mut source = Agent::new(source_opts);
893 source
894 .state
895 .messages
896 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
897 content: vec![ContentBlock::Text {
898 text: "hello".to_string(),
899 }],
900 timestamp: 0,
901 cache_hint: None,
902 })));
903 let checkpoint = source.save_checkpoint("cp-rebind").await.unwrap();
904
905 agent.restore_from_checkpoint(&checkpoint).unwrap();
907
908 assert_eq!(agent.state.model.provider, "provider-b");
910 assert_eq!(agent.state.model.model_id, "model-b");
911
912 assert!(
914 Arc::ptr_eq(&agent.stream_fn, &(stream_b.clone() as Arc<dyn StreamFn>)),
915 "stream_fn should be rebound to stream_b after checkpoint restore"
916 );
917 }
918
919 #[tokio::test]
920 async fn restore_from_checkpoint_clears_transient_runtime_state() {
921 let mut source = make_agent(None);
922 source.state.messages.push(user_msg("restored"));
923 let checkpoint = source.save_checkpoint("cp-clear-runtime").await.unwrap();
924
925 let mut agent = make_agent(None);
926 seed_transient_runtime_state(&mut agent);
927
928 agent.restore_from_checkpoint(&checkpoint).unwrap();
929
930 assert!(!agent.state.is_running);
931 assert!(agent.state.stream_message.is_none());
932 assert!(agent.state.pending_tool_calls.is_empty());
933 assert!(agent.state.error.is_none());
934 assert!(agent.abort_controller.is_none());
935 assert!(agent.in_flight_llm_messages.is_none());
936 assert!(agent.in_flight_messages.is_none());
937 }
938
939 #[tokio::test]
940 async fn restore_from_checkpoint_rejects_restore_while_running() {
941 let mut source = make_agent(None);
942 source.state.messages.push(user_msg("restored"));
943 let checkpoint = source.save_checkpoint("cp-running-guard").await.unwrap();
944
945 let mut agent = make_agent(None);
946 let stream = agent.prompt_stream(vec![user_msg("hi")]).unwrap();
947
948 let err = agent.restore_from_checkpoint(&checkpoint).unwrap_err();
949 assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
950 assert!(
951 err.to_string()
952 .contains("cannot restore checkpoint while agent is running")
953 );
954 assert!(agent.is_running());
955
956 drop(stream);
957 agent.wait_for_idle().await;
958 }
959
960 #[tokio::test]
961 async fn restore_from_loop_checkpoint_rebinds_stream_fn_for_matching_model() {
962 use crate::checkpoint::LoopCheckpoint;
963 use crate::stream::StreamFn;
964 use crate::types::ContentBlock;
965
966 let model_a = ModelSpec::new("provider-a", "model-a");
967 let model_b = ModelSpec::new("provider-b", "model-b");
968 let stream_a = Arc::new(SimpleMockStreamFn::from_text("from-a"));
969 let stream_b = Arc::new(SimpleMockStreamFn::from_text("from-b"));
970
971 let opts = AgentOptions::new_simple("system", model_a.clone(), stream_a.clone())
972 .with_available_models(vec![(model_b.clone(), stream_b.clone())]);
973 let mut agent = Agent::new(opts);
974
975 assert!(
976 Arc::ptr_eq(&agent.stream_fn, &(stream_a.clone() as Arc<dyn StreamFn>)),
977 "initial stream_fn should be stream_a"
978 );
979
980 let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
982 content: vec![ContentBlock::Text {
983 text: "hello".to_string(),
984 }],
985 timestamp: 0,
986 cache_hint: None,
987 }))];
988 let cp = LoopCheckpoint::new("system", "provider-b", "model-b", &messages);
989
990 agent.restore_from_loop_checkpoint(&cp).unwrap();
991
992 assert_eq!(agent.state.model.provider, "provider-b");
993 assert_eq!(agent.state.model.model_id, "model-b");
994 assert!(
995 Arc::ptr_eq(&agent.stream_fn, &(stream_b.clone() as Arc<dyn StreamFn>)),
996 "stream_fn should be rebound to stream_b after loop checkpoint restore"
997 );
998 }
999
1000 #[tokio::test]
1001 async fn restore_from_loop_checkpoint_clears_transient_runtime_state() {
1002 let checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &[user_msg("hi")]);
1003 let mut agent = make_agent(None);
1004 seed_transient_runtime_state(&mut agent);
1005
1006 agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
1007
1008 assert!(!agent.state.is_running);
1009 assert!(agent.state.stream_message.is_none());
1010 assert!(agent.state.pending_tool_calls.is_empty());
1011 assert!(agent.state.error.is_none());
1012 assert!(agent.abort_controller.is_none());
1013 assert!(agent.in_flight_llm_messages.is_none());
1014 assert!(agent.in_flight_messages.is_none());
1015 }
1016
1017 #[tokio::test]
1018 async fn loop_checkpoint_resume_rehydrates_custom_messages_via_registry() {
1019 let messages = vec![
1020 AgentMessage::Llm(LlmMessage::User(UserMessage {
1021 content: vec![crate::types::ContentBlock::Text {
1022 text: "hi".to_string(),
1023 }],
1024 timestamp: 0,
1025 cache_hint: None,
1026 })),
1027 AgentMessage::Custom(Box::new(Tagged {
1028 value: "resumed".to_string(),
1029 })),
1030 ];
1031 let cp = LoopCheckpoint::new("system", "mock", "mock-model", &messages);
1032 let json = serde_json::to_string(&cp).unwrap();
1033 let loaded: LoopCheckpoint = serde_json::from_str(&json).unwrap();
1034
1035 let mut agent = make_agent(Some(tagged_registry()));
1036 agent.restore_from_loop_checkpoint(&loaded).unwrap();
1037 assert_eq!(agent.state.messages.len(), 2);
1038 let restored = agent.state.messages[1]
1039 .downcast_ref::<Tagged>()
1040 .expect("custom message should be restored via registry");
1041 assert_eq!(restored.value, "resumed");
1042 }
1043
1044 #[tokio::test]
1045 async fn load_and_restore_checkpoint_rejects_corrupt_state_snapshot() {
1046 let store = TestCheckpointStore::default();
1047 let checkpoint = Checkpoint::new(
1048 "bad-state",
1049 "system",
1050 "mock",
1051 "mock-model",
1052 &[user_msg("hi")],
1053 )
1054 .with_state(serde_json::json!(["bad"]));
1055 store.save_checkpoint(checkpoint).await.unwrap();
1056
1057 let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
1058 let agent_options =
1059 AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
1060 .with_checkpoint_store(store);
1061 let mut agent = Agent::new(agent_options);
1062
1063 let err = agent
1064 .load_and_restore_checkpoint("bad-state")
1065 .await
1066 .unwrap_err();
1067 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1068 assert!(err.to_string().contains("corrupted session state snapshot"));
1069 }
1070
1071 #[tokio::test]
1072 async fn load_and_restore_checkpoint_rejects_restore_while_running() {
1073 let store = TestCheckpointStore::default();
1074 let checkpoint = Checkpoint::new(
1075 "running-guard",
1076 "system",
1077 "mock",
1078 "mock-model",
1079 &[user_msg("hi")],
1080 );
1081 store.save_checkpoint(checkpoint).await.unwrap();
1082
1083 let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
1084 let agent_options =
1085 AgentOptions::new_simple("system", ModelSpec::new("mock", "mock-model"), stream_fn)
1086 .with_checkpoint_store(store);
1087 let mut agent = Agent::new(agent_options);
1088 let stream = agent.prompt_stream(vec![user_msg("start")]).unwrap();
1089
1090 let err = agent
1091 .load_and_restore_checkpoint("running-guard")
1092 .await
1093 .unwrap_err();
1094 assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
1095 assert!(
1096 err.to_string()
1097 .contains("cannot restore checkpoint while agent is running")
1098 );
1099 assert!(agent.is_running());
1100
1101 drop(stream);
1102 agent.wait_for_idle().await;
1103 }
1104
1105 #[tokio::test]
1106 async fn resume_rejects_corrupt_loop_checkpoint_state_snapshot() {
1107 let checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &[user_msg("hi")])
1108 .with_state(serde_json::json!(["bad"]));
1109 let mut agent = make_agent(None);
1110
1111 let err = agent.resume(&checkpoint).await.unwrap_err();
1112 match err {
1113 AgentError::StreamError { source } => {
1114 let io = source
1115 .downcast_ref::<std::io::Error>()
1116 .expect("expected io::Error source");
1117 assert_eq!(io.kind(), std::io::ErrorKind::InvalidData);
1118 assert!(io.to_string().contains("corrupted session state snapshot"));
1119 }
1120 other => panic!("expected StreamError, got {other:?}"),
1121 }
1122 }
1123
1124 #[tokio::test]
1125 async fn restore_from_checkpoint_keeps_live_state_when_snapshot_is_corrupt() {
1126 let checkpoint = Checkpoint::new(
1127 "bad-state",
1128 "restored-system",
1129 "restored",
1130 "restored-model",
1131 &[user_msg("restored")],
1132 )
1133 .with_state(serde_json::json!(["bad"]));
1134 let mut agent = make_agent(None);
1135 agent.state.messages.push(user_msg("existing"));
1136 agent.state.system_prompt = "live-system".to_string();
1137 agent.state.model = ModelSpec::new("live-provider", "live-model");
1138 {
1139 let mut state = agent
1140 .session_state()
1141 .write()
1142 .unwrap_or_else(std::sync::PoisonError::into_inner);
1143 state.set("live", 7_i64).unwrap();
1144 }
1145
1146 let err = agent.restore_from_checkpoint(&checkpoint).unwrap_err();
1147 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
1148
1149 assert_eq!(agent.state.messages.len(), 1);
1150 match &agent.state.messages[0] {
1151 AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
1152 crate::types::ContentBlock::Text { text } => assert_eq!(text, "existing"),
1153 other => panic!("expected text content, got {other:?}"),
1154 },
1155 other => panic!("expected user message, got {other:?}"),
1156 }
1157 assert_eq!(agent.state.system_prompt, "live-system");
1158 assert_eq!(agent.state.model.provider, "live-provider");
1159 assert_eq!(agent.state.model.model_id, "live-model");
1160
1161 let state = agent
1162 .session_state()
1163 .read()
1164 .unwrap_or_else(std::sync::PoisonError::into_inner);
1165 assert_eq!(state.get::<i64>("live"), Some(7));
1166 }
1167
1168 #[tokio::test]
1169 async fn restore_from_loop_checkpoint_keeps_live_state_when_snapshot_is_corrupt() {
1170 let checkpoint = LoopCheckpoint::new(
1171 "restored-system",
1172 "restored",
1173 "restored-model",
1174 &[user_msg("restored")],
1175 )
1176 .with_state(serde_json::json!(["bad"]));
1177 let mut agent = make_agent(None);
1178 agent.state.messages.push(user_msg("existing"));
1179 agent.state.system_prompt = "live-system".to_string();
1180 agent.state.model = ModelSpec::new("live-provider", "live-model");
1181 agent.follow_up(user_msg("live-follow-up"));
1182 agent.steer(user_msg("live-steering"));
1183 {
1184 let mut state = agent
1185 .session_state()
1186 .write()
1187 .unwrap_or_else(std::sync::PoisonError::into_inner);
1188 state.set("live", 9_i64).unwrap();
1189 }
1190
1191 let err = agent.resume(&checkpoint).await.unwrap_err();
1192 match err {
1193 AgentError::StreamError { source } => {
1194 let io = source
1195 .downcast_ref::<std::io::Error>()
1196 .expect("expected io::Error source");
1197 assert_eq!(io.kind(), std::io::ErrorKind::InvalidData);
1198 }
1199 other => panic!("expected StreamError, got {other:?}"),
1200 }
1201
1202 assert_eq!(agent.state.messages.len(), 1);
1203 match &agent.state.messages[0] {
1204 AgentMessage::Llm(LlmMessage::User(user)) => match &user.content[0] {
1205 crate::types::ContentBlock::Text { text } => assert_eq!(text, "existing"),
1206 other => panic!("expected text content, got {other:?}"),
1207 },
1208 other => panic!("expected user message, got {other:?}"),
1209 }
1210 assert_eq!(agent.state.system_prompt, "live-system");
1211 assert_eq!(agent.state.model.provider, "live-provider");
1212 assert_eq!(agent.state.model.model_id, "live-model");
1213
1214 let state = agent
1215 .session_state()
1216 .read()
1217 .unwrap_or_else(std::sync::PoisonError::into_inner);
1218 assert_eq!(state.get::<i64>("live"), Some(9));
1219 drop(state);
1220
1221 let follow_up = agent
1222 .follow_up_queue
1223 .lock()
1224 .unwrap_or_else(std::sync::PoisonError::into_inner);
1225 let steering = agent
1226 .steering_queue
1227 .lock()
1228 .unwrap_or_else(std::sync::PoisonError::into_inner);
1229 assert_eq!(
1230 follow_up.len(),
1231 1,
1232 "failed restore should not clear follow-up queue"
1233 );
1234 assert_eq!(
1235 steering.len(),
1236 1,
1237 "failed restore should not clear steering queue"
1238 );
1239 }
1240
1241 #[tokio::test]
1242 async fn restore_from_checkpoint_clears_session_state_when_snapshot_missing() {
1243 let mut source = make_agent(None);
1244 source
1245 .state
1246 .messages
1247 .push(AgentMessage::Llm(LlmMessage::User(UserMessage {
1248 content: vec![crate::types::ContentBlock::Text {
1249 text: "hi".to_string(),
1250 }],
1251 timestamp: 0,
1252 cache_hint: None,
1253 })));
1254
1255 let mut checkpoint = source.save_checkpoint("cp-empty-state").await.unwrap();
1256 checkpoint.state = None;
1257
1258 let mut agent = make_agent(None);
1259 {
1260 let mut state = agent
1261 .session_state()
1262 .write()
1263 .unwrap_or_else(std::sync::PoisonError::into_inner);
1264 state.set("stale", 42_i64).unwrap();
1265 }
1266
1267 agent.restore_from_checkpoint(&checkpoint).unwrap();
1268
1269 let state = agent
1270 .session_state()
1271 .read()
1272 .unwrap_or_else(std::sync::PoisonError::into_inner);
1273 assert!(
1274 state.is_empty(),
1275 "missing snapshot should clear stale state"
1276 );
1277 }
1278
1279 #[tokio::test]
1280 async fn restore_from_loop_checkpoint_clears_session_state_when_snapshot_missing() {
1281 use crate::checkpoint::LoopCheckpoint;
1282
1283 let messages = vec![AgentMessage::Llm(LlmMessage::User(UserMessage {
1284 content: vec![crate::types::ContentBlock::Text {
1285 text: "hi".to_string(),
1286 }],
1287 timestamp: 0,
1288 cache_hint: None,
1289 }))];
1290 let mut checkpoint = LoopCheckpoint::new("system", "mock", "mock-model", &messages);
1291 checkpoint.state = None;
1292
1293 let mut agent = make_agent(None);
1294 {
1295 let mut state = agent
1296 .session_state()
1297 .write()
1298 .unwrap_or_else(std::sync::PoisonError::into_inner);
1299 state.set("stale", 99_i64).unwrap();
1300 }
1301
1302 agent.restore_from_loop_checkpoint(&checkpoint).unwrap();
1303
1304 let state = agent
1305 .session_state()
1306 .read()
1307 .unwrap_or_else(std::sync::PoisonError::into_inner);
1308 assert!(
1309 state.is_empty(),
1310 "missing snapshot should clear stale state"
1311 );
1312 }
1313
1314 #[tokio::test]
1321 async fn pause_captures_messages_drained_from_pending_into_loop_context() {
1322 let mut agent = make_agent(None);
1323 agent.in_flight_messages = Some(vec![user_msg("original")]);
1329 agent.pending_message_snapshot.clear();
1331 agent
1334 .loop_context_snapshot
1335 .replace(&[user_msg("original"), user_msg("consumed-pending")]);
1336
1337 agent
1338 .loop_active
1339 .store(true, std::sync::atomic::Ordering::Release);
1340
1341 let checkpoint = agent.pause().expect("agent should be paused");
1342 let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
1343
1344 assert_eq!(
1345 restored.len(),
1346 2,
1347 "pause snapshot must include messages already consumed from the pending queue \
1348 into loop context, not just in_flight_messages"
1349 );
1350 match &restored[0] {
1351 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1352 crate::types::ContentBlock::Text { text } => {
1353 assert_eq!(text, "original");
1354 }
1355 other => panic!("expected text content, got {other:?}"),
1356 },
1357 other => panic!("expected user message, got {other:?}"),
1358 }
1359 match &restored[1] {
1360 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1361 crate::types::ContentBlock::Text { text } => {
1362 assert_eq!(text, "consumed-pending");
1363 }
1364 other => panic!("expected text content, got {other:?}"),
1365 },
1366 other => panic!("expected user message, got {other:?}"),
1367 }
1368 }
1369
1370 #[tokio::test]
1373 async fn pause_falls_back_to_in_flight_messages_when_context_snapshot_absent() {
1374 let mut agent = make_agent(None);
1375
1376 agent.in_flight_messages = Some(vec![user_msg("in-flight")]);
1378 agent
1382 .loop_active
1383 .store(true, std::sync::atomic::Ordering::Release);
1384
1385 let checkpoint = agent.pause().expect("agent should be paused");
1386 let restored = checkpoint.restore_messages(agent.custom_message_registry.as_deref());
1387
1388 assert_eq!(
1389 restored.len(),
1390 1,
1391 "pause must fall back to in_flight_messages when loop_context_snapshot is absent"
1392 );
1393 match &restored[0] {
1394 AgentMessage::Llm(LlmMessage::User(u)) => match &u.content[0] {
1395 crate::types::ContentBlock::Text { text } => {
1396 assert_eq!(text, "in-flight");
1397 }
1398 other => panic!("expected text content, got {other:?}"),
1399 },
1400 other => panic!("expected user message, got {other:?}"),
1401 }
1402 }
1403}