1use crate::error::AgentError;
2use limit_llm::types::Message;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs;
6use std::path::PathBuf;
7use tracing::instrument;
8
9const STATE_DIR: &str = ".limit";
10const STATE_FILE: &str = "agent-state.json";
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Decision {
14 pub timestamp: u64,
15 pub action: String,
16 pub reason: String,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Todo {
21 pub id: String,
22 pub content: String,
23 pub status: TodoStatus,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub enum TodoStatus {
28 Pending,
29 InProgress,
30 Done,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct AgentState {
35 pub messages: Vec<Message>,
36 pub tool_results: HashMap<String, serde_json::Value>,
37 pub decisions: Vec<Decision>,
38 pub todos: Vec<Todo>,
39 pub iteration: u32,
40}
41
42impl Default for AgentState {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl AgentState {
49 pub fn new() -> Self {
50 Self {
51 messages: Vec::new(),
52 tool_results: HashMap::new(),
53 decisions: Vec::new(),
54 todos: Vec::new(),
55 iteration: 0,
56 }
57 }
58
59 fn state_path() -> PathBuf {
60 PathBuf::from(STATE_DIR).join(STATE_FILE)
61 }
62
63 #[instrument(skip(self))]
64 pub fn save_state(&self) -> Result<(), AgentError> {
65 let path = Self::state_path();
66
67 if let Some(parent) = path.parent() {
68 fs::create_dir_all(parent)?;
69 }
70
71 let encoded = serde_json::to_string_pretty(self).map_err(|e| {
72 AgentError::SerializationError(format!("Serialization failed: {:?}", e))
73 })?;
74 fs::write(&path, encoded)?;
75
76 Ok(())
77 }
78
79 #[instrument]
80 pub fn load_state() -> Result<Self, AgentError> {
81 let path = Self::state_path();
82
83 if !path.exists() {
84 return Ok(Self::new());
85 }
86
87 let encoded = fs::read_to_string(&path)?;
88 let _state: AgentState = serde_json::from_str(&encoded).map_err(|e| {
89 AgentError::SerializationError(format!("Deserialization failed: {:?}", e))
90 })?;
91
92 Ok(_state)
93 }
94
95 #[instrument(skip(self))]
96 pub fn check_loop_detection(
97 &mut self,
98 tool_name: &str,
99 _args: &serde_json::Value,
100 ) -> Result<(), AgentError> {
101 Ok(())
102 }
103
104 pub fn iteration(&self) -> u32 {
105 self.iteration
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[allow(dead_code)]
114 fn create_test_message() -> Message {
115 Message {
116 role: limit_llm::types::Role::User,
117 content: Some("test message".to_string()),
118 tool_calls: None,
119 tool_call_id: None,
120 cache_control: None,
121 }
122 }
123
124 fn create_test_tool_args() -> serde_json::Value {
125 serde_json::json!({"arg": "value"})
126 }
127
128 #[test]
129 fn test_agent_state_default() {
130 let state = AgentState::default();
131 assert_eq!(state.iteration, 0);
132 assert!(state.messages.is_empty());
133 assert!(state.decisions.is_empty());
134 assert!(state.todos.is_empty());
135 assert!(state.tool_results.is_empty());
136 }
137
138 #[test]
139 fn test_agent_state_new() {
140 let state = AgentState::new();
141 assert_eq!(state.iteration, 0);
142 assert!(state.messages.is_empty());
143 }
144
145 #[test]
146 fn test_loop_detection() {
147 let mut state = AgentState::new();
148 let args = create_test_tool_args();
149
150 for _ in 0..10 {
152 state.check_loop_detection("test_tool", &args).unwrap();
153 }
154 }
155
156 #[test]
157 fn test_save_and_load_state() -> Result<(), AgentError> {
158 let path = AgentState::state_path();
159 let _ = fs::remove_file(&path);
160 let _ = fs::remove_dir_all(".limit");
161
162 let mut state = AgentState::new();
163
164 state.decisions.push(Decision {
165 timestamp: 1234567890,
166 action: "test_action".to_string(),
167 reason: "test reason".to_string(),
168 });
169 state.todos.push(Todo {
170 id: "todo_1".to_string(),
171 content: "test todo".to_string(),
172 status: TodoStatus::Pending,
173 });
174
175 state.save_state()?;
176
177 let loaded_state = AgentState::load_state()?;
178
179 assert_eq!(loaded_state.decisions.len(), 1);
180 assert_eq!(loaded_state.todos.len(), 1);
181
182 Ok(())
183 }
184}