1use crate::error::AgentError;
2use bincode::{deserialize, serialize};
3use limit_llm::types::Message;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8use tracing::instrument;
9
10const STATE_DIR: &str = ".limit";
11const STATE_FILE: &str = "agent-state.bin";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Decision {
15 pub timestamp: u64,
16 pub action: String,
17 pub reason: String,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Todo {
22 pub id: String,
23 pub content: String,
24 pub status: TodoStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28pub enum TodoStatus {
29 Pending,
30 InProgress,
31 Done,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentState {
36 pub messages: Vec<Message>,
37 pub tool_results: HashMap<String, serde_json::Value>,
38 pub decisions: Vec<Decision>,
39 pub todos: Vec<Todo>,
40 pub iteration: u32,
41}
42
43impl Default for AgentState {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl AgentState {
50 pub fn new() -> Self {
51 Self {
52 messages: Vec::new(),
53 tool_results: HashMap::new(),
54 decisions: Vec::new(),
55 todos: Vec::new(),
56 iteration: 0,
57 }
58 }
59
60 fn state_path() -> PathBuf {
61 PathBuf::from(STATE_DIR).join(STATE_FILE)
62 }
63
64 #[instrument(skip(self))]
65 pub fn save_state(&self) -> Result<(), AgentError> {
66 let path = Self::state_path();
67
68 if let Some(parent) = path.parent() {
69 fs::create_dir_all(parent)?;
70 }
71
72 let encoded = serialize(self)
73 .map_err(|e| AgentError::BincodeError(format!("Serialization failed: {:?}", e)))?;
74 fs::write(&path, encoded)?;
75
76 Ok(())
77 }
78
79 #[instrument]
80 pub fn load_state() -> Result<Self, AgentError> {
81 let path = Self::state_path();
82
83 if !path.exists() {
84 return Ok(Self::new());
85 }
86
87 let encoded = fs::read(&path)?;
88 let _state: AgentState = deserialize(&encoded)
89 .map_err(|e| AgentError::BincodeError(format!("Deserialization failed: {:?}", e)))?;
90
91 Ok(_state)
92 }
93
94 #[instrument(skip(self))]
95 pub fn check_loop_detection(
96 &mut self,
97 tool_name: &str,
98 _args: &serde_json::Value,
99 ) -> Result<(), AgentError> {
100 Ok(())
101 }
102
103 pub fn iteration(&self) -> u32 {
104 self.iteration
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[allow(dead_code)]
113 fn create_test_message() -> Message {
114 Message {
115 role: limit_llm::types::Role::User,
116 content: Some("test message".to_string()),
117 tool_calls: None,
118 tool_call_id: None,
119 cache_control: None,
120 }
121 }
122
123 fn create_test_tool_args() -> serde_json::Value {
124 serde_json::json!({"arg": "value"})
125 }
126
127 #[test]
128 fn test_agent_state_default() {
129 let state = AgentState::default();
130 assert_eq!(state.iteration, 0);
131 assert!(state.messages.is_empty());
132 assert!(state.decisions.is_empty());
133 assert!(state.todos.is_empty());
134 assert!(state.tool_results.is_empty());
135 }
136
137 #[test]
138 fn test_agent_state_new() {
139 let state = AgentState::new();
140 assert_eq!(state.iteration, 0);
141 assert!(state.messages.is_empty());
142 }
143
144 #[test]
145 fn test_loop_detection() {
146 let mut state = AgentState::new();
147 let args = create_test_tool_args();
148
149 for _ in 0..10 {
151 state.check_loop_detection("test_tool", &args).unwrap();
152 }
153 }
154
155 #[test]
156 fn test_save_and_load_state() -> Result<(), AgentError> {
157 let path = AgentState::state_path();
158 let _ = fs::remove_file(&path);
159 let _ = fs::remove_dir_all(".limit");
160
161 let mut state = AgentState::new();
162
163 state.decisions.push(Decision {
164 timestamp: 1234567890,
165 action: "test_action".to_string(),
166 reason: "test reason".to_string(),
167 });
168 state.todos.push(Todo {
169 id: "todo_1".to_string(),
170 content: "test todo".to_string(),
171 status: TodoStatus::Pending,
172 });
173
174 state.save_state()?;
175
176 let loaded_state = AgentState::load_state()?;
177
178 assert_eq!(loaded_state.decisions.len(), 1);
179 assert_eq!(loaded_state.todos.len(), 1);
180
181 Ok(())
182 }
183}