Skip to main content

limit_agent/
state.rs

1use crate::error::AgentError;
2use limit_llm::types::Message;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use tracing::instrument;
8
9const STATE_DIR: &str = ".limit";
10const STATE_FILE: &str = "agent-state.json";
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Decision {
14    pub timestamp: u64,
15    pub action: String,
16    pub reason: String,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Todo {
21    pub id: String,
22    pub content: String,
23    pub status: TodoStatus,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum TodoStatus {
28    Pending,
29    InProgress,
30    Done,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct AgentState {
35    pub messages: Vec<Message>,
36    pub tool_results: HashMap<String, serde_json::Value>,
37    pub decisions: Vec<Decision>,
38    pub todos: Vec<Todo>,
39    pub iteration: u32,
40}
41
42impl Default for AgentState {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl AgentState {
49    pub fn new() -> Self {
50        Self {
51            messages: Vec::new(),
52            tool_results: HashMap::new(),
53            decisions: Vec::new(),
54            todos: Vec::new(),
55            iteration: 0,
56        }
57    }
58
59    fn state_path() -> PathBuf {
60        PathBuf::from(STATE_DIR).join(STATE_FILE)
61    }
62
63    #[instrument(skip(self))]
64    pub fn save_state(&self) -> Result<(), AgentError> {
65        let path = Self::state_path();
66
67        if let Some(parent) = path.parent() {
68            fs::create_dir_all(parent)?;
69        }
70
71        let encoded = serde_json::to_string_pretty(self).map_err(|e| {
72            AgentError::SerializationError(format!("Serialization failed: {:?}", e))
73        })?;
74        fs::write(&path, encoded)?;
75
76        Ok(())
77    }
78
79    #[instrument]
80    pub fn load_state() -> Result<Self, AgentError> {
81        let path = Self::state_path();
82
83        if !path.exists() {
84            return Ok(Self::new());
85        }
86
87        let encoded = fs::read_to_string(&path)?;
88        let _state: AgentState = serde_json::from_str(&encoded).map_err(|e| {
89            AgentError::SerializationError(format!("Deserialization failed: {:?}", e))
90        })?;
91
92        Ok(_state)
93    }
94
95    #[instrument(skip(self))]
96    pub fn check_loop_detection(
97        &mut self,
98        tool_name: &str,
99        _args: &serde_json::Value,
100    ) -> Result<(), AgentError> {
101        Ok(())
102    }
103
104    pub fn iteration(&self) -> u32 {
105        self.iteration
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[allow(dead_code)]
114    fn create_test_message() -> Message {
115        Message {
116            role: limit_llm::types::Role::User,
117            content: Some("test message".to_string()),
118            tool_calls: None,
119            tool_call_id: None,
120            cache_control: None,
121        }
122    }
123
124    fn create_test_tool_args() -> serde_json::Value {
125        serde_json::json!({"arg": "value"})
126    }
127
128    #[test]
129    fn test_agent_state_default() {
130        let state = AgentState::default();
131        assert_eq!(state.iteration, 0);
132        assert!(state.messages.is_empty());
133        assert!(state.decisions.is_empty());
134        assert!(state.todos.is_empty());
135        assert!(state.tool_results.is_empty());
136    }
137
138    #[test]
139    fn test_agent_state_new() {
140        let state = AgentState::new();
141        assert_eq!(state.iteration, 0);
142        assert!(state.messages.is_empty());
143    }
144
145    #[test]
146    fn test_loop_detection() {
147        let mut state = AgentState::new();
148        let args = create_test_tool_args();
149
150        // Loop detection now always returns Ok
151        for _ in 0..10 {
152            state.check_loop_detection("test_tool", &args).unwrap();
153        }
154    }
155
156    #[test]
157    fn test_save_and_load_state() -> Result<(), AgentError> {
158        let path = AgentState::state_path();
159        let _ = fs::remove_file(&path);
160        let _ = fs::remove_dir_all(".limit");
161
162        let mut state = AgentState::new();
163
164        state.decisions.push(Decision {
165            timestamp: 1234567890,
166            action: "test_action".to_string(),
167            reason: "test reason".to_string(),
168        });
169        state.todos.push(Todo {
170            id: "todo_1".to_string(),
171            content: "test todo".to_string(),
172            status: TodoStatus::Pending,
173        });
174
175        state.save_state()?;
176
177        let loaded_state = AgentState::load_state()?;
178
179        assert_eq!(loaded_state.decisions.len(), 1);
180        assert_eq!(loaded_state.todos.len(), 1);
181
182        Ok(())
183    }
184}