Skip to main content

matrixcode_core/agent/core/
state.rs

1//! Agent state management.
2//!
3//! This module manages the runtime state of the Agent, including:
4//! - Message history
5//! - Token usage tracking
6//! - Pending inputs
7//! - Todo reminders
8//! - Read history
9//!
10//! By extracting state into a dedicated struct, we enable:
11//! - Clear separation between state and configuration
12//! - Easier testing of state transitions
13//! - Better encapsulation of mutable state
14
15use std::collections::{HashMap, HashSet};
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use crate::providers::{ContentBlock, Message, MessageContent, Role, Usage};
19use crate::tools::ReadHistoryTracker;
20
21/// Agent runtime state.
22///
23/// Manages all mutable state during agent execution.
24/// All fields are private to enforce encapsulation.
25pub struct AgentState {
26    /// Message history (conversation with LLM).
27    messages: Vec<Message>,
28
29    /// Total input tokens consumed (lifetime).
30    total_input_tokens: AtomicU64,
31
32    /// Total output tokens generated (lifetime).
33    total_output_tokens: AtomicU64,
34
35    /// Last input tokens (for compression tracking).
36    last_input_tokens: AtomicU64,
37
38    /// Tool input IDs that were previewed during streaming.
39    /// Prevents duplicate emission of ToolUseStart events.
40    previewed_tool_inputs: HashSet<String>,
41
42    /// Todo reminder counts per todo content hash.
43    /// Prevents infinite reminder loops.
44    todo_reminder_count: HashMap<String, usize>,
45
46    /// Files read in this session.
47    /// Enforces "read before edit/write" rule.
48    read_history: ReadHistoryTracker,
49
50    /// Pending user inputs queued for next iteration.
51    pending_inputs: Vec<String>,
52}
53
54impl AgentState {
55    /// Create a new empty state.
56    pub fn new() -> Self {
57        Self {
58            messages: Vec::new(),
59            total_input_tokens: AtomicU64::new(0),
60            total_output_tokens: AtomicU64::new(0),
61            last_input_tokens: AtomicU64::new(0),
62            previewed_tool_inputs: HashSet::new(),
63            todo_reminder_count: HashMap::new(),
64            read_history: ReadHistoryTracker::new(),
65            pending_inputs: Vec::new(),
66        }
67    }
68
69    /// Add a message to history.
70    pub fn add_message(&mut self, message: Message) {
71        self.messages.push(message);
72    }
73
74    /// Get reference to message history.
75    pub fn messages(&self) -> &Vec<Message> {
76        &self.messages
77    }
78
79    /// Get mutable reference to message history.
80    pub fn messages_mut(&mut self) -> &mut Vec<Message> {
81        &mut self.messages
82    }
83
84    /// Replace message history (used in compression).
85    ///
86    /// This method validates and cleans orphaned tool results/tool uses
87    /// before setting the message history to prevent API errors.
88    pub fn set_messages(&mut self, messages: Vec<Message>) {
89        let cleaned = Self::clean_orphaned_messages(messages);
90        self.messages = cleaned;
91    }
92
93    /// Clean orphaned tool results and tool uses from messages.
94    ///
95    /// Orphaned tool result: a Tool message whose tool_use_id has no corresponding ToolUse block.
96    /// Orphaned tool use: a ToolUse block whose id has no corresponding ToolResult.
97    fn clean_orphaned_messages(messages: Vec<Message>) -> Vec<Message> {
98        if messages.is_empty() {
99            return messages;
100        }
101
102        // Collect all tool_use_ids from ToolUse blocks
103        let mut tool_use_ids: HashSet<String> = HashSet::new();
104        for msg in &messages {
105            if let MessageContent::Blocks(blocks) = &msg.content {
106                for block in blocks {
107                    if let ContentBlock::ToolUse { id, .. } = block {
108                        tool_use_ids.insert(id.clone());
109                    }
110                }
111            }
112        }
113
114        // Collect all tool_use_ids from ToolResult blocks
115        let mut tool_result_ids: HashSet<String> = HashSet::new();
116        for msg in &messages {
117            if msg.role == Role::Tool {
118                if let MessageContent::Blocks(blocks) = &msg.content {
119                    for block in blocks {
120                        if let ContentBlock::ToolResult { tool_use_id, .. } = block {
121                            tool_result_ids.insert(tool_use_id.clone());
122                        }
123                    }
124                }
125            }
126        }
127
128        // Find orphaned ids (no matching pair)
129        let orphaned_tool_use_ids: HashSet<&str> = tool_use_ids
130            .iter()
131            .filter(|id| !tool_result_ids.contains(*id))
132            .map(|s| s.as_str())
133            .collect();
134
135        let orphaned_tool_result_ids: HashSet<&str> = tool_result_ids
136            .iter()
137            .filter(|id| !tool_use_ids.contains(*id))
138            .map(|s| s.as_str())
139            .collect();
140
141        // If no orphans, return as-is
142        if orphaned_tool_use_ids.is_empty() && orphaned_tool_result_ids.is_empty() {
143            return messages;
144        }
145
146        log::warn!(
147            "Cleaning orphaned messages: {} tool_uses without results, {} tool_results without uses",
148            orphaned_tool_use_ids.len(),
149            orphaned_tool_result_ids.len()
150        );
151
152        // Clean messages
153        let original_len = messages.len();
154        let mut cleaned = Vec::with_capacity(messages.len());
155        for msg in messages {
156            // Skip entire Tool messages that are orphaned tool results
157            if msg.role == Role::Tool {
158                if let MessageContent::Blocks(blocks) = &msg.content {
159                    let has_orphaned_result = blocks.iter().any(|b| {
160                        if let ContentBlock::ToolResult { tool_use_id, .. } = b {
161                            orphaned_tool_result_ids.contains(tool_use_id.as_str())
162                        } else {
163                            false
164                        }
165                    });
166                    if has_orphaned_result {
167                        log::info!("Removing orphaned tool result message");
168                        continue;
169                    }
170                }
171            }
172
173            // For assistant messages, filter out orphaned tool_use blocks
174            if let MessageContent::Blocks(blocks) = msg.content {
175                let filtered_blocks: Vec<ContentBlock> = blocks
176                    .into_iter()
177                    .filter(|b| {
178                        if let ContentBlock::ToolUse { id, .. } = b {
179                            if orphaned_tool_use_ids.contains(id.as_str()) {
180                                log::info!("Removing orphaned tool_use block: {}", id);
181                                return false;
182                            }
183                        }
184                        true
185                    })
186                    .collect();
187
188                // Only add message if it has remaining content
189                if !filtered_blocks.is_empty() {
190                    cleaned.push(Message {
191                        role: msg.role,
192                        content: MessageContent::Blocks(filtered_blocks),
193                    });
194                }
195            } else {
196                cleaned.push(msg);
197            }
198        }
199
200        log::info!(
201            "Message cleaning complete: {} messages -> {} messages",
202            original_len,
203            cleaned.len()
204        );
205
206        cleaned
207    }
208
209    /// Track token usage from API response.
210    pub fn track_usage(&self, usage: &Usage) {
211        self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
212        self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
213        self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
214    }
215
216    /// Get total input tokens consumed.
217    pub fn total_input_tokens(&self) -> u64 {
218        self.total_input_tokens.load(Ordering::Relaxed)
219    }
220
221    /// Get total output tokens generated.
222    pub fn total_output_tokens(&self) -> u64 {
223        self.total_output_tokens.load(Ordering::Relaxed)
224    }
225
226    /// Get last input tokens (for compression decisions).
227    pub fn last_input_tokens(&self) -> u64 {
228        self.last_input_tokens.load(Ordering::Relaxed)
229    }
230
231    /// Set total input tokens (used after compression).
232    pub fn set_total_input_tokens(&self, value: u64) {
233        self.total_input_tokens.store(value, Ordering::Relaxed);
234    }
235
236    /// Set total output tokens.
237    pub fn set_total_output_tokens(&self, value: u64) {
238        self.total_output_tokens.store(value, Ordering::Relaxed);
239    }
240
241    /// Set last input tokens (used after compression).
242    pub fn set_last_input_tokens(&self, value: u64) {
243        self.last_input_tokens.store(value, Ordering::Relaxed);
244    }
245
246    /// Mark a tool input as previewed during streaming.
247    pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
248        self.previewed_tool_inputs.insert(tool_id);
249    }
250
251    /// Check if a tool input was already previewed.
252    pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
253        self.previewed_tool_inputs.contains(tool_id)
254    }
255
256    /// Remove a tool input from previewed set (after processing).
257    pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
258        self.previewed_tool_inputs.remove(tool_id)
259    }
260
261    /// Increment todo reminder count for a todo item.
262    /// Returns the new count.
263    pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
264        let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
265        self.todo_reminder_count.insert(todo_hash, count);
266        count
267    }
268
269    /// Get todo reminder count for a todo item.
270    pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
271        self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
272    }
273
274    /// Get reference to the entire todo reminder count map.
275    pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
276        &self.todo_reminder_count
277    }
278
279    /// Get mutable reference to the entire todo reminder count map.
280    pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
281        &mut self.todo_reminder_count
282    }
283
284    /// Check if todo reminder limit reached.
285    pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
286        self.todo_reminder_count(todo_hash) >= max_reminders
287    }
288
289    /// Get reference to read history tracker.
290    pub fn read_history(&self) -> &ReadHistoryTracker {
291        &self.read_history
292    }
293
294    /// Get mutable reference to read history tracker.
295    pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
296        &mut self.read_history
297    }
298
299    /// Add a pending input to queue.
300    pub fn add_pending_input(&mut self, input: String) {
301        self.pending_inputs.push(input);
302    }
303
304    /// Check if there are pending inputs.
305    pub fn has_pending_inputs(&self) -> bool {
306        !self.pending_inputs.is_empty()
307    }
308
309    /// Get reference to pending inputs vector.
310    pub fn pending_inputs_vec(&self) -> &Vec<String> {
311        &self.pending_inputs
312    }
313
314    /// Get mutable reference to pending inputs vector.
315    pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
316        &mut self.pending_inputs
317    }
318
319    /// Take all pending inputs (drains the queue).
320    pub fn take_pending_inputs(&mut self) -> Vec<String> {
321        std::mem::take(&mut self.pending_inputs)
322    }
323
324    /// Get count of pending inputs.
325    pub fn pending_input_count(&self) -> usize {
326        self.pending_inputs.len()
327    }
328
329    /// Get message count.
330    pub fn message_count(&self) -> usize {
331        self.messages.len()
332    }
333
334    /// Clear all state (reset to initial state).
335    pub fn clear(&mut self) {
336        self.messages.clear();
337        self.total_input_tokens.store(0, Ordering::Relaxed);
338        self.total_output_tokens.store(0, Ordering::Relaxed);
339        self.last_input_tokens.store(0, Ordering::Relaxed);
340        self.previewed_tool_inputs.clear();
341        self.todo_reminder_count.clear();
342        self.read_history = ReadHistoryTracker::new();
343        self.pending_inputs.clear();
344    }
345}
346
347impl Default for AgentState {
348    fn default() -> Self {
349        Self::new()
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::providers::{MessageContent, Role};
357
358    fn create_test_message(text: &str) -> Message {
359        Message {
360            role: Role::User,
361            content: MessageContent::Text(text.to_string()),
362        }
363    }
364
365    #[test]
366    fn test_state_new_is_empty() {
367        let state = AgentState::new();
368
369        assert_eq!(state.message_count(), 0);
370        assert_eq!(state.total_input_tokens(), 0);
371        assert_eq!(state.total_output_tokens(), 0);
372        assert_eq!(state.last_input_tokens(), 0);
373        assert!(!state.has_pending_inputs());
374        assert_eq!(state.pending_input_count(), 0);
375    }
376
377    #[test]
378    fn test_state_add_message() {
379        let mut state = AgentState::new();
380
381        state.add_message(create_test_message("Hello"));
382        state.add_message(create_test_message("World"));
383
384        assert_eq!(state.message_count(), 2);
385        assert_eq!(state.messages().len(), 2);
386    }
387
388    #[test]
389    fn test_state_track_usage() {
390        let state = AgentState::new();
391        let usage = Usage {
392            input_tokens: 100,
393            output_tokens: 50,
394            cache_creation_input_tokens: 0,
395            cache_read_input_tokens: 0,
396        };
397
398        state.track_usage(&usage);
399
400        assert_eq!(state.total_input_tokens(), 100);
401        assert_eq!(state.total_output_tokens(), 50);
402        assert_eq!(state.last_input_tokens(), 100);
403
404        // Track again (should accumulate)
405        state.track_usage(&usage);
406        assert_eq!(state.total_input_tokens(), 200);
407        assert_eq!(state.total_output_tokens(), 100);
408        assert_eq!(state.last_input_tokens(), 100);
409    }
410
411    #[test]
412    fn test_state_previewed_tool_inputs() {
413        let mut state = AgentState::new();
414
415        // Initially not previewed
416        assert!(!state.was_tool_input_previewed("tool_1"));
417
418        // Mark as previewed
419        state.mark_tool_input_previewed("tool_1".to_string());
420        assert!(state.was_tool_input_previewed("tool_1"));
421        assert!(!state.was_tool_input_previewed("tool_2"));
422
423        // Remove previewed
424        let removed = state.remove_previewed_tool_input("tool_1");
425        assert!(removed, "should return true when removing existing item");
426        assert!(!state.was_tool_input_previewed("tool_1"));
427
428        // Remove non-existent
429        let removed = state.remove_previewed_tool_input("tool_2");
430        assert!(!removed, "should return false when removing non-existent item");
431    }
432
433    #[test]
434    fn test_state_todo_reminders() {
435        let mut state = AgentState::new();
436        let todo_hash = "hash_123".to_string();
437
438        // Initially 0
439        assert_eq!(state.todo_reminder_count(&todo_hash), 0);
440        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
441
442        // Increment
443        let count = state.increment_todo_reminder(todo_hash.clone());
444        assert_eq!(count, 1);
445        assert_eq!(state.todo_reminder_count(&todo_hash), 1);
446        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
447
448        // Increment again
449        let count = state.increment_todo_reminder(todo_hash.clone());
450        assert_eq!(count, 2);
451        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
452
453        // Increment beyond limit
454        let count = state.increment_todo_reminder(todo_hash.clone());
455        assert_eq!(count, 3);
456        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
457    }
458
459    #[test]
460    fn test_state_pending_inputs() {
461        let mut state = AgentState::new();
462
463        // Initially empty
464        assert!(!state.has_pending_inputs());
465        assert_eq!(state.pending_input_count(), 0);
466
467        // Add inputs
468        state.add_pending_input("input 1".to_string());
469        state.add_pending_input("input 2".to_string());
470
471        assert!(state.has_pending_inputs());
472        assert_eq!(state.pending_input_count(), 2);
473
474        // Take inputs
475        let inputs = state.take_pending_inputs();
476        assert_eq!(inputs.len(), 2);
477        assert_eq!(inputs[0], "input 1");
478        assert_eq!(inputs[1], "input 2");
479
480        // Queue drained
481        assert!(!state.has_pending_inputs());
482        assert_eq!(state.pending_input_count(), 0);
483    }
484
485    #[test]
486    fn test_state_set_messages() {
487        let mut state = AgentState::new();
488        state.add_message(create_test_message("Old message"));
489
490        // Replace messages
491        let new_messages = vec![
492            create_test_message("New 1"),
493            create_test_message("New 2"),
494        ];
495        state.set_messages(new_messages);
496
497        assert_eq!(state.message_count(), 2);
498        assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
499    }
500
501    #[test]
502    fn test_state_clear() {
503        let mut state = AgentState::new();
504
505        // Add some state
506        state.add_message(create_test_message("Test"));
507        state.track_usage(&Usage {
508            input_tokens: 100,
509            output_tokens: 50,
510            cache_creation_input_tokens: 0,
511            cache_read_input_tokens: 0,
512        });
513        state.add_pending_input("pending".to_string());
514        state.mark_tool_input_previewed("tool_1".to_string());
515
516        // Clear
517        state.clear();
518
519        // Verify all cleared
520        assert_eq!(state.message_count(), 0);
521        assert_eq!(state.total_input_tokens(), 0);
522        assert_eq!(state.total_output_tokens(), 0);
523        assert!(!state.has_pending_inputs());
524        assert!(!state.was_tool_input_previewed("tool_1"));
525    }
526}