pub mod compact;
pub mod rehydrate;
pub mod storage;
use std::collections::HashMap;
use compact_str::CompactString;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ToolCallState {
Completed { result: String },
Interrupted,
Failed { error: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallEntry {
pub id: String,
pub name: String,
pub args: serde_json::Value,
pub state: ToolCallState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionMessage {
pub role: MessageRole,
pub content: CompactString,
pub estimated_tokens: u64,
#[serde(default = "new_message_id")]
pub id: CompactString,
#[serde(default)]
pub timestamp: i64,
#[serde(default)]
pub tool_calls: Vec<ToolCallEntry>,
}
pub(crate) fn new_message_id() -> CompactString {
CompactString::new(Uuid::new_v4().to_string())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Compaction {
pub summary: CompactString,
pub first_kept_index: usize,
pub summarized_count: usize,
pub token_savings: u64,
pub created_at: CompactString,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeNode {
pub id: CompactString,
pub parent: Option<CompactString>,
pub timestamp: i64,
#[serde(default)]
pub label: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionTree {
#[serde(default)]
pub entries: HashMap<CompactString, TreeNode>,
#[serde(default)]
pub leaf_id: Option<CompactString>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BranchSummary {
pub root_id: CompactString,
pub parent_id: CompactString,
pub message_count: usize,
pub preview: String,
pub created_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionAllowEntry {
pub tool: String,
pub pattern: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginEntry {
pub custom_type: String,
pub data: String,
pub display: bool,
pub timestamp: i64,
pub seq: u64,
}
pub const SCHEMA_VERSION: u32 = 2;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
#[serde(default)]
pub schema_version: u32,
pub id: CompactString,
#[serde(default)]
pub origin_id: Option<CompactString>,
pub name: CompactString,
pub messages: Vec<SessionMessage>,
pub compactions: Vec<Compaction>,
pub created_at: CompactString,
pub updated_at: CompactString,
pub total_tokens: u64,
pub total_cost: f64,
pub total_estimated_tokens: u64,
#[serde(default)]
pub cumulative_input_tokens: u64,
#[serde(default)]
pub cumulative_cached_input_tokens: u64,
#[serde(default)]
pub cumulative_cache_creation_tokens: u64,
pub context_window: u64,
pub model: CompactString,
pub provider: CompactString,
pub working_dir: CompactString,
#[serde(default)]
pub permission_allowlist: Vec<PermissionAllowEntry>,
#[serde(default)]
pub extra_entries: Vec<PluginEntry>,
#[serde(default)]
pub next_entry_seq: u64,
#[serde(default)]
pub tree: SessionTree,
#[serde(default)]
pub message_store: HashMap<CompactString, SessionMessage>,
#[serde(default)]
pub branch_summaries: Vec<BranchSummary>,
#[serde(default)]
pub current_prompt_name: Option<String>,
#[serde(default)]
pub todo_list: Vec<crate::agent::tools::todo::TodoItem>,
#[serde(default)]
pub modified_files: Vec<String>,
#[serde(skip)]
pub loaded_mtime: Option<std::time::SystemTime>,
#[serde(skip)]
pub loaded_from_newer_version: Option<u64>,
}
impl Session {
pub fn estimate_tokens(text: &str) -> u64 {
(text.len() as u64 / 4).max(1)
}
pub fn estimate_message_tokens(msg: &SessionMessage) -> u64 {
let mut tokens = Self::estimate_tokens(&msg.content);
for tc in &msg.tool_calls {
tokens = tokens
.saturating_add(Self::estimate_tokens(&tc.args.to_string()))
.saturating_add(Self::estimate_tokens(&tc.name))
.saturating_add(16);
match &tc.state {
ToolCallState::Completed { result } => {
tokens = tokens.saturating_add(Self::estimate_tokens(result));
}
ToolCallState::Failed { error } => {
tokens = tokens.saturating_add(Self::estimate_tokens(error));
}
ToolCallState::Interrupted => {
tokens = tokens.saturating_add(8);
}
}
}
tokens
}
pub fn recompute_all_estimates(&mut self) {
for msg in self.messages.iter_mut() {
msg.estimated_tokens = Self::estimate_message_tokens(msg);
}
for (id, m) in self.message_store.iter_mut() {
if let Some(canonical) = self.messages.iter().find(|x| x.id == id) {
m.estimated_tokens = canonical.estimated_tokens;
} else {
m.estimated_tokens = Self::estimate_message_tokens(m);
}
}
self.total_estimated_tokens = self.messages.iter().map(|m| m.estimated_tokens).sum();
}
pub fn new(provider: &str, model: &str, context_window: u64) -> Self {
let now = CompactString::new(chrono::Utc::now().to_rfc3339());
Session {
schema_version: SCHEMA_VERSION,
id: CompactString::new(Uuid::new_v4().to_string()),
origin_id: None,
name: CompactString::new(""),
messages: Vec::new(),
compactions: Vec::new(),
created_at: now.clone(),
updated_at: now,
total_tokens: 0,
total_cost: 0.0,
total_estimated_tokens: 0,
cumulative_input_tokens: 0,
cumulative_cached_input_tokens: 0,
cumulative_cache_creation_tokens: 0,
context_window,
model: CompactString::new(model),
provider: CompactString::new(provider),
working_dir: std::env::current_dir()
.map(|p| CompactString::new(p.to_string_lossy()))
.unwrap_or_default(),
permission_allowlist: Vec::new(),
extra_entries: Vec::new(),
next_entry_seq: 0,
tree: SessionTree::default(),
message_store: HashMap::new(),
branch_summaries: Vec::new(),
current_prompt_name: None,
todo_list: Vec::new(),
modified_files: Vec::new(),
loaded_mtime: None,
loaded_from_newer_version: None,
}
}
pub fn record_token_usage(
&mut self,
input_tokens: u64,
cached_input_tokens: u64,
cache_creation_input_tokens: u64,
) {
self.cumulative_input_tokens = self.cumulative_input_tokens.saturating_add(input_tokens);
self.cumulative_cached_input_tokens = self
.cumulative_cached_input_tokens
.saturating_add(cached_input_tokens);
self.cumulative_cache_creation_tokens = self
.cumulative_cache_creation_tokens
.saturating_add(cache_creation_input_tokens);
}
pub fn cache_hit_ratio(&self) -> Option<f64> {
if self.cumulative_input_tokens == 0 {
return None;
}
Some(self.cumulative_cached_input_tokens as f64 / self.cumulative_input_tokens as f64)
}
pub fn ensure_message_store_initialized(&mut self) {
if !self.message_store.is_empty() {
return;
}
for msg in &self.messages {
self.message_store.insert(msg.id.clone(), msg.clone());
}
}
pub fn ensure_tree_initialized(&mut self) {
if !self.tree.entries.is_empty() || self.messages.is_empty() {
return;
}
let mut prev: Option<CompactString> = None;
for msg in &self.messages {
let node = TreeNode {
id: msg.id.clone(),
parent: prev.clone(),
timestamp: msg.timestamp,
label: None,
};
prev = Some(msg.id.clone());
self.tree.entries.insert(msg.id.clone(), node);
}
self.tree.leaf_id = prev;
}
pub fn ensure_back_compat_initialized(&mut self) {
self.ensure_message_store_initialized();
self.ensure_tree_initialized();
self.ensure_next_entry_seq_initialized();
}
fn ensure_next_entry_seq_initialized(&mut self) {
if self.extra_entries.is_empty() {
return;
}
let max_seq = self.extra_entries.iter().map(|e| e.seq).max().unwrap_or(0);
let needed = max_seq
.saturating_add(1)
.max(self.extra_entries.len() as u64);
if self.next_entry_seq < needed {
self.next_entry_seq = needed;
}
}
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn append_plugin_entry(
&mut self,
custom_type: impl Into<String>,
data: impl Into<String>,
display: bool,
) -> &PluginEntry {
let entry = PluginEntry {
custom_type: custom_type.into(),
data: data.into(),
display,
timestamp: chrono::Utc::now().timestamp(),
seq: self.next_entry_seq,
};
self.next_entry_seq = self.next_entry_seq.saturating_add(1);
self.extra_entries.push(entry);
self.extra_entries.last().expect("just pushed")
}
pub fn add_message(&mut self, role: MessageRole, content: &str) {
self.add_message_with_tool_calls(role, content, Vec::new());
}
pub fn effective_origin(&self) -> &str {
self.origin_id.as_deref().unwrap_or(&self.id)
}
pub fn first_user_prompt(&self) -> Option<&str> {
self.messages
.iter()
.find(|m| m.role == MessageRole::User)
.map(|m| m.content.as_str())
}
pub fn add_message_with_tool_calls(
&mut self,
role: MessageRole,
content: &str,
tool_calls: Vec<ToolCallEntry>,
) {
self.ensure_back_compat_initialized();
let mut tokens = Self::estimate_tokens(content);
for tc in &tool_calls {
tokens = tokens
.saturating_add(Self::estimate_tokens(&tc.args.to_string()))
.saturating_add(Self::estimate_tokens(&tc.name))
.saturating_add(16);
match &tc.state {
ToolCallState::Completed { result } => {
tokens = tokens.saturating_add(Self::estimate_tokens(result));
}
ToolCallState::Failed { error } => {
tokens = tokens.saturating_add(Self::estimate_tokens(error));
}
ToolCallState::Interrupted => {
tokens = tokens.saturating_add(8);
}
}
}
let id = new_message_id();
let timestamp = chrono::Utc::now().timestamp();
let parent = self.tree.leaf_id.clone();
let msg = SessionMessage {
role,
content: CompactString::new(content),
estimated_tokens: tokens,
id: id.clone(),
timestamp,
tool_calls,
};
self.messages.push(msg.clone());
self.message_store.insert(id.clone(), msg);
self.tree.entries.insert(
id.clone(),
TreeNode {
id: id.clone(),
parent,
timestamp,
label: None,
},
);
self.tree.leaf_id = Some(id);
self.total_estimated_tokens = self.total_estimated_tokens.saturating_add(tokens);
self.updated_at = CompactString::new(chrono::Utc::now().to_rfc3339());
}
pub fn pop_last_message(&mut self) -> Option<SessionMessage> {
self.ensure_back_compat_initialized();
if let Some(last) = self.messages.last()
&& self.messages.len() == 1
&& last.role == MessageRole::System
{
return None;
}
let msg = self.messages.pop()?;
let parent = match self.tree.entries.get(&msg.id) {
Some(node) => node.parent.clone(),
None => self.messages.last().map(|m| m.id.clone()),
};
self.tree.leaf_id = parent;
let still_referenced = self
.tree
.entries
.values()
.any(|n| n.parent.as_ref() == Some(&msg.id));
if !still_referenced {
self.tree.entries.remove(&msg.id);
self.message_store.remove(&msg.id);
}
self.total_estimated_tokens = self
.total_estimated_tokens
.saturating_sub(msg.estimated_tokens);
self.updated_at = CompactString::new(chrono::Utc::now().to_rfc3339());
Some(msg)
}
pub fn switch_to_leaf(&mut self, new_leaf_id: &CompactString) -> Result<(), String> {
self.ensure_back_compat_initialized();
if !self.tree.entries.contains_key(new_leaf_id) {
return Err(format!("unknown entry id: {}", new_leaf_id));
}
let mut chain: Vec<CompactString> = Vec::new();
let mut cursor: Option<CompactString> = Some(new_leaf_id.clone());
let mut visited = std::collections::HashSet::new();
let mut hops = 0usize;
const MAX_HOPS: usize = 10_000;
while let Some(id) = cursor {
if hops >= MAX_HOPS {
return Err(format!(
"cycle or excessive depth in session tree (>{} hops from leaf {})",
MAX_HOPS, new_leaf_id
));
}
hops += 1;
if !visited.insert(id.clone()) {
return Err(format!("cycle detected in session tree at node {}", id));
}
let node = self
.tree
.entries
.get(&id)
.ok_or_else(|| format!("broken parent link to missing node {}", id))?;
cursor = node.parent.clone();
chain.push(id);
}
chain.reverse();
for id in &chain {
if !self.message_store.contains_key(id) {
return Err(format!("missing content for node {}", id));
}
}
let new_messages: Vec<SessionMessage> = chain
.iter()
.map(|id| self.message_store[id].clone())
.collect();
let new_total: u64 = new_messages.iter().map(|m| m.estimated_tokens).sum();
self.messages = new_messages;
self.total_estimated_tokens = new_total;
self.tree.leaf_id = Some(new_leaf_id.clone());
self.updated_at = CompactString::new(chrono::Utc::now().to_rfc3339());
Ok(())
}
pub fn fork_at(&mut self, entry_id: &CompactString) -> Result<SessionMessage, String> {
self.ensure_back_compat_initialized();
let node = self
.tree
.entries
.get(entry_id)
.ok_or_else(|| format!("unknown entry id: {}", entry_id))?;
let parent = node.parent.clone();
let original = self
.message_store
.get(entry_id)
.cloned()
.ok_or_else(|| format!("missing content for entry {}", entry_id))?;
match parent {
Some(parent_id) => {
self.switch_to_leaf(&parent_id)?;
}
None => {
self.messages.clear();
self.total_estimated_tokens = 0;
self.tree.leaf_id = None;
self.updated_at = CompactString::new(chrono::Utc::now().to_rfc3339());
}
}
Ok(original)
}
pub fn clone_at(&mut self, entry_id: &CompactString) -> Result<(), String> {
self.switch_to_leaf(entry_id)
}
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn set_label(
&mut self,
entry_id: &CompactString,
label: Option<String>,
) -> Result<(), String> {
self.ensure_back_compat_initialized();
let node = self
.tree
.entries
.get_mut(entry_id)
.ok_or_else(|| format!("unknown entry id: {}", entry_id))?;
node.label = label;
self.updated_at = CompactString::new(chrono::Utc::now().to_rfc3339());
Ok(())
}
#[cfg_attr(not(feature = "plugin"), allow(dead_code))]
pub fn reset_to_new(&mut self, parent_session: Option<&str>) {
let now = CompactString::new(chrono::Utc::now().to_rfc3339());
self.id = CompactString::new(Uuid::new_v4().to_string());
if let Some(parent) = parent_session {
self.name = CompactString::new(format!("parent:{}", parent));
} else {
self.name = CompactString::new("");
}
self.messages.clear();
self.compactions.clear();
self.extra_entries.clear();
self.next_entry_seq = 0;
self.message_store.clear();
self.tree = SessionTree::default();
self.total_tokens = 0;
self.total_cost = 0.0;
self.total_estimated_tokens = 0;
self.created_at = now.clone();
self.updated_at = now;
self.permission_allowlist.clear();
self.branch_summaries.clear();
self.loaded_mtime = None;
self.current_prompt_name = None;
self.todo_list.clear();
self.modified_files.clear();
self.loaded_from_newer_version = None;
}
pub fn needs_compaction(&self, reserve_tokens: u64) -> bool {
compact::needs_compaction(
self.total_estimated_tokens,
self.context_window,
reserve_tokens,
)
}
pub fn compacted_context(&self) -> (Option<&str>, usize) {
compact::compacted_context(&self.compactions, self.messages.len())
}
#[cfg(test)]
pub fn compress(&mut self, summary: String, first_kept_index: usize, token_savings: u64) {
compact::compress(self, summary, first_kept_index, token_savings);
}
pub fn compress_reporting(
&mut self,
summary: String,
first_kept_index: usize,
token_savings: u64,
) -> usize {
compact::compress_reporting(self, summary, first_kept_index, token_savings)
}
}
#[cfg(test)]
#[path = "mod_tests.rs"]
mod tests;