use jiff::{Timestamp, ToSpan};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use ulid::Ulid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionId(pub Ulid);
impl SessionId {
pub fn new() -> Self {
Self(Ulid::new())
}
pub fn from_ulid(ulid: Ulid) -> Self {
Self(ulid)
}
pub fn as_ulid(&self) -> &Ulid {
&self.0
}
pub fn from_string(s: &str) -> Result<Self, ulid::DecodeError> {
Ok(Self(Ulid::from_string(s)?))
}
}
impl Default for SessionId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SessionState {
#[default]
Initializing,
Ready,
Executing,
Paused,
Terminated,
Expired,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub id: SessionId,
pub state: SessionState,
pub created_at: Timestamp,
pub expires_at: Timestamp,
pub extension_count: u32,
pub recursion_depth: u32,
pub budget_status: BudgetStatus,
pub vm_instance_id: Option<String>,
pub context_variables: HashMap<String, String>,
pub current_snapshot_id: Option<String>,
pub snapshot_count: u32,
}
impl SessionInfo {
pub fn new(id: SessionId, duration_secs: u64) -> Self {
let now = Timestamp::now();
Self {
id,
state: SessionState::Initializing,
created_at: now,
expires_at: now
.checked_add((duration_secs as i64).seconds())
.expect("adding seconds to timestamp should not fail"),
extension_count: 0,
recursion_depth: 0,
budget_status: BudgetStatus::default(),
vm_instance_id: None,
context_variables: HashMap::new(),
current_snapshot_id: None,
snapshot_count: 0,
}
}
pub fn is_expired(&self) -> bool {
Timestamp::now() > self.expires_at
}
pub fn can_extend(&self, max_extensions: u32) -> bool {
self.extension_count < max_extensions
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetStatus {
pub token_budget: u64,
pub tokens_used: u64,
pub time_budget_ms: u64,
pub time_used_ms: u64,
pub max_recursion_depth: u32,
pub current_recursion_depth: u32,
}
impl Default for BudgetStatus {
fn default() -> Self {
Self {
token_budget: crate::DEFAULT_TOKEN_BUDGET,
tokens_used: 0,
time_budget_ms: crate::DEFAULT_TIME_BUDGET_MS,
time_used_ms: 0,
max_recursion_depth: crate::DEFAULT_MAX_RECURSION_DEPTH,
current_recursion_depth: 0,
}
}
}
impl BudgetStatus {
pub fn tokens_exhausted(&self) -> bool {
self.tokens_used >= self.token_budget
}
pub fn time_exhausted(&self) -> bool {
self.time_used_ms >= self.time_budget_ms
}
pub fn depth_exhausted(&self) -> bool {
self.current_recursion_depth > self.max_recursion_depth
}
pub fn is_exhausted(&self) -> bool {
self.tokens_exhausted() || self.time_exhausted() || self.depth_exhausted()
}
pub fn tokens_remaining(&self) -> u64 {
self.token_budget.saturating_sub(self.tokens_used)
}
pub fn time_remaining_ms(&self) -> u64 {
self.time_budget_ms.saturating_sub(self.time_used_ms)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Command {
Run(BashCommand),
Code(PythonCode),
Final(String),
FinalVar(String),
QueryLlm(LlmQuery),
QueryLlmBatched(Vec<LlmQuery>),
Snapshot(String),
Rollback(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BashCommand {
pub command: String,
pub timeout_ms: Option<u64>,
pub working_dir: Option<String>,
}
impl BashCommand {
pub fn new(command: impl Into<String>) -> Self {
Self {
command: command.into(),
timeout_ms: None,
working_dir: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PythonCode {
pub code: String,
pub timeout_ms: Option<u64>,
}
impl PythonCode {
pub fn new(code: impl Into<String>) -> Self {
Self {
code: code.into(),
timeout_ms: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmQuery {
pub prompt: String,
pub model: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
}
impl LlmQuery {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
model: None,
temperature: None,
max_tokens: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryMetadata {
pub query_id: Ulid,
pub parent_query_id: Option<Ulid>,
pub session_id: SessionId,
pub iteration: u32,
pub depth: u32,
pub started_at: Timestamp,
pub completed_at: Option<Timestamp>,
}
impl QueryMetadata {
pub fn new(session_id: SessionId) -> Self {
Self {
query_id: Ulid::new(),
parent_query_id: None,
session_id,
iteration: 0,
depth: 0,
started_at: Timestamp::now(),
completed_at: None,
}
}
pub fn child(&self) -> Self {
Self {
query_id: Ulid::new(),
parent_query_id: Some(self.query_id),
session_id: self.session_id,
iteration: 0,
depth: self.depth + 1,
started_at: Timestamp::now(),
completed_at: None,
}
}
pub fn complete(&mut self) {
self.completed_at = Some(Timestamp::now());
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CommandHistory {
pub entries: Vec<CommandHistoryEntry>,
}
impl CommandHistory {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, entry: CommandHistoryEntry) {
self.entries.push(entry);
}
pub fn last_successful_index(&self) -> Option<usize> {
self.entries.iter().rposition(|e| e.success)
}
pub fn since_checkpoint(&self, checkpoint_index: usize) -> &[CommandHistoryEntry] {
&self.entries[checkpoint_index..]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandHistoryEntry {
pub command: Command,
pub success: bool,
pub stdout: String,
pub stderr: String,
pub exit_code: Option<i32>,
pub execution_time_ms: u64,
pub executed_at: Timestamp,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_id_creation() {
let id1 = SessionId::new();
let id2 = SessionId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_budget_status_exhaustion() {
let mut budget = BudgetStatus::default();
assert!(!budget.is_exhausted());
budget.tokens_used = budget.token_budget;
assert!(budget.tokens_exhausted());
assert!(budget.is_exhausted());
}
#[test]
fn test_session_info_expiry() {
let id = SessionId::new();
let info = SessionInfo::new(id, 1); assert!(!info.is_expired());
let expired_info = SessionInfo::new(id, 0);
assert!(expired_info.expires_at <= Timestamp::now() || !expired_info.is_expired());
}
#[test]
fn test_query_metadata_child() {
let session_id = SessionId::new();
let parent = QueryMetadata::new(session_id);
let child = parent.child();
assert_eq!(child.parent_query_id, Some(parent.query_id));
assert_eq!(child.depth, parent.depth + 1);
assert_eq!(child.session_id, parent.session_id);
}
#[test]
fn test_command_history() {
let mut history = CommandHistory::new();
assert!(history.last_successful_index().is_none());
history.push(CommandHistoryEntry {
command: Command::Code(PythonCode::new("x = 1")),
success: true,
stdout: "".to_string(),
stderr: "".to_string(),
exit_code: Some(0),
execution_time_ms: 100,
executed_at: Timestamp::now(),
});
assert_eq!(history.last_successful_index(), Some(0));
}
}