Skip to main content

punch_memory/
bouts.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5use punch_types::{FighterId, Message, PunchError, PunchResult, Role};
6use tracing::debug;
7
8use crate::MemorySubstrate;
9
10/// Unique identifier for a Bout (session / conversation).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(transparent)]
13pub struct BoutId(pub Uuid);
14
15impl BoutId {
16    pub fn new() -> Self {
17        Self(Uuid::new_v4())
18    }
19}
20
21impl Default for BoutId {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl std::fmt::Display for BoutId {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{}", self.0)
30    }
31}
32
33/// Lightweight summary of a bout for listing purposes.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BoutSummary {
36    pub id: BoutId,
37    pub fighter_id: FighterId,
38    pub title: Option<String>,
39    pub message_count: u64,
40    pub created_at: String,
41    pub updated_at: String,
42}
43
44impl MemorySubstrate {
45    /// Create a new bout for the given fighter and return its ID.
46    pub async fn create_bout(&self, fighter_id: &FighterId) -> PunchResult<BoutId> {
47        let bout_id = BoutId::new();
48        let bout_str = bout_id.to_string();
49        let fighter_str = fighter_id.to_string();
50
51        let conn = self.conn.lock().await;
52        conn.execute(
53            "INSERT INTO bouts (id, fighter_id) VALUES (?1, ?2)",
54            rusqlite::params![bout_str, fighter_str],
55        )
56        .map_err(|e| PunchError::Bout(format!("failed to create bout: {e}")))?;
57
58        debug!(bout_id = %bout_id, fighter_id = %fighter_id, "bout created");
59        Ok(bout_id)
60    }
61
62    /// Append a message to an existing bout.
63    pub async fn save_message(&self, bout_id: &BoutId, message: &Message) -> PunchResult<()> {
64        let bout_str = bout_id.to_string();
65        let role_str = message.role.to_string();
66
67        // Pack tool_calls and tool_results into a metadata JSON blob.
68        let metadata = if message.tool_calls.is_empty() && message.tool_results.is_empty() {
69            None
70        } else {
71            Some(serde_json::json!({
72                "tool_calls": message.tool_calls,
73                "tool_results": message.tool_results,
74            }))
75        };
76        let metadata_str = metadata.map(|m| m.to_string());
77        let ts = message.timestamp.format("%Y-%m-%dT%H:%M:%SZ").to_string();
78
79        let conn = self.conn.lock().await;
80        conn.execute(
81            "INSERT INTO messages (bout_id, role, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
82            rusqlite::params![bout_str, role_str, message.content, metadata_str, ts],
83        )
84        .map_err(|e| PunchError::Bout(format!("failed to save message: {e}")))?;
85
86        // Touch the bout's updated_at timestamp.
87        conn.execute(
88            "UPDATE bouts SET updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?1",
89            [&bout_str],
90        )
91        .map_err(|e| PunchError::Bout(format!("failed to touch bout: {e}")))?;
92
93        Ok(())
94    }
95
96    /// Load all messages for a bout in chronological order.
97    pub async fn load_messages(&self, bout_id: &BoutId) -> PunchResult<Vec<Message>> {
98        let bout_str = bout_id.to_string();
99        let conn = self.conn.lock().await;
100
101        let mut stmt = conn
102            .prepare(
103                "SELECT role, content, metadata, created_at FROM messages WHERE bout_id = ?1 ORDER BY id",
104            )
105            .map_err(|e| PunchError::Bout(format!("failed to prepare message query: {e}")))?;
106
107        let rows = stmt
108            .query_map([&bout_str], |row| {
109                let role_str: String = row.get(0)?;
110                let content: String = row.get(1)?;
111                let metadata: Option<String> = row.get(2)?;
112                let created_at: String = row.get(3)?;
113                Ok((role_str, content, metadata, created_at))
114            })
115            .map_err(|e| PunchError::Bout(format!("failed to query messages: {e}")))?;
116
117        let mut messages = Vec::new();
118        for row in rows {
119            let (role_str, content, metadata, created_at) =
120                row.map_err(|e| PunchError::Bout(format!("failed to read message row: {e}")))?;
121
122            let role = parse_role(&role_str)?;
123            let timestamp = parse_timestamp(&created_at)?;
124
125            let (tool_calls, tool_results) = match metadata {
126                Some(json) => {
127                    let v: serde_json::Value = serde_json::from_str(&json)
128                        .map_err(|e| PunchError::Bout(format!("corrupt message metadata: {e}")))?;
129                    let tc = serde_json::from_value(
130                        v.get("tool_calls")
131                            .cloned()
132                            .unwrap_or(serde_json::Value::Array(vec![])),
133                    )
134                    .unwrap_or_default();
135                    let tr = serde_json::from_value(
136                        v.get("tool_results")
137                            .cloned()
138                            .unwrap_or(serde_json::Value::Array(vec![])),
139                    )
140                    .unwrap_or_default();
141                    (tc, tr)
142                }
143                None => (Vec::new(), Vec::new()),
144            };
145
146            messages.push(Message {
147                role,
148                content,
149                tool_calls,
150                tool_results,
151                timestamp,
152            });
153        }
154
155        Ok(messages)
156    }
157
158    /// List all bouts for a fighter, most recent first.
159    pub async fn list_bouts(&self, fighter_id: &FighterId) -> PunchResult<Vec<BoutSummary>> {
160        let fighter_str = fighter_id.to_string();
161        let conn = self.conn.lock().await;
162
163        let mut stmt = conn
164            .prepare(
165                "SELECT b.id, b.title, b.created_at, b.updated_at,
166                        (SELECT COUNT(*) FROM messages m WHERE m.bout_id = b.id)
167                 FROM bouts b
168                 WHERE b.fighter_id = ?1
169                 ORDER BY b.updated_at DESC",
170            )
171            .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
172
173        let rows = stmt
174            .query_map([&fighter_str], |row| {
175                let id: String = row.get(0)?;
176                let title: Option<String> = row.get(1)?;
177                let created_at: String = row.get(2)?;
178                let updated_at: String = row.get(3)?;
179                let message_count: u64 = row.get(4)?;
180                Ok((id, title, created_at, updated_at, message_count))
181            })
182            .map_err(|e| PunchError::Bout(format!("failed to list bouts: {e}")))?;
183
184        let mut summaries = Vec::new();
185        for row in rows {
186            let (id, title, created_at, updated_at, message_count) =
187                row.map_err(|e| PunchError::Bout(format!("failed to read bout row: {e}")))?;
188
189            let bout_id = BoutId(
190                Uuid::parse_str(&id)
191                    .map_err(|e| PunchError::Bout(format!("invalid bout id: {e}")))?,
192            );
193
194            summaries.push(BoutSummary {
195                id: bout_id,
196                fighter_id: *fighter_id,
197                title,
198                message_count,
199                created_at,
200                updated_at,
201            });
202        }
203
204        Ok(summaries)
205    }
206
207    /// Delete a bout and all its messages (cascading).
208    pub async fn delete_bout(&self, bout_id: &BoutId) -> PunchResult<()> {
209        let bout_str = bout_id.to_string();
210        let conn = self.conn.lock().await;
211
212        conn.execute("DELETE FROM bouts WHERE id = ?1", [&bout_str])
213            .map_err(|e| PunchError::Bout(format!("failed to delete bout: {e}")))?;
214
215        debug!(bout_id = %bout_id, "bout deleted");
216        Ok(())
217    }
218}
219
220fn parse_role(s: &str) -> PunchResult<Role> {
221    match s {
222        "user" => Ok(Role::User),
223        "assistant" => Ok(Role::Assistant),
224        "system" => Ok(Role::System),
225        "tool" => Ok(Role::Tool),
226        other => Err(PunchError::Bout(format!("unknown role: {other}"))),
227    }
228}
229
230fn parse_timestamp(s: &str) -> PunchResult<DateTime<Utc>> {
231    DateTime::parse_from_rfc3339(s)
232        .map(|dt| dt.with_timezone(&Utc))
233        .or_else(|_| {
234            chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ").map(|ndt| ndt.and_utc())
235        })
236        .map_err(|e| PunchError::Bout(format!("invalid timestamp '{s}': {e}")))
237}
238
239#[cfg(test)]
240mod tests {
241    use punch_types::{
242        FighterManifest, FighterStatus, Message, ModelConfig, Provider, Role, WeightClass,
243    };
244
245    use crate::MemorySubstrate;
246
247    fn test_manifest() -> FighterManifest {
248        FighterManifest {
249            name: "Test Fighter".into(),
250            description: "A test fighter".into(),
251            model: ModelConfig {
252                provider: Provider::Anthropic,
253                model: "claude-sonnet-4-20250514".into(),
254                api_key_env: None,
255                base_url: None,
256                max_tokens: Some(4096),
257                temperature: Some(0.7),
258            },
259            system_prompt: "You are a test fighter.".into(),
260            capabilities: Vec::new(),
261            weight_class: WeightClass::Middleweight,
262            tenant_id: None,
263        }
264    }
265
266    #[tokio::test]
267    async fn test_create_bout_and_messages() {
268        let substrate = MemorySubstrate::in_memory().unwrap();
269        let fighter_id = punch_types::FighterId::new();
270
271        substrate
272            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
273            .await
274            .unwrap();
275
276        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
277
278        let msg = Message::new(Role::User, "Hello, fighter!");
279        substrate.save_message(&bout_id, &msg).await.unwrap();
280
281        let messages = substrate.load_messages(&bout_id).await.unwrap();
282        assert_eq!(messages.len(), 1);
283        assert_eq!(messages[0].content, "Hello, fighter!");
284        assert_eq!(messages[0].role, Role::User);
285    }
286
287    #[tokio::test]
288    async fn test_multiple_messages_in_bout() {
289        let substrate = MemorySubstrate::in_memory().unwrap();
290        let fighter_id = punch_types::FighterId::new();
291        substrate
292            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
293            .await
294            .unwrap();
295        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
296
297        substrate.save_message(&bout_id, &Message::new(Role::User, "Hello")).await.unwrap();
298        substrate.save_message(&bout_id, &Message::new(Role::Assistant, "Hi there")).await.unwrap();
299        substrate.save_message(&bout_id, &Message::new(Role::User, "How are you?")).await.unwrap();
300
301        let messages = substrate.load_messages(&bout_id).await.unwrap();
302        assert_eq!(messages.len(), 3);
303        assert_eq!(messages[0].role, Role::User);
304        assert_eq!(messages[1].role, Role::Assistant);
305        assert_eq!(messages[2].content, "How are you?");
306    }
307
308    #[tokio::test]
309    async fn test_load_messages_empty_bout() {
310        let substrate = MemorySubstrate::in_memory().unwrap();
311        let fighter_id = punch_types::FighterId::new();
312        substrate
313            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
314            .await
315            .unwrap();
316        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
317
318        let messages = substrate.load_messages(&bout_id).await.unwrap();
319        assert!(messages.is_empty());
320    }
321
322    #[tokio::test]
323    async fn test_multiple_bouts_for_fighter() {
324        let substrate = MemorySubstrate::in_memory().unwrap();
325        let fighter_id = punch_types::FighterId::new();
326        substrate
327            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
328            .await
329            .unwrap();
330
331        substrate.create_bout(&fighter_id).await.unwrap();
332        substrate.create_bout(&fighter_id).await.unwrap();
333        substrate.create_bout(&fighter_id).await.unwrap();
334
335        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
336        assert_eq!(bouts.len(), 3);
337    }
338
339    #[tokio::test]
340    async fn test_bout_summary_message_count() {
341        let substrate = MemorySubstrate::in_memory().unwrap();
342        let fighter_id = punch_types::FighterId::new();
343        substrate
344            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
345            .await
346            .unwrap();
347        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
348
349        substrate.save_message(&bout_id, &Message::new(Role::User, "a")).await.unwrap();
350        substrate.save_message(&bout_id, &Message::new(Role::Assistant, "b")).await.unwrap();
351
352        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
353        assert_eq!(bouts[0].message_count, 2);
354    }
355
356    #[tokio::test]
357    async fn test_bout_id_display() {
358        let bout_id = super::BoutId::new();
359        let s = bout_id.to_string();
360        assert!(!s.is_empty());
361        // Should be a valid UUID string
362        assert!(uuid::Uuid::parse_str(&s).is_ok());
363    }
364
365    #[tokio::test]
366    async fn test_delete_bout_cascades_messages() {
367        let substrate = MemorySubstrate::in_memory().unwrap();
368        let fighter_id = punch_types::FighterId::new();
369        substrate
370            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
371            .await
372            .unwrap();
373        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
374        substrate.save_message(&bout_id, &Message::new(Role::User, "msg")).await.unwrap();
375
376        substrate.delete_bout(&bout_id).await.unwrap();
377        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
378        assert!(bouts.is_empty());
379    }
380
381    #[tokio::test]
382    async fn test_list_and_delete_bouts() {
383        let substrate = MemorySubstrate::in_memory().unwrap();
384        let fighter_id = punch_types::FighterId::new();
385
386        substrate
387            .save_fighter(&fighter_id, &test_manifest(), FighterStatus::Idle)
388            .await
389            .unwrap();
390
391        let bout_id = substrate.create_bout(&fighter_id).await.unwrap();
392
393        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
394        assert_eq!(bouts.len(), 1);
395
396        substrate.delete_bout(&bout_id).await.unwrap();
397
398        let bouts = substrate.list_bouts(&fighter_id).await.unwrap();
399        assert!(bouts.is_empty());
400    }
401}