Skip to main content

limit_agent/
state.rs

1use crate::error::AgentError;
2use bincode::{deserialize, serialize};
3use limit_llm::types::Message;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8use tracing::instrument;
9
10const STATE_DIR: &str = ".limit";
11const STATE_FILE: &str = "agent-state.bin";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Decision {
15    pub timestamp: u64,
16    pub action: String,
17    pub reason: String,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Todo {
22    pub id: String,
23    pub content: String,
24    pub status: TodoStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28pub enum TodoStatus {
29    Pending,
30    InProgress,
31    Done,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentState {
36    pub messages: Vec<Message>,
37    pub tool_results: HashMap<String, serde_json::Value>,
38    pub decisions: Vec<Decision>,
39    pub todos: Vec<Todo>,
40    pub iteration: u32,
41}
42
43impl Default for AgentState {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl AgentState {
50    pub fn new() -> Self {
51        Self {
52            messages: Vec::new(),
53            tool_results: HashMap::new(),
54            decisions: Vec::new(),
55            todos: Vec::new(),
56            iteration: 0,
57        }
58    }
59
60    fn state_path() -> PathBuf {
61        PathBuf::from(STATE_DIR).join(STATE_FILE)
62    }
63
64    #[instrument(skip(self))]
65    pub fn save_state(&self) -> Result<(), AgentError> {
66        let path = Self::state_path();
67
68        if let Some(parent) = path.parent() {
69            fs::create_dir_all(parent)?;
70        }
71
72        let encoded = serialize(self)
73            .map_err(|e| AgentError::BincodeError(format!("Serialization failed: {:?}", e)))?;
74        fs::write(&path, encoded)?;
75
76        Ok(())
77    }
78
79    #[instrument]
80    pub fn load_state() -> Result<Self, AgentError> {
81        let path = Self::state_path();
82
83        if !path.exists() {
84            return Ok(Self::new());
85        }
86
87        let encoded = fs::read(&path)?;
88        let _state: AgentState = deserialize(&encoded)
89            .map_err(|e| AgentError::BincodeError(format!("Deserialization failed: {:?}", e)))?;
90
91        Ok(_state)
92    }
93
94    #[instrument(skip(self))]
95    pub fn check_loop_detection(
96        &mut self,
97        tool_name: &str,
98        _args: &serde_json::Value,
99    ) -> Result<(), AgentError> {
100        Ok(())
101    }
102
103    pub fn iteration(&self) -> u32 {
104        self.iteration
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[allow(dead_code)]
113    fn create_test_message() -> Message {
114        Message {
115            role: limit_llm::types::Role::User,
116            content: Some("test message".to_string()),
117            tool_calls: None,
118            tool_call_id: None,
119            cache_control: None,
120        }
121    }
122
123    fn create_test_tool_args() -> serde_json::Value {
124        serde_json::json!({"arg": "value"})
125    }
126
127    #[test]
128    fn test_agent_state_default() {
129        let state = AgentState::default();
130        assert_eq!(state.iteration, 0);
131        assert!(state.messages.is_empty());
132        assert!(state.decisions.is_empty());
133        assert!(state.todos.is_empty());
134        assert!(state.tool_results.is_empty());
135    }
136
137    #[test]
138    fn test_agent_state_new() {
139        let state = AgentState::new();
140        assert_eq!(state.iteration, 0);
141        assert!(state.messages.is_empty());
142    }
143
144    #[test]
145    fn test_loop_detection() {
146        let mut state = AgentState::new();
147        let args = create_test_tool_args();
148
149        // Loop detection now always returns Ok
150        for _ in 0..10 {
151            state.check_loop_detection("test_tool", &args).unwrap();
152        }
153    }
154
155    #[test]
156    fn test_save_and_load_state() -> Result<(), AgentError> {
157        let path = AgentState::state_path();
158        let _ = fs::remove_file(&path);
159        let _ = fs::remove_dir_all(".limit");
160
161        let mut state = AgentState::new();
162
163        state.decisions.push(Decision {
164            timestamp: 1234567890,
165            action: "test_action".to_string(),
166            reason: "test reason".to_string(),
167        });
168        state.todos.push(Todo {
169            id: "todo_1".to_string(),
170            content: "test todo".to_string(),
171            status: TodoStatus::Pending,
172        });
173
174        state.save_state()?;
175
176        let loaded_state = AgentState::load_state()?;
177
178        assert_eq!(loaded_state.decisions.len(), 1);
179        assert_eq!(loaded_state.todos.len(), 1);
180
181        Ok(())
182    }
183}