use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::providers::{ContentBlock, Message, MessageContent, Role, Usage};
use crate::tools::ReadHistoryTracker;
pub const MAX_SAME_ERROR_COUNT: usize = 3;
#[derive(Debug, Clone)]
pub struct ToolErrorEntry {
pub tool_name: String,
pub error_key: String,
pub count: usize,
pub last_occurrence: std::time::Instant,
}
impl ToolErrorEntry {
pub fn new(tool_name: &str, error_msg: &str) -> Self {
let error_key = if error_msg.chars().count() > 100 {
error_msg.chars().take(100).collect::<String>()
} else {
error_msg.to_string()
};
Self {
tool_name: tool_name.to_string(),
error_key,
count: 1,
last_occurrence: std::time::Instant::now(),
}
}
pub fn matches(&self, tool_name: &str, error_msg: &str) -> bool {
let new_key = if error_msg.chars().count() > 100 {
error_msg.chars().take(100).collect::<String>()
} else {
error_msg.to_string()
};
self.tool_name == tool_name && self.error_key == new_key
}
pub fn increment(&mut self) {
self.count += 1;
self.last_occurrence = std::time::Instant::now();
}
pub fn is_limit_reached(&self) -> bool {
self.count >= MAX_SAME_ERROR_COUNT
}
}
pub struct AgentState {
messages: Vec<Message>,
total_input_tokens: AtomicU64,
total_output_tokens: AtomicU64,
last_input_tokens: AtomicU64,
previewed_tool_inputs: HashSet<String>,
todo_reminder_count: HashMap<String, usize>,
read_history: ReadHistoryTracker,
pending_inputs: Vec<String>,
error_history: Vec<ToolErrorEntry>,
}
impl AgentState {
pub fn new() -> Self {
Self {
messages: Vec::new(),
total_input_tokens: AtomicU64::new(0),
total_output_tokens: AtomicU64::new(0),
last_input_tokens: AtomicU64::new(0),
previewed_tool_inputs: HashSet::new(),
todo_reminder_count: HashMap::new(),
read_history: ReadHistoryTracker::new(),
pending_inputs: Vec::new(),
error_history: Vec::new(),
}
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn messages(&self) -> &Vec<Message> {
&self.messages
}
pub fn messages_mut(&mut self) -> &mut Vec<Message> {
&mut self.messages
}
pub fn set_messages(&mut self, messages: Vec<Message>) {
let cleaned = Self::clean_orphaned_messages(messages);
self.messages = cleaned;
}
fn clean_orphaned_messages(messages: Vec<Message>) -> Vec<Message> {
if messages.is_empty() {
return messages;
}
let mut tool_use_ids: HashSet<String> = HashSet::new();
for msg in &messages {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if let ContentBlock::ToolUse { id, .. } = block {
tool_use_ids.insert(id.clone());
}
}
}
}
let mut tool_result_ids: HashSet<String> = HashSet::new();
for msg in &messages {
if msg.role == Role::Tool {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if let ContentBlock::ToolResult { tool_use_id, .. } = block {
tool_result_ids.insert(tool_use_id.clone());
}
}
}
}
}
let orphaned_tool_use_ids: HashSet<&str> = tool_use_ids
.iter()
.filter(|id| !tool_result_ids.contains(*id))
.map(|s| s.as_str())
.collect();
let orphaned_tool_result_ids: HashSet<&str> = tool_result_ids
.iter()
.filter(|id| !tool_use_ids.contains(*id))
.map(|s| s.as_str())
.collect();
if orphaned_tool_use_ids.is_empty() && orphaned_tool_result_ids.is_empty() {
return messages;
}
log::warn!(
"Cleaning orphaned messages: {} tool_uses without results, {} tool_results without uses",
orphaned_tool_use_ids.len(),
orphaned_tool_result_ids.len()
);
let original_len = messages.len();
let mut cleaned = Vec::with_capacity(messages.len());
for msg in messages {
if msg.role == Role::Tool {
if let MessageContent::Blocks(blocks) = &msg.content {
let has_orphaned_result = blocks.iter().any(|b| {
if let ContentBlock::ToolResult { tool_use_id, .. } = b {
orphaned_tool_result_ids.contains(tool_use_id.as_str())
} else {
false
}
});
if has_orphaned_result {
log::info!("Removing orphaned tool result message");
continue;
}
}
}
if let MessageContent::Blocks(blocks) = msg.content {
let filtered_blocks: Vec<ContentBlock> = blocks
.into_iter()
.filter(|b| {
if let ContentBlock::ToolUse { id, .. } = b {
if orphaned_tool_use_ids.contains(id.as_str()) {
log::info!("Removing orphaned tool_use block: {}", id);
return false;
}
}
true
})
.collect();
if !filtered_blocks.is_empty() {
cleaned.push(Message {
role: msg.role,
content: MessageContent::Blocks(filtered_blocks),
});
}
} else {
cleaned.push(msg);
}
}
log::info!(
"Message cleaning complete: {} messages -> {} messages",
original_len,
cleaned.len()
);
cleaned
}
pub fn track_usage(&self, usage: &Usage) {
self.total_input_tokens.fetch_add(usage.input_tokens as u64, Ordering::Relaxed);
self.total_output_tokens.fetch_add(usage.output_tokens as u64, Ordering::Relaxed);
self.last_input_tokens.store(usage.input_tokens as u64, Ordering::Relaxed);
}
pub fn total_input_tokens(&self) -> u64 {
self.total_input_tokens.load(Ordering::Relaxed)
}
pub fn total_output_tokens(&self) -> u64 {
self.total_output_tokens.load(Ordering::Relaxed)
}
pub fn last_input_tokens(&self) -> u64 {
self.last_input_tokens.load(Ordering::Relaxed)
}
pub fn set_total_input_tokens(&self, value: u64) {
self.total_input_tokens.store(value, Ordering::Relaxed);
}
pub fn set_total_output_tokens(&self, value: u64) {
self.total_output_tokens.store(value, Ordering::Relaxed);
}
pub fn set_last_input_tokens(&self, value: u64) {
self.last_input_tokens.store(value, Ordering::Relaxed);
}
pub fn mark_tool_input_previewed(&mut self, tool_id: String) {
self.previewed_tool_inputs.insert(tool_id);
}
pub fn was_tool_input_previewed(&self, tool_id: &str) -> bool {
self.previewed_tool_inputs.contains(tool_id)
}
pub fn remove_previewed_tool_input(&mut self, tool_id: &str) -> bool {
self.previewed_tool_inputs.remove(tool_id)
}
pub fn increment_todo_reminder(&mut self, todo_hash: String) -> usize {
let count = self.todo_reminder_count.get(&todo_hash).copied().unwrap_or(0) + 1;
self.todo_reminder_count.insert(todo_hash, count);
count
}
pub fn todo_reminder_count(&self, todo_hash: &str) -> usize {
self.todo_reminder_count.get(todo_hash).copied().unwrap_or(0)
}
pub fn todo_reminder_count_map(&self) -> &std::collections::HashMap<String, usize> {
&self.todo_reminder_count
}
pub fn todo_reminder_count_map_mut(&mut self) -> &mut std::collections::HashMap<String, usize> {
&mut self.todo_reminder_count
}
pub fn is_todo_reminder_limit_reached(&self, todo_hash: &str, max_reminders: usize) -> bool {
self.todo_reminder_count(todo_hash) >= max_reminders
}
pub fn read_history(&self) -> &ReadHistoryTracker {
&self.read_history
}
pub fn read_history_mut(&mut self) -> &mut ReadHistoryTracker {
&mut self.read_history
}
pub fn add_pending_input(&mut self, input: String) {
self.pending_inputs.push(input);
}
pub fn has_pending_inputs(&self) -> bool {
!self.pending_inputs.is_empty()
}
pub fn pending_inputs_vec(&self) -> &Vec<String> {
&self.pending_inputs
}
pub fn pending_inputs_vec_mut(&mut self) -> &mut Vec<String> {
&mut self.pending_inputs
}
pub fn take_pending_inputs(&mut self) -> Vec<String> {
std::mem::take(&mut self.pending_inputs)
}
pub fn pending_input_count(&self) -> usize {
self.pending_inputs.len()
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn record_tool_error(&mut self, tool_name: &str, error_msg: &str) -> usize {
for entry in &mut self.error_history {
if entry.matches(tool_name, error_msg) {
entry.increment();
return entry.count;
}
}
let new_entry = ToolErrorEntry::new(tool_name, error_msg);
let count = new_entry.count;
self.error_history.push(new_entry);
count
}
pub fn check_error_limit(&self, tool_name: &str, error_msg: &str) -> Option<&ToolErrorEntry> {
self.error_history.iter().find(|e| {
e.matches(tool_name, error_msg) && e.is_limit_reached()
})
}
pub fn error_count(&self, tool_name: &str, error_msg: &str) -> usize {
self.error_history.iter()
.find(|e| e.matches(tool_name, error_msg))
.map(|e| e.count)
.unwrap_or(0)
}
pub fn clear_error_history(&mut self) {
self.error_history.clear();
}
pub fn unique_error_count(&self) -> usize {
self.error_history.len()
}
pub fn repeated_error_count(&self) -> usize {
self.error_history.iter().filter(|e| e.count > 1).count()
}
pub fn clear(&mut self) {
self.messages.clear();
self.total_input_tokens.store(0, Ordering::Relaxed);
self.total_output_tokens.store(0, Ordering::Relaxed);
self.last_input_tokens.store(0, Ordering::Relaxed);
self.previewed_tool_inputs.clear();
self.todo_reminder_count.clear();
self.read_history = ReadHistoryTracker::new();
self.pending_inputs.clear();
self.error_history.clear();
}
}
impl Default for AgentState {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{MessageContent, Role};
fn create_test_message(text: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text(text.to_string()),
}
}
#[test]
fn test_state_new_is_empty() {
let state = AgentState::new();
assert_eq!(state.message_count(), 0);
assert_eq!(state.total_input_tokens(), 0);
assert_eq!(state.total_output_tokens(), 0);
assert_eq!(state.last_input_tokens(), 0);
assert!(!state.has_pending_inputs());
assert_eq!(state.pending_input_count(), 0);
}
#[test]
fn test_state_add_message() {
let mut state = AgentState::new();
state.add_message(create_test_message("Hello"));
state.add_message(create_test_message("World"));
assert_eq!(state.message_count(), 2);
assert_eq!(state.messages().len(), 2);
}
#[test]
fn test_state_track_usage() {
let state = AgentState::new();
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
};
state.track_usage(&usage);
assert_eq!(state.total_input_tokens(), 100);
assert_eq!(state.total_output_tokens(), 50);
assert_eq!(state.last_input_tokens(), 100);
state.track_usage(&usage);
assert_eq!(state.total_input_tokens(), 200);
assert_eq!(state.total_output_tokens(), 100);
assert_eq!(state.last_input_tokens(), 100);
}
#[test]
fn test_state_previewed_tool_inputs() {
let mut state = AgentState::new();
assert!(!state.was_tool_input_previewed("tool_1"));
state.mark_tool_input_previewed("tool_1".to_string());
assert!(state.was_tool_input_previewed("tool_1"));
assert!(!state.was_tool_input_previewed("tool_2"));
let removed = state.remove_previewed_tool_input("tool_1");
assert!(removed, "should return true when removing existing item");
assert!(!state.was_tool_input_previewed("tool_1"));
let removed = state.remove_previewed_tool_input("tool_2");
assert!(!removed, "should return false when removing non-existent item");
}
#[test]
fn test_state_todo_reminders() {
let mut state = AgentState::new();
let todo_hash = "hash_123".to_string();
assert_eq!(state.todo_reminder_count(&todo_hash), 0);
assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
let count = state.increment_todo_reminder(todo_hash.clone());
assert_eq!(count, 1);
assert_eq!(state.todo_reminder_count(&todo_hash), 1);
assert!(!state.is_todo_reminder_limit_reached(&todo_hash, 2));
let count = state.increment_todo_reminder(todo_hash.clone());
assert_eq!(count, 2);
assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
let count = state.increment_todo_reminder(todo_hash.clone());
assert_eq!(count, 3);
assert!(state.is_todo_reminder_limit_reached(&todo_hash, 2));
}
#[test]
fn test_state_pending_inputs() {
let mut state = AgentState::new();
assert!(!state.has_pending_inputs());
assert_eq!(state.pending_input_count(), 0);
state.add_pending_input("input 1".to_string());
state.add_pending_input("input 2".to_string());
assert!(state.has_pending_inputs());
assert_eq!(state.pending_input_count(), 2);
let inputs = state.take_pending_inputs();
assert_eq!(inputs.len(), 2);
assert_eq!(inputs[0], "input 1");
assert_eq!(inputs[1], "input 2");
assert!(!state.has_pending_inputs());
assert_eq!(state.pending_input_count(), 0);
}
#[test]
fn test_state_set_messages() {
let mut state = AgentState::new();
state.add_message(create_test_message("Old message"));
let new_messages = vec![
create_test_message("New 1"),
create_test_message("New 2"),
];
state.set_messages(new_messages);
assert_eq!(state.message_count(), 2);
assert_eq!(state.messages()[0].content, MessageContent::Text("New 1".to_string()));
}
#[test]
fn test_state_clear() {
let mut state = AgentState::new();
state.add_message(create_test_message("Test"));
state.track_usage(&Usage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
});
state.add_pending_input("pending".to_string());
state.mark_tool_input_previewed("tool_1".to_string());
state.clear();
assert_eq!(state.message_count(), 0);
assert_eq!(state.total_input_tokens(), 0);
assert_eq!(state.total_output_tokens(), 0);
assert!(!state.has_pending_inputs());
assert!(!state.was_tool_input_previewed("tool_1"));
}
}