mem0_rust/memory/
manager.rs

1//! Core Memory manager.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tracing::{debug, info, warn};
6use uuid::Uuid;
7use chrono::Utc;
8
9use crate::config::MemoryConfig;
10use crate::embeddings::{create_embedder, Embedder};
11use crate::errors::{LLMError, MemoryError};
12use crate::history::HistoryManager;
13use crate::llms::{create_llm, generate_json, GenerateOptions, LLM};
14use crate::models::{
15    AddOptions, AddResult, EventType, Filters, GetAllOptions, HistoryEntry, MemoryEvent,
16    MemoryRecord, Message, Messages, Payload, ResetOptions, Role, ScoredMemory, SearchOptions,
17    SearchResult,
18};
19use crate::vector_stores::{create_vector_store, VectorStore};
20use crate::rerankers::{create_reranker, Reranker};
21
22use super::prompts::{
23    format_fact_extraction_input, format_memory_update_input, FACT_EXTRACTION_PROMPT,
24    MEMORY_UPDATE_PROMPT,
25};
26
27/// Main Memory interface
28pub struct Memory {
29    embedder: Arc<dyn Embedder>,
30    vector_store: Arc<dyn VectorStore>,
31    llm: Option<Arc<dyn LLM>>,
32    history: Option<Arc<HistoryManager>>,
33    reranker: Option<Arc<dyn Reranker>>,
34    #[allow(dead_code)]
35    config: MemoryConfig,
36}
37
38impl Memory {
39    /// Create a new Memory instance
40    pub async fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
41        let embedder = create_embedder(&config.embedder)?;
42        let dimensions = embedder.dimensions();
43
44        let vector_store =
45            create_vector_store(&config.vector_store, &config.collection_name, dimensions).await?;
46
47        let llm = if let Some(llm_config) = &config.llm {
48            Some(create_llm(llm_config)?)
49        } else {
50            None
51        };
52
53        let history = if let Some(path) = &config.history_db_path {
54            Some(Arc::new(HistoryManager::new(path)?))
55        } else {
56            None
57        };
58
59        let reranker = if let Some(reranker_config) = &config.reranker {
60            Some(create_reranker(reranker_config)?)
61        } else {
62            None
63        };
64
65        info!(
66            "Initialized Memory with {} embedder, {} dimensions",
67            embedder.model_name(),
68            dimensions
69        );
70
71        Ok(Self {
72            embedder,
73            vector_store,
74            llm,
75            history,
76            reranker,
77            config,
78        })
79    }
80
81    /// Add memories from messages
82    pub async fn add(
83        &self,
84        messages: impl Into<Messages>,
85        options: AddOptions,
86    ) -> Result<AddResult, MemoryError> {
87        let messages = messages.into().into_messages();
88        // Validate scoping
89        if options.user_id.is_none() && options.agent_id.is_none() && options.run_id.is_none() {
90            return Err(MemoryError::InvalidInput(
91                "At least one of user_id, agent_id, or run_id is required".to_string(),
92            ));
93        }
94
95        let results = if options.infer && self.llm.is_some() {
96            // Use LLM for fact extraction
97            self.add_with_inference(&messages, &options).await?
98        } else {
99            // Add messages directly without inference
100            self.add_raw(&messages, &options).await?
101        };
102
103        Ok(AddResult { results })
104    }
105
106    /// Add messages directly without LLM inference
107    async fn add_raw(
108        &self,
109        messages: &[Message],
110        options: &AddOptions,
111    ) -> Result<Vec<MemoryEvent>, MemoryError> {
112        let mut results = Vec::new();
113
114        for msg in messages {
115            if msg.role == Role::System {
116                continue;
117            }
118
119            let record = MemoryRecord::with_scoping(
120                msg.content.clone(),
121                options
122                    .metadata
123                    .as_ref()
124                    .map(|m| serde_json::to_value(m).unwrap_or_default())
125                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
126                options.user_id.clone(),
127                options.agent_id.clone(),
128                options.run_id.clone(),
129            );
130
131            let embedding = self.embedder.embed(&record.content).await?;
132            let payload = Payload::from(&record);
133
134            self.vector_store
135                .insert(&record.id.to_string(), embedding, payload)
136                .await?;
137
138            if let Some(history) = &self.history {
139                let _ = history.add_history(
140                    record.id,
141                    None,
142                    record.content.clone(),
143                    EventType::Add,
144                    record.created_at,
145                    record.user_id.clone(),
146                    record.agent_id.clone(),
147                    record.run_id.clone(),
148                );
149            }
150
151            results.push(MemoryEvent {
152                id: record.id,
153                memory: record.content,
154                event: EventType::Add,
155            });
156        }
157
158        Ok(results)
159    }
160
161    /// Add messages with LLM inference
162    async fn add_with_inference(
163        &self,
164        messages: &[Message],
165        options: &AddOptions,
166    ) -> Result<Vec<MemoryEvent>, MemoryError> {
167        let llm = self.llm.as_ref().ok_or(LLMError::NotConfigured)?;
168
169        // Format messages for extraction
170        let messages_text = messages
171            .iter()
172            .map(|m| format!("{:?}: {}", m.role, m.content))
173            .collect::<Vec<_>>()
174            .join("\n");
175
176        // Extract facts
177        let extraction_messages = vec![
178            Message::system(FACT_EXTRACTION_PROMPT),
179            Message::user(format_fact_extraction_input(&messages_text)),
180        ];
181
182        #[derive(serde::Deserialize)]
183        struct FactsResponse {
184            facts: Vec<String>,
185        }
186
187        let facts: FactsResponse = generate_json(
188            llm.as_ref(),
189            &extraction_messages,
190            GenerateOptions::default(),
191        )
192        .await?;
193
194        if facts.facts.is_empty() {
195            debug!("No facts extracted from messages");
196            return Ok(Vec::new());
197        }
198
199        info!("Extracted {} facts", facts.facts.len());
200
201        // Search for existing related memories
202        let mut existing_memories: Vec<(String, String)> = Vec::new(); // (Index, Content)
203        let mut memory_map: HashMap<String, String> = HashMap::new(); // Index -> RealID
204
205        let search_filters = Filters {
206            conditions: vec![],
207            logic: crate::models::FilterLogic::And,
208        };
209
210        for fact in &facts.facts {
211            let embedding = self.embedder.embed(fact).await?;
212
213            let similar = self
214                .vector_store
215                .search(&embedding, 5, Some(&search_filters))
216                .await?;
217
218            for result in similar {
219                // Check if we already have this memory in our list (dedupe by real ID)
220                let real_id = result.id.clone();
221                if !memory_map.values().any(|rid| rid == &real_id) {
222                     let index = memory_map.len().to_string();
223                     memory_map.insert(index.clone(), real_id);
224                     existing_memories.push((index, result.payload.data));
225                }
226            }
227        }
228
229        // Determine memory actions
230        let update_messages = vec![
231            Message::system(MEMORY_UPDATE_PROMPT),
232            Message::user(format_memory_update_input(&existing_memories, &facts.facts)),
233        ];
234
235        #[derive(serde::Deserialize)]
236        struct MemoryAction {
237            event: String,
238            text: Option<String>,
239            id: Option<String>,
240        }
241
242        #[derive(serde::Deserialize)]
243        struct MemoryActionsResponse {
244            memory: Vec<MemoryAction>,
245        }
246
247        let actions: MemoryActionsResponse = generate_json(
248            llm.as_ref(),
249            &update_messages,
250            GenerateOptions::default(),
251        )
252        .await?;
253
254        let mut results = Vec::new();
255
256        for action in actions.memory {
257            match action.event.to_uppercase().as_str() {
258                "ADD" => {
259                    if let Some(text) = action.text {
260                        let record = MemoryRecord::with_scoping(
261                            &text,
262                            options
263                                .metadata
264                                .as_ref()
265                                .map(|m| serde_json::to_value(m).unwrap_or_default())
266                                .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
267                            options.user_id.clone(),
268                            options.agent_id.clone(),
269                            options.run_id.clone(),
270                        );
271
272                        let embedding = self.embedder.embed(&text).await?;
273                        let payload = Payload::from(&record);
274
275                        self.vector_store
276                            .insert(&record.id.to_string(), embedding, payload)
277                            .await?;
278
279                        if let Some(history) = &self.history {
280                            let _ = history.add_history(
281                                record.id,
282                                None,
283                                record.content.clone(),
284                                EventType::Add,
285                                record.created_at,
286                                record.user_id.clone(),
287                                record.agent_id.clone(),
288                                record.run_id.clone(),
289                            );
290                        }
291
292                        results.push(MemoryEvent {
293                            id: record.id,
294                            memory: text,
295                            event: EventType::Add,
296                        });
297                    }
298                }
299                "UPDATE" => {
300                    if let (Some(index_id), Some(text)) = (action.id, action.text) {
301                        if let Some(real_id) = memory_map.get(&index_id) {
302                            debug!("Updating memory {} (index {}) with: {}", real_id, index_id, text);
303                            
304                            // Perform update
305                            match self.update(real_id, &text).await {
306                                Ok(record) => {
307                                    results.push(MemoryEvent {
308                                        id: record.id,
309                                        memory: text,
310                                        event: EventType::Update,
311                                    });
312                                },
313                                Err(e) => {
314                                    warn!("Failed to update memory {}: {}", real_id, e);
315                                }
316                            }
317                        } else {
318                            warn!("LLM tried to update unknown memory index: {}", index_id);
319                        }
320                    }
321                }
322                "DELETE" => {
323                    if let Some(index_id) = action.id {
324                        if let Some(real_id) = memory_map.get(&index_id) {
325                            debug!("Deleting memory {} (index {})", real_id, index_id);
326                            
327                            // Perform delete
328                            match self.delete(real_id).await {
329                                Ok(_) => {
330                                     // ID is needed for event, but delete returns void.
331                                     // We can use Uuid::parse_str(real_id)
332                                     if let Ok(uuid) = Uuid::parse_str(real_id) {
333                                         results.push(MemoryEvent {
334                                            id: uuid,
335                                            memory: String::new(), // Deleted
336                                            event: EventType::Delete,
337                                        });
338                                     }
339                                },
340                                Err(e) => {
341                                    warn!("Failed to delete memory {}: {}", real_id, e);
342                                }
343                            }
344                        } else {
345                            warn!("LLM tried to delete unknown memory index: {}", index_id);
346                        }
347                    }
348                }
349                "NOOP" => {
350                    debug!("No action needed");
351                }
352                _ => {
353                    warn!("Unknown memory action: {}", action.event);
354                }
355            }
356        }
357
358        Ok(results)
359    }
360
361    /// Search for memories
362    pub async fn search(
363        &self,
364        query: &str,
365        options: SearchOptions,
366    ) -> Result<SearchResult, MemoryError> {
367        let embedding = self.embedder.embed(query).await?;
368        let limit = options.limit.unwrap_or(10);
369        let threshold = options.threshold.unwrap_or(0.0);
370
371        // Fetch more candidates if reranking is enabled
372        let search_limit = if options.rerank { limit * 10 } else { limit * 2 };
373
374        let results = self
375            .vector_store
376            .search(&embedding, search_limit, options.filters.as_ref())
377            .await?;
378
379        let mut scored: Vec<ScoredMemory> = results
380            .into_iter()
381            .map(|r| r.to_scored_memory())
382            .collect();
383
384        // Apply scoping filters
385        scored.retain(|m| {
386            if let Some(ref user_id) = options.user_id {
387                if m.record.user_id.as_ref() != Some(user_id) {
388                    return false;
389                }
390            }
391            if let Some(ref agent_id) = options.agent_id {
392                if m.record.agent_id.as_ref() != Some(agent_id) {
393                    return false;
394                }
395            }
396            if let Some(ref run_id) = options.run_id {
397                if m.record.run_id.as_ref() != Some(run_id) {
398                    return false;
399                }
400            }
401            true
402        });
403
404        // Filter by threshold before reranking (optional, but saves rerank quota)
405        scored.retain(|m| m.score >= threshold);
406
407        // Reranking
408        if options.rerank {
409            if let Some(reranker) = &self.reranker {
410                scored = reranker.rerank(query, scored).await?;
411            } else {
412                 warn!("Reranking requested but no reranker configured");
413            }
414        }
415        
416        // Final sort and limit
417        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
418        scored.truncate(limit);
419
420        Ok(SearchResult { results: scored })
421    }
422
423    /// Get a memory by ID
424    pub async fn get(&self, id: &str) -> Result<Option<MemoryRecord>, MemoryError> {
425        let result = self.vector_store.get(id).await?;
426        Ok(result.map(|r| r.to_memory_record()))
427    }
428
429    /// Get all memories
430    pub async fn get_all(&self, options: GetAllOptions) -> Result<Vec<MemoryRecord>, MemoryError> {
431        let limit = options.limit.unwrap_or(100);
432        let results = self.vector_store.list(None, limit).await?;
433
434        let mut records: Vec<MemoryRecord> =
435            results.into_iter().map(|r| r.to_memory_record()).collect();
436
437        // Apply scoping filters
438        records.retain(|m| {
439            if let Some(ref user_id) = options.user_id {
440                if m.user_id.as_ref() != Some(user_id) {
441                    return false;
442                }
443            }
444            if let Some(ref agent_id) = options.agent_id {
445                if m.agent_id.as_ref() != Some(agent_id) {
446                    return false;
447                }
448            }
449            if let Some(ref run_id) = options.run_id {
450                if m.run_id.as_ref() != Some(run_id) {
451                    return false;
452                }
453            }
454            true
455        });
456
457        Ok(records)
458    }
459
460    /// Update a memory
461    pub async fn update(&self, id: &str, content: &str) -> Result<MemoryRecord, MemoryError> {
462        // Get existing record
463        let existing = self
464            .vector_store
465            .get(id)
466            .await?
467            .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
468
469        let mut record = existing.to_memory_record();
470        let previous_content = record.content.clone();
471        record.update_content(content);
472
473        let embedding = self.embedder.embed(content).await?;
474        let payload = Payload::from(&record);
475
476        self.vector_store
477            .update(id, Some(embedding), payload)
478            .await?;
479
480        if let Some(history) = &self.history {
481            let _ = history.add_history(
482                record.id,
483                Some(previous_content),
484                record.content.clone(),
485                EventType::Update,
486                Utc::now(),
487                record.user_id.clone(),
488                record.agent_id.clone(),
489                record.run_id.clone(),
490            );
491        }
492
493        Ok(record)
494    }
495
496    /// Delete a memory
497    pub async fn delete(&self, id: &str) -> Result<(), MemoryError> {
498        // Get record first for history
499        let record = self.get(id).await?;
500        
501        self.vector_store.delete(id).await?;
502
503        if let Some(record) = record {
504            if let Some(history) = &self.history {
505                let _ = history.add_history(
506                    record.id,
507                    Some(record.content),
508                    "DELETED".to_string(),
509                    EventType::Delete,
510                    Utc::now(),
511                    record.user_id,
512                    record.agent_id,
513                    record.run_id,
514                );
515            }
516        }
517        
518        Ok(())
519    }
520
521    /// Get memory history
522    pub async fn history(&self, id: &str) -> Result<Vec<HistoryEntry>, MemoryError> {
523        if let Some(history) = &self.history {
524            let memory_id = Uuid::parse_str(id).map_err(|e| MemoryError::InvalidInput(e.to_string()))?;
525            history.get_history(memory_id)
526        } else {
527            Ok(Vec::new())
528        }
529    }
530
531    /// Reset all memories
532    pub async fn reset(&self, options: ResetOptions) -> Result<(), MemoryError> {
533        // Build filters based on options
534        let filters = if options.user_id.is_some() || options.agent_id.is_some() {
535            // TODO: Build proper filters
536            None
537        } else {
538            None
539        };
540
541        self.vector_store.delete_all(filters.as_ref()).await?;
542        
543        if let Some(history) = &self.history {
544            // If global reset, clear history too
545            if filters.is_none() {
546                history.reset()?;
547            }
548        }
549        
550        Ok(())
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[tokio::test]
559    async fn test_memory_creation() {
560        let config = MemoryConfig::default();
561        let memory = Memory::new(config).await;
562        assert!(memory.is_ok());
563    }
564
565    #[tokio::test]
566    async fn test_add_raw() {
567        let config = MemoryConfig::default();
568        let memory = Memory::new(config).await.unwrap();
569
570        let result = memory
571            .add(
572                "Test memory content",
573                AddOptions {
574                    user_id: Some("test_user".to_string()),
575                    infer: false,
576                    ..Default::default()
577                },
578            )
579            .await;
580
581        assert!(result.is_ok());
582        assert_eq!(result.unwrap().results.len(), 1);
583    }
584
585    #[tokio::test]
586    async fn test_search() {
587        let config = MemoryConfig::default();
588        let memory = Memory::new(config).await.unwrap();
589
590        // Add a memory
591        memory
592            .add(
593                "I love programming in Rust",
594                AddOptions {
595                    user_id: Some("test_user".to_string()),
596                    infer: false,
597                    ..Default::default()
598                },
599            )
600            .await
601            .unwrap();
602
603        // Search for it
604        let results = memory
605            .search(
606                "Rust programming",
607                SearchOptions {
608                    user_id: Some("test_user".to_string()),
609                    ..Default::default()
610                },
611            )
612            .await
613            .unwrap();
614
615        assert!(!results.results.is_empty());
616    }
617}