1use chrono::{DateTime, Utc};
7use tracing::{debug, info};
8
9use punch_types::{PunchError, PunchResult};
10
11use crate::MemorySubstrate;
12
13impl MemorySubstrate {
14 pub async fn cleanup_old_messages(&self, cutoff: DateTime<Utc>) -> PunchResult<usize> {
18 let cutoff_str = cutoff.format("%Y-%m-%dT%H:%M:%SZ").to_string();
19 let conn = self.conn.lock().await;
20
21 let count = conn
22 .execute(
23 "DELETE FROM messages WHERE created_at < ?1",
24 rusqlite::params![cutoff_str],
25 )
26 .map_err(|e| PunchError::Memory(format!("failed to cleanup old messages: {e}")))?;
27
28 info!(deleted = count, cutoff = %cutoff_str, "cleaned up old messages");
29 Ok(count)
30 }
31
32 pub async fn compact_memories(&self, max_per_fighter: usize) -> PunchResult<usize> {
37 let conn = self.conn.lock().await;
38
39 let mut stmt = conn
41 .prepare(
42 "SELECT fighter_id, COUNT(*) as cnt FROM memories \
43 GROUP BY fighter_id HAVING cnt > ?1",
44 )
45 .map_err(|e| PunchError::Memory(format!("failed to query memory counts: {e}")))?;
46
47 let fighters: Vec<(String, usize)> = stmt
48 .query_map(rusqlite::params![max_per_fighter as i64], |row| {
49 Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
50 })
51 .map_err(|e| PunchError::Memory(format!("failed to list fighter memories: {e}")))?
52 .filter_map(|r| r.ok())
53 .collect();
54
55 let mut total_removed = 0;
56
57 for (fighter_id, count) in &fighters {
58 let excess = count - max_per_fighter;
59 if excess > 0 {
60 let deleted = conn
62 .execute(
63 "DELETE FROM memories WHERE rowid IN (\
64 SELECT rowid FROM memories \
65 WHERE fighter_id = ?1 \
66 ORDER BY confidence ASC \
67 LIMIT ?2\
68 )",
69 rusqlite::params![fighter_id, excess],
70 )
71 .map_err(|e| PunchError::Memory(format!("failed to compact memories: {e}")))?;
72 total_removed += deleted;
73 debug!(
74 fighter_id = %fighter_id,
75 removed = deleted,
76 "compacted memories for fighter"
77 );
78 }
79 }
80
81 info!(total_removed, "memory compaction complete");
82 Ok(total_removed)
83 }
84
85 pub async fn vacuum(&self) -> PunchResult<()> {
87 let conn = self.conn.lock().await;
88 conn.execute_batch("VACUUM")
89 .map_err(|e| PunchError::Memory(format!("vacuum failed: {e}")))?;
90 info!("database vacuumed");
91 Ok(())
92 }
93
94 pub async fn count_bouts_in_period(
96 &self,
97 start: DateTime<Utc>,
98 end: DateTime<Utc>,
99 ) -> PunchResult<usize> {
100 let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
101 let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
102 let conn = self.conn.lock().await;
103
104 let count: i64 = conn
105 .query_row(
106 "SELECT COUNT(*) FROM bouts WHERE created_at >= ?1 AND created_at <= ?2",
107 rusqlite::params![start_str, end_str],
108 |row| row.get(0),
109 )
110 .map_err(|e| PunchError::Memory(format!("failed to count bouts: {e}")))?;
111
112 Ok(count as usize)
113 }
114
115 pub async fn count_messages_in_period(
117 &self,
118 start: DateTime<Utc>,
119 end: DateTime<Utc>,
120 ) -> PunchResult<usize> {
121 let start_str = start.format("%Y-%m-%dT%H:%M:%SZ").to_string();
122 let end_str = end.format("%Y-%m-%dT%H:%M:%SZ").to_string();
123 let conn = self.conn.lock().await;
124
125 let count: i64 = conn
126 .query_row(
127 "SELECT COUNT(*) FROM messages WHERE created_at >= ?1 AND created_at <= ?2",
128 rusqlite::params![start_str, end_str],
129 |row| row.get(0),
130 )
131 .map_err(|e| PunchError::Memory(format!("failed to count messages: {e}")))?;
132
133 Ok(count as usize)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use punch_types::{
140 FighterId, FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role,
141 WeightClass,
142 };
143
144 use crate::MemorySubstrate;
145
146 fn test_manifest() -> FighterManifest {
147 FighterManifest {
148 name: "Test".into(),
149 description: "test".into(),
150 model: ModelConfig {
151 provider: Provider::Anthropic,
152 model: "claude-sonnet-4-20250514".into(),
153 api_key_env: None,
154 base_url: None,
155 max_tokens: Some(4096),
156 temperature: Some(0.7),
157 },
158 system_prompt: "test".into(),
159 capabilities: Vec::new(),
160 weight_class: WeightClass::Middleweight,
161 tenant_id: None,
162 }
163 }
164
165 #[tokio::test]
166 async fn test_cleanup_old_messages() {
167 let substrate = MemorySubstrate::in_memory().unwrap();
168 let fighter_id = FighterId::new();
169 substrate
170 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
171 .await
172 .unwrap();
173 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
174
175 substrate
176 .save_message(&bout_id, &Message::new(Role::User, "old msg"))
177 .await
178 .unwrap();
179
180 let cutoff = chrono::Utc::now() + chrono::Duration::hours(1);
182 let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
183 assert!(deleted >= 1);
184 }
185
186 #[tokio::test]
187 async fn test_cleanup_old_messages_none_deleted() {
188 let substrate = MemorySubstrate::in_memory().unwrap();
189 let fighter_id = FighterId::new();
190 substrate
191 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
192 .await
193 .unwrap();
194 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
195
196 substrate
197 .save_message(&bout_id, &Message::new(Role::User, "recent msg"))
198 .await
199 .unwrap();
200
201 let cutoff = chrono::Utc::now() - chrono::Duration::hours(1);
203 let deleted = substrate.cleanup_old_messages(cutoff).await.unwrap();
204 assert_eq!(deleted, 0);
205 }
206
207 #[tokio::test]
208 async fn test_compact_memories() {
209 let substrate = MemorySubstrate::in_memory().unwrap();
210 let fighter_id = FighterId::new();
211
212 for i in 0..5 {
214 substrate
215 .store_memory(
216 &fighter_id,
217 &format!("key_{}", i),
218 &format!("value_{}", i),
219 (i as f64) * 0.2,
220 )
221 .await
222 .unwrap();
223 }
224
225 let removed = substrate.compact_memories(3).await.unwrap();
227 assert_eq!(removed, 2);
228 }
229
230 #[tokio::test]
231 async fn test_compact_memories_no_excess() {
232 let substrate = MemorySubstrate::in_memory().unwrap();
233 let fighter_id = FighterId::new();
234
235 substrate
236 .store_memory(&fighter_id, "key", "value", 0.9)
237 .await
238 .unwrap();
239
240 let removed = substrate.compact_memories(10).await.unwrap();
241 assert_eq!(removed, 0);
242 }
243
244 #[tokio::test]
245 async fn test_vacuum() {
246 let substrate = MemorySubstrate::in_memory().unwrap();
247 substrate.vacuum().await.unwrap();
249 }
250
251 #[tokio::test]
252 async fn test_count_bouts_in_period() {
253 let substrate = MemorySubstrate::in_memory().unwrap();
254 let fighter_id = FighterId::new();
255 substrate
256 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
257 .await
258 .unwrap();
259
260 let before = chrono::Utc::now() - chrono::Duration::seconds(1);
261 substrate.create_bout(&fighter_id).await.unwrap();
262 substrate.create_bout(&fighter_id).await.unwrap();
263 let after = chrono::Utc::now() + chrono::Duration::seconds(1);
264
265 let count = substrate
266 .count_bouts_in_period(before, after)
267 .await
268 .unwrap();
269 assert_eq!(count, 2);
270 }
271
272 #[tokio::test]
273 async fn test_count_messages_in_period() {
274 let substrate = MemorySubstrate::in_memory().unwrap();
275 let fighter_id = FighterId::new();
276 substrate
277 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
278 .await
279 .unwrap();
280 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
281
282 let before = chrono::Utc::now() - chrono::Duration::seconds(1);
283 substrate
284 .save_message(&bout_id, &Message::new(Role::User, "a"))
285 .await
286 .unwrap();
287 substrate
288 .save_message(&bout_id, &Message::new(Role::Assistant, "b"))
289 .await
290 .unwrap();
291 let after = chrono::Utc::now() + chrono::Duration::seconds(1);
292
293 let count = substrate
294 .count_messages_in_period(before, after)
295 .await
296 .unwrap();
297 assert_eq!(count, 2);
298 }
299
300 #[tokio::test]
301 async fn test_count_bouts_empty_period() {
302 let substrate = MemorySubstrate::in_memory().unwrap();
303 let start = chrono::Utc::now() - chrono::Duration::days(365);
304 let end = chrono::Utc::now() - chrono::Duration::days(364);
305 let count = substrate.count_bouts_in_period(start, end).await.unwrap();
306 assert_eq!(count, 0);
307 }
308}