1use std::sync::Arc;
2
3use mem7_config::MemoryEngineConfig;
4use mem7_core::{
5 AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryEvent, MemoryFilter,
6 MemoryItem, SearchResult, new_memory_id,
7};
8use mem7_embedding::EmbeddingClient;
9use mem7_error::{Mem7Error, Result};
10use mem7_history::SqliteHistory;
11use mem7_llm::LlmClient;
12use mem7_vector::{VectorIndex, VectorSearchResult};
13use tracing::{debug, info};
14use uuid::Uuid;
15
16use crate::pipeline;
17
18pub struct MemoryEngine {
20 llm: Arc<dyn LlmClient>,
21 embedder: Arc<dyn EmbeddingClient>,
22 vector_index: Arc<dyn VectorIndex>,
23 history: Arc<SqliteHistory>,
24 config: MemoryEngineConfig,
25}
26
27impl MemoryEngine {
28 pub async fn new(config: MemoryEngineConfig) -> Result<Self> {
29 let llm = mem7_llm::create_llm(&config.llm)?;
30 let embedder = mem7_embedding::create_embedding(&config.embedding)?;
31 let vector_index = mem7_vector::create_vector_index(&config.vector)?;
32 let history = Arc::new(SqliteHistory::new(&config.history.db_path).await?);
33
34 info!("MemoryEngine initialized");
35
36 Ok(Self {
37 llm,
38 embedder,
39 vector_index,
40 history,
41 config,
42 })
43 }
44
45 pub async fn add(
47 &self,
48 messages: &[ChatMessage],
49 user_id: Option<&str>,
50 agent_id: Option<&str>,
51 run_id: Option<&str>,
52 ) -> Result<AddResult> {
53 let facts = pipeline::extract_facts(
54 self.llm.as_ref(),
55 messages,
56 self.config.custom_fact_extraction_prompt.as_deref(),
57 )
58 .await?;
59
60 if facts.is_empty() {
61 return Ok(AddResult {
62 results: Vec::new(),
63 });
64 }
65
66 debug!(count = facts.len(), "extracted facts");
67
68 let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
69 let embeddings = self.embedder.embed(&fact_texts).await?;
70
71 let filter = MemoryFilter {
72 user_id: user_id.map(String::from),
73 agent_id: agent_id.map(String::from),
74 run_id: run_id.map(String::from),
75 };
76 let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
77
78 for embedding in &embeddings {
79 let results = self
80 .vector_index
81 .search(embedding, 5, Some(&filter))
82 .await?;
83 for VectorSearchResult { id, score, payload } in results {
84 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
85 all_retrieved.push((id, text.to_string(), score));
86 }
87 }
88 }
89
90 let (update_resp, id_mapping) = pipeline::decide_memory_updates(
91 self.llm.as_ref(),
92 &facts,
93 all_retrieved,
94 self.config.custom_update_memory_prompt.as_deref(),
95 )
96 .await?;
97
98 let now = chrono_now();
99 let mut results = Vec::new();
100
101 for decision in &update_resp.memory {
102 match decision.event {
103 MemoryAction::Add => {
104 let memory_id = new_memory_id();
105 let text = &decision.text;
106
107 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
108 let vec = vecs.into_iter().next().unwrap_or_default();
109
110 let payload = serde_json::json!({
111 "text": text,
112 "user_id": user_id,
113 "agent_id": agent_id,
114 "run_id": run_id,
115 "created_at": now,
116 "updated_at": now,
117 });
118
119 self.vector_index.insert(memory_id, &vec, payload).await?;
120
121 self.history
122 .add_event(memory_id, None, Some(text), MemoryAction::Add)
123 .await?;
124
125 results.push(MemoryActionResult {
126 id: memory_id,
127 action: MemoryAction::Add,
128 old_value: None,
129 new_value: Some(text.clone()),
130 });
131 }
132 MemoryAction::Update => {
133 if let Some(real_id) = id_mapping.resolve(&decision.id) {
134 let text = &decision.text;
135 let old_text = decision.old_memory.as_deref();
136
137 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
138 let vec = vecs.into_iter().next().unwrap_or_default();
139
140 let payload = serde_json::json!({
141 "text": text,
142 "user_id": user_id,
143 "agent_id": agent_id,
144 "run_id": run_id,
145 "updated_at": now,
146 });
147
148 self.vector_index
149 .update(&real_id, Some(&vec), Some(payload))
150 .await?;
151
152 self.history
153 .add_event(real_id, old_text, Some(text), MemoryAction::Update)
154 .await?;
155
156 results.push(MemoryActionResult {
157 id: real_id,
158 action: MemoryAction::Update,
159 old_value: old_text.map(String::from),
160 new_value: Some(text.clone()),
161 });
162 }
163 }
164 MemoryAction::Delete => {
165 if let Some(real_id) = id_mapping.resolve(&decision.id) {
166 let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
167
168 self.vector_index.delete(&real_id).await?;
169
170 self.history
171 .add_event(real_id, old_text, None, MemoryAction::Delete)
172 .await?;
173
174 results.push(MemoryActionResult {
175 id: real_id,
176 action: MemoryAction::Delete,
177 old_value: old_text.map(String::from),
178 new_value: None,
179 });
180 }
181 }
182 MemoryAction::None => {}
183 }
184 }
185
186 info!(count = results.len(), "memory operations completed");
187 Ok(AddResult { results })
188 }
189
190 pub async fn search(
192 &self,
193 query: &str,
194 user_id: Option<&str>,
195 agent_id: Option<&str>,
196 run_id: Option<&str>,
197 limit: usize,
198 ) -> Result<SearchResult> {
199 let vecs = self.embedder.embed(&[query.to_string()]).await?;
200 let query_vec = vecs.into_iter().next().unwrap_or_default();
201
202 let filter = MemoryFilter {
203 user_id: user_id.map(String::from),
204 agent_id: agent_id.map(String::from),
205 run_id: run_id.map(String::from),
206 };
207
208 let results = self
209 .vector_index
210 .search(&query_vec, limit, Some(&filter))
211 .await?;
212
213 let memories = results
214 .into_iter()
215 .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
216 .collect();
217
218 Ok(SearchResult { memories })
219 }
220
221 pub async fn get(&self, memory_id: Uuid) -> Result<MemoryItem> {
223 let entry = self
224 .vector_index
225 .get(&memory_id)
226 .await?
227 .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
228
229 Ok(payload_to_memory_item(memory_id, &entry.1, None))
230 }
231
232 pub async fn get_all(
234 &self,
235 user_id: Option<&str>,
236 agent_id: Option<&str>,
237 run_id: Option<&str>,
238 ) -> Result<Vec<MemoryItem>> {
239 let filter = MemoryFilter {
240 user_id: user_id.map(String::from),
241 agent_id: agent_id.map(String::from),
242 run_id: run_id.map(String::from),
243 };
244
245 let entries = self.vector_index.list(Some(&filter), None).await?;
246
247 Ok(entries
248 .into_iter()
249 .map(|(id, payload)| payload_to_memory_item(id, &payload, None))
250 .collect())
251 }
252
253 pub async fn update(&self, memory_id: Uuid, new_text: &str) -> Result<()> {
255 let entry = self
256 .vector_index
257 .get(&memory_id)
258 .await?
259 .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
260
261 let old_text = entry
262 .1
263 .get("text")
264 .and_then(|v| v.as_str())
265 .map(String::from);
266
267 let vecs = self.embedder.embed(&[new_text.to_string()]).await?;
268 let vec = vecs.into_iter().next().unwrap_or_default();
269
270 let mut payload = entry.1.clone();
271 payload["text"] = serde_json::Value::String(new_text.to_string());
272 payload["updated_at"] = serde_json::Value::String(chrono_now());
273
274 self.vector_index
275 .update(&memory_id, Some(&vec), Some(payload))
276 .await?;
277
278 self.history
279 .add_event(
280 memory_id,
281 old_text.as_deref(),
282 Some(new_text),
283 MemoryAction::Update,
284 )
285 .await?;
286
287 Ok(())
288 }
289
290 pub async fn delete(&self, memory_id: Uuid) -> Result<()> {
292 let entry = self.vector_index.get(&memory_id).await?;
293 let old_text = entry
294 .as_ref()
295 .and_then(|(_, p)| p.get("text").and_then(|v| v.as_str()))
296 .map(String::from);
297
298 self.vector_index.delete(&memory_id).await?;
299
300 self.history
301 .add_event(memory_id, old_text.as_deref(), None, MemoryAction::Delete)
302 .await?;
303
304 Ok(())
305 }
306
307 pub async fn delete_all(
309 &self,
310 user_id: Option<&str>,
311 agent_id: Option<&str>,
312 run_id: Option<&str>,
313 ) -> Result<()> {
314 let filter = MemoryFilter {
315 user_id: user_id.map(String::from),
316 agent_id: agent_id.map(String::from),
317 run_id: run_id.map(String::from),
318 };
319
320 let entries = self.vector_index.list(Some(&filter), None).await?;
321 for (id, _) in entries {
322 self.vector_index.delete(&id).await?;
323 }
324
325 Ok(())
326 }
327
328 pub async fn history(&self, memory_id: Uuid) -> Result<Vec<MemoryEvent>> {
330 self.history.get_history(memory_id).await
331 }
332
333 pub async fn reset(&self) -> Result<()> {
335 self.vector_index.reset().await?;
336 self.history.reset().await?;
337 info!("MemoryEngine reset");
338 Ok(())
339 }
340}
341
342fn payload_to_memory_item(id: Uuid, payload: &serde_json::Value, score: Option<f32>) -> MemoryItem {
343 MemoryItem {
344 id,
345 text: payload
346 .get("text")
347 .and_then(|v| v.as_str())
348 .unwrap_or("")
349 .to_string(),
350 user_id: payload
351 .get("user_id")
352 .and_then(|v| v.as_str())
353 .map(String::from),
354 agent_id: payload
355 .get("agent_id")
356 .and_then(|v| v.as_str())
357 .map(String::from),
358 run_id: payload
359 .get("run_id")
360 .and_then(|v| v.as_str())
361 .map(String::from),
362 metadata: payload
363 .get("metadata")
364 .cloned()
365 .unwrap_or(serde_json::Value::Null),
366 created_at: payload
367 .get("created_at")
368 .and_then(|v| v.as_str())
369 .unwrap_or("")
370 .to_string(),
371 updated_at: payload
372 .get("updated_at")
373 .and_then(|v| v.as_str())
374 .unwrap_or("")
375 .to_string(),
376 score,
377 }
378}
379
380fn chrono_now() -> String {
381 let d = std::time::SystemTime::now()
382 .duration_since(std::time::UNIX_EPOCH)
383 .unwrap_or_default();
384 let secs = d.as_secs();
385 let days = secs / 86400;
386 let time_secs = secs % 86400;
387 let hours = time_secs / 3600;
388 let minutes = (time_secs % 3600) / 60;
389 let seconds = time_secs % 60;
390 let (year, month, day) = days_to_ymd(days);
391 format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
392}
393
394fn days_to_ymd(days_since_epoch: u64) -> (u64, u64, u64) {
395 let z = days_since_epoch + 719468;
396 let era = z / 146097;
397 let doe = z - era * 146097;
398 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
399 let y = yoe + era * 400;
400 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
401 let mp = (5 * doy + 2) / 153;
402 let d = doy - (153 * mp + 2) / 5 + 1;
403 let m = if mp < 10 { mp + 3 } else { mp - 9 };
404 let y = if m <= 2 { y + 1 } else { y };
405 (y, m, d)
406}