pub(crate) mod compaction;
pub mod manager;
pub use manager::SessionManager;
#[cfg(test)]
#[path = "tests.rs"]
mod tests_file;
use crate::agent::AgentEvent;
use crate::hitl::{ConfirmationManager, ConfirmationPolicy, ConfirmationProvider};
use crate::llm::{LlmClient, Message, TokenUsage, ToolDefinition};
use crate::permissions::{PermissionChecker, PermissionDecision, PermissionPolicy};
use crate::planning::Task;
use crate::prompts::PlanningMode;
use crate::queue::{ExternalTaskResult, LaneHandlerConfig, SessionQueueConfig};
use crate::session_lane_queue::SessionLaneQueue;
use crate::store::{LlmConfigData, SessionData};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum SessionState {
#[default]
Unknown = 0,
Active = 1,
Paused = 2,
Completed = 3,
Error = 4,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextUsage {
pub used_tokens: usize,
pub max_tokens: usize,
pub percent: f32,
pub turns: usize,
}
impl Default for ContextUsage {
fn default() -> Self {
Self {
used_tokens: 0,
max_tokens: 200_000,
percent: 0.0,
turns: 0,
}
}
}
pub const DEFAULT_AUTO_COMPACT_THRESHOLD: f32 = 0.80;
fn default_auto_compact_threshold() -> f32 {
DEFAULT_AUTO_COMPACT_THRESHOLD
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
pub name: String,
pub workspace: String,
pub system_prompt: Option<String>,
pub max_context_length: u32,
pub auto_compact: bool,
#[serde(default = "default_auto_compact_threshold")]
pub auto_compact_threshold: f32,
#[serde(default)]
pub storage_type: crate::config::StorageBackend,
#[serde(skip_serializing_if = "Option::is_none")]
pub queue_config: Option<SessionQueueConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub confirmation_policy: Option<ConfirmationPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_policy: Option<PermissionPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub security_config: Option<crate::security::SecurityConfig>,
#[serde(skip)]
pub hook_engine: Option<std::sync::Arc<dyn crate::hooks::HookExecutor>>,
#[serde(default)]
pub planning_mode: PlanningMode,
#[serde(default)]
pub goal_tracking: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
name: String::new(),
workspace: String::new(),
system_prompt: None,
max_context_length: 0,
auto_compact: false,
auto_compact_threshold: DEFAULT_AUTO_COMPACT_THRESHOLD,
storage_type: crate::config::StorageBackend::default(),
queue_config: None,
confirmation_policy: None,
permission_policy: None,
parent_id: None,
security_config: None,
hook_engine: None,
planning_mode: PlanningMode::default(),
goal_tracking: false,
}
}
}
pub struct Session {
pub id: String,
pub config: SessionConfig,
pub state: SessionState,
pub messages: Vec<Message>,
pub context_usage: ContextUsage,
pub total_usage: TokenUsage,
pub total_cost: f64,
pub model_name: Option<String>,
pub tools: Vec<ToolDefinition>,
pub thinking_enabled: bool,
pub thinking_budget: Option<usize>,
pub llm_client: Option<Arc<dyn LlmClient>>,
pub created_at: i64,
pub updated_at: i64,
pub command_queue: SessionLaneQueue,
pub confirmation_manager: Arc<dyn ConfirmationProvider>,
pub permission_checker: Arc<dyn PermissionChecker>,
event_tx: broadcast::Sender<AgentEvent>,
pub context_providers: Vec<Arc<dyn crate::context::ContextProvider>>,
pub tasks: Vec<Task>,
pub parent_id: Option<String>,
pub memory: Option<Arc<RwLock<crate::memory::AgentMemory>>>,
pub current_plan: Arc<RwLock<Option<crate::planning::ExecutionPlan>>>,
pub security_provider: Option<Arc<dyn crate::security::SecurityProvider>>,
pub tool_metrics: Arc<RwLock<crate::telemetry::ToolMetrics>>,
pub cost_records: Vec<crate::telemetry::LlmCostRecord>,
}
fn validate_path_safe_id(id: &str, label: &str) -> Result<()> {
if id.is_empty() {
anyhow::bail!("{label} must not be empty");
}
let is_safe = id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
&& !id.starts_with('.')
&& !id.contains("..");
if !is_safe {
anyhow::bail!("{label} contains unsafe characters: {id:?}");
}
Ok(())
}
impl Session {
pub async fn new(
id: String,
config: SessionConfig,
tools: Vec<ToolDefinition>,
) -> Result<Self> {
validate_path_safe_id(&id, "Session ID")?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let (event_tx, _) = broadcast::channel(100);
let queue_config = config.queue_config.clone().unwrap_or_default();
let command_queue = SessionLaneQueue::new(&id, queue_config, event_tx.clone()).await?;
let confirmation_policy = config
.confirmation_policy
.clone()
.unwrap_or_else(ConfirmationPolicy::enabled);
let confirmation_manager = Arc::new(ConfirmationManager::new(
confirmation_policy,
event_tx.clone(),
));
let permission_checker: Arc<dyn PermissionChecker> =
Arc::new(config.permission_policy.clone().unwrap_or_default());
let parent_id = config.parent_id.clone();
let memory = None;
let context_providers: Vec<Arc<dyn crate::context::ContextProvider>> = vec![];
let current_plan = Arc::new(RwLock::new(None));
let security_provider: Option<Arc<dyn crate::security::SecurityProvider>> =
config.security_config.as_ref().and_then(|sc| {
if sc.enabled {
Some(Arc::new(crate::security::NoOpSecurityProvider)
as Arc<dyn crate::security::SecurityProvider>)
} else {
None
}
});
Ok(Self {
id,
config,
state: SessionState::Active,
messages: Vec::new(),
context_usage: ContextUsage::default(),
total_usage: TokenUsage::default(),
total_cost: 0.0,
model_name: None,
tools,
thinking_enabled: false,
thinking_budget: None,
llm_client: None,
created_at: now,
updated_at: now,
command_queue,
confirmation_manager,
permission_checker,
event_tx,
context_providers,
tasks: Vec::new(),
parent_id,
memory,
current_plan,
security_provider,
tool_metrics: Arc::new(RwLock::new(crate::telemetry::ToolMetrics::new())),
cost_records: Vec::new(),
})
}
pub fn is_child_session(&self) -> bool {
self.parent_id.is_some()
}
pub fn parent_session_id(&self) -> Option<&str> {
self.parent_id.as_deref()
}
pub fn subscribe_events(&self) -> broadcast::Receiver<AgentEvent> {
self.event_tx.subscribe()
}
pub fn event_tx(&self) -> broadcast::Sender<AgentEvent> {
self.event_tx.clone()
}
pub async fn set_confirmation_policy(&self, policy: ConfirmationPolicy) {
self.confirmation_manager.set_policy(policy).await;
}
pub fn set_permission_policy(&mut self, policy: PermissionPolicy) {
self.permission_checker = Arc::new(policy.clone());
self.config.permission_policy = Some(policy);
}
pub async fn confirmation_policy(&self) -> ConfirmationPolicy {
self.confirmation_manager.policy().await
}
pub fn check_permission(
&self,
tool_name: &str,
args: &serde_json::Value,
) -> PermissionDecision {
self.permission_checker.check(tool_name, args)
}
pub fn add_context_provider(&mut self, provider: Arc<dyn crate::context::ContextProvider>) {
self.context_providers.push(provider);
}
pub fn remove_context_provider(&mut self, name: &str) -> bool {
let initial_len = self.context_providers.len();
self.context_providers.retain(|p| p.name() != name);
self.context_providers.len() < initial_len
}
pub fn context_provider_names(&self) -> Vec<String> {
self.context_providers
.iter()
.map(|p| p.name().to_string())
.collect()
}
pub fn get_tasks(&self) -> &[Task] {
&self.tasks
}
pub fn set_tasks(&mut self, tasks: Vec<Task>) {
self.tasks = tasks.clone();
self.touch();
let _ = self.event_tx.send(AgentEvent::TaskUpdated {
session_id: self.id.clone(),
tasks,
});
}
pub fn active_task_count(&self) -> usize {
self.tasks.iter().filter(|t| t.is_active()).count()
}
pub async fn set_lane_handler(
&self,
lane: crate::hitl::SessionLane,
config: LaneHandlerConfig,
) {
self.command_queue.set_lane_handler(lane, config).await;
}
pub async fn get_lane_handler(&self, lane: crate::hitl::SessionLane) -> LaneHandlerConfig {
self.command_queue.get_lane_handler(lane).await
}
pub async fn complete_external_task(&self, task_id: &str, result: ExternalTaskResult) -> bool {
self.command_queue
.complete_external_task(task_id, result)
.await
}
pub async fn pending_external_tasks(&self) -> Vec<crate::queue::ExternalTask> {
self.command_queue.pending_external_tasks().await
}
pub async fn dead_letters(&self) -> Vec<a3s_lane::DeadLetter> {
self.command_queue.dead_letters().await
}
pub async fn queue_metrics(&self) -> Option<a3s_lane::MetricsSnapshot> {
self.command_queue.metrics_snapshot().await
}
pub async fn queue_stats(&self) -> crate::queue::SessionQueueStats {
self.command_queue.stats().await
}
pub async fn start_queue(&self) -> Result<()> {
self.command_queue.start().await
}
pub async fn stop_queue(&self) {
self.command_queue.stop().await;
}
pub fn system(&self) -> Option<&str> {
self.config.system_prompt.as_deref()
}
pub fn history(&self) -> &[Message] {
&self.messages
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
self.context_usage.turns = self.messages.len();
self.touch();
}
pub fn update_usage(&mut self, usage: &TokenUsage) {
self.total_usage.prompt_tokens += usage.prompt_tokens;
self.total_usage.completion_tokens += usage.completion_tokens;
self.total_usage.total_tokens += usage.total_tokens;
let cost_usd = if let Some(ref model) = self.model_name {
let pricing_map = crate::telemetry::default_model_pricing();
if let Some(pricing) = pricing_map.get(model) {
let cost = pricing.calculate_cost(usage.prompt_tokens, usage.completion_tokens);
self.total_cost += cost;
Some(cost)
} else {
None
}
} else {
None
};
let model_str = self.model_name.clone().unwrap_or_default();
self.cost_records.push(crate::telemetry::LlmCostRecord {
model: model_str.clone(),
provider: String::new(),
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
cost_usd,
timestamp: chrono::Utc::now(),
session_id: Some(self.id.clone()),
});
crate::telemetry::record_llm_metrics(
if model_str.is_empty() {
"unknown"
} else {
&model_str
},
usage.prompt_tokens,
usage.completion_tokens,
cost_usd.unwrap_or(0.0),
0.0, );
self.context_usage.used_tokens = usage.prompt_tokens;
self.context_usage.percent =
self.context_usage.used_tokens as f32 / self.context_usage.max_tokens as f32;
self.touch();
}
pub fn clear(&mut self) {
self.messages.clear();
self.context_usage = ContextUsage::default();
self.touch();
}
pub async fn compact(&mut self, llm_client: &Arc<dyn LlmClient>) -> Result<()> {
if let Some(new_messages) =
compaction::compact_messages(&self.id, &self.messages, llm_client).await?
{
self.messages = new_messages;
self.touch();
}
Ok(())
}
pub fn pause(&mut self) -> bool {
if self.state == SessionState::Active {
self.state = SessionState::Paused;
self.touch();
true
} else {
false
}
}
pub fn resume(&mut self) -> bool {
if self.state == SessionState::Paused {
self.state = SessionState::Active;
self.touch();
true
} else {
false
}
}
pub fn set_error(&mut self) {
self.state = SessionState::Error;
self.touch();
}
pub fn set_completed(&mut self) {
self.state = SessionState::Completed;
self.touch();
}
fn touch(&mut self) {
self.updated_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
}
pub fn to_session_data(&self, llm_config: Option<LlmConfigData>) -> SessionData {
SessionData {
id: self.id.clone(),
config: self.config.clone(),
state: self.state,
messages: self.messages.clone(),
context_usage: self.context_usage.clone(),
total_usage: self.total_usage.clone(),
total_cost: self.total_cost,
model_name: self.model_name.clone(),
cost_records: self.cost_records.clone(),
tool_names: SessionData::tool_names_from_definitions(&self.tools),
thinking_enabled: self.thinking_enabled,
thinking_budget: self.thinking_budget,
created_at: self.created_at,
updated_at: self.updated_at,
llm_config,
tasks: self.tasks.clone(),
parent_id: self.parent_id.clone(),
}
}
pub fn restore_from_data(&mut self, data: &SessionData) {
self.state = data.state;
self.messages = data.messages.clone();
self.context_usage = data.context_usage.clone();
self.total_usage = data.total_usage.clone();
self.total_cost = data.total_cost;
self.model_name = data.model_name.clone();
self.cost_records = data.cost_records.clone();
self.thinking_enabled = data.thinking_enabled;
self.thinking_budget = data.thinking_budget;
self.created_at = data.created_at;
self.updated_at = data.updated_at;
self.tasks = data.tasks.clone();
self.parent_id = data.parent_id.clone();
}
}