use crate::agent::{AgentMessage, AgentTool};
use crate::thinking::ThinkingLevel;
use crate::types::Model;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[derive(Debug)]
pub struct AgentState {
pub system_prompt: RwLock<String>,
pub tools: RwLock<Vec<AgentTool>>,
pub messages: RwLock<Vec<AgentMessage>>,
pub is_streaming: AtomicBool,
pub stream_message: RwLock<Option<AgentMessage>>,
pub pending_tool_calls: RwLock<HashSet<String>>,
pub error: RwLock<Option<String>>,
pub max_messages: AtomicUsize,
}
impl AgentState {
pub fn new() -> Self {
Self {
system_prompt: RwLock::new(String::new()),
tools: RwLock::new(Vec::new()),
messages: RwLock::new(Vec::new()),
is_streaming: AtomicBool::new(false),
stream_message: RwLock::new(None),
pending_tool_calls: RwLock::new(HashSet::new()),
error: RwLock::new(None),
max_messages: AtomicUsize::new(0), }
}
pub fn set_system_prompt(&self, prompt: impl Into<String>) {
*self.system_prompt.write() = prompt.into();
}
pub fn set_tools(&self, tools: Vec<AgentTool>) {
*self.tools.write() = tools;
}
pub fn add_message(&self, message: AgentMessage) {
let mut msgs = self.messages.write();
msgs.push(message);
let max = self.max_messages.load(Ordering::SeqCst);
if max > 0 && msgs.len() > max {
let excess = msgs.len() - max;
msgs.drain(..excess);
}
}
pub fn set_max_messages(&self, max: usize) {
self.max_messages.store(max, Ordering::SeqCst);
if max > 0 {
let mut msgs = self.messages.write();
if msgs.len() > max {
let excess = msgs.len() - max;
msgs.drain(..excess);
}
}
}
pub fn get_max_messages(&self) -> usize {
self.max_messages.load(Ordering::SeqCst)
}
pub fn replace_messages(&self, messages: Vec<AgentMessage>) {
*self.messages.write() = messages;
}
pub fn clear_messages(&self) {
self.messages.write().clear();
}
pub fn reset(&self) {
*self.system_prompt.write() = String::new();
*self.tools.write() = Vec::new();
self.messages.write().clear();
self.is_streaming.store(false, Ordering::SeqCst);
*self.stream_message.write() = None;
self.pending_tool_calls.write().clear();
*self.error.write() = None;
}
pub fn is_streaming(&self) -> bool {
self.is_streaming.load(Ordering::SeqCst)
}
pub fn set_streaming(&self, value: bool) {
self.is_streaming.store(value, Ordering::SeqCst);
}
pub fn message_count(&self) -> usize {
self.messages.read().len()
}
}
impl Default for AgentState {
fn default() -> Self {
Self::new()
}
}
impl Clone for AgentState {
fn clone(&self) -> Self {
Self {
system_prompt: RwLock::new(self.system_prompt.read().clone()),
tools: RwLock::new(self.tools.read().clone()),
messages: RwLock::new(self.messages.read().clone()),
is_streaming: AtomicBool::new(self.is_streaming.load(Ordering::SeqCst)),
stream_message: RwLock::new(self.stream_message.read().clone()),
pending_tool_calls: RwLock::new(self.pending_tool_calls.read().clone()),
error: RwLock::new(self.error.read().clone()),
max_messages: AtomicUsize::new(self.max_messages.load(Ordering::SeqCst)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStateSnapshot {
pub system_prompt: String,
pub model: Model,
pub thinking_level: ThinkingLevel,
pub messages: Vec<AgentMessage>,
pub is_streaming: bool,
pub stream_message: Option<AgentMessage>,
pub pending_tool_calls: HashSet<String>,
pub error: Option<String>,
pub message_count: usize,
pub max_messages: usize,
}