Skip to main content

punch_memory/
bouts.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5use punch_types::{FighterId, Message, PunchError, PunchResult, Role};
6use tracing::debug;
7
8use crate::MemorySubstrate;
9
10/// Unique identifier for a Bout (session / conversation).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(transparent)]
13pub struct BoutId(pub Uuid);
14
15impl BoutId {
16    pub fn new() -> Self {
17        Self(Uuid::new_v4())
18    }
19}
20
21impl Default for BoutId {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl std::fmt::Display for BoutId {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{}", self.0)
30    }
31}
32
33/// Lightweight summary of a bout for listing purposes.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BoutSummary {
36    pub id: BoutId,
37    pub fighter_id: FighterId,
38    pub title: Option<String>,
39    pub message_count: u64,
40    pub created_at: String,
41    pub updated_at: String,
42}
43
44impl MemorySubstrate {
45    /// Create a new bout for the given fighter and return its ID.
46    pub async fn create_bout(&self, fighter_id: &FighterId) -> PunchResult<BoutId> {
47        let bout_id = BoutId::new();
48        let bout_str = bout_id.to_string();
49        let fighter_str = fighter_id.to_string();
50
51        let conn = self.conn.lock().await;
52        conn.execute(
53            "INSERT INTO bouts (id, fighter_id) VALUES (?1, ?2)",
54            rusqlite::params![bout_str, fighter_str],
55        )
56        .map_err(|e| PunchError::Bout(format!("failed to create bout: {e}")))?;
57
58        debug!(bout_id = %bout_id, fighter_id = %fighter_id, "bout created");
59        Ok(bout_id)
60    }
61
62    /// Append a message to an existing bout.
63    pub async fn save_message(&self, bout_id: &BoutId, message: &Message) -> PunchResult<()> {
64        let bout_str = bout_id.to_string();
65        let role_str = message.role.to_string();
66
67        // Pack tool_calls and tool_results into a metadata JSON blob.
68        let metadata = if message.tool_calls.is_empty() && message.tool_results.is_empty() {
69            None
70        } else {
71            Some(serde_json::json!({
72                "tool_calls": message.tool_calls,
73                "tool_results": message.tool_results,
74            }))
75        };
76        let metadata_str = metadata.map(|m| m.to_string());
77        let ts = message.timestamp.format("%Y-%m-%dT%H:%M:%SZ").to_string();
78
79        let conn = self.conn.lock().await;
80        conn.execute(
81            "INSERT INTO messages (bout_id, role, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
82            rusqlite::params![bout_str, role_str, message.content, metadata_str, ts],
83        )
84        .map_err(|e| PunchError::Bout(format!("failed to save message: {e}")))?;
85
86        // Touch the bout's updated_at timestamp.
87        conn.execute(
88            "UPDATE bouts SET updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?1",
89            [&bout_str],
90        )
91        .map_err(|e| PunchError::Bout(format!("failed to touch bout: {e}")))?;
92
93        Ok(())
94    }
95
96    /// Load all messages for a bout in chronological order.
97    pub async fn load_messages(&self, bout_id: &BoutId) -> PunchResult<Vec<Message>> {
98        let bout_str = bout_id.to_string();
99        let conn = self.conn.lock().await;
100
101        let mut stmt = conn
102            .prepare(
103                "SELECT role, content, metadata, created_at FROM messages WHERE bout_id = ?1 ORDER BY id",
104            )
105            .map_err(|e| PunchError::Bout(format!("failed to prepare message query: {e}")))?;
106
107        let rows = stmt
108            .query_map([&bout_str], |row| {
109                let role_str: String = row.get(0)?;
110                let content: String = row.get(1)?;
111                let metadata: Option<String> = row.get(2)?;
112                let created_at: String = row.get(3)?;
113                Ok((role_str, content, metadata, created_at))
114            })
115            .map_err(|e| PunchError::Bout(format!("failed to query messages: {e}")))?;
116
117        let mut messages = Vec::new();
118        for row in rows {
119            let (role_str, content, metadata, created_at) =
120                row.map_err(|e| PunchError::Bout(format!("failed to read message row: {e}")))?;
121
122            let role = parse_role(&role_str)?;
123            let timestamp = parse_timestamp(&created_at)?;
124
125            let (tool_calls, tool_results) = match metadata {
126                Some(json) => {
127                    let v: serde_json::Value = serde_json::from_str(&json)
128                        .map_err(|e| PunchError::Bout(format!("corrupt message metadata: {e}")))?;
129                    let tc = serde_json::from_value(
130                        v.get("tool_calls")
131                            .cloned()
132                            .unwrap_or(serde_json::Value::Array(vec![])),
133                    )
134                    .unwrap_or_default();
135                    let tr = serde_json::from_value(
136                        v.get("tool_results")
137                            .cloned()
138                            .unwrap_or(serde_json::Value::Array(vec![])),
139                    )
140                    .unwrap_or_default();
141                    (tc, tr)
142                }
143                None => (Vec::new(), Vec::new()),
144            };
145
146            messages.push(Message {
147                role,
148                content,
149                tool_calls,
150                tool_results,
151                timestamp,
152                content_parts: Vec::new(),
153            });
154        }
155
156        Ok(messages)
157    }
158
159    /// List all bouts for a fighter, most recent first.
160    pub async fn list_bouts(&self, fighter_id: &FighterId) -> PunchResult<Vec<BoutSummary>> {
161        let fighter_str = fighter_id.to_string();
162        let conn = self.conn.lock().await;
163
164        let mut stmt = conn
165            .prepare(
166                "SELECT b.id, b.title, b.created_at, b.updated_at,
167                        (SELECT COUNT(*) FROM messages m WHERE m.bout_id = b.id)
168                 FROM bouts b
169                 WHERE b.fighter_id = ?1
170                 ORDER BY b.updated_at DESC",
171            )
172            .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
173
174        let rows = stmt
175            .query_map([&fighter_str], |row| {
176                let id: String = row.get(0)?;
177                let title: Option<String> = row.get(1)?;
178                let created_at: String = row.get(2)?;
179                let updated_at: String = row.get(3)?;
180                let message_count: u64 = row.get(4)?;
181                Ok((id, title, created_at, updated_at, message_count))
182            })
183            .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
184
185        let mut summaries = Vec::new();
186        for row in rows {
187            let (id, title, created_at, updated_at, message_count) =
188                row.map_err(|e| PunchError::Bout(format!("failed to read bout row: {e}")))?;
189
190            let bout_id = BoutId(
191                Uuid::parse_str(&id)
192                    .map_err(|e| PunchError::Bout(format!("invalid bout id: {e}")))?,
193            );
194
195            summaries.push(BoutSummary {
196                id: bout_id,
197                fighter_id: *fighter_id,
198                title,
199                message_count,
200                created_at,
201                updated_at,
202            });
203        }
204
205        Ok(summaries)
206    }
207
208    /// Return the most recent bout for a fighter, if any exists.
209    ///
210    /// This is used to restore conversation continuity across daemon restarts:
211    /// when a user messages a fighter whose `current_bout` was lost from memory,
212    /// we look up the latest bout from the database instead of creating a new one.
213    pub async fn latest_bout_for_fighter(
214        &self,
215        fighter_id: &FighterId,
216    ) -> PunchResult<Option<BoutId>> {
217        let fighter_str = fighter_id.to_string();
218        let conn = self.conn.lock().await;
219
220        let result: Option<String> = conn
221            .query_row(
222                "SELECT id FROM bouts WHERE fighter_id = ?1 ORDER BY updated_at DESC LIMIT 1",
223                [&fighter_str],
224                |row| row.get(0),
225            )
226            .ok();
227
228        match result {
229            Some(id_str) => {
230                let uuid = Uuid::parse_str(&id_str)
231                    .map_err(|e| PunchError::Bout(format!("invalid bout id: {e}")))?;
232                debug!(bout_id = %id_str, fighter_id = %fighter_id, "restored latest bout from database");
233                Ok(Some(BoutId(uuid)))
234            }
235            None => Ok(None),
236        }
237    }
238
239    /// Delete a bout and all its messages (cascading).
240    pub async fn delete_bout(&self, bout_id: &BoutId) -> PunchResult<()> {
241        let bout_str = bout_id.to_string();
242        let conn = self.conn.lock().await;
243
244        conn.execute("DELETE FROM bouts WHERE id = ?1", [&bout_str])
245            .map_err(|e| PunchError::Bout(format!("failed to delete bout: {e}")))?;
246
247        debug!(bout_id = %bout_id, "bout deleted");
248        Ok(())
249    }
250}
251
252fn parse_role(s: &str) -> PunchResult<Role> {
253    match s {
254        "user" => Ok(Role::User),
255        "assistant" => Ok(Role::Assistant),
256        "system" => Ok(Role::System),
257        "tool" => Ok(Role::Tool),
258        other => Err(PunchError::Bout(format!("unknown role: {other}"))),
259    }
260}
261
262fn parse_timestamp(s: &str) -> PunchResult<DateTime<Utc>> {
263    DateTime::parse_from_rfc3339(s)
264        .map(|dt| dt.with_timezone(&Utc))
265        .or_else(|_| {
266            chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
267        })
268        .map_err(|e| PunchError::Bout(format!("invalid timestamp '{s}': {e}")))
269}
270
271#[cfg(test)]
272mod tests {
273    use punch_types::{
274        FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role, WeightClass,
275    };
276
277    use crate::MemorySubstrate;
278
279    fn test_manifest() -> FighterManifest {
280        FighterManifest {
281            name: "Test Fighter".into(),
282            description: "A test fighter".into(),
283            model: ModelConfig {
284                provider: Provider::Anthropic,
285                model: "claude-sonnet-4-20250514".into(),
286                api_key_env: None,
287                base_url: None,
288                max_tokens: Some(4096),
289                temperature: Some(0.7),
290            },
291            system_prompt: "You are a test fighter.".into(),
292            capabilities: Vec::new(),
293            weight_class: WeightClass::Middleweight,
294            tenant_id: None,
295        }
296    }
297
298    #[tokio::test]
299    async fn test_create_bout_and_messages() {
300        let substrate = MemorySubstrate::in_memory().unwrap();
301        let fighter_id = punch_types::FighterId::new();
302
303        substrate
304            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
305            .await
306            .unwrap();
307
308        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
309
310        let msg = Message::new(Role::User, "Hello, fighter!");
311        substrate.save_message(&bout_id, &msg).await.unwrap();
312
313        let messages = substrate.load_messages(&bout_id).await.unwrap();
314        assert_eq!(messages.len(), 1);
315        assert_eq!(messages[0].content, "Hello, fighter!");
316        assert_eq!(messages[0].role, Role::User);
317    }
318
319    #[tokio::test]
320    async fn test_multiple_messages_in_bout() {
321        let substrate = MemorySubstrate::in_memory().unwrap();
322        let fighter_id = punch_types::FighterId::new();
323        substrate
324            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
325            .await
326            .unwrap();
327        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
328
329        substrate
330            .save_message(&bout_id, &Message::new(Role::User, "Hello"))
331            .await
332            .unwrap();
333        substrate
334            .save_message(&bout_id, &Message::new(Role::Assistant, "Hi there"))
335            .await
336            .unwrap();
337        substrate
338            .save_message(&bout_id, &Message::new(Role::User, "How are you?"))
339            .await
340            .unwrap();
341
342        let messages = substrate.load_messages(&bout_id).await.unwrap();
343        assert_eq!(messages.len(), 3);
344        assert_eq!(messages[0].role, Role::User);
345        assert_eq!(messages[1].role, Role::Assistant);
346        assert_eq!(messages[2].content, "How are you?");
347    }
348
349    #[tokio::test]
350    async fn test_load_messages_empty_bout() {
351        let substrate = MemorySubstrate::in_memory().unwrap();
352        let fighter_id = punch_types::FighterId::new();
353        substrate
354            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
355            .await
356            .unwrap();
357        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
358
359        let messages = substrate.load_messages(&bout_id).await.unwrap();
360        assert!(messages.is_empty());
361    }
362
363    #[tokio::test]
364    async fn test_multiple_bouts_for_fighter() {
365        let substrate = MemorySubstrate::in_memory().unwrap();
366        let fighter_id = punch_types::FighterId::new();
367        substrate
368            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
369            .await
370            .unwrap();
371
372        substrate.create_bout(&fighter_id).await.unwrap();
373        substrate.create_bout(&fighter_id).await.unwrap();
374        substrate.create_bout(&fighter_id).await.unwrap();
375
376        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
377        assert_eq!(bouts.len(), 3);
378    }
379
380    #[tokio::test]
381    async fn test_bout_summary_message_count() {
382        let substrate = MemorySubstrate::in_memory().unwrap();
383        let fighter_id = punch_types::FighterId::new();
384        substrate
385            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
386            .await
387            .unwrap();
388        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
389
390        substrate
391            .save_message(&bout_id, &Message::new(Role::User, "a"))
392            .await
393            .unwrap();
394        substrate
395            .save_message(&bout_id, &Message::new(Role::Assistant, "b"))
396            .await
397            .unwrap();
398
399        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
400        assert_eq!(bouts[0].message_count, 2);
401    }
402
403    #[tokio::test]
404    async fn test_bout_id_display() {
405        let bout_id = super::BoutId::new();
406        let s = bout_id.to_string();
407        assert!(!s.is_empty());
408        // Should be a valid UUID string
409        assert!(uuid::Uuid::parse_str(&s).is_ok());
410    }
411
412    #[tokio::test]
413    async fn test_delete_bout_cascades_messages() {
414        let substrate = MemorySubstrate::in_memory().unwrap();
415        let fighter_id = punch_types::FighterId::new();
416        substrate
417            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
418            .await
419            .unwrap();
420        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
421        substrate
422            .save_message(&bout_id, &Message::new(Role::User, "msg"))
423            .await
424            .unwrap();
425
426        substrate.delete_bout(&bout_id).await.unwrap();
427        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
428        assert!(bouts.is_empty());
429    }
430
431    #[tokio::test]
432    async fn test_list_and_delete_bouts() {
433        let substrate = MemorySubstrate::in_memory().unwrap();
434        let fighter_id = punch_types::FighterId::new();
435
436        substrate
437            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
438            .await
439            .unwrap();
440
441        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
442
443        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
444        assert_eq!(bouts.len(), 1);
445
446        substrate.delete_bout(&bout_id).await.unwrap();
447
448        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
449        assert!(bouts.is_empty());
450    }
451
452    #[tokio::test]
453    async fn test_latest_bout_for_fighter_none_when_empty() {
454        let substrate = MemorySubstrate::in_memory().unwrap();
455        let fighter_id = punch_types::FighterId::new();
456        substrate
457            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
458            .await
459            .unwrap();
460
461        // No bouts yet — should return None.
462        let latest = substrate
463            .latest_bout_for_fighter(&fighter_id)
464            .await
465            .unwrap();
466        assert!(latest.is_none());
467    }
468
469    #[tokio::test]
470    async fn test_latest_bout_for_fighter_returns_a_bout() {
471        let substrate = MemorySubstrate::in_memory().unwrap();
472        let fighter_id = punch_types::FighterId::new();
473        substrate
474            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
475            .await
476            .unwrap();
477
478        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
479
480        let latest = substrate
481            .latest_bout_for_fighter(&fighter_id)
482            .await
483            .unwrap();
484        assert_eq!(latest, Some(bout_id));
485    }
486
487    #[tokio::test]
488    async fn test_latest_bout_for_fighter_ignores_other_fighters() {
489        let substrate = MemorySubstrate::in_memory().unwrap();
490        let fighter_a = punch_types::FighterId::new();
491        let fighter_b = punch_types::FighterId::new();
492        substrate
493            .save_fighter(&fighter_a, &test_manifest(), FighterStatus::Idle)
494            .await
495            .unwrap();
496        substrate
497            .save_fighter(&fighter_b, &test_manifest(), FighterStatus::Idle)
498            .await
499            .unwrap();
500
501        let bout_a = substrate.create_bout(&fighter_a).await.unwrap();
502        let _bout_b = substrate.create_bout(&fighter_b).await.unwrap();
503
504        let latest = substrate.latest_bout_for_fighter(&fighter_a).await.unwrap();
505        assert_eq!(latest, Some(bout_a));
506    }
507}