1use crate::hitl::AgentInterrupt;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5#[derive(Debug, Default, Clone, Serialize, Deserialize)]
7pub struct AgentStateSnapshot {
8 pub todos: Vec<TodoItem>,
9 pub files: BTreeMap<String, String>,
10 pub scratchpad: BTreeMap<String, serde_json::Value>,
11
12 #[serde(default, skip_serializing_if = "Vec::is_empty")]
14 pub pending_interrupts: Vec<AgentInterrupt>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TodoItem {
19 pub content: String,
20 pub status: TodoStatus,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TodoStatus {
26 Pending,
27 InProgress,
28 Completed,
29}
30
31impl TodoItem {
32 pub fn pending(content: impl Into<String>) -> Self {
33 Self {
34 content: content.into(),
35 status: TodoStatus::Pending,
36 }
37 }
38}
39
40impl AgentStateSnapshot {
41 pub fn add_interrupt(&mut self, interrupt: AgentInterrupt) {
43 self.pending_interrupts.push(interrupt);
44 }
45
46 pub fn clear_interrupts(&mut self) {
48 self.pending_interrupts.clear();
49 }
50
51 pub fn has_pending_interrupts(&self) -> bool {
53 !self.pending_interrupts.is_empty()
54 }
55
56 pub fn merge(&mut self, other: AgentStateSnapshot) {
58 self.files.extend(other.files);
60
61 if !other.todos.is_empty() {
63 self.todos = other.todos;
64 }
65
66 self.scratchpad.extend(other.scratchpad);
68
69 if !other.pending_interrupts.is_empty() {
71 self.pending_interrupts = other.pending_interrupts;
72 }
73 }
74
75 pub fn reduce_files(
78 left: Option<BTreeMap<String, String>>,
79 right: Option<BTreeMap<String, String>>,
80 ) -> Option<BTreeMap<String, String>> {
81 match (left, right) {
82 (None, None) => None,
83 (Some(l), None) => Some(l),
84 (None, Some(r)) => Some(r),
85 (Some(mut l), Some(r)) => {
86 l.extend(r); Some(l)
88 }
89 }
90 }
91
92 pub fn with_merged_files(&self, new_files: Option<BTreeMap<String, String>>) -> Self {
94 let mut result = self.clone();
95 if let Some(files) = new_files {
96 result.files.extend(files);
97 }
98 result
99 }
100
101 pub fn with_updated_todos(&self, new_todos: Vec<TodoItem>) -> Self {
102 if new_todos.is_empty() {
103 self.clone()
104 } else {
105 let mut result = self.clone();
106 result.todos = new_todos;
107 result
108 }
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_file_reducer_both_none() {
118 let result = AgentStateSnapshot::reduce_files(None, None);
119 assert!(result.is_none());
120 }
121
122 #[test]
123 fn test_file_reducer_left_some_right_none() {
124 let mut left = BTreeMap::new();
125 left.insert("file1.txt".to_string(), "content1".to_string());
126
127 let result = AgentStateSnapshot::reduce_files(Some(left.clone()), None);
128 assert_eq!(result, Some(left));
129 }
130
131 #[test]
132 fn test_file_reducer_left_none_right_some() {
133 let mut right = BTreeMap::new();
134 right.insert("file2.txt".to_string(), "content2".to_string());
135
136 let result = AgentStateSnapshot::reduce_files(None, Some(right.clone()));
137 assert_eq!(result, Some(right));
138 }
139
140 #[test]
141 fn test_file_reducer_both_some_merges() {
142 let mut left = BTreeMap::new();
143 left.insert("file1.txt".to_string(), "content1".to_string());
144 left.insert("shared.txt".to_string(), "old_content".to_string());
145
146 let mut right = BTreeMap::new();
147 right.insert("file2.txt".to_string(), "content2".to_string());
148 right.insert("shared.txt".to_string(), "new_content".to_string());
149
150 let result = AgentStateSnapshot::reduce_files(Some(left), Some(right)).unwrap();
151
152 assert_eq!(result.get("file1.txt").unwrap(), "content1");
154 assert_eq!(result.get("file2.txt").unwrap(), "content2");
155 assert_eq!(result.get("shared.txt").unwrap(), "new_content"); assert_eq!(result.len(), 3);
157 }
158
159 #[test]
160 fn test_merge_combines_states() {
161 let mut state1 = AgentStateSnapshot::default();
162 state1
163 .files
164 .insert("file1.txt".to_string(), "content1".to_string());
165 state1.todos.push(TodoItem::pending("task1"));
166 state1
167 .scratchpad
168 .insert("key1".to_string(), serde_json::json!("value1"));
169
170 let mut state2 = AgentStateSnapshot::default();
171 state2
172 .files
173 .insert("file2.txt".to_string(), "content2".to_string());
174 state2.todos.push(TodoItem::pending("task2"));
175 state2
176 .scratchpad
177 .insert("key2".to_string(), serde_json::json!("value2"));
178
179 let mut merged = state1.clone();
180 merged.merge(state2);
181
182 assert_eq!(merged.files.len(), 2);
184 assert_eq!(merged.files.get("file1.txt").unwrap(), "content1");
185 assert_eq!(merged.files.get("file2.txt").unwrap(), "content2");
186
187 assert_eq!(merged.todos.len(), 1);
189 assert_eq!(merged.todos[0].content, "task2");
190
191 assert_eq!(merged.scratchpad.len(), 2);
193 assert_eq!(merged.scratchpad.get("key1").unwrap(), "value1");
194 assert_eq!(merged.scratchpad.get("key2").unwrap(), "value2");
195 }
196
197 #[test]
198 fn test_merge_empty_todos_preserves_existing() {
199 let mut state1 = AgentStateSnapshot::default();
200 state1.todos.push(TodoItem::pending("task1"));
201
202 let state2 = AgentStateSnapshot::default(); let mut merged = state1.clone();
205 merged.merge(state2);
206
207 assert_eq!(merged.todos.len(), 1);
209 assert_eq!(merged.todos[0].content, "task1");
210 }
211
212 #[test]
213 fn test_with_merged_files() {
214 let mut state = AgentStateSnapshot::default();
215 state
216 .files
217 .insert("existing.txt".to_string(), "existing".to_string());
218
219 let mut new_files = BTreeMap::new();
220 new_files.insert("new.txt".to_string(), "new_content".to_string());
221 new_files.insert("existing.txt".to_string(), "updated".to_string()); let result = state.with_merged_files(Some(new_files));
224
225 assert_eq!(result.files.len(), 2);
226 assert_eq!(result.files.get("existing.txt").unwrap(), "updated");
227 assert_eq!(result.files.get("new.txt").unwrap(), "new_content");
228 }
229
230 #[test]
231 fn test_with_updated_todos() {
232 let mut state = AgentStateSnapshot::default();
233 state.todos.push(TodoItem::pending("old_task"));
234
235 let new_todos = vec![
236 TodoItem::pending("new_task1"),
237 TodoItem::pending("new_task2"),
238 ];
239
240 let result = state.with_updated_todos(new_todos);
241
242 assert_eq!(result.todos.len(), 2);
243 assert_eq!(result.todos[0].content, "new_task1");
244 assert_eq!(result.todos[1].content, "new_task2");
245 }
246
247 #[test]
248 fn test_with_updated_todos_empty_preserves_existing() {
249 let mut state = AgentStateSnapshot::default();
250 state.todos.push(TodoItem::pending("existing_task"));
251
252 let result = state.with_updated_todos(vec![]);
253
254 assert_eq!(result.todos.len(), 1);
256 assert_eq!(result.todos[0].content, "existing_task");
257 }
258
259 #[test]
260 fn test_add_interrupt() {
261 use crate::hitl::{AgentInterrupt, HitlInterrupt};
262 use serde_json::json;
263
264 let mut state = AgentStateSnapshot::default();
265 assert!(!state.has_pending_interrupts());
266
267 let interrupt = AgentInterrupt::HumanInLoop(HitlInterrupt::new(
268 "test_tool",
269 json!({"arg": "value"}),
270 "call_123",
271 Some("Test note".to_string()),
272 ));
273
274 state.add_interrupt(interrupt);
275
276 assert!(state.has_pending_interrupts());
277 assert_eq!(state.pending_interrupts.len(), 1);
278 }
279
280 #[test]
281 fn test_clear_interrupts() {
282 use crate::hitl::{AgentInterrupt, HitlInterrupt};
283 use serde_json::json;
284
285 let mut state = AgentStateSnapshot::default();
286
287 let interrupt = AgentInterrupt::HumanInLoop(HitlInterrupt::new(
288 "test_tool",
289 json!({}),
290 "call_123",
291 None,
292 ));
293
294 state.add_interrupt(interrupt);
295 assert!(state.has_pending_interrupts());
296
297 state.clear_interrupts();
298 assert!(!state.has_pending_interrupts());
299 assert_eq!(state.pending_interrupts.len(), 0);
300 }
301
302 #[test]
303 fn test_multiple_interrupts() {
304 use crate::hitl::{AgentInterrupt, HitlInterrupt};
305 use serde_json::json;
306
307 let mut state = AgentStateSnapshot::default();
308
309 let interrupt1 =
310 AgentInterrupt::HumanInLoop(HitlInterrupt::new("tool1", json!({}), "call_1", None));
311
312 let interrupt2 =
313 AgentInterrupt::HumanInLoop(HitlInterrupt::new("tool2", json!({}), "call_2", None));
314
315 state.add_interrupt(interrupt1);
316 state.add_interrupt(interrupt2);
317
318 assert_eq!(state.pending_interrupts.len(), 2);
319 }
320
321 #[test]
322 fn test_merge_with_interrupts() {
323 use crate::hitl::{AgentInterrupt, HitlInterrupt};
324 use serde_json::json;
325
326 let mut state1 = AgentStateSnapshot::default();
327 let interrupt1 =
328 AgentInterrupt::HumanInLoop(HitlInterrupt::new("tool1", json!({}), "call_1", None));
329 state1.add_interrupt(interrupt1);
330
331 let mut state2 = AgentStateSnapshot::default();
332 let interrupt2 =
333 AgentInterrupt::HumanInLoop(HitlInterrupt::new("tool2", json!({}), "call_2", None));
334 state2.add_interrupt(interrupt2);
335
336 state1.merge(state2);
337
338 assert_eq!(state1.pending_interrupts.len(), 1);
340 }
341
342 #[test]
343 fn test_merge_empty_interrupts_preserves_existing() {
344 use crate::hitl::{AgentInterrupt, HitlInterrupt};
345 use serde_json::json;
346
347 let mut state1 = AgentStateSnapshot::default();
348 let interrupt =
349 AgentInterrupt::HumanInLoop(HitlInterrupt::new("tool1", json!({}), "call_1", None));
350 state1.add_interrupt(interrupt);
351
352 let state2 = AgentStateSnapshot::default(); state1.merge(state2);
355
356 assert_eq!(state1.pending_interrupts.len(), 1);
358 }
359
360 #[test]
361 fn test_state_serialization_with_interrupts() {
362 use crate::hitl::{AgentInterrupt, HitlInterrupt};
363 use serde_json::json;
364
365 let mut state = AgentStateSnapshot::default();
366 let interrupt = AgentInterrupt::HumanInLoop(HitlInterrupt::new(
367 "test_tool",
368 json!({"arg": "value"}),
369 "call_123",
370 Some("Test note".to_string()),
371 ));
372 state.add_interrupt(interrupt);
373
374 let json = serde_json::to_string(&state).unwrap();
376 assert!(json.contains("pending_interrupts"));
377 assert!(json.contains("test_tool"));
378
379 let deserialized: AgentStateSnapshot = serde_json::from_str(&json).unwrap();
381 assert_eq!(deserialized.pending_interrupts.len(), 1);
382 }
383
384 #[test]
385 fn test_state_serialization_without_interrupts() {
386 let state = AgentStateSnapshot::default();
387
388 let json = serde_json::to_string(&state).unwrap();
390
391 assert!(!json.contains("pending_interrupts"));
393 }
394}