use crate::Provider;
use crate::types::{Message, SessionId, Usage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::sync::Arc;
use std::time::SystemTime;
pub const SESSION_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct Session {
version: u32,
id: SessionId,
pub(crate) messages: Arc<Vec<Message>>,
created_at: SystemTime,
updated_at: SystemTime,
metadata: serde_json::Map<String, serde_json::Value>,
usage: Usage,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
struct SessionSerde {
#[serde(default = "default_version")]
version: u32,
id: SessionId,
messages: Vec<Message>,
created_at: SystemTime,
updated_at: SystemTime,
#[serde(default)]
metadata: serde_json::Map<String, serde_json::Value>,
#[serde(default)]
usage: Usage,
}
impl Serialize for Session {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let serde_repr = SessionSerde {
version: self.version,
id: self.id.clone(),
messages: (*self.messages).clone(),
created_at: self.created_at,
updated_at: self.updated_at,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
};
serde_repr.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Session {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let serde_repr = SessionSerde::deserialize(deserializer)?;
Ok(Session {
version: serde_repr.version,
id: serde_repr.id,
messages: Arc::new(serde_repr.messages),
created_at: serde_repr.created_at,
updated_at: serde_repr.updated_at,
metadata: serde_repr.metadata,
usage: serde_repr.usage,
})
}
}
fn default_version() -> u32 {
SESSION_VERSION
}
impl Session {
pub fn new() -> Self {
let now = SystemTime::now();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::new(Vec::new()),
created_at: now,
updated_at: now,
metadata: serde_json::Map::new(),
usage: Usage::default(),
}
}
pub fn with_id(id: SessionId) -> Self {
let mut session = Self::new();
session.id = id;
session
}
pub fn id(&self) -> &SessionId {
&self.id
}
pub fn version(&self) -> u32 {
self.version
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn messages_mut(&mut self) -> &mut Vec<Message> {
Arc::make_mut(&mut self.messages)
}
pub fn created_at(&self) -> SystemTime {
self.created_at
}
pub fn updated_at(&self) -> SystemTime {
self.updated_at
}
pub fn push(&mut self, message: Message) {
Arc::make_mut(&mut self.messages).push(message);
self.updated_at = SystemTime::now();
}
pub fn push_batch(&mut self, messages: Vec<Message>) {
if messages.is_empty() {
return;
}
let inner = Arc::make_mut(&mut self.messages);
inner.extend(messages);
self.updated_at = SystemTime::now();
}
pub fn touch(&mut self) {
self.updated_at = SystemTime::now();
}
pub fn last_n(&self, n: usize) -> &[Message] {
let start = self.messages.len().saturating_sub(n);
&self.messages[start..]
}
pub fn total_tokens(&self) -> u64 {
self.usage.total_tokens()
}
pub fn total_usage(&self) -> Usage {
self.usage.clone()
}
pub fn record_usage(&mut self, turn_usage: Usage) {
self.usage.add(&turn_usage);
self.updated_at = SystemTime::now();
}
pub fn set_system_prompt(&mut self, prompt: String) {
use crate::types::SystemMessage;
let inner = Arc::make_mut(&mut self.messages);
if let Some(Message::System(_)) = inner.first() {
inner[0] = Message::System(SystemMessage { content: prompt });
} else {
inner.insert(0, Message::System(SystemMessage { content: prompt }));
}
self.updated_at = SystemTime::now();
}
pub fn last_assistant_text(&self) -> Option<String> {
self.messages.iter().rev().find_map(|m| match m {
Message::BlockAssistant(a) => {
let mut buf = String::new();
for block in &a.blocks {
if let crate::types::AssistantBlock::Text { text, .. } = block {
buf.push_str(text);
}
}
if buf.is_empty() { None } else { Some(buf) }
}
Message::Assistant(a) if !a.content.is_empty() => Some(a.content.clone()),
_ => None,
})
}
pub fn tool_call_count(&self) -> usize {
self.messages
.iter()
.filter_map(|m| match m {
Message::BlockAssistant(a) => Some(
a.blocks
.iter()
.filter(|b| matches!(b, crate::types::AssistantBlock::ToolUse { .. }))
.count(),
),
Message::Assistant(a) => Some(a.tool_calls.len()),
_ => None,
})
.sum()
}
pub fn metadata(&self) -> &serde_json::Map<String, serde_json::Value> {
&self.metadata
}
pub fn set_metadata(&mut self, key: &str, value: serde_json::Value) {
self.metadata.insert(key.to_string(), value);
self.updated_at = SystemTime::now();
}
pub fn set_session_metadata(
&mut self,
metadata: SessionMetadata,
) -> Result<(), serde_json::Error> {
let value = serde_json::to_value(metadata)?;
self.set_metadata(SESSION_METADATA_KEY, value);
Ok(())
}
pub fn session_metadata(&self) -> Option<SessionMetadata> {
self.metadata
.get(SESSION_METADATA_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn fork_at(&self, index: usize) -> Self {
let now = SystemTime::now();
let truncated = self.messages[..index.min(self.messages.len())].to_vec();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::new(truncated),
created_at: now,
updated_at: now,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
}
}
pub fn fork(&self) -> Self {
let now = SystemTime::now();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::clone(&self.messages),
created_at: now,
updated_at: now,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
}
}
}
impl Default for Session {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SessionMeta {
pub id: SessionId,
pub created_at: SystemTime,
pub updated_at: SystemTime,
pub message_count: usize,
pub total_tokens: u64,
#[serde(default)]
pub metadata: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SessionMetadata {
pub model: String,
pub max_tokens: u32,
pub provider: Provider,
pub tooling: SessionTooling,
pub host_mode: bool,
pub comms_name: Option<String>,
}
pub const SESSION_METADATA_KEY: &str = "session_metadata";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub struct SessionTooling {
pub builtins: bool,
pub shell: bool,
pub comms: bool,
pub subagents: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub active_skills: Option<Vec<crate::skills::SkillId>>,
}
impl From<&Session> for SessionMeta {
fn from(session: &Session) -> Self {
Self {
id: session.id.clone(),
created_at: session.created_at,
updated_at: session.updated_at,
message_count: session.messages.len(),
total_tokens: session.total_tokens(),
metadata: session.metadata.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::types::{AssistantMessage, StopReason, SystemMessage, UserMessage};
use std::sync::Arc;
#[test]
fn test_session_new() {
let session = Session::new();
assert_eq!(session.version(), SESSION_VERSION);
assert!(session.messages().is_empty());
assert!(session.created_at() <= session.updated_at());
}
#[test]
fn test_fork_shares_arc_no_clone() {
let mut session = Session::new();
for i in 0..100 {
session.push(Message::User(UserMessage {
content: format!("Message {}", i),
}));
}
let forked = session.fork();
assert!(Arc::ptr_eq(&session.messages, &forked.messages));
assert_eq!(forked.messages().len(), 100);
}
#[test]
fn test_fork_at_shares_arc_prefix() {
let mut session = Session::new();
for i in 0..100 {
session.push(Message::User(UserMessage {
content: format!("Message {}", i),
}));
}
let forked = session.fork_at(50);
assert_eq!(forked.messages().len(), 50);
assert_eq!(session.messages().len(), 100);
}
#[test]
fn test_push_cow_behavior() {
let mut session = Session::new();
session.push(Message::User(UserMessage {
content: "First".to_string(),
}));
let forked = session.fork();
assert!(Arc::ptr_eq(&session.messages, &forked.messages));
session.push(Message::User(UserMessage {
content: "Second".to_string(),
}));
assert!(!Arc::ptr_eq(&session.messages, &forked.messages));
assert_eq!(session.messages().len(), 2);
assert_eq!(forked.messages().len(), 1);
}
#[test]
fn test_push_batch_single_timestamp() {
let mut session = Session::new();
let initial_updated = session.updated_at();
session.push_batch(vec![
Message::User(UserMessage {
content: "First".to_string(),
}),
Message::User(UserMessage {
content: "Second".to_string(),
}),
Message::User(UserMessage {
content: "Third".to_string(),
}),
]);
assert_eq!(session.messages().len(), 3);
assert!(session.updated_at() >= initial_updated);
}
#[test]
fn test_touch_updates_timestamp() {
let mut session = Session::new();
let initial = session.updated_at();
std::thread::sleep(std::time::Duration::from_millis(10));
session.touch();
assert!(session.updated_at() > initial);
}
#[test]
fn test_session_push() {
let mut session = Session::new();
let initial_updated = session.updated_at();
std::thread::sleep(std::time::Duration::from_millis(10));
session.push(Message::User(UserMessage {
content: "Hello".to_string(),
}));
assert_eq!(session.messages().len(), 1);
assert!(session.updated_at() > initial_updated);
}
#[test]
fn test_session_fork() {
let mut session = Session::new();
session.push(Message::System(SystemMessage {
content: "System prompt".to_string(),
}));
session.push(Message::User(UserMessage {
content: "Hello".to_string(),
}));
session.push(Message::Assistant(AssistantMessage {
content: "Hi!".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: Usage::default(),
}));
let forked = session.fork_at(2);
assert_eq!(forked.messages().len(), 2);
assert_ne!(forked.id(), session.id());
let full_fork = session.fork();
assert_eq!(full_fork.messages().len(), 3);
}
#[test]
fn test_session_metadata() {
let mut session = Session::new();
session.set_metadata("key", serde_json::json!("value"));
assert_eq!(session.metadata().get("key").unwrap(), "value");
}
#[test]
fn test_session_serialization() {
let mut session = Session::new();
session.push(Message::User(UserMessage {
content: "Test".to_string(),
}));
let json = serde_json::to_string(&session).unwrap();
let parsed: Session = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id(), session.id());
assert_eq!(parsed.messages().len(), 1);
assert_eq!(parsed.version(), SESSION_VERSION);
}
#[test]
fn test_session_meta_from_session() {
let mut session = Session::new();
session.push(Message::User(UserMessage {
content: "Hello".to_string(),
}));
session.push(Message::Assistant(AssistantMessage {
content: "Hi!".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: 10,
output_tokens: 5,
cache_creation_tokens: None,
cache_read_tokens: None,
},
}));
session.record_usage(Usage {
input_tokens: 10,
output_tokens: 5,
cache_creation_tokens: None,
cache_read_tokens: None,
});
let meta = SessionMeta::from(&session);
assert_eq!(meta.id, *session.id());
assert_eq!(meta.message_count, 2);
assert_eq!(meta.total_tokens, 15);
}
}