matrixcode_core/agent/core/
state.rs1use std::collections::{HashMap, HashSet};
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use crate::providers::{ContentBlock, Message, MessageContent, Role, Usage};
19use crate::tools::ReadHistoryTracker;
20
21pub struct AgentState {
26 messages: Vec<Message>,
28
29 total_input_tokens: AtomicU64,
31
32 total_output_tokens: AtomicU64,
34
35 last_input_tokens: AtomicU64,
37
38 previewed_tool_inputs: HashSet<String>,
41
42 todo_reminder_count: HashMap<String, usize>,
45
46 read_history: ReadHistoryTracker,
49
50 pending_inputs: Vec<String>,
52}
53
54impl AgentState {
55 pub fn new() -> Self {
57 Self {
58 messages: Vec::new(),
59 total_input_tokens: AtomicU64::new(0),
60 total_output_tokens: AtomicU64::new(0),
61 last_input_tokens: AtomicU64::new(0),
62 previewed_tool_inputs: HashSet::new(),
63 todo_reminder_count: HashMap::new(),
64 read_history: ReadHistoryTracker::new(),
65 pending_inputs: Vec::new(),
66 }
67 }
68
69 pub fn add_message(&mut self, message: Message) {
71 self.messages.push(message);
72 }
73
74 pub fn messages(&self) -> &Vec<Message> {
76 &self.messages
77 }
78
79 pub fn messages_mut(&mut self) -> &mut Vec<Message> {
81 &mut self.messages
82 }
83
84 pub fn set_messages(&mut self, messages: Vec<Message>) {
89 let cleaned = Self::clean_orphaned_messages(messages);
90 self.messages = cleaned;
91 }
92
93 fn clean_orphaned_messages(messages: Vec<Message>) -> Vec<Message> {
98 if messages.is_empty() {
99 return messages;
100 }
101
102 let mut tool_use_ids: HashSet<String> = HashSet::new();
104 for msg in &messages {
105 if let MessageContent::Blocks(blocks) = &msg.content {
106 for block in blocks {
107 if let ContentBlock::ToolUse { id, .. } = block {
108 tool_use_ids.insert(id.clone());
109 }
110 }
111 }
112 }
113
114 let mut tool_result_ids: HashSet<String> = HashSet::new();
116 for msg in &messages {
117 if msg.role == Role::Tool {
118 if let MessageContent::Blocks(blocks) = &msg.content {
119 for block in blocks {
120 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
121 tool_result_ids.insert(tool_use_id.clone());
122 }
123 }
124 }
125 }
126 }
127
128 let orphaned_tool_use_ids: HashSet<&str> = tool_use_ids
130 .iter()
131 .filter(|id| !tool_result_ids.contains(*id))
132 .map(|s| s.as_str())
133 .collect();
134
135 let orphaned_tool_result_ids: HashSet<&str> = tool_result_ids
136 .iter()
137 .filter(|id| !tool_use_ids.contains(*id))
138 .map(|s| s.as_str())
139 .collect();
140
141 if orphaned_tool_use_ids.is_empty() && orphaned_tool_result_ids.is_empty() {
143 return messages;
144 }
145
146 log::warn!(
147 "Cleaning orphaned messages: {} tool_uses without results, {} tool_results without uses",
148 orphaned_tool_use_ids.len(),
149 orphaned_tool_result_ids.len()
150 );
151
152 let original_len = messages.len();
154 let mut cleaned = Vec::with_capacity(messages.len());
155 for msg in messages {
156 if msg.role == Role::Tool {
158 if let MessageContent::Blocks(blocks) = &msg.content {
159 let has_orphaned_result = blocks.iter().any(|b| {
160 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
161 orphaned_tool_result_ids.contains(tool_use_id.as_str())
162 } else {
163 false
164 }
165 });
166 if has_orphaned_result {
167 log::info!("Removing orphaned tool result message");
168 continue;
169 }
170 }
171 }
172
173 if let MessageContent::Blocks(blocks) = msg.content {
175 let filtered_blocks: Vec<ContentBlock> = blocks
176 .into_iter()
177 .filter(|b| {
178 if let ContentBlock::ToolUse { id, .. } = b {
179 if orphaned_tool_use_ids.contains(id.as_str()) {
180 log::info!("Removing orphaned tool_use block: {}", id);
181 return false;
182 }
183 }
184 true
185 })
186 .collect();
187
188 if !filtered_blocks.is_empty() {
190 cleaned.push(Message {
191 role: msg.role,
192 content: MessageContent::Blocks(filtered_blocks),
193 });
194 }
195 } else {
196 cleaned.push(msg);
197 }
198 }
199
200 log::info!(
201 "Message cleaning complete: {} messages -> {} messages",
202 original_len,
203 cleaned.len()
204 );
205
206 cleaned
207 }
208
209 pub fn track_usage(&self, usage: &Usage) {
211 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
212 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
213 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
214 }
215
216 pub fn total_input_tokens(&self) -> u64 {
218 self.total_input_tokens.load(Ordering::Relaxed)
219 }
220
221 pub fn total_output_tokens(&self) -> u64 {
223 self.total_output_tokens.load(Ordering::Relaxed)
224 }
225
226 pub fn last_input_tokens(&self) -> u64 {
228 self.last_input_tokens.load(Ordering::Relaxed)
229 }
230
231 pub fn set_total_input_tokens(&self, value: u64) {
233 self.total_input_tokens.store(value, Ordering::Relaxed);
234 }
235
236 pub fn set_total_output_tokens(&self, value: u64) {
238 self.total_output_tokens.store(value, Ordering::Relaxed);
239 }
240
241 pub fn set_last_input_tokens(&self, value: u64) {
243 self.last_input_tokens.store(value, Ordering::Relaxed);
244 }
245
246 pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
248 self.previewed_tool_inputs.insert(tool_id);
249 }
250
251 pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
253 self.previewed_tool_inputs.contains(tool_id)
254 }
255
256 pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
258 self.previewed_tool_inputs.remove(tool_id)
259 }
260
261 pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
264 let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
265 self.todo_reminder_count.insert(todo_hash, count);
266 count
267 }
268
269 pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
271 self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
272 }
273
274 pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
276 &self.todo_reminder_count
277 }
278
279 pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
281 &mut self.todo_reminder_count
282 }
283
284 pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
286 self.todo_reminder_count(todo_hash) >= max_reminders
287 }
288
289 pub fn read_history(&self) -> &ReadHistoryTracker {
291 &self.read_history
292 }
293
294 pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
296 &mut self.read_history
297 }
298
299 pub fn add_pending_input(&mut self, input: String) {
301 self.pending_inputs.push(input);
302 }
303
304 pub fn has_pending_inputs(&self) -> bool {
306 !self.pending_inputs.is_empty()
307 }
308
309 pub fn pending_inputs_vec(&self) -> &Vec<String> {
311 &self.pending_inputs
312 }
313
314 pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
316 &mut self.pending_inputs
317 }
318
319 pub fn take_pending_inputs(&mut self) -> Vec<String> {
321 std::mem::take(&mut self.pending_inputs)
322 }
323
324 pub fn pending_input_count(&self) -> usize {
326 self.pending_inputs.len()
327 }
328
329 pub fn message_count(&self) -> usize {
331 self.messages.len()
332 }
333
334 pub fn clear(&mut self) {
336 self.messages.clear();
337 self.total_input_tokens.store(0, Ordering::Relaxed);
338 self.total_output_tokens.store(0, Ordering::Relaxed);
339 self.last_input_tokens.store(0, Ordering::Relaxed);
340 self.previewed_tool_inputs.clear();
341 self.todo_reminder_count.clear();
342 self.read_history = ReadHistoryTracker::new();
343 self.pending_inputs.clear();
344 }
345}
346
347impl Default for AgentState {
348 fn default() -> Self {
349 Self::new()
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::providers::{MessageContent, Role};
357
358 fn create_test_message(text: &str) -> Message {
359 Message {
360 role: Role::User,
361 content: MessageContent::Text(text.to_string()),
362 }
363 }
364
365 #[test]
366 fn test_state_new_is_empty() {
367 let state = AgentState::new();
368
369 assert_eq!(state.message_count(), 0);
370 assert_eq!(state.total_input_tokens(), 0);
371 assert_eq!(state.total_output_tokens(), 0);
372 assert_eq!(state.last_input_tokens(), 0);
373 assert!(!state.has_pending_inputs());
374 assert_eq!(state.pending_input_count(), 0);
375 }
376
377 #[test]
378 fn test_state_add_message() {
379 let mut state = AgentState::new();
380
381 state.add_message(create_test_message("Hello"));
382 state.add_message(create_test_message("World"));
383
384 assert_eq!(state.message_count(), 2);
385 assert_eq!(state.messages().len(), 2);
386 }
387
388 #[test]
389 fn test_state_track_usage() {
390 let state = AgentState::new();
391 let usage = Usage {
392 input_tokens: 100,
393 output_tokens: 50,
394 cache_creation_input_tokens: 0,
395 cache_read_input_tokens: 0,
396 };
397
398 state.track_usage(&usage);
399
400 assert_eq!(state.total_input_tokens(), 100);
401 assert_eq!(state.total_output_tokens(), 50);
402 assert_eq!(state.last_input_tokens(), 100);
403
404 state.track_usage(&usage);
406 assert_eq!(state.total_input_tokens(), 200);
407 assert_eq!(state.total_output_tokens(), 100);
408 assert_eq!(state.last_input_tokens(), 100);
409 }
410
411 #[test]
412 fn test_state_previewed_tool_inputs() {
413 let mut state = AgentState::new();
414
415 assert!(!state.was_tool_input_previewed("tool_1"));
417
418 state.mark_tool_input_previewed("tool_1".to_string());
420 assert!(state.was_tool_input_previewed("tool_1"));
421 assert!(!state.was_tool_input_previewed("tool_2"));
422
423 let removed = state.remove_previewed_tool_input("tool_1");
425 assert!(removed, "should return true when removing existing item");
426 assert!(!state.was_tool_input_previewed("tool_1"));
427
428 let removed = state.remove_previewed_tool_input("tool_2");
430 assert!(!removed, "should return false when removing non-existent item");
431 }
432
433 #[test]
434 fn test_state_todo_reminders() {
435 let mut state = AgentState::new();
436 let todo_hash = "hash_123".to_string();
437
438 assert_eq!(state.todo_reminder_count(&todo_hash), 0);
440 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
441
442 let count = state.increment_todo_reminder(todo_hash.clone());
444 assert_eq!(count, 1);
445 assert_eq!(state.todo_reminder_count(&todo_hash), 1);
446 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
447
448 let count = state.increment_todo_reminder(todo_hash.clone());
450 assert_eq!(count, 2);
451 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
452
453 let count = state.increment_todo_reminder(todo_hash.clone());
455 assert_eq!(count, 3);
456 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
457 }
458
459 #[test]
460 fn test_state_pending_inputs() {
461 let mut state = AgentState::new();
462
463 assert!(!state.has_pending_inputs());
465 assert_eq!(state.pending_input_count(), 0);
466
467 state.add_pending_input("input 1".to_string());
469 state.add_pending_input("input 2".to_string());
470
471 assert!(state.has_pending_inputs());
472 assert_eq!(state.pending_input_count(), 2);
473
474 let inputs = state.take_pending_inputs();
476 assert_eq!(inputs.len(), 2);
477 assert_eq!(inputs[0], "input 1");
478 assert_eq!(inputs[1], "input 2");
479
480 assert!(!state.has_pending_inputs());
482 assert_eq!(state.pending_input_count(), 0);
483 }
484
485 #[test]
486 fn test_state_set_messages() {
487 let mut state = AgentState::new();
488 state.add_message(create_test_message("Old message"));
489
490 let new_messages = vec![
492 create_test_message("New 1"),
493 create_test_message("New 2"),
494 ];
495 state.set_messages(new_messages);
496
497 assert_eq!(state.message_count(), 2);
498 assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
499 }
500
501 #[test]
502 fn test_state_clear() {
503 let mut state = AgentState::new();
504
505 state.add_message(create_test_message("Test"));
507 state.track_usage(&Usage {
508 input_tokens: 100,
509 output_tokens: 50,
510 cache_creation_input_tokens: 0,
511 cache_read_input_tokens: 0,
512 });
513 state.add_pending_input("pending".to_string());
514 state.mark_tool_input_previewed("tool_1".to_string());
515
516 state.clear();
518
519 assert_eq!(state.message_count(), 0);
521 assert_eq!(state.total_input_tokens(), 0);
522 assert_eq!(state.total_output_tokens(), 0);
523 assert!(!state.has_pending_inputs());
524 assert!(!state.was_tool_input_previewed("tool_1"));
525 }
526}