Skip to main content

mem7_store/
engine.rs

1use std::sync::Arc;
2
3use mem7_config::MemoryEngineConfig;
4use mem7_core::{
5    AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryEvent, MemoryFilter,
6    MemoryItem, SearchResult, new_memory_id,
7};
8use mem7_embedding::{EmbeddingClient, OpenAICompatibleEmbedding};
9use mem7_error::{Mem7Error, Result};
10use mem7_history::SqliteHistory;
11use mem7_llm::{LlmClient, OpenAICompatibleLlm};
12use mem7_vector::{DistanceMetric, FlatIndex, UpstashVectorIndex, VectorIndex, VectorSearchResult};
13use tracing::{debug, info};
14use uuid::Uuid;
15
16use crate::pipeline;
17
18/// The core memory engine. Orchestrates the full add/search/get/update/delete/history pipeline.
19pub struct MemoryEngine {
20    llm: Arc<dyn LlmClient>,
21    embedder: Arc<dyn EmbeddingClient>,
22    vector_index: Arc<dyn VectorIndex>,
23    history: Arc<SqliteHistory>,
24    config: MemoryEngineConfig,
25}
26
27impl MemoryEngine {
28    pub async fn new(config: MemoryEngineConfig) -> Result<Self> {
29        let llm = Arc::new(OpenAICompatibleLlm::new(config.llm.clone())) as Arc<dyn LlmClient>;
30        let embedder = Arc::new(OpenAICompatibleEmbedding::new(config.embedding.clone()))
31            as Arc<dyn EmbeddingClient>;
32
33        let vector_index: Arc<dyn VectorIndex> = match config.vector.provider.as_str() {
34            "upstash" => {
35                let url = config
36                    .vector
37                    .upstash_url
38                    .as_deref()
39                    .ok_or_else(|| Mem7Error::Config("upstash_url is required".into()))?;
40                let token = config
41                    .vector
42                    .upstash_token
43                    .as_deref()
44                    .ok_or_else(|| Mem7Error::Config("upstash_token is required".into()))?;
45                info!(namespace = %config.vector.collection_name, "using Upstash Vector");
46                Arc::new(UpstashVectorIndex::new(
47                    url,
48                    token,
49                    &config.vector.collection_name,
50                ))
51            }
52            _ => {
53                info!("using in-memory FlatIndex");
54                Arc::new(FlatIndex::new(DistanceMetric::Cosine))
55            }
56        };
57
58        let history = Arc::new(SqliteHistory::new(&config.history.db_path).await?);
59
60        info!("MemoryEngine initialized");
61
62        Ok(Self {
63            llm,
64            embedder,
65            vector_index,
66            history,
67            config,
68        })
69    }
70
71    /// Add memories from a conversation. Extracts facts, deduplicates, and stores.
72    pub async fn add(
73        &self,
74        messages: &[ChatMessage],
75        user_id: Option<&str>,
76        agent_id: Option<&str>,
77        run_id: Option<&str>,
78    ) -> Result<AddResult> {
79        let facts = pipeline::extract_facts(
80            self.llm.as_ref(),
81            messages,
82            self.config.custom_fact_extraction_prompt.as_deref(),
83        )
84        .await?;
85
86        if facts.is_empty() {
87            return Ok(AddResult {
88                results: Vec::new(),
89            });
90        }
91
92        debug!(count = facts.len(), "extracted facts");
93
94        let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
95        let embeddings = self.embedder.embed(&fact_texts).await?;
96
97        let filter = MemoryFilter {
98            user_id: user_id.map(String::from),
99            agent_id: agent_id.map(String::from),
100            run_id: run_id.map(String::from),
101        };
102        let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
103
104        for embedding in &embeddings {
105            let results = self
106                .vector_index
107                .search(embedding, 5, Some(&filter))
108                .await?;
109            for VectorSearchResult { id, score, payload } in results {
110                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
111                    all_retrieved.push((id, text.to_string(), score));
112                }
113            }
114        }
115
116        let (update_resp, id_mapping) = pipeline::decide_memory_updates(
117            self.llm.as_ref(),
118            &facts,
119            all_retrieved,
120            self.config.custom_update_memory_prompt.as_deref(),
121        )
122        .await?;
123
124        let now = chrono_now();
125        let mut results = Vec::new();
126
127        for decision in &update_resp.memory {
128            match decision.event {
129                MemoryAction::Add => {
130                    let memory_id = new_memory_id();
131                    let text = &decision.text;
132
133                    let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
134                    let vec = vecs.into_iter().next().unwrap_or_default();
135
136                    let payload = serde_json::json!({
137                        "text": text,
138                        "user_id": user_id,
139                        "agent_id": agent_id,
140                        "run_id": run_id,
141                        "created_at": now,
142                        "updated_at": now,
143                    });
144
145                    self.vector_index.insert(memory_id, &vec, payload).await?;
146
147                    self.history
148                        .add_event(memory_id, None, Some(text), MemoryAction::Add)
149                        .await?;
150
151                    results.push(MemoryActionResult {
152                        id: memory_id,
153                        action: MemoryAction::Add,
154                        old_value: None,
155                        new_value: Some(text.clone()),
156                    });
157                }
158                MemoryAction::Update => {
159                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
160                        let text = &decision.text;
161                        let old_text = decision.old_memory.as_deref();
162
163                        let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
164                        let vec = vecs.into_iter().next().unwrap_or_default();
165
166                        let payload = serde_json::json!({
167                            "text": text,
168                            "user_id": user_id,
169                            "agent_id": agent_id,
170                            "run_id": run_id,
171                            "updated_at": now,
172                        });
173
174                        self.vector_index
175                            .update(&real_id, Some(&vec), Some(payload))
176                            .await?;
177
178                        self.history
179                            .add_event(real_id, old_text, Some(text), MemoryAction::Update)
180                            .await?;
181
182                        results.push(MemoryActionResult {
183                            id: real_id,
184                            action: MemoryAction::Update,
185                            old_value: old_text.map(String::from),
186                            new_value: Some(text.clone()),
187                        });
188                    }
189                }
190                MemoryAction::Delete => {
191                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
192                        let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
193
194                        self.vector_index.delete(&real_id).await?;
195
196                        self.history
197                            .add_event(real_id, old_text, None, MemoryAction::Delete)
198                            .await?;
199
200                        results.push(MemoryActionResult {
201                            id: real_id,
202                            action: MemoryAction::Delete,
203                            old_value: old_text.map(String::from),
204                            new_value: None,
205                        });
206                    }
207                }
208                MemoryAction::None => {}
209            }
210        }
211
212        info!(count = results.len(), "memory operations completed");
213        Ok(AddResult { results })
214    }
215
216    /// Search memories by semantic similarity.
217    pub async fn search(
218        &self,
219        query: &str,
220        user_id: Option<&str>,
221        agent_id: Option<&str>,
222        run_id: Option<&str>,
223        limit: usize,
224    ) -> Result<SearchResult> {
225        let vecs = self.embedder.embed(&[query.to_string()]).await?;
226        let query_vec = vecs.into_iter().next().unwrap_or_default();
227
228        let filter = MemoryFilter {
229            user_id: user_id.map(String::from),
230            agent_id: agent_id.map(String::from),
231            run_id: run_id.map(String::from),
232        };
233
234        let results = self
235            .vector_index
236            .search(&query_vec, limit, Some(&filter))
237            .await?;
238
239        let memories = results
240            .into_iter()
241            .map(|r| payload_to_memory_item(r.id, &r.payload, Some(r.score)))
242            .collect();
243
244        Ok(SearchResult { memories })
245    }
246
247    /// Get a single memory by ID.
248    pub async fn get(&self, memory_id: Uuid) -> Result<MemoryItem> {
249        let entry = self
250            .vector_index
251            .get(&memory_id)
252            .await?
253            .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
254
255        Ok(payload_to_memory_item(memory_id, &entry.1, None))
256    }
257
258    /// List all memories matching the given filters.
259    pub async fn get_all(
260        &self,
261        user_id: Option<&str>,
262        agent_id: Option<&str>,
263        run_id: Option<&str>,
264    ) -> Result<Vec<MemoryItem>> {
265        let filter = MemoryFilter {
266            user_id: user_id.map(String::from),
267            agent_id: agent_id.map(String::from),
268            run_id: run_id.map(String::from),
269        };
270
271        let entries = self.vector_index.list(Some(&filter), None).await?;
272
273        Ok(entries
274            .into_iter()
275            .map(|(id, payload)| payload_to_memory_item(id, &payload, None))
276            .collect())
277    }
278
279    /// Update a memory's text directly.
280    pub async fn update(&self, memory_id: Uuid, new_text: &str) -> Result<()> {
281        let entry = self
282            .vector_index
283            .get(&memory_id)
284            .await?
285            .ok_or_else(|| Mem7Error::NotFound(format!("memory {memory_id}")))?;
286
287        let old_text = entry
288            .1
289            .get("text")
290            .and_then(|v| v.as_str())
291            .map(String::from);
292
293        let vecs = self.embedder.embed(&[new_text.to_string()]).await?;
294        let vec = vecs.into_iter().next().unwrap_or_default();
295
296        let mut payload = entry.1.clone();
297        payload["text"] = serde_json::Value::String(new_text.to_string());
298        payload["updated_at"] = serde_json::Value::String(chrono_now());
299
300        self.vector_index
301            .update(&memory_id, Some(&vec), Some(payload))
302            .await?;
303
304        self.history
305            .add_event(
306                memory_id,
307                old_text.as_deref(),
308                Some(new_text),
309                MemoryAction::Update,
310            )
311            .await?;
312
313        Ok(())
314    }
315
316    /// Delete a memory by ID.
317    pub async fn delete(&self, memory_id: Uuid) -> Result<()> {
318        let entry = self.vector_index.get(&memory_id).await?;
319        let old_text = entry
320            .as_ref()
321            .and_then(|(_, p)| p.get("text").and_then(|v| v.as_str()))
322            .map(String::from);
323
324        self.vector_index.delete(&memory_id).await?;
325
326        self.history
327            .add_event(memory_id, old_text.as_deref(), None, MemoryAction::Delete)
328            .await?;
329
330        Ok(())
331    }
332
333    /// Delete all memories matching the given filters.
334    pub async fn delete_all(
335        &self,
336        user_id: Option<&str>,
337        agent_id: Option<&str>,
338        run_id: Option<&str>,
339    ) -> Result<()> {
340        let filter = MemoryFilter {
341            user_id: user_id.map(String::from),
342            agent_id: agent_id.map(String::from),
343            run_id: run_id.map(String::from),
344        };
345
346        let entries = self.vector_index.list(Some(&filter), None).await?;
347        for (id, _) in entries {
348            self.vector_index.delete(&id).await?;
349        }
350
351        Ok(())
352    }
353
354    /// Get the change history for a memory.
355    pub async fn history(&self, memory_id: Uuid) -> Result<Vec<MemoryEvent>> {
356        self.history.get_history(memory_id).await
357    }
358
359    /// Reset all data (vector index + history).
360    pub async fn reset(&self) -> Result<()> {
361        self.vector_index.reset().await?;
362        self.history.reset().await?;
363        info!("MemoryEngine reset");
364        Ok(())
365    }
366}
367
368fn payload_to_memory_item(id: Uuid, payload: &serde_json::Value, score: Option<f32>) -> MemoryItem {
369    MemoryItem {
370        id,
371        text: payload
372            .get("text")
373            .and_then(|v| v.as_str())
374            .unwrap_or("")
375            .to_string(),
376        user_id: payload
377            .get("user_id")
378            .and_then(|v| v.as_str())
379            .map(String::from),
380        agent_id: payload
381            .get("agent_id")
382            .and_then(|v| v.as_str())
383            .map(String::from),
384        run_id: payload
385            .get("run_id")
386            .and_then(|v| v.as_str())
387            .map(String::from),
388        metadata: payload
389            .get("metadata")
390            .cloned()
391            .unwrap_or(serde_json::Value::Null),
392        created_at: payload
393            .get("created_at")
394            .and_then(|v| v.as_str())
395            .unwrap_or("")
396            .to_string(),
397        updated_at: payload
398            .get("updated_at")
399            .and_then(|v| v.as_str())
400            .unwrap_or("")
401            .to_string(),
402        score,
403    }
404}
405
406fn chrono_now() -> String {
407    let d = std::time::SystemTime::now()
408        .duration_since(std::time::UNIX_EPOCH)
409        .unwrap_or_default();
410    let secs = d.as_secs();
411    let days = secs / 86400;
412    let time_secs = secs % 86400;
413    let hours = time_secs / 3600;
414    let minutes = (time_secs % 3600) / 60;
415    let seconds = time_secs % 60;
416    let (year, month, day) = days_to_ymd(days);
417    format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
418}
419
420fn days_to_ymd(days_since_epoch: u64) -> (u64, u64, u64) {
421    let z = days_since_epoch + 719468;
422    let era = z / 146097;
423    let doe = z - era * 146097;
424    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
425    let y = yoe + era * 400;
426    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
427    let mp = (5 * doy + 2) / 153;
428    let d = doy - (153 * mp + 2) / 5 + 1;
429    let m = if mp < 10 { mp + 3 } else { mp - 9 };
430    let y = if m <= 2 { y + 1 } else { y };
431    (y, m, d)
432}