1use std::collections::HashMap;
2
3use mem7_core::MemoryAction;
4use mem7_error::{Mem7Error, Result};
5use serde::Deserialize;
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Deserialize)]
10pub struct MemoryDecision {
11 pub id: String,
12 pub text: String,
13 pub event: MemoryAction,
14 pub old_memory: Option<String>,
15}
16
17#[derive(Debug, Clone, Deserialize)]
19pub struct MemoryUpdateResponse {
20 pub memory: Vec<MemoryDecision>,
21}
22
23pub struct IdMapping {
26 int_to_uuid: HashMap<String, Uuid>,
27}
28
29impl IdMapping {
30 pub fn new() -> Self {
31 Self {
32 int_to_uuid: HashMap::new(),
33 }
34 }
35
36 pub fn add(&mut self, int_id: String, uuid: Uuid) {
37 self.int_to_uuid.insert(int_id, uuid);
38 }
39
40 pub fn resolve(&self, int_id: &str) -> Option<Uuid> {
41 self.int_to_uuid.get(int_id).copied()
42 }
43}
44
45impl Default for IdMapping {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51pub fn build_existing_memory_dict(existing: &[(Uuid, String)]) -> (serde_json::Value, IdMapping) {
54 let mut mapping = IdMapping::new();
55 let mut entries = Vec::new();
56
57 for (idx, (uuid, text)) in existing.iter().enumerate() {
58 let int_id = idx.to_string();
59 mapping.add(int_id.clone(), *uuid);
60 entries.push(serde_json::json!({
61 "id": int_id,
62 "text": text,
63 }));
64 }
65
66 (serde_json::Value::Array(entries), mapping)
67}
68
69pub fn parse_memory_update_response(json_str: &str) -> Result<MemoryUpdateResponse> {
71 if let Ok(resp) = serde_json::from_str::<MemoryUpdateResponse>(json_str) {
73 return Ok(resp);
74 }
75
76 let trimmed = json_str.trim();
78 let cleaned = if trimmed.starts_with("```json") {
79 trimmed
80 .trim_start_matches("```json")
81 .trim_end_matches("```")
82 .trim()
83 } else if trimmed.starts_with("```") {
84 trimmed
85 .trim_start_matches("```")
86 .trim_end_matches("```")
87 .trim()
88 } else {
89 trimmed
90 };
91
92 serde_json::from_str(cleaned).map_err(|e| {
93 Mem7Error::Serialization(format!(
94 "Failed to parse memory update response: {e}\nRaw: {json_str}"
95 ))
96 })
97}
98
99pub fn deduplicate_memories(memories: Vec<(Uuid, String, f32)>) -> Vec<(Uuid, String)> {
102 let mut seen = HashMap::new();
103 for (uuid, text, score) in memories {
104 seen.entry(uuid)
105 .and_modify(|(existing_text, existing_score): &mut (String, f32)| {
106 if score > *existing_score {
107 *existing_text = text.clone();
108 *existing_score = score;
109 }
110 })
111 .or_insert((text, score));
112 }
113 seen.into_iter()
114 .map(|(uuid, (text, _))| (uuid, text))
115 .collect()
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn parse_valid_response() {
124 let json = r#"{"memory": [{"id": "0", "text": "Loves pizza", "event": "ADD"}]}"#;
125 let resp = parse_memory_update_response(json).unwrap();
126 assert_eq!(resp.memory.len(), 1);
127 assert_eq!(resp.memory[0].text, "Loves pizza");
128 assert_eq!(resp.memory[0].event, MemoryAction::Add);
129 }
130
131 #[test]
132 fn parse_code_block_response() {
133 let json = "```json\n{\"memory\": [{\"id\": \"0\", \"text\": \"test\", \"event\": \"NONE\"}]}\n```";
134 let resp = parse_memory_update_response(json).unwrap();
135 assert_eq!(resp.memory.len(), 1);
136 }
137
138 #[test]
139 fn dedup_keeps_highest_score() {
140 let id = Uuid::now_v7();
141 let memories = vec![
142 (id, "low score".into(), 0.5),
143 (id, "high score".into(), 0.9),
144 ];
145 let result = deduplicate_memories(memories);
146 assert_eq!(result.len(), 1);
147 assert_eq!(result[0].1, "high score");
148 }
149}