use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::Utc;
use crate::domain::Message;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentState {
pub id: String,
pub last_active: i64,
pub conversation_history: Vec<Message>,
pub memories: Vec<Memory>,
pub long_term_summary: String,
pub custom_state: HashMap<String, serde_json::Value>,
}
impl AgentState {
pub fn new(id: String) -> Self {
Self {
id,
last_active: Utc::now().timestamp(),
conversation_history: Vec::new(),
memories: Vec::new(),
long_term_summary: String::new(),
custom_state: HashMap::new(),
}
}
pub fn add_memory(&mut self, memory: Memory) {
self.memories.push(memory);
if self.memories.len() > 100 {
self.memories.drain(0..10); }
}
pub fn add_conversation(&mut self, message: Message) {
self.conversation_history.push(message);
if self.conversation_history.len() > 50 {
self.conversation_history.drain(0..5); }
}
pub fn update_last_active(&mut self) {
self.last_active = Utc::now().timestamp();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Memory {
pub id: String,
pub content: String,
pub importance: u8,
pub timestamp: i64,
pub context: String,
pub memory_type: MemoryType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemoryType {
Fact,
Event,
Relationship,
Task,
Other,
}
impl Memory {
pub fn new(content: String, importance: u8, context: String, memory_type: MemoryType) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
content,
importance,
timestamp: Utc::now().timestamp(),
context,
memory_type,
}
}
}
pub struct StateManager {
states: Arc<RwLock<HashMap<String, AgentState>>>,
}
impl StateManager {
pub fn new() -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn get_or_create_state(&self, agent_id: &str) -> AgentState {
let mut states = self.states.write().await;
if let Some(state) = states.get(agent_id) {
state.clone()
} else {
let new_state = AgentState::new(agent_id.to_string());
states.insert(agent_id.to_string(), new_state.clone());
new_state
}
}
pub async fn update_state(&self, state: AgentState) -> Result<()> {
let mut states = self.states.write().await;
states.insert(state.id.clone(), state);
Ok(())
}
pub async fn add_memory(&self, agent_id: &str, memory: Memory) -> Result<()> {
let mut state = self.get_or_create_state(agent_id).await;
state.add_memory(memory);
state.update_last_active();
self.update_state(state).await
}
pub async fn add_conversation(&self, agent_id: &str, message: Message) -> Result<()> {
let mut state = self.get_or_create_state(agent_id).await;
state.add_conversation(message);
state.update_last_active();
self.update_state(state).await
}
pub async fn get_memories(&self, agent_id: &str, limit: Option<usize>) -> Vec<Memory> {
let state = self.get_or_create_state(agent_id).await;
let mut memories = state.memories.clone();
memories.sort_by(|a, b| {
b.importance.cmp(&a.importance)
.then(b.timestamp.cmp(&a.timestamp))
});
if let Some(limit) = limit {
memories.truncate(limit);
}
memories
}
pub async fn get_conversation_history(&self, agent_id: &str, limit: Option<usize>) -> Vec<Message> {
let state = self.get_or_create_state(agent_id).await;
let mut history = state.conversation_history.clone();
if let Some(limit) = limit {
history.truncate(limit);
}
history
}
pub async fn get_all_agent_ids(&self) -> Vec<String> {
let states = self.states.read().await;
states.keys().cloned().collect()
}
}
impl Default for StateManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::{Message, MessageTarget};
#[tokio::test]
async fn test_state_manager() {
let manager = StateManager::new();
let state = manager.get_or_create_state("test_agent").await;
assert_eq!(state.id, "test_agent");
let memory = Memory::new(
"This is a test memory".to_string(),
8,
"test context".to_string(),
MemoryType::Fact,
);
manager.add_memory("test_agent", memory).await.unwrap();
let memories = manager.get_memories("test_agent", Some(10)).await;
assert_eq!(memories.len(), 1);
assert_eq!(memories[0].content, "This is a test memory");
let message = Message {
id: "msg1".to_string(),
from: "test_agent".to_string(),
to: MessageTarget::Direct("other_agent".to_string()),
content: "Hello, world!".to_string(),
timestamp: Utc::now().timestamp(),
reply_to: None,
mentions: vec![],
};
manager.add_conversation("test_agent", message).await.unwrap();
let history = manager.get_conversation_history("test_agent", Some(10)).await;
assert_eq!(history.len(), 1);
assert_eq!(history[0].content, "Hello, world!");
}
}