1use std::collections::HashMap;
2
3use mem7_core::{
4 AddOptions, AddResult, ChatMessage, MemoryAction, MemoryActionResult, MemoryFilter,
5 new_memory_id,
6};
7use mem7_datetime::now_iso;
8use mem7_error::Result;
9use mem7_vector::VectorSearchResult;
10use tracing::{debug, info, instrument, warn};
11use uuid::Uuid;
12
13use crate::constants::*;
14use crate::decay;
15use crate::engine::MemoryEngine;
16use crate::payload::{
17 build_memory_payload, build_raw_memory_payload, build_update_payload, payload_to_event_metadata,
18};
19use crate::pipeline;
20use crate::prompts::VISION_DESCRIBE_PROMPT;
21use crate::require_scope;
22
23impl MemoryEngine {
24 #[instrument(skip(self, messages, metadata), fields(msg_count = messages.len()))]
35 pub async fn add(
36 &self,
37 messages: &[ChatMessage],
38 user_id: Option<&str>,
39 agent_id: Option<&str>,
40 run_id: Option<&str>,
41 metadata: Option<&serde_json::Value>,
42 infer: bool,
43 ) -> Result<AddResult> {
44 let opts = AddOptions {
45 user_id,
46 agent_id,
47 run_id,
48 metadata,
49 infer,
50 };
51 self.add_with_options(messages, &opts).await
52 }
53
54 pub async fn add_with_options(
56 &self,
57 messages: &[ChatMessage],
58 opts: &AddOptions<'_>,
59 ) -> Result<AddResult> {
60 require_scope("add", opts.user_id, opts.agent_id, opts.run_id)?;
61 if opts.infer {
62 self.add_with_inference(
63 messages,
64 opts.user_id,
65 opts.agent_id,
66 opts.run_id,
67 opts.metadata,
68 )
69 .await
70 } else {
71 self.add_raw(
72 messages,
73 opts.user_id,
74 opts.agent_id,
75 opts.run_id,
76 opts.metadata,
77 )
78 .await
79 }
80 }
81
82 async fn add_raw(
85 &self,
86 messages: &[ChatMessage],
87 user_id: Option<&str>,
88 agent_id: Option<&str>,
89 run_id: Option<&str>,
90 metadata: Option<&serde_json::Value>,
91 ) -> Result<AddResult> {
92 let non_system: Vec<&ChatMessage> =
93 messages.iter().filter(|m| m.role != "system").collect();
94
95 if non_system.is_empty() {
96 return Ok(AddResult {
97 results: Vec::new(),
98 relations: Vec::new(),
99 });
100 }
101
102 let owned: Vec<String> = non_system.iter().map(|m| m.content.clone()).collect();
103 let embeddings = self.embedder.embed(&owned).await?;
104
105 let now = now_iso();
106 let mut results = Vec::new();
107
108 for (msg, vec) in non_system.iter().zip(embeddings) {
109 let memory_id = new_memory_id();
110
111 let payload = build_raw_memory_payload(
112 &msg.content,
113 &msg.role,
114 user_id,
115 agent_id,
116 run_id,
117 metadata,
118 &now,
119 );
120 let audit = payload_to_event_metadata(&payload);
121
122 self.vector_index.insert(memory_id, &vec, payload).await?;
123
124 self.history
125 .add_event(
126 memory_id,
127 None,
128 Some(&msg.content),
129 MemoryAction::Add,
130 audit,
131 )
132 .await?;
133
134 results.push(MemoryActionResult {
135 id: memory_id,
136 action: MemoryAction::Add,
137 old_value: None,
138 new_value: Some(msg.content.clone()),
139 });
140 }
141
142 let relations = if let Some(gp) = &self.graph_pipeline {
143 let conversation = non_system
144 .iter()
145 .map(|m| m.content.as_str())
146 .collect::<Vec<_>>()
147 .join("\n");
148 let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
149 gp.add(&conversation, &filter).await.unwrap_or_else(|e| {
150 warn!(error = %e, "graph extraction failed during raw add");
151 Vec::new()
152 })
153 } else {
154 Vec::new()
155 };
156
157 info!(count = results.len(), infer = false, "raw memories stored");
158 Ok(AddResult { results, relations })
159 }
160
161 async fn add_with_inference(
163 &self,
164 messages: &[ChatMessage],
165 user_id: Option<&str>,
166 agent_id: Option<&str>,
167 run_id: Option<&str>,
168 metadata: Option<&serde_json::Value>,
169 ) -> Result<AddResult> {
170 let messages = if self.config.llm.enable_vision {
171 self.describe_images(messages).await?
172 } else {
173 messages.to_vec()
174 };
175
176 let conversation = messages
177 .iter()
178 .map(|m| format!("{}: {}", m.role, m.content))
179 .collect::<Vec<_>>()
180 .join("\n");
181
182 let filter = MemoryFilter::from_session(user_id, agent_id, run_id);
183
184 let graph_future = async {
185 match &self.graph_pipeline {
186 Some(gp) => gp.add(&conversation, &filter).await,
187 None => Ok(Vec::new()),
188 }
189 };
190
191 let vector_future =
192 self.add_vector_with_inference(&messages, user_id, agent_id, run_id, metadata, &filter);
193
194 let (vector_result, graph_result) = tokio::join!(vector_future, graph_future);
195
196 let (results, _) = vector_result?;
197 let relations = graph_result.unwrap_or_else(|e| {
198 warn!(error = %e, "graph extraction failed");
199 Vec::new()
200 });
201
202 info!(
203 count = results.len(),
204 relations = relations.len(),
205 "memory operations completed"
206 );
207 Ok(AddResult { results, relations })
208 }
209
210 async fn add_vector_with_inference(
212 &self,
213 messages: &[ChatMessage],
214 user_id: Option<&str>,
215 agent_id: Option<&str>,
216 run_id: Option<&str>,
217 metadata: Option<&serde_json::Value>,
218 filter: &MemoryFilter,
219 ) -> Result<(Vec<MemoryActionResult>, ())> {
220 let facts = pipeline::extract_facts(
221 self.llm.as_ref(),
222 messages,
223 agent_id,
224 self.config.custom_fact_extraction_prompt.as_deref(),
225 )
226 .await?;
227
228 if facts.is_empty() {
229 return Ok((Vec::new(), ()));
230 }
231
232 debug!(count = facts.len(), "extracted facts");
233
234 let fact_texts: Vec<String> = facts.iter().map(|f| f.text.clone()).collect();
235 let embeddings = self.embedder.embed(&fact_texts).await?;
236
237 let mut all_retrieved: Vec<(Uuid, String, f32)> = Vec::new();
238
239 let decay_cfg = self.config.decay.as_ref().filter(|d| d.enabled);
240
241 for embedding in &embeddings {
242 let results = self
243 .vector_index
244 .search(embedding, DEDUP_CANDIDATE_LIMIT, Some(filter))
245 .await?;
246 for VectorSearchResult { id, score, payload } in results {
247 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
248 let effective_score = match decay_cfg {
249 Some(cfg) => {
250 let age = decay::age_from_payload(&payload);
251 let ac = decay::access_count_from_payload(&payload);
252 decay::apply_decay(score, age, ac, cfg)
253 }
254 None => score,
255 };
256 all_retrieved.push((id, text.to_string(), effective_score));
257 }
258 }
259 }
260
261 let (update_resp, id_mapping) = pipeline::decide_memory_updates(
262 self.llm.as_ref(),
263 &facts,
264 all_retrieved,
265 self.config.custom_update_memory_prompt.as_deref(),
266 )
267 .await?;
268
269 let fact_type_map: HashMap<&str, &str> = facts
270 .iter()
271 .map(|f| (f.text.as_str(), f.memory_type.as_str()))
272 .collect();
273
274 let now = now_iso();
275 let mut results = Vec::new();
276
277 for decision in &update_resp.memory {
278 match decision.event {
279 MemoryAction::Add => {
280 let memory_id = new_memory_id();
281 let text = &decision.text;
282
283 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
284 let vec = vecs.into_iter().next().unwrap_or_default();
285
286 let mt = fact_type_map.get(text.as_str()).copied();
287 let payload =
288 build_memory_payload(text, user_id, agent_id, run_id, metadata, &now, mt);
289 let audit = payload_to_event_metadata(&payload);
290
291 self.vector_index.insert(memory_id, &vec, payload).await?;
292
293 self.history
294 .add_event(memory_id, None, Some(text), MemoryAction::Add, audit)
295 .await?;
296
297 results.push(MemoryActionResult {
298 id: memory_id,
299 action: MemoryAction::Add,
300 old_value: None,
301 new_value: Some(text.clone()),
302 });
303 }
304 MemoryAction::Update => {
305 if let Some(real_id) = id_mapping.resolve(&decision.id) {
306 let text = &decision.text;
307 let old_text = decision.old_memory.as_deref();
308
309 let vecs = self.embedder.embed(std::slice::from_ref(text)).await?;
310 let vec = vecs.into_iter().next().unwrap_or_default();
311
312 let existing_entry = self.vector_index.get(&real_id).await.ok().flatten();
313 let prev_ac = existing_entry
314 .as_ref()
315 .map(|(_, p)| decay::access_count_from_payload(p))
316 .unwrap_or(0);
317 let existing_mt = existing_entry
318 .as_ref()
319 .and_then(|(_, p)| p.get("memory_type").and_then(|v| v.as_str()));
320 let existing_created_at = existing_entry
321 .as_ref()
322 .and_then(|(_, p)| p.get("created_at").and_then(|v| v.as_str()));
323 let mt = existing_mt.or_else(|| fact_type_map.get(text.as_str()).copied());
324
325 let payload = build_update_payload(
326 text,
327 user_id,
328 agent_id,
329 run_id,
330 metadata,
331 existing_created_at,
332 &now,
333 prev_ac + 1,
334 mt,
335 );
336 let audit = payload_to_event_metadata(&payload);
337
338 self.vector_index
339 .update(&real_id, Some(&vec), Some(payload))
340 .await?;
341
342 self.history
343 .add_event(real_id, old_text, Some(text), MemoryAction::Update, audit)
344 .await?;
345
346 results.push(MemoryActionResult {
347 id: real_id,
348 action: MemoryAction::Update,
349 old_value: old_text.map(String::from),
350 new_value: Some(text.clone()),
351 });
352 }
353 }
354 MemoryAction::Delete => {
355 if let Some(real_id) = id_mapping.resolve(&decision.id) {
356 let old_text = decision.old_memory.as_deref().or(Some(&decision.text));
357 let audit = self
358 .vector_index
359 .get(&real_id)
360 .await
361 .ok()
362 .flatten()
363 .map(|(_, payload)| {
364 let mut metadata = payload_to_event_metadata(&payload);
365 metadata.is_deleted = true;
366 metadata
367 })
368 .unwrap_or_else(|| mem7_core::MemoryEventMetadata {
369 is_deleted: true,
370 ..Default::default()
371 });
372
373 self.vector_index.delete(&real_id).await?;
374
375 self.history
376 .add_event(real_id, old_text, None, MemoryAction::Delete, audit)
377 .await?;
378
379 results.push(MemoryActionResult {
380 id: real_id,
381 action: MemoryAction::Delete,
382 old_value: old_text.map(String::from),
383 new_value: None,
384 });
385 }
386 }
387 MemoryAction::None => {
388 if let Some(real_id) = id_mapping.resolve(&decision.id) {
389 let needs_update = agent_id.is_some() || run_id.is_some();
390 if needs_update
391 && let Ok(Some(entry)) = self.vector_index.get(&real_id).await
392 {
393 let mut payload = entry.1;
394 let mut changed = false;
395 if let Some(aid) = agent_id {
396 let cur = payload.get("agent_id").and_then(|v| v.as_str());
397 if cur != Some(aid) {
398 payload["agent_id"] =
399 serde_json::Value::String(aid.to_string());
400 changed = true;
401 }
402 }
403 if let Some(rid) = run_id {
404 let cur = payload.get("run_id").and_then(|v| v.as_str());
405 if cur != Some(rid) {
406 payload["run_id"] = serde_json::Value::String(rid.to_string());
407 changed = true;
408 }
409 }
410 if changed {
411 payload["updated_at"] = serde_json::Value::String(now.clone());
412 if let Err(e) = self
413 .vector_index
414 .update(&real_id, None, Some(payload))
415 .await
416 {
417 warn!(id = %real_id, "failed to update session IDs: {e}");
418 } else {
419 debug!(
420 id = %real_id,
421 "updated session IDs on NONE action"
422 );
423 }
424 }
425 }
426 }
427 }
428 }
429 }
430
431 Ok((results, ()))
432 }
433
434 pub(crate) async fn describe_images(
437 &self,
438 messages: &[ChatMessage],
439 ) -> Result<Vec<ChatMessage>> {
440 let mut out = Vec::with_capacity(messages.len());
441 for msg in messages {
442 if msg.images.is_empty() {
443 out.push(msg.clone());
444 continue;
445 }
446
447 let llm_msg = mem7_llm::LlmMessage::user_with_images(
448 VISION_DESCRIBE_PROMPT.to_string(),
449 msg.images.clone(),
450 );
451 match self.llm.chat_completion(&[llm_msg], None).await {
452 Ok(resp) => {
453 let mut enriched = msg.clone();
454 if enriched.content.is_empty() {
455 enriched.content = resp.content;
456 } else {
457 enriched.content = format!(
458 "{}\n[Image description: {}]",
459 enriched.content, resp.content
460 );
461 }
462 enriched.images.clear();
463 out.push(enriched);
464 }
465 Err(e) => {
466 warn!(error = %e, "vision description failed, using original text");
467 out.push(msg.clone());
468 }
469 }
470 }
471 Ok(out)
472 }
473}