Skip to main content

rs_agent/memory/
mod.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use uuid::Uuid;
5
6use crate::error::Result;
7
8// Memory backend implementations
9#[cfg(feature = "postgres")]
10pub mod postgres;
11
12#[cfg(feature = "qdrant")]
13pub mod qdrant;
14
15#[cfg(feature = "mongodb")]
16pub mod mongodb;
17
18// Re-export backends
19#[cfg(feature = "postgres")]
20pub use postgres::PostgresStore;
21
22#[cfg(feature = "qdrant")]
23pub use qdrant::QdrantStore;
24
25#[cfg(feature = "mongodb")]
26pub use mongodb::MongoStore;
27
28/// Memory record storing a piece of information
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MemoryRecord {
31    pub id: Uuid,
32    pub session_id: String,
33    pub role: String,
34    pub content: String,
35    pub importance: f32,
36    pub timestamp: DateTime<Utc>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub metadata: Option<HashMap<String, String>>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub embedding: Option<Vec<f32>>,
41}
42
43/// Memory store trait for different backends
44#[async_trait::async_trait]
45pub trait MemoryStore: Send + Sync {
46    /// Stores a memory record
47    async fn store(&self, record: MemoryRecord) -> Result<()>;
48
49    /// Retrieves memories for a session
50    async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>>;
51
52    /// Searches for similar memories using embeddings
53    async fn search(
54        &self,
55        session_id: &str,
56        query_embedding: Vec<f32>,
57        limit: usize,
58    ) -> Result<Vec<MemoryRecord>>;
59
60    /// Flushes all pending writes
61    async fn flush(&self) -> Result<()>;
62}
63
64/// In-memory store implementation
65pub struct InMemoryStore {
66    records: parking_lot::RwLock<Vec<MemoryRecord>>,
67}
68
69impl InMemoryStore {
70    pub fn new() -> Self {
71        Self {
72            records: parking_lot::RwLock::new(Vec::new()),
73        }
74    }
75}
76
77impl Default for InMemoryStore {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83#[async_trait::async_trait]
84impl MemoryStore for InMemoryStore {
85    async fn store(&self, record: MemoryRecord) -> Result<()> {
86        let mut records = self.records.write();
87        records.push(record);
88        Ok(())
89    }
90
91    async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
92        let records = self.records.read();
93        let filtered: Vec<MemoryRecord> = records
94            .iter()
95            .filter(|r| r.session_id == session_id)
96            .rev()
97            .take(limit)
98            .cloned()
99            .collect();
100        Ok(filtered)
101    }
102
103    async fn search(
104        &self,
105        session_id: &str,
106        query_embedding: Vec<f32>,
107        limit: usize,
108    ) -> Result<Vec<MemoryRecord>> {
109        let records = self.records.read();
110        let mut scored: Vec<(f32, MemoryRecord)> = records
111            .iter()
112            .filter(|r| r.session_id == session_id && r.embedding.is_some())
113            .map(|r| {
114                let embedding = r.embedding.as_ref().unwrap();
115                let similarity = cosine_similarity(&query_embedding, embedding);
116                (similarity, r.clone())
117            })
118            .collect();
119
120        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
121        Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
122    }
123
124    async fn flush(&self) -> Result<()> {
125        Ok(())
126    }
127}
128
129/// Calculates cosine similarity between two vectors
130fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
131    if a.len() != b.len() {
132        return 0.0;
133    }
134
135    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
136    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
137    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
138
139    if mag_a == 0.0 || mag_b == 0.0 {
140        0.0
141    } else {
142        dot / (mag_a * mag_b)
143    }
144}
145
146/// Maximal Marginal Relevance (MMR) for diverse retrieval
147///
148/// Balances relevance to query with diversity in results.
149/// Lambda controls the trade-off: 1.0 = pure relevance, 0.0 = pure diversity
150pub fn mmr_rerank(
151    query_embedding: &[f32],
152    candidates: Vec<MemoryRecord>,
153    k: usize,
154    lambda: f32,
155) -> Vec<MemoryRecord> {
156    if candidates.is_empty() {
157        return Vec::new();
158    }
159
160    let k = k.min(candidates.len());
161    let mut selected = Vec::with_capacity(k);
162    let mut remaining = candidates;
163
164    // Select first item with highest similarity to query
165    if let Some((idx, _)) = remaining
166        .iter()
167        .enumerate()
168        .filter_map(|(i, r)| {
169            r.embedding
170                .as_ref()
171                .map(|emb| (i, cosine_similarity(query_embedding, emb)))
172        })
173        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
174    {
175        selected.push(remaining.swap_remove(idx));
176    }
177
178    // Iteratively select items that maximize MMR score
179    while selected.len() < k && !remaining.is_empty() {
180        let next_idx = remaining
181            .iter()
182            .enumerate()
183            .filter_map(|(i, r)| {
184                let emb = r.embedding.as_ref()?;
185
186                // Relevance: similarity to query
187                let relevance = cosine_similarity(query_embedding, emb);
188
189                // Diversity: max similarity to already selected items
190                let max_sim_selected = selected
191                    .iter()
192                    .filter_map(|s| s.embedding.as_ref())
193                    .map(|s_emb| cosine_similarity(emb, s_emb))
194                    .fold(f32::NEG_INFINITY, f32::max);
195
196                // MMR score: λ * relevance - (1-λ) * max_similarity_to_selected
197                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
198
199                Some((i, mmr_score))
200            })
201            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
202            .map(|(i, _)| i);
203
204        if let Some(idx) = next_idx {
205            selected.push(remaining.swap_remove(idx));
206        } else {
207            break;
208        }
209    }
210
211    selected
212}
213
214/// Session memory manages short-term and long-term memory for a session
215pub struct SessionMemory {
216    store: Box<dyn MemoryStore>,
217    // Short-term cache of recent messages
218    short_term: parking_lot::RwLock<HashMap<String, Vec<MemoryRecord>>>,
219    context_window: usize,
220}
221
222impl SessionMemory {
223    /// Creates a new session memory with the given store
224    pub fn new(store: Box<dyn MemoryStore>, context_window: usize) -> Self {
225        Self {
226            store,
227            short_term: parking_lot::RwLock::new(HashMap::new()),
228            context_window,
229        }
230    }
231
232    /// Stores a memory record
233    pub async fn store(&self, record: MemoryRecord) -> Result<()> {
234        let session_id = record.session_id.clone();
235
236        // Add to short-term cache
237        {
238            let mut short_term = self.short_term.write();
239            let session_records = short_term.entry(session_id).or_insert_with(Vec::new);
240            session_records.push(record.clone());
241
242            // Trim to context window
243            if session_records.len() > self.context_window {
244                session_records.drain(0..session_records.len() - self.context_window);
245            }
246        }
247
248        // Store in long-term
249        self.store.store(record).await
250    }
251
252    /// Retrieves recent memories from short-term cache
253    pub async fn retrieve_recent(&self, session_id: &str) -> Result<Vec<MemoryRecord>> {
254        let short_term = self.short_term.read();
255        Ok(short_term.get(session_id).cloned().unwrap_or_default())
256    }
257
258    /// Searches for relevant memories
259    pub async fn search(
260        &self,
261        session_id: &str,
262        query_embedding: Vec<f32>,
263        limit: usize,
264    ) -> Result<Vec<MemoryRecord>> {
265        self.store.search(session_id, query_embedding, limit).await
266    }
267
268    /// Flushes all pending writes
269    pub async fn flush(&self) -> Result<()> {
270        self.store.flush().await
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[tokio::test]
279    async fn test_in_memory_store() {
280        let store = InMemoryStore::new();
281        let record = MemoryRecord {
282            id: Uuid::new_v4(),
283            session_id: "test".to_string(),
284            role: "user".to_string(),
285            content: "Hello".to_string(),
286            importance: 0.8,
287            timestamp: Utc::now(),
288            metadata: None,
289            embedding: None,
290        };
291
292        store.store(record.clone()).await.unwrap();
293        let retrieved = store.retrieve("test", 10).await.unwrap();
294        assert_eq!(retrieved.len(), 1);
295        assert_eq!(retrieved[0].content, "Hello");
296    }
297
298    #[tokio::test]
299    async fn test_session_memory() {
300        let store = Box::new(InMemoryStore::new());
301        let memory = SessionMemory::new(store, 5);
302
303        let record = MemoryRecord {
304            id: Uuid::new_v4(),
305            session_id: "test".to_string(),
306            role: "user".to_string(),
307            content: "Test message".to_string(),
308            importance: 0.9,
309            timestamp: Utc::now(),
310            metadata: None,
311            embedding: None,
312        };
313
314        memory.store(record).await.unwrap();
315        let recent = memory.retrieve_recent("test").await.unwrap();
316        assert_eq!(recent.len(), 1);
317    }
318}