1use chrono::{DateTime, Utc};
7use tracing::{debug, info};
8
9use punch_types::{PunchError, PunchResult};
10
11use crate::MemorySubstrate;
12
13impl MemorySubstrate {
14 pub async fn cleanup_old_messages(&self, cutoff: DateTime<Utc>) -> PunchResult<usize> {
18 let cutoff_str = cutoff.format("%Y-%m-%dT%H:%M:%SZ").to_string();
19 let conn = self.conn.lock().await;
20
21 let count = conn
22 .execute(
23 "DELETE FROM messages WHERE created_at < ?1",
24 rusqlite::params![cutoff_str],
25 )
26 .map_err(|e| PunchError::Memory(format!("failed to cleanup old messages: {e}")))?;
27
28 info!(deleted = count, cutoff = %cutoff_str, "cleaned up old messages");
29 Ok(count)
30 }
31
32 pub async fn compact_memories(&self, max_per_fighter: usize) -> PunchResult<usize> {
37 let conn = self.conn.lock().await;
38
39 let mut stmt = conn
41 .prepare(
42 "SELECT fighter_id, COUNT(*) as cnt FROM memories \
43 GROUP BY fighter_id HAVING cnt > ?1",
44 )
45 .map_err(|e| PunchError::Memory(format!("failed to query memory counts: {e}")))?;
46
47 let fighters: Vec<(String, usize)> = stmt
48 .query_map(rusqlite::params![max_per_fighter as i64], |row| {
49 Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
50 })
51 .map_err(|e| PunchError::Memory(format!("failed to list fighter memories: {e}")))?
52 .filter_map(|r| r.ok())
53 .collect();
54
55 let mut total_removed = 0;
56
57 for (fighter_id, count) in &fighters {
58 let excess = count - max_per_fighter;
59 if excess > 0 {
60 let deleted = conn
62 .execute(
63 "DELETE FROM memories WHERE rowid IN (\
64 SELECT rowid FROM memories \
65 WHERE fighter_id = ?1 \
66 ORDER BY confidence ASC \
67 LIMIT ?2\
68 )",
69 rusqlite::params![fighter_id, excess],
70 )
71 .map_err(|e| {
72 PunchError::Memory(format!("failed to compact memories: {e}"))
73 })?;
74 total_removed += deleted;
75 debug!(
76 fighter_id = %fighter_id,
77 removed = deleted,
78 "compacted memories for fighter"
79 );
80 }
81 }
82
83 info!(total_removed, "memory compaction complete");
84 Ok(total_removed)
85 }
86
87 pub async fn vacuum(&self) -> PunchResult<()> {
89 let conn = self.conn.lock().await;
90 conn.execute_batch("VACUUM")
91 .map_err(|e| PunchError::Memory(format!("vacuum failed: {e}")))?;
92 info!("database vacuumed");
93 Ok(())
94 }
95
96 pub async fn count_bouts_in_period(
98 &self,
99 start: DateTime<Utc>,
100 end: DateTime<Utc>,
101 ) -> PunchResult<usize> {
102 let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
103 let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
104 let conn = self.conn.lock().await;
105
106 let count: i64 = conn
107 .query_row(
108 "SELECT COUNT(*) FROM bouts WHERE created_at >= ?1 AND created_at <= ?2",
109 rusqlite::params![start_str, end_str],
110 |row| row.get(0),
111 )
112 .map_err(|e| PunchError::Memory(format!("failed to count bouts: {e}")))?;
113
114 Ok(count as usize)
115 }
116
117 pub async fn count_messages_in_period(
119 &self,
120 start: DateTime<Utc>,
121 end: DateTime<Utc>,
122 ) -> PunchResult<usize> {
123 let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
124 let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
125 let conn = self.conn.lock().await;
126
127 let count: i64 = conn
128 .query_row(
129 "SELECT COUNT(*) FROM messages WHERE created_at >= ?1 AND created_at <= ?2",
130 rusqlite::params![start_str, end_str],
131 |row| row.get(0),
132 )
133 .map_err(|e| PunchError::Memory(format!("failed to count messages: {e}")))?;
134
135 Ok(count as usize)
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use punch_types::{
142 FighterId, FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role,
143 WeightClass,
144 };
145
146 use crate::MemorySubstrate;
147
148 fn test_manifest() -> FighterManifest {
149 FighterManifest {
150 name: "Test".into(),
151 description: "test".into(),
152 model: ModelConfig {
153 provider: Provider::Anthropic,
154 model: "claude-sonnet-4-20250514".into(),
155 api_key_env: None,
156 base_url: None,
157 max_tokens: Some(4096),
158 temperature: Some(0.7),
159 },
160 system_prompt: "test".into(),
161 capabilities: Vec::new(),
162 weight_class: WeightClass::Middleweight,
163 tenant_id: None,
164 }
165 }
166
167 #[tokio::test]
168 async fn test_cleanup_old_messages() {
169 let substrate = MemorySubstrate::in_memory().unwrap();
170 let fighter_id = FighterId::new();
171 substrate
172 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
173 .await
174 .unwrap();
175 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
176
177 substrate
178 .save_message(&bout_id, &Message::new(Role::User, "old msg"))
179 .await
180 .unwrap();
181
182 let cutoff = chrono::Utc::now() + chrono::Duration::hours(1);
184 let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
185 assert!(deleted >= 1);
186 }
187
188 #[tokio::test]
189 async fn test_cleanup_old_messages_none_deleted() {
190 let substrate = MemorySubstrate::in_memory().unwrap();
191 let fighter_id = FighterId::new();
192 substrate
193 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
194 .await
195 .unwrap();
196 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
197
198 substrate
199 .save_message(&bout_id, &Message::new(Role::User, "recent msg"))
200 .await
201 .unwrap();
202
203 let cutoff = chrono::Utc::now() - chrono::Duration::hours(1);
205 let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
206 assert_eq!(deleted, 0);
207 }
208
209 #[tokio::test]
210 async fn test_compact_memories() {
211 let substrate = MemorySubstrate::in_memory().unwrap();
212 let fighter_id = FighterId::new();
213
214 for i in 0..5 {
216 substrate
217 .store_memory(
218 &fighter_id,
219 &format!("key_{}", i),
220 &format!("value_{}", i),
221 (i as f64) * 0.2,
222 )
223 .await
224 .unwrap();
225 }
226
227 let removed = substrate.compact_memories(3).await.unwrap();
229 assert_eq!(removed, 2);
230 }
231
232 #[tokio::test]
233 async fn test_compact_memories_no_excess() {
234 let substrate = MemorySubstrate::in_memory().unwrap();
235 let fighter_id = FighterId::new();
236
237 substrate
238 .store_memory(&fighter_id, "key", "value", 0.9)
239 .await
240 .unwrap();
241
242 let removed = substrate.compact_memories(10).await.unwrap();
243 assert_eq!(removed, 0);
244 }
245
246 #[tokio::test]
247 async fn test_vacuum() {
248 let substrate = MemorySubstrate::in_memory().unwrap();
249 substrate.vacuum().await.unwrap();
251 }
252
253 #[tokio::test]
254 async fn test_count_bouts_in_period() {
255 let substrate = MemorySubstrate::in_memory().unwrap();
256 let fighter_id = FighterId::new();
257 substrate
258 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
259 .await
260 .unwrap();
261
262 let before = chrono::Utc::now() - chrono::Duration::seconds(1);
263 substrate.create_bout(&fighter_id).await.unwrap();
264 substrate.create_bout(&fighter_id).await.unwrap();
265 let after = chrono::Utc::now() + chrono::Duration::seconds(1);
266
267 let count = substrate.count_bouts_in_period(before, after).await.unwrap();
268 assert_eq!(count, 2);
269 }
270
271 #[tokio::test]
272 async fn test_count_messages_in_period() {
273 let substrate = MemorySubstrate::in_memory().unwrap();
274 let fighter_id = FighterId::new();
275 substrate
276 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
277 .await
278 .unwrap();
279 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
280
281 let before = chrono::Utc::now() - chrono::Duration::seconds(1);
282 substrate
283 .save_message(&bout_id, &Message::new(Role::User, "a"))
284 .await
285 .unwrap();
286 substrate
287 .save_message(&bout_id, &Message::new(Role::Assistant, "b"))
288 .await
289 .unwrap();
290 let after = chrono::Utc::now() + chrono::Duration::seconds(1);
291
292 let count = substrate
293 .count_messages_in_period(before, after)
294 .await
295 .unwrap();
296 assert_eq!(count, 2);
297 }
298
299 #[tokio::test]
300 async fn test_count_bouts_empty_period() {
301 let substrate = MemorySubstrate::in_memory().unwrap();
302 let start = chrono::Utc::now() - chrono::Duration::days(365);
303 let end = chrono::Utc::now() - chrono::Duration::days(364);
304 let count = substrate.count_bouts_in_period(start, end).await.unwrap();
305 assert_eq!(count, 0);
306 }
307}