use crate::Provider;
use crate::peer_meta::PeerMeta;
use crate::service::{AppendSystemContextRequest, MobToolAuthorityContext};
use crate::time_compat::SystemTime;
use crate::types::{ContentInput, Message, SessionId, ToolDef, ToolResult, Usage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::sync::Arc;
pub const SESSION_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct Session {
version: u32,
id: SessionId,
pub(crate) messages: Arc<Vec<Message>>,
created_at: SystemTime,
updated_at: SystemTime,
metadata: serde_json::Map<String, serde_json::Value>,
usage: Usage,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
struct SessionSerde {
#[serde(default = "default_version")]
version: u32,
id: SessionId,
messages: Vec<Message>,
created_at: SystemTime,
updated_at: SystemTime,
#[serde(default)]
metadata: serde_json::Map<String, serde_json::Value>,
#[serde(default)]
usage: Usage,
}
impl Serialize for Session {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let serde_repr = SessionSerde {
version: self.version,
id: self.id.clone(),
messages: (*self.messages).clone(),
created_at: self.created_at,
updated_at: self.updated_at,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
};
serde_repr.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Session {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let serde_repr = SessionSerde::deserialize(deserializer)?;
Ok(Session {
version: serde_repr.version,
id: serde_repr.id,
messages: Arc::new(serde_repr.messages),
created_at: serde_repr.created_at,
updated_at: serde_repr.updated_at,
metadata: serde_repr.metadata,
usage: serde_repr.usage,
})
}
}
fn default_version() -> u32 {
SESSION_VERSION
}
pub const SESSION_SYSTEM_CONTEXT_STATE_KEY: &str = "session_system_context_state";
pub const SESSION_DEFERRED_TURN_STATE_KEY: &str = "session_deferred_turn_state";
pub const SESSION_BUILD_STATE_KEY: &str = "session_build_state";
pub const SYSTEM_CONTEXT_SEPARATOR: &str = "\n\n---\n\n";
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct SessionSystemContextState {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub pending: Vec<PendingSystemContextAppend>,
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
pub seen: std::collections::BTreeMap<String, SeenSystemContextKey>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct PendingSystemContextAppend {
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub idempotency_key: Option<String>,
pub accepted_at: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub struct SessionDeferredTurnState {
#[serde(default, skip_serializing_if = "DeferredFirstTurnPhase::is_inactive")]
pub first_turn_phase: DeferredFirstTurnPhase,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pending_initial_prompt: Option<PendingDeferredPrompt>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub pending_tool_results: Vec<PendingToolResultsMessage>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DeferredFirstTurnPhase {
#[default]
Inactive,
Pending,
Consumed,
}
impl DeferredFirstTurnPhase {
pub fn is_inactive(&self) -> bool {
matches!(self, Self::Inactive)
}
}
fn is_default_hook_run_overrides(value: &crate::HookRunOverrides) -> bool {
value == &crate::HookRunOverrides::default()
}
fn is_default_call_timeout_override(value: &crate::CallTimeoutOverride) -> bool {
value == &crate::CallTimeoutOverride::default()
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub struct SessionBuildState {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<crate::OutputSchema>,
#[serde(default, skip_serializing_if = "is_default_hook_run_overrides")]
pub hooks_override: crate::HookRunOverrides,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub budget_limits: Option<crate::BudgetLimits>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub recoverable_tool_defs: Vec<ToolDef>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub silent_comms_intents: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_inline_peer_notifications: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub app_context: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub additional_instructions: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shell_env: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mob_tool_authority_context: Option<MobToolAuthorityContext>,
#[serde(default, skip_serializing_if = "is_default_call_timeout_override")]
pub call_timeout_override: crate::CallTimeoutOverride,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub struct PendingDeferredPrompt {
pub prompt: ContentInput,
pub accepted_at: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PendingToolResultsMessage {
pub results: Vec<ToolResult>,
pub accepted_at: SystemTime,
}
impl PartialEq for PendingToolResultsMessage {
fn eq(&self, other: &Self) -> bool {
self.accepted_at == other.accepted_at
&& serde_json::to_value(&self.results).ok() == serde_json::to_value(&other.results).ok()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct SeenSystemContextKey {
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
pub state: SeenSystemContextState,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SeenSystemContextState {
Pending,
Applied,
}
impl SessionSystemContextState {
pub fn stage_append(
&mut self,
req: &AppendSystemContextRequest,
accepted_at: SystemTime,
) -> Result<crate::service::AppendSystemContextStatus, SystemContextStageError> {
let text = req.text.trim();
if text.is_empty() {
return Err(SystemContextStageError::InvalidRequest(
"system context text must not be empty".to_string(),
));
}
if let Some(key) = req.idempotency_key.as_ref() {
match self.seen.get(key) {
Some(existing)
if existing.text == text
&& existing.source.as_deref() == req.source.as_deref() =>
{
return Ok(crate::service::AppendSystemContextStatus::Duplicate);
}
Some(existing) => {
return Err(SystemContextStageError::Conflict {
key: key.clone(),
existing_text: existing.text.clone(),
existing_source: existing.source.clone(),
});
}
None => {}
}
}
let append = PendingSystemContextAppend {
text: text.to_string(),
source: req.source.clone(),
idempotency_key: req.idempotency_key.clone(),
accepted_at,
};
if let Some(key) = req.idempotency_key.as_ref() {
self.seen.insert(
key.clone(),
SeenSystemContextKey {
text: append.text.clone(),
source: append.source.clone(),
state: SeenSystemContextState::Pending,
},
);
}
self.pending.push(append);
Ok(crate::service::AppendSystemContextStatus::Staged)
}
pub fn mark_pending_applied(&mut self) {
for pending in &self.pending {
if let Some(key) = pending.idempotency_key.as_ref()
&& let Some(seen) = self.seen.get_mut(key)
{
seen.state = SeenSystemContextState::Applied;
}
}
self.pending.clear();
}
}
impl SessionDeferredTurnState {
pub fn mark_initial_turn_pending(&mut self) {
self.first_turn_phase = DeferredFirstTurnPhase::Pending;
}
pub fn mark_initial_turn_started(&mut self) -> bool {
let was_pending = matches!(self.first_turn_phase, DeferredFirstTurnPhase::Pending);
if was_pending {
self.first_turn_phase = DeferredFirstTurnPhase::Consumed;
}
was_pending
}
pub fn restore_initial_turn_pending(&mut self) {
self.first_turn_phase = DeferredFirstTurnPhase::Pending;
}
pub fn allows_initial_turn_overrides(&self) -> bool {
matches!(self.first_turn_phase, DeferredFirstTurnPhase::Pending)
}
pub fn stage_initial_prompt(&mut self, prompt: ContentInput, accepted_at: SystemTime) {
if !prompt.has_images() && prompt.text_content().trim().is_empty() {
self.pending_initial_prompt = None;
return;
}
self.pending_initial_prompt = Some(PendingDeferredPrompt {
prompt,
accepted_at,
});
}
pub fn stage_tool_results(
&mut self,
results: Vec<ToolResult>,
accepted_at: SystemTime,
) -> usize {
if results.is_empty() {
return 0;
}
let accepted = results.len();
self.pending_tool_results.push(PendingToolResultsMessage {
results,
accepted_at,
});
accepted
}
pub fn take_initial_prompt(&mut self) -> Option<ContentInput> {
self.pending_initial_prompt
.take()
.map(|pending| pending.prompt)
}
pub fn take_tool_results(&mut self) -> Vec<PendingToolResultsMessage> {
std::mem::take(&mut self.pending_tool_results)
}
pub fn has_pending_tool_results(&self) -> bool {
!self.pending_tool_results.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SystemContextStageError {
InvalidRequest(String),
Conflict {
key: String,
existing_text: String,
existing_source: Option<String>,
},
}
fn render_system_context_block(append: &PendingSystemContextAppend) -> String {
let mut rendered = String::from("[Runtime System Context]");
if let Some(source) = &append.source {
rendered.push_str("\nsource: ");
rendered.push_str(source);
}
rendered.push_str("\n\n");
rendered.push_str(&append.text);
rendered
}
impl Session {
pub fn new() -> Self {
let now = SystemTime::now();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::new(Vec::new()),
created_at: now,
updated_at: now,
metadata: serde_json::Map::new(),
usage: Usage::default(),
}
}
pub fn with_id(id: SessionId) -> Self {
let mut session = Self::new();
session.id = id;
session
}
pub fn id(&self) -> &SessionId {
&self.id
}
pub fn version(&self) -> u32 {
self.version
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn messages_mut(&mut self) -> &mut Vec<Message> {
Arc::make_mut(&mut self.messages)
}
pub fn created_at(&self) -> SystemTime {
self.created_at
}
pub fn updated_at(&self) -> SystemTime {
self.updated_at
}
pub fn push(&mut self, message: Message) {
Arc::make_mut(&mut self.messages).push(message);
self.updated_at = SystemTime::now();
}
pub fn push_batch(&mut self, messages: Vec<Message>) {
if messages.is_empty() {
return;
}
let inner = Arc::make_mut(&mut self.messages);
inner.extend(messages);
self.updated_at = SystemTime::now();
}
pub fn touch(&mut self) {
self.updated_at = SystemTime::now();
}
pub fn last_n(&self, n: usize) -> &[Message] {
let start = self.messages.len().saturating_sub(n);
&self.messages[start..]
}
pub fn total_tokens(&self) -> u64 {
self.usage.total_tokens()
}
pub fn total_usage(&self) -> Usage {
self.usage.clone()
}
pub fn record_usage(&mut self, turn_usage: Usage) {
self.usage.add(&turn_usage);
self.updated_at = SystemTime::now();
}
pub fn set_system_prompt(&mut self, prompt: String) {
use crate::types::SystemMessage;
let inner = Arc::make_mut(&mut self.messages);
if let Some(Message::System(_)) = inner.first() {
inner[0] = Message::System(SystemMessage { content: prompt });
} else {
inner.insert(0, Message::System(SystemMessage { content: prompt }));
}
self.updated_at = SystemTime::now();
}
pub fn append_system_context_blocks(&mut self, appends: &[PendingSystemContextAppend]) {
if appends.is_empty() {
return;
}
let rendered = appends
.iter()
.map(render_system_context_block)
.collect::<Vec<_>>()
.join(SYSTEM_CONTEXT_SEPARATOR);
let next = match self.messages.first() {
Some(Message::System(sys)) if !sys.content.is_empty() => {
format!("{}{}{}", sys.content, SYSTEM_CONTEXT_SEPARATOR, rendered)
}
_ => rendered,
};
self.set_system_prompt(next);
}
pub fn last_assistant_text(&self) -> Option<String> {
self.messages.iter().rev().find_map(|m| match m {
Message::BlockAssistant(a) => {
let mut buf = String::new();
for block in &a.blocks {
if let crate::types::AssistantBlock::Text { text, .. } = block {
buf.push_str(text);
}
}
if buf.is_empty() { None } else { Some(buf) }
}
Message::Assistant(a) if !a.content.is_empty() => Some(a.content.clone()),
_ => None,
})
}
pub fn tool_call_count(&self) -> usize {
self.messages
.iter()
.filter_map(|m| match m {
Message::BlockAssistant(a) => Some(
a.blocks
.iter()
.filter(|b| matches!(b, crate::types::AssistantBlock::ToolUse { .. }))
.count(),
),
Message::Assistant(a) => Some(a.tool_calls.len()),
_ => None,
})
.sum()
}
pub fn metadata(&self) -> &serde_json::Map<String, serde_json::Value> {
&self.metadata
}
pub fn set_metadata(&mut self, key: &str, value: serde_json::Value) {
self.metadata.insert(key.to_string(), value);
self.updated_at = SystemTime::now();
}
pub fn remove_metadata(&mut self, key: &str) {
self.metadata.remove(key);
self.updated_at = SystemTime::now();
}
pub fn set_session_metadata(
&mut self,
metadata: SessionMetadata,
) -> Result<(), serde_json::Error> {
let value = serde_json::to_value(metadata)?;
self.set_metadata(SESSION_METADATA_KEY, value);
Ok(())
}
pub fn session_metadata(&self) -> Option<SessionMetadata> {
self.metadata
.get(SESSION_METADATA_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn set_system_context_state(
&mut self,
state: SessionSystemContextState,
) -> Result<(), serde_json::Error> {
let value = serde_json::to_value(state)?;
self.set_metadata(SESSION_SYSTEM_CONTEXT_STATE_KEY, value);
Ok(())
}
pub fn system_context_state(&self) -> Option<SessionSystemContextState> {
self.metadata
.get(SESSION_SYSTEM_CONTEXT_STATE_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn set_deferred_turn_state(
&mut self,
state: SessionDeferredTurnState,
) -> Result<(), serde_json::Error> {
let value = serde_json::to_value(state)?;
self.set_metadata(SESSION_DEFERRED_TURN_STATE_KEY, value);
Ok(())
}
pub fn deferred_turn_state(&self) -> Option<SessionDeferredTurnState> {
self.metadata
.get(SESSION_DEFERRED_TURN_STATE_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn set_build_state(&mut self, state: SessionBuildState) -> Result<(), serde_json::Error> {
let value = serde_json::to_value(state)?;
self.set_metadata(SESSION_BUILD_STATE_KEY, value);
Ok(())
}
pub fn build_state(&self) -> Option<SessionBuildState> {
self.metadata
.get(SESSION_BUILD_STATE_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
pub fn set_mob_tool_authority_context(
&mut self,
authority_context: Option<MobToolAuthorityContext>,
) -> Result<(), serde_json::Error> {
let mut build_state = self.build_state().unwrap_or_default();
build_state.mob_tool_authority_context = authority_context;
self.set_build_state(build_state)
}
pub fn mob_tool_authority_context(&self) -> Option<MobToolAuthorityContext> {
self.build_state()
.and_then(|state| state.mob_tool_authority_context)
}
pub fn fork_at(&self, index: usize) -> Self {
let now = SystemTime::now();
let truncated = self.messages[..index.min(self.messages.len())].to_vec();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::new(truncated),
created_at: now,
updated_at: now,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
}
}
pub fn fork(&self) -> Self {
let now = SystemTime::now();
Self {
version: SESSION_VERSION,
id: SessionId::new(),
messages: Arc::clone(&self.messages),
created_at: now,
updated_at: now,
metadata: self.metadata.clone(),
usage: self.usage.clone(),
}
}
}
impl Default for Session {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SessionMeta {
pub id: SessionId,
pub created_at: SystemTime,
pub updated_at: SystemTime,
pub message_count: usize,
pub total_tokens: u64,
#[serde(default)]
pub metadata: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct SessionMetadata {
pub model: String,
pub max_tokens: u32,
#[serde(default = "default_structured_output_retries")]
pub structured_output_retries: u32,
pub provider: Provider,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_params: Option<serde_json::Value>,
pub tooling: SessionTooling,
#[serde(default)]
pub keep_alive: bool,
pub comms_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub peer_meta: Option<PeerMeta>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub realm_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instance_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub config_generation: Option<u64>,
}
fn default_structured_output_retries() -> u32 {
2
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub struct SessionLlmIdentity {
pub model: String,
pub provider: Provider,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_params: Option<serde_json::Value>,
}
impl SessionMetadata {
pub fn llm_identity(&self) -> SessionLlmIdentity {
SessionLlmIdentity {
model: self.model.clone(),
provider: self.provider,
provider_params: self.provider_params.clone(),
}
}
pub fn apply_llm_identity(&mut self, identity: &SessionLlmIdentity) {
self.model = identity.model.clone();
self.provider = identity.provider;
self.provider_params = identity.provider_params.clone();
}
}
pub const SESSION_METADATA_KEY: &str = "session_metadata";
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCategoryOverride {
#[default]
Inherit,
Enable,
Disable,
}
impl ToolCategoryOverride {
#[must_use]
pub fn resolve(self, runtime_default: bool) -> bool {
match self {
Self::Enable => true,
Self::Disable => false,
Self::Inherit => runtime_default,
}
}
#[must_use]
pub fn to_override(self) -> Option<bool> {
match self {
Self::Enable => Some(true),
Self::Disable => Some(false),
Self::Inherit => None,
}
}
#[must_use]
pub fn from_effective(enabled: bool) -> Self {
if enabled { Self::Enable } else { Self::Disable }
}
#[must_use]
pub fn from_override(value: Option<bool>) -> Self {
match value {
Some(true) => Self::Enable,
Some(false) => Self::Disable,
None => Self::Inherit,
}
}
}
fn deserialize_tool_category_compat<'de, D>(
deserializer: D,
) -> Result<ToolCategoryOverride, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct ToolCategoryVisitor;
impl de::Visitor<'_> for ToolCategoryVisitor {
type Value = ToolCategoryOverride;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a boolean or one of \"inherit\", \"enable\", \"disable\"")
}
fn visit_bool<E: de::Error>(self, v: bool) -> Result<Self::Value, E> {
Ok(if v {
ToolCategoryOverride::Enable
} else {
ToolCategoryOverride::Inherit
})
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
match v {
"inherit" => Ok(ToolCategoryOverride::Inherit),
"enable" => Ok(ToolCategoryOverride::Enable),
"disable" => Ok(ToolCategoryOverride::Disable),
_ => Err(de::Error::unknown_variant(
v,
&["inherit", "enable", "disable"],
)),
}
}
}
deserializer.deserialize_any(ToolCategoryVisitor)
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub struct SessionTooling {
#[serde(default, deserialize_with = "deserialize_tool_category_compat")]
pub builtins: ToolCategoryOverride,
#[serde(default, deserialize_with = "deserialize_tool_category_compat")]
pub shell: ToolCategoryOverride,
#[serde(default, deserialize_with = "deserialize_tool_category_compat")]
pub comms: ToolCategoryOverride,
#[serde(default, deserialize_with = "deserialize_tool_category_compat")]
pub mob: ToolCategoryOverride,
#[serde(default, deserialize_with = "deserialize_tool_category_compat")]
pub memory: ToolCategoryOverride,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub active_skills: Option<Vec<crate::skills::SkillId>>,
}
impl From<&Session> for SessionMeta {
fn from(session: &Session) -> Self {
Self {
id: session.id.clone(),
created_at: session.created_at,
updated_at: session.updated_at,
message_count: session.messages.len(),
total_tokens: session.total_tokens(),
metadata: session.metadata.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::types::{AssistantMessage, StopReason, SystemMessage, UserMessage};
use std::sync::Arc;
#[test]
fn test_session_new() {
let session = Session::new();
assert_eq!(session.version(), SESSION_VERSION);
assert!(session.messages().is_empty());
assert!(session.created_at() <= session.updated_at());
}
#[test]
fn test_fork_shares_arc_no_clone() {
let mut session = Session::new();
for i in 0..100 {
session.push(Message::User(UserMessage::text(format!("Message {i}"))));
}
let forked = session.fork();
assert!(Arc::ptr_eq(&session.messages, &forked.messages));
assert_eq!(forked.messages().len(), 100);
}
#[test]
fn test_fork_at_shares_arc_prefix() {
let mut session = Session::new();
for i in 0..100 {
session.push(Message::User(UserMessage::text(format!("Message {i}"))));
}
let forked = session.fork_at(50);
assert_eq!(forked.messages().len(), 50);
assert_eq!(session.messages().len(), 100);
}
#[test]
fn test_push_cow_behavior() {
let mut session = Session::new();
session.push(Message::User(UserMessage::text("First".to_string())));
let forked = session.fork();
assert!(Arc::ptr_eq(&session.messages, &forked.messages));
session.push(Message::User(UserMessage::text("Second".to_string())));
assert!(!Arc::ptr_eq(&session.messages, &forked.messages));
assert_eq!(session.messages().len(), 2);
assert_eq!(forked.messages().len(), 1);
}
#[test]
fn test_push_batch_single_timestamp() {
let mut session = Session::new();
let initial_updated = session.updated_at();
session.push_batch(vec![
Message::User(UserMessage::text("First".to_string())),
Message::User(UserMessage::text("Second".to_string())),
Message::User(UserMessage::text("Third".to_string())),
]);
assert_eq!(session.messages().len(), 3);
assert!(session.updated_at() >= initial_updated);
}
#[test]
fn test_touch_updates_timestamp() {
let mut session = Session::new();
let initial = session.updated_at();
std::thread::sleep(std::time::Duration::from_millis(10));
session.touch();
assert!(session.updated_at() > initial);
}
#[test]
fn test_session_push() {
let mut session = Session::new();
let initial_updated = session.updated_at();
std::thread::sleep(std::time::Duration::from_millis(10));
session.push(Message::User(UserMessage::text("Hello".to_string())));
assert_eq!(session.messages().len(), 1);
assert!(session.updated_at() > initial_updated);
}
#[test]
fn test_session_fork() {
let mut session = Session::new();
session.push(Message::System(SystemMessage {
content: "System prompt".to_string(),
}));
session.push(Message::User(UserMessage::text("Hello".to_string())));
session.push(Message::Assistant(AssistantMessage {
content: "Hi!".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: Usage::default(),
}));
let forked = session.fork_at(2);
assert_eq!(forked.messages().len(), 2);
assert_ne!(forked.id(), session.id());
let full_fork = session.fork();
assert_eq!(full_fork.messages().len(), 3);
}
#[test]
fn test_session_metadata() {
let mut session = Session::new();
session.set_metadata("key", serde_json::json!("value"));
assert_eq!(session.metadata().get("key").unwrap(), "value");
}
#[test]
fn test_session_mob_tool_authority_context_roundtrip() {
let mut session = Session::new();
let authority = MobToolAuthorityContext::new(
crate::service::OpaquePrincipalToken::new("opaque-principal"),
false,
)
.with_managed_mob_scope(["mob-a"])
.with_audit_invocation_id("audit-1");
session
.set_mob_tool_authority_context(Some(authority.clone()))
.expect("authority should serialize");
assert_eq!(session.mob_tool_authority_context(), Some(authority));
session
.set_mob_tool_authority_context(None)
.expect("authority should clear");
assert!(session.mob_tool_authority_context().is_none());
}
#[test]
fn test_session_serialization() {
let mut session = Session::new();
session.push(Message::User(UserMessage::text("Test".to_string())));
let json = serde_json::to_string(&session).unwrap();
let parsed: Session = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id(), session.id());
assert_eq!(parsed.messages().len(), 1);
assert_eq!(parsed.version(), SESSION_VERSION);
}
#[test]
fn test_session_meta_from_session() {
let mut session = Session::new();
session.push(Message::User(UserMessage::text("Hello".to_string())));
session.push(Message::Assistant(AssistantMessage {
content: "Hi!".to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
usage: Usage {
input_tokens: 10,
output_tokens: 5,
cache_creation_tokens: None,
cache_read_tokens: None,
},
}));
session.record_usage(Usage {
input_tokens: 10,
output_tokens: 5,
cache_creation_tokens: None,
cache_read_tokens: None,
});
let meta = SessionMeta::from(&session);
assert_eq!(meta.id, *session.id());
assert_eq!(meta.message_count, 2);
assert_eq!(meta.total_tokens, 15);
}
}