Skip to main content

punch_memory/
memories.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use punch_types::{FighterId, PunchError, PunchResult};
5use tracing::debug;
6
7use crate::MemorySubstrate;
8
9/// A single memory entry stored for a fighter.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MemoryEntry {
12    pub key: String,
13    pub value: String,
14    pub confidence: f64,
15    pub created_at: DateTime<Utc>,
16    pub accessed_at: DateTime<Utc>,
17}
18
19impl MemorySubstrate {
20    /// Store (or overwrite) a key-value memory for a fighter.
21    pub async fn store_memory(
22        &self,
23        fighter_id: &FighterId,
24        key: &str,
25        value: &str,
26        confidence: f64,
27    ) -> PunchResult<()> {
28        let fighter_str = fighter_id.to_string();
29        let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
30
31        let conn = self.conn.lock().await;
32        conn.execute(
33            "INSERT INTO memories (fighter_id, key, value, confidence, created_at, accessed_at)
34             VALUES (?1, ?2, ?3, ?4, ?5, ?5)
35             ON CONFLICT(fighter_id, key) DO UPDATE SET
36                value = excluded.value,
37                confidence = excluded.confidence,
38                accessed_at = excluded.accessed_at",
39            rusqlite::params![fighter_str, key, value, confidence, now],
40        )
41        .map_err(|e| PunchError::Memory(format!("failed to store memory: {e}")))?;
42
43        debug!(fighter_id = %fighter_id, key = key, "memory stored");
44        Ok(())
45    }
46
47    /// Recall memories matching a query substring, ordered by confidence descending.
48    pub async fn recall_memories(
49        &self,
50        fighter_id: &FighterId,
51        query: &str,
52        limit: u32,
53    ) -> PunchResult<Vec<MemoryEntry>> {
54        let fighter_str = fighter_id.to_string();
55        let pattern = format!("%{query}%");
56
57        let conn = self.conn.lock().await;
58
59        // Update accessed_at for matched rows.
60        conn.execute(
61            "UPDATE memories SET accessed_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
62             WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)",
63            rusqlite::params![fighter_str, pattern],
64        )
65        .map_err(|e| PunchError::Memory(format!("failed to touch memory: {e}")))?;
66
67        let mut stmt = conn
68            .prepare(
69                "SELECT key, value, confidence, created_at, accessed_at
70                 FROM memories
71                 WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)
72                 ORDER BY confidence DESC
73                 LIMIT ?3",
74            )
75            .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
76
77        let rows = stmt
78            .query_map(rusqlite::params![fighter_str, pattern, limit], |row| {
79                let key: String = row.get(0)?;
80                let value: String = row.get(1)?;
81                let confidence: f64 = row.get(2)?;
82                let created_at: String = row.get(3)?;
83                let accessed_at: String = row.get(4)?;
84                Ok((key, value, confidence, created_at, accessed_at))
85            })
86            .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
87
88        let mut entries = Vec::new();
89        for row in rows {
90            let (key, value, confidence, created_at, accessed_at) =
91                row.map_err(|e| PunchError::Memory(format!("failed to read memory row: {e}")))?;
92
93            entries.push(MemoryEntry {
94                key,
95                value,
96                confidence,
97                created_at: parse_ts(&created_at)?,
98                accessed_at: parse_ts(&accessed_at)?,
99            });
100        }
101
102        Ok(entries)
103    }
104
105    /// Decay all memory confidences for a fighter by a multiplicative rate.
106    ///
107    /// Each memory's confidence becomes `confidence * (1.0 - rate)`. Memories
108    /// that decay below a threshold (0.01) are automatically deleted.
109    pub async fn decay_memories(&self, fighter_id: &FighterId, rate: f64) -> PunchResult<()> {
110        let fighter_str = fighter_id.to_string();
111        let factor = 1.0 - rate;
112
113        let conn = self.conn.lock().await;
114
115        conn.execute(
116            "UPDATE memories SET confidence = confidence * ?1 WHERE fighter_id = ?2",
117            rusqlite::params![factor, fighter_str],
118        )
119        .map_err(|e| PunchError::Memory(format!("failed to decay memories: {e}")))?;
120
121        // Prune near-zero memories.
122        conn.execute(
123            "DELETE FROM memories WHERE fighter_id = ?1 AND confidence < 0.01",
124            [&fighter_str],
125        )
126        .map_err(|e| PunchError::Memory(format!("failed to prune memories: {e}")))?;
127
128        debug!(fighter_id = %fighter_id, rate = rate, "memories decayed");
129        Ok(())
130    }
131
132    /// Delete a specific memory by key.
133    pub async fn delete_memory(&self, fighter_id: &FighterId, key: &str) -> PunchResult<()> {
134        let fighter_str = fighter_id.to_string();
135        let conn = self.conn.lock().await;
136
137        conn.execute(
138            "DELETE FROM memories WHERE fighter_id = ?1 AND key = ?2",
139            rusqlite::params![fighter_str, key],
140        )
141        .map_err(|e| PunchError::Memory(format!("failed to delete memory: {e}")))?;
142
143        debug!(fighter_id = %fighter_id, key = key, "memory deleted");
144        Ok(())
145    }
146}
147
148fn parse_ts(s: &str) -> PunchResult<DateTime<Utc>> {
149    DateTime::parse_from_rfc3339(s)
150        .map(|dt| dt.with_timezone(&Utc))
151        .or_else(|_| {
152            chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
153        })
154        .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
155}
156
157#[cfg(test)]
158mod tests {
159    use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
160
161    use crate::MemorySubstrate;
162
163    fn test_manifest() -> FighterManifest {
164        FighterManifest {
165            name: "Mem Fighter".into(),
166            description: "memory test".into(),
167            model: ModelConfig {
168                provider: Provider::Anthropic,
169                model: "claude-sonnet-4-20250514".into(),
170                api_key_env: None,
171                base_url: None,
172                max_tokens: Some(4096),
173                temperature: Some(0.7),
174            },
175            system_prompt: "test".into(),
176            capabilities: Vec::new(),
177            weight_class: WeightClass::Featherweight,
178            tenant_id: None,
179        }
180    }
181
182    #[tokio::test]
183    async fn test_store_and_recall() {
184        let substrate = MemorySubstrate::in_memory().unwrap();
185        let fid = punch_types::FighterId::new();
186        substrate
187            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
188            .await
189            .unwrap();
190
191        substrate
192            .store_memory(&fid, "user_name", "Alice", 0.9)
193            .await
194            .unwrap();
195        substrate
196            .store_memory(&fid, "user_lang", "Rust", 0.8)
197            .await
198            .unwrap();
199
200        let results = substrate.recall_memories(&fid, "user", 10).await.unwrap();
201        assert_eq!(results.len(), 2);
202        // Highest confidence first.
203        assert_eq!(results[0].key, "user_name");
204    }
205
206    #[tokio::test]
207    async fn test_decay_memories() {
208        let substrate = MemorySubstrate::in_memory().unwrap();
209        let fid = punch_types::FighterId::new();
210        substrate
211            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
212            .await
213            .unwrap();
214
215        substrate
216            .store_memory(&fid, "fact", "sky is blue", 0.05)
217            .await
218            .unwrap();
219
220        // A heavy decay should prune a low-confidence memory.
221        substrate.decay_memories(&fid, 0.9).await.unwrap();
222
223        let results = substrate.recall_memories(&fid, "fact", 10).await.unwrap();
224        assert!(results.is_empty());
225    }
226
227    #[tokio::test]
228    async fn test_store_memory_overwrites() {
229        let substrate = MemorySubstrate::in_memory().unwrap();
230        let fid = punch_types::FighterId::new();
231        substrate
232            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
233            .await
234            .unwrap();
235
236        substrate
237            .store_memory(&fid, "key", "old_value", 0.5)
238            .await
239            .unwrap();
240        substrate
241            .store_memory(&fid, "key", "new_value", 0.9)
242            .await
243            .unwrap();
244
245        let results = substrate.recall_memories(&fid, "key", 10).await.unwrap();
246        assert_eq!(results.len(), 1);
247        assert_eq!(results[0].value, "new_value");
248        assert!((results[0].confidence - 0.9).abs() < f64::EPSILON);
249    }
250
251    #[tokio::test]
252    async fn test_recall_empty() {
253        let substrate = MemorySubstrate::in_memory().unwrap();
254        let fid = punch_types::FighterId::new();
255        substrate
256            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
257            .await
258            .unwrap();
259
260        let results = substrate
261            .recall_memories(&fid, "nothing", 10)
262            .await
263            .unwrap();
264        assert!(results.is_empty());
265    }
266
267    #[tokio::test]
268    async fn test_recall_limit() {
269        let substrate = MemorySubstrate::in_memory().unwrap();
270        let fid = punch_types::FighterId::new();
271        substrate
272            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
273            .await
274            .unwrap();
275
276        for i in 0..10 {
277            substrate
278                .store_memory(&fid, &format!("item_{i}"), &format!("val_{i}"), 0.5)
279                .await
280                .unwrap();
281        }
282
283        let results = substrate.recall_memories(&fid, "item", 3).await.unwrap();
284        assert_eq!(results.len(), 3);
285    }
286
287    #[tokio::test]
288    async fn test_recall_ordered_by_confidence() {
289        let substrate = MemorySubstrate::in_memory().unwrap();
290        let fid = punch_types::FighterId::new();
291        substrate
292            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
293            .await
294            .unwrap();
295
296        substrate
297            .store_memory(&fid, "low_prio", "data", 0.2)
298            .await
299            .unwrap();
300        substrate
301            .store_memory(&fid, "high_prio", "data", 0.9)
302            .await
303            .unwrap();
304        substrate
305            .store_memory(&fid, "mid_prio", "data", 0.5)
306            .await
307            .unwrap();
308
309        let results = substrate.recall_memories(&fid, "prio", 10).await.unwrap();
310        assert_eq!(results.len(), 3);
311        assert_eq!(results[0].key, "high_prio");
312        assert_eq!(results[2].key, "low_prio");
313    }
314
315    #[tokio::test]
316    async fn test_decay_preserves_high_confidence() {
317        let substrate = MemorySubstrate::in_memory().unwrap();
318        let fid = punch_types::FighterId::new();
319        substrate
320            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
321            .await
322            .unwrap();
323
324        substrate
325            .store_memory(&fid, "strong", "data", 1.0)
326            .await
327            .unwrap();
328
329        // Light decay should not prune a 1.0 confidence memory
330        substrate.decay_memories(&fid, 0.1).await.unwrap();
331
332        let results = substrate.recall_memories(&fid, "strong", 10).await.unwrap();
333        assert_eq!(results.len(), 1);
334        assert!(results[0].confidence > 0.5);
335    }
336
337    #[tokio::test]
338    async fn test_delete_nonexistent_memory() {
339        let substrate = MemorySubstrate::in_memory().unwrap();
340        let fid = punch_types::FighterId::new();
341        substrate
342            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
343            .await
344            .unwrap();
345
346        // Should not error
347        substrate.delete_memory(&fid, "nonexistent").await.unwrap();
348    }
349
350    #[tokio::test]
351    async fn test_delete_memory() {
352        let substrate = MemorySubstrate::in_memory().unwrap();
353        let fid = punch_types::FighterId::new();
354        substrate
355            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
356            .await
357            .unwrap();
358
359        substrate
360            .store_memory(&fid, "temp", "data", 1.0)
361            .await
362            .unwrap();
363        substrate.delete_memory(&fid, "temp").await.unwrap();
364
365        let results = substrate.recall_memories(&fid, "temp", 10).await.unwrap();
366        assert!(results.is_empty());
367    }
368}