Skip to main content

rs_agent/memory/
mod.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use uuid::Uuid;
5
6use crate::error::Result;
7
8#[cfg(feature = "fastembed")]
9use fastembed::{InitOptions, TextEmbedding};
10use tokio::sync::OnceCell;
11
12// Memory backend implementations
13#[cfg(feature = "postgres")]
14pub mod postgres;
15
16#[cfg(feature = "qdrant")]
17pub mod qdrant;
18
19#[cfg(feature = "mongodb")]
20pub mod mongodb;
21
22// Re-export backends
23#[cfg(feature = "postgres")]
24pub use postgres::PostgresStore;
25
26#[cfg(feature = "qdrant")]
27pub use qdrant::QdrantStore;
28
29#[cfg(feature = "mongodb")]
30pub use mongodb::MongoStore;
31
32/// Memory record storing a piece of information
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryRecord {
35    pub id: Uuid,
36    pub session_id: String,
37    pub role: String,
38    pub content: String,
39    pub importance: f32,
40    pub timestamp: DateTime<Utc>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub metadata: Option<HashMap<String, String>>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub embedding: Option<Vec<f32>>,
45}
46
47/// Memory store trait for different backends
48#[async_trait::async_trait]
49pub trait MemoryStore: Send + Sync {
50    /// Stores a memory record
51    async fn store(&self, record: MemoryRecord) -> Result<()>;
52
53    /// Retrieves memories for a session
54    async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>>;
55
56    /// Searches for similar memories using embeddings
57    async fn search(
58        &self,
59        session_id: &str,
60        query_embedding: Vec<f32>,
61        limit: usize,
62    ) -> Result<Vec<MemoryRecord>>;
63
64    /// Embeds text using the store's embedding model
65    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
66
67    /// Flushes all pending writes
68    async fn flush(&self) -> Result<()>;
69}
70
71/// In-memory store implementation
72pub struct InMemoryStore {
73    records: parking_lot::RwLock<Vec<MemoryRecord>>,
74    #[cfg(feature = "fastembed")]
75    embedder: OnceCell<TextEmbedding>,
76}
77
78impl InMemoryStore {
79    pub fn new() -> Self {
80        Self {
81            records: parking_lot::RwLock::new(Vec::new()),
82            #[cfg(feature = "fastembed")]
83            embedder: OnceCell::new(),
84        }
85    }
86}
87
88impl Default for InMemoryStore {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94#[async_trait::async_trait]
95impl MemoryStore for InMemoryStore {
96    async fn store(&self, record: MemoryRecord) -> Result<()> {
97        let mut records = self.records.write();
98        records.push(record);
99        Ok(())
100    }
101
102    async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
103        let records = self.records.read();
104        let filtered: Vec<MemoryRecord> = records
105            .iter()
106            .filter(|r| r.session_id == session_id)
107            .rev()
108            .take(limit)
109            .cloned()
110            .collect();
111        Ok(filtered)
112    }
113
114    async fn search(
115        &self,
116        session_id: &str,
117        query_embedding: Vec<f32>,
118        limit: usize,
119    ) -> Result<Vec<MemoryRecord>> {
120        let records = self.records.read();
121        let mut scored: Vec<(f32, MemoryRecord)> = records
122            .iter()
123            .filter(|r| r.session_id == session_id && r.embedding.is_some())
124            .map(|r| {
125                let embedding = r.embedding.as_ref().unwrap();
126                let similarity = cosine_similarity(&query_embedding, embedding);
127                (similarity, r.clone())
128            })
129            .collect();
130
131        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
132        Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
133    }
134
135    async fn flush(&self) -> Result<()> {
136        Ok(())
137    }
138
139    async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
140        #[cfg(feature = "fastembed")]
141        {
142            let embedder = self
143                .embedder
144                .get_or_try_init(|| async {
145                    TextEmbedding::try_new(InitOptions::default())
146                        .map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))
147                })
148                .await?;
149
150            let embeddings = embedder
151                .embed(vec![_text], None)
152                .map_err(|e| crate::error::AgentError::MemoryError(e.to_string()))?;
153
154            Ok(embeddings[0].clone())
155        }
156
157        #[cfg(not(feature = "fastembed"))]
158        Ok(vec![])
159    }
160}
161
162/// Calculates cosine similarity between two vectors
163fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
164    if a.len() != b.len() {
165        return 0.0;
166    }
167
168    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
169    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171
172    if mag_a == 0.0 || mag_b == 0.0 {
173        0.0
174    } else {
175        dot / (mag_a * mag_b)
176    }
177}
178
179pub fn mmr_rerank_records(
180    query_embedding: &[f32],
181    candidates: Vec<MemoryRecord>,
182    k: usize,
183    lambda: f32,
184) -> Vec<MemoryRecord> {
185    if candidates.is_empty() {
186        return Vec::new();
187    }
188
189    let k = k.min(candidates.len());
190    let mut selected_indices = Vec::with_capacity(k);
191    let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
192
193    // Select first item with highest similarity to query
194    if let Some((idx, _)) = remaining_indices
195        .iter()
196        .enumerate()
197        .filter_map(|(i, &r_idx)| {
198             candidates[r_idx].embedding
199                .as_ref()
200                .map(|emb| (i, cosine_similarity(query_embedding, emb)))
201        })
202        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
203    {
204        let selected_idx = remaining_indices.remove(idx);
205        selected_indices.push(selected_idx);
206    }
207
208    // Iteratively select items that maximize MMR score
209    while selected_indices.len() < k && !remaining_indices.is_empty() {
210        let next_idx = remaining_indices
211            .iter()
212            .enumerate()
213            .filter_map(|(i, &r_idx)| {
214                let emb = candidates[r_idx].embedding.as_ref()?;
215
216                // Relevance: similarity to query
217                let relevance = cosine_similarity(query_embedding, emb);
218
219                // Diversity: max similarity to already selected items
220                let max_sim_selected = selected_indices
221                    .iter()
222                    .filter_map(|&s_idx| candidates[s_idx].embedding.as_ref())
223                    .map(|s_emb| cosine_similarity(emb, s_emb))
224                    .fold(f32::NEG_INFINITY, f32::max);
225
226                // MMR score: λ * relevance - (1-λ) * max_similarity_to_selected
227                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
228
229                Some((i, mmr_score))
230            })
231            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
232            .map(|(i, _)| i);
233
234        if let Some(idx) = next_idx {
235            let selected_idx = remaining_indices.remove(idx);
236            selected_indices.push(selected_idx);
237        } else {
238            break;
239        }
240    }
241
242    selected_indices.into_iter().map(|i| candidates[i].clone()).collect()
243}
244
245/// Maximal Marginal Relevance (MMR) for diverse retrieval
246///
247/// Balances relevance to query with diversity in results.
248/// Lambda controls the trade-off: 1.0 = pure relevance, 0.0 = pure diversity
249pub fn mmr_rerank(
250    query_embedding: &[f32],
251    candidates: Vec<MemoryRecord>,
252    k: usize,
253    lambda: f32,
254) -> Vec<MemoryRecord> {
255    if candidates.is_empty() {
256        return Vec::new();
257    }
258
259    let k = k.min(candidates.len());
260    let mut selected = Vec::with_capacity(k);
261    let mut remaining = candidates;
262
263    // Select first item with highest similarity to query
264    if let Some((idx, _)) = remaining
265        .iter()
266        .enumerate()
267        .filter_map(|(i, r)| {
268            r.embedding
269                .as_ref()
270                .map(|emb| (i, cosine_similarity(query_embedding, emb)))
271        })
272        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
273    {
274        selected.push(remaining.swap_remove(idx));
275    }
276
277    // Iteratively select items that maximize MMR score
278    while selected.len() < k && !remaining.is_empty() {
279        let next_idx = remaining
280            .iter()
281            .enumerate()
282            .filter_map(|(i, r)| {
283                let emb = r.embedding.as_ref()?;
284
285                // Relevance: similarity to query
286                let relevance = cosine_similarity(query_embedding, emb);
287
288                // Diversity: max similarity to already selected items
289                let max_sim_selected = selected
290                    .iter()
291                    .filter_map(|s| s.embedding.as_ref())
292                    .map(|s_emb| cosine_similarity(emb, s_emb))
293                    .fold(f32::NEG_INFINITY, f32::max);
294
295                // MMR score: λ * relevance - (1-λ) * max_similarity_to_selected
296                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
297
298                Some((i, mmr_score))
299            })
300            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
301            .map(|(i, _)| i);
302
303        if let Some(idx) = next_idx {
304            selected.push(remaining.swap_remove(idx));
305        } else {
306            break;
307        }
308    }
309
310    selected
311}
312
313/// Session memory manages short-term and long-term memory for a session
314pub struct SessionMemory {
315    store: Box<dyn MemoryStore>,
316    // Short-term cache of recent messages
317    short_term: parking_lot::RwLock<HashMap<String, Vec<MemoryRecord>>>,
318    context_window: usize,
319}
320
321impl SessionMemory {
322    /// Creates a new session memory with the given store
323    pub fn new(store: Box<dyn MemoryStore>, context_window: usize) -> Self {
324        Self {
325            store,
326            short_term: parking_lot::RwLock::new(HashMap::new()),
327            context_window,
328        }
329    }
330
331    /// Stores a memory record
332    pub async fn store(&self, record: MemoryRecord) -> Result<()> {
333        let session_id = record.session_id.clone();
334
335        // Add to short-term cache
336        {
337            let mut short_term = self.short_term.write();
338            let session_records = short_term.entry(session_id).or_insert_with(Vec::new);
339            session_records.push(record.clone());
340
341            // Trim to context window
342            if session_records.len() > self.context_window {
343                session_records.drain(0..session_records.len() - self.context_window);
344            }
345        }
346
347        // Generate embedding if not present
348        let mut record = record;
349        if record.embedding.is_none() && !record.content.is_empty() {
350            if let Ok(embedding) = self.store.embed(&record.content).await {
351                if !embedding.is_empty() {
352                    record.embedding = Some(embedding);
353                }
354            }
355        }
356
357        // Store in long-term
358        self.store.store(record).await
359    }
360
361    /// Retrieves recent memories from short-term cache
362    pub async fn retrieve_recent(&self, session_id: &str) -> Result<Vec<MemoryRecord>> {
363        let short_term = self.short_term.read();
364        Ok(short_term.get(session_id).cloned().unwrap_or_default())
365    }
366
367    pub async fn search(
368        &self,
369        session_id: &str,
370        query: &str,
371        limit: usize,
372    ) -> Result<Vec<MemoryRecord>> {
373        let query_embedding = self.store.embed(query).await?;
374        if query_embedding.is_empty() {
375             return Ok(Vec::new());
376        }
377        self.store.search(session_id, query_embedding, limit).await
378    }
379
380    /// Embeds text manually (exposed for testing/utils)
381    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
382        self.store.embed(text).await
383    }
384
385    /// Flushes all pending writes
386    pub async fn flush(&self) -> Result<()> {
387        self.store.flush().await
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[tokio::test]
396    async fn test_in_memory_store() {
397        let store = InMemoryStore::new();
398        let record = MemoryRecord {
399            id: Uuid::new_v4(),
400            session_id: "test".to_string(),
401            role: "user".to_string(),
402            content: "Hello".to_string(),
403            importance: 0.8,
404            timestamp: Utc::now(),
405            metadata: None,
406            embedding: None,
407        };
408
409        store.store(record.clone()).await.unwrap();
410        let retrieved = store.retrieve("test", 10).await.unwrap();
411        assert_eq!(retrieved.len(), 1);
412        assert_eq!(retrieved[0].content, "Hello");
413    }
414
415    #[tokio::test]
416    async fn test_session_memory() {
417        // Mock store that doesn't actually embed but stores records
418        let store = Box::new(InMemoryStore::new());
419        let memory = SessionMemory::new(store, 5);
420
421        let record = MemoryRecord {
422            id: Uuid::new_v4(),
423            session_id: "test".to_string(),
424            role: "user".to_string(),
425            content: "Test message".to_string(),
426            importance: 0.9,
427            timestamp: Utc::now(),
428            metadata: None,
429            embedding: None,
430        };
431
432        memory.store(record).await.unwrap();
433        let recent = memory.retrieve_recent("test").await.unwrap();
434        assert_eq!(recent.len(), 1);
435    }
436}