use crate::types::Message;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentMemory {
pub short_term: Vec<Message>,
pub long_term: HashMap<String, serde_json::Value>,
pub working: HashMap<String, serde_json::Value>,
}
impl AgentMemory {
pub fn new() -> Self {
Self::default()
}
pub fn add_message(&mut self, message: Message) {
self.short_term.push(message);
}
pub fn messages(&self) -> &[Message] {
&self.short_term
}
pub fn clear_short_term(&mut self) {
self.short_term.clear();
}
pub fn remember(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.long_term.insert(key.into(), value);
}
pub fn recall(&self, key: &str) -> Option<&serde_json::Value> {
self.long_term.get(key)
}
pub fn forget(&mut self, key: &str) -> Option<serde_json::Value> {
self.long_term.remove(key)
}
pub fn set_working(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.working.insert(key.into(), value);
}
pub fn get_working(&self, key: &str) -> Option<&serde_json::Value> {
self.working.get(key)
}
pub fn clear_working(&mut self) {
self.working.clear();
}
pub fn message_count(&self) -> usize {
self.short_term.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum AgentState {
#[default]
Idle,
Thinking,
ExecutingTool,
WaitingForHuman,
Completed,
Error,
Stopped,
}
impl std::fmt::Display for AgentState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
AgentState::Idle => "idle",
AgentState::Thinking => "thinking",
AgentState::ExecutingTool => "executing_tool",
AgentState::WaitingForHuman => "waiting_for_human",
AgentState::Completed => "completed",
AgentState::Error => "error",
AgentState::Stopped => "stopped",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone)]
pub struct AgentContext {
pub agent_id: String,
pub state: AgentState,
pub memory: AgentMemory,
pub current_step: usize,
pub max_steps: usize,
pub system_prompt: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
pub correlation_id: Option<String>,
pub preserve_history: bool,
}
impl AgentContext {
pub fn new(agent_id: impl Into<String>) -> Self {
Self {
agent_id: agent_id.into(),
state: AgentState::Idle,
memory: AgentMemory::new(),
current_step: 0,
max_steps: 10,
system_prompt: None,
metadata: HashMap::new(),
correlation_id: None,
preserve_history: false,
}
}
pub fn with_max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
self.correlation_id = Some(id.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_preserve_history(mut self, preserve: bool) -> Self {
self.preserve_history = preserve;
self
}
pub fn set_preserve_history(&mut self, preserve: bool) {
self.preserve_history = preserve;
}
pub fn can_continue(&self) -> bool {
self.current_step < self.max_steps
&& !matches!(
self.state,
AgentState::Completed | AgentState::Error | AgentState::Stopped
)
}
pub fn increment_step(&mut self) {
self.current_step += 1;
}
pub fn get_messages(&self) -> Vec<Message> {
let mut messages = Vec::new();
if let Some(ref system) = self.system_prompt {
messages.push(Message::system(system));
}
messages.extend(self.memory.short_term.iter().cloned());
messages
}
pub fn reset(&mut self) {
self.state = AgentState::Idle;
self.current_step = 0;
if !self.preserve_history {
self.memory.clear_short_term();
}
self.memory.clear_working();
}
}
impl Default for AgentContext {
fn default() -> Self {
Self {
agent_id: generate_agent_id(),
state: AgentState::Idle,
memory: AgentMemory::new(),
current_step: 0,
max_steps: 10,
system_prompt: None,
metadata: HashMap::new(),
correlation_id: None,
preserve_history: false,
}
}
}
fn generate_agent_id() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
format!("agent_{:x}_{:04x}", ts, count & 0xFFFF)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_short_term() {
let mut memory = AgentMemory::new();
assert_eq!(memory.message_count(), 0);
memory.add_message(Message::user("Hello"));
memory.add_message(Message::assistant("Hi there!"));
assert_eq!(memory.message_count(), 2);
assert_eq!(memory.messages()[0].content, Some("Hello".to_string()));
memory.clear_short_term();
assert_eq!(memory.message_count(), 0);
}
#[test]
fn test_memory_long_term() {
let mut memory = AgentMemory::new();
memory.remember("user_name", serde_json::json!("Alice"));
assert_eq!(
memory.recall("user_name"),
Some(&serde_json::json!("Alice"))
);
memory.forget("user_name");
assert_eq!(memory.recall("user_name"), None);
}
#[test]
fn test_context_creation() {
let ctx = AgentContext::new("test-agent")
.with_max_steps(5)
.with_system_prompt("You are helpful.");
assert_eq!(ctx.agent_id, "test-agent");
assert_eq!(ctx.max_steps, 5);
assert_eq!(ctx.system_prompt, Some("You are helpful.".to_string()));
assert!(ctx.can_continue());
}
#[test]
fn test_context_can_continue() {
let mut ctx = AgentContext::new("test").with_max_steps(2);
assert!(ctx.can_continue());
ctx.increment_step();
assert!(ctx.can_continue());
ctx.increment_step();
assert!(!ctx.can_continue());
ctx.current_step = 0;
ctx.state = AgentState::Completed;
assert!(!ctx.can_continue()); }
#[test]
fn test_context_get_messages() {
let mut ctx = AgentContext::new("test").with_system_prompt("System prompt");
ctx.memory.add_message(Message::user("Hello"));
ctx.memory.add_message(Message::assistant("Hi!"));
let messages = ctx.get_messages();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, "system");
assert_eq!(messages[1].role, "user");
assert_eq!(messages[2].role, "assistant");
}
#[test]
fn test_agent_state_display() {
assert_eq!(AgentState::Idle.to_string(), "idle");
assert_eq!(AgentState::Thinking.to_string(), "thinking");
assert_eq!(AgentState::ExecutingTool.to_string(), "executing_tool");
}
#[test]
fn test_context_reset_clears_history_by_default() {
let mut ctx = AgentContext::new("test");
ctx.memory.add_message(Message::user("Hello"));
ctx.memory.add_message(Message::assistant("Hi!"));
ctx.current_step = 3;
ctx.state = AgentState::Completed;
ctx.reset();
assert_eq!(ctx.memory.message_count(), 0);
assert_eq!(ctx.current_step, 0);
assert_eq!(ctx.state, AgentState::Idle);
}
#[test]
fn test_context_reset_preserves_history_when_enabled() {
let mut ctx = AgentContext::new("test").with_preserve_history(true);
ctx.memory.add_message(Message::user("Hello"));
ctx.memory.add_message(Message::assistant("Hi!"));
ctx.current_step = 3;
ctx.state = AgentState::Completed;
ctx.reset();
assert_eq!(ctx.memory.message_count(), 2);
assert_eq!(ctx.current_step, 0);
assert_eq!(ctx.state, AgentState::Idle);
}
#[test]
fn test_set_preserve_history() {
let mut ctx = AgentContext::new("test");
assert!(!ctx.preserve_history);
ctx.set_preserve_history(true);
assert!(ctx.preserve_history);
ctx.memory.add_message(Message::user("Hello"));
ctx.reset();
assert_eq!(ctx.memory.message_count(), 1);
ctx.set_preserve_history(false);
ctx.reset();
assert_eq!(ctx.memory.message_count(), 0);
}
}