use std::collections::HashMap;
use serde_json::Value;
use crate::schemas::messages::Message;
#[derive(Clone, Debug)]
pub struct AgentState {
pub messages: Vec<Message>,
pub custom_fields: HashMap<String, Value>,
pub structured_response: Option<Value>,
pub active_agent: Option<String>,
pub loaded_skills: Vec<String>,
pub routing_history: Vec<(String, String)>,
}
impl Default for AgentState {
fn default() -> Self {
Self {
messages: Vec::new(),
custom_fields: HashMap::new(),
structured_response: None,
active_agent: None,
loaded_skills: Vec::new(),
routing_history: Vec::new(),
}
}
}
impl AgentState {
pub fn new() -> Self {
Self::default()
}
pub fn with_messages(messages: Vec<Message>) -> Self {
Self {
messages,
custom_fields: HashMap::new(),
structured_response: None,
active_agent: None,
loaded_skills: Vec::new(),
routing_history: Vec::new(),
}
}
pub fn get_field(&self, key: &str) -> Option<&Value> {
self.custom_fields.get(key)
}
pub fn set_field(&mut self, key: String, value: Value) {
self.custom_fields.insert(key, value);
}
pub fn remove_field(&mut self, key: &str) -> Option<Value> {
self.custom_fields.remove(key)
}
pub fn get_active_agent(&self) -> Option<&String> {
self.active_agent.as_ref()
}
pub fn set_active_agent(&mut self, agent_name: String) {
self.active_agent = Some(agent_name);
}
pub fn clear_active_agent(&mut self) {
self.active_agent = None;
}
pub fn has_skill(&self, skill_name: &str) -> bool {
self.loaded_skills.contains(&skill_name.to_string())
}
pub fn add_skill(&mut self, skill_name: String) {
if !self.loaded_skills.contains(&skill_name) {
self.loaded_skills.push(skill_name);
}
}
pub fn remove_skill(&mut self, skill_name: &str) {
self.loaded_skills.retain(|s| s != skill_name);
}
pub fn add_routing_history(&mut self, input: String, selected_agent: String) {
self.routing_history.push((input, selected_agent));
}
}
#[derive(Debug, Clone)]
pub enum Command {
UpdateState { fields: HashMap<String, Value> },
RemoveMessages { ids: Vec<String> },
ClearMessages,
ClearState,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_state() {
let mut state = AgentState::new();
state.set_field("test_key".to_string(), serde_json::json!("test_value"));
assert_eq!(
state.get_field("test_key"),
Some(&serde_json::json!("test_value"))
);
let removed = state.remove_field("test_key");
assert_eq!(removed, Some(serde_json::json!("test_value")));
assert_eq!(state.get_field("test_key"), None);
}
#[test]
fn test_agent_state_with_messages() {
let messages = vec![
Message::new_human_message("Hello"),
Message::new_ai_message("Hi there!"),
];
let state = AgentState::with_messages(messages.clone());
assert_eq!(state.messages.len(), 2);
assert_eq!(state.custom_fields.len(), 0);
}
#[test]
fn test_command_creation() {
let mut fields = HashMap::new();
fields.insert("key1".to_string(), serde_json::json!("value1"));
let command = Command::UpdateState { fields };
match command {
Command::UpdateState { fields } => {
assert_eq!(fields.len(), 1);
}
_ => panic!("Wrong command type"),
}
}
#[test]
fn test_active_agent() {
let mut state = AgentState::new();
assert_eq!(state.get_active_agent(), None);
state.set_active_agent("test_agent".to_string());
assert_eq!(state.get_active_agent(), Some(&"test_agent".to_string()));
state.clear_active_agent();
assert_eq!(state.get_active_agent(), None);
}
#[test]
fn test_skills() {
let mut state = AgentState::new();
assert!(!state.has_skill("rust"));
state.add_skill("rust".to_string());
assert!(state.has_skill("rust"));
state.add_skill("python".to_string());
assert_eq!(state.loaded_skills.len(), 2);
state.remove_skill("rust");
assert!(!state.has_skill("rust"));
assert!(state.has_skill("python"));
}
#[test]
fn test_routing_history() {
let mut state = AgentState::new();
assert_eq!(state.routing_history.len(), 0);
state.add_routing_history("test input".to_string(), "agent1".to_string());
assert_eq!(state.routing_history.len(), 1);
assert_eq!(state.routing_history[0].0, "test input");
assert_eq!(state.routing_history[0].1, "agent1");
}
}