Skip to main content

matrixcode_core/agent/core/
state.rs

1//! Agent state management.
2//!
3//! This module manages the runtime state of the Agent, including:
4//! - Message history
5//! - Token usage tracking
6//! - Pending inputs
7//! - Todo reminders
8//! - Read history
9//!
10//! By extracting state into a dedicated struct, we enable:
11//! - Clear separation between state and configuration
12//! - Easier testing of state transitions
13//! - Better encapsulation of mutable state
14
15use std::collections::{HashMap, HashSet};
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use crate::providers::{Message, Usage};
19use crate::tools::ReadHistoryTracker;
20
21/// Agent runtime state.
22///
23/// Manages all mutable state during agent execution.
24/// All fields are private to enforce encapsulation.
25pub struct AgentState {
26    /// Message history (conversation with LLM).
27    messages: Vec<Message>,
28
29    /// Total input tokens consumed (lifetime).
30    total_input_tokens: AtomicU64,
31
32    /// Total output tokens generated (lifetime).
33    total_output_tokens: AtomicU64,
34
35    /// Last input tokens (for compression tracking).
36    last_input_tokens: AtomicU64,
37
38    /// Tool input IDs that were previewed during streaming.
39    /// Prevents duplicate emission of ToolUseStart events.
40    previewed_tool_inputs: HashSet<String>,
41
42    /// Todo reminder counts per todo content hash.
43    /// Prevents infinite reminder loops.
44    todo_reminder_count: HashMap<String, usize>,
45
46    /// Files read in this session.
47    /// Enforces "read before edit/write" rule.
48    read_history: ReadHistoryTracker,
49
50    /// Pending user inputs queued for next iteration.
51    pending_inputs: Vec<String>,
52}
53
54impl AgentState {
55    /// Create a new empty state.
56    pub fn new() -> Self {
57        Self {
58            messages: Vec::new(),
59            total_input_tokens: AtomicU64::new(0),
60            total_output_tokens: AtomicU64::new(0),
61            last_input_tokens: AtomicU64::new(0),
62            previewed_tool_inputs: HashSet::new(),
63            todo_reminder_count: HashMap::new(),
64            read_history: ReadHistoryTracker::new(),
65            pending_inputs: Vec::new(),
66        }
67    }
68
69    /// Add a message to history.
70    pub fn add_message(&mut self, message: Message) {
71        self.messages.push(message);
72    }
73
74    /// Get reference to message history.
75    pub fn messages(&self) -> &Vec<Message> {
76        &self.messages
77    }
78
79    /// Get mutable reference to message history.
80    pub fn messages_mut(&mut self) -> &mut Vec<Message> {
81        &mut self.messages
82    }
83
84    /// Replace message history (used in compression).
85    pub fn set_messages(&mut self, messages: Vec<Message>) {
86        self.messages = messages;
87    }
88
89    /// Track token usage from API response.
90    pub fn track_usage(&self, usage: &Usage) {
91        self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
92        self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
93        self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
94    }
95
96    /// Get total input tokens consumed.
97    pub fn total_input_tokens(&self) -> u64 {
98        self.total_input_tokens.load(Ordering::Relaxed)
99    }
100
101    /// Get total output tokens generated.
102    pub fn total_output_tokens(&self) -> u64 {
103        self.total_output_tokens.load(Ordering::Relaxed)
104    }
105
106    /// Get last input tokens (for compression decisions).
107    pub fn last_input_tokens(&self) -> u64 {
108        self.last_input_tokens.load(Ordering::Relaxed)
109    }
110
111    /// Set total input tokens (used after compression).
112    pub fn set_total_input_tokens(&self, value: u64) {
113        self.total_input_tokens.store(value, Ordering::Relaxed);
114    }
115
116    /// Set total output tokens.
117    pub fn set_total_output_tokens(&self, value: u64) {
118        self.total_output_tokens.store(value, Ordering::Relaxed);
119    }
120
121    /// Set last input tokens (used after compression).
122    pub fn set_last_input_tokens(&self, value: u64) {
123        self.last_input_tokens.store(value, Ordering::Relaxed);
124    }
125
126    /// Mark a tool input as previewed during streaming.
127    pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
128        self.previewed_tool_inputs.insert(tool_id);
129    }
130
131    /// Check if a tool input was already previewed.
132    pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
133        self.previewed_tool_inputs.contains(tool_id)
134    }
135
136    /// Remove a tool input from previewed set (after processing).
137    pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
138        self.previewed_tool_inputs.remove(tool_id)
139    }
140
141    /// Increment todo reminder count for a todo item.
142    /// Returns the new count.
143    pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
144        let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
145        self.todo_reminder_count.insert(todo_hash, count);
146        count
147    }
148
149    /// Get todo reminder count for a todo item.
150    pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
151        self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
152    }
153
154    /// Get reference to the entire todo reminder count map.
155    pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
156        &self.todo_reminder_count
157    }
158
159    /// Get mutable reference to the entire todo reminder count map.
160    pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
161        &mut self.todo_reminder_count
162    }
163
164    /// Check if todo reminder limit reached.
165    pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
166        self.todo_reminder_count(todo_hash) >= max_reminders
167    }
168
169    /// Get reference to read history tracker.
170    pub fn read_history(&self) -> &ReadHistoryTracker {
171        &self.read_history
172    }
173
174    /// Get mutable reference to read history tracker.
175    pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
176        &mut self.read_history
177    }
178
179    /// Add a pending input to queue.
180    pub fn add_pending_input(&mut self, input: String) {
181        self.pending_inputs.push(input);
182    }
183
184    /// Check if there are pending inputs.
185    pub fn has_pending_inputs(&self) -> bool {
186        !self.pending_inputs.is_empty()
187    }
188
189    /// Get reference to pending inputs vector.
190    pub fn pending_inputs_vec(&self) -> &Vec<String> {
191        &self.pending_inputs
192    }
193
194    /// Get mutable reference to pending inputs vector.
195    pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
196        &mut self.pending_inputs
197    }
198
199    /// Take all pending inputs (drains the queue).
200    pub fn take_pending_inputs(&mut self) -> Vec<String> {
201        std::mem::take(&mut self.pending_inputs)
202    }
203
204    /// Get count of pending inputs.
205    pub fn pending_input_count(&self) -> usize {
206        self.pending_inputs.len()
207    }
208
209    /// Get message count.
210    pub fn message_count(&self) -> usize {
211        self.messages.len()
212    }
213
214    /// Clear all state (reset to initial state).
215    pub fn clear(&mut self) {
216        self.messages.clear();
217        self.total_input_tokens.store(0, Ordering::Relaxed);
218        self.total_output_tokens.store(0, Ordering::Relaxed);
219        self.last_input_tokens.store(0, Ordering::Relaxed);
220        self.previewed_tool_inputs.clear();
221        self.todo_reminder_count.clear();
222        self.read_history = ReadHistoryTracker::new();
223        self.pending_inputs.clear();
224    }
225}
226
227impl Default for AgentState {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::providers::{MessageContent, Role};
237
238    fn create_test_message(text: &str) -> Message {
239        Message {
240            role: Role::User,
241            content: MessageContent::Text(text.to_string()),
242        }
243    }
244
245    #[test]
246    fn test_state_new_is_empty() {
247        let state = AgentState::new();
248
249        assert_eq!(state.message_count(), 0);
250        assert_eq!(state.total_input_tokens(), 0);
251        assert_eq!(state.total_output_tokens(), 0);
252        assert_eq!(state.last_input_tokens(), 0);
253        assert!(!state.has_pending_inputs());
254        assert_eq!(state.pending_input_count(), 0);
255    }
256
257    #[test]
258    fn test_state_add_message() {
259        let mut state = AgentState::new();
260
261        state.add_message(create_test_message("Hello"));
262        state.add_message(create_test_message("World"));
263
264        assert_eq!(state.message_count(), 2);
265        assert_eq!(state.messages().len(), 2);
266    }
267
268    #[test]
269    fn test_state_track_usage() {
270        let state = AgentState::new();
271        let usage = Usage {
272            input_tokens: 100,
273            output_tokens: 50,
274            cache_creation_input_tokens: 0,
275            cache_read_input_tokens: 0,
276        };
277
278        state.track_usage(&usage);
279
280        assert_eq!(state.total_input_tokens(), 100);
281        assert_eq!(state.total_output_tokens(), 50);
282        assert_eq!(state.last_input_tokens(), 100);
283
284        // Track again (should accumulate)
285        state.track_usage(&usage);
286        assert_eq!(state.total_input_tokens(), 200);
287        assert_eq!(state.total_output_tokens(), 100);
288        assert_eq!(state.last_input_tokens(), 100);
289    }
290
291    #[test]
292    fn test_state_previewed_tool_inputs() {
293        let mut state = AgentState::new();
294
295        // Initially not previewed
296        assert!(!state.was_tool_input_previewed("tool_1"));
297
298        // Mark as previewed
299        state.mark_tool_input_previewed("tool_1".to_string());
300        assert!(state.was_tool_input_previewed("tool_1"));
301        assert!(!state.was_tool_input_previewed("tool_2"));
302
303        // Remove previewed
304        let removed = state.remove_previewed_tool_input("tool_1");
305        assert!(removed, "should return true when removing existing item");
306        assert!(!state.was_tool_input_previewed("tool_1"));
307
308        // Remove non-existent
309        let removed = state.remove_previewed_tool_input("tool_2");
310        assert!(!removed, "should return false when removing non-existent item");
311    }
312
313    #[test]
314    fn test_state_todo_reminders() {
315        let mut state = AgentState::new();
316        let todo_hash = "hash_123".to_string();
317
318        // Initially 0
319        assert_eq!(state.todo_reminder_count(&todo_hash), 0);
320        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
321
322        // Increment
323        let count = state.increment_todo_reminder(todo_hash.clone());
324        assert_eq!(count, 1);
325        assert_eq!(state.todo_reminder_count(&todo_hash), 1);
326        assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
327
328        // Increment again
329        let count = state.increment_todo_reminder(todo_hash.clone());
330        assert_eq!(count, 2);
331        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
332
333        // Increment beyond limit
334        let count = state.increment_todo_reminder(todo_hash.clone());
335        assert_eq!(count, 3);
336        assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
337    }
338
339    #[test]
340    fn test_state_pending_inputs() {
341        let mut state = AgentState::new();
342
343        // Initially empty
344        assert!(!state.has_pending_inputs());
345        assert_eq!(state.pending_input_count(), 0);
346
347        // Add inputs
348        state.add_pending_input("input 1".to_string());
349        state.add_pending_input("input 2".to_string());
350
351        assert!(state.has_pending_inputs());
352        assert_eq!(state.pending_input_count(), 2);
353
354        // Take inputs
355        let inputs = state.take_pending_inputs();
356        assert_eq!(inputs.len(), 2);
357        assert_eq!(inputs[0], "input 1");
358        assert_eq!(inputs[1], "input 2");
359
360        // Queue drained
361        assert!(!state.has_pending_inputs());
362        assert_eq!(state.pending_input_count(), 0);
363    }
364
365    #[test]
366    fn test_state_set_messages() {
367        let mut state = AgentState::new();
368        state.add_message(create_test_message("Old message"));
369
370        // Replace messages
371        let new_messages = vec![
372            create_test_message("New 1"),
373            create_test_message("New 2"),
374        ];
375        state.set_messages(new_messages);
376
377        assert_eq!(state.message_count(), 2);
378        assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
379    }
380
381    #[test]
382    fn test_state_clear() {
383        let mut state = AgentState::new();
384
385        // Add some state
386        state.add_message(create_test_message("Test"));
387        state.track_usage(&Usage {
388            input_tokens: 100,
389            output_tokens: 50,
390            cache_creation_input_tokens: 0,
391            cache_read_input_tokens: 0,
392        });
393        state.add_pending_input("pending".to_string());
394        state.mark_tool_input_previewed("tool_1".to_string());
395
396        // Clear
397        state.clear();
398
399        // Verify all cleared
400        assert_eq!(state.message_count(), 0);
401        assert_eq!(state.total_input_tokens(), 0);
402        assert_eq!(state.total_output_tokens(), 0);
403        assert!(!state.has_pending_inputs());
404        assert!(!state.was_tool_input_previewed("tool_1"));
405    }
406}