use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
use crate::error::Result;
#[cfg(feature = "fastembed")]
use fastembed::{InitOptions, TextEmbedding};
use tokio::sync::OnceCell;
#[cfg(feature = "postgres")]
pub mod postgres;
#[cfg(feature = "qdrant")]
pub mod qdrant;
#[cfg(feature = "mongodb")]
pub mod mongodb;
#[cfg(feature = "postgres")]
pub use postgres::PostgresStore;
#[cfg(feature = "qdrant")]
pub use qdrant::QdrantStore;
#[cfg(feature = "mongodb")]
pub use mongodb::MongoStore;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRecord {
pub id: Uuid,
pub session_id: String,
pub role: String,
pub content: String,
pub importance: f32,
pub timestamp: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<Vec<f32>>,
}
#[async_trait::async_trait]
pub trait MemoryStore: Send + Sync {
async fn store(&self, record: MemoryRecord) -> Result<()>;
async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>>;
async fn search(
&self,
session_id: &str,
query_embedding: Vec<f32>,
limit: usize,
) -> Result<Vec<MemoryRecord>>;
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn flush(&self) -> Result<()>;
}
pub struct InMemoryStore {
records: parking_lot::RwLock<Vec<MemoryRecord>>,
#[cfg(feature = "fastembed")]
embedder: OnceCell<TextEmbedding>,
}
impl InMemoryStore {
pub fn new() -> Self {
Self {
records: parking_lot::RwLock::new(Vec::new()),
#[cfg(feature = "fastembed")]
embedder: OnceCell::new(),
}
}
}
impl Default for InMemoryStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl MemoryStore for InMemoryStore {
async fn store(&self, record: MemoryRecord) -> Result<()> {
let mut records = self.records.write();
records.push(record);
Ok(())
}
async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
let records = self.records.read();
let filtered: Vec<MemoryRecord> = records
.iter()
.filter(|r| r.session_id == session_id)
.rev()
.take(limit)
.cloned()
.collect();
Ok(filtered)
}
async fn search(
&self,
session_id: &str,
query_embedding: Vec<f32>,
limit: usize,
) -> Result<Vec<MemoryRecord>> {
let records = self.records.read();
let mut scored: Vec<(f32, MemoryRecord)> = records
.iter()
.filter(|r| r.session_id == session_id && r.embedding.is_some())
.map(|r| {
let embedding = r.embedding.as_ref().unwrap();
let similarity = cosine_similarity(&query_embedding, embedding);
(similarity, r.clone())
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
}
async fn flush(&self) -> Result<()> {
Ok(())
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
#[cfg(feature = "fastembed")]
{
let embedder = self
.embedder
.get_or_try_init(|| async {
TextEmbedding::try_new(InitOptions::default())
.map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))
})
.await?;
let embeddings = embedder
.embed(vec![_text], None)
.map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))?;
Ok(embeddings[0].clone())
}
#[cfg(not(feature = "fastembed"))]
Ok(vec![])
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
pub fn mmr_rerank_records(
query_embedding: &[f32],
candidates: Vec<MemoryRecord>,
k: usize,
lambda: f32,
) -> Vec<MemoryRecord> {
if candidates.is_empty() {
return Vec::new();
}
let k = k.min(candidates.len());
let mut selected_indices = Vec::with_capacity(k);
let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
if let Some((idx, _)) = remaining_indices
.iter()
.enumerate()
.filter_map(|(i, &r_idx)| {
candidates[r_idx].embedding
.as_ref()
.map(|emb| (i, cosine_similarity(query_embedding, emb)))
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
{
let selected_idx = remaining_indices.remove(idx);
selected_indices.push(selected_idx);
}
while selected_indices.len() < k && !remaining_indices.is_empty() {
let next_idx = remaining_indices
.iter()
.enumerate()
.filter_map(|(i, &r_idx)| {
let emb = candidates[r_idx].embedding.as_ref()?;
let relevance = cosine_similarity(query_embedding, emb);
let max_sim_selected = selected_indices
.iter()
.filter_map(|&s_idx| candidates[s_idx].embedding.as_ref())
.map(|s_emb| cosine_similarity(emb, s_emb))
.fold(f32::NEG_INFINITY, f32::max);
let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
Some((i, mmr_score))
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i);
if let Some(idx) = next_idx {
let selected_idx = remaining_indices.remove(idx);
selected_indices.push(selected_idx);
} else {
break;
}
}
selected_indices.into_iter().map(|i| candidates[i].clone()).collect()
}
pub fn mmr_rerank(
query_embedding: &[f32],
candidates: Vec<MemoryRecord>,
k: usize,
lambda: f32,
) -> Vec<MemoryRecord> {
if candidates.is_empty() {
return Vec::new();
}
let k = k.min(candidates.len());
let mut selected = Vec::with_capacity(k);
let mut remaining = candidates;
if let Some((idx, _)) = remaining
.iter()
.enumerate()
.filter_map(|(i, r)| {
r.embedding
.as_ref()
.map(|emb| (i, cosine_similarity(query_embedding, emb)))
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
{
selected.push(remaining.swap_remove(idx));
}
while selected.len() < k && !remaining.is_empty() {
let next_idx = remaining
.iter()
.enumerate()
.filter_map(|(i, r)| {
let emb = r.embedding.as_ref()?;
let relevance = cosine_similarity(query_embedding, emb);
let max_sim_selected = selected
.iter()
.filter_map(|s| s.embedding.as_ref())
.map(|s_emb| cosine_similarity(emb, s_emb))
.fold(f32::NEG_INFINITY, f32::max);
let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
Some((i, mmr_score))
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i);
if let Some(idx) = next_idx {
selected.push(remaining.swap_remove(idx));
} else {
break;
}
}
selected
}
pub struct SessionMemory {
store: Box<dyn MemoryStore>,
short_term: parking_lot::RwLock<HashMap<String, Vec<MemoryRecord>>>,
context_window: usize,
}
impl SessionMemory {
pub fn new(store: Box<dyn MemoryStore>, context_window: usize) -> Self {
Self {
store,
short_term: parking_lot::RwLock::new(HashMap::new()),
context_window,
}
}
pub async fn store(&self, record: MemoryRecord) -> Result<()> {
let session_id = record.session_id.clone();
{
let mut short_term = self.short_term.write();
let session_records = short_term.entry(session_id).or_insert_with(Vec::new);
session_records.push(record.clone());
if session_records.len() > self.context_window {
session_records.drain(0..session_records.len() - self.context_window);
}
}
let mut record = record;
if record.embedding.is_none() && !record.content.is_empty() {
if let Ok(embedding) = self.store.embed(&record.content).await {
if !embedding.is_empty() {
record.embedding = Some(embedding);
}
}
}
self.store.store(record).await
}
pub async fn retrieve_recent(&self, session_id: &str) -> Result<Vec<MemoryRecord>> {
let short_term = self.short_term.read();
Ok(short_term.get(session_id).cloned().unwrap_or_default())
}
pub async fn search(
&self,
session_id: &str,
query: &str,
limit: usize,
) -> Result<Vec<MemoryRecord>> {
let query_embedding = self.store.embed(query).await?;
if query_embedding.is_empty() {
return Ok(Vec::new());
}
self.store.search(session_id, query_embedding, limit).await
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.store.embed(text).await
}
pub async fn flush(&self) -> Result<()> {
self.store.flush().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_store() {
let store = InMemoryStore::new();
let record = MemoryRecord {
id: Uuid::new_v4(),
session_id: "test".to_string(),
role: "user".to_string(),
content: "Hello".to_string(),
importance: 0.8,
timestamp: Utc::now(),
metadata: None,
embedding: None,
};
store.store(record.clone()).await.unwrap();
let retrieved = store.retrieve("test", 10).await.unwrap();
assert_eq!(retrieved.len(), 1);
assert_eq!(retrieved[0].content, "Hello");
}
#[tokio::test]
async fn test_session_memory() {
let store = Box::new(InMemoryStore::new());
let memory = SessionMemory::new(store, 5);
let record = MemoryRecord {
id: Uuid::new_v4(),
session_id: "test".to_string(),
role: "user".to_string(),
content: "Test message".to_string(),
importance: 0.9,
timestamp: Utc::now(),
metadata: None,
embedding: None,
};
memory.store(record).await.unwrap();
let recent = memory.retrieve_recent("test").await.unwrap();
assert_eq!(recent.len(), 1);
}
}