Skip to main content

mem7_store/
engine.rs

1use std::sync::Arc;
2
3use mem7_config::MemoryEngineConfig;
4use mem7_core::{
5    AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryEvent, MemoryFilter,
6    MemoryItem, SearchResult, new_memory_id,
7};
8use mem7_embedding::EmbeddingClient;
9use mem7_error::{Mem7Error, Result};
10use mem7_history::SqliteHistory;
11use mem7_llm::LlmClient;
12use mem7_vector::{VectorIndex, VectorSearchResult};
13use tracing::{debug, info};
14use uuid::Uuid;
15
16use crate::pipeline;
17
18/// The core memory engine. Orchestrates the full add/search/get/update/delete/history pipeline.
19pub struct MemoryEngine {
20    llm: Arc<dyn LlmClient>,
21    embedder: Arc<dyn EmbeddingClient>,
22    vector_index: Arc<dyn VectorIndex>,
23    history: Arc<SqliteHistory>,
24    config: MemoryEngineConfig,
25}
26
27impl MemoryEngine {
28    pub async fn new(config: MemoryEngineConfig) -> Result<Self> {
29        let llm = mem7_llm::create_llm(&config.llm)?;
30        let embedder = mem7_embedding::create_embedding(&config.embedding)?;
31        let vector_index = mem7_vector::create_vector_index(&config.vector)?;
32        let history = Arc::new(SqliteHistory::new(&config.history.db_path).await?);
33
34        info!("MemoryEngine initialized");
35
36        Ok(Self {
37            llm,
38            embedder,
39            vector_index,
40            history,
41            config,
42        })
43    }
44
45    /// Add memories from a conversation. Extracts facts, deduplicates, and stores.
46    pub async fn add(
47        &self,
48        messages: &[ChatMessage],
49        user_id: Option<&str>,
50        agent_id: Option<&str>,
51        run_id: Option<&str>,
52    ) -> Result<AddResult> {
53        let facts = pipeline::extract_facts(
54            self.llm.as_ref(),
55            messages,
56            self.config.custom_fact_extraction_prompt.as_deref(),
57        )
58        .await?;
59
60        if facts.is_empty() {
61            return Ok(AddResult {
62                results: Vec::new(),
63            });
64        }
65
66        debug!(count = facts.len(), "extracted facts");
67
68        let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
69        let embeddings = self.embedder.embed(&fact_texts).await?;
70
71        let filter = MemoryFilter {
72            user_id: user_id.map(String::from),
73            agent_id: agent_id.map(String::from),
74            run_id: run_id.map(String::from),
75        };
76        let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
77
78        for embedding in &embeddings {
79            let results = self
80                .vector_index
81                .search(embedding, 5, Some(&filter))
82                .await?;
83            for VectorSearchResult { id, score, payload } in results {
84                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
85                    all_retrieved.push((id, text.to_string(), score));
86                }
87            }
88        }
89
90        let (update_resp, id_mapping) = pipeline::decide_memory_updates(
91            self.llm.as_ref(),
92            &facts,
93            all_retrieved,
94            self.config.custom_update_memory_prompt.as_deref(),
95        )
96        .await?;
97
98        let now = chrono_now();
99        let mut results = Vec::new();
100
101        for decision in &update_resp.memory {
102            match decision.event {
103                MemoryAction::Add => {
104                    let memory_id = new_memory_id();
105                    let text = &decision.text;
106
107                    let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
108                    let vec = vecs.into_iter().next().unwrap_or_default();
109
110                    let payload = serde_json::json!({
111                        "text": text,
112                        "user_id": user_id,
113                        "agent_id": agent_id,
114                        "run_id": run_id,
115                        "created_at": now,
116                        "updated_at": now,
117                    });
118
119                    self.vector_index.insert(memory_id, &vec, payload).await?;
120
121                    self.history
122                        .add_event(memory_id, None, Some(text), MemoryAction::Add)
123                        .await?;
124
125                    results.push(MemoryActionResult {
126                        id: memory_id,
127                        action: MemoryAction::Add,
128                        old_value: None,
129                        new_value: Some(text.clone()),
130                    });
131                }
132                MemoryAction::Update => {
133                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
134                        let text = &decision.text;
135                        let old_text = decision.old_memory.as_deref();
136
137                        let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
138                        let vec = vecs.into_iter().next().unwrap_or_default();
139
140                        let payload = serde_json::json!({
141                            "text": text,
142                            "user_id": user_id,
143                            "agent_id": agent_id,
144                            "run_id": run_id,
145                            "updated_at": now,
146                        });
147
148                        self.vector_index
149                            .update(&real_id, Some(&vec), Some(payload))
150                            .await?;
151
152                        self.history
153                            .add_event(real_id, old_text, Some(text), MemoryAction::Update)
154                            .await?;
155
156                        results.push(MemoryActionResult {
157                            id: real_id,
158                            action: MemoryAction::Update,
159                            old_value: old_text.map(String::from),
160                            new_value: Some(text.clone()),
161                        });
162                    }
163                }
164                MemoryAction::Delete => {
165                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
166                        let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
167
168                        self.vector_index.delete(&real_id).await?;
169
170                        self.history
171                            .add_event(real_id, old_text, None, MemoryAction::Delete)
172                            .await?;
173
174                        results.push(MemoryActionResult {
175                            id: real_id,
176                            action: MemoryAction::Delete,
177                            old_value: old_text.map(String::from),
178                            new_value: None,
179                        });
180                    }
181                }
182                MemoryAction::None => {}
183            }
184        }
185
186        info!(count = results.len(), "memory operations completed");
187        Ok(AddResult { results })
188    }
189
190    /// Search memories by semantic similarity.
191    pub async fn search(
192        &self,
193        query: &str,
194        user_id: Option<&str>,
195        agent_id: Option<&str>,
196        run_id: Option<&str>,
197        limit: usize,
198    ) -> Result<SearchResult> {
199        let vecs = self.embedder.embed(&[query.to_string()]).await?;
200        let query_vec = vecs.into_iter().next().unwrap_or_default();
201
202        let filter = MemoryFilter {
203            user_id: user_id.map(String::from),
204            agent_id: agent_id.map(String::from),
205            run_id: run_id.map(String::from),
206        };
207
208        let results = self
209            .vector_index
210            .search(&query_vec, limit, Some(&filter))
211            .await?;
212
213        let memories = results
214            .into_iter()
215            .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
216            .collect();
217
218        Ok(SearchResult { memories })
219    }
220
221    /// Get a single memory by ID.
222    pub async fn get(&self, memory_id: Uuid) -> Result<MemoryItem> {
223        let entry = self
224            .vector_index
225            .get(&memory_id)
226            .await?
227            .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
228
229        Ok(payload_to_memory_item(memory_id, &entry.1, None))
230    }
231
232    /// List all memories matching the given filters.
233    pub async fn get_all(
234        &self,
235        user_id: Option<&str>,
236        agent_id: Option<&str>,
237        run_id: Option<&str>,
238    ) -> Result<Vec<MemoryItem>> {
239        let filter = MemoryFilter {
240            user_id: user_id.map(String::from),
241            agent_id: agent_id.map(String::from),
242            run_id: run_id.map(String::from),
243        };
244
245        let entries = self.vector_index.list(Some(&filter), None).await?;
246
247        Ok(entries
248            .into_iter()
249            .map(|(id, payload)| payload_to_memory_item(id, &payload, None))
250            .collect())
251    }
252
253    /// Update a memory's text directly.
254    pub async fn update(&self, memory_id: Uuid, new_text: &str) -> Result<()> {
255        let entry = self
256            .vector_index
257            .get(&memory_id)
258            .await?
259            .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
260
261        let old_text = entry
262            .1
263            .get("text")
264            .and_then(|v| v.as_str())
265            .map(String::from);
266
267        let vecs = self.embedder.embed(&[new_text.to_string()]).await?;
268        let vec = vecs.into_iter().next().unwrap_or_default();
269
270        let mut payload = entry.1.clone();
271        payload["text"] = serde_json::Value::String(new_text.to_string());
272        payload["updated_at"] = serde_json::Value::String(chrono_now());
273
274        self.vector_index
275            .update(&memory_id, Some(&vec), Some(payload))
276            .await?;
277
278        self.history
279            .add_event(
280                memory_id,
281                old_text.as_deref(),
282                Some(new_text),
283                MemoryAction::Update,
284            )
285            .await?;
286
287        Ok(())
288    }
289
290    /// Delete a memory by ID.
291    pub async fn delete(&self, memory_id: Uuid) -> Result<()> {
292        let entry = self.vector_index.get(&memory_id).await?;
293        let old_text = entry
294            .as_ref()
295            .and_then(|(_, p)| p.get("text").and_then(|v| v.as_str()))
296            .map(String::from);
297
298        self.vector_index.delete(&memory_id).await?;
299
300        self.history
301            .add_event(memory_id, old_text.as_deref(), None, MemoryAction::Delete)
302            .await?;
303
304        Ok(())
305    }
306
307    /// Delete all memories matching the given filters.
308    pub async fn delete_all(
309        &self,
310        user_id: Option<&str>,
311        agent_id: Option<&str>,
312        run_id: Option<&str>,
313    ) -> Result<()> {
314        let filter = MemoryFilter {
315            user_id: user_id.map(String::from),
316            agent_id: agent_id.map(String::from),
317            run_id: run_id.map(String::from),
318        };
319
320        let entries = self.vector_index.list(Some(&filter), None).await?;
321        for (id, _) in entries {
322            self.vector_index.delete(&id).await?;
323        }
324
325        Ok(())
326    }
327
328    /// Get the change history for a memory.
329    pub async fn history(&self, memory_id: Uuid) -> Result<Vec<MemoryEvent>> {
330        self.history.get_history(memory_id).await
331    }
332
333    /// Reset all data (vector index + history).
334    pub async fn reset(&self) -> Result<()> {
335        self.vector_index.reset().await?;
336        self.history.reset().await?;
337        info!("MemoryEngine reset");
338        Ok(())
339    }
340}
341
342fn payload_to_memory_item(id: Uuid, payload: &serde_json::Value, score: Option<f32>) -> MemoryItem {
343    MemoryItem {
344        id,
345        text: payload
346            .get("text")
347            .and_then(|v| v.as_str())
348            .unwrap_or("")
349            .to_string(),
350        user_id: payload
351            .get("user_id")
352            .and_then(|v| v.as_str())
353            .map(String::from),
354        agent_id: payload
355            .get("agent_id")
356            .and_then(|v| v.as_str())
357            .map(String::from),
358        run_id: payload
359            .get("run_id")
360            .and_then(|v| v.as_str())
361            .map(String::from),
362        metadata: payload
363            .get("metadata")
364            .cloned()
365            .unwrap_or(serde_json::Value::Null),
366        created_at: payload
367            .get("created_at")
368            .and_then(|v| v.as_str())
369            .unwrap_or("")
370            .to_string(),
371        updated_at: payload
372            .get("updated_at")
373            .and_then(|v| v.as_str())
374            .unwrap_or("")
375            .to_string(),
376        score,
377    }
378}
379
380fn chrono_now() -> String {
381    let d = std::time::SystemTime::now()
382        .duration_since(std::time::UNIX_EPOCH)
383        .unwrap_or_default();
384    let secs = d.as_secs();
385    let days = secs / 86400;
386    let time_secs = secs % 86400;
387    let hours = time_secs / 3600;
388    let minutes = (time_secs % 3600) / 60;
389    let seconds = time_secs % 60;
390    let (year, month, day) = days_to_ymd(days);
391    format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
392}
393
394fn days_to_ymd(days_since_epoch: u64) -> (u64, u64, u64) {
395    let z = days_since_epoch + 719468;
396    let era = z / 146097;
397    let doe = z - era * 146097;
398    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
399    let y = yoe + era * 400;
400    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
401    let mp = (5 * doy + 2) / 153;
402    let d = doy - (153 * mp + 2) / 5 + 1;
403    let m = if mp < 10 { mp + 3 } else { mp - 9 };
404    let y = if m <= 2 { y + 1 } else { y };
405    (y, m, d)
406}