use crate::agent::error::AgentResult;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait Memory: Send + Sync {
async fn store(&mut self, key: &str, value: MemoryValue) -> AgentResult<()>;
async fn retrieve(&self, key: &str) -> AgentResult<Option<MemoryValue>>;
async fn remove(&mut self, key: &str) -> AgentResult<bool>;
async fn contains(&self, key: &str) -> AgentResult<bool> {
Ok(self.retrieve(key).await?.is_some())
}
async fn search(&self, query: &str, limit: usize) -> AgentResult<Vec<MemoryItem>>;
async fn clear(&mut self) -> AgentResult<()>;
async fn get_history(&self, session_id: &str) -> AgentResult<Vec<Message>>;
async fn add_to_history(&mut self, session_id: &str, message: Message) -> AgentResult<()>;
async fn clear_history(&mut self, session_id: &str) -> AgentResult<()>;
async fn stats(&self) -> AgentResult<MemoryStats> {
Ok(MemoryStats::default())
}
fn memory_type(&self) -> &str {
"memory"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MemoryValue {
Text(String),
Embedding(Vec<f32>),
Structured(serde_json::Value),
Binary(Vec<u8>),
TextWithEmbedding { text: String, embedding: Vec<f32> },
}
impl MemoryValue {
pub fn text(s: impl Into<String>) -> Self {
Self::Text(s.into())
}
pub fn embedding(e: Vec<f32>) -> Self {
Self::Embedding(e)
}
pub fn structured(v: serde_json::Value) -> Self {
Self::Structured(v)
}
pub fn text_with_embedding(text: impl Into<String>, embedding: Vec<f32>) -> Self {
Self::TextWithEmbedding {
text: text.into(),
embedding,
}
}
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text(s) => Some(s),
Self::TextWithEmbedding { text, .. } => Some(text),
_ => None,
}
}
pub fn as_embedding(&self) -> Option<&[f32]> {
match self {
Self::Embedding(e) => Some(e),
Self::TextWithEmbedding { embedding, .. } => Some(embedding),
_ => None,
}
}
pub fn as_structured(&self) -> Option<&serde_json::Value> {
match self {
Self::Structured(v) => Some(v),
_ => None,
}
}
}
impl From<String> for MemoryValue {
fn from(s: String) -> Self {
Self::Text(s)
}
}
impl From<&str> for MemoryValue {
fn from(s: &str) -> Self {
Self::Text(s.to_string())
}
}
impl From<serde_json::Value> for MemoryValue {
fn from(v: serde_json::Value) -> Self {
Self::Structured(v)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryItem {
pub key: String,
pub value: MemoryValue,
pub score: f32,
pub metadata: HashMap<String, String>,
pub created_at: u64,
pub last_accessed: u64,
}
impl MemoryItem {
pub fn new(key: impl Into<String>, value: MemoryValue) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
key: key.into(),
value,
score: 1.0,
metadata: HashMap::new(),
created_at: now,
last_accessed: now,
}
}
pub fn with_score(mut self, score: f32) -> Self {
self.score = score.clamp(0.0, 1.0);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
pub timestamp: u64,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Message {
pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
role,
content: content.into(),
timestamp: now,
metadata: HashMap::new(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(MessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(MessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(MessageRole::Assistant, content)
}
pub fn tool(tool_name: impl Into<String>, content: impl Into<String>) -> Self {
let mut msg = Self::new(MessageRole::Tool, content);
msg.metadata.insert(
"tool_name".to_string(),
serde_json::Value::String(tool_name.into()),
);
msg
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::System => write!(f, "system"),
Self::User => write!(f, "user"),
Self::Assistant => write!(f, "assistant"),
Self::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemoryStats {
pub total_items: usize,
pub total_sessions: usize,
pub total_messages: usize,
pub memory_bytes: usize,
}