Skip to main content

punch_memory/
maintenance.rs

1//! Maintenance operations for the memory substrate.
2//!
3//! Provides cleanup, compaction, vacuum, and statistical query operations
4//! used by the gorilla execution engine (Data Sweeper, Report Generator, etc.).
5
6use chrono::{DateTime, Utc};
7use tracing::{debug, info};
8
9use punch_types::{PunchError, PunchResult};
10
11use crate::MemorySubstrate;
12
13impl MemorySubstrate {
14    /// Delete bout messages older than the given cutoff date.
15    ///
16    /// Returns the number of messages deleted.
17    pub async fn cleanup_old_messages(&self, cutoff: DateTime<Utc>) -> PunchResult<usize> {
18        let cutoff_str = cutoff.format("%Y-%m-%dT%H:%M:%SZ").to_string();
19        let conn = self.conn.lock().await;
20
21        let count = conn
22            .execute(
23                "DELETE FROM messages WHERE created_at < ?1",
24                rusqlite::params![cutoff_str],
25            )
26            .map_err(|e| PunchError::Memory(format!("failed to cleanup old messages: {e}")))?;
27
28        info!(deleted = count, cutoff = %cutoff_str, "cleaned up old messages");
29        Ok(count)
30    }
31
32    /// Compact memory entries by removing low-confidence entries when a fighter
33    /// exceeds the maximum number of memories.
34    ///
35    /// Returns the number of entries removed.
36    pub async fn compact_memories(&self, max_per_fighter: usize) -> PunchResult<usize> {
37        let conn = self.conn.lock().await;
38
39        // Find fighters that exceed the limit.
40        let mut stmt = conn
41            .prepare(
42                "SELECT fighter_id, COUNT(*) as cnt FROM memories \
43                 GROUP BY fighter_id HAVING cnt > ?1",
44            )
45            .map_err(|e| PunchError::Memory(format!("failed to query memory counts: {e}")))?;
46
47        let fighters: Vec<(String, usize)> = stmt
48            .query_map(rusqlite::params![max_per_fighter as i64], |row| {
49                Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
50            })
51            .map_err(|e| PunchError::Memory(format!("failed to list fighter memories: {e}")))?
52            .filter_map(|r| r.ok())
53            .collect();
54
55        let mut total_removed = 0;
56
57        for (fighter_id, count) in &fighters {
58            let excess = count - max_per_fighter;
59            if excess > 0 {
60                // Delete the lowest-confidence entries for this fighter.
61                let deleted = conn
62                    .execute(
63                        "DELETE FROM memories WHERE rowid IN (\
64                             SELECT rowid FROM memories \
65                             WHERE fighter_id = ?1 \
66                             ORDER BY confidence ASC \
67                             LIMIT ?2\
68                         )",
69                        rusqlite::params![fighter_id, excess],
70                    )
71                    .map_err(|e| {
72                        PunchError::Memory(format!("failed to compact memories: {e}"))
73                    })?;
74                total_removed += deleted;
75                debug!(
76                    fighter_id = %fighter_id,
77                    removed = deleted,
78                    "compacted memories for fighter"
79                );
80            }
81        }
82
83        info!(total_removed, "memory compaction complete");
84        Ok(total_removed)
85    }
86
87    /// Run SQLite VACUUM to reclaim disk space.
88    pub async fn vacuum(&self) -> PunchResult<()> {
89        let conn = self.conn.lock().await;
90        conn.execute_batch("VACUUM")
91            .map_err(|e| PunchError::Memory(format!("vacuum failed: {e}")))?;
92        info!("database vacuumed");
93        Ok(())
94    }
95
96    /// Count bouts created within a time period.
97    pub async fn count_bouts_in_period(
98        &self,
99        start: DateTime<Utc>,
100        end: DateTime<Utc>,
101    ) -> PunchResult<usize> {
102        let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
103        let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
104        let conn = self.conn.lock().await;
105
106        let count: i64 = conn
107            .query_row(
108                "SELECT COUNT(*) FROM bouts WHERE created_at >= ?1 AND created_at <= ?2",
109                rusqlite::params![start_str, end_str],
110                |row| row.get(0),
111            )
112            .map_err(|e| PunchError::Memory(format!("failed to count bouts: {e}")))?;
113
114        Ok(count as usize)
115    }
116
117    /// Count messages created within a time period.
118    pub async fn count_messages_in_period(
119        &self,
120        start: DateTime<Utc>,
121        end: DateTime<Utc>,
122    ) -> PunchResult<usize> {
123        let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
124        let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
125        let conn = self.conn.lock().await;
126
127        let count: i64 = conn
128            .query_row(
129                "SELECT COUNT(*) FROM messages WHERE created_at >= ?1 AND created_at <= ?2",
130                rusqlite::params![start_str, end_str],
131                |row| row.get(0),
132            )
133            .map_err(|e| PunchError::Memory(format!("failed to count messages: {e}")))?;
134
135        Ok(count as usize)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use punch_types::{
142        FighterId, FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role,
143        WeightClass,
144    };
145
146    use crate::MemorySubstrate;
147
148    fn test_manifest() -> FighterManifest {
149        FighterManifest {
150            name: "Test".into(),
151            description: "test".into(),
152            model: ModelConfig {
153                provider: Provider::Anthropic,
154                model: "claude-sonnet-4-20250514".into(),
155                api_key_env: None,
156                base_url: None,
157                max_tokens: Some(4096),
158                temperature: Some(0.7),
159            },
160            system_prompt: "test".into(),
161            capabilities: Vec::new(),
162            weight_class: WeightClass::Middleweight,
163            tenant_id: None,
164        }
165    }
166
167    #[tokio::test]
168    async fn test_cleanup_old_messages() {
169        let substrate = MemorySubstrate::in_memory().unwrap();
170        let fighter_id = FighterId::new();
171        substrate
172            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
173            .await
174            .unwrap();
175        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
176
177        substrate
178            .save_message(&bout_id, &Message::new(Role::User, "old msg"))
179            .await
180            .unwrap();
181
182        // Cutoff in the future should delete everything.
183        let cutoff = chrono::Utc::now() + chrono::Duration::hours(1);
184        let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
185        assert!(deleted >= 1);
186    }
187
188    #[tokio::test]
189    async fn test_cleanup_old_messages_none_deleted() {
190        let substrate = MemorySubstrate::in_memory().unwrap();
191        let fighter_id = FighterId::new();
192        substrate
193            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
194            .await
195            .unwrap();
196        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
197
198        substrate
199            .save_message(&bout_id, &Message::new(Role::User, "recent msg"))
200            .await
201            .unwrap();
202
203        // Cutoff in the past should delete nothing.
204        let cutoff = chrono::Utc::now() - chrono::Duration::hours(1);
205        let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
206        assert_eq!(deleted, 0);
207    }
208
209    #[tokio::test]
210    async fn test_compact_memories() {
211        let substrate = MemorySubstrate::in_memory().unwrap();
212        let fighter_id = FighterId::new();
213
214        // Store several memories.
215        for i in 0..5 {
216            substrate
217                .store_memory(
218                    &fighter_id,
219                    &format!("key_{}", i),
220                    &format!("value_{}", i),
221                    (i as f64) * 0.2,
222                )
223                .await
224                .unwrap();
225        }
226
227        // Compact to max 3.
228        let removed = substrate.compact_memories(3).await.unwrap();
229        assert_eq!(removed, 2);
230    }
231
232    #[tokio::test]
233    async fn test_compact_memories_no_excess() {
234        let substrate = MemorySubstrate::in_memory().unwrap();
235        let fighter_id = FighterId::new();
236
237        substrate
238            .store_memory(&fighter_id, "key", "value", 0.9)
239            .await
240            .unwrap();
241
242        let removed = substrate.compact_memories(10).await.unwrap();
243        assert_eq!(removed, 0);
244    }
245
246    #[tokio::test]
247    async fn test_vacuum() {
248        let substrate = MemorySubstrate::in_memory().unwrap();
249        // Vacuum on in-memory should succeed.
250        substrate.vacuum().await.unwrap();
251    }
252
253    #[tokio::test]
254    async fn test_count_bouts_in_period() {
255        let substrate = MemorySubstrate::in_memory().unwrap();
256        let fighter_id = FighterId::new();
257        substrate
258            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
259            .await
260            .unwrap();
261
262        let before = chrono::Utc::now() - chrono::Duration::seconds(1);
263        substrate.create_bout(&fighter_id).await.unwrap();
264        substrate.create_bout(&fighter_id).await.unwrap();
265        let after = chrono::Utc::now() + chrono::Duration::seconds(1);
266
267        let count = substrate.count_bouts_in_period(before, after).await.unwrap();
268        assert_eq!(count, 2);
269    }
270
271    #[tokio::test]
272    async fn test_count_messages_in_period() {
273        let substrate = MemorySubstrate::in_memory().unwrap();
274        let fighter_id = FighterId::new();
275        substrate
276            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
277            .await
278            .unwrap();
279        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
280
281        let before = chrono::Utc::now() - chrono::Duration::seconds(1);
282        substrate
283            .save_message(&bout_id, &Message::new(Role::User, "a"))
284            .await
285            .unwrap();
286        substrate
287            .save_message(&bout_id, &Message::new(Role::Assistant, "b"))
288            .await
289            .unwrap();
290        let after = chrono::Utc::now() + chrono::Duration::seconds(1);
291
292        let count = substrate
293            .count_messages_in_period(before, after)
294            .await
295            .unwrap();
296        assert_eq!(count, 2);
297    }
298
299    #[tokio::test]
300    async fn test_count_bouts_empty_period() {
301        let substrate = MemorySubstrate::in_memory().unwrap();
302        let start = chrono::Utc::now() - chrono::Duration::days(365);
303        let end = chrono::Utc::now() - chrono::Duration::days(364);
304        let count = substrate.count_bouts_in_period(start, end).await.unwrap();
305        assert_eq!(count, 0);
306    }
307}