1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use punch_types::{FighterId, PunchError, PunchResult};
5use tracing::debug;
6
7use crate::MemorySubstrate;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MemoryEntry {
12 pub key: String,
13 pub value: String,
14 pub confidence: f64,
15 pub created_at: DateTime<Utc>,
16 pub accessed_at: DateTime<Utc>,
17}
18
19impl MemorySubstrate {
20 pub async fn store_memory(
22 &self,
23 fighter_id: &FighterId,
24 key: &str,
25 value: &str,
26 confidence: f64,
27 ) -> PunchResult<()> {
28 let fighter_str = fighter_id.to_string();
29 let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
30
31 let conn = self.conn.lock().await;
32 conn.execute(
33 "INSERT INTO memories (fighter_id, key, value, confidence, created_at, accessed_at)
34 VALUES (?1, ?2, ?3, ?4, ?5, ?5)
35 ON CONFLICT(fighter_id, key) DO UPDATE SET
36 value = excluded.value,
37 confidence = excluded.confidence,
38 accessed_at = excluded.accessed_at",
39 rusqlite::params![fighter_str, key, value, confidence, now],
40 )
41 .map_err(|e| PunchError::Memory(format!("failed to store memory: {e}")))?;
42
43 debug!(fighter_id = %fighter_id, key = key, "memory stored");
44 Ok(())
45 }
46
47 pub async fn recall_memories(
49 &self,
50 fighter_id: &FighterId,
51 query: &str,
52 limit: u32,
53 ) -> PunchResult<Vec<MemoryEntry>> {
54 let fighter_str = fighter_id.to_string();
55 let pattern = format!("%{query}%");
56
57 let conn = self.conn.lock().await;
58
59 conn.execute(
61 "UPDATE memories SET accessed_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
62 WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)",
63 rusqlite::params![fighter_str, pattern],
64 )
65 .map_err(|e| PunchError::Memory(format!("failed to touch memory: {e}")))?;
66
67 let mut stmt = conn
68 .prepare(
69 "SELECT key, value, confidence, created_at, accessed_at
70 FROM memories
71 WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)
72 ORDER BY confidence DESC
73 LIMIT ?3",
74 )
75 .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
76
77 let rows = stmt
78 .query_map(rusqlite::params![fighter_str, pattern, limit], |row| {
79 let key: String = row.get(0)?;
80 let value: String = row.get(1)?;
81 let confidence: f64 = row.get(2)?;
82 let created_at: String = row.get(3)?;
83 let accessed_at: String = row.get(4)?;
84 Ok((key, value, confidence, created_at, accessed_at))
85 })
86 .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
87
88 let mut entries = Vec::new();
89 for row in rows {
90 let (key, value, confidence, created_at, accessed_at) =
91 row.map_err(|e| PunchError::Memory(format!("failed to read memory row: {e}")))?;
92
93 entries.push(MemoryEntry {
94 key,
95 value,
96 confidence,
97 created_at: parse_ts(&created_at)?,
98 accessed_at: parse_ts(&accessed_at)?,
99 });
100 }
101
102 Ok(entries)
103 }
104
105 pub async fn decay_memories(&self, fighter_id: &FighterId, rate: f64) -> PunchResult<()> {
110 let fighter_str = fighter_id.to_string();
111 let factor = 1.0 - rate;
112
113 let conn = self.conn.lock().await;
114
115 conn.execute(
116 "UPDATE memories SET confidence = confidence * ?1 WHERE fighter_id = ?2",
117 rusqlite::params![factor, fighter_str],
118 )
119 .map_err(|e| PunchError::Memory(format!("failed to decay memories: {e}")))?;
120
121 conn.execute(
123 "DELETE FROM memories WHERE fighter_id = ?1 AND confidence < 0.01",
124 [&fighter_str],
125 )
126 .map_err(|e| PunchError::Memory(format!("failed to prune memories: {e}")))?;
127
128 debug!(fighter_id = %fighter_id, rate = rate, "memories decayed");
129 Ok(())
130 }
131
132 pub async fn delete_memory(&self, fighter_id: &FighterId, key: &str) -> PunchResult<()> {
134 let fighter_str = fighter_id.to_string();
135 let conn = self.conn.lock().await;
136
137 conn.execute(
138 "DELETE FROM memories WHERE fighter_id = ?1 AND key = ?2",
139 rusqlite::params![fighter_str, key],
140 )
141 .map_err(|e| PunchError::Memory(format!("failed to delete memory: {e}")))?;
142
143 debug!(fighter_id = %fighter_id, key = key, "memory deleted");
144 Ok(())
145 }
146}
147
148fn parse_ts(s: &str) -> PunchResult<DateTime<Utc>> {
149 DateTime::parse_from_rfc3339(s)
150 .map(|dt| dt.with_timezone(&Utc))
151 .or_else(|_| {
152 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
153 })
154 .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
155}
156
157#[cfg(test)]
158mod tests {
159 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
160
161 use crate::MemorySubstrate;
162
163 fn test_manifest() -> FighterManifest {
164 FighterManifest {
165 name: "Mem Fighter".into(),
166 description: "memory test".into(),
167 model: ModelConfig {
168 provider: Provider::Anthropic,
169 model: "claude-sonnet-4-20250514".into(),
170 api_key_env: None,
171 base_url: None,
172 max_tokens: Some(4096),
173 temperature: Some(0.7),
174 },
175 system_prompt: "test".into(),
176 capabilities: Vec::new(),
177 weight_class: WeightClass::Featherweight,
178 tenant_id: None,
179 }
180 }
181
182 #[tokio::test]
183 async fn test_store_and_recall() {
184 let substrate = MemorySubstrate::in_memory().unwrap();
185 let fid = punch_types::FighterId::new();
186 substrate
187 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
188 .await
189 .unwrap();
190
191 substrate
192 .store_memory(&fid, "user_name", "Alice", 0.9)
193 .await
194 .unwrap();
195 substrate
196 .store_memory(&fid, "user_lang", "Rust", 0.8)
197 .await
198 .unwrap();
199
200 let results = substrate.recall_memories(&fid, "user", 10).await.unwrap();
201 assert_eq!(results.len(), 2);
202 assert_eq!(results[0].key, "user_name");
204 }
205
206 #[tokio::test]
207 async fn test_decay_memories() {
208 let substrate = MemorySubstrate::in_memory().unwrap();
209 let fid = punch_types::FighterId::new();
210 substrate
211 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
212 .await
213 .unwrap();
214
215 substrate
216 .store_memory(&fid, "fact", "sky is blue", 0.05)
217 .await
218 .unwrap();
219
220 substrate.decay_memories(&fid, 0.9).await.unwrap();
222
223 let results = substrate.recall_memories(&fid, "fact", 10).await.unwrap();
224 assert!(results.is_empty());
225 }
226
227 #[tokio::test]
228 async fn test_store_memory_overwrites() {
229 let substrate = MemorySubstrate::in_memory().unwrap();
230 let fid = punch_types::FighterId::new();
231 substrate
232 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
233 .await
234 .unwrap();
235
236 substrate.store_memory(&fid, "key", "old_value", 0.5).await.unwrap();
237 substrate.store_memory(&fid, "key", "new_value", 0.9).await.unwrap();
238
239 let results = substrate.recall_memories(&fid, "key", 10).await.unwrap();
240 assert_eq!(results.len(), 1);
241 assert_eq!(results[0].value, "new_value");
242 assert!((results[0].confidence - 0.9).abs() < f64::EPSILON);
243 }
244
245 #[tokio::test]
246 async fn test_recall_empty() {
247 let substrate = MemorySubstrate::in_memory().unwrap();
248 let fid = punch_types::FighterId::new();
249 substrate
250 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
251 .await
252 .unwrap();
253
254 let results = substrate.recall_memories(&fid, "nothing", 10).await.unwrap();
255 assert!(results.is_empty());
256 }
257
258 #[tokio::test]
259 async fn test_recall_limit() {
260 let substrate = MemorySubstrate::in_memory().unwrap();
261 let fid = punch_types::FighterId::new();
262 substrate
263 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
264 .await
265 .unwrap();
266
267 for i in 0..10 {
268 substrate
269 .store_memory(&fid, &format!("item_{i}"), &format!("val_{i}"), 0.5)
270 .await
271 .unwrap();
272 }
273
274 let results = substrate.recall_memories(&fid, "item", 3).await.unwrap();
275 assert_eq!(results.len(), 3);
276 }
277
278 #[tokio::test]
279 async fn test_recall_ordered_by_confidence() {
280 let substrate = MemorySubstrate::in_memory().unwrap();
281 let fid = punch_types::FighterId::new();
282 substrate
283 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
284 .await
285 .unwrap();
286
287 substrate.store_memory(&fid, "low_prio", "data", 0.2).await.unwrap();
288 substrate.store_memory(&fid, "high_prio", "data", 0.9).await.unwrap();
289 substrate.store_memory(&fid, "mid_prio", "data", 0.5).await.unwrap();
290
291 let results = substrate.recall_memories(&fid, "prio", 10).await.unwrap();
292 assert_eq!(results.len(), 3);
293 assert_eq!(results[0].key, "high_prio");
294 assert_eq!(results[2].key, "low_prio");
295 }
296
297 #[tokio::test]
298 async fn test_decay_preserves_high_confidence() {
299 let substrate = MemorySubstrate::in_memory().unwrap();
300 let fid = punch_types::FighterId::new();
301 substrate
302 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
303 .await
304 .unwrap();
305
306 substrate.store_memory(&fid, "strong", "data", 1.0).await.unwrap();
307
308 substrate.decay_memories(&fid, 0.1).await.unwrap();
310
311 let results = substrate.recall_memories(&fid, "strong", 10).await.unwrap();
312 assert_eq!(results.len(), 1);
313 assert!(results[0].confidence > 0.5);
314 }
315
316 #[tokio::test]
317 async fn test_delete_nonexistent_memory() {
318 let substrate = MemorySubstrate::in_memory().unwrap();
319 let fid = punch_types::FighterId::new();
320 substrate
321 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
322 .await
323 .unwrap();
324
325 substrate.delete_memory(&fid, "nonexistent").await.unwrap();
327 }
328
329 #[tokio::test]
330 async fn test_delete_memory() {
331 let substrate = MemorySubstrate::in_memory().unwrap();
332 let fid = punch_types::FighterId::new();
333 substrate
334 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
335 .await
336 .unwrap();
337
338 substrate
339 .store_memory(&fid, "temp", "data", 1.0)
340 .await
341 .unwrap();
342 substrate.delete_memory(&fid, "temp").await.unwrap();
343
344 let results = substrate.recall_memories(&fid, "temp", 10).await.unwrap();
345 assert!(results.is_empty());
346 }
347}