Skip to main content

mem7_dedup/
scorer.rs

1use std::collections::HashMap;
2
3use mem7_core::MemoryAction;
4use mem7_error::{Mem7Error, Result};
5use serde::Deserialize;
6use uuid::Uuid;
7
8/// A single memory decision from the LLM's response.
9#[derive(Debug, Clone, Deserialize)]
10pub struct MemoryDecision {
11    pub id: String,
12    pub text: String,
13    pub event: MemoryAction,
14    pub old_memory: Option<String>,
15}
16
17/// Parsed LLM memory-update response.
18#[derive(Debug, Clone, Deserialize)]
19pub struct MemoryUpdateResponse {
20    pub memory: Vec<MemoryDecision>,
21}
22
23/// Maps integer IDs (used by LLM) back to real UUIDs.
24/// The LLM receives integer IDs to avoid hallucinating UUIDs.
25pub struct IdMapping {
26    int_to_uuid: HashMap<String, Uuid>,
27}
28
29impl IdMapping {
30    pub fn new() -> Self {
31        Self {
32            int_to_uuid: HashMap::new(),
33        }
34    }
35
36    pub fn add(&mut self, int_id: String, uuid: Uuid) {
37        self.int_to_uuid.insert(int_id, uuid);
38    }
39
40    pub fn resolve(&self, int_id: &str) -> Option<Uuid> {
41        self.int_to_uuid.get(int_id).copied()
42    }
43}
44
45impl Default for IdMapping {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// Build the "existing memory" dict that gets sent to the LLM for comparison.
52/// Maps each existing memory to an integer ID for stability.
53pub fn build_existing_memory_dict(existing: &[(Uuid, String)]) -> (serde_json::Value, IdMapping) {
54    let mut mapping = IdMapping::new();
55    let mut entries = Vec::new();
56
57    for (idx, (uuid, text)) in existing.iter().enumerate() {
58        let int_id = idx.to_string();
59        mapping.add(int_id.clone(), *uuid);
60        entries.push(serde_json::json!({
61            "id": int_id,
62            "text": text,
63        }));
64    }
65
66    (serde_json::Value::Array(entries), mapping)
67}
68
69/// Parse the LLM's JSON response for memory update decisions.
70pub fn parse_memory_update_response(json_str: &str) -> Result<MemoryUpdateResponse> {
71    // Try parsing directly
72    if let Ok(resp) = serde_json::from_str::<MemoryUpdateResponse>(json_str) {
73        return Ok(resp);
74    }
75
76    // Try extracting JSON from markdown code blocks
77    let trimmed = json_str.trim();
78    let cleaned = if trimmed.starts_with("```json") {
79        trimmed
80            .trim_start_matches("```json")
81            .trim_end_matches("```")
82            .trim()
83    } else if trimmed.starts_with("```") {
84        trimmed
85            .trim_start_matches("```")
86            .trim_end_matches("```")
87            .trim()
88    } else {
89        trimmed
90    };
91
92    serde_json::from_str(cleaned).map_err(|e| {
93        Mem7Error::Serialization(format!(
94            "Failed to parse memory update response: {e}\nRaw: {json_str}"
95        ))
96    })
97}
98
99/// Deduplicate a list of existing memories retrieved for multiple facts.
100/// Returns a deduplicated list of (uuid, text) pairs.
101pub fn deduplicate_memories(memories: Vec<(Uuid, String, f32)>) -> Vec<(Uuid, String)> {
102    let mut seen = HashMap::new();
103    for (uuid, text, score) in memories {
104        seen.entry(uuid)
105            .and_modify(|(existing_text, existing_score): &mut (String, f32)| {
106                if score > *existing_score {
107                    *existing_text = text.clone();
108                    *existing_score = score;
109                }
110            })
111            .or_insert((text, score));
112    }
113    seen.into_iter()
114        .map(|(uuid, (text, _))| (uuid, text))
115        .collect()
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn parse_valid_response() {
124        let json = r#"{"memory": [{"id": "0", "text": "Loves pizza", "event": "ADD"}]}"#;
125        let resp = parse_memory_update_response(json).unwrap();
126        assert_eq!(resp.memory.len(), 1);
127        assert_eq!(resp.memory[0].text, "Loves pizza");
128        assert_eq!(resp.memory[0].event, MemoryAction::Add);
129    }
130
131    #[test]
132    fn parse_code_block_response() {
133        let json = "```json\n{\"memory\": [{\"id\": \"0\", \"text\": \"test\", \"event\": \"NONE\"}]}\n```";
134        let resp = parse_memory_update_response(json).unwrap();
135        assert_eq!(resp.memory.len(), 1);
136    }
137
138    #[test]
139    fn dedup_keeps_highest_score() {
140        let id = Uuid::now_v7();
141        let memories = vec![
142            (id, "low score".into(), 0.5),
143            (id, "high score".into(), 0.9),
144        ];
145        let result = deduplicate_memories(memories);
146        assert_eq!(result.len(), 1);
147        assert_eq!(result[0].1, "high score");
148    }
149}