matrixcode_core/agent/core/
state.rs1use std::collections::{HashMap, HashSet};
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use crate::providers::{Message, Usage};
19use crate::tools::ReadHistoryTracker;
20
21pub struct AgentState {
26 messages: Vec<Message>,
28
29 total_input_tokens: AtomicU64,
31
32 total_output_tokens: AtomicU64,
34
35 last_input_tokens: AtomicU64,
37
38 previewed_tool_inputs: HashSet<String>,
41
42 todo_reminder_count: HashMap<String, usize>,
45
46 read_history: ReadHistoryTracker,
49
50 pending_inputs: Vec<String>,
52}
53
54impl AgentState {
55 pub fn new() -> Self {
57 Self {
58 messages: Vec::new(),
59 total_input_tokens: AtomicU64::new(0),
60 total_output_tokens: AtomicU64::new(0),
61 last_input_tokens: AtomicU64::new(0),
62 previewed_tool_inputs: HashSet::new(),
63 todo_reminder_count: HashMap::new(),
64 read_history: ReadHistoryTracker::new(),
65 pending_inputs: Vec::new(),
66 }
67 }
68
69 pub fn add_message(&mut self, message: Message) {
71 self.messages.push(message);
72 }
73
74 pub fn messages(&self) -> &Vec<Message> {
76 &self.messages
77 }
78
79 pub fn messages_mut(&mut self) -> &mut Vec<Message> {
81 &mut self.messages
82 }
83
84 pub fn set_messages(&mut self, messages: Vec<Message>) {
86 self.messages = messages;
87 }
88
89 pub fn track_usage(&self, usage: &Usage) {
91 self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
92 self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
93 self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
94 }
95
96 pub fn total_input_tokens(&self) -> u64 {
98 self.total_input_tokens.load(Ordering::Relaxed)
99 }
100
101 pub fn total_output_tokens(&self) -> u64 {
103 self.total_output_tokens.load(Ordering::Relaxed)
104 }
105
106 pub fn last_input_tokens(&self) -> u64 {
108 self.last_input_tokens.load(Ordering::Relaxed)
109 }
110
111 pub fn set_total_input_tokens(&self, value: u64) {
113 self.total_input_tokens.store(value, Ordering::Relaxed);
114 }
115
116 pub fn set_total_output_tokens(&self, value: u64) {
118 self.total_output_tokens.store(value, Ordering::Relaxed);
119 }
120
121 pub fn set_last_input_tokens(&self, value: u64) {
123 self.last_input_tokens.store(value, Ordering::Relaxed);
124 }
125
126 pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
128 self.previewed_tool_inputs.insert(tool_id);
129 }
130
131 pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
133 self.previewed_tool_inputs.contains(tool_id)
134 }
135
136 pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
138 self.previewed_tool_inputs.remove(tool_id)
139 }
140
141 pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
144 let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
145 self.todo_reminder_count.insert(todo_hash, count);
146 count
147 }
148
149 pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
151 self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
152 }
153
154 pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
156 &self.todo_reminder_count
157 }
158
159 pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
161 &mut self.todo_reminder_count
162 }
163
164 pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
166 self.todo_reminder_count(todo_hash) >= max_reminders
167 }
168
169 pub fn read_history(&self) -> &ReadHistoryTracker {
171 &self.read_history
172 }
173
174 pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
176 &mut self.read_history
177 }
178
179 pub fn add_pending_input(&mut self, input: String) {
181 self.pending_inputs.push(input);
182 }
183
184 pub fn has_pending_inputs(&self) -> bool {
186 !self.pending_inputs.is_empty()
187 }
188
189 pub fn pending_inputs_vec(&self) -> &Vec<String> {
191 &self.pending_inputs
192 }
193
194 pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
196 &mut self.pending_inputs
197 }
198
199 pub fn take_pending_inputs(&mut self) -> Vec<String> {
201 std::mem::take(&mut self.pending_inputs)
202 }
203
204 pub fn pending_input_count(&self) -> usize {
206 self.pending_inputs.len()
207 }
208
209 pub fn message_count(&self) -> usize {
211 self.messages.len()
212 }
213
214 pub fn clear(&mut self) {
216 self.messages.clear();
217 self.total_input_tokens.store(0, Ordering::Relaxed);
218 self.total_output_tokens.store(0, Ordering::Relaxed);
219 self.last_input_tokens.store(0, Ordering::Relaxed);
220 self.previewed_tool_inputs.clear();
221 self.todo_reminder_count.clear();
222 self.read_history = ReadHistoryTracker::new();
223 self.pending_inputs.clear();
224 }
225}
226
227impl Default for AgentState {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::providers::{MessageContent, Role};
237
238 fn create_test_message(text: &str) -> Message {
239 Message {
240 role: Role::User,
241 content: MessageContent::Text(text.to_string()),
242 }
243 }
244
245 #[test]
246 fn test_state_new_is_empty() {
247 let state = AgentState::new();
248
249 assert_eq!(state.message_count(), 0);
250 assert_eq!(state.total_input_tokens(), 0);
251 assert_eq!(state.total_output_tokens(), 0);
252 assert_eq!(state.last_input_tokens(), 0);
253 assert!(!state.has_pending_inputs());
254 assert_eq!(state.pending_input_count(), 0);
255 }
256
257 #[test]
258 fn test_state_add_message() {
259 let mut state = AgentState::new();
260
261 state.add_message(create_test_message("Hello"));
262 state.add_message(create_test_message("World"));
263
264 assert_eq!(state.message_count(), 2);
265 assert_eq!(state.messages().len(), 2);
266 }
267
268 #[test]
269 fn test_state_track_usage() {
270 let state = AgentState::new();
271 let usage = Usage {
272 input_tokens: 100,
273 output_tokens: 50,
274 cache_creation_input_tokens: 0,
275 cache_read_input_tokens: 0,
276 };
277
278 state.track_usage(&usage);
279
280 assert_eq!(state.total_input_tokens(), 100);
281 assert_eq!(state.total_output_tokens(), 50);
282 assert_eq!(state.last_input_tokens(), 100);
283
284 state.track_usage(&usage);
286 assert_eq!(state.total_input_tokens(), 200);
287 assert_eq!(state.total_output_tokens(), 100);
288 assert_eq!(state.last_input_tokens(), 100);
289 }
290
291 #[test]
292 fn test_state_previewed_tool_inputs() {
293 let mut state = AgentState::new();
294
295 assert!(!state.was_tool_input_previewed("tool_1"));
297
298 state.mark_tool_input_previewed("tool_1".to_string());
300 assert!(state.was_tool_input_previewed("tool_1"));
301 assert!(!state.was_tool_input_previewed("tool_2"));
302
303 let removed = state.remove_previewed_tool_input("tool_1");
305 assert!(removed, "should return true when removing existing item");
306 assert!(!state.was_tool_input_previewed("tool_1"));
307
308 let removed = state.remove_previewed_tool_input("tool_2");
310 assert!(!removed, "should return false when removing non-existent item");
311 }
312
313 #[test]
314 fn test_state_todo_reminders() {
315 let mut state = AgentState::new();
316 let todo_hash = "hash_123".to_string();
317
318 assert_eq!(state.todo_reminder_count(&todo_hash), 0);
320 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
321
322 let count = state.increment_todo_reminder(todo_hash.clone());
324 assert_eq!(count, 1);
325 assert_eq!(state.todo_reminder_count(&todo_hash), 1);
326 assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
327
328 let count = state.increment_todo_reminder(todo_hash.clone());
330 assert_eq!(count, 2);
331 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
332
333 let count = state.increment_todo_reminder(todo_hash.clone());
335 assert_eq!(count, 3);
336 assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
337 }
338
339 #[test]
340 fn test_state_pending_inputs() {
341 let mut state = AgentState::new();
342
343 assert!(!state.has_pending_inputs());
345 assert_eq!(state.pending_input_count(), 0);
346
347 state.add_pending_input("input 1".to_string());
349 state.add_pending_input("input 2".to_string());
350
351 assert!(state.has_pending_inputs());
352 assert_eq!(state.pending_input_count(), 2);
353
354 let inputs = state.take_pending_inputs();
356 assert_eq!(inputs.len(), 2);
357 assert_eq!(inputs[0], "input 1");
358 assert_eq!(inputs[1], "input 2");
359
360 assert!(!state.has_pending_inputs());
362 assert_eq!(state.pending_input_count(), 0);
363 }
364
365 #[test]
366 fn test_state_set_messages() {
367 let mut state = AgentState::new();
368 state.add_message(create_test_message("Old message"));
369
370 let new_messages = vec![
372 create_test_message("New 1"),
373 create_test_message("New 2"),
374 ];
375 state.set_messages(new_messages);
376
377 assert_eq!(state.message_count(), 2);
378 assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
379 }
380
381 #[test]
382 fn test_state_clear() {
383 let mut state = AgentState::new();
384
385 state.add_message(create_test_message("Test"));
387 state.track_usage(&Usage {
388 input_tokens: 100,
389 output_tokens: 50,
390 cache_creation_input_tokens: 0,
391 cache_read_input_tokens: 0,
392 });
393 state.add_pending_input("pending".to_string());
394 state.mark_tool_input_previewed("tool_1".to_string());
395
396 state.clear();
398
399 assert_eq!(state.message_count(), 0);
401 assert_eq!(state.total_input_tokens(), 0);
402 assert_eq!(state.total_output_tokens(), 0);
403 assert!(!state.has_pending_inputs());
404 assert!(!state.was_tool_input_previewed("tool_1"));
405 }
406}