Skip to main content

zero_session/
store.rs

1//! The `Store` handle — owns the SQLite connection, runs
2//! migrations, and exposes append/list/milestone ops.
3//!
4//! Writes are synchronous on the caller's thread. With WAL
5//! journalling and our <1 kB event rows, a commit lands in
6//! microseconds on modern hardware; a thread-boundary would cost
7//! more than the write itself. If that becomes untrue under
8//! plugin-heavy loads, see the commented-out `spawn_blocking`
9//! skeleton at the bottom of this file.
10
11use std::path::Path;
12use std::sync::Mutex;
13
14use chrono::{DateTime, Utc};
15use rusqlite::{Connection, OptionalExtension, params};
16
17use crate::SessionError;
18use crate::event::{EventKind, SessionRow, StoredEvent};
19
20mod migrations {
21    // Embed the SQL files in `../migrations/` at compile time.
22    refinery::embed_migrations!("./migrations");
23}
24
25/// Session-store handle.
26pub struct Store {
27    conn: Mutex<Connection>,
28}
29
30impl std::fmt::Debug for Store {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("Store").finish_non_exhaustive()
33    }
34}
35
36impl Store {
37    /// Open or create a database at `path`, running migrations.
38    ///
39    /// # Errors
40    /// Returns a `SessionError` if the parent directory cannot be
41    /// created, the connection cannot be opened, or migrations fail.
42    pub fn open(path: impl AsRef<Path>) -> Result<Self, SessionError> {
43        let path = path.as_ref();
44        if let Some(parent) = path.parent() {
45            std::fs::create_dir_all(parent)?;
46        }
47        let mut conn = Connection::open(path)?;
48        Self::configure(&conn)?;
49        migrations::migrations::runner()
50            .run(&mut conn)
51            .map_err(|e| SessionError::Migration(e.to_string()))?;
52        Ok(Self {
53            conn: Mutex::new(conn),
54        })
55    }
56
57    /// In-memory store for tests.
58    ///
59    /// # Errors
60    /// Returns a `SessionError` if the in-memory connection or
61    /// migrations fail.
62    pub fn open_in_memory() -> Result<Self, SessionError> {
63        let mut conn = Connection::open_in_memory()?;
64        Self::configure(&conn)?;
65        migrations::migrations::runner()
66            .run(&mut conn)
67            .map_err(|e| SessionError::Migration(e.to_string()))?;
68        Ok(Self {
69            conn: Mutex::new(conn),
70        })
71    }
72
73    fn configure(conn: &Connection) -> Result<(), SessionError> {
74        // WAL survives a hard kill mid-write; `synchronous=NORMAL`
75        // keeps fsync'ing for durability without the latency of FULL.
76        // `foreign_keys=ON` makes ON DELETE CASCADE actually cascade.
77        conn.pragma_update(None, "journal_mode", "WAL")?;
78        conn.pragma_update(None, "synchronous", "NORMAL")?;
79        conn.pragma_update(None, "foreign_keys", "ON")?;
80        Ok(())
81    }
82
83    /// Start a new session. `ulid` should be freshly generated by
84    /// the caller (keeps this crate ulid-free) — any short unique
85    /// string works. Returns the row id for subsequent `append`s.
86    ///
87    /// # Errors
88    /// Returns a `SessionError::Sql` on insert failure (e.g. ulid
89    /// collision).
90    pub fn start_session(
91        &self,
92        ulid: &str,
93        engine_base_url: Option<&str>,
94        cli_version: &str,
95        parent_ulid: Option<&str>,
96    ) -> Result<i64, SessionError> {
97        let conn = self.conn.lock().expect("store mutex poisoned");
98        let now = Utc::now().to_rfc3339();
99        conn.execute(
100            "INSERT INTO sessions (ulid, started_at, engine_base_url, cli_version, parent_ulid)
101             VALUES (?1, ?2, ?3, ?4, ?5)",
102            params![ulid, now, engine_base_url, cli_version, parent_ulid],
103        )?;
104        Ok(conn.last_insert_rowid())
105    }
106
107    /// Mark a session as ended. Idempotent — re-ending is a no-op.
108    ///
109    /// # Errors
110    /// Propagates underlying SQL failures.
111    pub fn end_session(&self, session_id: i64) -> Result<(), SessionError> {
112        let conn = self.conn.lock().expect("store mutex poisoned");
113        let now = Utc::now().to_rfc3339();
114        conn.execute(
115            "UPDATE sessions SET ended_at = COALESCE(ended_at, ?1) WHERE id = ?2",
116            params![now, session_id],
117        )?;
118        Ok(())
119    }
120
121    /// Append an event. The `seq` is allocated here so callers
122    /// don't race on numbering. Returns the new `seq`.
123    ///
124    /// # Errors
125    /// Propagates underlying SQL failures.
126    pub fn append(
127        &self,
128        session_id: i64,
129        kind: EventKind,
130        text: &str,
131    ) -> Result<i64, SessionError> {
132        let conn = self.conn.lock().expect("store mutex poisoned");
133        let seq: i64 = conn.query_row(
134            "SELECT COALESCE(MAX(seq), 0) + 1 FROM events WHERE session_id = ?1",
135            params![session_id],
136            |r| r.get(0),
137        )?;
138        let now = Utc::now().to_rfc3339();
139        conn.execute(
140            "INSERT INTO events (session_id, seq, at, kind, text) VALUES (?1, ?2, ?3, ?4, ?5)",
141            params![session_id, seq, now, kind.as_str(), text],
142        )?;
143        Ok(seq)
144    }
145
146    /// List events in replay order, newest last. `limit` caps the
147    /// returned set from the tail (most recent `limit` entries).
148    ///
149    /// # Errors
150    /// Propagates underlying SQL failures.
151    pub fn list_events(
152        &self,
153        session_id: i64,
154        limit: u32,
155    ) -> Result<Vec<StoredEvent>, SessionError> {
156        let conn = self.conn.lock().expect("store mutex poisoned");
157        let mut stmt = conn.prepare(
158            "SELECT id, session_id, seq, at, kind, text
159             FROM events
160             WHERE session_id = ?1
161             ORDER BY seq DESC
162             LIMIT ?2",
163        )?;
164        let rows = stmt.query_map(params![session_id, limit], |r| {
165            let at_str: String = r.get(3)?;
166            let kind_str: String = r.get(4)?;
167            Ok(StoredEvent {
168                id: r.get(0)?,
169                session_id: r.get(1)?,
170                seq: r.get(2)?,
171                at: parse_rfc3339(&at_str),
172                kind: EventKind::parse_str(&kind_str).unwrap_or(EventKind::System),
173                text: r.get(5)?,
174            })
175        })?;
176        // We asked DESC to cap from the tail; flip back to
177        // ascending so the caller can append straight into the log.
178        let mut out: Vec<_> = rows.collect::<Result<_, _>>()?;
179        out.reverse();
180        Ok(out)
181    }
182
183    /// Fetch the N most recent sessions (by `started_at` desc).
184    /// `limit == 0` returns an empty vec without hitting the DB.
185    ///
186    /// Used by `/sessions` to paint a navigable list; the `cli_version`
187    /// and `engine_base_url` columns are pulled through so the listing
188    /// can show mismatches (e.g. an old session from a different
189    /// engine base URL) without a follow-up round-trip.
190    ///
191    /// # Errors
192    /// Propagates underlying SQL failures.
193    pub fn list_sessions(&self, limit: u32) -> Result<Vec<SessionRow>, SessionError> {
194        if limit == 0 {
195            return Ok(Vec::new());
196        }
197        let conn = self.conn.lock().expect("store mutex poisoned");
198        let mut stmt = conn.prepare(
199            "SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
200             FROM sessions
201             ORDER BY started_at DESC
202             LIMIT ?1",
203        )?;
204        let rows = stmt.query_map(params![limit], parse_session_row)?;
205        rows.collect::<Result<_, _>>().map_err(Into::into)
206    }
207
208    /// Look up a session by its `ulid`. Returns `None` when no row
209    /// matches; `Err` only on SQL failure.
210    ///
211    /// # Errors
212    /// Propagates underlying SQL failures.
213    pub fn get_session_by_ulid(&self, ulid: &str) -> Result<Option<SessionRow>, SessionError> {
214        let conn = self.conn.lock().expect("store mutex poisoned");
215        conn.query_row(
216            "SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
217             FROM sessions
218             WHERE ulid = ?1",
219            params![ulid],
220            parse_session_row,
221        )
222        .optional()
223        .map_err(Into::into)
224    }
225
226    /// Cheap `COUNT(*)` for a session's events. Handy for the
227    /// `/sessions` list so it can render `42 event(s)` without
228    /// pulling every row.
229    ///
230    /// # Errors
231    /// Propagates underlying SQL failures.
232    pub fn count_events(&self, session_id: i64) -> Result<i64, SessionError> {
233        let conn = self.conn.lock().expect("store mutex poisoned");
234        conn.query_row(
235            "SELECT COUNT(*) FROM events WHERE session_id = ?1",
236            params![session_id],
237            |r| r.get(0),
238        )
239        .map_err(Into::into)
240    }
241
242    /// Fetch the most recent session (by `started_at`).
243    ///
244    /// # Errors
245    /// Propagates underlying SQL failures.
246    pub fn last_session(&self) -> Result<Option<SessionRow>, SessionError> {
247        let conn = self.conn.lock().expect("store mutex poisoned");
248        conn.query_row(
249            "SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
250             FROM sessions
251             ORDER BY started_at DESC
252             LIMIT 1",
253            [],
254            parse_session_row,
255        )
256        .optional()
257        .map_err(Into::into)
258    }
259
260    /// Set (upsert) a milestone flag.
261    ///
262    /// # Errors
263    /// Propagates underlying SQL failures.
264    pub fn set_milestone(&self, key: &str, value: &str) -> Result<(), SessionError> {
265        let conn = self.conn.lock().expect("store mutex poisoned");
266        let now = Utc::now().to_rfc3339();
267        conn.execute(
268            "INSERT INTO milestones (key, value, at) VALUES (?1, ?2, ?3)
269             ON CONFLICT(key) DO UPDATE SET value = excluded.value, at = excluded.at",
270            params![key, value, now],
271        )?;
272        Ok(())
273    }
274
275    /// Read a milestone. Returns `None` if it has never been set.
276    ///
277    /// # Errors
278    /// Propagates underlying SQL failures.
279    pub fn get_milestone(&self, key: &str) -> Result<Option<String>, SessionError> {
280        let conn = self.conn.lock().expect("store mutex poisoned");
281        conn.query_row(
282            "SELECT value FROM milestones WHERE key = ?1",
283            params![key],
284            |r| r.get::<_, String>(0),
285        )
286        .optional()
287        .map_err(Into::into)
288    }
289}
290
291fn parse_session_row(r: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRow> {
292    let started: String = r.get(2)?;
293    let ended: Option<String> = r.get(3)?;
294    Ok(SessionRow {
295        id: r.get(0)?,
296        ulid: r.get(1)?,
297        started_at: parse_rfc3339(&started),
298        ended_at: ended.as_deref().map(parse_rfc3339),
299        engine_base_url: r.get(4)?,
300        cli_version: r.get(5)?,
301        parent_ulid: r.get(6)?,
302    })
303}
304
305fn parse_rfc3339(s: &str) -> DateTime<Utc> {
306    DateTime::parse_from_rfc3339(s).map_or_else(|_| Utc::now(), |dt| dt.with_timezone(&Utc))
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn migrations_are_idempotent() {
315        // Open twice — second open should replay migrations against
316        // the already-migrated DB without errors.
317        let s1 = Store::open_in_memory().unwrap();
318        drop(s1);
319        let _s2 = Store::open_in_memory().unwrap();
320    }
321
322    #[test]
323    fn append_and_list_round_trip() {
324        let s = Store::open_in_memory().unwrap();
325        let sid = s
326            .start_session("01HTEST", Some("http://x"), "0.3.0", None)
327            .unwrap();
328        let seq1 = s.append(sid, EventKind::Prompt, "> hello").unwrap();
329        let seq2 = s.append(sid, EventKind::System, "welcome").unwrap();
330        let seq3 = s.append(sid, EventKind::Command, "status ok").unwrap();
331        assert_eq!((seq1, seq2, seq3), (1, 2, 3));
332
333        let events = s.list_events(sid, 100).unwrap();
334        assert_eq!(events.len(), 3);
335        assert_eq!(events[0].kind, EventKind::Prompt);
336        assert_eq!(events[2].text, "status ok");
337    }
338
339    #[test]
340    fn list_events_returns_tail_in_order() {
341        let s = Store::open_in_memory().unwrap();
342        let sid = s.start_session("01HTAIL", None, "0.3.0", None).unwrap();
343        for i in 0..10 {
344            s.append(sid, EventKind::System, &format!("line {i}"))
345                .unwrap();
346        }
347        let tail = s.list_events(sid, 3).unwrap();
348        assert_eq!(tail.len(), 3);
349        assert_eq!(tail[0].text, "line 7");
350        assert_eq!(tail[2].text, "line 9");
351    }
352
353    #[test]
354    fn last_session_is_most_recent() {
355        let s = Store::open_in_memory().unwrap();
356        s.start_session("01HA", None, "0.3.0", None).unwrap();
357        // Sleep one ms so started_at differs. Chrono's RFC-3339
358        // output has ms granularity; same-ms rows would sort by id.
359        std::thread::sleep(std::time::Duration::from_millis(2));
360        s.start_session("01HB", None, "0.3.0", None).unwrap();
361
362        let last = s.last_session().unwrap().unwrap();
363        assert_eq!(last.ulid, "01HB");
364    }
365
366    #[test]
367    fn end_session_is_idempotent() {
368        let s = Store::open_in_memory().unwrap();
369        let sid = s.start_session("01HEND", None, "0.3.0", None).unwrap();
370        s.end_session(sid).unwrap();
371        let first_ended = s.last_session().unwrap().unwrap().ended_at;
372        s.end_session(sid).unwrap();
373        let second_ended = s.last_session().unwrap().unwrap().ended_at;
374        assert_eq!(first_ended, second_ended);
375    }
376
377    #[test]
378    fn milestones_upsert_and_read() {
379        let s = Store::open_in_memory().unwrap();
380        assert_eq!(s.get_milestone("welcome_shown").unwrap(), None);
381        s.set_milestone("welcome_shown", "true").unwrap();
382        assert_eq!(
383            s.get_milestone("welcome_shown").unwrap().as_deref(),
384            Some("true")
385        );
386        // Overwrite.
387        s.set_milestone("welcome_shown", "skipped").unwrap();
388        assert_eq!(
389            s.get_milestone("welcome_shown").unwrap().as_deref(),
390            Some("skipped")
391        );
392    }
393
394    #[test]
395    fn unknown_event_kind_is_rejected_by_schema() {
396        let s = Store::open_in_memory().unwrap();
397        let sid = s.start_session("01HBAD", None, "0.3.0", None).unwrap();
398        let conn = s.conn.lock().unwrap();
399        let res = conn.execute(
400            "INSERT INTO events (session_id, seq, at, kind, text) VALUES (?1, 1, ?2, ?3, ?4)",
401            params![sid, Utc::now().to_rfc3339(), "bogus", "x"],
402        );
403        assert!(res.is_err(), "CHECK constraint should reject unknown kind");
404    }
405
406    #[test]
407    fn list_sessions_honors_limit_and_is_newest_first() {
408        let s = Store::open_in_memory().unwrap();
409        s.start_session("01HA", None, "0.3.0", None).unwrap();
410        std::thread::sleep(std::time::Duration::from_millis(2));
411        s.start_session("01HB", None, "0.3.0", None).unwrap();
412        std::thread::sleep(std::time::Duration::from_millis(2));
413        s.start_session("01HC", None, "0.3.0", None).unwrap();
414
415        let all = s.list_sessions(10).unwrap();
416        assert_eq!(all.len(), 3);
417        assert_eq!(
418            all.iter().map(|r| r.ulid.as_str()).collect::<Vec<_>>(),
419            vec!["01HC", "01HB", "01HA"],
420            "list_sessions must be newest-first",
421        );
422
423        let top = s.list_sessions(1).unwrap();
424        assert_eq!(top.len(), 1);
425        assert_eq!(top[0].ulid, "01HC");
426
427        // Zero limit is a no-op — must not hit the DB or produce rows.
428        assert!(s.list_sessions(0).unwrap().is_empty());
429    }
430
431    #[test]
432    fn get_session_by_ulid_round_trips_and_misses_cleanly() {
433        let s = Store::open_in_memory().unwrap();
434        s.start_session("01HFOUND", Some("http://e"), "0.3.0", None)
435            .unwrap();
436        let hit = s.get_session_by_ulid("01HFOUND").unwrap();
437        assert!(hit.is_some());
438        assert_eq!(hit.unwrap().ulid, "01HFOUND");
439        // Missing ulid returns Ok(None), not an error — callers should
440        // format a friendly "no such session" line, not crash.
441        assert!(s.get_session_by_ulid("01HMISSING").unwrap().is_none());
442    }
443
444    #[test]
445    fn count_events_matches_append_count() {
446        let s = Store::open_in_memory().unwrap();
447        let sid = s.start_session("01HCNT", None, "0.3.0", None).unwrap();
448        assert_eq!(s.count_events(sid).unwrap(), 0);
449        for i in 0..5 {
450            s.append(sid, EventKind::System, &format!("line {i}"))
451                .unwrap();
452        }
453        assert_eq!(s.count_events(sid).unwrap(), 5);
454    }
455
456    #[test]
457    fn parent_ulid_records_fork_link() {
458        let s = Store::open_in_memory().unwrap();
459        let _parent = s.start_session("01HP", None, "0.3.0", None).unwrap();
460        let _child = s
461            .start_session("01HC", None, "0.3.0", Some("01HP"))
462            .unwrap();
463        let last = s.last_session().unwrap().unwrap();
464        assert_eq!(last.parent_ulid.as_deref(), Some("01HP"));
465    }
466}