use anyhow::Result;
use rusqlite::Connection;
use crate::state::session::{self, SessionMode, SessionOwnerKind, SessionStateFile};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ExclusiveWriteResult {
Applied { revision: u64 },
RevisionConflict { current_revision: u64 },
}
pub(crate) fn read(conn: &Connection) -> Result<Option<SessionStateFile>> {
let mut stmt = conn.prepare(
"SELECT schema_version, started_at_epoch_s, last_started_at_epoch_s,
start_count, session_id, mode, owner_kind, owner_id,
supervisor_id, lease_ttl_secs, last_heartbeat_at_epoch_s, revision
FROM session WHERE id = 1",
)?;
let result = stmt.query_row([], |row| {
let mode: String = row.get(5)?;
let mode = SessionMode::from_str(&mode).ok_or_else(|| {
rusqlite::Error::FromSqlConversionFailure(
5,
rusqlite::types::Type::Text,
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid session mode `{mode}`"),
)),
)
})?;
let owner_kind = match row.get::<_, Option<String>>(6)? {
Some(owner_kind) => SessionOwnerKind::from_str(&owner_kind).ok_or_else(|| {
rusqlite::Error::FromSqlConversionFailure(
6,
rusqlite::types::Type::Text,
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid session owner_kind `{owner_kind}`"),
)),
)
})?,
None => SessionOwnerKind::Interactive,
};
let session_id: Option<String> = row.get(4)?;
let owner_id = match row.get::<_, Option<String>>(7)? {
Some(owner_id) => Some(owner_id),
None if session_id.is_some() && owner_kind == SessionOwnerKind::Interactive => {
Some("interactive".to_owned())
}
None => None,
};
Ok(SessionStateFile {
schema_version: row.get(0)?,
started_at_epoch_s: row.get(1)?,
last_started_at_epoch_s: row.get(2)?,
start_count: row.get(3)?,
session_id,
mode,
owner_kind,
owner_id,
supervisor_id: row.get(8)?,
lease_ttl_secs: row.get(9)?,
last_heartbeat_at_epoch_s: row.get(10)?,
revision: row.get(11)?,
})
});
match result {
Ok(state) => Ok(Some(state)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub(crate) fn write(conn: &Connection, state: &SessionStateFile) -> Result<()> {
let owner_id = match state.owner_kind {
SessionOwnerKind::Interactive => state
.owner_id
.clone()
.or_else(|| state.session_id.as_ref().map(|_| "interactive".to_owned())),
SessionOwnerKind::RuntimeSupervisor | SessionOwnerKind::RuntimeWorker => Some(
state
.owner_id
.clone()
.ok_or_else(|| anyhow::anyhow!("autonomous session rows require owner_id"))?,
),
};
if state.owner_kind.lifecycle() == session::SessionLifecycle::Autonomous
&& (state.lease_ttl_secs.is_none() || state.last_heartbeat_at_epoch_s.is_none())
{
return Err(anyhow::anyhow!(
"autonomous session rows require lease_ttl_secs and last_heartbeat_at_epoch_s"
));
}
conn.execute(
"INSERT OR REPLACE INTO session
(id, schema_version, started_at_epoch_s, last_started_at_epoch_s,
start_count, session_id, mode, owner_kind, owner_id, supervisor_id,
lease_ttl_secs, last_heartbeat_at_epoch_s, revision)
VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
rusqlite::params![
state.schema_version,
state.started_at_epoch_s,
state.last_started_at_epoch_s,
state.start_count,
state.session_id,
state.mode.as_str(),
state.owner_kind.as_str(),
owner_id,
state.supervisor_id,
state.lease_ttl_secs,
state.last_heartbeat_at_epoch_s,
state.revision,
],
)?;
Ok(())
}
pub(crate) fn delete(conn: &Connection) -> Result<()> {
conn.execute("DELETE FROM session WHERE id = 1", [])?;
Ok(())
}
pub(crate) fn write_if_revision_matches(
conn: &Connection,
state: &SessionStateFile,
expected_revision: u64,
) -> Result<ExclusiveWriteResult> {
let owner_id = match state.owner_kind {
SessionOwnerKind::Interactive => state
.owner_id
.clone()
.or_else(|| state.session_id.as_ref().map(|_| "interactive".to_owned())),
SessionOwnerKind::RuntimeSupervisor | SessionOwnerKind::RuntimeWorker => Some(
state
.owner_id
.clone()
.ok_or_else(|| anyhow::anyhow!("autonomous session rows require owner_id"))?,
),
};
if state.owner_kind.lifecycle() == session::SessionLifecycle::Autonomous
&& (state.lease_ttl_secs.is_none() || state.last_heartbeat_at_epoch_s.is_none())
{
return Err(anyhow::anyhow!(
"autonomous session rows require lease_ttl_secs and last_heartbeat_at_epoch_s"
));
}
let expected_next_revision = expected_revision.saturating_add(1).max(1);
if state.revision != expected_next_revision {
return Err(anyhow::anyhow!(
"session CAS write rejected: state.revision {} does not equal expected_revision + 1 ({}); callers must derive next.revision = next_revision(&fresh) inside the same transaction as the read",
state.revision,
expected_next_revision
));
}
let row_exists = conn
.query_row("SELECT 1 FROM session WHERE id = 1", [], |_| Ok(()))
.map(|()| true)
.or_else(|e| match e {
rusqlite::Error::QueryReturnedNoRows => Ok(false),
other => Err(other),
})?;
if !row_exists {
if expected_revision != 0 {
return Ok(ExclusiveWriteResult::RevisionConflict {
current_revision: 0,
});
}
let inserted = conn.execute(
"INSERT INTO session
(id, schema_version, started_at_epoch_s, last_started_at_epoch_s,
start_count, session_id, mode, owner_kind, owner_id, supervisor_id,
lease_ttl_secs, last_heartbeat_at_epoch_s, revision)
VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
rusqlite::params![
state.schema_version,
state.started_at_epoch_s,
state.last_started_at_epoch_s,
state.start_count,
state.session_id,
state.mode.as_str(),
state.owner_kind.as_str(),
owner_id,
state.supervisor_id,
state.lease_ttl_secs,
state.last_heartbeat_at_epoch_s,
state.revision,
],
)?;
return if inserted == 1 {
Ok(ExclusiveWriteResult::Applied {
revision: state.revision,
})
} else {
Ok(ExclusiveWriteResult::RevisionConflict {
current_revision: current_revision(conn)?,
})
};
}
let updated = conn.execute(
"UPDATE session
SET schema_version = ?1,
started_at_epoch_s = ?2,
last_started_at_epoch_s = ?3,
start_count = ?4,
session_id = ?5,
mode = ?6,
owner_kind = ?7,
owner_id = ?8,
supervisor_id = ?9,
lease_ttl_secs = ?10,
last_heartbeat_at_epoch_s = ?11,
revision = ?12
WHERE id = 1 AND revision = ?13",
rusqlite::params![
state.schema_version,
state.started_at_epoch_s,
state.last_started_at_epoch_s,
state.start_count,
state.session_id,
state.mode.as_str(),
state.owner_kind.as_str(),
owner_id,
state.supervisor_id,
state.lease_ttl_secs,
state.last_heartbeat_at_epoch_s,
state.revision,
expected_revision,
],
)?;
if updated == 1 {
Ok(ExclusiveWriteResult::Applied {
revision: state.revision,
})
} else {
Ok(ExclusiveWriteResult::RevisionConflict {
current_revision: current_revision(conn)?,
})
}
}
fn current_revision(conn: &Connection) -> Result<u64> {
let mut stmt = conn.prepare("SELECT revision FROM session WHERE id = 1")?;
match stmt.query_row([], |row| row.get::<_, u64>(0)) {
Ok(revision) => Ok(revision),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(0),
Err(e) => Err(e.into()),
}
}
pub(crate) fn load_session_id(conn: &Connection, now_epoch_s: u64) -> Result<Option<String>> {
let Some(state) = read(conn)? else {
return Ok(None);
};
if session::is_stale(&state, now_epoch_s) {
return Ok(None);
}
Ok(state.session_id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::schema;
use crate::state::session::{SessionMode, SessionOwnerKind};
fn test_conn() -> Connection {
let conn = Connection::open_in_memory().unwrap();
schema::initialize(&conn, None).unwrap();
conn
}
fn interactive_state(
schema_version: u32,
started_at_epoch_s: u64,
last_started_at_epoch_s: u64,
start_count: u32,
session_id: Option<&str>,
mode: SessionMode,
) -> SessionStateFile {
SessionStateFile {
schema_version,
started_at_epoch_s,
last_started_at_epoch_s,
start_count,
session_id: session_id.map(str::to_owned),
mode,
owner_kind: SessionOwnerKind::Interactive,
owner_id: session_id.map(|_| "interactive".to_owned()),
supervisor_id: None,
lease_ttl_secs: None,
last_heartbeat_at_epoch_s: None,
revision: u64::from(session_id.is_some()),
}
}
#[test]
fn read_returns_none_when_empty() {
let conn = test_conn();
assert!(read(&conn).unwrap().is_none());
}
#[test]
fn write_and_read_roundtrip() {
let conn = test_conn();
let state = interactive_state(
3,
1_000_000,
1_000_050,
2,
Some("ses_01ABC"),
SessionMode::Research,
);
write(&conn, &state).unwrap();
let loaded = read(&conn).unwrap().expect("should exist");
assert_eq!(loaded.schema_version, 3);
assert_eq!(loaded.started_at_epoch_s, 1_000_000);
assert_eq!(loaded.last_started_at_epoch_s, 1_000_050);
assert_eq!(loaded.start_count, 2);
assert_eq!(loaded.session_id.as_deref(), Some("ses_01ABC"));
assert_eq!(loaded.mode, crate::state::session::SessionMode::Research);
}
#[test]
fn write_without_session_id() {
let conn = test_conn();
let state = interactive_state(3, 1_000_000, 1_000_000, 1, None, SessionMode::General);
write(&conn, &state).unwrap();
let loaded = read(&conn).unwrap().expect("should exist");
assert!(loaded.session_id.is_none());
}
#[test]
fn delete_removes_session() {
let conn = test_conn();
let state = interactive_state(
3,
1_000_000,
1_000_000,
1,
Some("ses_DEL"),
SessionMode::General,
);
write(&conn, &state).unwrap();
delete(&conn).unwrap();
assert!(read(&conn).unwrap().is_none());
}
#[test]
fn load_session_id_returns_none_for_stale() {
use crate::state::session::STALE_AFTER_SECS;
let conn = test_conn();
let now = 1_000_000u64;
let state = interactive_state(
3,
now - STALE_AFTER_SECS - 200,
now - STALE_AFTER_SECS - 100,
1,
Some("ses_STALE"),
SessionMode::General,
);
write(&conn, &state).unwrap();
assert!(load_session_id(&conn, now).unwrap().is_none());
}
#[test]
fn load_session_id_returns_id_for_active() {
let conn = test_conn();
let now = 1_000_000u64;
let state = interactive_state(
3,
now - 100,
now - 50,
1,
Some("ses_ACTIVE"),
SessionMode::General,
);
write(&conn, &state).unwrap();
assert_eq!(
load_session_id(&conn, now).unwrap().as_deref(),
Some("ses_ACTIVE")
);
}
#[test]
fn write_if_revision_matches_missing_row_with_positive_expected_is_conflict() {
let conn = test_conn();
let mut state = interactive_state(
3,
1_000_000,
1_000_050,
2,
Some("ses_GHOST"),
SessionMode::General,
);
state.revision = 5;
let result = write_if_revision_matches(&conn, &state, 4).unwrap();
assert!(
matches!(
result,
ExclusiveWriteResult::RevisionConflict {
current_revision: 0
}
),
"missing row + expected > 0 must be RevisionConflict{{0}}, got {result:?}"
);
assert!(read(&conn).unwrap().is_none());
}
#[test]
fn write_if_revision_matches_updates_legacy_migrated_row_at_revision_zero() {
let conn = test_conn();
let legacy = interactive_state(3, 900_000, 900_000, 3, None, SessionMode::General);
assert_eq!(
legacy.revision, 0,
"fixture invariant: migrated-without-session_id rows start at revision 0"
);
write(&conn, &legacy).unwrap();
let mut successor = interactive_state(
3,
900_000,
1_000_000,
4,
Some("ses_POST_MIGRATION"),
SessionMode::General,
);
successor.revision = 1;
let result = write_if_revision_matches(&conn, &successor, 0).unwrap();
assert!(
matches!(result, ExclusiveWriteResult::Applied { revision: 1 }),
"row-exists-at-revision-0 + expected=0 must UPDATE-with-CAS, got {result:?}"
);
let reloaded = read(&conn).unwrap().expect("row must still exist");
assert_eq!(reloaded.revision, 1);
assert_eq!(reloaded.session_id.as_deref(), Some("ses_POST_MIGRATION"));
assert_eq!(reloaded.start_count, 4);
}
#[test]
fn write_if_revision_matches_mismatched_revision_is_conflict() {
let conn = test_conn();
let mut seeded = interactive_state(
3,
1_000_000,
1_000_050,
2,
Some("ses_SEEDED"),
SessionMode::General,
);
seeded.revision = 6;
write(&conn, &seeded).unwrap();
let mut loser = seeded.clone();
loser.revision = 6; let result = write_if_revision_matches(&conn, &loser, 5).unwrap();
assert!(
matches!(
result,
ExclusiveWriteResult::RevisionConflict {
current_revision: 6
}
),
"mismatched revision must be RevisionConflict{{current=6}}, got {result:?}"
);
let reloaded = read(&conn).unwrap().expect("row must still exist");
assert_eq!(reloaded.revision, 6);
}
}