use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::value::Value;
pub type AgentId = String;
pub type ConversationId = String;
pub type CheckpointId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentState {
pub agent_id: AgentId,
pub name: String,
pub agent_type: String,
pub system_prompt: Option<String>,
pub capabilities: Vec<String>,
pub config: HashMap<String, Value>,
pub custom_state: HashMap<String, Value>,
pub active_conversations: Vec<ConversationId>,
pub created_at: u64,
pub modified_at: u64,
pub version: u32,
pub metadata: HashMap<String, String>,
}
impl AgentState {
pub fn new(agent_id: AgentId, name: &str, agent_type: &str) -> Self {
let now = current_timestamp();
Self {
agent_id,
name: name.to_string(),
agent_type: agent_type.to_string(),
system_prompt: None,
capabilities: Vec::new(),
config: HashMap::new(),
custom_state: HashMap::new(),
active_conversations: Vec::new(),
created_at: now,
modified_at: now,
version: 1,
metadata: HashMap::new(),
}
}
pub fn touch(&mut self) {
self.modified_at = current_timestamp();
}
pub fn set_state(&mut self, key: &str, value: Value) {
self.custom_state.insert(key.to_string(), value);
self.touch();
}
pub fn get_state(&self, key: &str) -> Option<&Value> {
self.custom_state.get(key)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
pub id: ConversationId,
pub agent_id: AgentId,
pub title: Option<String>,
pub messages: Vec<Message>,
pub metadata: HashMap<String, String>,
pub created_at: u64,
pub last_message_at: u64,
pub archived: bool,
pub tags: Vec<String>,
}
impl Conversation {
pub fn new(agent_id: AgentId) -> Self {
let now = current_timestamp();
Self {
id: Uuid::new_v4().to_string(),
agent_id,
title: None,
messages: Vec::new(),
metadata: HashMap::new(),
created_at: now,
last_message_at: now,
archived: false,
tags: Vec::new(),
}
}
pub fn add_message(&mut self, message: Message) {
self.last_message_at = current_timestamp();
self.messages.push(message);
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn auto_title(&mut self) {
if self.title.is_none() {
if let Some(first_user_msg) = self.messages.iter().find(|m| m.role == MessageRole::User)
{
let content = &first_user_msg.content;
let title = if content.len() > 50 {
format!("{}...", &content[..47])
} else {
content.clone()
};
self.title = Some(title);
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub role: MessageRole,
pub content: String,
pub timestamp: u64,
pub tokens: Option<u32>,
pub attachments: Vec<Attachment>,
pub tool_calls: Vec<ToolCall>,
pub parent_id: Option<String>,
pub metadata: HashMap<String, String>,
}
impl Message {
pub fn user(content: &str) -> Self {
Self {
id: Uuid::new_v4().to_string(),
role: MessageRole::User,
content: content.to_string(),
timestamp: current_timestamp(),
tokens: None,
attachments: Vec::new(),
tool_calls: Vec::new(),
parent_id: None,
metadata: HashMap::new(),
}
}
pub fn assistant(content: &str) -> Self {
Self {
id: Uuid::new_v4().to_string(),
role: MessageRole::Assistant,
content: content.to_string(),
timestamp: current_timestamp(),
tokens: None,
attachments: Vec::new(),
tool_calls: Vec::new(),
parent_id: None,
metadata: HashMap::new(),
}
}
pub fn system(content: &str) -> Self {
Self {
id: Uuid::new_v4().to_string(),
role: MessageRole::System,
content: content.to_string(),
timestamp: current_timestamp(),
tokens: None,
attachments: Vec::new(),
tool_calls: Vec::new(),
parent_id: None,
metadata: HashMap::new(),
}
}
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
let mut metadata = HashMap::new();
metadata.insert("tool_call_id".to_string(), tool_call_id.to_string());
Self {
id: Uuid::new_v4().to_string(),
role: MessageRole::Tool,
content: content.to_string(),
timestamp: current_timestamp(),
tokens: None,
attachments: Vec::new(),
tool_calls: Vec::new(),
parent_id: None,
metadata,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attachment {
pub id: String,
pub attachment_type: AttachmentType,
pub name: String,
pub mime_type: String,
pub content: Option<String>,
pub url: Option<String>,
pub size: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum AttachmentType {
Image,
Audio,
Video,
File,
Code,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: HashMap<String, Value>,
pub result: Option<String>,
pub status: ToolCallStatus,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ToolCallStatus {
Pending,
Running,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: CheckpointId,
pub agent_id: AgentId,
pub name: String,
pub description: Option<String>,
pub agent_state: AgentState,
pub conversations: Vec<Conversation>,
pub checkpoint_type: CheckpointType,
pub created_at: u64,
pub size_bytes: u64,
pub checksum: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CheckpointType {
Automatic,
Manual,
PreOperation,
Recovery,
}
#[async_trait::async_trait]
pub trait PersistenceBackend: Send + Sync {
async fn initialize(&self) -> Result<()>;
async fn save_agent_state(&self, state: &AgentState) -> Result<()>;
async fn load_agent_state(&self, agent_id: &AgentId) -> Result<Option<AgentState>>;
async fn delete_agent_state(&self, agent_id: &AgentId) -> Result<()>;
async fn list_agents(&self) -> Result<Vec<AgentId>>;
async fn save_conversation(&self, conversation: &Conversation) -> Result<()>;
async fn load_conversation(
&self,
conversation_id: &ConversationId,
) -> Result<Option<Conversation>>;
async fn delete_conversation(&self, conversation_id: &ConversationId) -> Result<()>;
async fn list_conversations(&self, agent_id: &AgentId) -> Result<Vec<ConversationId>>;
async fn save_checkpoint(&self, checkpoint: &Checkpoint) -> Result<()>;
async fn load_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<Option<Checkpoint>>;
async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()>;
async fn list_checkpoints(&self, agent_id: &AgentId) -> Result<Vec<CheckpointId>>;
async fn search_conversations(
&self,
agent_id: &AgentId,
query: &str,
limit: usize,
) -> Result<Vec<ConversationId>>;
async fn get_stats(&self) -> Result<StorageStats>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub agent_count: usize,
pub conversation_count: usize,
pub checkpoint_count: usize,
pub total_messages: usize,
pub total_size_bytes: u64,
}
pub struct FileBackend {
base_path: PathBuf,
}
impl FileBackend {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
fn agents_path(&self) -> PathBuf {
self.base_path.join("agents")
}
fn conversations_path(&self) -> PathBuf {
self.base_path.join("conversations")
}
fn checkpoints_path(&self) -> PathBuf {
self.base_path.join("checkpoints")
}
fn agent_file(&self, agent_id: &str) -> PathBuf {
self.agents_path().join(format!("{}.json", agent_id))
}
fn conversation_file(&self, conversation_id: &str) -> PathBuf {
self.conversations_path()
.join(format!("{}.json", conversation_id))
}
fn checkpoint_file(&self, checkpoint_id: &str) -> PathBuf {
self.checkpoints_path()
.join(format!("{}.json", checkpoint_id))
}
}
#[async_trait::async_trait]
impl PersistenceBackend for FileBackend {
async fn initialize(&self) -> Result<()> {
tokio::fs::create_dir_all(self.agents_path()).await?;
tokio::fs::create_dir_all(self.conversations_path()).await?;
tokio::fs::create_dir_all(self.checkpoints_path()).await?;
Ok(())
}
async fn save_agent_state(&self, state: &AgentState) -> Result<()> {
let path = self.agent_file(&state.agent_id);
let json = serde_json::to_string_pretty(state)?;
tokio::fs::write(path, json).await?;
Ok(())
}
async fn load_agent_state(&self, agent_id: &AgentId) -> Result<Option<AgentState>> {
let path = self.agent_file(agent_id);
if path.exists() {
let json = tokio::fs::read_to_string(path).await?;
let state: AgentState = serde_json::from_str(&json)?;
Ok(Some(state))
} else {
Ok(None)
}
}
async fn delete_agent_state(&self, agent_id: &AgentId) -> Result<()> {
let path = self.agent_file(agent_id);
if path.exists() {
tokio::fs::remove_file(path).await?;
}
Ok(())
}
async fn list_agents(&self) -> Result<Vec<AgentId>> {
let mut agents = Vec::new();
let path = self.agents_path();
if path.exists() {
let mut dir = tokio::fs::read_dir(path).await?;
while let Some(entry) = dir.next_entry().await? {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".json") {
agents.push(name.trim_end_matches(".json").to_string());
}
}
}
}
Ok(agents)
}
async fn save_conversation(&self, conversation: &Conversation) -> Result<()> {
let path = self.conversation_file(&conversation.id);
let json = serde_json::to_string_pretty(conversation)?;
tokio::fs::write(path, json).await?;
Ok(())
}
async fn load_conversation(
&self,
conversation_id: &ConversationId,
) -> Result<Option<Conversation>> {
let path = self.conversation_file(conversation_id);
if path.exists() {
let json = tokio::fs::read_to_string(path).await?;
let conv: Conversation = serde_json::from_str(&json)?;
Ok(Some(conv))
} else {
Ok(None)
}
}
async fn delete_conversation(&self, conversation_id: &ConversationId) -> Result<()> {
let path = self.conversation_file(conversation_id);
if path.exists() {
tokio::fs::remove_file(path).await?;
}
Ok(())
}
async fn list_conversations(&self, agent_id: &AgentId) -> Result<Vec<ConversationId>> {
let mut conversations = Vec::new();
let path = self.conversations_path();
if path.exists() {
let mut dir = tokio::fs::read_dir(path).await?;
while let Some(entry) = dir.next_entry().await? {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".json") {
let conv_id = name.trim_end_matches(".json").to_string();
if let Ok(Some(conv)) = self.load_conversation(&conv_id).await {
if conv.agent_id == *agent_id {
conversations.push(conv_id);
}
}
}
}
}
}
Ok(conversations)
}
async fn save_checkpoint(&self, checkpoint: &Checkpoint) -> Result<()> {
let path = self.checkpoint_file(&checkpoint.id);
let json = serde_json::to_string_pretty(checkpoint)?;
tokio::fs::write(path, json).await?;
Ok(())
}
async fn load_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<Option<Checkpoint>> {
let path = self.checkpoint_file(checkpoint_id);
if path.exists() {
let json = tokio::fs::read_to_string(path).await?;
let checkpoint: Checkpoint = serde_json::from_str(&json)?;
Ok(Some(checkpoint))
} else {
Ok(None)
}
}
async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()> {
let path = self.checkpoint_file(checkpoint_id);
if path.exists() {
tokio::fs::remove_file(path).await?;
}
Ok(())
}
async fn list_checkpoints(&self, agent_id: &AgentId) -> Result<Vec<CheckpointId>> {
let mut checkpoints = Vec::new();
let path = self.checkpoints_path();
if path.exists() {
let mut dir = tokio::fs::read_dir(path).await?;
while let Some(entry) = dir.next_entry().await? {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".json") {
let cp_id = name.trim_end_matches(".json").to_string();
if let Ok(Some(cp)) = self.load_checkpoint(&cp_id).await {
if cp.agent_id == *agent_id {
checkpoints.push(cp_id);
}
}
}
}
}
}
Ok(checkpoints)
}
async fn search_conversations(
&self,
agent_id: &AgentId,
query: &str,
limit: usize,
) -> Result<Vec<ConversationId>> {
let mut results = Vec::new();
let query_lower = query.to_lowercase();
let conversations = self.list_conversations(agent_id).await?;
for conv_id in conversations {
if results.len() >= limit {
break;
}
if let Ok(Some(conv)) = self.load_conversation(&conv_id).await {
let matches = conv
.messages
.iter()
.any(|m| m.content.to_lowercase().contains(&query_lower));
if matches {
results.push(conv_id);
}
}
}
Ok(results)
}
async fn get_stats(&self) -> Result<StorageStats> {
let agents = self.list_agents().await?;
let mut conversation_count = 0;
let mut checkpoint_count = 0;
let mut total_messages = 0;
let mut total_size = 0u64;
for agent_id in &agents {
let conversations = self.list_conversations(agent_id).await?;
conversation_count += conversations.len();
for conv_id in conversations {
if let Ok(Some(conv)) = self.load_conversation(&conv_id).await {
total_messages += conv.messages.len();
}
let path = self.conversation_file(&conv_id);
if let Ok(meta) = tokio::fs::metadata(&path).await {
total_size += meta.len();
}
}
let checkpoints = self.list_checkpoints(agent_id).await?;
checkpoint_count += checkpoints.len();
for cp_id in checkpoints {
let path = self.checkpoint_file(&cp_id);
if let Ok(meta) = tokio::fs::metadata(&path).await {
total_size += meta.len();
}
}
let agent_path = self.agent_file(agent_id);
if let Ok(meta) = tokio::fs::metadata(&agent_path).await {
total_size += meta.len();
}
}
Ok(StorageStats {
agent_count: agents.len(),
conversation_count,
checkpoint_count,
total_messages,
total_size_bytes: total_size,
})
}
}
pub struct MemoryBackend {
agents: Arc<RwLock<HashMap<AgentId, AgentState>>>,
conversations: Arc<RwLock<HashMap<ConversationId, Conversation>>>,
checkpoints: Arc<RwLock<HashMap<CheckpointId, Checkpoint>>>,
}
impl MemoryBackend {
pub fn new() -> Self {
Self {
agents: Arc::new(RwLock::new(HashMap::new())),
conversations: Arc::new(RwLock::new(HashMap::new())),
checkpoints: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemoryBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl PersistenceBackend for MemoryBackend {
async fn initialize(&self) -> Result<()> {
Ok(())
}
async fn save_agent_state(&self, state: &AgentState) -> Result<()> {
self.agents
.write()
.await
.insert(state.agent_id.clone(), state.clone());
Ok(())
}
async fn load_agent_state(&self, agent_id: &AgentId) -> Result<Option<AgentState>> {
Ok(self.agents.read().await.get(agent_id).cloned())
}
async fn delete_agent_state(&self, agent_id: &AgentId) -> Result<()> {
self.agents.write().await.remove(agent_id);
Ok(())
}
async fn list_agents(&self) -> Result<Vec<AgentId>> {
Ok(self.agents.read().await.keys().cloned().collect())
}
async fn save_conversation(&self, conversation: &Conversation) -> Result<()> {
self.conversations
.write()
.await
.insert(conversation.id.clone(), conversation.clone());
Ok(())
}
async fn load_conversation(
&self,
conversation_id: &ConversationId,
) -> Result<Option<Conversation>> {
Ok(self
.conversations
.read()
.await
.get(conversation_id)
.cloned())
}
async fn delete_conversation(&self, conversation_id: &ConversationId) -> Result<()> {
self.conversations.write().await.remove(conversation_id);
Ok(())
}
async fn list_conversations(&self, agent_id: &AgentId) -> Result<Vec<ConversationId>> {
Ok(self
.conversations
.read()
.await
.values()
.filter(|c| c.agent_id == *agent_id)
.map(|c| c.id.clone())
.collect())
}
async fn save_checkpoint(&self, checkpoint: &Checkpoint) -> Result<()> {
self.checkpoints
.write()
.await
.insert(checkpoint.id.clone(), checkpoint.clone());
Ok(())
}
async fn load_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<Option<Checkpoint>> {
Ok(self.checkpoints.read().await.get(checkpoint_id).cloned())
}
async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()> {
self.checkpoints.write().await.remove(checkpoint_id);
Ok(())
}
async fn list_checkpoints(&self, agent_id: &AgentId) -> Result<Vec<CheckpointId>> {
Ok(self
.checkpoints
.read()
.await
.values()
.filter(|c| c.agent_id == *agent_id)
.map(|c| c.id.clone())
.collect())
}
async fn search_conversations(
&self,
agent_id: &AgentId,
query: &str,
limit: usize,
) -> Result<Vec<ConversationId>> {
let query_lower = query.to_lowercase();
Ok(self
.conversations
.read()
.await
.values()
.filter(|c| c.agent_id == *agent_id)
.filter(|c| {
c.messages
.iter()
.any(|m| m.content.to_lowercase().contains(&query_lower))
})
.take(limit)
.map(|c| c.id.clone())
.collect())
}
async fn get_stats(&self) -> Result<StorageStats> {
let agents = self.agents.read().await;
let conversations = self.conversations.read().await;
let checkpoints = self.checkpoints.read().await;
let total_messages: usize = conversations.values().map(|c| c.messages.len()).sum();
let estimated_size = (agents.len() * 1000
+ conversations.len() * 500
+ total_messages * 200
+ checkpoints.len() * 5000) as u64;
Ok(StorageStats {
agent_count: agents.len(),
conversation_count: conversations.len(),
checkpoint_count: checkpoints.len(),
total_messages,
total_size_bytes: estimated_size,
})
}
}
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub enabled: bool,
pub interval_secs: u64,
pub max_checkpoints: usize,
pub checkpoint_before_operations: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
enabled: true,
interval_secs: 300, max_checkpoints: 10,
checkpoint_before_operations: true,
}
}
}
pub struct PersistenceManager {
backend: Arc<dyn PersistenceBackend>,
checkpoint_config: CheckpointConfig,
running: Arc<RwLock<bool>>,
}
impl PersistenceManager {
pub fn new(backend: Arc<dyn PersistenceBackend>) -> Self {
Self {
backend,
checkpoint_config: CheckpointConfig::default(),
running: Arc::new(RwLock::new(false)),
}
}
pub fn with_file_backend(path: impl AsRef<Path>) -> Self {
Self::new(Arc::new(FileBackend::new(path)))
}
pub fn with_memory_backend() -> Self {
Self::new(Arc::new(MemoryBackend::new()))
}
pub fn set_checkpoint_config(&mut self, config: CheckpointConfig) {
self.checkpoint_config = config;
}
pub async fn initialize(&self) -> Result<()> {
self.backend.initialize().await
}
pub async fn save_agent(&self, state: &AgentState) -> Result<()> {
self.backend.save_agent_state(state).await
}
pub async fn load_agent(&self, agent_id: &AgentId) -> Result<Option<AgentState>> {
self.backend.load_agent_state(agent_id).await
}
pub async fn delete_agent(&self, agent_id: &AgentId) -> Result<()> {
let conversations = self.backend.list_conversations(agent_id).await?;
for conv_id in conversations {
self.backend.delete_conversation(&conv_id).await?;
}
let checkpoints = self.backend.list_checkpoints(agent_id).await?;
for cp_id in checkpoints {
self.backend.delete_checkpoint(&cp_id).await?;
}
self.backend.delete_agent_state(agent_id).await
}
pub async fn list_agents(&self) -> Result<Vec<AgentId>> {
self.backend.list_agents().await
}
pub async fn save_conversation(&self, conversation: &Conversation) -> Result<()> {
self.backend.save_conversation(conversation).await
}
pub async fn load_conversation(
&self,
conversation_id: &ConversationId,
) -> Result<Option<Conversation>> {
self.backend.load_conversation(conversation_id).await
}
pub async fn delete_conversation(&self, conversation_id: &ConversationId) -> Result<()> {
self.backend.delete_conversation(conversation_id).await
}
pub async fn list_conversations(&self, agent_id: &AgentId) -> Result<Vec<ConversationId>> {
self.backend.list_conversations(agent_id).await
}
pub async fn search_conversations(
&self,
agent_id: &AgentId,
query: &str,
limit: usize,
) -> Result<Vec<ConversationId>> {
self.backend
.search_conversations(agent_id, query, limit)
.await
}
pub async fn add_message(
&self,
conversation_id: &ConversationId,
message: Message,
) -> Result<()> {
let mut conversation = self
.backend
.load_conversation(conversation_id)
.await?
.ok_or_else(|| anyhow!("Conversation not found"))?;
conversation.add_message(message);
self.backend.save_conversation(&conversation).await
}
pub async fn create_checkpoint(
&self,
agent_id: &AgentId,
name: &str,
checkpoint_type: CheckpointType,
) -> Result<CheckpointId> {
let agent_state = self
.backend
.load_agent_state(agent_id)
.await?
.ok_or_else(|| anyhow!("Agent not found"))?;
let conversation_ids = self.backend.list_conversations(agent_id).await?;
let mut conversations = Vec::new();
for conv_id in conversation_ids {
if let Some(conv) = self.backend.load_conversation(&conv_id).await? {
conversations.push(conv);
}
}
let checkpoint_id = Uuid::new_v4().to_string();
let json = serde_json::to_string(&(&agent_state, &conversations))?;
let checksum = format!("{:x}", md5::compute(&json));
let size_bytes = json.len() as u64;
let checkpoint = Checkpoint {
id: checkpoint_id.clone(),
agent_id: agent_id.clone(),
name: name.to_string(),
description: None,
agent_state,
conversations,
checkpoint_type,
created_at: current_timestamp(),
size_bytes,
checksum,
};
self.backend.save_checkpoint(&checkpoint).await?;
self.cleanup_old_checkpoints(agent_id).await?;
Ok(checkpoint_id)
}
pub async fn restore_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<AgentId> {
let checkpoint = self
.backend
.load_checkpoint(checkpoint_id)
.await?
.ok_or_else(|| anyhow!("Checkpoint not found"))?;
let json = serde_json::to_string(&(&checkpoint.agent_state, &checkpoint.conversations))?;
let computed_checksum = format!("{:x}", md5::compute(&json));
if computed_checksum != checkpoint.checksum {
return Err(anyhow!("Checkpoint integrity check failed"));
}
self.backend
.save_agent_state(&checkpoint.agent_state)
.await?;
for conversation in &checkpoint.conversations {
self.backend.save_conversation(conversation).await?;
}
Ok(checkpoint.agent_id)
}
pub async fn list_checkpoints(&self, agent_id: &AgentId) -> Result<Vec<CheckpointId>> {
self.backend.list_checkpoints(agent_id).await
}
pub async fn delete_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<()> {
self.backend.delete_checkpoint(checkpoint_id).await
}
async fn cleanup_old_checkpoints(&self, agent_id: &AgentId) -> Result<()> {
let checkpoint_ids = self.backend.list_checkpoints(agent_id).await?;
if checkpoint_ids.len() <= self.checkpoint_config.max_checkpoints {
return Ok(());
}
let mut checkpoints: Vec<Checkpoint> = Vec::new();
for cp_id in checkpoint_ids {
if let Some(cp) = self.backend.load_checkpoint(&cp_id).await? {
checkpoints.push(cp);
}
}
checkpoints.sort_by(|a, b| b.created_at.cmp(&a.created_at));
for checkpoint in checkpoints
.iter()
.skip(self.checkpoint_config.max_checkpoints)
{
self.backend.delete_checkpoint(&checkpoint.id).await?;
}
Ok(())
}
pub async fn start_auto_checkpoint(&self, agent_ids: Vec<AgentId>) {
if !self.checkpoint_config.enabled {
return;
}
*self.running.write().await = true;
let backend = Arc::clone(&self.backend);
let running = Arc::clone(&self.running);
let interval = Duration::from_secs(self.checkpoint_config.interval_secs);
let _max_checkpoints = self.checkpoint_config.max_checkpoints;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
while *running.read().await {
ticker.tick().await;
for agent_id in &agent_ids {
if let Ok(Some(agent_state)) = backend.load_agent_state(agent_id).await {
let conversation_ids = backend
.list_conversations(agent_id)
.await
.unwrap_or_default();
let mut conversations = Vec::new();
for conv_id in conversation_ids {
if let Ok(Some(conv)) = backend.load_conversation(&conv_id).await {
conversations.push(conv);
}
}
let checkpoint_id = Uuid::new_v4().to_string();
let json = serde_json::to_string(&(&agent_state, &conversations))
.unwrap_or_default();
let checksum = format!("{:x}", md5::compute(&json));
let size_bytes = json.len() as u64;
let checkpoint = Checkpoint {
id: checkpoint_id,
agent_id: agent_id.clone(),
name: format!(
"Auto checkpoint {}",
chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")
),
description: Some("Automatic checkpoint".to_string()),
agent_state,
conversations,
checkpoint_type: CheckpointType::Automatic,
created_at: current_timestamp(),
size_bytes,
checksum,
};
let _ = backend.save_checkpoint(&checkpoint).await;
}
}
}
});
}
pub async fn stop_auto_checkpoint(&self) {
*self.running.write().await = false;
}
pub async fn get_stats(&self) -> Result<StorageStats> {
self.backend.get_stats().await
}
}
pub fn persistence_builtins() -> Vec<(&'static str, &'static str)> {
vec![
("agent_save", "Save agent state to persistent storage"),
("agent_load", "Load agent state from persistent storage"),
("agent_delete", "Delete agent and all associated data"),
("agent_list", "List all persisted agents"),
("conversation_save", "Save a conversation"),
("conversation_load", "Load a conversation"),
("conversation_delete", "Delete a conversation"),
("conversation_list", "List conversations for an agent"),
("conversation_search", "Search conversations by content"),
("checkpoint_create", "Create a checkpoint for recovery"),
("checkpoint_restore", "Restore from a checkpoint"),
("checkpoint_list", "List checkpoints for an agent"),
("checkpoint_delete", "Delete a checkpoint"),
("persistence_stats", "Get storage statistics"),
]
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_state_creation() {
let state = AgentState::new("agent-1".to_string(), "Test Agent", "gpt-4");
assert_eq!(state.agent_id, "agent-1");
assert_eq!(state.name, "Test Agent");
assert_eq!(state.agent_type, "gpt-4");
assert_eq!(state.version, 1);
}
#[test]
fn test_agent_state_custom_state() {
let mut state = AgentState::new("agent-1".to_string(), "Test", "gpt-4");
state.set_state("counter", Value::Int(42));
assert_eq!(state.get_state("counter"), Some(&Value::Int(42)));
assert_eq!(state.get_state("nonexistent"), None);
}
#[test]
fn test_conversation_creation() {
let conv = Conversation::new("agent-1".to_string());
assert!(!conv.id.is_empty());
assert_eq!(conv.agent_id, "agent-1");
assert!(conv.messages.is_empty());
assert!(!conv.archived);
}
#[test]
fn test_conversation_add_message() {
let mut conv = Conversation::new("agent-1".to_string());
conv.add_message(Message::user("Hello"));
conv.add_message(Message::assistant("Hi there!"));
assert_eq!(conv.message_count(), 2);
}
#[test]
fn test_conversation_auto_title() {
let mut conv = Conversation::new("agent-1".to_string());
conv.add_message(Message::user("What is the weather like today?"));
conv.auto_title();
assert_eq!(
conv.title,
Some("What is the weather like today?".to_string())
);
}
#[test]
fn test_conversation_auto_title_truncation() {
let mut conv = Conversation::new("agent-1".to_string());
conv.add_message(Message::user(
"This is a very long message that should be truncated when used as a title",
));
conv.auto_title();
assert!(conv.title.as_ref().unwrap().len() <= 50);
assert!(conv.title.as_ref().unwrap().ends_with("..."));
}
#[test]
fn test_message_user() {
let msg = Message::user("Hello");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_message_assistant() {
let msg = Message::assistant("Hi there!");
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.content, "Hi there!");
}
#[test]
fn test_message_system() {
let msg = Message::system("You are a helpful assistant.");
assert_eq!(msg.role, MessageRole::System);
}
#[test]
fn test_message_tool_result() {
let msg = Message::tool_result("call-123", "Result: 42");
assert_eq!(msg.role, MessageRole::Tool);
assert_eq!(
msg.metadata.get("tool_call_id"),
Some(&"call-123".to_string())
);
}
#[test]
fn test_checkpoint_config_default() {
let config = CheckpointConfig::default();
assert!(config.enabled);
assert_eq!(config.interval_secs, 300);
assert_eq!(config.max_checkpoints, 10);
}
#[tokio::test]
async fn test_memory_backend_agent_operations() {
let backend = MemoryBackend::new();
backend.initialize().await.unwrap();
let state = AgentState::new("agent-1".to_string(), "Test", "gpt-4");
backend.save_agent_state(&state).await.unwrap();
let loaded = backend
.load_agent_state(&"agent-1".to_string())
.await
.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().name, "Test");
let agents = backend.list_agents().await.unwrap();
assert_eq!(agents.len(), 1);
backend
.delete_agent_state(&"agent-1".to_string())
.await
.unwrap();
let loaded = backend
.load_agent_state(&"agent-1".to_string())
.await
.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_memory_backend_conversation_operations() {
let backend = MemoryBackend::new();
backend.initialize().await.unwrap();
let mut conv = Conversation::new("agent-1".to_string());
conv.add_message(Message::user("Hello"));
let conv_id = conv.id.clone();
backend.save_conversation(&conv).await.unwrap();
let loaded = backend.load_conversation(&conv_id).await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().messages.len(), 1);
let convs = backend
.list_conversations(&"agent-1".to_string())
.await
.unwrap();
assert_eq!(convs.len(), 1);
}
#[tokio::test]
async fn test_memory_backend_search() {
let backend = MemoryBackend::new();
backend.initialize().await.unwrap();
let mut conv1 = Conversation::new("agent-1".to_string());
conv1.add_message(Message::user("Hello world"));
backend.save_conversation(&conv1).await.unwrap();
let mut conv2 = Conversation::new("agent-1".to_string());
conv2.add_message(Message::user("Goodbye universe"));
backend.save_conversation(&conv2).await.unwrap();
let results = backend
.search_conversations(&"agent-1".to_string(), "world", 10)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], conv1.id);
}
#[tokio::test]
async fn test_persistence_manager_basic() {
let manager = PersistenceManager::with_memory_backend();
manager.initialize().await.unwrap();
let state = AgentState::new("agent-1".to_string(), "Test", "gpt-4");
manager.save_agent(&state).await.unwrap();
let loaded = manager.load_agent(&"agent-1".to_string()).await.unwrap();
assert!(loaded.is_some());
}
#[tokio::test]
async fn test_persistence_manager_checkpoint() {
let manager = PersistenceManager::with_memory_backend();
manager.initialize().await.unwrap();
let state = AgentState::new("agent-1".to_string(), "Test", "gpt-4");
manager.save_agent(&state).await.unwrap();
let checkpoint_id = manager
.create_checkpoint(
&"agent-1".to_string(),
"Test checkpoint",
CheckpointType::Manual,
)
.await
.unwrap();
assert!(!checkpoint_id.is_empty());
let checkpoints = manager
.list_checkpoints(&"agent-1".to_string())
.await
.unwrap();
assert_eq!(checkpoints.len(), 1);
}
#[tokio::test]
async fn test_persistence_manager_stats() {
let manager = PersistenceManager::with_memory_backend();
manager.initialize().await.unwrap();
let state = AgentState::new("agent-1".to_string(), "Test", "gpt-4");
manager.save_agent(&state).await.unwrap();
let stats = manager.get_stats().await.unwrap();
assert_eq!(stats.agent_count, 1);
}
#[test]
fn test_persistence_builtins() {
let builtins = persistence_builtins();
assert!(builtins.len() >= 10);
assert!(builtins.iter().any(|(name, _)| *name == "agent_save"));
assert!(builtins
.iter()
.any(|(name, _)| *name == "checkpoint_create"));
}
#[test]
fn test_message_role_equality() {
assert_eq!(MessageRole::User, MessageRole::User);
assert_ne!(MessageRole::User, MessageRole::Assistant);
}
#[test]
fn test_checkpoint_type_equality() {
assert_eq!(CheckpointType::Manual, CheckpointType::Manual);
assert_ne!(CheckpointType::Automatic, CheckpointType::Manual);
}
#[test]
fn test_tool_call_status_equality() {
assert_eq!(ToolCallStatus::Pending, ToolCallStatus::Pending);
assert_ne!(ToolCallStatus::Running, ToolCallStatus::Completed);
}
#[test]
fn test_attachment_type_equality() {
assert_eq!(AttachmentType::Image, AttachmentType::Image);
assert_ne!(AttachmentType::Audio, AttachmentType::Video);
}
}