Skip to main content

pi/
session_sqlite.rs

1use crate::agent_cx::AgentCx;
2use crate::error::{Error, Result};
3use crate::session::{SessionEntry, SessionHeader};
4use crate::session_metrics;
5use asupersync::Outcome;
6use asupersync::database::{SqliteConnection, SqliteError, SqliteRow, SqliteValue};
7use std::path::Path;
8
9const INIT_SQL: &str = r"
10PRAGMA journal_mode = WAL;
11PRAGMA synchronous = NORMAL;
12PRAGMA foreign_keys = ON;
13
14CREATE TABLE IF NOT EXISTS pi_session_header (
15  id TEXT PRIMARY KEY,
16  json TEXT NOT NULL
17);
18
19CREATE TABLE IF NOT EXISTS pi_session_entries (
20  seq INTEGER PRIMARY KEY,
21  json TEXT NOT NULL
22);
23
24CREATE TABLE IF NOT EXISTS pi_session_meta (
25  key TEXT PRIMARY KEY,
26  value TEXT NOT NULL
27);
28";
29
30#[derive(Debug, Clone)]
31pub struct SqliteSessionMeta {
32    pub header: SessionHeader,
33    pub message_count: u64,
34    pub name: Option<String>,
35}
36
37fn map_outcome<T>(outcome: Outcome<T, SqliteError>) -> Result<T> {
38    match outcome {
39        Outcome::Ok(value) => Ok(value),
40        Outcome::Err(err) => Err(Error::session(format!("SQLite session error: {err}"))),
41        Outcome::Cancelled(_) => Err(Error::Aborted),
42        Outcome::Panicked(payload) => Err(Error::session(format!(
43            "SQLite session operation panicked: {payload:?}"
44        ))),
45    }
46}
47
48fn row_get_str<'a>(row: &'a SqliteRow, column: &str) -> Result<&'a str> {
49    row.get_str(column)
50        .map_err(|err| Error::session(format!("SQLite row read failed: {err}")))
51}
52
53fn compute_message_count_and_name(entries: &[SessionEntry]) -> (u64, Option<String>) {
54    let mut message_count = 0u64;
55    let mut name = None;
56
57    for entry in entries {
58        match entry {
59            SessionEntry::Message(_) => message_count += 1,
60            SessionEntry::SessionInfo(info) => {
61                if info.name.is_some() {
62                    name.clone_from(&info.name);
63                }
64            }
65            _ => {}
66        }
67    }
68
69    (message_count, name)
70}
71
72pub async fn load_session(path: &Path) -> Result<(SessionHeader, Vec<SessionEntry>)> {
73    let metrics = session_metrics::global();
74    let _timer = metrics.start_timer(&metrics.sqlite_load);
75
76    if !path.exists() {
77        return Err(Error::SessionNotFound {
78            path: path.display().to_string(),
79        });
80    }
81
82    let cx = AgentCx::for_request();
83    let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
84
85    let header_rows = map_outcome(
86        conn.query(cx.cx(), "SELECT json FROM pi_session_header LIMIT 1", &[])
87            .await,
88    )?;
89    let header_row = header_rows
90        .first()
91        .ok_or_else(|| Error::session("SQLite session missing header row"))?;
92    let header_json = row_get_str(header_row, "json")?;
93    let header: SessionHeader = serde_json::from_str(header_json)?;
94
95    let entry_rows = map_outcome(
96        conn.query(
97            cx.cx(),
98            "SELECT json FROM pi_session_entries ORDER BY seq ASC",
99            &[],
100        )
101        .await,
102    )?;
103
104    let mut entries = Vec::with_capacity(entry_rows.len());
105    for row in entry_rows {
106        let json = row_get_str(&row, "json")?;
107        let entry: SessionEntry = serde_json::from_str(json)?;
108        entries.push(entry);
109    }
110
111    Ok((header, entries))
112}
113
114pub async fn load_session_meta(path: &Path) -> Result<SqliteSessionMeta> {
115    let metrics = session_metrics::global();
116    let _timer = metrics.start_timer(&metrics.sqlite_load_meta);
117
118    if !path.exists() {
119        return Err(Error::SessionNotFound {
120            path: path.display().to_string(),
121        });
122    }
123
124    let cx = AgentCx::for_request();
125    let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
126
127    let header_rows = map_outcome(
128        conn.query(cx.cx(), "SELECT json FROM pi_session_header LIMIT 1", &[])
129            .await,
130    )?;
131    let header_row = header_rows
132        .first()
133        .ok_or_else(|| Error::session("SQLite session missing header row"))?;
134    let header_json = row_get_str(header_row, "json")?;
135    let header: SessionHeader = serde_json::from_str(header_json)?;
136
137    let meta_rows = map_outcome(
138        conn.query(
139            cx.cx(),
140            "SELECT key,value FROM pi_session_meta WHERE key IN ('message_count','name')",
141            &[],
142        )
143        .await,
144    )?;
145
146    let mut message_count: Option<u64> = None;
147    let mut name: Option<String> = None;
148    for row in meta_rows {
149        let key = row_get_str(&row, "key")?;
150        let value = row_get_str(&row, "value")?;
151        match key {
152            "message_count" => message_count = value.parse::<u64>().ok(),
153            "name" => name = Some(value.to_string()),
154            _ => {}
155        }
156    }
157
158    let message_count = if let Some(message_count) = message_count {
159        message_count
160    } else {
161        let entry_rows = map_outcome(
162            conn.query(
163                cx.cx(),
164                "SELECT json FROM pi_session_entries ORDER BY seq ASC",
165                &[],
166            )
167            .await,
168        )?;
169
170        let mut entries = Vec::with_capacity(entry_rows.len());
171        for row in entry_rows {
172            let json = row_get_str(&row, "json")?;
173            let entry: SessionEntry = serde_json::from_str(json)?;
174            entries.push(entry);
175        }
176
177        let (message_count, fallback_name) = compute_message_count_and_name(&entries);
178        if name.is_none() {
179            name = fallback_name;
180        }
181        message_count
182    };
183    Ok(SqliteSessionMeta {
184        header,
185        message_count,
186        name,
187    })
188}
189
190#[cfg(test)]
191#[allow(clippy::items_after_test_module)]
192mod tests {
193    use super::*;
194    use crate::model::UserContent;
195    use crate::session::{EntryBase, MessageEntry, SessionInfoEntry, SessionMessage};
196
197    fn dummy_base() -> EntryBase {
198        EntryBase {
199            id: Some("test-id".to_string()),
200            parent_id: None,
201            timestamp: "2026-01-01T00:00:00.000Z".to_string(),
202        }
203    }
204
205    fn message_entry() -> SessionEntry {
206        SessionEntry::Message(MessageEntry {
207            base: dummy_base(),
208            message: SessionMessage::User {
209                content: UserContent::Text("hello".to_string()),
210                timestamp: None,
211            },
212        })
213    }
214
215    fn session_info_entry(name: Option<String>) -> SessionEntry {
216        SessionEntry::SessionInfo(SessionInfoEntry {
217            base: dummy_base(),
218            name,
219        })
220    }
221
222    #[test]
223    fn compute_counts_empty() {
224        let (count, name) = compute_message_count_and_name(&[]);
225        assert_eq!(count, 0);
226        assert!(name.is_none());
227    }
228
229    #[test]
230    fn compute_counts_messages_only() {
231        let entries = vec![message_entry(), message_entry(), message_entry()];
232        let (count, name) = compute_message_count_and_name(&entries);
233        assert_eq!(count, 3);
234        assert!(name.is_none());
235    }
236
237    #[test]
238    fn compute_counts_session_info_with_name() {
239        let entries = vec![
240            message_entry(),
241            session_info_entry(Some("My Session".to_string())),
242            message_entry(),
243        ];
244        let (count, name) = compute_message_count_and_name(&entries);
245        assert_eq!(count, 2);
246        assert_eq!(name, Some("My Session".to_string()));
247    }
248
249    #[test]
250    fn compute_counts_session_info_none_name_ignored() {
251        let entries = vec![
252            session_info_entry(Some("First".to_string())),
253            session_info_entry(None),
254            message_entry(),
255        ];
256        let (count, name) = compute_message_count_and_name(&entries);
257        assert_eq!(count, 1);
258        // The second SessionInfo has name=None, so it doesn't overwrite.
259        assert_eq!(name, Some("First".to_string()));
260    }
261
262    #[test]
263    fn compute_counts_latest_name_wins() {
264        let entries = vec![
265            session_info_entry(Some("First".to_string())),
266            session_info_entry(Some("Second".to_string())),
267        ];
268        let (_, name) = compute_message_count_and_name(&entries);
269        assert_eq!(name, Some("Second".to_string()));
270    }
271
272    // -- Non-message / non-session-info entries are ignored --
273
274    #[test]
275    fn compute_counts_ignores_model_change_entries() {
276        use crate::session::ModelChangeEntry;
277        let entries = vec![
278            message_entry(),
279            SessionEntry::ModelChange(ModelChangeEntry {
280                base: dummy_base(),
281                provider: "anthropic".to_string(),
282                model_id: "claude-sonnet-4-5".to_string(),
283            }),
284            message_entry(),
285        ];
286        let (count, name) = compute_message_count_and_name(&entries);
287        assert_eq!(count, 2);
288        assert!(name.is_none());
289    }
290
291    #[test]
292    fn compute_counts_ignores_label_entries() {
293        use crate::session::LabelEntry;
294        let entries = vec![
295            message_entry(),
296            SessionEntry::Label(LabelEntry {
297                base: dummy_base(),
298                target_id: "some-id".to_string(),
299                label: Some("important".to_string()),
300            }),
301        ];
302        let (count, name) = compute_message_count_and_name(&entries);
303        assert_eq!(count, 1);
304        assert!(name.is_none());
305    }
306
307    #[test]
308    fn compute_counts_ignores_custom_entries() {
309        use crate::session::CustomEntry;
310        let entries = vec![
311            SessionEntry::Custom(CustomEntry {
312                base: dummy_base(),
313                custom_type: "my_custom".to_string(),
314                data: Some(serde_json::json!({"key": "value"})),
315            }),
316            message_entry(),
317        ];
318        let (count, name) = compute_message_count_and_name(&entries);
319        assert_eq!(count, 1);
320        assert!(name.is_none());
321    }
322
323    #[test]
324    fn compute_counts_ignores_compaction_entries() {
325        use crate::session::CompactionEntry;
326        let entries = vec![
327            message_entry(),
328            SessionEntry::Compaction(CompactionEntry {
329                base: dummy_base(),
330                summary: "summary text".to_string(),
331                first_kept_entry_id: "e1".to_string(),
332                tokens_before: 500,
333                details: None,
334                from_hook: None,
335            }),
336            message_entry(),
337            message_entry(),
338        ];
339        let (count, name) = compute_message_count_and_name(&entries);
340        assert_eq!(count, 3);
341        assert!(name.is_none());
342    }
343
344    #[test]
345    fn compute_counts_mixed_entry_types() {
346        use crate::session::{CompactionEntry, CustomEntry, LabelEntry, ModelChangeEntry};
347        let entries = vec![
348            message_entry(),
349            SessionEntry::ModelChange(ModelChangeEntry {
350                base: dummy_base(),
351                provider: "openai".to_string(),
352                model_id: "gpt-4".to_string(),
353            }),
354            session_info_entry(Some("Named".to_string())),
355            SessionEntry::Label(LabelEntry {
356                base: dummy_base(),
357                target_id: "t1".to_string(),
358                label: None,
359            }),
360            message_entry(),
361            SessionEntry::Compaction(CompactionEntry {
362                base: dummy_base(),
363                summary: "s".to_string(),
364                first_kept_entry_id: "e1".to_string(),
365                tokens_before: 100,
366                details: None,
367                from_hook: None,
368            }),
369            SessionEntry::Custom(CustomEntry {
370                base: dummy_base(),
371                custom_type: "ct".to_string(),
372                data: None,
373            }),
374            message_entry(),
375        ];
376        let (count, name) = compute_message_count_and_name(&entries);
377        assert_eq!(count, 3);
378        assert_eq!(name, Some("Named".to_string()));
379    }
380
381    // -- map_outcome tests --
382
383    #[test]
384    fn map_outcome_ok() {
385        let outcome: Outcome<i32, SqliteError> = Outcome::Ok(42);
386        let result = map_outcome(outcome);
387        assert_eq!(result.unwrap(), 42);
388    }
389
390    #[test]
391    fn map_outcome_err() {
392        let outcome: Outcome<i32, SqliteError> = Outcome::Err(SqliteError::ConnectionClosed);
393        let result = map_outcome(outcome);
394        let err = result.unwrap_err();
395        match err {
396            Error::Session(message) => {
397                assert!(message.contains("SQLite session error"));
398            }
399            other => panic!("expected Session error, got {other:?}"),
400        }
401    }
402
403    #[test]
404    fn map_outcome_cancelled() {
405        use asupersync::types::CancelKind;
406        let reason = asupersync::CancelReason::new(CancelKind::User);
407        let outcome: Outcome<i32, SqliteError> = Outcome::Cancelled(reason);
408        let result = map_outcome(outcome);
409        assert!(matches!(result.unwrap_err(), Error::Aborted));
410    }
411
412    #[test]
413    fn map_outcome_panicked() {
414        use asupersync::types::PanicPayload;
415        let outcome: Outcome<i32, SqliteError> = Outcome::Panicked(PanicPayload::new("test panic"));
416        let result = map_outcome(outcome);
417        let err = result.unwrap_err();
418        match err {
419            Error::Session(message) => {
420                assert!(message.contains("panicked"));
421            }
422            other => panic!("expected Session error, got {other:?}"),
423        }
424    }
425
426    // -- SqliteSessionMeta struct --
427
428    #[test]
429    fn sqlite_session_meta_fields() {
430        let meta = SqliteSessionMeta {
431            header: SessionHeader {
432                id: "test-session".to_string(),
433                ..SessionHeader::default()
434            },
435            message_count: 42,
436            name: Some("My Session".to_string()),
437        };
438        assert_eq!(meta.header.id, "test-session");
439        assert_eq!(meta.message_count, 42);
440        assert_eq!(meta.name.as_deref(), Some("My Session"));
441    }
442
443    #[test]
444    fn sqlite_session_meta_no_name() {
445        let meta = SqliteSessionMeta {
446            header: SessionHeader::default(),
447            message_count: 0,
448            name: None,
449        };
450        assert_eq!(meta.message_count, 0);
451        assert!(meta.name.is_none());
452    }
453
454    // -- compute_message_count_and_name: large input --
455
456    #[test]
457    fn compute_counts_large_message_set() {
458        let entries: Vec<SessionEntry> = (0..1000).map(|_| message_entry()).collect();
459        let (count, name) = compute_message_count_and_name(&entries);
460        assert_eq!(count, 1000);
461        assert!(name.is_none());
462    }
463
464    // -- compute_message_count_and_name: name then messages only --
465
466    #[test]
467    fn compute_counts_name_set_early_persists() {
468        let entries = vec![
469            session_info_entry(Some("Early Name".to_string())),
470            message_entry(),
471            message_entry(),
472            message_entry(),
473        ];
474        let (count, name) = compute_message_count_and_name(&entries);
475        assert_eq!(count, 3);
476        assert_eq!(name, Some("Early Name".to_string()));
477    }
478
479    // -- compute_message_count_and_name: branch summary entry --
480
481    #[test]
482    fn compute_counts_ignores_branch_summary() {
483        use crate::session::BranchSummaryEntry;
484        let entries = vec![
485            message_entry(),
486            SessionEntry::BranchSummary(BranchSummaryEntry {
487                base: dummy_base(),
488                from_id: "parent-id".to_string(),
489                summary: "branch summary".to_string(),
490                details: None,
491                from_hook: None,
492            }),
493        ];
494        let (count, name) = compute_message_count_and_name(&entries);
495        assert_eq!(count, 1);
496        assert!(name.is_none());
497    }
498
499    // -- compute_message_count_and_name: thinking level change --
500
501    #[test]
502    fn compute_counts_ignores_thinking_level_change() {
503        use crate::session::ThinkingLevelChangeEntry;
504        let entries = vec![
505            SessionEntry::ThinkingLevelChange(ThinkingLevelChangeEntry {
506                base: dummy_base(),
507                thinking_level: "high".to_string(),
508            }),
509            message_entry(),
510        ];
511        let (count, name) = compute_message_count_and_name(&entries);
512        assert_eq!(count, 1);
513        assert!(name.is_none());
514    }
515}
516
517pub async fn save_session(
518    path: &Path,
519    header: &SessionHeader,
520    entries: &[SessionEntry],
521) -> Result<()> {
522    let metrics = session_metrics::global();
523    let _save_timer = metrics.start_timer(&metrics.sqlite_save);
524
525    if let Some(parent) = path.parent() {
526        asupersync::fs::create_dir_all(parent).await?;
527    }
528
529    let cx = AgentCx::for_request();
530    let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
531    map_outcome(conn.execute_batch(cx.cx(), INIT_SQL).await)?;
532
533    let tx = map_outcome(conn.begin_immediate(cx.cx()).await)?;
534
535    map_outcome(
536        tx.execute(cx.cx(), "DELETE FROM pi_session_entries", &[])
537            .await,
538    )?;
539    map_outcome(
540        tx.execute(cx.cx(), "DELETE FROM pi_session_header", &[])
541            .await,
542    )?;
543    map_outcome(
544        tx.execute(cx.cx(), "DELETE FROM pi_session_meta", &[])
545            .await,
546    )?;
547
548    // Serialize header + entries and track serialization time + bytes.
549    let serialize_timer = metrics.start_timer(&metrics.sqlite_serialize);
550    let header_json = serde_json::to_string(header)?;
551    let mut total_json_bytes = header_json.len() as u64;
552
553    let mut entry_jsons = Vec::with_capacity(entries.len());
554    for entry in entries {
555        let json = serde_json::to_string(entry)?;
556        total_json_bytes += json.len() as u64;
557        entry_jsons.push(json);
558    }
559    serialize_timer.finish();
560    metrics.record_bytes(&metrics.sqlite_bytes, total_json_bytes);
561
562    map_outcome(
563        tx.execute(
564            cx.cx(),
565            "INSERT INTO pi_session_header (id,json) VALUES (?1,?2)",
566            &[
567                SqliteValue::Text(header.id.clone()),
568                SqliteValue::Text(header_json),
569            ],
570        )
571        .await,
572    )?;
573
574    for (idx, json) in entry_jsons.into_iter().enumerate() {
575        map_outcome(
576            tx.execute(
577                cx.cx(),
578                "INSERT INTO pi_session_entries (seq,json) VALUES (?1,?2)",
579                &[
580                    SqliteValue::Integer(i64::try_from(idx + 1).unwrap_or(i64::MAX)),
581                    SqliteValue::Text(json),
582                ],
583            )
584            .await,
585        )?;
586    }
587
588    let (message_count, name) = compute_message_count_and_name(entries);
589    map_outcome(
590        tx.execute(
591            cx.cx(),
592            "INSERT INTO pi_session_meta (key,value) VALUES (?1,?2)",
593            &[
594                SqliteValue::Text("message_count".to_string()),
595                SqliteValue::Text(message_count.to_string()),
596            ],
597        )
598        .await,
599    )?;
600    if let Some(name) = name {
601        map_outcome(
602            tx.execute(
603                cx.cx(),
604                "INSERT INTO pi_session_meta (key,value) VALUES (?1,?2)",
605                &[
606                    SqliteValue::Text("name".to_string()),
607                    SqliteValue::Text(name),
608                ],
609            )
610            .await,
611        )?;
612    }
613
614    map_outcome(tx.commit(cx.cx()).await)?;
615    Ok(())
616}
617
618/// Incrementally append new entries to an existing SQLite session database.
619///
620/// Only the entries in `new_entries` (starting at 1-based sequence `start_seq`)
621/// are inserted. The header row is left unchanged, while the `message_count`
622/// and `name` meta rows are upserted to reflect the current totals.
623///
624/// This avoids the DELETE+reinsert cost of [`save_session`] for the common
625/// case where a few entries are appended between saves.
626pub async fn append_entries(
627    path: &Path,
628    new_entries: &[SessionEntry],
629    start_seq: usize,
630    message_count: u64,
631    session_name: Option<&str>,
632) -> Result<()> {
633    let metrics = session_metrics::global();
634    let _timer = metrics.start_timer(&metrics.sqlite_append);
635
636    let cx = AgentCx::for_request();
637    let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
638
639    // Ensure WAL mode is active (no-op if already set).
640    map_outcome(
641        conn.execute_batch(cx.cx(), "PRAGMA journal_mode = WAL")
642            .await,
643    )?;
644
645    let tx = map_outcome(conn.begin_immediate(cx.cx()).await)?;
646
647    // Serialize and insert only the new entries.
648    let serialize_timer = metrics.start_timer(&metrics.sqlite_serialize);
649    let mut total_json_bytes = 0u64;
650    let mut entry_jsons = Vec::with_capacity(new_entries.len());
651    for entry in new_entries {
652        let json = serde_json::to_string(entry)?;
653        total_json_bytes += json.len() as u64;
654        entry_jsons.push(json);
655    }
656    serialize_timer.finish();
657    metrics.record_bytes(&metrics.sqlite_bytes, total_json_bytes);
658
659    for (i, json) in entry_jsons.into_iter().enumerate() {
660        let seq = start_seq + i + 1; // 1-based
661        map_outcome(
662            tx.execute(
663                cx.cx(),
664                "INSERT INTO pi_session_entries (seq,json) VALUES (?1,?2)",
665                &[
666                    SqliteValue::Integer(i64::try_from(seq).unwrap_or(i64::MAX)),
667                    SqliteValue::Text(json),
668                ],
669            )
670            .await,
671        )?;
672    }
673
674    // Upsert meta counters (INSERT OR REPLACE).
675    map_outcome(
676        tx.execute(
677            cx.cx(),
678            "INSERT OR REPLACE INTO pi_session_meta (key,value) VALUES (?1,?2)",
679            &[
680                SqliteValue::Text("message_count".to_string()),
681                SqliteValue::Text(message_count.to_string()),
682            ],
683        )
684        .await,
685    )?;
686    if let Some(name) = session_name {
687        map_outcome(
688            tx.execute(
689                cx.cx(),
690                "INSERT OR REPLACE INTO pi_session_meta (key,value) VALUES (?1,?2)",
691                &[
692                    SqliteValue::Text("name".to_string()),
693                    SqliteValue::Text(name.to_string()),
694                ],
695            )
696            .await,
697        )?;
698    }
699
700    map_outcome(tx.commit(cx.cx()).await)?;
701    Ok(())
702}