Skip to main content

mem7_store/
add.rs

1use std::collections::HashMap;
2
3use mem7_core::{
4    AddOptions, AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryFilter,
5    new_memory_id,
6};
7use mem7_datetime::now_iso;
8use mem7_error::Result;
9use mem7_vector::VectorSearchResult;
10use tracing::{debug, info, instrument, warn};
11use uuid::Uuid;
12
13use crate::constants::*;
14use crate::decay;
15use crate::engine::MemoryEngine;
16use crate::payload::{
17    build_memory_payload, build_raw_memory_payload, build_update_payload, payload_to_event_metadata,
18};
19use crate::pipeline;
20use crate::prompts::VISION_DESCRIBE_PROMPT;
21use crate::require_scope;
22
23impl MemoryEngine {
24    /// Add memories from a conversation.
25    ///
26    /// When `infer` is `true` (the default), the LLM extracts facts from the
27    /// conversation, deduplicates them against existing memories, and decides
28    /// whether to add, update, or delete.
29    ///
30    /// When `infer` is `false`, each message's content is stored directly as a
31    /// new memory without any LLM processing — useful for importing raw text.
32    ///
33    /// `metadata` is an optional JSON object stored under `payload.metadata`.
34    #[instrument(skip(self, messages, metadata), fields(msg_count = messages.len()))]
35    pub async fn add(
36        &self,
37        messages: &[ChatMessage],
38        user_id: Option<&str>,
39        agent_id: Option<&str>,
40        run_id: Option<&str>,
41        metadata: Option<&serde_json::Value>,
42        infer: bool,
43    ) -> Result<AddResult> {
44        let opts = AddOptions {
45            user_id,
46            agent_id,
47            run_id,
48            metadata,
49            infer,
50        };
51        self.add_with_options(messages, &opts).await
52    }
53
54    /// Add memories using structured options.
55    pub async fn add_with_options(
56        &self,
57        messages: &[ChatMessage],
58        opts: &AddOptions<'_>,
59    ) -> Result<AddResult> {
60        require_scope("add", opts.user_id, opts.agent_id, opts.run_id)?;
61        if opts.infer {
62            self.add_with_inference(
63                messages,
64                opts.user_id,
65                opts.agent_id,
66                opts.run_id,
67                opts.metadata,
68            )
69            .await
70        } else {
71            self.add_raw(
72                messages,
73                opts.user_id,
74                opts.agent_id,
75                opts.run_id,
76                opts.metadata,
77            )
78            .await
79        }
80    }
81
82    /// Store raw message texts directly without LLM inference or deduplication.
83    /// System messages are skipped. Each message's `role` is stored in the payload.
84    async fn add_raw(
85        &self,
86        messages: &[ChatMessage],
87        user_id: Option<&str>,
88        agent_id: Option<&str>,
89        run_id: Option<&str>,
90        metadata: Option<&serde_json::Value>,
91    ) -> Result<AddResult> {
92        let non_system: Vec<&ChatMessage> =
93            messages.iter().filter(|m| m.role != "system").collect();
94
95        if non_system.is_empty() {
96            return Ok(AddResult {
97                results: Vec::new(),
98                relations: Vec::new(),
99            });
100        }
101
102        let owned: Vec<String> = non_system.iter().map(|m| m.content.clone()).collect();
103        let embeddings = self.embedder.embed(&owned).await?;
104
105        let now = now_iso();
106        let mut results = Vec::new();
107
108        for (msg, vec) in non_system.iter().zip(embeddings) {
109            let memory_id = new_memory_id();
110
111            let payload = build_raw_memory_payload(
112                &msg.content,
113                &msg.role,
114                user_id,
115                agent_id,
116                run_id,
117                metadata,
118                &now,
119            );
120            let audit = payload_to_event_metadata(&payload);
121
122            self.vector_index.insert(memory_id, &vec, payload).await?;
123
124            self.history
125                .add_event(
126                    memory_id,
127                    None,
128                    Some(&msg.content),
129                    MemoryAction::Add,
130                    audit,
131                )
132                .await?;
133
134            results.push(MemoryActionResult {
135                id: memory_id,
136                action: MemoryAction::Add,
137                old_value: None,
138                new_value: Some(msg.content.clone()),
139            });
140        }
141
142        let relations = if let Some(gp) = &self.graph_pipeline {
143            let conversation = non_system
144                .iter()
145                .map(|m| m.content.as_str())
146                .collect::<Vec<_>>()
147                .join("\n");
148            let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
149            gp.add(&conversation, &filter).await.unwrap_or_else(|e| {
150                warn!(error = %e, "graph extraction failed during raw add");
151                Vec::new()
152            })
153        } else {
154            Vec::new()
155        };
156
157        info!(count = results.len(), infer = false, "raw memories stored");
158        Ok(AddResult { results, relations })
159    }
160
161    /// Full LLM-powered pipeline: extract facts, deduplicate, store.
162    async fn add_with_inference(
163        &self,
164        messages: &[ChatMessage],
165        user_id: Option<&str>,
166        agent_id: Option<&str>,
167        run_id: Option<&str>,
168        metadata: Option<&serde_json::Value>,
169    ) -> Result<AddResult> {
170        let messages = if self.config.llm.enable_vision {
171            self.describe_images(messages).await?
172        } else {
173            messages.to_vec()
174        };
175
176        let conversation = messages
177            .iter()
178            .map(|m| format!("{}: {}", m.role, m.content))
179            .collect::<Vec<_>>()
180            .join("\n");
181
182        let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
183
184        let graph_future = async {
185            match &self.graph_pipeline {
186                Some(gp) => gp.add(&conversation, &filter).await,
187                None => Ok(Vec::new()),
188            }
189        };
190
191        let vector_future =
192            self.add_vector_with_inference(&messages, user_id, agent_id, run_id, metadata, &filter);
193
194        let (vector_result, graph_result) = tokio::join!(vector_future, graph_future);
195
196        let (results, _) = vector_result?;
197        let relations = graph_result.unwrap_or_else(|e| {
198            warn!(error = %e, "graph extraction failed");
199            Vec::new()
200        });
201
202        info!(
203            count = results.len(),
204            relations = relations.len(),
205            "memory operations completed"
206        );
207        Ok(AddResult { results, relations })
208    }
209
210    /// Vector-only inference pipeline, factored out for concurrent execution.
211    async fn add_vector_with_inference(
212        &self,
213        messages: &[ChatMessage],
214        user_id: Option<&str>,
215        agent_id: Option<&str>,
216        run_id: Option<&str>,
217        metadata: Option<&serde_json::Value>,
218        filter: &MemoryFilter,
219    ) -> Result<(Vec<MemoryActionResult>, ())> {
220        let facts = pipeline::extract_facts(
221            self.llm.as_ref(),
222            messages,
223            agent_id,
224            self.config.custom_fact_extraction_prompt.as_deref(),
225        )
226        .await?;
227
228        if facts.is_empty() {
229            return Ok((Vec::new(), ()));
230        }
231
232        debug!(count = facts.len(), "extracted facts");
233
234        let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
235        let embeddings = self.embedder.embed(&fact_texts).await?;
236
237        let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
238
239        let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
240
241        for embedding in &embeddings {
242            let results = self
243                .vector_index
244                .search(embedding, DEDUP_CANDIDATE_LIMIT, Some(filter))
245                .await?;
246            for VectorSearchResult { id, score, payload } in results {
247                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
248                    let effective_score = match decay_cfg {
249                        Some(cfg) => {
250                            let age = decay::age_from_payload(&payload);
251                            let ac = decay::access_count_from_payload(&payload);
252                            decay::apply_decay(score, age, ac, cfg)
253                        }
254                        None => score,
255                    };
256                    all_retrieved.push((id, text.to_string(), effective_score));
257                }
258            }
259        }
260
261        let (update_resp, id_mapping) = pipeline::decide_memory_updates(
262            self.llm.as_ref(),
263            &facts,
264            all_retrieved,
265            self.config.custom_update_memory_prompt.as_deref(),
266        )
267        .await?;
268
269        let fact_type_map: HashMap<&str, &str> = facts
270            .iter()
271            .map(|f| (f.text.as_str(), f.memory_type.as_str()))
272            .collect();
273
274        let now = now_iso();
275        let mut results = Vec::new();
276
277        for decision in &update_resp.memory {
278            match decision.event {
279                MemoryAction::Add => {
280                    let memory_id = new_memory_id();
281                    let text = &decision.text;
282
283                    let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
284                    let vec = vecs.into_iter().next().unwrap_or_default();
285
286                    let mt = fact_type_map.get(text.as_str()).copied();
287                    let payload =
288                        build_memory_payload(text, user_id, agent_id, run_id, metadata, &now, mt);
289                    let audit = payload_to_event_metadata(&payload);
290
291                    self.vector_index.insert(memory_id, &vec, payload).await?;
292
293                    self.history
294                        .add_event(memory_id, None, Some(text), MemoryAction::Add, audit)
295                        .await?;
296
297                    results.push(MemoryActionResult {
298                        id: memory_id,
299                        action: MemoryAction::Add,
300                        old_value: None,
301                        new_value: Some(text.clone()),
302                    });
303                }
304                MemoryAction::Update => {
305                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
306                        let text = &decision.text;
307                        let old_text = decision.old_memory.as_deref();
308
309                        let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
310                        let vec = vecs.into_iter().next().unwrap_or_default();
311
312                        let existing_entry = self.vector_index.get(&real_id).await.ok().flatten();
313                        let prev_ac = existing_entry
314                            .as_ref()
315                            .map(|(_, p)| decay::access_count_from_payload(p))
316                            .unwrap_or(0);
317                        let existing_mt = existing_entry
318                            .as_ref()
319                            .and_then(|(_, p)| p.get("memory_type").and_then(|v| v.as_str()));
320                        let existing_created_at = existing_entry
321                            .as_ref()
322                            .and_then(|(_, p)| p.get("created_at").and_then(|v| v.as_str()));
323                        let mt = existing_mt.or_else(|| fact_type_map.get(text.as_str()).copied());
324
325                        let payload = build_update_payload(
326                            text,
327                            user_id,
328                            agent_id,
329                            run_id,
330                            metadata,
331                            existing_created_at,
332                            &now,
333                            prev_ac + 1,
334                            mt,
335                        );
336                        let audit = payload_to_event_metadata(&payload);
337
338                        self.vector_index
339                            .update(&real_id, Some(&vec), Some(payload))
340                            .await?;
341
342                        self.history
343                            .add_event(real_id, old_text, Some(text), MemoryAction::Update, audit)
344                            .await?;
345
346                        results.push(MemoryActionResult {
347                            id: real_id,
348                            action: MemoryAction::Update,
349                            old_value: old_text.map(String::from),
350                            new_value: Some(text.clone()),
351                        });
352                    }
353                }
354                MemoryAction::Delete => {
355                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
356                        let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
357                        let audit = self
358                            .vector_index
359                            .get(&real_id)
360                            .await
361                            .ok()
362                            .flatten()
363                            .map(|(_, payload)| {
364                                let mut metadata = payload_to_event_metadata(&payload);
365                                metadata.is_deleted = true;
366                                metadata
367                            })
368                            .unwrap_or_else(|| mem7_core::MemoryEventMetadata {
369                                is_deleted: true,
370                                ..Default::default()
371                            });
372
373                        self.vector_index.delete(&real_id).await?;
374
375                        self.history
376                            .add_event(real_id, old_text, None, MemoryAction::Delete, audit)
377                            .await?;
378
379                        results.push(MemoryActionResult {
380                            id: real_id,
381                            action: MemoryAction::Delete,
382                            old_value: old_text.map(String::from),
383                            new_value: None,
384                        });
385                    }
386                }
387                MemoryAction::None => {
388                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
389                        let needs_update = agent_id.is_some() || run_id.is_some();
390                        if needs_update
391                            && let Ok(Some(entry)) = self.vector_index.get(&real_id).await
392                        {
393                            let mut payload = entry.1;
394                            let mut changed = false;
395                            if let Some(aid) = agent_id {
396                                let cur = payload.get("agent_id").and_then(|v| v.as_str());
397                                if cur != Some(aid) {
398                                    payload["agent_id"] =
399                                        serde_json::Value::String(aid.to_string());
400                                    changed = true;
401                                }
402                            }
403                            if let Some(rid) = run_id {
404                                let cur = payload.get("run_id").and_then(|v| v.as_str());
405                                if cur != Some(rid) {
406                                    payload["run_id"] = serde_json::Value::String(rid.to_string());
407                                    changed = true;
408                                }
409                            }
410                            if changed {
411                                payload["updated_at"] = serde_json::Value::String(now.clone());
412                                if let Err(e) = self
413                                    .vector_index
414                                    .update(&real_id, None, Some(payload))
415                                    .await
416                                {
417                                    warn!(id = %real_id, "failed to update session IDs: {e}");
418                                } else {
419                                    debug!(
420                                        id = %real_id,
421                                        "updated session IDs on NONE action"
422                                    );
423                                }
424                            }
425                        }
426                    }
427                }
428            }
429        }
430
431        Ok((results, ()))
432    }
433
434    /// When `enable_vision` is set, send each message's images to the LLM and
435    /// append the resulting description to the message's text content.
436    pub(crate) async fn describe_images(
437        &self,
438        messages: &[ChatMessage],
439    ) -> Result<Vec<ChatMessage>> {
440        let mut out = Vec::with_capacity(messages.len());
441        for msg in messages {
442            if msg.images.is_empty() {
443                out.push(msg.clone());
444                continue;
445            }
446
447            let llm_msg = mem7_llm::LlmMessage::user_with_images(
448                VISION_DESCRIBE_PROMPT.to_string(),
449                msg.images.clone(),
450            );
451            match self.llm.chat_completion(&[llm_msg], None).await {
452                Ok(resp) => {
453                    let mut enriched = msg.clone();
454                    if enriched.content.is_empty() {
455                        enriched.content = resp.content;
456                    } else {
457                        enriched.content = format!(
458                            "{}\n[Image description: {}]",
459                            enriched.content, resp.content
460                        );
461                    }
462                    enriched.images.clear();
463                    out.push(enriched);
464                }
465                Err(e) => {
466                    warn!(error = %e, "vision description failed, using original text");
467                    out.push(msg.clone());
468                }
469            }
470        }
471        Ok(out)
472    }
473}