Skip to main content

aster/agents/resume/
state_manager.rs

1//! Agent State Manager
2//!
3//! Manages agent state persistence including save/load,
4//! checkpoint management, and state cleanup.
5//!
6//! This module provides:
7//! - Agent state persistence to disk
8//! - State loading and listing
9//! - Checkpoint creation and management
10//! - Automatic cleanup of expired states
11
12use chrono::{DateTime, Duration as ChronoDuration, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::path::PathBuf;
16use std::time::Duration;
17use thiserror::Error;
18
19use crate::conversation::message::Message;
20
21/// Result type alias for state manager operations
22pub type StateManagerResult<T> = Result<T, StateManagerError>;
23
24/// Error types for state manager operations
25#[derive(Debug, Error)]
26pub enum StateManagerError {
27    /// State not found
28    #[error("State not found: {0}")]
29    NotFound(String),
30
31    /// Checkpoint not found
32    #[error("Checkpoint not found: {0}")]
33    CheckpointNotFound(String),
34
35    /// I/O error
36    #[error("IO error: {0}")]
37    Io(#[from] std::io::Error),
38
39    /// Serialization error
40    #[error("Serialization error: {0}")]
41    Serialization(String),
42
43    /// Invalid state
44    #[error("Invalid state: {0}")]
45    InvalidState(String),
46}
47
48impl From<serde_json::Error> for StateManagerError {
49    fn from(err: serde_json::Error) -> Self {
50        StateManagerError::Serialization(err.to_string())
51    }
52}
53
54/// Agent state status
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
56#[serde(rename_all = "snake_case")]
57pub enum AgentStateStatus {
58    /// Agent is currently running
59    #[default]
60    Running,
61    /// Agent is paused
62    Paused,
63    /// Agent completed successfully
64    Completed,
65    /// Agent failed with an error
66    Failed,
67    /// Agent was cancelled
68    Cancelled,
69}
70
71impl AgentStateStatus {
72    /// Check if the state is resumable
73    pub fn is_resumable(&self) -> bool {
74        matches!(self, Self::Running | Self::Paused | Self::Failed)
75    }
76
77    /// Check if the state is terminal (completed, cancelled)
78    pub fn is_terminal(&self) -> bool {
79        matches!(self, Self::Completed | Self::Cancelled)
80    }
81}
82
83/// Tool call record for state persistence
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(rename_all = "camelCase")]
86pub struct ToolCallRecord {
87    /// Tool call ID
88    pub id: String,
89    /// Tool name
90    pub tool_name: String,
91    /// Input parameters
92    pub input: serde_json::Value,
93    /// Output result (if completed)
94    pub output: Option<serde_json::Value>,
95    /// Whether the call succeeded
96    pub success: Option<bool>,
97    /// Error message if failed
98    pub error: Option<String>,
99    /// Timestamp
100    pub timestamp: DateTime<Utc>,
101}
102
103impl ToolCallRecord {
104    /// Create a new tool call record
105    pub fn new(tool_name: impl Into<String>, input: serde_json::Value) -> Self {
106        Self {
107            id: uuid::Uuid::new_v4().to_string(),
108            tool_name: tool_name.into(),
109            input,
110            output: None,
111            success: None,
112            error: None,
113            timestamp: Utc::now(),
114        }
115    }
116
117    /// Complete the tool call with success
118    pub fn complete_success(&mut self, output: serde_json::Value) {
119        self.output = Some(output);
120        self.success = Some(true);
121    }
122
123    /// Complete the tool call with failure
124    pub fn complete_failure(&mut self, error: impl Into<String>) {
125        self.success = Some(false);
126        self.error = Some(error.into());
127    }
128}
129
130/// Checkpoint for agent state recovery
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
132#[serde(rename_all = "camelCase")]
133pub struct Checkpoint {
134    /// Unique checkpoint identifier
135    pub id: String,
136    /// Agent ID this checkpoint belongs to
137    pub agent_id: String,
138    /// Checkpoint name/label
139    pub name: Option<String>,
140    /// Step number at checkpoint
141    pub step: usize,
142    /// Messages at checkpoint
143    pub messages: Vec<Message>,
144    /// Tool calls at checkpoint
145    pub tool_calls: Vec<ToolCallRecord>,
146    /// Results at checkpoint
147    pub results: Vec<serde_json::Value>,
148    /// Metadata at checkpoint
149    pub metadata: HashMap<String, serde_json::Value>,
150    /// Creation timestamp
151    pub created_at: DateTime<Utc>,
152}
153
154impl Checkpoint {
155    /// Create a new checkpoint
156    pub fn new(agent_id: impl Into<String>, step: usize) -> Self {
157        Self {
158            id: uuid::Uuid::new_v4().to_string(),
159            agent_id: agent_id.into(),
160            name: None,
161            step,
162            messages: Vec::new(),
163            tool_calls: Vec::new(),
164            results: Vec::new(),
165            metadata: HashMap::new(),
166            created_at: Utc::now(),
167        }
168    }
169
170    /// Set checkpoint name
171    pub fn with_name(mut self, name: impl Into<String>) -> Self {
172        self.name = Some(name.into());
173        self
174    }
175
176    /// Set messages
177    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
178        self.messages = messages;
179        self
180    }
181
182    /// Set tool calls
183    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCallRecord>) -> Self {
184        self.tool_calls = tool_calls;
185        self
186    }
187
188    /// Set results
189    pub fn with_results(mut self, results: Vec<serde_json::Value>) -> Self {
190        self.results = results;
191        self
192    }
193
194    /// Add metadata
195    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
196        self.metadata.insert(key.into(), value);
197        self
198    }
199}
200
201/// Agent state for persistence and recovery
202#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct AgentState {
205    /// Unique state identifier (same as agent ID)
206    pub id: String,
207    /// Agent type
208    pub agent_type: String,
209    /// Current status
210    pub status: AgentStateStatus,
211    /// Creation timestamp
212    pub created_at: DateTime<Utc>,
213    /// Last update timestamp
214    pub updated_at: DateTime<Utc>,
215    /// Original prompt
216    pub prompt: String,
217    /// Conversation messages
218    pub messages: Vec<Message>,
219    /// Tool call records
220    pub tool_calls: Vec<ToolCallRecord>,
221    /// Results collected
222    pub results: Vec<serde_json::Value>,
223    /// Current checkpoint (if any)
224    pub checkpoint: Option<Checkpoint>,
225    /// All checkpoints
226    pub checkpoints: Vec<Checkpoint>,
227    /// Current step number
228    pub current_step: usize,
229    /// Total steps (if known)
230    pub total_steps: Option<usize>,
231    /// Error count
232    pub error_count: usize,
233    /// Retry count
234    pub retry_count: usize,
235    /// Maximum retries allowed
236    pub max_retries: usize,
237    /// Custom metadata
238    pub metadata: HashMap<String, serde_json::Value>,
239}
240
241impl AgentState {
242    /// Create a new agent state
243    pub fn new(
244        id: impl Into<String>,
245        agent_type: impl Into<String>,
246        prompt: impl Into<String>,
247    ) -> Self {
248        let now = Utc::now();
249        Self {
250            id: id.into(),
251            agent_type: agent_type.into(),
252            status: AgentStateStatus::Running,
253            created_at: now,
254            updated_at: now,
255            prompt: prompt.into(),
256            messages: Vec::new(),
257            tool_calls: Vec::new(),
258            results: Vec::new(),
259            checkpoint: None,
260            checkpoints: Vec::new(),
261            current_step: 0,
262            total_steps: None,
263            error_count: 0,
264            retry_count: 0,
265            max_retries: 3,
266            metadata: HashMap::new(),
267        }
268    }
269
270    /// Set status
271    pub fn with_status(mut self, status: AgentStateStatus) -> Self {
272        self.status = status;
273        self.updated_at = Utc::now();
274        self
275    }
276
277    /// Set max retries
278    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
279        self.max_retries = max_retries;
280        self
281    }
282
283    /// Set total steps
284    pub fn with_total_steps(mut self, total: usize) -> Self {
285        self.total_steps = Some(total);
286        self
287    }
288
289    /// Add a message
290    pub fn add_message(&mut self, message: Message) {
291        self.messages.push(message);
292        self.updated_at = Utc::now();
293    }
294
295    /// Add a tool call
296    pub fn add_tool_call(&mut self, tool_call: ToolCallRecord) {
297        self.tool_calls.push(tool_call);
298        self.updated_at = Utc::now();
299    }
300
301    /// Add a result
302    pub fn add_result(&mut self, result: serde_json::Value) {
303        self.results.push(result);
304        self.updated_at = Utc::now();
305    }
306
307    /// Increment step
308    pub fn increment_step(&mut self) {
309        self.current_step += 1;
310        self.updated_at = Utc::now();
311    }
312
313    /// Record an error
314    pub fn record_error(&mut self) {
315        self.error_count += 1;
316        self.updated_at = Utc::now();
317    }
318
319    /// Record a retry
320    pub fn record_retry(&mut self) {
321        self.retry_count += 1;
322        self.updated_at = Utc::now();
323    }
324
325    /// Reset error state
326    pub fn reset_errors(&mut self) {
327        self.error_count = 0;
328        self.retry_count = 0;
329        self.updated_at = Utc::now();
330    }
331
332    /// Set metadata
333    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
334        self.metadata.insert(key.into(), value);
335        self.updated_at = Utc::now();
336    }
337
338    /// Create a checkpoint from current state
339    pub fn create_checkpoint(&mut self, name: Option<&str>) -> Checkpoint {
340        let mut checkpoint = Checkpoint::new(&self.id, self.current_step)
341            .with_messages(self.messages.clone())
342            .with_tool_calls(self.tool_calls.clone())
343            .with_results(self.results.clone());
344
345        if let Some(n) = name {
346            checkpoint = checkpoint.with_name(n);
347        }
348
349        for (k, v) in &self.metadata {
350            checkpoint = checkpoint.with_metadata(k.clone(), v.clone());
351        }
352
353        self.checkpoint = Some(checkpoint.clone());
354        self.checkpoints.push(checkpoint.clone());
355        self.updated_at = Utc::now();
356
357        checkpoint
358    }
359
360    /// Restore from a checkpoint
361    pub fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) {
362        self.current_step = checkpoint.step;
363        self.messages = checkpoint.messages.clone();
364        self.tool_calls = checkpoint.tool_calls.clone();
365        self.results = checkpoint.results.clone();
366        self.metadata = checkpoint.metadata.clone();
367        self.checkpoint = Some(checkpoint.clone());
368        self.updated_at = Utc::now();
369    }
370
371    /// Check if can be resumed
372    pub fn can_resume(&self) -> bool {
373        self.status.is_resumable()
374    }
375
376    /// Get the latest checkpoint
377    pub fn latest_checkpoint(&self) -> Option<&Checkpoint> {
378        self.checkpoints.last()
379    }
380
381    /// Get age of the state
382    pub fn age(&self) -> ChronoDuration {
383        Utc::now().signed_duration_since(self.created_at)
384    }
385
386    /// Check if state is expired based on max age
387    pub fn is_expired(&self, max_age: Duration) -> bool {
388        let age = self.age();
389        if let Ok(max_age_chrono) = ChronoDuration::from_std(max_age) {
390            age > max_age_chrono
391        } else {
392            false
393        }
394    }
395}
396
397impl PartialEq for AgentState {
398    fn eq(&self, other: &Self) -> bool {
399        self.id == other.id
400    }
401}
402
403impl Eq for AgentState {}
404
405/// Filter for listing agent states
406#[derive(Debug, Clone, Default)]
407pub struct StateFilter {
408    /// Filter by agent type
409    pub agent_type: Option<String>,
410    /// Filter by status
411    pub status: Option<AgentStateStatus>,
412    /// Filter by minimum creation time
413    pub created_after: Option<DateTime<Utc>>,
414    /// Filter by maximum creation time
415    pub created_before: Option<DateTime<Utc>>,
416    /// Filter by having checkpoints
417    pub has_checkpoints: Option<bool>,
418    /// Maximum number of results
419    pub limit: Option<usize>,
420}
421
422impl StateFilter {
423    /// Create a new filter
424    pub fn new() -> Self {
425        Self::default()
426    }
427
428    /// Filter by agent type
429    pub fn with_agent_type(mut self, agent_type: impl Into<String>) -> Self {
430        self.agent_type = Some(agent_type.into());
431        self
432    }
433
434    /// Filter by status
435    pub fn with_status(mut self, status: AgentStateStatus) -> Self {
436        self.status = Some(status);
437        self
438    }
439
440    /// Filter by creation time range
441    pub fn created_between(mut self, after: DateTime<Utc>, before: DateTime<Utc>) -> Self {
442        self.created_after = Some(after);
443        self.created_before = Some(before);
444        self
445    }
446
447    /// Filter by having checkpoints
448    pub fn with_checkpoints(mut self, has: bool) -> Self {
449        self.has_checkpoints = Some(has);
450        self
451    }
452
453    /// Limit results
454    pub fn with_limit(mut self, limit: usize) -> Self {
455        self.limit = Some(limit);
456        self
457    }
458
459    /// Check if a state matches this filter
460    pub fn matches(&self, state: &AgentState) -> bool {
461        if let Some(ref agent_type) = self.agent_type {
462            if &state.agent_type != agent_type {
463                return false;
464            }
465        }
466
467        if let Some(status) = self.status {
468            if state.status != status {
469                return false;
470            }
471        }
472
473        if let Some(after) = self.created_after {
474            if state.created_at < after {
475                return false;
476            }
477        }
478
479        if let Some(before) = self.created_before {
480            if state.created_at > before {
481                return false;
482            }
483        }
484
485        if let Some(has_checkpoints) = self.has_checkpoints {
486            let has = !state.checkpoints.is_empty();
487            if has != has_checkpoints {
488                return false;
489            }
490        }
491
492        true
493    }
494}
495
496/// Agent State Manager for persistence and recovery
497#[derive(Debug)]
498pub struct AgentStateManager {
499    /// Storage directory for states
500    storage_dir: PathBuf,
501}
502
503impl Default for AgentStateManager {
504    fn default() -> Self {
505        Self::new(None)
506    }
507}
508
509impl AgentStateManager {
510    /// Create a new AgentStateManager
511    pub fn new(storage_dir: Option<PathBuf>) -> Self {
512        let storage_dir = storage_dir.unwrap_or_else(|| PathBuf::from(".aster/states"));
513        Self { storage_dir }
514    }
515
516    /// Get the storage directory
517    pub fn storage_dir(&self) -> &PathBuf {
518        &self.storage_dir
519    }
520
521    /// Set the storage directory
522    pub fn set_storage_dir(&mut self, dir: PathBuf) {
523        self.storage_dir = dir;
524    }
525
526    /// Get the file path for a state
527    fn state_file_path(&self, id: &str) -> PathBuf {
528        self.storage_dir.join(format!("{}.json", id))
529    }
530
531    /// Get the checkpoints directory for an agent
532    fn checkpoints_dir(&self, agent_id: &str) -> PathBuf {
533        self.storage_dir.join("checkpoints").join(agent_id)
534    }
535
536    /// Get the file path for a checkpoint
537    fn checkpoint_file_path(&self, agent_id: &str, checkpoint_id: &str) -> PathBuf {
538        self.checkpoints_dir(agent_id)
539            .join(format!("{}.json", checkpoint_id))
540    }
541
542    /// Save agent state to disk
543    pub async fn save_state(&self, state: &AgentState) -> StateManagerResult<()> {
544        // Create storage directory if it doesn't exist
545        tokio::fs::create_dir_all(&self.storage_dir).await?;
546
547        let file_path = self.state_file_path(&state.id);
548        let json = serde_json::to_string_pretty(state)?;
549        tokio::fs::write(file_path, json).await?;
550
551        Ok(())
552    }
553
554    /// Load agent state from disk
555    pub async fn load_state(&self, id: &str) -> StateManagerResult<Option<AgentState>> {
556        let file_path = self.state_file_path(id);
557
558        if !file_path.exists() {
559            return Ok(None);
560        }
561
562        let json = tokio::fs::read_to_string(&file_path).await?;
563        let state: AgentState = serde_json::from_str(&json)?;
564
565        Ok(Some(state))
566    }
567
568    /// List all saved agent states with optional filtering
569    pub async fn list_states(
570        &self,
571        filter: Option<StateFilter>,
572    ) -> StateManagerResult<Vec<AgentState>> {
573        if !self.storage_dir.exists() {
574            return Ok(Vec::new());
575        }
576
577        let mut states = Vec::new();
578        let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
579
580        while let Some(entry) = entries.next_entry().await? {
581            let path = entry.path();
582
583            // Skip directories and non-JSON files
584            if path.is_dir() || path.extension().is_none_or(|ext| ext != "json") {
585                continue;
586            }
587
588            // Try to load the state
589            if let Ok(json) = tokio::fs::read_to_string(&path).await {
590                if let Ok(state) = serde_json::from_str::<AgentState>(&json) {
591                    // Apply filter if provided
592                    if let Some(ref f) = filter {
593                        if f.matches(&state) {
594                            states.push(state);
595                        }
596                    } else {
597                        states.push(state);
598                    }
599                }
600            }
601        }
602
603        // Sort by creation time (newest first)
604        states.sort_by(|a, b| b.created_at.cmp(&a.created_at));
605
606        // Apply limit if specified
607        if let Some(ref f) = filter {
608            if let Some(limit) = f.limit {
609                states.truncate(limit);
610            }
611        }
612
613        Ok(states)
614    }
615
616    /// Delete agent state from disk
617    pub async fn delete_state(&self, id: &str) -> StateManagerResult<bool> {
618        let file_path = self.state_file_path(id);
619
620        if !file_path.exists() {
621            return Ok(false);
622        }
623
624        tokio::fs::remove_file(&file_path).await?;
625
626        // Also delete checkpoints directory if it exists
627        let checkpoints_dir = self.checkpoints_dir(id);
628        if checkpoints_dir.exists() {
629            tokio::fs::remove_dir_all(&checkpoints_dir).await?;
630        }
631
632        Ok(true)
633    }
634
635    /// Cleanup expired states based on max age
636    pub async fn cleanup_expired(&self, max_age: Duration) -> StateManagerResult<usize> {
637        if !self.storage_dir.exists() {
638            return Ok(0);
639        }
640
641        let mut cleaned = 0;
642        let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
643
644        while let Some(entry) = entries.next_entry().await? {
645            let path = entry.path();
646
647            // Skip directories and non-JSON files
648            if path.is_dir() || path.extension().is_none_or(|ext| ext != "json") {
649                continue;
650            }
651
652            // Try to load and check if expired
653            if let Ok(json) = tokio::fs::read_to_string(&path).await {
654                if let Ok(state) = serde_json::from_str::<AgentState>(&json) {
655                    if state.is_expired(max_age) {
656                        // Delete the state file
657                        if tokio::fs::remove_file(&path).await.is_ok() {
658                            cleaned += 1;
659
660                            // Also delete checkpoints directory
661                            let checkpoints_dir = self.checkpoints_dir(&state.id);
662                            let _ = tokio::fs::remove_dir_all(&checkpoints_dir).await;
663                        }
664                    }
665                }
666            }
667        }
668
669        Ok(cleaned)
670    }
671
672    /// Save a checkpoint to disk
673    pub async fn save_checkpoint(&self, checkpoint: &Checkpoint) -> StateManagerResult<()> {
674        let checkpoints_dir = self.checkpoints_dir(&checkpoint.agent_id);
675        tokio::fs::create_dir_all(&checkpoints_dir).await?;
676
677        let file_path = self.checkpoint_file_path(&checkpoint.agent_id, &checkpoint.id);
678        let json = serde_json::to_string_pretty(checkpoint)?;
679        tokio::fs::write(file_path, json).await?;
680
681        Ok(())
682    }
683
684    /// Load a checkpoint from disk
685    pub async fn load_checkpoint(
686        &self,
687        agent_id: &str,
688        checkpoint_id: &str,
689    ) -> StateManagerResult<Option<Checkpoint>> {
690        let file_path = self.checkpoint_file_path(agent_id, checkpoint_id);
691
692        if !file_path.exists() {
693            return Ok(None);
694        }
695
696        let json = tokio::fs::read_to_string(&file_path).await?;
697        let checkpoint: Checkpoint = serde_json::from_str(&json)?;
698
699        Ok(Some(checkpoint))
700    }
701
702    /// List all checkpoints for an agent
703    pub async fn list_checkpoints(&self, agent_id: &str) -> StateManagerResult<Vec<Checkpoint>> {
704        let checkpoints_dir = self.checkpoints_dir(agent_id);
705
706        if !checkpoints_dir.exists() {
707            return Ok(Vec::new());
708        }
709
710        let mut checkpoints = Vec::new();
711        let mut entries = tokio::fs::read_dir(&checkpoints_dir).await?;
712
713        while let Some(entry) = entries.next_entry().await? {
714            let path = entry.path();
715
716            // Skip non-JSON files
717            if path.extension().is_none_or(|ext| ext != "json") {
718                continue;
719            }
720
721            if let Ok(json) = tokio::fs::read_to_string(&path).await {
722                if let Ok(checkpoint) = serde_json::from_str::<Checkpoint>(&json) {
723                    checkpoints.push(checkpoint);
724                }
725            }
726        }
727
728        // Sort by step number
729        checkpoints.sort_by_key(|c| c.step);
730
731        Ok(checkpoints)
732    }
733
734    /// Delete a checkpoint
735    pub async fn delete_checkpoint(
736        &self,
737        agent_id: &str,
738        checkpoint_id: &str,
739    ) -> StateManagerResult<bool> {
740        let file_path = self.checkpoint_file_path(agent_id, checkpoint_id);
741
742        if !file_path.exists() {
743            return Ok(false);
744        }
745
746        tokio::fs::remove_file(&file_path).await?;
747        Ok(true)
748    }
749
750    /// Check if a state exists
751    pub async fn state_exists(&self, id: &str) -> bool {
752        self.state_file_path(id).exists()
753    }
754
755    /// Get the count of saved states
756    pub async fn state_count(&self) -> StateManagerResult<usize> {
757        if !self.storage_dir.exists() {
758            return Ok(0);
759        }
760
761        let mut count = 0;
762        let mut entries = tokio::fs::read_dir(&self.storage_dir).await?;
763
764        while let Some(entry) = entries.next_entry().await? {
765            let path = entry.path();
766            if !path.is_dir() && path.extension().is_some_and(|ext| ext == "json") {
767                count += 1;
768            }
769        }
770
771        Ok(count)
772    }
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778    use tempfile::TempDir;
779
780    fn create_test_state(id: &str) -> AgentState {
781        AgentState::new(id, "test_agent", "Test prompt")
782    }
783
784    #[test]
785    fn test_agent_state_creation() {
786        let state = AgentState::new("agent-1", "test_agent", "Test prompt");
787
788        assert_eq!(state.id, "agent-1");
789        assert_eq!(state.agent_type, "test_agent");
790        assert_eq!(state.prompt, "Test prompt");
791        assert_eq!(state.status, AgentStateStatus::Running);
792        assert_eq!(state.current_step, 0);
793        assert_eq!(state.error_count, 0);
794        assert!(state.messages.is_empty());
795        assert!(state.checkpoints.is_empty());
796    }
797
798    #[test]
799    fn test_agent_state_status_resumable() {
800        assert!(AgentStateStatus::Running.is_resumable());
801        assert!(AgentStateStatus::Paused.is_resumable());
802        assert!(AgentStateStatus::Failed.is_resumable());
803        assert!(!AgentStateStatus::Completed.is_resumable());
804        assert!(!AgentStateStatus::Cancelled.is_resumable());
805    }
806
807    #[test]
808    fn test_agent_state_status_terminal() {
809        assert!(!AgentStateStatus::Running.is_terminal());
810        assert!(!AgentStateStatus::Paused.is_terminal());
811        assert!(!AgentStateStatus::Failed.is_terminal());
812        assert!(AgentStateStatus::Completed.is_terminal());
813        assert!(AgentStateStatus::Cancelled.is_terminal());
814    }
815
816    #[test]
817    fn test_agent_state_increment_step() {
818        let mut state = create_test_state("agent-1");
819        assert_eq!(state.current_step, 0);
820
821        state.increment_step();
822        assert_eq!(state.current_step, 1);
823
824        state.increment_step();
825        assert_eq!(state.current_step, 2);
826    }
827
828    #[test]
829    fn test_agent_state_error_tracking() {
830        let mut state = create_test_state("agent-1");
831        assert_eq!(state.error_count, 0);
832        assert_eq!(state.retry_count, 0);
833
834        state.record_error();
835        assert_eq!(state.error_count, 1);
836
837        state.record_retry();
838        assert_eq!(state.retry_count, 1);
839
840        state.reset_errors();
841        assert_eq!(state.error_count, 0);
842        assert_eq!(state.retry_count, 0);
843    }
844
845    #[test]
846    fn test_checkpoint_creation() {
847        let checkpoint = Checkpoint::new("agent-1", 5).with_name("test_checkpoint");
848
849        assert!(!checkpoint.id.is_empty());
850        assert_eq!(checkpoint.agent_id, "agent-1");
851        assert_eq!(checkpoint.step, 5);
852        assert_eq!(checkpoint.name, Some("test_checkpoint".to_string()));
853    }
854
855    #[test]
856    fn test_agent_state_create_checkpoint() {
857        let mut state = create_test_state("agent-1");
858        state.current_step = 3;
859        state.set_metadata("key", serde_json::json!("value"));
860
861        let checkpoint = state.create_checkpoint(Some("checkpoint-1"));
862
863        assert_eq!(checkpoint.agent_id, "agent-1");
864        assert_eq!(checkpoint.step, 3);
865        assert_eq!(checkpoint.name, Some("checkpoint-1".to_string()));
866        assert!(state.checkpoint.is_some());
867        assert_eq!(state.checkpoints.len(), 1);
868    }
869
870    #[test]
871    fn test_agent_state_restore_from_checkpoint() {
872        let mut state = create_test_state("agent-1");
873        state.current_step = 5;
874        state.add_result(serde_json::json!({"result": 1}));
875
876        let checkpoint = state.create_checkpoint(Some("cp-1"));
877
878        // Modify state
879        state.current_step = 10;
880        state.add_result(serde_json::json!({"result": 2}));
881
882        // Restore
883        state.restore_from_checkpoint(&checkpoint);
884
885        assert_eq!(state.current_step, 5);
886        assert_eq!(state.results.len(), 1);
887    }
888
889    #[test]
890    fn test_tool_call_record() {
891        let mut record = ToolCallRecord::new("test_tool", serde_json::json!({"arg": "value"}));
892
893        assert!(!record.id.is_empty());
894        assert_eq!(record.tool_name, "test_tool");
895        assert!(record.success.is_none());
896
897        record.complete_success(serde_json::json!({"output": "result"}));
898        assert_eq!(record.success, Some(true));
899        assert!(record.output.is_some());
900    }
901
902    #[test]
903    fn test_tool_call_record_failure() {
904        let mut record = ToolCallRecord::new("test_tool", serde_json::json!({}));
905        record.complete_failure("Test error");
906
907        assert_eq!(record.success, Some(false));
908        assert_eq!(record.error, Some("Test error".to_string()));
909    }
910
911    #[test]
912    fn test_state_filter_matches() {
913        let state = AgentState::new("agent-1", "test_agent", "prompt")
914            .with_status(AgentStateStatus::Running);
915
916        // Empty filter matches all
917        let filter = StateFilter::new();
918        assert!(filter.matches(&state));
919
920        // Type filter
921        let filter = StateFilter::new().with_agent_type("test_agent");
922        assert!(filter.matches(&state));
923
924        let filter = StateFilter::new().with_agent_type("other_agent");
925        assert!(!filter.matches(&state));
926
927        // Status filter
928        let filter = StateFilter::new().with_status(AgentStateStatus::Running);
929        assert!(filter.matches(&state));
930
931        let filter = StateFilter::new().with_status(AgentStateStatus::Completed);
932        assert!(!filter.matches(&state));
933    }
934
935    #[test]
936    fn test_state_filter_checkpoints() {
937        let mut state = create_test_state("agent-1");
938
939        let filter = StateFilter::new().with_checkpoints(false);
940        assert!(filter.matches(&state));
941
942        let filter = StateFilter::new().with_checkpoints(true);
943        assert!(!filter.matches(&state));
944
945        state.create_checkpoint(None);
946
947        let filter = StateFilter::new().with_checkpoints(true);
948        assert!(filter.matches(&state));
949    }
950
951    #[tokio::test]
952    async fn test_state_manager_save_load() {
953        let temp_dir = TempDir::new().unwrap();
954        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
955
956        let state = create_test_state("agent-1");
957        manager.save_state(&state).await.unwrap();
958
959        let loaded = manager.load_state("agent-1").await.unwrap();
960        assert!(loaded.is_some());
961        let loaded = loaded.unwrap();
962        assert_eq!(loaded.id, "agent-1");
963        assert_eq!(loaded.agent_type, "test_agent");
964    }
965
966    #[tokio::test]
967    async fn test_state_manager_load_nonexistent() {
968        let temp_dir = TempDir::new().unwrap();
969        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
970
971        let loaded = manager.load_state("nonexistent").await.unwrap();
972        assert!(loaded.is_none());
973    }
974
975    #[tokio::test]
976    async fn test_state_manager_delete() {
977        let temp_dir = TempDir::new().unwrap();
978        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
979
980        let state = create_test_state("agent-1");
981        manager.save_state(&state).await.unwrap();
982
983        let deleted = manager.delete_state("agent-1").await.unwrap();
984        assert!(deleted);
985
986        let loaded = manager.load_state("agent-1").await.unwrap();
987        assert!(loaded.is_none());
988
989        // Delete nonexistent
990        let deleted = manager.delete_state("agent-1").await.unwrap();
991        assert!(!deleted);
992    }
993
994    #[tokio::test]
995    async fn test_state_manager_list_states() {
996        let temp_dir = TempDir::new().unwrap();
997        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
998
999        // Save multiple states
1000        for i in 1..=3 {
1001            let state = AgentState::new(format!("agent-{}", i), "test_agent", "prompt");
1002            manager.save_state(&state).await.unwrap();
1003        }
1004
1005        let states = manager.list_states(None).await.unwrap();
1006        assert_eq!(states.len(), 3);
1007    }
1008
1009    #[tokio::test]
1010    async fn test_state_manager_list_with_filter() {
1011        let temp_dir = TempDir::new().unwrap();
1012        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1013
1014        // Save states with different types
1015        let state1 = AgentState::new("agent-1", "type_a", "prompt");
1016        let state2 = AgentState::new("agent-2", "type_b", "prompt");
1017        let state3 = AgentState::new("agent-3", "type_a", "prompt");
1018
1019        manager.save_state(&state1).await.unwrap();
1020        manager.save_state(&state2).await.unwrap();
1021        manager.save_state(&state3).await.unwrap();
1022
1023        let filter = StateFilter::new().with_agent_type("type_a");
1024        let states = manager.list_states(Some(filter)).await.unwrap();
1025        assert_eq!(states.len(), 2);
1026    }
1027
1028    #[tokio::test]
1029    async fn test_state_manager_list_with_limit() {
1030        let temp_dir = TempDir::new().unwrap();
1031        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1032
1033        for i in 1..=5 {
1034            let state = AgentState::new(format!("agent-{}", i), "test", "prompt");
1035            manager.save_state(&state).await.unwrap();
1036        }
1037
1038        let filter = StateFilter::new().with_limit(2);
1039        let states = manager.list_states(Some(filter)).await.unwrap();
1040        assert_eq!(states.len(), 2);
1041    }
1042
1043    #[tokio::test]
1044    async fn test_checkpoint_save_load() {
1045        let temp_dir = TempDir::new().unwrap();
1046        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1047
1048        let checkpoint = Checkpoint::new("agent-1", 5)
1049            .with_name("test_checkpoint")
1050            .with_results(vec![serde_json::json!({"result": 1})]);
1051
1052        manager.save_checkpoint(&checkpoint).await.unwrap();
1053
1054        let loaded = manager
1055            .load_checkpoint("agent-1", &checkpoint.id)
1056            .await
1057            .unwrap();
1058        assert!(loaded.is_some());
1059        let loaded = loaded.unwrap();
1060        assert_eq!(loaded.step, 5);
1061        assert_eq!(loaded.name, Some("test_checkpoint".to_string()));
1062    }
1063
1064    #[tokio::test]
1065    async fn test_list_checkpoints() {
1066        let temp_dir = TempDir::new().unwrap();
1067        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1068
1069        // Save multiple checkpoints
1070        for step in [1, 3, 2] {
1071            let checkpoint = Checkpoint::new("agent-1", step);
1072            manager.save_checkpoint(&checkpoint).await.unwrap();
1073        }
1074
1075        let checkpoints = manager.list_checkpoints("agent-1").await.unwrap();
1076        assert_eq!(checkpoints.len(), 3);
1077        // Should be sorted by step
1078        assert_eq!(checkpoints[0].step, 1);
1079        assert_eq!(checkpoints[1].step, 2);
1080        assert_eq!(checkpoints[2].step, 3);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_state_count() {
1085        let temp_dir = TempDir::new().unwrap();
1086        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1087
1088        assert_eq!(manager.state_count().await.unwrap(), 0);
1089
1090        for i in 1..=3 {
1091            let state = create_test_state(&format!("agent-{}", i));
1092            manager.save_state(&state).await.unwrap();
1093        }
1094
1095        assert_eq!(manager.state_count().await.unwrap(), 3);
1096    }
1097
1098    #[tokio::test]
1099    async fn test_state_exists() {
1100        let temp_dir = TempDir::new().unwrap();
1101        let manager = AgentStateManager::new(Some(temp_dir.path().to_path_buf()));
1102
1103        assert!(!manager.state_exists("agent-1").await);
1104
1105        let state = create_test_state("agent-1");
1106        manager.save_state(&state).await.unwrap();
1107
1108        assert!(manager.state_exists("agent-1").await);
1109    }
1110}