1use chrono::{DateTime, Duration as ChronoDuration, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::time::Duration;
17use thiserror::Error;
18
19use crate::conversation::message::Message;
20
21pub type StateManagerResult<T> = Result<T, StateManagerError>;
23
24#[derive(Debug, Error)]
26pub enum StateManagerError {
27 #[error("State not found: {0}")]
29 NotFound(String),
30
31 #[error("Checkpoint not found: {0}")]
33 CheckpointNotFound(String),
34
35 #[error("IO error: {0}")]
37 Io(#[from] std::io::Error),
38
39 #[error("Serialization error: {0}")]
41 Serialization(String),
42
43 #[error("Invalid state: {0}")]
45 InvalidState(String),
46}
47
48impl From<serde_json::Error> for StateManagerError {
49 fn from(err: serde_json::Error) -> Self {
50 StateManagerError::Serialization(err.to_string())
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
56#[serde(rename_all = "snake_case")]
57pub enum AgentStateStatus {
58 #[default]
60 Running,
61 Paused,
63 Completed,
65 Failed,
67 Cancelled,
69}
70
71impl AgentStateStatus {
72 pub fn is_resumable(&self) -> bool {
74 matches!(self, Self::Running | Self::Paused | Self::Failed)
75 }
76
77 pub fn is_terminal(&self) -> bool {
79 matches!(self, Self::Completed | Self::Cancelled)
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(rename_all = "camelCase")]
86pub struct ToolCallRecord {
87 pub id: String,
89 pub tool_name: String,
91 pub input: serde_json::Value,
93 pub output: Option<serde_json::Value>,
95 pub success: Option<bool>,
97 pub error: Option<String>,
99 pub timestamp: DateTime<Utc>,
101}
102
103impl ToolCallRecord {
104 pub fn new(tool_name: impl Into<String>, input: serde_json::Value) -> Self {
106 Self {
107 id: uuid::Uuid::new_v4().to_string(),
108 tool_name: tool_name.into(),
109 input,
110 output: None,
111 success: None,
112 error: None,
113 timestamp: Utc::now(),
114 }
115 }
116
117 pub fn complete_success(&mut self, output: serde_json::Value) {
119 self.output = Some(output);
120 self.success = Some(true);
121 }
122
123 pub fn complete_failure(&mut self, error: impl Into<String>) {
125 self.success = Some(false);
126 self.error = Some(error.into());
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
132#[serde(rename_all = "camelCase")]
133pub struct Checkpoint {
134 pub id: String,
136 pub agent_id: String,
138 pub name: Option<String>,
140 pub step: usize,
142 pub messages: Vec<Message>,
144 pub tool_calls: Vec<ToolCallRecord>,
146 pub results: Vec<serde_json::Value>,
148 pub metadata: HashMap<String, serde_json::Value>,
150 pub created_at: DateTime<Utc>,
152}
153
154impl Checkpoint {
155 pub fn new(agent_id: impl Into<String>, step: usize) -> Self {
157 Self {
158 id: uuid::Uuid::new_v4().to_string(),
159 agent_id: agent_id.into(),
160 name: None,
161 step,
162 messages: Vec::new(),
163 tool_calls: Vec::new(),
164 results: Vec::new(),
165 metadata: HashMap::new(),
166 created_at: Utc::now(),
167 }
168 }
169
170 pub fn with_name(mut self, name: impl Into<String>) -> Self {
172 self.name = Some(name.into());
173 self
174 }
175
176 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
178 self.messages = messages;
179 self
180 }
181
182 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCallRecord>) -> Self {
184 self.tool_calls = tool_calls;
185 self
186 }
187
188 pub fn with_results(mut self, results: Vec<serde_json::Value>) -> Self {
190 self.results = results;
191 self
192 }
193
194 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
196 self.metadata.insert(key.into(), value);
197 self
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct AgentState {
205 pub id: String,
207 pub agent_type: String,
209 pub status: AgentStateStatus,
211 pub created_at: DateTime<Utc>,
213 pub updated_at: DateTime<Utc>,
215 pub prompt: String,
217 pub messages: Vec<Message>,
219 pub tool_calls: Vec<ToolCallRecord>,
221 pub results: Vec<serde_json::Value>,
223 pub checkpoint: Option<Checkpoint>,
225 pub checkpoints: Vec<Checkpoint>,
227 pub current_step: usize,
229 pub total_steps: Option<usize>,
231 pub error_count: usize,
233 pub retry_count: usize,
235 pub max_retries: usize,
237 pub metadata: HashMap<String, serde_json::Value>,
239}
240
241impl AgentState {
242 pub fn new(
244 id: impl Into<String>,
245 agent_type: impl Into<String>,
246 prompt: impl Into<String>,
247 ) -> Self {
248 let now = Utc::now();
249 Self {
250 id: id.into(),
251 agent_type: agent_type.into(),
252 status: AgentStateStatus::Running,
253 created_at: now,
254 updated_at: now,
255 prompt: prompt.into(),
256 messages: Vec::new(),
257 tool_calls: Vec::new(),
258 results: Vec::new(),
259 checkpoint: None,
260 checkpoints: Vec::new(),
261 current_step: 0,
262 total_steps: None,
263 error_count: 0,
264 retry_count: 0,
265 max_retries: 3,
266 metadata: HashMap::new(),
267 }
268 }
269
270 pub fn with_status(mut self, status: AgentStateStatus) -> Self {
272 self.status = status;
273 self.updated_at = Utc::now();
274 self
275 }
276
277 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
279 self.max_retries = max_retries;
280 self
281 }
282
283 pub fn with_total_steps(mut self, total: usize) -> Self {
285 self.total_steps = Some(total);
286 self
287 }
288
289 pub fn add_message(&mut self, message: Message) {
291 self.messages.push(message);
292 self.updated_at = Utc::now();
293 }
294
295 pub fn add_tool_call(&mut self, tool_call: ToolCallRecord) {
297 self.tool_calls.push(tool_call);
298 self.updated_at = Utc::now();
299 }
300
301 pub fn add_result(&mut self, result: serde_json::Value) {
303 self.results.push(result);
304 self.updated_at = Utc::now();
305 }
306
307 pub fn increment_step(&mut self) {
309 self.current_step += 1;
310 self.updated_at = Utc::now();
311 }
312
313 pub fn record_error(&mut self) {
315 self.error_count += 1;
316 self.updated_at = Utc::now();
317 }
318
319 pub fn record_retry(&mut self) {
321 self.retry_count += 1;
322 self.updated_at = Utc::now();
323 }
324
325 pub fn reset_errors(&mut self) {
327 self.error_count = 0;
328 self.retry_count = 0;
329 self.updated_at = Utc::now();
330 }
331
332 pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
334 self.metadata.insert(key.into(), value);
335 self.updated_at = Utc::now();
336 }
337
338 pub fn create_checkpoint(&mut self, name: Option<&str>) -> Checkpoint {
340 let mut checkpoint = Checkpoint::new(&self.id, self.current_step)
341 .with_messages(self.messages.clone())
342 .with_tool_calls(self.tool_calls.clone())
343 .with_results(self.results.clone());
344
345 if let Some(n) = name {
346 checkpoint = checkpoint.with_name(n);
347 }
348
349 for (k, v) in &self.metadata {
350 checkpoint = checkpoint.with_metadata(k.clone(), v.clone());
351 }
352
353 self.checkpoint = Some(checkpoint.clone());
354 self.checkpoints.push(checkpoint.clone());
355 self.updated_at = Utc::now();
356
357 checkpoint
358 }
359
360 pub fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) {
362 self.current_step = checkpoint.step;
363 self.messages = checkpoint.messages.clone();
364 self.tool_calls = checkpoint.tool_calls.clone();
365 self.results = checkpoint.results.clone();
366 self.metadata = checkpoint.metadata.clone();
367 self.checkpoint = Some(checkpoint.clone());
368 self.updated_at = Utc::now();
369 }
370
371 pub fn can_resume(&self) -> bool {
373 self.status.is_resumable()
374 }
375
376 pub fn latest_checkpoint(&self) -> Option<&Checkpoint> {
378 self.checkpoints.last()
379 }
380
381 pub fn age(&self) -> ChronoDuration {
383 Utc::now().signed_duration_since(self.created_at)
384 }
385
386 pub fn is_expired(&self, max_age: Duration) -> bool {
388 let age = self.age();
389 if let Ok(max_age_chrono) = ChronoDuration::from_std(max_age) {
390 age > max_age_chrono
391 } else {
392 false
393 }
394 }
395}
396
397impl PartialEq for AgentState {
398 fn eq(&self, other: &Self) -> bool {
399 self.id == other.id
400 }
401}
402
403impl Eq for AgentState {}
404
405#[derive(Debug, Clone, Default)]
407pub struct StateFilter {
408 pub agent_type: Option<String>,
410 pub status: Option<AgentStateStatus>,
412 pub created_after: Option<DateTime<Utc>>,
414 pub created_before: Option<DateTime<Utc>>,
416 pub has_checkpoints: Option<bool>,
418 pub limit: Option<usize>,
420}
421
422impl StateFilter {
423 pub fn new() -> Self {
425 Self::default()
426 }
427
428 pub fn with_agent_type(mut self, agent_type: impl Into<String>) -> Self {
430 self.agent_type = Some(agent_type.into());
431 self
432 }
433
434 pub fn with_status(mut self, status: AgentStateStatus) -> Self {
436 self.status = Some(status);
437 self
438 }
439
440 pub fn created_between(mut self, after: DateTime<Utc>, before: DateTime<Utc>) -> Self {
442 self.created_after = Some(after);
443 self.created_before = Some(before);
444 self
445 }
446
447 pub fn with_checkpoints(mut self, has: bool) -> Self {
449 self.has_checkpoints = Some(has);
450 self
451 }
452
453 pub fn with_limit(mut self, limit: usize) -> Self {
455 self.limit = Some(limit);
456 self
457 }
458
459 pub fn matches(&self, state: &AgentState) -> bool {
461 if let Some(ref agent_type) = self.agent_type {
462 if &state.agent_type != agent_type {
463 return false;
464 }
465 }
466
467 if let Some(status) = self.status {
468 if state.status != status {
469 return false;
470 }
471 }
472
473 if let Some(after) = self.created_after {
474 if state.created_at < after {
475 return false;
476 }
477 }
478
479 if let Some(before) = self.created_before {
480 if state.created_at > before {
481 return false;
482 }
483 }
484
485 if let Some(has_checkpoints) = self.has_checkpoints {
486 let has = !state.checkpoints.is_empty();
487 if has != has_checkpoints {
488 return false;
489 }
490 }
491
492 true
493 }
494}
495
496#[derive(Debug)]
498pub struct AgentStateManager {
499 storage_dir: PathBuf,
501}
502
503impl Default for AgentStateManager {
504 fn default() -> Self {
505 Self::new(None)
506 }
507}
508
509impl AgentStateManager {
510 pub fn new(storage_dir: Option<PathBuf>) -> Self {
512 let storage_dir = storage_dir.unwrap_or_else(|| PathBuf::from(".aster/states"));
513 Self { storage_dir }
514 }
515
516 pub fn storage_dir(&self) -> &PathBuf {
518 &self.storage_dir
519 }
520
521 pub fn set_storage_dir(&mut self, dir: PathBuf) {
523 self.storage_dir = dir;
524 }
525
526 fn state_file_path(&self, id: &str) -> PathBuf {
528 self.storage_dir.join(format!("{}.json", id))
529 }
530
531 fn checkpoints_dir(&self, agent_id: &str) -> PathBuf {
533 self.storage_dir.join("checkpoints").join(agent_id)
534 }
535
536 fn checkpoint_file_path(&self, agent_id: &str, checkpoint_id: &str) -> PathBuf {
538 self.checkpoints_dir(agent_id)
539 .join(format!("{}.json", checkpoint_id))
540 }
541
542 pub async fn save_state(&self, state: &AgentState) -> StateManagerResult<()> {
544 tokio::fs::create_dir_all(&self.storage_dir).await?;
546
547 let file_path = self.state_file_path(&state.id);
548 let json = serde_json::to_string_pretty(state)?;
549 tokio::fs::write(file_path, json).await?;
550
551 Ok(())
552 }
553
554 pub async fn load_state(&self, id: &str) -> StateManagerResult<Option<AgentState>> {
556 let file_path = self.state_file_path(id);
557
558 if !file_path.exists() {
559 return Ok(None);
560 }
561
562 let json = tokio::fs::read_to_string(&file_path).await?;
563 let state: AgentState = serde_json::from_str(&json)?;
564
565 Ok(Some(state))
566 }
567
568 pub async fn list_states(
570 &self,
571 filter: Option<StateFilter>,
572 ) -> StateManagerResult<Vec<AgentState>> {
573 if !self.storage_dir.exists() {
574 return Ok(Vec::new());
575 }
576
577 let mut states = Vec::new();
578 let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
579
580 while let Some(entry) = entries.next_entry().await? {
581 let path = entry.path();
582
583 if path.is_dir() || path.extension().is_none_or(|ext| ext != "json") {
585 continue;
586 }
587
588 if let Ok(json) = tokio::fs::read_to_string(&path).await {
590 if let Ok(state) = serde_json::from_str::<AgentState>(&json) {
591 if let Some(ref f) = filter {
593 if f.matches(&state) {
594 states.push(state);
595 }
596 } else {
597 states.push(state);
598 }
599 }
600 }
601 }
602
603 states.sort_by(|a, b| b.created_at.cmp(&a.created_at));
605
606 if let Some(ref f) = filter {
608 if let Some(limit) = f.limit {
609 states.truncate(limit);
610 }
611 }
612
613 Ok(states)
614 }
615
616 pub async fn delete_state(&self, id: &str) -> StateManagerResult<bool> {
618 let file_path = self.state_file_path(id);
619
620 if !file_path.exists() {
621 return Ok(false);
622 }
623
624 tokio::fs::remove_file(&file_path).await?;
625
626 let checkpoints_dir = self.checkpoints_dir(id);
628 if checkpoints_dir.exists() {
629 tokio::fs::remove_dir_all(&checkpoints_dir).await?;
630 }
631
632 Ok(true)
633 }
634
635 pub async fn cleanup_expired(&self, max_age: Duration) -> StateManagerResult<usize> {
637 if !self.storage_dir.exists() {
638 return Ok(0);
639 }
640
641 let mut cleaned = 0;
642 let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
643
644 while let Some(entry) = entries.next_entry().await? {
645 let path = entry.path();
646
647 if path.is_dir() || path.extension().is_none_or(|ext| ext != "json") {
649 continue;
650 }
651
652 if let Ok(json) = tokio::fs::read_to_string(&path).await {
654 if let Ok(state) = serde_json::from_str::<AgentState>(&json) {
655 if state.is_expired(max_age) {
656 if tokio::fs::remove_file(&path).await.is_ok() {
658 cleaned += 1;
659
660 let checkpoints_dir = self.checkpoints_dir(&state.id);
662 let _ = tokio::fs::remove_dir_all(&checkpoints_dir).await;
663 }
664 }
665 }
666 }
667 }
668
669 Ok(cleaned)
670 }
671
672 pub async fn save_checkpoint(&self, checkpoint: &Checkpoint) -> StateManagerResult<()> {
674 let checkpoints_dir = self.checkpoints_dir(&checkpoint.agent_id);
675 tokio::fs::create_dir_all(&checkpoints_dir).await?;
676
677 let file_path = self.checkpoint_file_path(&checkpoint.agent_id, &checkpoint.id);
678 let json = serde_json::to_string_pretty(checkpoint)?;
679 tokio::fs::write(file_path, json).await?;
680
681 Ok(())
682 }
683
684 pub async fn load_checkpoint(
686 &self,
687 agent_id: &str,
688 checkpoint_id: &str,
689 ) -> StateManagerResult<Option<Checkpoint>> {
690 let file_path = self.checkpoint_file_path(agent_id, checkpoint_id);
691
692 if !file_path.exists() {
693 return Ok(None);
694 }
695
696 let json = tokio::fs::read_to_string(&file_path).await?;
697 let checkpoint: Checkpoint = serde_json::from_str(&json)?;
698
699 Ok(Some(checkpoint))
700 }
701
702 pub async fn list_checkpoints(&self, agent_id: &str) -> StateManagerResult<Vec<Checkpoint>> {
704 let checkpoints_dir = self.checkpoints_dir(agent_id);
705
706 if !checkpoints_dir.exists() {
707 return Ok(Vec::new());
708 }
709
710 let mut checkpoints = Vec::new();
711 let mut entries = tokio::fs::read_dir(&checkpoints_dir).await?;
712
713 while let Some(entry) = entries.next_entry().await? {
714 let path = entry.path();
715
716 if path.extension().is_none_or(|ext| ext != "json") {
718 continue;
719 }
720
721 if let Ok(json) = tokio::fs::read_to_string(&path).await {
722 if let Ok(checkpoint) = serde_json::from_str::<Checkpoint>(&json) {
723 checkpoints.push(checkpoint);
724 }
725 }
726 }
727
728 checkpoints.sort_by_key(|c| c.step);
730
731 Ok(checkpoints)
732 }
733
734 pub async fn delete_checkpoint(
736 &self,
737 agent_id: &str,
738 checkpoint_id: &str,
739 ) -> StateManagerResult<bool> {
740 let file_path = self.checkpoint_file_path(agent_id, checkpoint_id);
741
742 if !file_path.exists() {
743 return Ok(false);
744 }
745
746 tokio::fs::remove_file(&file_path).await?;
747 Ok(true)
748 }
749
750 pub async fn state_exists(&self, id: &str) -> bool {
752 self.state_file_path(id).exists()
753 }
754
755 pub async fn state_count(&self) -> StateManagerResult<usize> {
757 if !self.storage_dir.exists() {
758 return Ok(0);
759 }
760
761 let mut count = 0;
762 let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
763
764 while let Some(entry) = entries.next_entry().await? {
765 let path = entry.path();
766 if !path.is_dir() && path.extension().is_some_and(|ext| ext == "json") {
767 count += 1;
768 }
769 }
770
771 Ok(count)
772 }
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778 use tempfile::TempDir;
779
780 fn create_test_state(id: &str) -> AgentState {
781 AgentState::new(id, "test_agent", "Test prompt")
782 }
783
784 #[test]
785 fn test_agent_state_creation() {
786 let state = AgentState::new("agent-1", "test_agent", "Test prompt");
787
788 assert_eq!(state.id, "agent-1");
789 assert_eq!(state.agent_type, "test_agent");
790 assert_eq!(state.prompt, "Test prompt");
791 assert_eq!(state.status, AgentStateStatus::Running);
792 assert_eq!(state.current_step, 0);
793 assert_eq!(state.error_count, 0);
794 assert!(state.messages.is_empty());
795 assert!(state.checkpoints.is_empty());
796 }
797
798 #[test]
799 fn test_agent_state_status_resumable() {
800 assert!(AgentStateStatus::Running.is_resumable());
801 assert!(AgentStateStatus::Paused.is_resumable());
802 assert!(AgentStateStatus::Failed.is_resumable());
803 assert!(!AgentStateStatus::Completed.is_resumable());
804 assert!(!AgentStateStatus::Cancelled.is_resumable());
805 }
806
807 #[test]
808 fn test_agent_state_status_terminal() {
809 assert!(!AgentStateStatus::Running.is_terminal());
810 assert!(!AgentStateStatus::Paused.is_terminal());
811 assert!(!AgentStateStatus::Failed.is_terminal());
812 assert!(AgentStateStatus::Completed.is_terminal());
813 assert!(AgentStateStatus::Cancelled.is_terminal());
814 }
815
816 #[test]
817 fn test_agent_state_increment_step() {
818 let mut state = create_test_state("agent-1");
819 assert_eq!(state.current_step, 0);
820
821 state.increment_step();
822 assert_eq!(state.current_step, 1);
823
824 state.increment_step();
825 assert_eq!(state.current_step, 2);
826 }
827
828 #[test]
829 fn test_agent_state_error_tracking() {
830 let mut state = create_test_state("agent-1");
831 assert_eq!(state.error_count, 0);
832 assert_eq!(state.retry_count, 0);
833
834 state.record_error();
835 assert_eq!(state.error_count, 1);
836
837 state.record_retry();
838 assert_eq!(state.retry_count, 1);
839
840 state.reset_errors();
841 assert_eq!(state.error_count, 0);
842 assert_eq!(state.retry_count, 0);
843 }
844
845 #[test]
846 fn test_checkpoint_creation() {
847 let checkpoint = Checkpoint::new("agent-1", 5).with_name("test_checkpoint");
848
849 assert!(!checkpoint.id.is_empty());
850 assert_eq!(checkpoint.agent_id, "agent-1");
851 assert_eq!(checkpoint.step, 5);
852 assert_eq!(checkpoint.name, Some("test_checkpoint".to_string()));
853 }
854
855 #[test]
856 fn test_agent_state_create_checkpoint() {
857 let mut state = create_test_state("agent-1");
858 state.current_step = 3;
859 state.set_metadata("key", serde_json::json!("value"));
860
861 let checkpoint = state.create_checkpoint(Some("checkpoint-1"));
862
863 assert_eq!(checkpoint.agent_id, "agent-1");
864 assert_eq!(checkpoint.step, 3);
865 assert_eq!(checkpoint.name, Some("checkpoint-1".to_string()));
866 assert!(state.checkpoint.is_some());
867 assert_eq!(state.checkpoints.len(), 1);
868 }
869
870 #[test]
871 fn test_agent_state_restore_from_checkpoint() {
872 let mut state = create_test_state("agent-1");
873 state.current_step = 5;
874 state.add_result(serde_json::json!({"result": 1}));
875
876 let checkpoint = state.create_checkpoint(Some("cp-1"));
877
878 state.current_step = 10;
880 state.add_result(serde_json::json!({"result": 2}));
881
882 state.restore_from_checkpoint(&checkpoint);
884
885 assert_eq!(state.current_step, 5);
886 assert_eq!(state.results.len(), 1);
887 }
888
889 #[test]
890 fn test_tool_call_record() {
891 let mut record = ToolCallRecord::new("test_tool", serde_json::json!({"arg": "value"}));
892
893 assert!(!record.id.is_empty());
894 assert_eq!(record.tool_name, "test_tool");
895 assert!(record.success.is_none());
896
897 record.complete_success(serde_json::json!({"output": "result"}));
898 assert_eq!(record.success, Some(true));
899 assert!(record.output.is_some());
900 }
901
902 #[test]
903 fn test_tool_call_record_failure() {
904 let mut record = ToolCallRecord::new("test_tool", serde_json::json!({}));
905 record.complete_failure("Test error");
906
907 assert_eq!(record.success, Some(false));
908 assert_eq!(record.error, Some("Test error".to_string()));
909 }
910
911 #[test]
912 fn test_state_filter_matches() {
913 let state = AgentState::new("agent-1", "test_agent", "prompt")
914 .with_status(AgentStateStatus::Running);
915
916 let filter = StateFilter::new();
918 assert!(filter.matches(&state));
919
920 let filter = StateFilter::new().with_agent_type("test_agent");
922 assert!(filter.matches(&state));
923
924 let filter = StateFilter::new().with_agent_type("other_agent");
925 assert!(!filter.matches(&state));
926
927 let filter = StateFilter::new().with_status(AgentStateStatus::Running);
929 assert!(filter.matches(&state));
930
931 let filter = StateFilter::new().with_status(AgentStateStatus::Completed);
932 assert!(!filter.matches(&state));
933 }
934
935 #[test]
936 fn test_state_filter_checkpoints() {
937 let mut state = create_test_state("agent-1");
938
939 let filter = StateFilter::new().with_checkpoints(false);
940 assert!(filter.matches(&state));
941
942 let filter = StateFilter::new().with_checkpoints(true);
943 assert!(!filter.matches(&state));
944
945 state.create_checkpoint(None);
946
947 let filter = StateFilter::new().with_checkpoints(true);
948 assert!(filter.matches(&state));
949 }
950
951 #[tokio::test]
952 async fn test_state_manager_save_load() {
953 let temp_dir = TempDir::new().unwrap();
954 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
955
956 let state = create_test_state("agent-1");
957 manager.save_state(&state).await.unwrap();
958
959 let loaded = manager.load_state("agent-1").await.unwrap();
960 assert!(loaded.is_some());
961 let loaded = loaded.unwrap();
962 assert_eq!(loaded.id, "agent-1");
963 assert_eq!(loaded.agent_type, "test_agent");
964 }
965
966 #[tokio::test]
967 async fn test_state_manager_load_nonexistent() {
968 let temp_dir = TempDir::new().unwrap();
969 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
970
971 let loaded = manager.load_state("nonexistent").await.unwrap();
972 assert!(loaded.is_none());
973 }
974
975 #[tokio::test]
976 async fn test_state_manager_delete() {
977 let temp_dir = TempDir::new().unwrap();
978 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
979
980 let state = create_test_state("agent-1");
981 manager.save_state(&state).await.unwrap();
982
983 let deleted = manager.delete_state("agent-1").await.unwrap();
984 assert!(deleted);
985
986 let loaded = manager.load_state("agent-1").await.unwrap();
987 assert!(loaded.is_none());
988
989 let deleted = manager.delete_state("agent-1").await.unwrap();
991 assert!(!deleted);
992 }
993
994 #[tokio::test]
995 async fn test_state_manager_list_states() {
996 let temp_dir = TempDir::new().unwrap();
997 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
998
999 for i in 1..=3 {
1001 let state = AgentState::new(format!("agent-{}", i), "test_agent", "prompt");
1002 manager.save_state(&state).await.unwrap();
1003 }
1004
1005 let states = manager.list_states(None).await.unwrap();
1006 assert_eq!(states.len(), 3);
1007 }
1008
1009 #[tokio::test]
1010 async fn test_state_manager_list_with_filter() {
1011 let temp_dir = TempDir::new().unwrap();
1012 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1013
1014 let state1 = AgentState::new("agent-1", "type_a", "prompt");
1016 let state2 = AgentState::new("agent-2", "type_b", "prompt");
1017 let state3 = AgentState::new("agent-3", "type_a", "prompt");
1018
1019 manager.save_state(&state1).await.unwrap();
1020 manager.save_state(&state2).await.unwrap();
1021 manager.save_state(&state3).await.unwrap();
1022
1023 let filter = StateFilter::new().with_agent_type("type_a");
1024 let states = manager.list_states(Some(filter)).await.unwrap();
1025 assert_eq!(states.len(), 2);
1026 }
1027
1028 #[tokio::test]
1029 async fn test_state_manager_list_with_limit() {
1030 let temp_dir = TempDir::new().unwrap();
1031 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1032
1033 for i in 1..=5 {
1034 let state = AgentState::new(format!("agent-{}", i), "test", "prompt");
1035 manager.save_state(&state).await.unwrap();
1036 }
1037
1038 let filter = StateFilter::new().with_limit(2);
1039 let states = manager.list_states(Some(filter)).await.unwrap();
1040 assert_eq!(states.len(), 2);
1041 }
1042
1043 #[tokio::test]
1044 async fn test_checkpoint_save_load() {
1045 let temp_dir = TempDir::new().unwrap();
1046 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1047
1048 let checkpoint = Checkpoint::new("agent-1", 5)
1049 .with_name("test_checkpoint")
1050 .with_results(vec![serde_json::json!({"result": 1})]);
1051
1052 manager.save_checkpoint(&checkpoint).await.unwrap();
1053
1054 let loaded = manager
1055 .load_checkpoint("agent-1", &checkpoint.id)
1056 .await
1057 .unwrap();
1058 assert!(loaded.is_some());
1059 let loaded = loaded.unwrap();
1060 assert_eq!(loaded.step, 5);
1061 assert_eq!(loaded.name, Some("test_checkpoint".to_string()));
1062 }
1063
1064 #[tokio::test]
1065 async fn test_list_checkpoints() {
1066 let temp_dir = TempDir::new().unwrap();
1067 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1068
1069 for step in [1, 3, 2] {
1071 let checkpoint = Checkpoint::new("agent-1", step);
1072 manager.save_checkpoint(&checkpoint).await.unwrap();
1073 }
1074
1075 let checkpoints = manager.list_checkpoints("agent-1").await.unwrap();
1076 assert_eq!(checkpoints.len(), 3);
1077 assert_eq!(checkpoints[0].step, 1);
1079 assert_eq!(checkpoints[1].step, 2);
1080 assert_eq!(checkpoints[2].step, 3);
1081 }
1082
1083 #[tokio::test]
1084 async fn test_state_count() {
1085 let temp_dir = TempDir::new().unwrap();
1086 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1087
1088 assert_eq!(manager.state_count().await.unwrap(), 0);
1089
1090 for i in 1..=3 {
1091 let state = create_test_state(&format!("agent-{}", i));
1092 manager.save_state(&state).await.unwrap();
1093 }
1094
1095 assert_eq!(manager.state_count().await.unwrap(), 3);
1096 }
1097
1098 #[tokio::test]
1099 async fn test_state_exists() {
1100 let temp_dir = TempDir::new().unwrap();
1101 let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1102
1103 assert!(!manager.state_exists("agent-1").await);
1104
1105 let state = create_test_state("agent-1");
1106 manager.save_state(&state).await.unwrap();
1107
1108 assert!(manager.state_exists("agent-1").await);
1109 }
1110}