1use crate::agent_cx::AgentCx;
2use crate::error::{Error, Result};
3use crate::session::{SessionEntry, SessionHeader};
4use crate::session_metrics;
5use asupersync::Outcome;
6use asupersync::database::{SqliteConnection, SqliteError, SqliteRow, SqliteValue};
7use std::path::Path;
8
9const INIT_SQL: &str = r"
10PRAGMA journal_mode = WAL;
11PRAGMA synchronous = NORMAL;
12PRAGMA foreign_keys = ON;
13
14CREATE TABLE IF NOT EXISTS pi_session_header (
15 id TEXT PRIMARY KEY,
16 json TEXT NOT NULL
17);
18
19CREATE TABLE IF NOT EXISTS pi_session_entries (
20 seq INTEGER PRIMARY KEY,
21 json TEXT NOT NULL
22);
23
24CREATE TABLE IF NOT EXISTS pi_session_meta (
25 key TEXT PRIMARY KEY,
26 value TEXT NOT NULL
27);
28";
29
30#[derive(Debug, Clone)]
31pub struct SqliteSessionMeta {
32 pub header: SessionHeader,
33 pub message_count: u64,
34 pub name: Option<String>,
35}
36
37fn map_outcome<T>(outcome: Outcome<T, SqliteError>) -> Result<T> {
38 match outcome {
39 Outcome::Ok(value) => Ok(value),
40 Outcome::Err(err) => Err(Error::session(format!("SQLite session error: {err}"))),
41 Outcome::Cancelled(_) => Err(Error::Aborted),
42 Outcome::Panicked(payload) => Err(Error::session(format!(
43 "SQLite session operation panicked: {payload:?}"
44 ))),
45 }
46}
47
48fn row_get_str<'a>(row: &'a SqliteRow, column: &str) -> Result<&'a str> {
49 row.get_str(column)
50 .map_err(|err| Error::session(format!("SQLite row read failed: {err}")))
51}
52
53fn compute_message_count_and_name(entries: &[SessionEntry]) -> (u64, Option<String>) {
54 let mut message_count = 0u64;
55 let mut name = None;
56
57 for entry in entries {
58 match entry {
59 SessionEntry::Message(_) => message_count += 1,
60 SessionEntry::SessionInfo(info) => {
61 if info.name.is_some() {
62 name.clone_from(&info.name);
63 }
64 }
65 _ => {}
66 }
67 }
68
69 (message_count, name)
70}
71
72pub async fn load_session(path: &Path) -> Result<(SessionHeader, Vec<SessionEntry>)> {
73 let metrics = session_metrics::global();
74 let _timer = metrics.start_timer(&metrics.sqlite_load);
75
76 if !path.exists() {
77 return Err(Error::SessionNotFound {
78 path: path.display().to_string(),
79 });
80 }
81
82 let cx = AgentCx::for_request();
83 let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
84
85 let header_rows = map_outcome(
86 conn.query(cx.cx(), "SELECT json FROM pi_session_header LIMIT 1", &[])
87 .await,
88 )?;
89 let header_row = header_rows
90 .first()
91 .ok_or_else(|| Error::session("SQLite session missing header row"))?;
92 let header_json = row_get_str(header_row, "json")?;
93 let header: SessionHeader = serde_json::from_str(header_json)?;
94
95 let entry_rows = map_outcome(
96 conn.query(
97 cx.cx(),
98 "SELECT json FROM pi_session_entries ORDER BY seq ASC",
99 &[],
100 )
101 .await,
102 )?;
103
104 let mut entries = Vec::with_capacity(entry_rows.len());
105 for row in entry_rows {
106 let json = row_get_str(&row, "json")?;
107 let entry: SessionEntry = serde_json::from_str(json)?;
108 entries.push(entry);
109 }
110
111 Ok((header, entries))
112}
113
114pub async fn load_session_meta(path: &Path) -> Result<SqliteSessionMeta> {
115 let metrics = session_metrics::global();
116 let _timer = metrics.start_timer(&metrics.sqlite_load_meta);
117
118 if !path.exists() {
119 return Err(Error::SessionNotFound {
120 path: path.display().to_string(),
121 });
122 }
123
124 let cx = AgentCx::for_request();
125 let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
126
127 let header_rows = map_outcome(
128 conn.query(cx.cx(), "SELECT json FROM pi_session_header LIMIT 1", &[])
129 .await,
130 )?;
131 let header_row = header_rows
132 .first()
133 .ok_or_else(|| Error::session("SQLite session missing header row"))?;
134 let header_json = row_get_str(header_row, "json")?;
135 let header: SessionHeader = serde_json::from_str(header_json)?;
136
137 let meta_rows = map_outcome(
138 conn.query(
139 cx.cx(),
140 "SELECT key,value FROM pi_session_meta WHERE key IN ('message_count','name')",
141 &[],
142 )
143 .await,
144 )?;
145
146 let mut message_count: Option<u64> = None;
147 let mut name: Option<String> = None;
148 for row in meta_rows {
149 let key = row_get_str(&row, "key")?;
150 let value = row_get_str(&row, "value")?;
151 match key {
152 "message_count" => message_count = value.parse::<u64>().ok(),
153 "name" => name = Some(value.to_string()),
154 _ => {}
155 }
156 }
157
158 let message_count = if let Some(message_count) = message_count {
159 message_count
160 } else {
161 let entry_rows = map_outcome(
162 conn.query(
163 cx.cx(),
164 "SELECT json FROM pi_session_entries ORDER BY seq ASC",
165 &[],
166 )
167 .await,
168 )?;
169
170 let mut entries = Vec::with_capacity(entry_rows.len());
171 for row in entry_rows {
172 let json = row_get_str(&row, "json")?;
173 let entry: SessionEntry = serde_json::from_str(json)?;
174 entries.push(entry);
175 }
176
177 let (message_count, fallback_name) = compute_message_count_and_name(&entries);
178 if name.is_none() {
179 name = fallback_name;
180 }
181 message_count
182 };
183 Ok(SqliteSessionMeta {
184 header,
185 message_count,
186 name,
187 })
188}
189
190#[cfg(test)]
191#[allow(clippy::items_after_test_module)]
192mod tests {
193 use super::*;
194 use crate::model::UserContent;
195 use crate::session::{EntryBase, MessageEntry, SessionInfoEntry, SessionMessage};
196
197 fn dummy_base() -> EntryBase {
198 EntryBase {
199 id: Some("test-id".to_string()),
200 parent_id: None,
201 timestamp: "2026-01-01T00:00:00.000Z".to_string(),
202 }
203 }
204
205 fn message_entry() -> SessionEntry {
206 SessionEntry::Message(MessageEntry {
207 base: dummy_base(),
208 message: SessionMessage::User {
209 content: UserContent::Text("hello".to_string()),
210 timestamp: None,
211 },
212 })
213 }
214
215 fn session_info_entry(name: Option<String>) -> SessionEntry {
216 SessionEntry::SessionInfo(SessionInfoEntry {
217 base: dummy_base(),
218 name,
219 })
220 }
221
222 #[test]
223 fn compute_counts_empty() {
224 let (count, name) = compute_message_count_and_name(&[]);
225 assert_eq!(count, 0);
226 assert!(name.is_none());
227 }
228
229 #[test]
230 fn compute_counts_messages_only() {
231 let entries = vec![message_entry(), message_entry(), message_entry()];
232 let (count, name) = compute_message_count_and_name(&entries);
233 assert_eq!(count, 3);
234 assert!(name.is_none());
235 }
236
237 #[test]
238 fn compute_counts_session_info_with_name() {
239 let entries = vec![
240 message_entry(),
241 session_info_entry(Some("My Session".to_string())),
242 message_entry(),
243 ];
244 let (count, name) = compute_message_count_and_name(&entries);
245 assert_eq!(count, 2);
246 assert_eq!(name, Some("My Session".to_string()));
247 }
248
249 #[test]
250 fn compute_counts_session_info_none_name_ignored() {
251 let entries = vec![
252 session_info_entry(Some("First".to_string())),
253 session_info_entry(None),
254 message_entry(),
255 ];
256 let (count, name) = compute_message_count_and_name(&entries);
257 assert_eq!(count, 1);
258 assert_eq!(name, Some("First".to_string()));
260 }
261
262 #[test]
263 fn compute_counts_latest_name_wins() {
264 let entries = vec![
265 session_info_entry(Some("First".to_string())),
266 session_info_entry(Some("Second".to_string())),
267 ];
268 let (_, name) = compute_message_count_and_name(&entries);
269 assert_eq!(name, Some("Second".to_string()));
270 }
271
272 #[test]
275 fn compute_counts_ignores_model_change_entries() {
276 use crate::session::ModelChangeEntry;
277 let entries = vec![
278 message_entry(),
279 SessionEntry::ModelChange(ModelChangeEntry {
280 base: dummy_base(),
281 provider: "anthropic".to_string(),
282 model_id: "claude-sonnet-4-5".to_string(),
283 }),
284 message_entry(),
285 ];
286 let (count, name) = compute_message_count_and_name(&entries);
287 assert_eq!(count, 2);
288 assert!(name.is_none());
289 }
290
291 #[test]
292 fn compute_counts_ignores_label_entries() {
293 use crate::session::LabelEntry;
294 let entries = vec![
295 message_entry(),
296 SessionEntry::Label(LabelEntry {
297 base: dummy_base(),
298 target_id: "some-id".to_string(),
299 label: Some("important".to_string()),
300 }),
301 ];
302 let (count, name) = compute_message_count_and_name(&entries);
303 assert_eq!(count, 1);
304 assert!(name.is_none());
305 }
306
307 #[test]
308 fn compute_counts_ignores_custom_entries() {
309 use crate::session::CustomEntry;
310 let entries = vec![
311 SessionEntry::Custom(CustomEntry {
312 base: dummy_base(),
313 custom_type: "my_custom".to_string(),
314 data: Some(serde_json::json!({"key": "value"})),
315 }),
316 message_entry(),
317 ];
318 let (count, name) = compute_message_count_and_name(&entries);
319 assert_eq!(count, 1);
320 assert!(name.is_none());
321 }
322
323 #[test]
324 fn compute_counts_ignores_compaction_entries() {
325 use crate::session::CompactionEntry;
326 let entries = vec![
327 message_entry(),
328 SessionEntry::Compaction(CompactionEntry {
329 base: dummy_base(),
330 summary: "summary text".to_string(),
331 first_kept_entry_id: "e1".to_string(),
332 tokens_before: 500,
333 details: None,
334 from_hook: None,
335 }),
336 message_entry(),
337 message_entry(),
338 ];
339 let (count, name) = compute_message_count_and_name(&entries);
340 assert_eq!(count, 3);
341 assert!(name.is_none());
342 }
343
344 #[test]
345 fn compute_counts_mixed_entry_types() {
346 use crate::session::{CompactionEntry, CustomEntry, LabelEntry, ModelChangeEntry};
347 let entries = vec![
348 message_entry(),
349 SessionEntry::ModelChange(ModelChangeEntry {
350 base: dummy_base(),
351 provider: "openai".to_string(),
352 model_id: "gpt-4".to_string(),
353 }),
354 session_info_entry(Some("Named".to_string())),
355 SessionEntry::Label(LabelEntry {
356 base: dummy_base(),
357 target_id: "t1".to_string(),
358 label: None,
359 }),
360 message_entry(),
361 SessionEntry::Compaction(CompactionEntry {
362 base: dummy_base(),
363 summary: "s".to_string(),
364 first_kept_entry_id: "e1".to_string(),
365 tokens_before: 100,
366 details: None,
367 from_hook: None,
368 }),
369 SessionEntry::Custom(CustomEntry {
370 base: dummy_base(),
371 custom_type: "ct".to_string(),
372 data: None,
373 }),
374 message_entry(),
375 ];
376 let (count, name) = compute_message_count_and_name(&entries);
377 assert_eq!(count, 3);
378 assert_eq!(name, Some("Named".to_string()));
379 }
380
381 #[test]
384 fn map_outcome_ok() {
385 let outcome: Outcome<i32, SqliteError> = Outcome::Ok(42);
386 let result = map_outcome(outcome);
387 assert_eq!(result.unwrap(), 42);
388 }
389
390 #[test]
391 fn map_outcome_err() {
392 let outcome: Outcome<i32, SqliteError> = Outcome::Err(SqliteError::ConnectionClosed);
393 let result = map_outcome(outcome);
394 let err = result.unwrap_err();
395 match err {
396 Error::Session(message) => {
397 assert!(message.contains("SQLite session error"));
398 }
399 other => panic!("expected Session error, got {other:?}"),
400 }
401 }
402
403 #[test]
404 fn map_outcome_cancelled() {
405 use asupersync::types::CancelKind;
406 let reason = asupersync::CancelReason::new(CancelKind::User);
407 let outcome: Outcome<i32, SqliteError> = Outcome::Cancelled(reason);
408 let result = map_outcome(outcome);
409 assert!(matches!(result.unwrap_err(), Error::Aborted));
410 }
411
412 #[test]
413 fn map_outcome_panicked() {
414 use asupersync::types::PanicPayload;
415 let outcome: Outcome<i32, SqliteError> = Outcome::Panicked(PanicPayload::new("test panic"));
416 let result = map_outcome(outcome);
417 let err = result.unwrap_err();
418 match err {
419 Error::Session(message) => {
420 assert!(message.contains("panicked"));
421 }
422 other => panic!("expected Session error, got {other:?}"),
423 }
424 }
425
426 #[test]
429 fn sqlite_session_meta_fields() {
430 let meta = SqliteSessionMeta {
431 header: SessionHeader {
432 id: "test-session".to_string(),
433 ..SessionHeader::default()
434 },
435 message_count: 42,
436 name: Some("My Session".to_string()),
437 };
438 assert_eq!(meta.header.id, "test-session");
439 assert_eq!(meta.message_count, 42);
440 assert_eq!(meta.name.as_deref(), Some("My Session"));
441 }
442
443 #[test]
444 fn sqlite_session_meta_no_name() {
445 let meta = SqliteSessionMeta {
446 header: SessionHeader::default(),
447 message_count: 0,
448 name: None,
449 };
450 assert_eq!(meta.message_count, 0);
451 assert!(meta.name.is_none());
452 }
453
454 #[test]
457 fn compute_counts_large_message_set() {
458 let entries: Vec<SessionEntry> = (0..1000).map(|_| message_entry()).collect();
459 let (count, name) = compute_message_count_and_name(&entries);
460 assert_eq!(count, 1000);
461 assert!(name.is_none());
462 }
463
464 #[test]
467 fn compute_counts_name_set_early_persists() {
468 let entries = vec![
469 session_info_entry(Some("Early Name".to_string())),
470 message_entry(),
471 message_entry(),
472 message_entry(),
473 ];
474 let (count, name) = compute_message_count_and_name(&entries);
475 assert_eq!(count, 3);
476 assert_eq!(name, Some("Early Name".to_string()));
477 }
478
479 #[test]
482 fn compute_counts_ignores_branch_summary() {
483 use crate::session::BranchSummaryEntry;
484 let entries = vec![
485 message_entry(),
486 SessionEntry::BranchSummary(BranchSummaryEntry {
487 base: dummy_base(),
488 from_id: "parent-id".to_string(),
489 summary: "branch summary".to_string(),
490 details: None,
491 from_hook: None,
492 }),
493 ];
494 let (count, name) = compute_message_count_and_name(&entries);
495 assert_eq!(count, 1);
496 assert!(name.is_none());
497 }
498
499 #[test]
502 fn compute_counts_ignores_thinking_level_change() {
503 use crate::session::ThinkingLevelChangeEntry;
504 let entries = vec![
505 SessionEntry::ThinkingLevelChange(ThinkingLevelChangeEntry {
506 base: dummy_base(),
507 thinking_level: "high".to_string(),
508 }),
509 message_entry(),
510 ];
511 let (count, name) = compute_message_count_and_name(&entries);
512 assert_eq!(count, 1);
513 assert!(name.is_none());
514 }
515}
516
517pub async fn save_session(
518 path: &Path,
519 header: &SessionHeader,
520 entries: &[SessionEntry],
521) -> Result<()> {
522 let metrics = session_metrics::global();
523 let _save_timer = metrics.start_timer(&metrics.sqlite_save);
524
525 if let Some(parent) = path.parent() {
526 asupersync::fs::create_dir_all(parent).await?;
527 }
528
529 let cx = AgentCx::for_request();
530 let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
531 map_outcome(conn.execute_batch(cx.cx(), INIT_SQL).await)?;
532
533 let tx = map_outcome(conn.begin_immediate(cx.cx()).await)?;
534
535 map_outcome(
536 tx.execute(cx.cx(), "DELETE FROM pi_session_entries", &[])
537 .await,
538 )?;
539 map_outcome(
540 tx.execute(cx.cx(), "DELETE FROM pi_session_header", &[])
541 .await,
542 )?;
543 map_outcome(
544 tx.execute(cx.cx(), "DELETE FROM pi_session_meta", &[])
545 .await,
546 )?;
547
548 let serialize_timer = metrics.start_timer(&metrics.sqlite_serialize);
550 let header_json = serde_json::to_string(header)?;
551 let mut total_json_bytes = header_json.len() as u64;
552
553 let mut entry_jsons = Vec::with_capacity(entries.len());
554 for entry in entries {
555 let json = serde_json::to_string(entry)?;
556 total_json_bytes += json.len() as u64;
557 entry_jsons.push(json);
558 }
559 serialize_timer.finish();
560 metrics.record_bytes(&metrics.sqlite_bytes, total_json_bytes);
561
562 map_outcome(
563 tx.execute(
564 cx.cx(),
565 "INSERT INTO pi_session_header (id,json) VALUES (?1,?2)",
566 &[
567 SqliteValue::Text(header.id.clone()),
568 SqliteValue::Text(header_json),
569 ],
570 )
571 .await,
572 )?;
573
574 for (idx, json) in entry_jsons.into_iter().enumerate() {
575 map_outcome(
576 tx.execute(
577 cx.cx(),
578 "INSERT INTO pi_session_entries (seq,json) VALUES (?1,?2)",
579 &[
580 SqliteValue::Integer(i64::try_from(idx + 1).unwrap_or(i64::MAX)),
581 SqliteValue::Text(json),
582 ],
583 )
584 .await,
585 )?;
586 }
587
588 let (message_count, name) = compute_message_count_and_name(entries);
589 map_outcome(
590 tx.execute(
591 cx.cx(),
592 "INSERT INTO pi_session_meta (key,value) VALUES (?1,?2)",
593 &[
594 SqliteValue::Text("message_count".to_string()),
595 SqliteValue::Text(message_count.to_string()),
596 ],
597 )
598 .await,
599 )?;
600 if let Some(name) = name {
601 map_outcome(
602 tx.execute(
603 cx.cx(),
604 "INSERT INTO pi_session_meta (key,value) VALUES (?1,?2)",
605 &[
606 SqliteValue::Text("name".to_string()),
607 SqliteValue::Text(name),
608 ],
609 )
610 .await,
611 )?;
612 }
613
614 map_outcome(tx.commit(cx.cx()).await)?;
615 Ok(())
616}
617
618pub async fn append_entries(
627 path: &Path,
628 new_entries: &[SessionEntry],
629 start_seq: usize,
630 message_count: u64,
631 session_name: Option<&str>,
632) -> Result<()> {
633 let metrics = session_metrics::global();
634 let _timer = metrics.start_timer(&metrics.sqlite_append);
635
636 let cx = AgentCx::for_request();
637 let conn = map_outcome(SqliteConnection::open(cx.cx(), path).await)?;
638
639 map_outcome(
641 conn.execute_batch(cx.cx(), "PRAGMA journal_mode = WAL")
642 .await,
643 )?;
644
645 let tx = map_outcome(conn.begin_immediate(cx.cx()).await)?;
646
647 let serialize_timer = metrics.start_timer(&metrics.sqlite_serialize);
649 let mut total_json_bytes = 0u64;
650 let mut entry_jsons = Vec::with_capacity(new_entries.len());
651 for entry in new_entries {
652 let json = serde_json::to_string(entry)?;
653 total_json_bytes += json.len() as u64;
654 entry_jsons.push(json);
655 }
656 serialize_timer.finish();
657 metrics.record_bytes(&metrics.sqlite_bytes, total_json_bytes);
658
659 for (i, json) in entry_jsons.into_iter().enumerate() {
660 let seq = start_seq + i + 1; map_outcome(
662 tx.execute(
663 cx.cx(),
664 "INSERT INTO pi_session_entries (seq,json) VALUES (?1,?2)",
665 &[
666 SqliteValue::Integer(i64::try_from(seq).unwrap_or(i64::MAX)),
667 SqliteValue::Text(json),
668 ],
669 )
670 .await,
671 )?;
672 }
673
674 map_outcome(
676 tx.execute(
677 cx.cx(),
678 "INSERT OR REPLACE INTO pi_session_meta (key,value) VALUES (?1,?2)",
679 &[
680 SqliteValue::Text("message_count".to_string()),
681 SqliteValue::Text(message_count.to_string()),
682 ],
683 )
684 .await,
685 )?;
686 if let Some(name) = session_name {
687 map_outcome(
688 tx.execute(
689 cx.cx(),
690 "INSERT OR REPLACE INTO pi_session_meta (key,value) VALUES (?1,?2)",
691 &[
692 SqliteValue::Text("name".to_string()),
693 SqliteValue::Text(name.to_string()),
694 ],
695 )
696 .await,
697 )?;
698 }
699
700 map_outcome(tx.commit(cx.cx()).await)?;
701 Ok(())
702}