1use crate::error::AgentError;
2use bincode::{deserialize, serialize};
3use limit_llm::types::Message;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8use tracing::instrument;
9
10const STATE_DIR: &str = ".limit";
11const STATE_FILE: &str = "agent-state.bin";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Decision {
15 pub timestamp: u64,
16 pub action: String,
17 pub reason: String,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Todo {
22 pub id: String,
23 pub content: String,
24 pub status: TodoStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28pub enum TodoStatus {
29 Pending,
30 InProgress,
31 Done,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentState {
36 pub messages: Vec<Message>,
37 pub tool_results: HashMap<String, serde_json::Value>,
38 pub decisions: Vec<Decision>,
39 pub todos: Vec<Todo>,
40 pub iteration: u32,
41}
42
43impl Default for AgentState {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl AgentState {
50 pub fn new() -> Self {
51 Self {
52 messages: Vec::new(),
53 tool_results: HashMap::new(),
54 decisions: Vec::new(),
55 todos: Vec::new(),
56 iteration: 0,
57 }
58 }
59
60 fn state_path() -> PathBuf {
61 PathBuf::from(STATE_DIR).join(STATE_FILE)
62 }
63
64 #[instrument(skip(self))]
65 pub fn save_state(&self) -> Result<(), AgentError> {
66 let path = Self::state_path();
67
68 if let Some(parent) = path.parent() {
69 fs::create_dir_all(parent)?;
70 }
71
72 let encoded = serialize(self)
73 .map_err(|e| AgentError::BincodeError(format!("Serialization failed: {:?}", e)))?;
74 fs::write(&path, encoded)?;
75
76 Ok(())
77 }
78
79 #[instrument]
80 pub fn load_state() -> Result<Self, AgentError> {
81 let path = Self::state_path();
82
83 if !path.exists() {
84 return Ok(Self::new());
85 }
86
87 let encoded = fs::read(&path)?;
88 let _state: AgentState = deserialize(&encoded)
89 .map_err(|e| AgentError::BincodeError(format!("Deserialization failed: {:?}", e)))?;
90
91 Ok(_state)
92 }
93
94 #[instrument(skip(self))]
95 pub fn check_loop_detection(
96 &mut self,
97 tool_name: &str,
98 _args: &serde_json::Value,
99 ) -> Result<(), AgentError> {
100 Ok(())
101 }
102
103 pub fn iteration(&self) -> u32 {
104 self.iteration
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[allow(dead_code)]
113 fn create_test_message() -> Message {
114 Message {
115 role: limit_llm::types::Role::User,
116 content: Some("test message".to_string()),
117 tool_calls: None,
118 tool_call_id: None,
119 }
120 }
121
122 fn create_test_tool_args() -> serde_json::Value {
123 serde_json::json!({"arg": "value"})
124 }
125
126 #[test]
127 fn test_agent_state_default() {
128 let state = AgentState::default();
129 assert_eq!(state.iteration, 0);
130 assert!(state.messages.is_empty());
131 assert!(state.decisions.is_empty());
132 assert!(state.todos.is_empty());
133 assert!(state.tool_results.is_empty());
134 }
135
136 #[test]
137 fn test_agent_state_new() {
138 let state = AgentState::new();
139 assert_eq!(state.iteration, 0);
140 assert!(state.messages.is_empty());
141 }
142
143 #[test]
144 fn test_loop_detection() {
145 let mut state = AgentState::new();
146 let args = create_test_tool_args();
147
148 for _ in 0..10 {
150 state.check_loop_detection("test_tool", &args).unwrap();
151 }
152 }
153
154 #[test]
155 fn test_save_and_load_state() -> Result<(), AgentError> {
156 let path = AgentState::state_path();
157 let _ = fs::remove_file(&path);
158 let _ = fs::remove_dir_all(".limit");
159
160 let mut state = AgentState::new();
161
162 state.decisions.push(Decision {
163 timestamp: 1234567890,
164 action: "test_action".to_string(),
165 reason: "test reason".to_string(),
166 });
167 state.todos.push(Todo {
168 id: "todo_1".to_string(),
169 content: "test todo".to_string(),
170 status: TodoStatus::Pending,
171 });
172
173 state.save_state()?;
174
175 let loaded_state = AgentState::load_state()?;
176
177 assert_eq!(loaded_state.decisions.len(), 1);
178 assert_eq!(loaded_state.todos.len(), 1);
179
180 Ok(())
181 }
182}