matrixcode_core/agent/core/
state.rs1use std::collections::{HashMap, HashSet};
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use crate::providers::{ContentBlock, Message, MessageContent, Role, Usage};
20use crate::tools::ReadHistoryTracker;
21
22pub const MAX_SAME_ERROR_COUNT: usize = 3;
24
25#[derive(Debug, Clone)]
27pub struct ToolErrorEntry {
28 pub tool_name: String,
30 pub error_key: String,
32 pub count: usize,
34 pub last_occurrence: std::time::Instant,
36}
37
38impl ToolErrorEntry {
39 pub fn new(tool_name: &str, error_msg: &str) -> Self {
41 let error_key = if error_msg.len() > 100 {
44 error_msg[..100].to_string()
45 } else {
46 error_msg.to_string()
47 };
48
49 Self {
50 tool_name: tool_name.to_string(),
51 error_key,
52 count: 1,
53 last_occurrence: std::time::Instant::now(),
54 }
55 }
56
57 pub fn matches(&self, tool_name: &str, error_msg: &str) -> bool {
59 let new_key = if error_msg.chars().count() > 100 {
61 error_msg.chars().take(100).collect::<String>()
62 } else {
63 error_msg.to_string()
64 };
65 self.tool_name == tool_name && self.error_key == new_key
66 }
67
68 pub fn increment(&mut self) {
70 self.count += 1;
71 self.last_occurrence = std::time::Instant::now();
72 }
73
74 pub fn is_limit_reached(&self) -> bool {
76 self.count >= MAX_SAME_ERROR_COUNT
77 }
78}
79
80pub struct AgentState {
85 messages: Vec<Message>,
87
88 total_input_tokens: AtomicU64,
90
91 total_output_tokens: AtomicU64,
93
94 last_input_tokens: AtomicU64,
96
97 previewed_tool_inputs: HashSet<String>,
100
101 todo_reminder_count: HashMap<String, usize>,
104
105 read_history: ReadHistoryTracker,
108
109 pending_inputs: Vec<String>,
111
112 error_history: Vec<ToolErrorEntry>,
115}
116
117impl AgentState {
118 pub fn new() -> Self {
120 Self {
121 messages: Vec::new(),
122 total_input_tokens: AtomicU64::new(0),
123 total_output_tokens: AtomicU64::new(0),
124 last_input_tokens: AtomicU64::new(0),
125 previewed_tool_inputs: HashSet::new(),
126 todo_reminder_count: HashMap::new(),
127 read_history: ReadHistoryTracker::new(),
128 pending_inputs: Vec::new(),
129 error_history: Vec::new(),
130 }
131 }
132
133 pub fn add_message(&mut self, message: Message) {
135 self.messages.push(message);
136 }
137
138 pub fn messages(&self) -> &Vec<Message> {
140 &self.messages
141 }
142
143 pub fn messages_mut(&mut self) -> &mut Vec<Message> {
145 &mut self.messages
146 }
147
148 pub fn set_messages(&mut self, messages: Vec<Message>) {
153 let cleaned = Self::clean_orphaned_messages(messages);
154 self.messages = cleaned;
155 }
156
157 fn clean_orphaned_messages(messages: Vec<Message>) -> Vec<Message> {
162 if messages.is_empty() {
163 return messages;
164 }
165
166 let mut tool_use_ids: HashSet<String> = HashSet::new();
168 for msg in &messages {
169 if let MessageContent::Blocks(blocks) = &msg.content {
170 for block in blocks {
171 if let ContentBlock::ToolUse { id, .. } = block {
172 tool_use_ids.insert(id.clone());
173 }
174 }
175 }
176 }
177
178 let mut tool_result_ids: HashSet<String> = HashSet::new();
180 for msg in &messages {
181 if msg.role == Role::Tool {
182 if let MessageContent::Blocks(blocks) = &msg.content {
183 for block in blocks {
184 if let ContentBlock::ToolResult { tool_use_id, .. } = block {
185 tool_result_ids.insert(tool_use_id.clone());
186 }
187 }
188 }
189 }
190 }
191
192 let orphaned_tool_use_ids: HashSet<&str> = tool_use_ids
194 .iter()
195 .filter(|id| !tool_result_ids.contains(*id))
196 .map(|s| s.as_str())
197 .collect();
198
199 let orphaned_tool_result_ids: HashSet<&str> = tool_result_ids
200 .iter()
201 .filter(|id| !tool_use_ids.contains(*id))
202 .map(|s| s.as_str())
203 .collect();
204
205 if orphaned_tool_use_ids.is_empty() && orphaned_tool_result_ids.is_empty() {
207 return messages;
208 }
209
210 log::warn!(
211 "Cleaning orphaned messages: {} tool_uses without results, {} tool_results without uses",
212 orphaned_tool_use_ids.len(),
213 orphaned_tool_result_ids.len()
214 );
215
216 let original_len = messages.len();
218 let mut cleaned = Vec::with_capacity(messages.len());
219 for msg in messages {
220 if msg.role == Role::Tool {
222 if let MessageContent::Blocks(blocks) = &msg.content {
223 let has_orphaned_result = blocks.iter().any(|b| {
224 if let ContentBlock::ToolResult { tool_use_id, .. } = b {
225 orphaned_tool_result_ids.contains(tool_use_id.as_str())
226 } else {
227 false
228 }
229 });
230 if has_orphaned_result {
231 log::info!("Removing orphaned tool result message");
232 continue;
233 }
234 }
235 }
236
237 if let MessageContent::Blocks(blocks) = msg.content {
239 let filtered_blocks: Vec<ContentBlock> = blocks
240 .into_iter()
241 .filter(|b| {
242 if let ContentBlock::ToolUse { id, .. } = b {
243 if orphaned_tool_use_ids.contains(id.as_str()) {
244 log::info!("Removing orphaned tool_use block: {}", id);
245 return false;
246 }
247 }
248 true
249 })
250 .collect();
251
252 if !filtered_blocks.is_empty() {
254 cleaned.push(Message {
255 role: msg.role,
256 content: MessageContent::Blocks(filtered_blocks),
257 });
258 }
259 } else {
260 cleaned.push(msg);
261 }
262 }
263
264 log::info!(
265 "Message cleaning complete: {} messages -> {} messages",
266 original_len,
267 cleaned.len()
268 );
269
270 cleaned
271 }
272
273 pub fn track_usage(&self, usage: &Usage) {
275 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
276 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
277 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
278 }
279
280 pub fn total_input_tokens(&self) -> u64 {
282 self.total_input_tokens.load(Ordering::Relaxed)
283 }
284
285 pub fn total_output_tokens(&self) -> u64 {
287 self.total_output_tokens.load(Ordering::Relaxed)
288 }
289
290 pub fn last_input_tokens(&self) -> u64 {
292 self.last_input_tokens.load(Ordering::Relaxed)
293 }
294
295 pub fn set_total_input_tokens(&self, value: u64) {
297 self.total_input_tokens.store(value, Ordering::Relaxed);
298 }
299
300 pub fn set_total_output_tokens(&self, value: u64) {
302 self.total_output_tokens.store(value, Ordering::Relaxed);
303 }
304
305 pub fn set_last_input_tokens(&self, value: u64) {
307 self.last_input_tokens.store(value, Ordering::Relaxed);
308 }
309
310 pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
312 self.previewed_tool_inputs.insert(tool_id);
313 }
314
315 pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
317 self.previewed_tool_inputs.contains(tool_id)
318 }
319
320 pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
322 self.previewed_tool_inputs.remove(tool_id)
323 }
324
325 pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
328 let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
329 self.todo_reminder_count.insert(todo_hash, count);
330 count
331 }
332
333 pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
335 self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
336 }
337
338 pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
340 &self.todo_reminder_count
341 }
342
343 pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
345 &mut self.todo_reminder_count
346 }
347
348 pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
350 self.todo_reminder_count(todo_hash) >= max_reminders
351 }
352
353 pub fn read_history(&self) -> &ReadHistoryTracker {
355 &self.read_history
356 }
357
358 pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
360 &mut self.read_history
361 }
362
363 pub fn add_pending_input(&mut self, input: String) {
365 self.pending_inputs.push(input);
366 }
367
368 pub fn has_pending_inputs(&self) -> bool {
370 !self.pending_inputs.is_empty()
371 }
372
373 pub fn pending_inputs_vec(&self) -> &Vec<String> {
375 &self.pending_inputs
376 }
377
378 pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
380 &mut self.pending_inputs
381 }
382
383 pub fn take_pending_inputs(&mut self) -> Vec<String> {
385 std::mem::take(&mut self.pending_inputs)
386 }
387
388 pub fn pending_input_count(&self) -> usize {
390 self.pending_inputs.len()
391 }
392
393 pub fn message_count(&self) -> usize {
395 self.messages.len()
396 }
397
398 pub fn record_tool_error(&mut self, tool_name: &str, error_msg: &str) -> usize {
405 for entry in &mut self.error_history {
407 if entry.matches(tool_name, error_msg) {
408 entry.increment();
409 return entry.count;
410 }
411 }
412
413 let new_entry = ToolErrorEntry::new(tool_name, error_msg);
415 let count = new_entry.count;
416 self.error_history.push(new_entry);
417 count
418 }
419
420 pub fn check_error_limit(&self, tool_name: &str, error_msg: &str) -> Option<&ToolErrorEntry> {
423 self.error_history.iter().find(|e| {
424 e.matches(tool_name, error_msg) && e.is_limit_reached()
425 })
426 }
427
428 pub fn error_count(&self, tool_name: &str, error_msg: &str) -> usize {
430 self.error_history.iter()
431 .find(|e| e.matches(tool_name, error_msg))
432 .map(|e| e.count)
433 .unwrap_or(0)
434 }
435
436 pub fn clear_error_history(&mut self) {
438 self.error_history.clear();
439 }
440
441 pub fn unique_error_count(&self) -> usize {
443 self.error_history.len()
444 }
445
446 pub fn repeated_error_count(&self) -> usize {
448 self.error_history.iter().filter(|e| e.count > 1).count()
449 }
450
451 pub fn clear(&mut self) {
453 self.messages.clear();
454 self.total_input_tokens.store(0, Ordering::Relaxed);
455 self.total_output_tokens.store(0, Ordering::Relaxed);
456 self.last_input_tokens.store(0, Ordering::Relaxed);
457 self.previewed_tool_inputs.clear();
458 self.todo_reminder_count.clear();
459 self.read_history = ReadHistoryTracker::new();
460 self.pending_inputs.clear();
461 self.error_history.clear();
462 }
463}
464
465impl Default for AgentState {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::providers::{MessageContent, Role};
475
476 fn create_test_message(text: &str) -> Message {
477 Message {
478 role: Role::User,
479 content: MessageContent::Text(text.to_string()),
480 }
481 }
482
483 #[test]
484 fn test_state_new_is_empty() {
485 let state = AgentState::new();
486
487 assert_eq!(state.message_count(), 0);
488 assert_eq!(state.total_input_tokens(), 0);
489 assert_eq!(state.total_output_tokens(), 0);
490 assert_eq!(state.last_input_tokens(), 0);
491 assert!(!state.has_pending_inputs());
492 assert_eq!(state.pending_input_count(), 0);
493 }
494
495 #[test]
496 fn test_state_add_message() {
497 let mut state = AgentState::new();
498
499 state.add_message(create_test_message("Hello"));
500 state.add_message(create_test_message("World"));
501
502 assert_eq!(state.message_count(), 2);
503 assert_eq!(state.messages().len(), 2);
504 }
505
506 #[test]
507 fn test_state_track_usage() {
508 let state = AgentState::new();
509 let usage = Usage {
510 input_tokens: 100,
511 output_tokens: 50,
512 cache_creation_input_tokens: 0,
513 cache_read_input_tokens: 0,
514 };
515
516 state.track_usage(&usage);
517
518 assert_eq!(state.total_input_tokens(), 100);
519 assert_eq!(state.total_output_tokens(), 50);
520 assert_eq!(state.last_input_tokens(), 100);
521
522 state.track_usage(&usage);
524 assert_eq!(state.total_input_tokens(), 200);
525 assert_eq!(state.total_output_tokens(), 100);
526 assert_eq!(state.last_input_tokens(), 100);
527 }
528
529 #[test]
530 fn test_state_previewed_tool_inputs() {
531 let mut state = AgentState::new();
532
533 assert!(!state.was_tool_input_previewed("tool_1"));
535
536 state.mark_tool_input_previewed("tool_1".to_string());
538 assert!(state.was_tool_input_previewed("tool_1"));
539 assert!(!state.was_tool_input_previewed("tool_2"));
540
541 let removed = state.remove_previewed_tool_input("tool_1");
543 assert!(removed, "should return true when removing existing item");
544 assert!(!state.was_tool_input_previewed("tool_1"));
545
546 let removed = state.remove_previewed_tool_input("tool_2");
548 assert!(!removed, "should return false when removing non-existent item");
549 }
550
551 #[test]
552 fn test_state_todo_reminders() {
553 let mut state = AgentState::new();
554 let todo_hash = "hash_123".to_string();
555
556 assert_eq!(state.todo_reminder_count(&todo_hash), 0);
558 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
559
560 let count = state.increment_todo_reminder(todo_hash.clone());
562 assert_eq!(count, 1);
563 assert_eq!(state.todo_reminder_count(&todo_hash), 1);
564 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
565
566 let count = state.increment_todo_reminder(todo_hash.clone());
568 assert_eq!(count, 2);
569 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
570
571 let count = state.increment_todo_reminder(todo_hash.clone());
573 assert_eq!(count, 3);
574 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
575 }
576
577 #[test]
578 fn test_state_pending_inputs() {
579 let mut state = AgentState::new();
580
581 assert!(!state.has_pending_inputs());
583 assert_eq!(state.pending_input_count(), 0);
584
585 state.add_pending_input("input 1".to_string());
587 state.add_pending_input("input 2".to_string());
588
589 assert!(state.has_pending_inputs());
590 assert_eq!(state.pending_input_count(), 2);
591
592 let inputs = state.take_pending_inputs();
594 assert_eq!(inputs.len(), 2);
595 assert_eq!(inputs[0], "input 1");
596 assert_eq!(inputs[1], "input 2");
597
598 assert!(!state.has_pending_inputs());
600 assert_eq!(state.pending_input_count(), 0);
601 }
602
603 #[test]
604 fn test_state_set_messages() {
605 let mut state = AgentState::new();
606 state.add_message(create_test_message("Old message"));
607
608 let new_messages = vec![
610 create_test_message("New 1"),
611 create_test_message("New 2"),
612 ];
613 state.set_messages(new_messages);
614
615 assert_eq!(state.message_count(), 2);
616 assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
617 }
618
619 #[test]
620 fn test_state_clear() {
621 let mut state = AgentState::new();
622
623 state.add_message(create_test_message("Test"));
625 state.track_usage(&Usage {
626 input_tokens: 100,
627 output_tokens: 50,
628 cache_creation_input_tokens: 0,
629 cache_read_input_tokens: 0,
630 });
631 state.add_pending_input("pending".to_string());
632 state.mark_tool_input_previewed("tool_1".to_string());
633
634 state.clear();
636
637 assert_eq!(state.message_count(), 0);
639 assert_eq!(state.total_input_tokens(), 0);
640 assert_eq!(state.total_output_tokens(), 0);
641 assert!(!state.has_pending_inputs());
642 assert!(!state.was_tool_input_previewed("tool_1"));
643 }
644}