1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use uuid::Uuid;
5
6use crate::error::Result;
7
8#[cfg(feature = "fastembed")]
9use fastembed::{InitOptions, TextEmbedding};
10use tokio::sync::OnceCell;
11
12#[cfg(feature = "postgres")]
14pub mod postgres;
15
16#[cfg(feature = "qdrant")]
17pub mod qdrant;
18
19#[cfg(feature = "mongodb")]
20pub mod mongodb;
21
22#[cfg(feature = "postgres")]
24pub use postgres::PostgresStore;
25
26#[cfg(feature = "qdrant")]
27pub use qdrant::QdrantStore;
28
29#[cfg(feature = "mongodb")]
30pub use mongodb::MongoStore;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryRecord {
35 pub id: Uuid,
36 pub session_id: String,
37 pub role: String,
38 pub content: String,
39 pub importance: f32,
40 pub timestamp: DateTime<Utc>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub metadata: Option<HashMap<String, String>>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub embedding: Option<Vec<f32>>,
45}
46
47#[async_trait::async_trait]
49pub trait MemoryStore: Send + Sync {
50 async fn store(&self, record: MemoryRecord) -> Result<()>;
52
53 async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>>;
55
56 async fn search(
58 &self,
59 session_id: &str,
60 query_embedding: Vec<f32>,
61 limit: usize,
62 ) -> Result<Vec<MemoryRecord>>;
63
64 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
66
67 async fn flush(&self) -> Result<()>;
69}
70
71pub struct InMemoryStore {
73 records: parking_lot::RwLock<Vec<MemoryRecord>>,
74 #[cfg(feature = "fastembed")]
75 embedder: OnceCell<TextEmbedding>,
76}
77
78impl InMemoryStore {
79 pub fn new() -> Self {
80 Self {
81 records: parking_lot::RwLock::new(Vec::new()),
82 #[cfg(feature = "fastembed")]
83 embedder: OnceCell::new(),
84 }
85 }
86}
87
88impl Default for InMemoryStore {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[async_trait::async_trait]
95impl MemoryStore for InMemoryStore {
96 async fn store(&self, record: MemoryRecord) -> Result<()> {
97 let mut records = self.records.write();
98 records.push(record);
99 Ok(())
100 }
101
102 async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
103 let records = self.records.read();
104 let filtered: Vec<MemoryRecord> = records
105 .iter()
106 .filter(|r| r.session_id == session_id)
107 .rev()
108 .take(limit)
109 .cloned()
110 .collect();
111 Ok(filtered)
112 }
113
114 async fn search(
115 &self,
116 session_id: &str,
117 query_embedding: Vec<f32>,
118 limit: usize,
119 ) -> Result<Vec<MemoryRecord>> {
120 let records = self.records.read();
121 let mut scored: Vec<(f32, MemoryRecord)> = records
122 .iter()
123 .filter(|r| r.session_id == session_id && r.embedding.is_some())
124 .map(|r| {
125 let embedding = r.embedding.as_ref().unwrap();
126 let similarity = cosine_similarity(&query_embedding, embedding);
127 (similarity, r.clone())
128 })
129 .collect();
130
131 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
132 Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
133 }
134
135 async fn flush(&self) -> Result<()> {
136 Ok(())
137 }
138
139 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
140 #[cfg(feature = "fastembed")]
141 {
142 let embedder = self
143 .embedder
144 .get_or_try_init(|| async {
145 TextEmbedding::try_new(InitOptions::default())
146 .map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))
147 })
148 .await?;
149
150 let embeddings = embedder
151 .embed(vec![_text], None)
152 .map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))?;
153
154 Ok(embeddings[0].clone())
155 }
156
157 #[cfg(not(feature = "fastembed"))]
158 Ok(vec![])
159 }
160}
161
162fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
164 if a.len() != b.len() {
165 return 0.0;
166 }
167
168 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
169 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171
172 if mag_a == 0.0 || mag_b == 0.0 {
173 0.0
174 } else {
175 dot / (mag_a * mag_b)
176 }
177}
178
179pub fn mmr_rerank_records(
180 query_embedding: &[f32],
181 candidates: Vec<MemoryRecord>,
182 k: usize,
183 lambda: f32,
184) -> Vec<MemoryRecord> {
185 if candidates.is_empty() {
186 return Vec::new();
187 }
188
189 let k = k.min(candidates.len());
190 let mut selected_indices = Vec::with_capacity(k);
191 let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
192
193 if let Some((idx, _)) = remaining_indices
195 .iter()
196 .enumerate()
197 .filter_map(|(i, &r_idx)| {
198 candidates[r_idx].embedding
199 .as_ref()
200 .map(|emb| (i, cosine_similarity(query_embedding, emb)))
201 })
202 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
203 {
204 let selected_idx = remaining_indices.remove(idx);
205 selected_indices.push(selected_idx);
206 }
207
208 while selected_indices.len() < k && !remaining_indices.is_empty() {
210 let next_idx = remaining_indices
211 .iter()
212 .enumerate()
213 .filter_map(|(i, &r_idx)| {
214 let emb = candidates[r_idx].embedding.as_ref()?;
215
216 let relevance = cosine_similarity(query_embedding, emb);
218
219 let max_sim_selected = selected_indices
221 .iter()
222 .filter_map(|&s_idx| candidates[s_idx].embedding.as_ref())
223 .map(|s_emb| cosine_similarity(emb, s_emb))
224 .fold(f32::NEG_INFINITY, f32::max);
225
226 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
228
229 Some((i, mmr_score))
230 })
231 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
232 .map(|(i, _)| i);
233
234 if let Some(idx) = next_idx {
235 let selected_idx = remaining_indices.remove(idx);
236 selected_indices.push(selected_idx);
237 } else {
238 break;
239 }
240 }
241
242 selected_indices.into_iter().map(|i| candidates[i].clone()).collect()
243}
244
245pub fn mmr_rerank(
250 query_embedding: &[f32],
251 candidates: Vec<MemoryRecord>,
252 k: usize,
253 lambda: f32,
254) -> Vec<MemoryRecord> {
255 if candidates.is_empty() {
256 return Vec::new();
257 }
258
259 let k = k.min(candidates.len());
260 let mut selected = Vec::with_capacity(k);
261 let mut remaining = candidates;
262
263 if let Some((idx, _)) = remaining
265 .iter()
266 .enumerate()
267 .filter_map(|(i, r)| {
268 r.embedding
269 .as_ref()
270 .map(|emb| (i, cosine_similarity(query_embedding, emb)))
271 })
272 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
273 {
274 selected.push(remaining.swap_remove(idx));
275 }
276
277 while selected.len() < k && !remaining.is_empty() {
279 let next_idx = remaining
280 .iter()
281 .enumerate()
282 .filter_map(|(i, r)| {
283 let emb = r.embedding.as_ref()?;
284
285 let relevance = cosine_similarity(query_embedding, emb);
287
288 let max_sim_selected = selected
290 .iter()
291 .filter_map(|s| s.embedding.as_ref())
292 .map(|s_emb| cosine_similarity(emb, s_emb))
293 .fold(f32::NEG_INFINITY, f32::max);
294
295 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
297
298 Some((i, mmr_score))
299 })
300 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
301 .map(|(i, _)| i);
302
303 if let Some(idx) = next_idx {
304 selected.push(remaining.swap_remove(idx));
305 } else {
306 break;
307 }
308 }
309
310 selected
311}
312
313pub struct SessionMemory {
315 store: Box<dyn MemoryStore>,
316 short_term: parking_lot::RwLock<HashMap<String, Vec<MemoryRecord>>>,
318 context_window: usize,
319}
320
321impl SessionMemory {
322 pub fn new(store: Box<dyn MemoryStore>, context_window: usize) -> Self {
324 Self {
325 store,
326 short_term: parking_lot::RwLock::new(HashMap::new()),
327 context_window,
328 }
329 }
330
331 pub async fn store(&self, record: MemoryRecord) -> Result<()> {
333 let session_id = record.session_id.clone();
334
335 {
337 let mut short_term = self.short_term.write();
338 let session_records = short_term.entry(session_id).or_insert_with(Vec::new);
339 session_records.push(record.clone());
340
341 if session_records.len() > self.context_window {
343 session_records.drain(0..session_records.len() - self.context_window);
344 }
345 }
346
347 let mut record = record;
349 if record.embedding.is_none() && !record.content.is_empty() {
350 if let Ok(embedding) = self.store.embed(&record.content).await {
351 if !embedding.is_empty() {
352 record.embedding = Some(embedding);
353 }
354 }
355 }
356
357 self.store.store(record).await
359 }
360
361 pub async fn retrieve_recent(&self, session_id: &str) -> Result<Vec<MemoryRecord>> {
363 let short_term = self.short_term.read();
364 Ok(short_term.get(session_id).cloned().unwrap_or_default())
365 }
366
367 pub async fn search(
368 &self,
369 session_id: &str,
370 query: &str,
371 limit: usize,
372 ) -> Result<Vec<MemoryRecord>> {
373 let query_embedding = self.store.embed(query).await?;
374 if query_embedding.is_empty() {
375 return Ok(Vec::new());
376 }
377 self.store.search(session_id, query_embedding, limit).await
378 }
379
380 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
382 self.store.embed(text).await
383 }
384
385 pub async fn flush(&self) -> Result<()> {
387 self.store.flush().await
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[tokio::test]
396 async fn test_in_memory_store() {
397 let store = InMemoryStore::new();
398 let record = MemoryRecord {
399 id: Uuid::new_v4(),
400 session_id: "test".to_string(),
401 role: "user".to_string(),
402 content: "Hello".to_string(),
403 importance: 0.8,
404 timestamp: Utc::now(),
405 metadata: None,
406 embedding: None,
407 };
408
409 store.store(record.clone()).await.unwrap();
410 let retrieved = store.retrieve("test", 10).await.unwrap();
411 assert_eq!(retrieved.len(), 1);
412 assert_eq!(retrieved[0].content, "Hello");
413 }
414
415 #[tokio::test]
416 async fn test_session_memory() {
417 let store = Box::new(InMemoryStore::new());
419 let memory = SessionMemory::new(store, 5);
420
421 let record = MemoryRecord {
422 id: Uuid::new_v4(),
423 session_id: "test".to_string(),
424 role: "user".to_string(),
425 content: "Test message".to_string(),
426 importance: 0.9,
427 timestamp: Utc::now(),
428 metadata: None,
429 embedding: None,
430 };
431
432 memory.store(record).await.unwrap();
433 let recent = memory.retrieve_recent("test").await.unwrap();
434 assert_eq!(recent.len(), 1);
435 }
436}