1use std::sync::Arc;
2
3use mem7_config::MemoryEngineConfig;
4use mem7_core::{
5 AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryEvent, MemoryFilter,
6 MemoryItem, SearchResult, new_memory_id,
7};
8use mem7_embedding::{EmbeddingClient, OpenAICompatibleEmbedding};
9use mem7_error::{Mem7Error, Result};
10use mem7_history::SqliteHistory;
11use mem7_llm::{LlmClient, OpenAICompatibleLlm};
12use mem7_vector::{DistanceMetric, FlatIndex, UpstashVectorIndex, VectorIndex, VectorSearchResult};
13use tracing::{debug, info};
14use uuid::Uuid;
15
16use crate::pipeline;
17
18pub struct MemoryEngine {
20 llm: Arc<dyn LlmClient>,
21 embedder: Arc<dyn EmbeddingClient>,
22 vector_index: Arc<dyn VectorIndex>,
23 history: Arc<SqliteHistory>,
24 config: MemoryEngineConfig,
25}
26
27impl MemoryEngine {
28 pub async fn new(config: MemoryEngineConfig) -> Result<Self> {
29 let llm = Arc::new(OpenAICompatibleLlm::new(config.llm.clone())) as Arc<dyn LlmClient>;
30 let embedder = Arc::new(OpenAICompatibleEmbedding::new(config.embedding.clone()))
31 as Arc<dyn EmbeddingClient>;
32
33 let vector_index: Arc<dyn VectorIndex> = match config.vector.provider.as_str() {
34 "upstash" => {
35 let url = config
36 .vector
37 .upstash_url
38 .as_deref()
39 .ok_or_else(|| Mem7Error::Config("upstash_url is required".into()))?;
40 let token = config
41 .vector
42 .upstash_token
43 .as_deref()
44 .ok_or_else(|| Mem7Error::Config("upstash_token is required".into()))?;
45 info!(namespace = %config.vector.collection_name, "using Upstash Vector");
46 Arc::new(UpstashVectorIndex::new(
47 url,
48 token,
49 &config.vector.collection_name,
50 ))
51 }
52 _ => {
53 info!("using in-memory FlatIndex");
54 Arc::new(FlatIndex::new(DistanceMetric::Cosine))
55 }
56 };
57
58 let history = Arc::new(SqliteHistory::new(&config.history.db_path).await?);
59
60 info!("MemoryEngine initialized");
61
62 Ok(Self {
63 llm,
64 embedder,
65 vector_index,
66 history,
67 config,
68 })
69 }
70
71 pub async fn add(
73 &self,
74 messages: &[ChatMessage],
75 user_id: Option<&str>,
76 agent_id: Option<&str>,
77 run_id: Option<&str>,
78 ) -> Result<AddResult> {
79 let facts = pipeline::extract_facts(
80 self.llm.as_ref(),
81 messages,
82 self.config.custom_fact_extraction_prompt.as_deref(),
83 )
84 .await?;
85
86 if facts.is_empty() {
87 return Ok(AddResult {
88 results: Vec::new(),
89 });
90 }
91
92 debug!(count = facts.len(), "extracted facts");
93
94 let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
95 let embeddings = self.embedder.embed(&fact_texts).await?;
96
97 let filter = MemoryFilter {
98 user_id: user_id.map(String::from),
99 agent_id: agent_id.map(String::from),
100 run_id: run_id.map(String::from),
101 };
102 let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
103
104 for embedding in &embeddings {
105 let results = self
106 .vector_index
107 .search(embedding, 5, Some(&filter))
108 .await?;
109 for VectorSearchResult { id, score, payload } in results {
110 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
111 all_retrieved.push((id, text.to_string(), score));
112 }
113 }
114 }
115
116 let (update_resp, id_mapping) = pipeline::decide_memory_updates(
117 self.llm.as_ref(),
118 &facts,
119 all_retrieved,
120 self.config.custom_update_memory_prompt.as_deref(),
121 )
122 .await?;
123
124 let now = chrono_now();
125 let mut results = Vec::new();
126
127 for decision in &update_resp.memory {
128 match decision.event {
129 MemoryAction::Add => {
130 let memory_id = new_memory_id();
131 let text = &decision.text;
132
133 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
134 let vec = vecs.into_iter().next().unwrap_or_default();
135
136 let payload = serde_json::json!({
137 "text": text,
138 "user_id": user_id,
139 "agent_id": agent_id,
140 "run_id": run_id,
141 "created_at": now,
142 "updated_at": now,
143 });
144
145 self.vector_index.insert(memory_id, &vec, payload).await?;
146
147 self.history
148 .add_event(memory_id, None, Some(text), MemoryAction::Add)
149 .await?;
150
151 results.push(MemoryActionResult {
152 id: memory_id,
153 action: MemoryAction::Add,
154 old_value: None,
155 new_value: Some(text.clone()),
156 });
157 }
158 MemoryAction::Update => {
159 if let Some(real_id) = id_mapping.resolve(&decision.id) {
160 let text = &decision.text;
161 let old_text = decision.old_memory.as_deref();
162
163 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
164 let vec = vecs.into_iter().next().unwrap_or_default();
165
166 let payload = serde_json::json!({
167 "text": text,
168 "user_id": user_id,
169 "agent_id": agent_id,
170 "run_id": run_id,
171 "updated_at": now,
172 });
173
174 self.vector_index
175 .update(&real_id, Some(&vec), Some(payload))
176 .await?;
177
178 self.history
179 .add_event(real_id, old_text, Some(text), MemoryAction::Update)
180 .await?;
181
182 results.push(MemoryActionResult {
183 id: real_id,
184 action: MemoryAction::Update,
185 old_value: old_text.map(String::from),
186 new_value: Some(text.clone()),
187 });
188 }
189 }
190 MemoryAction::Delete => {
191 if let Some(real_id) = id_mapping.resolve(&decision.id) {
192 let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
193
194 self.vector_index.delete(&real_id).await?;
195
196 self.history
197 .add_event(real_id, old_text, None, MemoryAction::Delete)
198 .await?;
199
200 results.push(MemoryActionResult {
201 id: real_id,
202 action: MemoryAction::Delete,
203 old_value: old_text.map(String::from),
204 new_value: None,
205 });
206 }
207 }
208 MemoryAction::None => {}
209 }
210 }
211
212 info!(count = results.len(), "memory operations completed");
213 Ok(AddResult { results })
214 }
215
216 pub async fn search(
218 &self,
219 query: &str,
220 user_id: Option<&str>,
221 agent_id: Option<&str>,
222 run_id: Option<&str>,
223 limit: usize,
224 ) -> Result<SearchResult> {
225 let vecs = self.embedder.embed(&[query.to_string()]).await?;
226 let query_vec = vecs.into_iter().next().unwrap_or_default();
227
228 let filter = MemoryFilter {
229 user_id: user_id.map(String::from),
230 agent_id: agent_id.map(String::from),
231 run_id: run_id.map(String::from),
232 };
233
234 let results = self
235 .vector_index
236 .search(&query_vec, limit, Some(&filter))
237 .await?;
238
239 let memories = results
240 .into_iter()
241 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
242 .collect();
243
244 Ok(SearchResult { memories })
245 }
246
247 pub async fn get(&self, memory_id: Uuid) -> Result<MemoryItem> {
249 let entry = self
250 .vector_index
251 .get(&memory_id)
252 .await?
253 .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
254
255 Ok(payload_to_memory_item(memory_id, &entry.1, None))
256 }
257
258 pub async fn get_all(
260 &self,
261 user_id: Option<&str>,
262 agent_id: Option<&str>,
263 run_id: Option<&str>,
264 ) -> Result<Vec<MemoryItem>> {
265 let filter = MemoryFilter {
266 user_id: user_id.map(String::from),
267 agent_id: agent_id.map(String::from),
268 run_id: run_id.map(String::from),
269 };
270
271 let entries = self.vector_index.list(Some(&filter), None).await?;
272
273 Ok(entries
274 .into_iter()
275 .map(|(id, payload)| payload_to_memory_item(id, &payload, None))
276 .collect())
277 }
278
279 pub async fn update(&self, memory_id: Uuid, new_text: &str) -> Result<()> {
281 let entry = self
282 .vector_index
283 .get(&memory_id)
284 .await?
285 .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
286
287 let old_text = entry
288 .1
289 .get("text")
290 .and_then(|v| v.as_str())
291 .map(String::from);
292
293 let vecs = self.embedder.embed(&[new_text.to_string()]).await?;
294 let vec = vecs.into_iter().next().unwrap_or_default();
295
296 let mut payload = entry.1.clone();
297 payload["text"] = serde_json::Value::String(new_text.to_string());
298 payload["updated_at"] = serde_json::Value::String(chrono_now());
299
300 self.vector_index
301 .update(&memory_id, Some(&vec), Some(payload))
302 .await?;
303
304 self.history
305 .add_event(
306 memory_id,
307 old_text.as_deref(),
308 Some(new_text),
309 MemoryAction::Update,
310 )
311 .await?;
312
313 Ok(())
314 }
315
316 pub async fn delete(&self, memory_id: Uuid) -> Result<()> {
318 let entry = self.vector_index.get(&memory_id).await?;
319 let old_text = entry
320 .as_ref()
321 .and_then(|(_, p)| p.get("text").and_then(|v| v.as_str()))
322 .map(String::from);
323
324 self.vector_index.delete(&memory_id).await?;
325
326 self.history
327 .add_event(memory_id, old_text.as_deref(), None, MemoryAction::Delete)
328 .await?;
329
330 Ok(())
331 }
332
333 pub async fn delete_all(
335 &self,
336 user_id: Option<&str>,
337 agent_id: Option<&str>,
338 run_id: Option<&str>,
339 ) -> Result<()> {
340 let filter = MemoryFilter {
341 user_id: user_id.map(String::from),
342 agent_id: agent_id.map(String::from),
343 run_id: run_id.map(String::from),
344 };
345
346 let entries = self.vector_index.list(Some(&filter), None).await?;
347 for (id, _) in entries {
348 self.vector_index.delete(&id).await?;
349 }
350
351 Ok(())
352 }
353
354 pub async fn history(&self, memory_id: Uuid) -> Result<Vec<MemoryEvent>> {
356 self.history.get_history(memory_id).await
357 }
358
359 pub async fn reset(&self) -> Result<()> {
361 self.vector_index.reset().await?;
362 self.history.reset().await?;
363 info!("MemoryEngine reset");
364 Ok(())
365 }
366}
367
368fn payload_to_memory_item(id: Uuid, payload: &serde_json::Value, score: Option<f32>) -> MemoryItem {
369 MemoryItem {
370 id,
371 text: payload
372 .get("text")
373 .and_then(|v| v.as_str())
374 .unwrap_or("")
375 .to_string(),
376 user_id: payload
377 .get("user_id")
378 .and_then(|v| v.as_str())
379 .map(String::from),
380 agent_id: payload
381 .get("agent_id")
382 .and_then(|v| v.as_str())
383 .map(String::from),
384 run_id: payload
385 .get("run_id")
386 .and_then(|v| v.as_str())
387 .map(String::from),
388 metadata: payload
389 .get("metadata")
390 .cloned()
391 .unwrap_or(serde_json::Value::Null),
392 created_at: payload
393 .get("created_at")
394 .and_then(|v| v.as_str())
395 .unwrap_or("")
396 .to_string(),
397 updated_at: payload
398 .get("updated_at")
399 .and_then(|v| v.as_str())
400 .unwrap_or("")
401 .to_string(),
402 score,
403 }
404}
405
406fn chrono_now() -> String {
407 let d = std::time::SystemTime::now()
408 .duration_since(std::time::UNIX_EPOCH)
409 .unwrap_or_default();
410 let secs = d.as_secs();
411 let days = secs / 86400;
412 let time_secs = secs % 86400;
413 let hours = time_secs / 3600;
414 let minutes = (time_secs % 3600) / 60;
415 let seconds = time_secs % 60;
416 let (year, month, day) = days_to_ymd(days);
417 format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
418}
419
420fn days_to_ymd(days_since_epoch: u64) -> (u64, u64, u64) {
421 let z = days_since_epoch + 719468;
422 let era = z / 146097;
423 let doe = z - era * 146097;
424 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
425 let y = yoe + era * 400;
426 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
427 let mp = (5 * doy + 2) / 153;
428 let d = doy - (153 * mp + 2) / 5 + 1;
429 let m = if mp < 10 { mp + 3 } else { mp - 9 };
430 let y = if m <= 2 { y + 1 } else { y };
431 (y, m, d)
432}