Skip to main content

limit_agent/
state.rs

1use crate::error::AgentError;
2use bincode::{deserialize, serialize};
3use limit_llm::types::Message;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8use tracing::instrument;
9
10const STATE_DIR: &str = ".limit";
11const STATE_FILE: &str = "agent-state.bin";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Decision {
15    pub timestamp: u64,
16    pub action: String,
17    pub reason: String,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Todo {
22    pub id: String,
23    pub content: String,
24    pub status: TodoStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28pub enum TodoStatus {
29    Pending,
30    InProgress,
31    Done,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentState {
36    pub messages: Vec<Message>,
37    pub tool_results: HashMap<String, serde_json::Value>,
38    pub decisions: Vec<Decision>,
39    pub todos: Vec<Todo>,
40    pub iteration: u32,
41}
42
43impl Default for AgentState {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl AgentState {
50    pub fn new() -> Self {
51        Self {
52            messages: Vec::new(),
53            tool_results: HashMap::new(),
54            decisions: Vec::new(),
55            todos: Vec::new(),
56            iteration: 0,
57        }
58    }
59
60    fn state_path() -> PathBuf {
61        PathBuf::from(STATE_DIR).join(STATE_FILE)
62    }
63
64    #[instrument(skip(self))]
65    pub fn save_state(&self) -> Result<(), AgentError> {
66        let path = Self::state_path();
67
68        if let Some(parent) = path.parent() {
69            fs::create_dir_all(parent)?;
70        }
71
72        let encoded = serialize(self)
73            .map_err(|e| AgentError::BincodeError(format!("Serialization failed: {:?}", e)))?;
74        fs::write(&path, encoded)?;
75
76        Ok(())
77    }
78
79    #[instrument]
80    pub fn load_state() -> Result<Self, AgentError> {
81        let path = Self::state_path();
82
83        if !path.exists() {
84            return Ok(Self::new());
85        }
86
87        let encoded = fs::read(&path)?;
88        let _state: AgentState = deserialize(&encoded)
89            .map_err(|e| AgentError::BincodeError(format!("Deserialization failed: {:?}", e)))?;
90
91        Ok(_state)
92    }
93
94    #[instrument(skip(self))]
95    pub fn check_loop_detection(
96        &mut self,
97        tool_name: &str,
98        _args: &serde_json::Value,
99    ) -> Result<(), AgentError> {
100        Ok(())
101    }
102
103    pub fn iteration(&self) -> u32 {
104        self.iteration
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[allow(dead_code)]
113    fn create_test_message() -> Message {
114        Message {
115            role: limit_llm::types::Role::User,
116            content: Some("test message".to_string()),
117            tool_calls: None,
118            tool_call_id: None,
119        }
120    }
121
122    fn create_test_tool_args() -> serde_json::Value {
123        serde_json::json!({"arg": "value"})
124    }
125
126    #[test]
127    fn test_agent_state_default() {
128        let state = AgentState::default();
129        assert_eq!(state.iteration, 0);
130        assert!(state.messages.is_empty());
131        assert!(state.decisions.is_empty());
132        assert!(state.todos.is_empty());
133        assert!(state.tool_results.is_empty());
134    }
135
136    #[test]
137    fn test_agent_state_new() {
138        let state = AgentState::new();
139        assert_eq!(state.iteration, 0);
140        assert!(state.messages.is_empty());
141    }
142
143    #[test]
144    fn test_loop_detection() {
145        let mut state = AgentState::new();
146        let args = create_test_tool_args();
147
148        // Loop detection now always returns Ok
149        for _ in 0..10 {
150            state.check_loop_detection("test_tool", &args).unwrap();
151        }
152    }
153
154    #[test]
155    fn test_save_and_load_state() -> Result<(), AgentError> {
156        let path = AgentState::state_path();
157        let _ = fs::remove_file(&path);
158        let _ = fs::remove_dir_all(".limit");
159
160        let mut state = AgentState::new();
161
162        state.decisions.push(Decision {
163            timestamp: 1234567890,
164            action: "test_action".to_string(),
165            reason: "test reason".to_string(),
166        });
167        state.todos.push(Todo {
168            id: "todo_1".to_string(),
169            content: "test todo".to_string(),
170            status: TodoStatus::Pending,
171        });
172
173        state.save_state()?;
174
175        let loaded_state = AgentState::load_state()?;
176
177        assert_eq!(loaded_state.decisions.len(), 1);
178        assert_eq!(loaded_state.todos.len(), 1);
179
180        Ok(())
181    }
182}