Skip to main content

punch_memory/
memories.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use punch_types::{FighterId, PunchError, PunchResult};
5use tracing::debug;
6
7use crate::MemorySubstrate;
8
9/// A single memory entry stored for a fighter.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct MemoryEntry {
12    pub key: String,
13    pub value: String,
14    pub confidence: f64,
15    pub created_at: DateTime<Utc>,
16    pub accessed_at: DateTime<Utc>,
17}
18
19impl MemorySubstrate {
20    /// Store (or overwrite) a key-value memory for a fighter.
21    pub async fn store_memory(
22        &self,
23        fighter_id: &FighterId,
24        key: &str,
25        value: &str,
26        confidence: f64,
27    ) -> PunchResult<()> {
28        let fighter_str = fighter_id.to_string();
29        let now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
30
31        let conn = self.conn.lock().await;
32        conn.execute(
33            "INSERT INTO memories (fighter_id, key, value, confidence, created_at, accessed_at)
34             VALUES (?1, ?2, ?3, ?4, ?5, ?5)
35             ON CONFLICT(fighter_id, key) DO UPDATE SET
36                value = excluded.value,
37                confidence = excluded.confidence,
38                accessed_at = excluded.accessed_at",
39            rusqlite::params![fighter_str, key, value, confidence, now],
40        )
41        .map_err(|e| PunchError::Memory(format!("failed to store memory: {e}")))?;
42
43        debug!(fighter_id = %fighter_id, key = key, "memory stored");
44        Ok(())
45    }
46
47    /// Recall memories matching a query substring, ordered by confidence descending.
48    pub async fn recall_memories(
49        &self,
50        fighter_id: &FighterId,
51        query: &str,
52        limit: u32,
53    ) -> PunchResult<Vec<MemoryEntry>> {
54        let fighter_str = fighter_id.to_string();
55        let pattern = format!("%{query}%");
56
57        let conn = self.conn.lock().await;
58
59        // Update accessed_at for matched rows.
60        conn.execute(
61            "UPDATE memories SET accessed_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')
62             WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)",
63            rusqlite::params![fighter_str, pattern],
64        )
65        .map_err(|e| PunchError::Memory(format!("failed to touch memory: {e}")))?;
66
67        let mut stmt = conn
68            .prepare(
69                "SELECT key, value, confidence, created_at, accessed_at
70                 FROM memories
71                 WHERE fighter_id = ?1 AND (key LIKE ?2 OR value LIKE ?2)
72                 ORDER BY confidence DESC
73                 LIMIT ?3",
74            )
75            .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
76
77        let rows = stmt
78            .query_map(rusqlite::params![fighter_str, pattern, limit], |row| {
79                let key: String = row.get(0)?;
80                let value: String = row.get(1)?;
81                let confidence: f64 = row.get(2)?;
82                let created_at: String = row.get(3)?;
83                let accessed_at: String = row.get(4)?;
84                Ok((key, value, confidence, created_at, accessed_at))
85            })
86            .map_err(|e| PunchError::Memory(format!("failed to recall memories: {e}")))?;
87
88        let mut entries = Vec::new();
89        for row in rows {
90            let (key, value, confidence, created_at, accessed_at) =
91                row.map_err(|e| PunchError::Memory(format!("failed to read memory row: {e}")))?;
92
93            entries.push(MemoryEntry {
94                key,
95                value,
96                confidence,
97                created_at: parse_ts(&created_at)?,
98                accessed_at: parse_ts(&accessed_at)?,
99            });
100        }
101
102        Ok(entries)
103    }
104
105    /// Decay all memory confidences for a fighter by a multiplicative rate.
106    ///
107    /// Each memory's confidence becomes `confidence * (1.0 - rate)`. Memories
108    /// that decay below a threshold (0.01) are automatically deleted.
109    pub async fn decay_memories(&self, fighter_id: &FighterId, rate: f64) -> PunchResult<()> {
110        let fighter_str = fighter_id.to_string();
111        let factor = 1.0 - rate;
112
113        let conn = self.conn.lock().await;
114
115        conn.execute(
116            "UPDATE memories SET confidence = confidence * ?1 WHERE fighter_id = ?2",
117            rusqlite::params![factor, fighter_str],
118        )
119        .map_err(|e| PunchError::Memory(format!("failed to decay memories: {e}")))?;
120
121        // Prune near-zero memories.
122        conn.execute(
123            "DELETE FROM memories WHERE fighter_id = ?1 AND confidence < 0.01",
124            [&fighter_str],
125        )
126        .map_err(|e| PunchError::Memory(format!("failed to prune memories: {e}")))?;
127
128        debug!(fighter_id = %fighter_id, rate = rate, "memories decayed");
129        Ok(())
130    }
131
132    /// Delete a specific memory by key.
133    pub async fn delete_memory(&self, fighter_id: &FighterId, key: &str) -> PunchResult<()> {
134        let fighter_str = fighter_id.to_string();
135        let conn = self.conn.lock().await;
136
137        conn.execute(
138            "DELETE FROM memories WHERE fighter_id = ?1 AND key = ?2",
139            rusqlite::params![fighter_str, key],
140        )
141        .map_err(|e| PunchError::Memory(format!("failed to delete memory: {e}")))?;
142
143        debug!(fighter_id = %fighter_id, key = key, "memory deleted");
144        Ok(())
145    }
146}
147
148fn parse_ts(s: &str) -> PunchResult<DateTime<Utc>> {
149    DateTime::parse_from_rfc3339(s)
150        .map(|dt| dt.with_timezone(&Utc))
151        .or_else(|_| {
152            chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
153        })
154        .map_err(|e| PunchError::Memory(format!("invalid timestamp '{s}': {e}")))
155}
156
157#[cfg(test)]
158mod tests {
159    use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
160
161    use crate::MemorySubstrate;
162
163    fn test_manifest() -> FighterManifest {
164        FighterManifest {
165            name: "Mem Fighter".into(),
166            description: "memory test".into(),
167            model: ModelConfig {
168                provider: Provider::Anthropic,
169                model: "claude-sonnet-4-20250514".into(),
170                api_key_env: None,
171                base_url: None,
172                max_tokens: Some(4096),
173                temperature: Some(0.7),
174            },
175            system_prompt: "test".into(),
176            capabilities: Vec::new(),
177            weight_class: WeightClass::Featherweight,
178            tenant_id: None,
179        }
180    }
181
182    #[tokio::test]
183    async fn test_store_and_recall() {
184        let substrate = MemorySubstrate::in_memory().unwrap();
185        let fid = punch_types::FighterId::new();
186        substrate
187            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
188            .await
189            .unwrap();
190
191        substrate
192            .store_memory(&fid, "user_name", "Alice", 0.9)
193            .await
194            .unwrap();
195        substrate
196            .store_memory(&fid, "user_lang", "Rust", 0.8)
197            .await
198            .unwrap();
199
200        let results = substrate.recall_memories(&fid, "user", 10).await.unwrap();
201        assert_eq!(results.len(), 2);
202        // Highest confidence first.
203        assert_eq!(results[0].key, "user_name");
204    }
205
206    #[tokio::test]
207    async fn test_decay_memories() {
208        let substrate = MemorySubstrate::in_memory().unwrap();
209        let fid = punch_types::FighterId::new();
210        substrate
211            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
212            .await
213            .unwrap();
214
215        substrate
216            .store_memory(&fid, "fact", "sky is blue", 0.05)
217            .await
218            .unwrap();
219
220        // A heavy decay should prune a low-confidence memory.
221        substrate.decay_memories(&fid, 0.9).await.unwrap();
222
223        let results = substrate.recall_memories(&fid, "fact", 10).await.unwrap();
224        assert!(results.is_empty());
225    }
226
227    #[tokio::test]
228    async fn test_store_memory_overwrites() {
229        let substrate = MemorySubstrate::in_memory().unwrap();
230        let fid = punch_types::FighterId::new();
231        substrate
232            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
233            .await
234            .unwrap();
235
236        substrate.store_memory(&fid, "key", "old_value", 0.5).await.unwrap();
237        substrate.store_memory(&fid, "key", "new_value", 0.9).await.unwrap();
238
239        let results = substrate.recall_memories(&fid, "key", 10).await.unwrap();
240        assert_eq!(results.len(), 1);
241        assert_eq!(results[0].value, "new_value");
242        assert!((results[0].confidence - 0.9).abs() < f64::EPSILON);
243    }
244
245    #[tokio::test]
246    async fn test_recall_empty() {
247        let substrate = MemorySubstrate::in_memory().unwrap();
248        let fid = punch_types::FighterId::new();
249        substrate
250            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
251            .await
252            .unwrap();
253
254        let results = substrate.recall_memories(&fid, "nothing", 10).await.unwrap();
255        assert!(results.is_empty());
256    }
257
258    #[tokio::test]
259    async fn test_recall_limit() {
260        let substrate = MemorySubstrate::in_memory().unwrap();
261        let fid = punch_types::FighterId::new();
262        substrate
263            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
264            .await
265            .unwrap();
266
267        for i in 0..10 {
268            substrate
269                .store_memory(&fid, &format!("item_{i}"), &format!("val_{i}"), 0.5)
270                .await
271                .unwrap();
272        }
273
274        let results = substrate.recall_memories(&fid, "item", 3).await.unwrap();
275        assert_eq!(results.len(), 3);
276    }
277
278    #[tokio::test]
279    async fn test_recall_ordered_by_confidence() {
280        let substrate = MemorySubstrate::in_memory().unwrap();
281        let fid = punch_types::FighterId::new();
282        substrate
283            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
284            .await
285            .unwrap();
286
287        substrate.store_memory(&fid, "low_prio", "data", 0.2).await.unwrap();
288        substrate.store_memory(&fid, "high_prio", "data", 0.9).await.unwrap();
289        substrate.store_memory(&fid, "mid_prio", "data", 0.5).await.unwrap();
290
291        let results = substrate.recall_memories(&fid, "prio", 10).await.unwrap();
292        assert_eq!(results.len(), 3);
293        assert_eq!(results[0].key, "high_prio");
294        assert_eq!(results[2].key, "low_prio");
295    }
296
297    #[tokio::test]
298    async fn test_decay_preserves_high_confidence() {
299        let substrate = MemorySubstrate::in_memory().unwrap();
300        let fid = punch_types::FighterId::new();
301        substrate
302            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
303            .await
304            .unwrap();
305
306        substrate.store_memory(&fid, "strong", "data", 1.0).await.unwrap();
307
308        // Light decay should not prune a 1.0 confidence memory
309        substrate.decay_memories(&fid, 0.1).await.unwrap();
310
311        let results = substrate.recall_memories(&fid, "strong", 10).await.unwrap();
312        assert_eq!(results.len(), 1);
313        assert!(results[0].confidence > 0.5);
314    }
315
316    #[tokio::test]
317    async fn test_delete_nonexistent_memory() {
318        let substrate = MemorySubstrate::in_memory().unwrap();
319        let fid = punch_types::FighterId::new();
320        substrate
321            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
322            .await
323            .unwrap();
324
325        // Should not error
326        substrate.delete_memory(&fid, "nonexistent").await.unwrap();
327    }
328
329    #[tokio::test]
330    async fn test_delete_memory() {
331        let substrate = MemorySubstrate::in_memory().unwrap();
332        let fid = punch_types::FighterId::new();
333        substrate
334            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
335            .await
336            .unwrap();
337
338        substrate
339            .store_memory(&fid, "temp", "data", 1.0)
340            .await
341            .unwrap();
342        substrate.delete_memory(&fid, "temp").await.unwrap();
343
344        let results = substrate.recall_memories(&fid, "temp", 10).await.unwrap();
345        assert!(results.is_empty());
346    }
347}