1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5use punch_types::{FighterId, Message, PunchError, PunchResult, Role};
6use tracing::debug;
7
8use crate::MemorySubstrate;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(transparent)]
13pub struct BoutId(pub Uuid);
14
15impl BoutId {
16 pub fn new() -> Self {
17 Self(Uuid::new_v4())
18 }
19}
20
21impl Default for BoutId {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl std::fmt::Display for BoutId {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BoutSummary {
36 pub id: BoutId,
37 pub fighter_id: FighterId,
38 pub title: Option<String>,
39 pub message_count: u64,
40 pub created_at: String,
41 pub updated_at: String,
42}
43
44impl MemorySubstrate {
45 pub async fn create_bout(&self, fighter_id: &FighterId) -> PunchResult<BoutId> {
47 let bout_id = BoutId::new();
48 let bout_str = bout_id.to_string();
49 let fighter_str = fighter_id.to_string();
50
51 let conn = self.conn.lock().await;
52 conn.execute(
53 "INSERT INTO bouts (id, fighter_id) VALUES (?1, ?2)",
54 rusqlite::params![bout_str, fighter_str],
55 )
56 .map_err(|e| PunchError::Bout(format!("failed to create bout: {e}")))?;
57
58 debug!(bout_id = %bout_id, fighter_id = %fighter_id, "bout created");
59 Ok(bout_id)
60 }
61
62 pub async fn save_message(&self, bout_id: &BoutId, message: &Message) -> PunchResult<()> {
64 let bout_str = bout_id.to_string();
65 let role_str = message.role.to_string();
66
67 let metadata = if message.tool_calls.is_empty() && message.tool_results.is_empty() {
69 None
70 } else {
71 Some(serde_json::json!({
72 "tool_calls": message.tool_calls,
73 "tool_results": message.tool_results,
74 }))
75 };
76 let metadata_str = metadata.map(|m| m.to_string());
77 let ts = message.timestamp.format("%Y-%m-%dT%H:%M:%SZ").to_string();
78
79 let conn = self.conn.lock().await;
80 conn.execute(
81 "INSERT INTO messages (bout_id, role, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
82 rusqlite::params![bout_str, role_str, message.content, metadata_str, ts],
83 )
84 .map_err(|e| PunchError::Bout(format!("failed to save message: {e}")))?;
85
86 conn.execute(
88 "UPDATE bouts SET updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?1",
89 [&bout_str],
90 )
91 .map_err(|e| PunchError::Bout(format!("failed to touch bout: {e}")))?;
92
93 Ok(())
94 }
95
96 pub async fn load_messages(&self, bout_id: &BoutId) -> PunchResult<Vec<Message>> {
98 let bout_str = bout_id.to_string();
99 let conn = self.conn.lock().await;
100
101 let mut stmt = conn
102 .prepare(
103 "SELECT role, content, metadata, created_at FROM messages WHERE bout_id = ?1 ORDER BY id",
104 )
105 .map_err(|e| PunchError::Bout(format!("failed to prepare message query: {e}")))?;
106
107 let rows = stmt
108 .query_map([&bout_str], |row| {
109 let role_str: String = row.get(0)?;
110 let content: String = row.get(1)?;
111 let metadata: Option<String> = row.get(2)?;
112 let created_at: String = row.get(3)?;
113 Ok((role_str, content, metadata, created_at))
114 })
115 .map_err(|e| PunchError::Bout(format!("failed to query messages: {e}")))?;
116
117 let mut messages = Vec::new();
118 for row in rows {
119 let (role_str, content, metadata, created_at) =
120 row.map_err(|e| PunchError::Bout(format!("failed to read message row: {e}")))?;
121
122 let role = parse_role(&role_str)?;
123 let timestamp = parse_timestamp(&created_at)?;
124
125 let (tool_calls, tool_results) = match metadata {
126 Some(json) => {
127 let v: serde_json::Value = serde_json::from_str(&json)
128 .map_err(|e| PunchError::Bout(format!("corrupt message metadata: {e}")))?;
129 let tc = serde_json::from_value(
130 v.get("tool_calls")
131 .cloned()
132 .unwrap_or(serde_json::Value::Array(vec![])),
133 )
134 .unwrap_or_default();
135 let tr = serde_json::from_value(
136 v.get("tool_results")
137 .cloned()
138 .unwrap_or(serde_json::Value::Array(vec![])),
139 )
140 .unwrap_or_default();
141 (tc, tr)
142 }
143 None => (Vec::new(), Vec::new()),
144 };
145
146 messages.push(Message {
147 role,
148 content,
149 tool_calls,
150 tool_results,
151 timestamp,
152 content_parts: Vec::new(),
153 });
154 }
155
156 Ok(messages)
157 }
158
159 pub async fn list_bouts(&self, fighter_id: &FighterId) -> PunchResult<Vec<BoutSummary>> {
161 let fighter_str = fighter_id.to_string();
162 let conn = self.conn.lock().await;
163
164 let mut stmt = conn
165 .prepare(
166 "SELECT b.id, b.title, b.created_at, b.updated_at,
167 (SELECT COUNT(*) FROM messages m WHERE m.bout_id = b.id)
168 FROM bouts b
169 WHERE b.fighter_id = ?1
170 ORDER BY b.updated_at DESC",
171 )
172 .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
173
174 let rows = stmt
175 .query_map([&fighter_str], |row| {
176 let id: String = row.get(0)?;
177 let title: Option<String> = row.get(1)?;
178 let created_at: String = row.get(2)?;
179 let updated_at: String = row.get(3)?;
180 let message_count: u64 = row.get(4)?;
181 Ok((id, title, created_at, updated_at, message_count))
182 })
183 .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
184
185 let mut summaries = Vec::new();
186 for row in rows {
187 let (id, title, created_at, updated_at, message_count) =
188 row.map_err(|e| PunchError::Bout(format!("failed to read bout row: {e}")))?;
189
190 let bout_id = BoutId(
191 Uuid::parse_str(&id)
192 .map_err(|e| PunchError::Bout(format!("invalid bout id: {e}")))?,
193 );
194
195 summaries.push(BoutSummary {
196 id: bout_id,
197 fighter_id: *fighter_id,
198 title,
199 message_count,
200 created_at,
201 updated_at,
202 });
203 }
204
205 Ok(summaries)
206 }
207
208 pub async fn latest_bout_for_fighter(
214 &self,
215 fighter_id: &FighterId,
216 ) -> PunchResult<Option<BoutId>> {
217 let fighter_str = fighter_id.to_string();
218 let conn = self.conn.lock().await;
219
220 let result: Option<String> = conn
221 .query_row(
222 "SELECT id FROM bouts WHERE fighter_id = ?1 ORDER BY updated_at DESC LIMIT 1",
223 [&fighter_str],
224 |row| row.get(0),
225 )
226 .ok();
227
228 match result {
229 Some(id_str) => {
230 let uuid = Uuid::parse_str(&id_str)
231 .map_err(|e| PunchError::Bout(format!("invalid bout id: {e}")))?;
232 debug!(bout_id = %id_str, fighter_id = %fighter_id, "restored latest bout from database");
233 Ok(Some(BoutId(uuid)))
234 }
235 None => Ok(None),
236 }
237 }
238
239 pub async fn delete_bout(&self, bout_id: &BoutId) -> PunchResult<()> {
241 let bout_str = bout_id.to_string();
242 let conn = self.conn.lock().await;
243
244 conn.execute("DELETE FROM bouts WHERE id = ?1", [&bout_str])
245 .map_err(|e| PunchError::Bout(format!("failed to delete bout: {e}")))?;
246
247 debug!(bout_id = %bout_id, "bout deleted");
248 Ok(())
249 }
250}
251
252fn parse_role(s: &str) -> PunchResult<Role> {
253 match s {
254 "user" => Ok(Role::User),
255 "assistant" => Ok(Role::Assistant),
256 "system" => Ok(Role::System),
257 "tool" => Ok(Role::Tool),
258 other => Err(PunchError::Bout(format!("unknown role: {other}"))),
259 }
260}
261
262fn parse_timestamp(s: &str) -> PunchResult<DateTime<Utc>> {
263 DateTime::parse_from_rfc3339(s)
264 .map(|dt| dt.with_timezone(&Utc))
265 .or_else(|_| {
266 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
267 })
268 .map_err(|e| PunchError::Bout(format!("invalid timestamp '{s}': {e}")))
269}
270
271#[cfg(test)]
272mod tests {
273 use punch_types::{
274 FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role, WeightClass,
275 };
276
277 use crate::MemorySubstrate;
278
279 fn test_manifest() -> FighterManifest {
280 FighterManifest {
281 name: "Test Fighter".into(),
282 description: "A test fighter".into(),
283 model: ModelConfig {
284 provider: Provider::Anthropic,
285 model: "claude-sonnet-4-20250514".into(),
286 api_key_env: None,
287 base_url: None,
288 max_tokens: Some(4096),
289 temperature: Some(0.7),
290 },
291 system_prompt: "You are a test fighter.".into(),
292 capabilities: Vec::new(),
293 weight_class: WeightClass::Middleweight,
294 tenant_id: None,
295 }
296 }
297
298 #[tokio::test]
299 async fn test_create_bout_and_messages() {
300 let substrate = MemorySubstrate::in_memory().unwrap();
301 let fighter_id = punch_types::FighterId::new();
302
303 substrate
304 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
305 .await
306 .unwrap();
307
308 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
309
310 let msg = Message::new(Role::User, "Hello, fighter!");
311 substrate.save_message(&bout_id, &msg).await.unwrap();
312
313 let messages = substrate.load_messages(&bout_id).await.unwrap();
314 assert_eq!(messages.len(), 1);
315 assert_eq!(messages[0].content, "Hello, fighter!");
316 assert_eq!(messages[0].role, Role::User);
317 }
318
319 #[tokio::test]
320 async fn test_multiple_messages_in_bout() {
321 let substrate = MemorySubstrate::in_memory().unwrap();
322 let fighter_id = punch_types::FighterId::new();
323 substrate
324 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
325 .await
326 .unwrap();
327 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
328
329 substrate
330 .save_message(&bout_id, &Message::new(Role::User, "Hello"))
331 .await
332 .unwrap();
333 substrate
334 .save_message(&bout_id, &Message::new(Role::Assistant, "Hi there"))
335 .await
336 .unwrap();
337 substrate
338 .save_message(&bout_id, &Message::new(Role::User, "How are you?"))
339 .await
340 .unwrap();
341
342 let messages = substrate.load_messages(&bout_id).await.unwrap();
343 assert_eq!(messages.len(), 3);
344 assert_eq!(messages[0].role, Role::User);
345 assert_eq!(messages[1].role, Role::Assistant);
346 assert_eq!(messages[2].content, "How are you?");
347 }
348
349 #[tokio::test]
350 async fn test_load_messages_empty_bout() {
351 let substrate = MemorySubstrate::in_memory().unwrap();
352 let fighter_id = punch_types::FighterId::new();
353 substrate
354 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
355 .await
356 .unwrap();
357 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
358
359 let messages = substrate.load_messages(&bout_id).await.unwrap();
360 assert!(messages.is_empty());
361 }
362
363 #[tokio::test]
364 async fn test_multiple_bouts_for_fighter() {
365 let substrate = MemorySubstrate::in_memory().unwrap();
366 let fighter_id = punch_types::FighterId::new();
367 substrate
368 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
369 .await
370 .unwrap();
371
372 substrate.create_bout(&fighter_id).await.unwrap();
373 substrate.create_bout(&fighter_id).await.unwrap();
374 substrate.create_bout(&fighter_id).await.unwrap();
375
376 let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
377 assert_eq!(bouts.len(), 3);
378 }
379
380 #[tokio::test]
381 async fn test_bout_summary_message_count() {
382 let substrate = MemorySubstrate::in_memory().unwrap();
383 let fighter_id = punch_types::FighterId::new();
384 substrate
385 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
386 .await
387 .unwrap();
388 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
389
390 substrate
391 .save_message(&bout_id, &Message::new(Role::User, "a"))
392 .await
393 .unwrap();
394 substrate
395 .save_message(&bout_id, &Message::new(Role::Assistant, "b"))
396 .await
397 .unwrap();
398
399 let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
400 assert_eq!(bouts[0].message_count, 2);
401 }
402
403 #[tokio::test]
404 async fn test_bout_id_display() {
405 let bout_id = super::BoutId::new();
406 let s = bout_id.to_string();
407 assert!(!s.is_empty());
408 assert!(uuid::Uuid::parse_str(&s).is_ok());
410 }
411
412 #[tokio::test]
413 async fn test_delete_bout_cascades_messages() {
414 let substrate = MemorySubstrate::in_memory().unwrap();
415 let fighter_id = punch_types::FighterId::new();
416 substrate
417 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
418 .await
419 .unwrap();
420 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
421 substrate
422 .save_message(&bout_id, &Message::new(Role::User, "msg"))
423 .await
424 .unwrap();
425
426 substrate.delete_bout(&bout_id).await.unwrap();
427 let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
428 assert!(bouts.is_empty());
429 }
430
431 #[tokio::test]
432 async fn test_list_and_delete_bouts() {
433 let substrate = MemorySubstrate::in_memory().unwrap();
434 let fighter_id = punch_types::FighterId::new();
435
436 substrate
437 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
438 .await
439 .unwrap();
440
441 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
442
443 let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
444 assert_eq!(bouts.len(), 1);
445
446 substrate.delete_bout(&bout_id).await.unwrap();
447
448 let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
449 assert!(bouts.is_empty());
450 }
451
452 #[tokio::test]
453 async fn test_latest_bout_for_fighter_none_when_empty() {
454 let substrate = MemorySubstrate::in_memory().unwrap();
455 let fighter_id = punch_types::FighterId::new();
456 substrate
457 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
458 .await
459 .unwrap();
460
461 let latest = substrate
463 .latest_bout_for_fighter(&fighter_id)
464 .await
465 .unwrap();
466 assert!(latest.is_none());
467 }
468
469 #[tokio::test]
470 async fn test_latest_bout_for_fighter_returns_a_bout() {
471 let substrate = MemorySubstrate::in_memory().unwrap();
472 let fighter_id = punch_types::FighterId::new();
473 substrate
474 .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
475 .await
476 .unwrap();
477
478 let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
479
480 let latest = substrate
481 .latest_bout_for_fighter(&fighter_id)
482 .await
483 .unwrap();
484 assert_eq!(latest, Some(bout_id));
485 }
486
487 #[tokio::test]
488 async fn test_latest_bout_for_fighter_ignores_other_fighters() {
489 let substrate = MemorySubstrate::in_memory().unwrap();
490 let fighter_a = punch_types::FighterId::new();
491 let fighter_b = punch_types::FighterId::new();
492 substrate
493 .save_fighter(&fighter_a, &test_manifest(), FighterStatus::Idle)
494 .await
495 .unwrap();
496 substrate
497 .save_fighter(&fighter_b, &test_manifest(), FighterStatus::Idle)
498 .await
499 .unwrap();
500
501 let bout_a = substrate.create_bout(&fighter_a).await.unwrap();
502 let _bout_b = substrate.create_bout(&fighter_b).await.unwrap();
503
504 let latest = substrate.latest_bout_for_fighter(&fighter_a).await.unwrap();
505 assert_eq!(latest, Some(bout_a));
506 }
507}