use std::path::Path;
use std::sync::Mutex;
use chrono::{DateTime, Utc};
use rusqlite::{Connection, OptionalExtension, params};
use crate::SessionError;
use crate::event::{EventKind, SessionRow, StoredEvent};
mod migrations {
refinery::embed_migrations!("./migrations");
}
pub struct Store {
conn: Mutex<Connection>,
}
impl std::fmt::Debug for Store {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Store").finish_non_exhaustive()
}
}
impl Store {
pub fn open(path: impl AsRef<Path>) -> Result<Self, SessionError> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut conn = Connection::open(path)?;
Self::configure(&conn)?;
migrations::migrations::runner()
.run(&mut conn)
.map_err(|e| SessionError::Migration(e.to_string()))?;
Ok(Self {
conn: Mutex::new(conn),
})
}
pub fn open_in_memory() -> Result<Self, SessionError> {
let mut conn = Connection::open_in_memory()?;
Self::configure(&conn)?;
migrations::migrations::runner()
.run(&mut conn)
.map_err(|e| SessionError::Migration(e.to_string()))?;
Ok(Self {
conn: Mutex::new(conn),
})
}
fn configure(conn: &Connection) -> Result<(), SessionError> {
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "synchronous", "NORMAL")?;
conn.pragma_update(None, "foreign_keys", "ON")?;
Ok(())
}
pub fn start_session(
&self,
ulid: &str,
engine_base_url: Option<&str>,
cli_version: &str,
parent_ulid: Option<&str>,
) -> Result<i64, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO sessions (ulid, started_at, engine_base_url, cli_version, parent_ulid)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![ulid, now, engine_base_url, cli_version, parent_ulid],
)?;
Ok(conn.last_insert_rowid())
}
pub fn end_session(&self, session_id: i64) -> Result<(), SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
let now = Utc::now().to_rfc3339();
conn.execute(
"UPDATE sessions SET ended_at = COALESCE(ended_at, ?1) WHERE id = ?2",
params![now, session_id],
)?;
Ok(())
}
pub fn append(
&self,
session_id: i64,
kind: EventKind,
text: &str,
) -> Result<i64, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
let seq: i64 = conn.query_row(
"SELECT COALESCE(MAX(seq), 0) + 1 FROM events WHERE session_id = ?1",
params![session_id],
|r| r.get(0),
)?;
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO events (session_id, seq, at, kind, text) VALUES (?1, ?2, ?3, ?4, ?5)",
params![session_id, seq, now, kind.as_str(), text],
)?;
Ok(seq)
}
pub fn list_events(
&self,
session_id: i64,
limit: u32,
) -> Result<Vec<StoredEvent>, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
let mut stmt = conn.prepare(
"SELECT id, session_id, seq, at, kind, text
FROM events
WHERE session_id = ?1
ORDER BY seq DESC
LIMIT ?2",
)?;
let rows = stmt.query_map(params![session_id, limit], |r| {
let at_str: String = r.get(3)?;
let kind_str: String = r.get(4)?;
Ok(StoredEvent {
id: r.get(0)?,
session_id: r.get(1)?,
seq: r.get(2)?,
at: parse_rfc3339(&at_str),
kind: EventKind::parse_str(&kind_str).unwrap_or(EventKind::System),
text: r.get(5)?,
})
})?;
let mut out: Vec<_> = rows.collect::<Result<_, _>>()?;
out.reverse();
Ok(out)
}
pub fn list_sessions(&self, limit: u32) -> Result<Vec<SessionRow>, SessionError> {
if limit == 0 {
return Ok(Vec::new());
}
let conn = self.conn.lock().expect("store mutex poisoned");
let mut stmt = conn.prepare(
"SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
FROM sessions
ORDER BY started_at DESC
LIMIT ?1",
)?;
let rows = stmt.query_map(params![limit], parse_session_row)?;
rows.collect::<Result<_, _>>().map_err(Into::into)
}
pub fn get_session_by_ulid(&self, ulid: &str) -> Result<Option<SessionRow>, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
conn.query_row(
"SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
FROM sessions
WHERE ulid = ?1",
params![ulid],
parse_session_row,
)
.optional()
.map_err(Into::into)
}
pub fn count_events(&self, session_id: i64) -> Result<i64, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
conn.query_row(
"SELECT COUNT(*) FROM events WHERE session_id = ?1",
params![session_id],
|r| r.get(0),
)
.map_err(Into::into)
}
pub fn last_session(&self) -> Result<Option<SessionRow>, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
conn.query_row(
"SELECT id, ulid, started_at, ended_at, engine_base_url, cli_version, parent_ulid
FROM sessions
ORDER BY started_at DESC
LIMIT 1",
[],
parse_session_row,
)
.optional()
.map_err(Into::into)
}
pub fn set_milestone(&self, key: &str, value: &str) -> Result<(), SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
let now = Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO milestones (key, value, at) VALUES (?1, ?2, ?3)
ON CONFLICT(key) DO UPDATE SET value = excluded.value, at = excluded.at",
params![key, value, now],
)?;
Ok(())
}
pub fn get_milestone(&self, key: &str) -> Result<Option<String>, SessionError> {
let conn = self.conn.lock().expect("store mutex poisoned");
conn.query_row(
"SELECT value FROM milestones WHERE key = ?1",
params![key],
|r| r.get::<_, String>(0),
)
.optional()
.map_err(Into::into)
}
}
fn parse_session_row(r: &rusqlite::Row<'_>) -> rusqlite::Result<SessionRow> {
let started: String = r.get(2)?;
let ended: Option<String> = r.get(3)?;
Ok(SessionRow {
id: r.get(0)?,
ulid: r.get(1)?,
started_at: parse_rfc3339(&started),
ended_at: ended.as_deref().map(parse_rfc3339),
engine_base_url: r.get(4)?,
cli_version: r.get(5)?,
parent_ulid: r.get(6)?,
})
}
fn parse_rfc3339(s: &str) -> DateTime<Utc> {
DateTime::parse_from_rfc3339(s).map_or_else(|_| Utc::now(), |dt| dt.with_timezone(&Utc))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn migrations_are_idempotent() {
let s1 = Store::open_in_memory().unwrap();
drop(s1);
let _s2 = Store::open_in_memory().unwrap();
}
#[test]
fn append_and_list_round_trip() {
let s = Store::open_in_memory().unwrap();
let sid = s
.start_session("01HTEST", Some("http://x"), "0.3.0", None)
.unwrap();
let seq1 = s.append(sid, EventKind::Prompt, "> hello").unwrap();
let seq2 = s.append(sid, EventKind::System, "welcome").unwrap();
let seq3 = s.append(sid, EventKind::Command, "status ok").unwrap();
assert_eq!((seq1, seq2, seq3), (1, 2, 3));
let events = s.list_events(sid, 100).unwrap();
assert_eq!(events.len(), 3);
assert_eq!(events[0].kind, EventKind::Prompt);
assert_eq!(events[2].text, "status ok");
}
#[test]
fn list_events_returns_tail_in_order() {
let s = Store::open_in_memory().unwrap();
let sid = s.start_session("01HTAIL", None, "0.3.0", None).unwrap();
for i in 0..10 {
s.append(sid, EventKind::System, &format!("line {i}"))
.unwrap();
}
let tail = s.list_events(sid, 3).unwrap();
assert_eq!(tail.len(), 3);
assert_eq!(tail[0].text, "line 7");
assert_eq!(tail[2].text, "line 9");
}
#[test]
fn last_session_is_most_recent() {
let s = Store::open_in_memory().unwrap();
s.start_session("01HA", None, "0.3.0", None).unwrap();
std::thread::sleep(std::time::Duration::from_millis(2));
s.start_session("01HB", None, "0.3.0", None).unwrap();
let last = s.last_session().unwrap().unwrap();
assert_eq!(last.ulid, "01HB");
}
#[test]
fn end_session_is_idempotent() {
let s = Store::open_in_memory().unwrap();
let sid = s.start_session("01HEND", None, "0.3.0", None).unwrap();
s.end_session(sid).unwrap();
let first_ended = s.last_session().unwrap().unwrap().ended_at;
s.end_session(sid).unwrap();
let second_ended = s.last_session().unwrap().unwrap().ended_at;
assert_eq!(first_ended, second_ended);
}
#[test]
fn milestones_upsert_and_read() {
let s = Store::open_in_memory().unwrap();
assert_eq!(s.get_milestone("welcome_shown").unwrap(), None);
s.set_milestone("welcome_shown", "true").unwrap();
assert_eq!(
s.get_milestone("welcome_shown").unwrap().as_deref(),
Some("true")
);
s.set_milestone("welcome_shown", "skipped").unwrap();
assert_eq!(
s.get_milestone("welcome_shown").unwrap().as_deref(),
Some("skipped")
);
}
#[test]
fn unknown_event_kind_is_rejected_by_schema() {
let s = Store::open_in_memory().unwrap();
let sid = s.start_session("01HBAD", None, "0.3.0", None).unwrap();
let conn = s.conn.lock().unwrap();
let res = conn.execute(
"INSERT INTO events (session_id, seq, at, kind, text) VALUES (?1, 1, ?2, ?3, ?4)",
params![sid, Utc::now().to_rfc3339(), "bogus", "x"],
);
assert!(res.is_err(), "CHECK constraint should reject unknown kind");
}
#[test]
fn list_sessions_honors_limit_and_is_newest_first() {
let s = Store::open_in_memory().unwrap();
s.start_session("01HA", None, "0.3.0", None).unwrap();
std::thread::sleep(std::time::Duration::from_millis(2));
s.start_session("01HB", None, "0.3.0", None).unwrap();
std::thread::sleep(std::time::Duration::from_millis(2));
s.start_session("01HC", None, "0.3.0", None).unwrap();
let all = s.list_sessions(10).unwrap();
assert_eq!(all.len(), 3);
assert_eq!(
all.iter().map(|r| r.ulid.as_str()).collect::<Vec<_>>(),
vec!["01HC", "01HB", "01HA"],
"list_sessions must be newest-first",
);
let top = s.list_sessions(1).unwrap();
assert_eq!(top.len(), 1);
assert_eq!(top[0].ulid, "01HC");
assert!(s.list_sessions(0).unwrap().is_empty());
}
#[test]
fn get_session_by_ulid_round_trips_and_misses_cleanly() {
let s = Store::open_in_memory().unwrap();
s.start_session("01HFOUND", Some("http://e"), "0.3.0", None)
.unwrap();
let hit = s.get_session_by_ulid("01HFOUND").unwrap();
assert!(hit.is_some());
assert_eq!(hit.unwrap().ulid, "01HFOUND");
assert!(s.get_session_by_ulid("01HMISSING").unwrap().is_none());
}
#[test]
fn count_events_matches_append_count() {
let s = Store::open_in_memory().unwrap();
let sid = s.start_session("01HCNT", None, "0.3.0", None).unwrap();
assert_eq!(s.count_events(sid).unwrap(), 0);
for i in 0..5 {
s.append(sid, EventKind::System, &format!("line {i}"))
.unwrap();
}
assert_eq!(s.count_events(sid).unwrap(), 5);
}
#[test]
fn parent_ulid_records_fork_link() {
let s = Store::open_in_memory().unwrap();
let _parent = s.start_session("01HP", None, "0.3.0", None).unwrap();
let _child = s
.start_session("01HC", None, "0.3.0", Some("01HP"))
.unwrap();
let last = s.last_session().unwrap().unwrap();
assert_eq!(last.parent_ulid.as_deref(), Some("01HP"));
}
}