Skip to main content

mem7_store/
add.rs

1use std::collections::HashMap;
2
3use mem7_core::{
4    AddOptions, AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryFilter,
5    new_memory_id,
6};
7use mem7_datetime::now_iso;
8use mem7_error::Result;
9use mem7_vector::VectorSearchResult;
10use tracing::{debug, info, instrument, warn};
11use uuid::Uuid;
12
13use crate::constants::*;
14use crate::decay;
15use crate::engine::MemoryEngine;
16use crate::payload::{build_memory_payload, build_raw_memory_payload, build_update_payload};
17use crate::pipeline;
18use crate::prompts::VISION_DESCRIBE_PROMPT;
19
20impl MemoryEngine {
21    /// Add memories from a conversation.
22    ///
23    /// When `infer` is `true` (the default), the LLM extracts facts from the
24    /// conversation, deduplicates them against existing memories, and decides
25    /// whether to add, update, or delete.
26    ///
27    /// When `infer` is `false`, each message's content is stored directly as a
28    /// new memory without any LLM processing — useful for importing raw text.
29    ///
30    /// `metadata` is an optional JSON object stored under `payload.metadata`.
31    #[instrument(skip(self, messages, metadata), fields(msg_count = messages.len()))]
32    pub async fn add(
33        &self,
34        messages: &[ChatMessage],
35        user_id: Option<&str>,
36        agent_id: Option<&str>,
37        run_id: Option<&str>,
38        metadata: Option<&serde_json::Value>,
39        infer: bool,
40    ) -> Result<AddResult> {
41        let opts = AddOptions {
42            user_id,
43            agent_id,
44            run_id,
45            metadata,
46            infer,
47        };
48        self.add_with_options(messages, &opts).await
49    }
50
51    /// Add memories using structured options.
52    pub async fn add_with_options(
53        &self,
54        messages: &[ChatMessage],
55        opts: &AddOptions<'_>,
56    ) -> Result<AddResult> {
57        if opts.infer {
58            self.add_with_inference(
59                messages,
60                opts.user_id,
61                opts.agent_id,
62                opts.run_id,
63                opts.metadata,
64            )
65            .await
66        } else {
67            self.add_raw(
68                messages,
69                opts.user_id,
70                opts.agent_id,
71                opts.run_id,
72                opts.metadata,
73            )
74            .await
75        }
76    }
77
78    /// Store raw message texts directly without LLM inference or deduplication.
79    /// System messages are skipped. Each message's `role` is stored in the payload.
80    async fn add_raw(
81        &self,
82        messages: &[ChatMessage],
83        user_id: Option<&str>,
84        agent_id: Option<&str>,
85        run_id: Option<&str>,
86        metadata: Option<&serde_json::Value>,
87    ) -> Result<AddResult> {
88        let non_system: Vec<&ChatMessage> =
89            messages.iter().filter(|m| m.role != "system").collect();
90
91        if non_system.is_empty() {
92            return Ok(AddResult {
93                results: Vec::new(),
94                relations: Vec::new(),
95            });
96        }
97
98        let owned: Vec<String> = non_system.iter().map(|m| m.content.clone()).collect();
99        let embeddings = self.embedder.embed(&owned).await?;
100
101        let now = now_iso();
102        let mut results = Vec::new();
103
104        for (msg, vec) in non_system.iter().zip(embeddings) {
105            let memory_id = new_memory_id();
106
107            let payload = build_raw_memory_payload(
108                &msg.content,
109                &msg.role,
110                user_id,
111                agent_id,
112                run_id,
113                metadata,
114                &now,
115            );
116
117            self.vector_index.insert(memory_id, &vec, payload).await?;
118
119            self.history
120                .add_event(memory_id, None, Some(&msg.content), MemoryAction::Add)
121                .await?;
122
123            results.push(MemoryActionResult {
124                id: memory_id,
125                action: MemoryAction::Add,
126                old_value: None,
127                new_value: Some(msg.content.clone()),
128            });
129        }
130
131        let relations = if let Some(gp) = &self.graph_pipeline {
132            let conversation = non_system
133                .iter()
134                .map(|m| m.content.as_str())
135                .collect::<Vec<_>>()
136                .join("\n");
137            let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
138            gp.add(&conversation, &filter).await.unwrap_or_else(|e| {
139                warn!(error = %e, "graph extraction failed during raw add");
140                Vec::new()
141            })
142        } else {
143            Vec::new()
144        };
145
146        info!(count = results.len(), infer = false, "raw memories stored");
147        Ok(AddResult { results, relations })
148    }
149
150    /// Full LLM-powered pipeline: extract facts, deduplicate, store.
151    async fn add_with_inference(
152        &self,
153        messages: &[ChatMessage],
154        user_id: Option<&str>,
155        agent_id: Option<&str>,
156        run_id: Option<&str>,
157        metadata: Option<&serde_json::Value>,
158    ) -> Result<AddResult> {
159        let messages = if self.config.llm.enable_vision {
160            self.describe_images(messages).await?
161        } else {
162            messages.to_vec()
163        };
164
165        let conversation = messages
166            .iter()
167            .map(|m| format!("{}: {}", m.role, m.content))
168            .collect::<Vec<_>>()
169            .join("\n");
170
171        let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
172
173        let graph_future = async {
174            match &self.graph_pipeline {
175                Some(gp) => gp.add(&conversation, &filter).await,
176                None => Ok(Vec::new()),
177            }
178        };
179
180        let vector_future =
181            self.add_vector_with_inference(&messages, user_id, agent_id, run_id, metadata, &filter);
182
183        let (vector_result, graph_result) = tokio::join!(vector_future, graph_future);
184
185        let (results, _) = vector_result?;
186        let relations = graph_result.unwrap_or_else(|e| {
187            warn!(error = %e, "graph extraction failed");
188            Vec::new()
189        });
190
191        info!(
192            count = results.len(),
193            relations = relations.len(),
194            "memory operations completed"
195        );
196        Ok(AddResult { results, relations })
197    }
198
199    /// Vector-only inference pipeline, factored out for concurrent execution.
200    async fn add_vector_with_inference(
201        &self,
202        messages: &[ChatMessage],
203        user_id: Option<&str>,
204        agent_id: Option<&str>,
205        run_id: Option<&str>,
206        metadata: Option<&serde_json::Value>,
207        filter: &MemoryFilter,
208    ) -> Result<(Vec<MemoryActionResult>, ())> {
209        let facts = pipeline::extract_facts(
210            self.llm.as_ref(),
211            messages,
212            agent_id,
213            self.config.custom_fact_extraction_prompt.as_deref(),
214        )
215        .await?;
216
217        if facts.is_empty() {
218            return Ok((Vec::new(), ()));
219        }
220
221        debug!(count = facts.len(), "extracted facts");
222
223        let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
224        let embeddings = self.embedder.embed(&fact_texts).await?;
225
226        let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
227
228        let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
229
230        for embedding in &embeddings {
231            let results = self
232                .vector_index
233                .search(embedding, DEDUP_CANDIDATE_LIMIT, Some(filter))
234                .await?;
235            for VectorSearchResult { id, score, payload } in results {
236                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
237                    let effective_score = match decay_cfg {
238                        Some(cfg) => {
239                            let age = decay::age_from_payload(&payload);
240                            let ac = decay::access_count_from_payload(&payload);
241                            decay::apply_decay(score, age, ac, cfg)
242                        }
243                        None => score,
244                    };
245                    all_retrieved.push((id, text.to_string(), effective_score));
246                }
247            }
248        }
249
250        let (update_resp, id_mapping) = pipeline::decide_memory_updates(
251            self.llm.as_ref(),
252            &facts,
253            all_retrieved,
254            self.config.custom_update_memory_prompt.as_deref(),
255        )
256        .await?;
257
258        let fact_type_map: HashMap<&str, &str> = facts
259            .iter()
260            .map(|f| (f.text.as_str(), f.memory_type.as_str()))
261            .collect();
262
263        let now = now_iso();
264        let mut results = Vec::new();
265
266        for decision in &update_resp.memory {
267            match decision.event {
268                MemoryAction::Add => {
269                    let memory_id = new_memory_id();
270                    let text = &decision.text;
271
272                    let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
273                    let vec = vecs.into_iter().next().unwrap_or_default();
274
275                    let mt = fact_type_map.get(text.as_str()).copied();
276                    let payload =
277                        build_memory_payload(text, user_id, agent_id, run_id, metadata, &now, mt);
278
279                    self.vector_index.insert(memory_id, &vec, payload).await?;
280
281                    self.history
282                        .add_event(memory_id, None, Some(text), MemoryAction::Add)
283                        .await?;
284
285                    results.push(MemoryActionResult {
286                        id: memory_id,
287                        action: MemoryAction::Add,
288                        old_value: None,
289                        new_value: Some(text.clone()),
290                    });
291                }
292                MemoryAction::Update => {
293                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
294                        let text = &decision.text;
295                        let old_text = decision.old_memory.as_deref();
296
297                        let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
298                        let vec = vecs.into_iter().next().unwrap_or_default();
299
300                        let existing_entry = self.vector_index.get(&real_id).await.ok().flatten();
301                        let prev_ac = existing_entry
302                            .as_ref()
303                            .map(|(_, p)| decay::access_count_from_payload(p))
304                            .unwrap_or(0);
305                        let existing_mt = existing_entry
306                            .as_ref()
307                            .and_then(|(_, p)| p.get("memory_type").and_then(|v| v.as_str()));
308                        let mt = existing_mt.or_else(|| fact_type_map.get(text.as_str()).copied());
309
310                        let payload = build_update_payload(
311                            text,
312                            user_id,
313                            agent_id,
314                            run_id,
315                            metadata,
316                            &now,
317                            prev_ac + 1,
318                            mt,
319                        );
320
321                        self.vector_index
322                            .update(&real_id, Some(&vec), Some(payload))
323                            .await?;
324
325                        self.history
326                            .add_event(real_id, old_text, Some(text), MemoryAction::Update)
327                            .await?;
328
329                        results.push(MemoryActionResult {
330                            id: real_id,
331                            action: MemoryAction::Update,
332                            old_value: old_text.map(String::from),
333                            new_value: Some(text.clone()),
334                        });
335                    }
336                }
337                MemoryAction::Delete => {
338                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
339                        let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
340
341                        self.vector_index.delete(&real_id).await?;
342
343                        self.history
344                            .add_event(real_id, old_text, None, MemoryAction::Delete)
345                            .await?;
346
347                        results.push(MemoryActionResult {
348                            id: real_id,
349                            action: MemoryAction::Delete,
350                            old_value: old_text.map(String::from),
351                            new_value: None,
352                        });
353                    }
354                }
355                MemoryAction::None => {
356                    if let Some(real_id) = id_mapping.resolve(&decision.id) {
357                        let needs_update = agent_id.is_some() || run_id.is_some();
358                        if needs_update
359                            && let Ok(Some(entry)) = self.vector_index.get(&real_id).await
360                        {
361                            let mut payload = entry.1;
362                            let mut changed = false;
363                            if let Some(aid) = agent_id {
364                                let cur = payload.get("agent_id").and_then(|v| v.as_str());
365                                if cur != Some(aid) {
366                                    payload["agent_id"] =
367                                        serde_json::Value::String(aid.to_string());
368                                    changed = true;
369                                }
370                            }
371                            if let Some(rid) = run_id {
372                                let cur = payload.get("run_id").and_then(|v| v.as_str());
373                                if cur != Some(rid) {
374                                    payload["run_id"] = serde_json::Value::String(rid.to_string());
375                                    changed = true;
376                                }
377                            }
378                            if changed {
379                                payload["updated_at"] = serde_json::Value::String(now.clone());
380                                if let Err(e) = self
381                                    .vector_index
382                                    .update(&real_id, None, Some(payload))
383                                    .await
384                                {
385                                    warn!(id = %real_id, "failed to update session IDs: {e}");
386                                } else {
387                                    debug!(
388                                        id = %real_id,
389                                        "updated session IDs on NONE action"
390                                    );
391                                }
392                            }
393                        }
394                    }
395                }
396            }
397        }
398
399        Ok((results, ()))
400    }
401
402    /// When `enable_vision` is set, send each message's images to the LLM and
403    /// append the resulting description to the message's text content.
404    pub(crate) async fn describe_images(
405        &self,
406        messages: &[ChatMessage],
407    ) -> Result<Vec<ChatMessage>> {
408        let mut out = Vec::with_capacity(messages.len());
409        for msg in messages {
410            if msg.images.is_empty() {
411                out.push(msg.clone());
412                continue;
413            }
414
415            let llm_msg = mem7_llm::LlmMessage::user_with_images(
416                VISION_DESCRIBE_PROMPT.to_string(),
417                msg.images.clone(),
418            );
419            match self.llm.chat_completion(&[llm_msg], None).await {
420                Ok(resp) => {
421                    let mut enriched = msg.clone();
422                    if enriched.content.is_empty() {
423                        enriched.content = resp.content;
424                    } else {
425                        enriched.content = format!(
426                            "{}\n[Image description: {}]",
427                            enriched.content, resp.content
428                        );
429                    }
430                    enriched.images.clear();
431                    out.push(enriched);
432                }
433                Err(e) => {
434                    warn!(error = %e, "vision description failed, using original text");
435                    out.push(msg.clone());
436                }
437            }
438        }
439        Ok(out)
440    }
441}