use std::sync::{Arc, Mutex};
use chrono::{DateTime, Utc};
use zero_commands::{
ReplayEvent, ReplayKind, SessionError as CmdSessionError, SessionSource, SessionSummary,
};
use zero_session::{EventKind as SessionKind, SessionError, SessionRow, Store, StoredEvent};
use crate::app::log::{EntryKind, LogEntry};
#[derive(Debug, Default, Clone)]
struct ActiveSession {
row_id: Option<i64>,
ulid: Option<String>,
}
#[derive(Clone)]
pub struct SessionSink {
store: Arc<Store>,
active: Arc<Mutex<ActiveSession>>,
}
impl std::fmt::Debug for SessionSink {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let active = self.active.lock().unwrap();
f.debug_struct("SessionSink")
.field("row_id", &active.row_id)
.field("ulid", &active.ulid)
.finish_non_exhaustive()
}
}
impl SessionSink {
#[must_use]
pub fn new(store: Arc<Store>, session_id: i64, ulid: String) -> Self {
Self {
store,
active: Arc::new(Mutex::new(ActiveSession {
row_id: Some(session_id),
ulid: Some(ulid),
})),
}
}
#[must_use]
pub fn adapter(&self) -> SessionAdapter {
SessionAdapter {
store: Arc::clone(&self.store),
active: Arc::clone(&self.active),
}
}
pub fn record(&self, entry: &LogEntry) {
let Some(session_id) = self.active.lock().unwrap().row_id else {
return;
};
let kind = to_session_kind(entry.kind);
if let Err(e) = self.store.append(session_id, kind, &entry.text) {
tracing::warn!(err = %e, "session append failed");
}
}
pub fn end(&self) {
if let Some(session_id) = self.active.lock().unwrap().row_id
&& let Err(e) = self.store.end_session(session_id)
{
tracing::warn!(err = %e, "session end failed");
}
}
#[must_use]
pub fn store(&self) -> &Store {
&self.store
}
#[must_use]
pub fn session_id(&self) -> Option<i64> {
self.active.lock().unwrap().row_id
}
#[must_use]
pub fn ulid(&self) -> Option<String> {
self.active.lock().unwrap().ulid.clone()
}
}
#[derive(Clone)]
pub struct SessionAdapter {
store: Arc<Store>,
active: Arc<Mutex<ActiveSession>>,
}
impl std::fmt::Debug for SessionAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionAdapter")
.field("active_ulid", &self.active.lock().unwrap().ulid)
.finish_non_exhaustive()
}
}
impl SessionAdapter {
fn resolve_needle(&self, needle: &str) -> Result<Option<SessionRow>, SessionError> {
if let Some(row) = self.store.get_session_by_ulid(needle)? {
return Ok(Some(row));
}
if needle.len() >= 6 {
let rows = self.store.list_sessions(1000)?;
if let Some(hit) = rows.into_iter().find(|r| r.ulid.starts_with(needle)) {
return Ok(Some(hit));
}
}
let key = label_key(needle);
if let Some(ulid) = self.store.get_milestone(&key)?
&& let Some(row) = self.store.get_session_by_ulid(&ulid)?
{
return Ok(Some(row));
}
Ok(None)
}
}
impl SessionSource for SessionAdapter {
fn current_ulid(&self) -> Option<String> {
self.active.lock().unwrap().ulid.clone()
}
fn list(&self, limit: u32) -> Result<Vec<SessionSummary>, CmdSessionError> {
let rows = self.store.list_sessions(limit).map_err(io_err)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let n_events = self.store.count_events(row.id).map_err(io_err)?;
out.push(row_to_summary(row, n_events));
}
Ok(out)
}
fn find(&self, needle: &str) -> Result<SessionSummary, CmdSessionError> {
let row = self
.resolve_needle(needle)
.map_err(io_err)?
.ok_or(CmdSessionError::NotFound)?;
let n_events = self.store.count_events(row.id).map_err(io_err)?;
Ok(row_to_summary(row, n_events))
}
fn list_events(&self, ulid: &str, limit: u32) -> Result<Vec<ReplayEvent>, CmdSessionError> {
let row = self
.store
.get_session_by_ulid(ulid)
.map_err(io_err)?
.ok_or(CmdSessionError::NotFound)?;
let events = self.store.list_events(row.id, limit).map_err(io_err)?;
Ok(events.into_iter().map(stored_to_replay).collect())
}
fn save_label(&self, ulid: &str, label: &str) -> Result<(), CmdSessionError> {
let trimmed = label.trim();
if trimmed.is_empty() {
return Err(CmdSessionError::Io("empty label".into()));
}
self.store
.set_milestone(&label_key(trimmed), ulid)
.map_err(io_err)
}
fn fork_from_current(&self) -> Result<Option<String>, CmdSessionError> {
let parent = self.active.lock().unwrap().ulid.clone();
let Some(parent_ulid) = parent else {
return Ok(None);
};
let new_ulid = new_ulid();
let new_row_id = self
.store
.start_session(
&new_ulid,
None,
env!("CARGO_PKG_VERSION"),
Some(&parent_ulid),
)
.map_err(io_err)?;
let mut g = self.active.lock().unwrap();
g.row_id = Some(new_row_id);
g.ulid = Some(new_ulid.clone());
Ok(Some(new_ulid))
}
}
#[must_use]
pub fn to_entry_kind(k: SessionKind) -> EntryKind {
match k {
SessionKind::Prompt => EntryKind::Prompt,
SessionKind::System | SessionKind::ModeChange => EntryKind::System,
SessionKind::Command => EntryKind::Command,
SessionKind::Warn => EntryKind::Warn,
SessionKind::Alert => EntryKind::Alert,
}
}
fn to_session_kind(k: EntryKind) -> SessionKind {
match k {
EntryKind::Prompt => SessionKind::Prompt,
EntryKind::System => SessionKind::System,
EntryKind::Command => SessionKind::Command,
EntryKind::Warn => SessionKind::Warn,
EntryKind::Alert => SessionKind::Alert,
}
}
#[must_use]
pub fn replay(events: &[StoredEvent]) -> Vec<LogEntry> {
events
.iter()
.map(|e| LogEntry::new(to_entry_kind(e.kind), &e.text).at(e.at))
.collect()
}
#[must_use]
pub fn summarize(row: &SessionRow, n_events: usize) -> String {
let ts = row.started_at.format("%Y-%m-%d %H:%M UTC");
let status = if row.ended_at.is_some() {
"ended"
} else {
"interrupted"
};
format!("resuming: {ts} · {status} · {n_events} prior event(s)")
}
fn row_to_summary(row: SessionRow, n_events: i64) -> SessionSummary {
SessionSummary {
ulid: row.ulid,
started_at_ms: row.started_at.timestamp_millis(),
ended_at_ms: row.ended_at.map(|dt| dt.timestamp_millis()),
engine_base_url: row.engine_base_url,
cli_version: row.cli_version,
parent_ulid: row.parent_ulid,
n_events,
}
}
fn stored_to_replay(e: StoredEvent) -> ReplayEvent {
ReplayEvent {
kind: stored_kind_to_replay(e.kind),
at_ms: e.at.timestamp_millis(),
text: e.text,
}
}
fn stored_kind_to_replay(k: SessionKind) -> ReplayKind {
match k {
SessionKind::Prompt => ReplayKind::Prompt,
SessionKind::System | SessionKind::ModeChange => ReplayKind::System,
SessionKind::Command => ReplayKind::Command,
SessionKind::Warn => ReplayKind::Warn,
SessionKind::Alert => ReplayKind::Alert,
}
}
#[allow(clippy::needless_pass_by_value)]
fn io_err(e: SessionError) -> CmdSessionError {
CmdSessionError::Io(e.to_string())
}
fn label_key(label: &str) -> String {
format!("session.label.{label}")
}
fn new_ulid() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
let rand = fastrand_hex(6);
format!("{ms:013x}{rand}")
}
fn fastrand_hex(n: usize) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let mut state: u64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0x9E37_79B9_7F4A_7C15, |d| {
u64::try_from(d.as_nanos()).unwrap_or(0x9E37_79B9_7F4A_7C15)
});
(0..n)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
char::from_digit(u32::try_from((state >> 60) & 0xF).unwrap_or(0), 16).unwrap_or('0')
})
.collect()
}
#[allow(dead_code)]
fn _nudge(_: DateTime<Utc>) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn adapter_current_ulid_tracks_active_session() {
let store = Arc::new(Store::open_in_memory().unwrap());
let id = store.start_session("01HX", None, "0.3.0", None).unwrap();
let sink = SessionSink::new(Arc::clone(&store), id, "01HX".into());
let adapter = sink.adapter();
assert_eq!(adapter.current_ulid().as_deref(), Some("01HX"));
}
#[test]
fn adapter_fork_swaps_active_ulid_and_links_parent() {
let store = Arc::new(Store::open_in_memory().unwrap());
let id = store
.start_session("01HPARENT", None, "0.3.0", None)
.unwrap();
let sink = SessionSink::new(Arc::clone(&store), id, "01HPARENT".into());
let adapter = sink.adapter();
let child_ulid = adapter
.fork_from_current()
.unwrap()
.expect("fork produced ulid");
assert_eq!(adapter.current_ulid(), Some(child_ulid.clone()));
assert_eq!(
sink.active.lock().unwrap().ulid.as_deref(),
Some(child_ulid.as_str()),
"sink must see the fork under it",
);
let child = store.get_session_by_ulid(&child_ulid).unwrap().unwrap();
assert_eq!(child.parent_ulid.as_deref(), Some("01HPARENT"));
}
#[test]
fn adapter_save_label_then_find_by_label() {
let store = Arc::new(Store::open_in_memory().unwrap());
let id = store.start_session("01HLBL", None, "0.3.0", None).unwrap();
let sink = SessionSink::new(Arc::clone(&store), id, "01HLBL".into());
let adapter = sink.adapter();
adapter.save_label("01HLBL", "pre-cpi").unwrap();
let hit = adapter.find("pre-cpi").unwrap();
assert_eq!(hit.ulid, "01HLBL");
}
#[test]
fn adapter_find_missing_returns_not_found() {
let store = Arc::new(Store::open_in_memory().unwrap());
let id = store.start_session("01HX", None, "0.3.0", None).unwrap();
let sink = SessionSink::new(store, id, "01HX".into());
let adapter = sink.adapter();
assert!(matches!(
adapter.find("nope").unwrap_err(),
CmdSessionError::NotFound
));
}
#[test]
fn adapter_save_rejects_empty_label() {
let store = Arc::new(Store::open_in_memory().unwrap());
let id = store.start_session("01HE", None, "0.3.0", None).unwrap();
let sink = SessionSink::new(store, id, "01HE".into());
let adapter = sink.adapter();
assert!(matches!(
adapter.save_label("01HE", " ").unwrap_err(),
CmdSessionError::Io(_)
));
}
}