1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use uuid::Uuid;
5
6use crate::error::Result;
7
8#[cfg(feature = "postgres")]
10pub mod postgres;
11
12#[cfg(feature = "qdrant")]
13pub mod qdrant;
14
15#[cfg(feature = "mongodb")]
16pub mod mongodb;
17
18#[cfg(feature = "postgres")]
20pub use postgres::PostgresStore;
21
22#[cfg(feature = "qdrant")]
23pub use qdrant::QdrantStore;
24
25#[cfg(feature = "mongodb")]
26pub use mongodb::MongoStore;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MemoryRecord {
31 pub id: Uuid,
32 pub session_id: String,
33 pub role: String,
34 pub content: String,
35 pub importance: f32,
36 pub timestamp: DateTime<Utc>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub metadata: Option<HashMap<String, String>>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub embedding: Option<Vec<f32>>,
41}
42
43#[async_trait::async_trait]
45pub trait MemoryStore: Send + Sync {
46 async fn store(&self, record: MemoryRecord) -> Result<()>;
48
49 async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>>;
51
52 async fn search(
54 &self,
55 session_id: &str,
56 query_embedding: Vec<f32>,
57 limit: usize,
58 ) -> Result<Vec<MemoryRecord>>;
59
60 async fn flush(&self) -> Result<()>;
62}
63
64pub struct InMemoryStore {
66 records: parking_lot::RwLock<Vec<MemoryRecord>>,
67}
68
69impl InMemoryStore {
70 pub fn new() -> Self {
71 Self {
72 records: parking_lot::RwLock::new(Vec::new()),
73 }
74 }
75}
76
77impl Default for InMemoryStore {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[async_trait::async_trait]
84impl MemoryStore for InMemoryStore {
85 async fn store(&self, record: MemoryRecord) -> Result<()> {
86 let mut records = self.records.write();
87 records.push(record);
88 Ok(())
89 }
90
91 async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
92 let records = self.records.read();
93 let filtered: Vec<MemoryRecord> = records
94 .iter()
95 .filter(|r| r.session_id == session_id)
96 .rev()
97 .take(limit)
98 .cloned()
99 .collect();
100 Ok(filtered)
101 }
102
103 async fn search(
104 &self,
105 session_id: &str,
106 query_embedding: Vec<f32>,
107 limit: usize,
108 ) -> Result<Vec<MemoryRecord>> {
109 let records = self.records.read();
110 let mut scored: Vec<(f32, MemoryRecord)> = records
111 .iter()
112 .filter(|r| r.session_id == session_id && r.embedding.is_some())
113 .map(|r| {
114 let embedding = r.embedding.as_ref().unwrap();
115 let similarity = cosine_similarity(&query_embedding, embedding);
116 (similarity, r.clone())
117 })
118 .collect();
119
120 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
121 Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
122 }
123
124 async fn flush(&self) -> Result<()> {
125 Ok(())
126 }
127}
128
129fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
131 if a.len() != b.len() {
132 return 0.0;
133 }
134
135 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
136 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
137 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
138
139 if mag_a == 0.0 || mag_b == 0.0 {
140 0.0
141 } else {
142 dot / (mag_a * mag_b)
143 }
144}
145
146pub fn mmr_rerank(
151 query_embedding: &[f32],
152 candidates: Vec<MemoryRecord>,
153 k: usize,
154 lambda: f32,
155) -> Vec<MemoryRecord> {
156 if candidates.is_empty() {
157 return Vec::new();
158 }
159
160 let k = k.min(candidates.len());
161 let mut selected = Vec::with_capacity(k);
162 let mut remaining = candidates;
163
164 if let Some((idx, _)) = remaining
166 .iter()
167 .enumerate()
168 .filter_map(|(i, r)| {
169 r.embedding
170 .as_ref()
171 .map(|emb| (i, cosine_similarity(query_embedding, emb)))
172 })
173 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
174 {
175 selected.push(remaining.swap_remove(idx));
176 }
177
178 while selected.len() < k && !remaining.is_empty() {
180 let next_idx = remaining
181 .iter()
182 .enumerate()
183 .filter_map(|(i, r)| {
184 let emb = r.embedding.as_ref()?;
185
186 let relevance = cosine_similarity(query_embedding, emb);
188
189 let max_sim_selected = selected
191 .iter()
192 .filter_map(|s| s.embedding.as_ref())
193 .map(|s_emb| cosine_similarity(emb, s_emb))
194 .fold(f32::NEG_INFINITY, f32::max);
195
196 let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_selected;
198
199 Some((i, mmr_score))
200 })
201 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
202 .map(|(i, _)| i);
203
204 if let Some(idx) = next_idx {
205 selected.push(remaining.swap_remove(idx));
206 } else {
207 break;
208 }
209 }
210
211 selected
212}
213
214pub struct SessionMemory {
216 store: Box<dyn MemoryStore>,
217 short_term: parking_lot::RwLock<HashMap<String, Vec<MemoryRecord>>>,
219 context_window: usize,
220}
221
222impl SessionMemory {
223 pub fn new(store: Box<dyn MemoryStore>, context_window: usize) -> Self {
225 Self {
226 store,
227 short_term: parking_lot::RwLock::new(HashMap::new()),
228 context_window,
229 }
230 }
231
232 pub async fn store(&self, record: MemoryRecord) -> Result<()> {
234 let session_id = record.session_id.clone();
235
236 {
238 let mut short_term = self.short_term.write();
239 let session_records = short_term.entry(session_id).or_insert_with(Vec::new);
240 session_records.push(record.clone());
241
242 if session_records.len() > self.context_window {
244 session_records.drain(0..session_records.len() - self.context_window);
245 }
246 }
247
248 self.store.store(record).await
250 }
251
252 pub async fn retrieve_recent(&self, session_id: &str) -> Result<Vec<MemoryRecord>> {
254 let short_term = self.short_term.read();
255 Ok(short_term.get(session_id).cloned().unwrap_or_default())
256 }
257
258 pub async fn search(
260 &self,
261 session_id: &str,
262 query_embedding: Vec<f32>,
263 limit: usize,
264 ) -> Result<Vec<MemoryRecord>> {
265 self.store.search(session_id, query_embedding, limit).await
266 }
267
268 pub async fn flush(&self) -> Result<()> {
270 self.store.flush().await
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[tokio::test]
279 async fn test_in_memory_store() {
280 let store = InMemoryStore::new();
281 let record = MemoryRecord {
282 id: Uuid::new_v4(),
283 session_id: "test".to_string(),
284 role: "user".to_string(),
285 content: "Hello".to_string(),
286 importance: 0.8,
287 timestamp: Utc::now(),
288 metadata: None,
289 embedding: None,
290 };
291
292 store.store(record.clone()).await.unwrap();
293 let retrieved = store.retrieve("test", 10).await.unwrap();
294 assert_eq!(retrieved.len(), 1);
295 assert_eq!(retrieved[0].content, "Hello");
296 }
297
298 #[tokio::test]
299 async fn test_session_memory() {
300 let store = Box::new(InMemoryStore::new());
301 let memory = SessionMemory::new(store, 5);
302
303 let record = MemoryRecord {
304 id: Uuid::new_v4(),
305 session_id: "test".to_string(),
306 role: "user".to_string(),
307 content: "Test message".to_string(),
308 importance: 0.9,
309 timestamp: Utc::now(),
310 metadata: None,
311 embedding: None,
312 };
313
314 memory.store(record).await.unwrap();
315 let recent = memory.retrieve_recent("test").await.unwrap();
316 assert_eq!(recent.len(), 1);
317 }
318}