Skip to main content

matrixcode_core/agent/core/
state.rs

1//! Agent state management.
2//!
3//! This module manages the runtime state of the Agent, including:
4//! - Message history
5//! - Token usage tracking
6//! - Pending inputs
7//! - Todo reminders
8//! - Read history
9//! - Tool error history (for repeat error detection)
10//!
11//! By extracting state into a dedicated struct, we enable:
12//! - Clear separation between state and configuration
13//! - Easier testing of state transitions
14//! - Better encapsulation of mutable state
15
16use std::collections::{HashMap, HashSet};
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use crate::providers::{ContentBlock, Message, MessageContent, Role, Usage};
20use crate::tools::ReadHistoryTracker;
21
22/// Maximum number of same errors before triggering intervention
23pub const MAX_SAME_ERROR_COUNT: usize = 3;
24
25/// A tool error entry for tracking repeated errors.
26#[derive(Debug, Clone)]
27pub struct ToolErrorEntry {
28    /// Tool name that failed
29    pub tool_name: String,
30    /// Error message (truncated for comparison)
31    pub error_key: String,
32    /// Number of times this error occurred
33    pub count: usize,
34    /// Last occurrence timestamp (for aging)
35    pub last_occurrence: std::time::Instant,
36}
37
38impl ToolErrorEntry {
39    /// Create a new error entry
40    pub fn new(tool_name: &str, error_msg: &str) -> Self {
41        // Create a key for comparison - truncate to first 100 chars
42        // Use chars() to safely handle UTF-8 multi-byte characters
43        let error_key = if error_msg.chars().count() > 100 {
44            error_msg.chars().take(100).collect::<String>()
45        } else {
46            error_msg.to_string()
47        };
48
49        Self {
50            tool_name: tool_name.to_string(),
51            error_key,
52            count: 1,
53            last_occurrence: std::time::Instant::now(),
54        }
55    }
56
57    /// Check if this entry matches a new error
58    pub fn matches(&self, tool_name: &str, error_msg: &str) -> bool {
59        // Use chars() to safely handle UTF-8 multi-byte characters
60        let new_key = if error_msg.chars().count() > 100 {
61            error_msg.chars().take(100).collect::<String>()
62        } else {
63            error_msg.to_string()
64        };
65        self.tool_name == tool_name && self.error_key == new_key
66    }
67
68    /// Increment the count and update timestamp
69    pub fn increment(&mut self) {
70        self.count += 1;
71        self.last_occurrence = std::time::Instant::now();
72    }
73
74    /// Check if error limit reached
75    pub fn is_limit_reached(&self) -> bool {
76        self.count >= MAX_SAME_ERROR_COUNT
77    }
78}
79
80/// Agent runtime state.
81///
82/// Manages all mutable state during agent execution.
83/// All fields are private to enforce encapsulation.
84pub struct AgentState {
85    /// Message history (conversation with LLM).
86    messages: Vec<Message>,
87
88    /// Total input tokens consumed (lifetime).
89    total_input_tokens: AtomicU64,
90
91    /// Total output tokens generated (lifetime).
92    total_output_tokens: AtomicU64,
93
94    /// Last input tokens (for compression tracking).
95    last_input_tokens: AtomicU64,
96
97    /// Tool input IDs that were previewed during streaming.
98    /// Prevents duplicate emission of ToolUseStart events.
99    previewed_tool_inputs: HashSet<String>,
100
101    /// Todo reminder counts per todo content hash.
102    /// Prevents infinite reminder loops.
103    todo_reminder_count: HashMap<String, usize>,
104
105    /// Files read in this session.
106    /// Enforces "read before edit/write" rule.
107    read_history: ReadHistoryTracker,
108
109    /// Pending user inputs queued for next iteration.
110    pending_inputs: Vec<String>,
111
112    /// Tool error history for detecting repeated errors.
113    /// When same error repeats, provide enhanced guidance.
114    error_history: Vec<ToolErrorEntry>,
115}
116
117impl AgentState {
118    /// Create a new empty state.
119    pub fn new() -> Self {
120        Self {
121            messages: Vec::new(),
122            total_input_tokens: AtomicU64::new(0),
123            total_output_tokens: AtomicU64::new(0),
124            last_input_tokens: AtomicU64::new(0),
125            previewed_tool_inputs: HashSet::new(),
126            todo_reminder_count: HashMap::new(),
127            read_history: ReadHistoryTracker::new(),
128            pending_inputs: Vec::new(),
129            error_history: Vec::new(),
130        }
131    }
132
133    /// Add a message to history.
134    pub fn add_message(&mut self, message: Message) {
135        self.messages.push(message);
136    }
137
138    /// Get reference to message history.
139    pub fn messages(&self) -> &Vec<Message> {
140        &self.messages
141    }
142
143    /// Get mutable reference to message history.
144    pub fn messages_mut(&mut self) -> &mut Vec<Message> {
145        &mut self.messages
146    }
147
148    /// Replace message history (used in compression).
149    ///
150    /// This method validates and cleans orphaned tool results/tool uses
151    /// before setting the message history to prevent API errors.
152    pub fn set_messages(&mut self, messages: Vec<Message>) {
153        let cleaned = Self::clean_orphaned_messages(messages);
154        self.messages = cleaned;
155    }
156
157    /// Clean orphaned tool results and tool uses from messages.
158    ///
159    /// Orphaned tool result: a Tool message whose tool_use_id has no corresponding ToolUse block.
160    /// Orphaned tool use: a ToolUse block whose id has no corresponding ToolResult.
161    fn clean_orphaned_messages(messages: Vec<Message>) -> Vec<Message> {
162        if messages.is_empty() {
163            return messages;
164        }
165
166        // Collect all tool_use_ids from ToolUse blocks
167        let mut tool_use_ids: HashSet<String> = HashSet::new();
168        for msg in &messages {
169            if let MessageContent::Blocks(blocks) = &msg.content {
170                for block in blocks {
171                    if let ContentBlock::ToolUse { id, .. } = block {
172                        tool_use_ids.insert(id.clone());
173                    }
174                }
175            }
176        }
177
178        // Collect all tool_use_ids from ToolResult blocks
179        let mut tool_result_ids: HashSet<String> = HashSet::new();
180        for msg in &messages {
181            if msg.role == Role::Tool {
182                if let MessageContent::Blocks(blocks) = &msg.content {
183                    for block in blocks {
184                        if let ContentBlock::ToolResult { tool_use_id, .. } = block {
185                            tool_result_ids.insert(tool_use_id.clone());
186                        }
187                    }
188                }
189            }
190        }
191
192        // Find orphaned ids (no matching pair)
193        let orphaned_tool_use_ids: HashSet<&str> = tool_use_ids
194            .iter()
195            .filter(|id| !tool_result_ids.contains(*id))
196            .map(|s| s.as_str())
197            .collect();
198
199        let orphaned_tool_result_ids: HashSet<&str> = tool_result_ids
200            .iter()
201            .filter(|id| !tool_use_ids.contains(*id))
202            .map(|s| s.as_str())
203            .collect();
204
205        // If no orphans, return as-is
206        if orphaned_tool_use_ids.is_empty() && orphaned_tool_result_ids.is_empty() {
207            return messages;
208        }
209
210        log::warn!(
211            "Cleaning orphaned messages: {} tool_uses without results, {} tool_results without uses",
212            orphaned_tool_use_ids.len(),
213            orphaned_tool_result_ids.len()
214        );
215
216        // Clean messages
217        let original_len = messages.len();
218        let mut cleaned = Vec::with_capacity(messages.len());
219        for msg in messages {
220            // Skip entire Tool messages that are orphaned tool results
221            if msg.role == Role::Tool {
222                if let MessageContent::Blocks(blocks) = &msg.content {
223                    let has_orphaned_result = blocks.iter().any(|b| {
224                        if let ContentBlock::ToolResult { tool_use_id, .. } = b {
225                            orphaned_tool_result_ids.contains(tool_use_id.as_str())
226                        } else {
227                            false
228                        }
229                    });
230                    if has_orphaned_result {
231                        log::info!("Removing orphaned tool result message");
232                        continue;
233                    }
234                }
235            }
236
237            // For assistant messages, filter out orphaned tool_use blocks
238            if let MessageContent::Blocks(blocks) = msg.content {
239                let filtered_blocks: Vec<ContentBlock> = blocks
240                    .into_iter()
241                    .filter(|b| {
242                        if let ContentBlock::ToolUse { id, .. } = b {
243                            if orphaned_tool_use_ids.contains(id.as_str()) {
244                                log::info!("Removing orphaned tool_use block: {}", id);
245                                return false;
246                            }
247                        }
248                        true
249                    })
250                    .collect();
251
252                // Only add message if it has remaining content
253                if !filtered_blocks.is_empty() {
254                    cleaned.push(Message {
255                        role: msg.role,
256                        content: MessageContent::Blocks(filtered_blocks),
257                    });
258                }
259            } else {
260                cleaned.push(msg);
261            }
262        }
263
264        log::info!(
265            "Message cleaning complete: {} messages -> {} messages",
266            original_len,
267            cleaned.len()
268        );
269
270        cleaned
271    }
272
273    /// Track token usage from API response.
274    pub fn track_usage(&self, usage: &Usage) {
275        self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
276        self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
277        self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
278    }
279
280    /// Get total input tokens consumed.
281    pub fn total_input_tokens(&self) -> u64 {
282        self.total_input_tokens.load(Ordering::Relaxed)
283    }
284
285    /// Get total output tokens generated.
286    pub fn total_output_tokens(&self) -> u64 {
287        self.total_output_tokens.load(Ordering::Relaxed)
288    }
289
290    /// Get last input tokens (for compression decisions).
291    pub fn last_input_tokens(&self) -> u64 {
292        self.last_input_tokens.load(Ordering::Relaxed)
293    }
294
295    /// Set total input tokens (used after compression).
296    pub fn set_total_input_tokens(&self, value: u64) {
297        self.total_input_tokens.store(value, Ordering::Relaxed);
298    }
299
300    /// Set total output tokens.
301    pub fn set_total_output_tokens(&self, value: u64) {
302        self.total_output_tokens.store(value, Ordering::Relaxed);
303    }
304
305    /// Set last input tokens (used after compression).
306    pub fn set_last_input_tokens(&self, value: u64) {
307        self.last_input_tokens.store(value, Ordering::Relaxed);
308    }
309
310    /// Mark a tool input as previewed during streaming.
311    pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
312        self.previewed_tool_inputs.insert(tool_id);
313    }
314
315    /// Check if a tool input was already previewed.
316    pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
317        self.previewed_tool_inputs.contains(tool_id)
318    }
319
320    /// Remove a tool input from previewed set (after processing).
321    pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
322        self.previewed_tool_inputs.remove(tool_id)
323    }
324
325    /// Increment todo reminder count for a todo item.
326    /// Returns the new count.
327    pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
328        let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
329        self.todo_reminder_count.insert(todo_hash, count);
330        count
331    }
332
333    /// Get todo reminder count for a todo item.
334    pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
335        self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
336    }
337
338    /// Get reference to the entire todo reminder count map.
339    pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
340        &self.todo_reminder_count
341    }
342
343    /// Get mutable reference to the entire todo reminder count map.
344    pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
345        &mut self.todo_reminder_count
346    }
347
348    /// Check if todo reminder limit reached.
349    pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
350        self.todo_reminder_count(todo_hash) >= max_reminders
351    }
352
353    /// Get reference to read history tracker.
354    pub fn read_history(&self) -> &ReadHistoryTracker {
355        &self.read_history
356    }
357
358    /// Get mutable reference to read history tracker.
359    pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
360        &mut self.read_history
361    }
362
363    /// Add a pending input to queue.
364    pub fn add_pending_input(&mut self, input: String) {
365        self.pending_inputs.push(input);
366    }
367
368    /// Check if there are pending inputs.
369    pub fn has_pending_inputs(&self) -> bool {
370        !self.pending_inputs.is_empty()
371    }
372
373    /// Get reference to pending inputs vector.
374    pub fn pending_inputs_vec(&self) -> &Vec<String> {
375        &self.pending_inputs
376    }
377
378    /// Get mutable reference to pending inputs vector.
379    pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
380        &mut self.pending_inputs
381    }
382
383    /// Take all pending inputs (drains the queue).
384    pub fn take_pending_inputs(&mut self) -> Vec<String> {
385        std::mem::take(&mut self.pending_inputs)
386    }
387
388    /// Get count of pending inputs.
389    pub fn pending_input_count(&self) -> usize {
390        self.pending_inputs.len()
391    }
392
393    /// Get message count.
394    pub fn message_count(&self) -> usize {
395        self.messages.len()
396    }
397
398    // ========================================================================
399    // Error History Management
400    // ========================================================================
401
402    /// Record a tool error and check for repeats.
403    /// Returns the count of how many times this error has occurred.
404    pub fn record_tool_error(&mut self, tool_name: &str, error_msg: &str) -> usize {
405        // Find existing entry for this error
406        for entry in &mut self.error_history {
407            if entry.matches(tool_name, error_msg) {
408                entry.increment();
409                return entry.count;
410            }
411        }
412
413        // No existing entry, create new one
414        let new_entry = ToolErrorEntry::new(tool_name, error_msg);
415        let count = new_entry.count;
416        self.error_history.push(new_entry);
417        count
418    }
419
420    /// Check if a tool has reached error limit.
421    /// Returns the entry if limit is reached.
422    pub fn check_error_limit(&self, tool_name: &str, error_msg: &str) -> Option<&ToolErrorEntry> {
423        self.error_history.iter().find(|e| {
424            e.matches(tool_name, error_msg) && e.is_limit_reached()
425        })
426    }
427
428    /// Get the count of a specific error.
429    pub fn error_count(&self, tool_name: &str, error_msg: &str) -> usize {
430        self.error_history.iter()
431            .find(|e| e.matches(tool_name, error_msg))
432            .map(|e| e.count)
433            .unwrap_or(0)
434    }
435
436    /// Clear error history (e.g., after successful correction).
437    pub fn clear_error_history(&mut self) {
438        self.error_history.clear();
439    }
440
441    /// Get total unique error count.
442    pub fn unique_error_count(&self) -> usize {
443        self.error_history.len()
444    }
445
446    /// Get total repeated error count (errors that occurred more than once).
447    pub fn repeated_error_count(&self) -> usize {
448        self.error_history.iter().filter(|e| e.count > 1).count()
449    }
450
451    /// Clear all state (reset to initial state).
452    pub fn clear(&mut self) {
453        self.messages.clear();
454        self.total_input_tokens.store(0, Ordering::Relaxed);
455        self.total_output_tokens.store(0, Ordering::Relaxed);
456        self.last_input_tokens.store(0, Ordering::Relaxed);
457        self.previewed_tool_inputs.clear();
458        self.todo_reminder_count.clear();
459        self.read_history = ReadHistoryTracker::new();
460        self.pending_inputs.clear();
461        self.error_history.clear();
462    }
463}
464
465impl Default for AgentState {
466    fn default() -> Self {
467        Self::new()
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::providers::{MessageContent, Role};
475
476    fn create_test_message(text: &str) -> Message {
477        Message {
478            role: Role::User,
479            content: MessageContent::Text(text.to_string()),
480        }
481    }
482
483    #[test]
484    fn test_state_new_is_empty() {
485        let state = AgentState::new();
486
487        assert_eq!(state.message_count(), 0);
488        assert_eq!(state.total_input_tokens(), 0);
489        assert_eq!(state.total_output_tokens(), 0);
490        assert_eq!(state.last_input_tokens(), 0);
491        assert!(!state.has_pending_inputs());
492        assert_eq!(state.pending_input_count(), 0);
493    }
494
495    #[test]
496    fn test_state_add_message() {
497        let mut state = AgentState::new();
498
499        state.add_message(create_test_message("Hello"));
500        state.add_message(create_test_message("World"));
501
502        assert_eq!(state.message_count(), 2);
503        assert_eq!(state.messages().len(), 2);
504    }
505
506    #[test]
507    fn test_state_track_usage() {
508        let state = AgentState::new();
509        let usage = Usage {
510            input_tokens: 100,
511            output_tokens: 50,
512            cache_creation_input_tokens: 0,
513            cache_read_input_tokens: 0,
514        };
515
516        state.track_usage(&usage);
517
518        assert_eq!(state.total_input_tokens(), 100);
519        assert_eq!(state.total_output_tokens(), 50);
520        assert_eq!(state.last_input_tokens(), 100);
521
522        // Track again (should accumulate)
523        state.track_usage(&usage);
524        assert_eq!(state.total_input_tokens(), 200);
525        assert_eq!(state.total_output_tokens(), 100);
526        assert_eq!(state.last_input_tokens(), 100);
527    }
528
529    #[test]
530    fn test_state_previewed_tool_inputs() {
531        let mut state = AgentState::new();
532
533        // Initially not previewed
534        assert!(!state.was_tool_input_previewed("tool_1"));
535
536        // Mark as previewed
537        state.mark_tool_input_previewed("tool_1".to_string());
538        assert!(state.was_tool_input_previewed("tool_1"));
539        assert!(!state.was_tool_input_previewed("tool_2"));
540
541        // Remove previewed
542        let removed = state.remove_previewed_tool_input("tool_1");
543        assert!(removed, "should return true when removing existing item");
544        assert!(!state.was_tool_input_previewed("tool_1"));
545
546        // Remove non-existent
547        let removed = state.remove_previewed_tool_input("tool_2");
548        assert!(!removed, "should return false when removing non-existent item");
549    }
550
551    #[test]
552    fn test_state_todo_reminders() {
553        let mut state = AgentState::new();
554        let todo_hash = "hash_123".to_string();
555
556        // Initially 0
557        assert_eq!(state.todo_reminder_count(&todo_hash), 0);
558        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
559
560        // Increment
561        let count = state.increment_todo_reminder(todo_hash.clone());
562        assert_eq!(count, 1);
563        assert_eq!(state.todo_reminder_count(&todo_hash), 1);
564        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
565
566        // Increment again
567        let count = state.increment_todo_reminder(todo_hash.clone());
568        assert_eq!(count, 2);
569        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
570
571        // Increment beyond limit
572        let count = state.increment_todo_reminder(todo_hash.clone());
573        assert_eq!(count, 3);
574        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
575    }
576
577    #[test]
578    fn test_state_pending_inputs() {
579        let mut state = AgentState::new();
580
581        // Initially empty
582        assert!(!state.has_pending_inputs());
583        assert_eq!(state.pending_input_count(), 0);
584
585        // Add inputs
586        state.add_pending_input("input 1".to_string());
587        state.add_pending_input("input 2".to_string());
588
589        assert!(state.has_pending_inputs());
590        assert_eq!(state.pending_input_count(), 2);
591
592        // Take inputs
593        let inputs = state.take_pending_inputs();
594        assert_eq!(inputs.len(), 2);
595        assert_eq!(inputs[0], "input 1");
596        assert_eq!(inputs[1], "input 2");
597
598        // Queue drained
599        assert!(!state.has_pending_inputs());
600        assert_eq!(state.pending_input_count(), 0);
601    }
602
603    #[test]
604    fn test_state_set_messages() {
605        let mut state = AgentState::new();
606        state.add_message(create_test_message("Old message"));
607
608        // Replace messages
609        let new_messages = vec![
610            create_test_message("New 1"),
611            create_test_message("New 2"),
612        ];
613        state.set_messages(new_messages);
614
615        assert_eq!(state.message_count(), 2);
616        assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
617    }
618
619    #[test]
620    fn test_state_clear() {
621        let mut state = AgentState::new();
622
623        // Add some state
624        state.add_message(create_test_message("Test"));
625        state.track_usage(&Usage {
626            input_tokens: 100,
627            output_tokens: 50,
628            cache_creation_input_tokens: 0,
629            cache_read_input_tokens: 0,
630        });
631        state.add_pending_input("pending".to_string());
632        state.mark_tool_input_previewed("tool_1".to_string());
633
634        // Clear
635        state.clear();
636
637        // Verify all cleared
638        assert_eq!(state.message_count(), 0);
639        assert_eq!(state.total_input_tokens(), 0);
640        assert_eq!(state.total_output_tokens(), 0);
641        assert!(!state.has_pending_inputs());
642        assert!(!state.was_tool_input_previewed("tool_1"));
643    }
644}