use crate::context::auto_compaction::{CompactionConfig, CompactionReason};
use crate::extensions::{
ExtensionContext, ExtensionContextBuilder, ExtensionRunner, InputEvent as ExtInputEvent,
InputEventResult as ExtInputEventResult, SessionShutdownEvent, SessionShutdownReason,
};
use anyhow::{Context, Result};
use oxi_agent::{Agent, AgentEvent, AgentState};
use oxi_ai::Message;
use oxi_store::session::{AgentMessage, SessionManager};
use oxi_store::settings::{Settings, ThinkingLevel};
use parking_lot::RwLock;
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum SessionEvent {
QueueUpdate {
steering: Vec<String>,
follow_up: Vec<String>,
},
CompactionStart {
reason: CompactionReason,
},
CompactionEnd {
reason: CompactionReason,
error_message: Option<String>,
},
SessionInfoChanged,
Agent(AgentEvent),
ThinkingLevelChanged {
level: ThinkingLevel,
},
}
#[derive(Debug, Clone)]
pub struct CompactionResult {
pub tokens_before: usize,
}
#[derive(Debug, Clone)]
pub struct ScopedModel {
pub provider: String,
pub model_id: String,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct PromptOptions {
pub expand_templates: bool,
pub images: Vec<oxi_ai::ImageContent>,
pub streaming_behavior: Option<StreamingBehavior>,
pub source: InputSource,
}
impl Default for PromptOptions {
fn default() -> Self {
Self {
expand_templates: true,
images: Vec::new(),
streaming_behavior: None,
source: InputSource::Interactive,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum StreamingBehavior {
Steer,
FollowUp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
#[derive(Default)]
pub enum InputSource {
#[default]
Interactive,
Extension,
Rpc,
}
#[derive(Debug, Clone)]
pub struct SessionStats {
pub session_id: String,
pub user_messages: usize,
pub assistant_messages: usize,
pub tool_calls: usize,
pub tool_results: usize,
pub total_messages: usize,
}
pub struct AgentSession {
agent: Arc<Agent>,
settings: Arc<RwLock<Settings>>,
session_manager: Arc<RwLock<SessionManager>>,
#[allow(clippy::type_complexity)]
listeners: Arc<RwLock<Vec<Box<dyn Fn(&SessionEvent) + Send + Sync>>>>,
event_tx: mpsc::UnboundedSender<SessionEvent>,
scoped_models: Arc<RwLock<Vec<ScopedModel>>>,
steering_messages: Arc<RwLock<VecDeque<String>>>,
follow_up_messages: Arc<RwLock<VecDeque<String>>>,
compaction_config: Arc<RwLock<CompactionConfig>>,
compaction_abort: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
overflow_recovery_attempted: Arc<RwLock<bool>>,
session_id: Arc<RwLock<String>>,
cwd: String,
streaming: Arc<AtomicBool>,
should_stop: Arc<AtomicBool>,
extension_runner: Arc<RwLock<Option<ExtensionRunner>>>,
}
#[allow(dead_code)]
impl AgentSession {
pub fn new(
agent: Arc<Agent>,
settings: Settings,
session_manager: SessionManager,
cwd: String,
) -> Self {
let session_id = session_manager.get_session_id();
let compaction_config = CompactionConfig {
enabled: settings.auto_compaction,
..CompactionConfig::default()
};
let (event_tx, _event_rx) = mpsc::unbounded_channel();
Self {
agent,
settings: Arc::new(RwLock::new(settings)),
session_manager: Arc::new(RwLock::new(session_manager)),
listeners: Arc::new(RwLock::new(Vec::new())),
event_tx,
scoped_models: Arc::new(RwLock::new(Vec::new())),
steering_messages: Arc::new(RwLock::new(VecDeque::new())),
follow_up_messages: Arc::new(RwLock::new(VecDeque::new())),
compaction_config: Arc::new(RwLock::new(compaction_config)),
compaction_abort: Arc::new(Mutex::new(None)),
overflow_recovery_attempted: Arc::new(RwLock::new(false)),
session_id: Arc::new(RwLock::new(session_id)),
cwd,
streaming: Arc::new(AtomicBool::new(false)),
should_stop: Arc::new(AtomicBool::new(false)),
extension_runner: Arc::new(RwLock::new(None)),
}
}
pub fn model_id(&self) -> String {
self.agent.model_id()
}
#[allow(dead_code)]
pub fn state(&self) -> AgentState {
self.agent.state()
}
pub fn thinking_level(&self) -> ThinkingLevel {
self.settings.read().thinking_level
}
#[allow(dead_code)]
pub fn is_streaming(&self) -> bool {
self.streaming.load(Ordering::SeqCst)
}
#[allow(dead_code)]
pub fn messages(&self) -> Vec<Message> {
self.agent.state().messages
}
pub fn session_id(&self) -> String {
self.session_manager.read().get_session_id()
}
#[allow(dead_code)]
pub fn is_compacting(&self) -> bool {
match self.compaction_abort.try_lock() {
Ok(guard) => guard.is_some(), Err(_) => true, }
}
#[allow(dead_code)]
pub fn auto_retry_enabled(&self) -> bool {
true
}
pub fn session_stats(&self) -> SessionStats {
let state = self.agent.state();
let mut user_messages = 0usize;
let mut assistant_messages = 0usize;
let mut tool_results = 0usize;
let mut tool_calls = 0usize;
for msg in &state.messages {
match msg {
Message::User(_) => user_messages += 1,
Message::Assistant(a) => {
assistant_messages += 1;
for block in &a.content {
if matches!(block, oxi_ai::ContentBlock::ToolCall(_)) {
tool_calls += 1;
}
}
let _ = &a; }
Message::ToolResult(_) => tool_results += 1,
}
}
SessionStats {
session_id: self.session_id(),
user_messages,
assistant_messages,
tool_calls,
tool_results,
total_messages: state.messages.len(),
}
}
#[allow(dead_code)]
pub fn pending_message_count(&self) -> usize {
self.steering_messages.read().len() + self.follow_up_messages.read().len()
}
pub fn steering_messages(&self) -> Vec<String> {
self.steering_messages.read().iter().cloned().collect()
}
pub fn follow_up_messages(&self) -> Vec<String> {
self.follow_up_messages.read().iter().cloned().collect()
}
pub fn steering_queue(&self) -> Arc<RwLock<std::collections::VecDeque<String>>> {
self.steering_messages.clone()
}
pub fn follow_up_queue(&self) -> Arc<RwLock<std::collections::VecDeque<String>>> {
self.follow_up_messages.clone()
}
#[allow(dead_code)]
pub fn cwd(&self) -> &str {
&self.cwd
}
pub fn scoped_models(&self) -> Vec<ScopedModel> {
self.scoped_models.read().clone()
}
pub fn auto_compaction_enabled(&self) -> bool {
self.compaction_config.read().enabled
}
pub fn subscribe(
&self,
listener: Box<dyn Fn(&SessionEvent) + Send + Sync>,
) -> SessionListenerGuard {
let key = {
let mut listeners = self.listeners.write();
listeners.push(listener);
listeners.len() - 1
};
SessionListenerGuard {
listeners: Arc::clone(&self.listeners),
key,
}
}
#[allow(dead_code)]
pub fn subscribe_channel(&self) -> mpsc::UnboundedReceiver<SessionEvent> {
let (tx, rx) = mpsc::unbounded_channel();
self.subscribe(Box::new(move |event| {
let _ = tx.send(event.clone());
}));
rx
}
fn emit(&self, event: SessionEvent) {
let listeners = self.listeners.read();
for listener in listeners.iter() {
listener(&event);
}
let _ = self.event_tx.send(event);
}
fn emit_queue_update(&self) {
self.emit(SessionEvent::QueueUpdate {
steering: self.steering_messages(),
follow_up: self.follow_up_messages(),
});
}
#[allow(dead_code)]
pub async fn prompt(&self, text: String, options: PromptOptions) -> Result<()> {
if self.is_streaming() {
return match options.streaming_behavior {
Some(StreamingBehavior::Steer) => self.steer(text).await,
Some(StreamingBehavior::FollowUp) => self.follow_up(text).await,
None => {
anyhow::bail!(
"Agent is already processing. Specify streaming_behavior to queue the message."
);
}
};
}
let model_id = self.model_id();
if model_id.is_empty() {
anyhow::bail!("No model selected");
}
let steering_q = self.steering_messages.clone();
let follow_up_q = self.follow_up_messages.clone();
let hooks = oxi_agent::AgentHooks {
get_steering_messages: Some(Box::new(move || {
steering_q.write().drain(..).collect::<Vec<String>>()
})),
get_follow_up_messages: Some(Box::new(move || {
follow_up_q.write().drain(..).collect::<Vec<String>>()
})),
tool_execution: oxi_agent::ToolExecutionMode::Sequential,
..Default::default()
};
self.agent.set_hooks(hooks);
let (_response, events) = self.agent.run(text.clone()).await?;
self.process_events(events).await?;
Ok(())
}
#[allow(dead_code)]
pub fn prompt_streaming(&self, text: String) -> mpsc::UnboundedReceiver<AgentEvent> {
let (tx, rx) = mpsc::unbounded_channel();
self.streaming.store(true, Ordering::SeqCst);
let agent = Arc::clone(&self.agent);
let streaming = Arc::clone(&self.streaming);
tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
let local = tokio::task::LocalSet::new();
local
.run_until(async move {
let (agent_tx, agent_rx) = std::sync::mpsc::channel::<AgentEvent>();
let agent_for_task = Arc::clone(&agent);
let agent_handle = tokio::task::spawn_local(async move {
agent_for_task.run_with_channel(text, agent_tx).await
});
while let Ok(event) = agent_rx.recv() {
let _ = tx.send(event);
}
match agent_handle.await {
Ok(Ok(_response)) => {
}
Ok(Err(e)) => {
let _ = tx.send(AgentEvent::Error {
message: e.to_string(),
session_id: None,
});
}
Err(join_err) => {
let _ = tx.send(AgentEvent::Error {
message: format!("Agent task failed: {}", join_err),
session_id: None,
});
}
}
streaming.store(false, Ordering::SeqCst);
})
.await;
});
});
rx
}
pub async fn steer(&self, text: String) -> Result<()> {
{
let mut queue = self.steering_messages.write();
queue.push_back(text.clone());
}
self.emit_queue_update();
self.agent.state().add_user_message(text);
Ok(())
}
#[allow(dead_code)]
pub async fn follow_up(&self, text: String) -> Result<()> {
{
let mut queue = self.follow_up_messages.write();
queue.push_back(text.clone());
}
self.emit_queue_update();
Ok(())
}
pub async fn abort(&self) {
tracing::debug!("AgentSession::abort() — setting should_stop flag");
self.should_stop.store(true, Ordering::SeqCst);
self.clear_queue();
}
pub fn should_stop_flag(&self) -> Arc<AtomicBool> {
Arc::clone(&self.should_stop)
}
pub fn reset_should_stop(&self) {
self.should_stop.store(false, Ordering::SeqCst);
}
pub fn clear_queue(&self) -> (Vec<String>, Vec<String>) {
let steering: Vec<String> = self.steering_messages.write().drain(..).collect();
let follow_up: Vec<String> = self.follow_up_messages.write().drain(..).collect();
self.emit_queue_update();
(steering, follow_up)
}
pub fn set_model(&self, model_id: &str) -> Result<()> {
self.agent.switch_model(model_id)?;
{
let mut sm = self.session_manager.write();
let parts: Vec<&str> = model_id.split('/').collect();
if parts.len() >= 2 {
sm.append_model_change(parts[0], &parts[1..].join("/"));
}
}
{
let mut settings = self.settings.write();
let parts: Vec<&str> = model_id.split('/').collect();
if parts.len() >= 2 {
settings.default_provider = Some(parts[0].to_string());
settings.default_model = Some(parts[1..].join("/"));
} else {
settings.default_model = Some(model_id.to_string());
}
}
Ok(())
}
pub fn set_scoped_models(&self, models: Vec<ScopedModel>) {
*self.scoped_models.write() = models;
}
pub fn set_thinking_level(&self, level: ThinkingLevel) {
let old_level = self.thinking_level();
if level == old_level {
return;
}
{
let mut settings = self.settings.write();
settings.thinking_level = level;
}
{
let mut sm = self.session_manager.write();
sm.append_thinking_level_change(&format!("{:?}", level).to_lowercase());
}
self.emit(SessionEvent::ThinkingLevelChanged { level });
}
pub fn cycle_thinking_level(&self) -> Option<ThinkingLevel> {
let levels = [
ThinkingLevel::Off,
ThinkingLevel::Minimal,
ThinkingLevel::Low,
ThinkingLevel::Medium,
ThinkingLevel::High,
ThinkingLevel::XHigh,
];
let current = self.thinking_level();
let current_index = levels.iter().position(|l| *l == current).unwrap_or(0);
let next_index = (current_index + 1) % levels.len();
let next = levels[next_index];
self.set_thinking_level(next);
Some(next)
}
pub async fn compact(&self, custom_instructions: Option<String>) -> Result<CompactionResult> {
self.emit(SessionEvent::CompactionStart {
reason: CompactionReason::Manual,
});
let result = self.run_compaction(custom_instructions).await;
match &result {
Ok(_r) => self.emit(SessionEvent::CompactionEnd {
reason: CompactionReason::Manual,
error_message: None,
}),
Err(e) => self.emit(SessionEvent::CompactionEnd {
reason: CompactionReason::Manual,
error_message: Some(e.to_string()),
}),
}
result
}
#[allow(dead_code)]
async fn check_auto_compaction(&self) {
let config = self.compaction_config.read().clone();
if !config.enabled {
return;
}
let state = self.agent.state();
let messages = &state.messages;
if messages.is_empty() {
return;
}
let context_json = serde_json::to_string(messages).unwrap_or_default();
let estimated_tokens = oxi_ai::estimate_tokens(&context_json);
let context_window = 128_000;
let ratio = estimated_tokens as f32 / context_window as f32;
if ratio >= config.threshold {
tracing::info!(
"Auto-compaction triggered: {} tokens ({:.0}%) >= {:.0}% of {}",
estimated_tokens,
ratio * 100.0,
config.threshold * 100.0,
context_window,
);
self.emit(SessionEvent::CompactionStart {
reason: CompactionReason::Threshold,
});
let result = self.run_compaction(None).await;
match result {
Ok(_r) => self.emit(SessionEvent::CompactionEnd {
reason: CompactionReason::Threshold,
error_message: None,
}),
Err(e) => {
tracing::warn!("Auto-compaction failed: {}", e);
self.emit(SessionEvent::CompactionEnd {
reason: CompactionReason::Threshold,
error_message: Some(format!("Auto-compaction failed: {}", e)),
});
}
}
}
}
async fn run_compaction(
&self,
_custom_instructions: Option<String>,
) -> Result<CompactionResult> {
let state = self.agent.state();
let messages = state.messages.clone();
if messages.len() < 3 {
anyhow::bail!("Nothing to compact (session too small)");
}
let compacted = self
.agent
.compaction_manager()
.compact_if_needed(&messages, None, state.estimate_tokens(), state.iteration)
.await
.context("Compaction failed")?;
match compacted {
Some(ctx) => {
let tokens_before = state.estimate_tokens();
self.agent
.state()
.replace_messages(ctx.kept_messages.clone());
self.persist_session();
Ok(CompactionResult { tokens_before })
}
None => {
anyhow::bail!("Nothing to compact");
}
}
}
#[allow(dead_code)]
pub async fn abort_compaction(&self) {
let mut guard = self.compaction_abort.lock().await;
if let Some(handle) = guard.take() {
handle.abort();
}
}
#[allow(dead_code)]
pub fn set_auto_compaction_enabled(&self, enabled: bool) {
self.compaction_config.write().enabled = enabled;
self.settings.write().auto_compaction = enabled;
}
fn persist_session(&self) {
let state = self.agent.state();
let messages = &state.messages;
let total = messages.len();
if total == 0 {
return;
}
let mut sm = self.session_manager.write();
let persisted = sm.persisted_count();
if persisted >= total {
return; }
for msg in &messages[persisted..] {
match msg {
Message::User(u) => {
let content = match &u.content {
oxi_ai::MessageContent::Text(t) => t.clone(),
oxi_ai::MessageContent::Blocks(blocks) => blocks
.iter()
.filter_map(|b| b.as_text())
.collect::<Vec<_>>()
.join(""),
};
sm.append_message(AgentMessage::User {
content: oxi_store::session::ContentValue::String(content),
});
}
Message::Assistant(a) => {
let content_blocks: Vec<oxi_store::session::AssistantContentBlock> = a
.content
.iter()
.map(|b| match b {
oxi_ai::ContentBlock::Text(t) => {
oxi_store::session::AssistantContentBlock::Text {
text: t.text.clone(),
}
}
oxi_ai::ContentBlock::Thinking(t) => {
oxi_store::session::AssistantContentBlock::Thinking {
thinking: t.thinking.clone(),
}
}
oxi_ai::ContentBlock::ToolCall(tc) => {
oxi_store::session::AssistantContentBlock::ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
}
}
oxi_ai::ContentBlock::Image(img) => {
oxi_store::session::AssistantContentBlock::ImageResult {
data: img.data.clone(),
media_type: img.mime_type.clone(),
}
}
oxi_ai::ContentBlock::Unknown(v) => {
oxi_store::session::AssistantContentBlock::Text {
text: v.to_string(),
}
}
})
.collect();
sm.append_message(AgentMessage::Assistant {
content: content_blocks,
provider: Some(a.provider.clone()),
model_id: Some(a.model.clone()),
usage: Some(oxi_store::session::Usage {
input: Some(a.usage.input as i64),
output: Some(a.usage.output as i64),
cache_read: Some(a.usage.cache_read as i64),
cache_write: Some(a.usage.cache_write as i64),
total_tokens: Some(a.usage.total_tokens as i64),
}),
stop_reason: Some(format!("{:?}", a.stop_reason)),
});
}
Message::ToolResult(t) => {
let content = t
.content
.iter()
.filter_map(|b| b.as_text())
.collect::<Vec<_>>()
.join("");
sm.append_message(AgentMessage::ToolResult {
content: oxi_store::session::ContentValue::String(content),
tool_call_id: t.tool_call_id.clone(),
});
}
}
}
sm.set_persisted_count(total);
}
async fn process_events(&self, events: Vec<AgentEvent>) -> Result<()> {
for event in &events {
self.emit(SessionEvent::Agent(event.clone()));
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.registry().emit_event(event);
match event {
AgentEvent::ToolCall { tool_call } => {
runner.emit_tool_call(&tool_call.name, &tool_call.arguments);
}
AgentEvent::ToolExecutionStart {
tool_name, args, ..
} => {
runner.emit_tool_call(tool_name, args);
}
AgentEvent::ToolExecutionEnd {
tool_name, result, ..
} => {
let tool_result = oxi_agent::AgentToolResult::success(&result.content);
runner.emit_tool_result_event(tool_name, &tool_result);
}
AgentEvent::Error { message, .. } => {
let err = anyhow::anyhow!("{}", message);
runner.registry().emit_error(&err);
}
_ => {}
}
}
}
let has_complete = events
.iter()
.any(|e| matches!(e, AgentEvent::AgentEnd { .. } | AgentEvent::Complete { .. }));
if has_complete {
self.check_auto_compaction().await;
let follow_ups: Vec<String> = self.follow_up_messages.write().drain(..).collect();
if !follow_ups.is_empty() {
self.emit_queue_update();
for msg in follow_ups {
let _ = self.agent.run(msg).await;
}
}
}
self.persist_session();
Ok(())
}
pub fn set_session_name(&self, name: String) {
let mut sm = self.session_manager.write();
sm.append_session_info(&name);
self.emit(SessionEvent::SessionInfoChanged);
}
pub fn reset(&self) {
self.agent.reset();
*self.overflow_recovery_attempted.write() = false;
self.clear_queue();
}
pub fn agent_ref(&self) -> Arc<Agent> {
Arc::clone(&self.agent)
}
pub fn persist(&self) {
self.persist_session();
}
pub fn clone_handle(&self) -> AgentSessionHandle {
AgentSessionHandle {
inner: Arc::new(self.clone_inner()),
}
}
fn clone_inner(&self) -> Self {
Self {
agent: Arc::clone(&self.agent),
settings: Arc::clone(&self.settings),
session_manager: Arc::clone(&self.session_manager),
listeners: Arc::clone(&self.listeners),
event_tx: self.event_tx.clone(),
scoped_models: Arc::clone(&self.scoped_models),
steering_messages: Arc::clone(&self.steering_messages),
follow_up_messages: Arc::clone(&self.follow_up_messages),
compaction_config: Arc::clone(&self.compaction_config),
compaction_abort: Arc::clone(&self.compaction_abort),
overflow_recovery_attempted: Arc::clone(&self.overflow_recovery_attempted),
session_id: Arc::clone(&self.session_id),
cwd: self.cwd.clone(),
streaming: Arc::clone(&self.streaming),
should_stop: Arc::clone(&self.should_stop),
extension_runner: Arc::clone(&self.extension_runner),
}
}
pub fn set_extension_runner(&self, runner: ExtensionRunner) {
{
let guard = self.extension_runner.read();
if let Some(existing) = guard.as_ref() {
let session_id = self.session_id();
let shutdown_event = SessionShutdownEvent {
reason: SessionShutdownReason::Reload,
target_session_file: None,
};
existing.emit_session_shutdown_event(&shutdown_event);
existing.registry().emit_session_end(&session_id);
existing.registry().emit_unload();
}
}
{
let mut guard = self.extension_runner.write();
*guard = Some(runner);
}
{
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
let ctx = self.build_extension_context();
runner.registry().emit_load(&ctx);
let session_id = self.session_id();
runner.registry().emit_session_start(&session_id);
}
}
tracing::debug!("ExtensionRunner installed into AgentSession");
}
pub fn extension_runner(&self) -> parking_lot::RwLockReadGuard<'_, Option<ExtensionRunner>> {
self.extension_runner.read()
}
pub fn take_extension_runner(&self) -> Option<ExtensionRunner> {
{
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
let session_id = self.session_id();
let shutdown_event = SessionShutdownEvent {
reason: SessionShutdownReason::Quit,
target_session_file: None,
};
runner.emit_session_shutdown_event(&shutdown_event);
runner.registry().emit_session_end(&session_id);
runner.registry().emit_unload();
}
}
self.extension_runner.write().take()
}
pub fn build_extension_context(&self) -> ExtensionContext {
ExtensionContextBuilder::new(PathBuf::from(&self.cwd))
.settings(Arc::clone(&self.settings))
.build()
}
pub fn forward_event_to_extensions(&self, event: &AgentEvent) {
self.emit(SessionEvent::Agent(event.clone()));
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.registry().emit_event(event);
match event {
AgentEvent::ToolCall { tool_call } => {
runner.emit_tool_call(&tool_call.name, &tool_call.arguments);
}
AgentEvent::ToolExecutionStart {
tool_name, args, ..
} => {
runner.emit_tool_call(tool_name, args);
}
AgentEvent::ToolExecutionEnd {
tool_name, result, ..
} => {
let tool_result = oxi_agent::AgentToolResult::success(&result.content);
runner.emit_tool_result_event(tool_name, &tool_result);
}
_ => {}
}
}
}
pub fn has_extension_handlers(&self, event_type: &str) -> bool {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.has_handlers(event_type)
} else {
false
}
}
pub fn extension_tools(&self) -> Vec<Arc<dyn oxi_agent::AgentTool>> {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.all_tools()
} else {
Vec::new()
}
}
pub fn extension_commands(&self) -> Vec<crate::extensions::Command> {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.all_commands()
} else {
Vec::new()
}
}
pub fn emit_before_tool_call(
&self,
tool_name: &str,
params: &serde_json::Value,
) -> crate::extensions::ToolCallEmitResult {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.emit_tool_call(tool_name, params)
} else {
crate::extensions::ToolCallEmitResult::default()
}
}
pub fn emit_after_tool_result(
&self,
tool_name: &str,
result: &oxi_agent::AgentToolResult,
) -> crate::extensions::ToolResultEmitResult {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.emit_tool_result_event(tool_name, result)
} else {
crate::extensions::ToolResultEmitResult::default()
}
}
pub fn process_input_through_extensions(
&self,
text: &str,
source: InputSource,
) -> ExtInputEventResult {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
let ext_source = match source {
InputSource::Interactive => crate::extensions::InputSource::Interactive,
InputSource::Extension => crate::extensions::InputSource::Extension,
InputSource::Rpc => crate::extensions::InputSource::Rpc,
};
let mut event = ExtInputEvent {
text: text.to_string(),
source: ext_source,
};
runner.emit_input_event(&mut event)
} else {
ExtInputEventResult::Continue
}
}
pub fn notify_extensions_message_sent(&self, msg: &str) {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.registry().emit_message_sent(msg);
}
}
pub fn notify_extensions_message_received(&self, msg: &str) {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
runner.registry().emit_message_received(msg);
}
}
pub fn notify_extensions_settings_changed(&self) {
let guard = self.extension_runner.read();
if let Some(runner) = guard.as_ref() {
let settings = self.settings.read().clone();
runner.registry().emit_settings_changed(&settings);
}
}
}
pub struct SessionListenerGuard {
#[allow(clippy::type_complexity)]
listeners: Arc<RwLock<Vec<Box<dyn Fn(&SessionEvent) + Send + Sync>>>>,
key: usize,
}
impl Drop for SessionListenerGuard {
fn drop(&mut self) {
let mut listeners = self.listeners.write();
if self.key < listeners.len() {
listeners[self.key] = Box::new(|_| {});
}
}
}
#[derive(Clone)]
pub struct AgentSessionHandle {
inner: Arc<AgentSession>,
}
impl std::ops::Deref for AgentSessionHandle {
type Target = AgentSession;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use futures::Stream;
use oxi_agent::AgentConfig;
use oxi_ai::{Model, Provider, ProviderError, ProviderEvent};
use std::pin::Pin;
use std::task::{Context as TaskContext, Poll};
struct MockProvider;
struct EmptyStream;
impl Stream for EmptyStream {
type Item = ProviderEvent;
fn poll_next(self: Pin<&mut Self>, _cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(None)
}
}
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &Model,
_context: &oxi_ai::Context,
_options: Option<oxi_ai::StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
Ok(Box::pin(EmptyStream))
}
fn name(&self) -> &str {
"mock"
}
}
fn make_session() -> AgentSession {
let provider = Arc::new(MockProvider);
let config = AgentConfig::new("anthropic/claude-sonnet-4-20250514");
let agent = Arc::new(Agent::new(
provider,
config,
Arc::new(oxi_agent::ToolRegistry::new()),
));
let settings = Settings::default();
let session_manager = SessionManager::in_memory("/tmp/test");
AgentSession::new(agent, settings, session_manager, "/tmp/test".to_string())
}
#[test]
fn test_session_creation_basic_fields() {
let session = make_session();
assert!(!session.session_id().is_empty());
assert_eq!(session.cwd(), "/tmp/test");
assert!(!session.is_streaming());
assert!(session.messages().is_empty());
}
#[test]
fn test_session_creation_model_id() {
let session = make_session();
assert_eq!(session.model_id(), "anthropic/claude-sonnet-4-20250514");
}
#[test]
fn test_session_creation_default_thinking_level() {
let session = make_session();
assert_eq!(session.thinking_level(), ThinkingLevel::Medium);
}
#[test]
fn test_session_creation_empty_queues() {
let session = make_session();
assert_eq!(session.pending_message_count(), 0);
assert!(session.steering_messages().is_empty());
assert!(session.follow_up_messages().is_empty());
}
#[test]
fn test_scoped_models_empty_by_default() {
let session = make_session();
assert!(session.scoped_models().is_empty());
}
#[test]
fn test_set_scoped_models() {
let session = make_session();
let models = vec![
ScopedModel {
provider: "anthropic".to_string(),
model_id: "claude-sonnet-4-20250514".to_string(),
},
ScopedModel {
provider: "openai".to_string(),
model_id: "gpt-4o".to_string(),
},
ScopedModel {
provider: "google".to_string(),
model_id: "gemini-2.0-flash".to_string(),
},
];
session.set_scoped_models(models);
let retrieved = session.scoped_models();
assert_eq!(retrieved.len(), 3);
assert_eq!(retrieved[0].provider, "anthropic");
assert_eq!(retrieved[2].model_id, "gemini-2.0-flash");
}
#[test]
fn test_scoped_model_fields() {
let model = ScopedModel {
provider: "anthropic".to_string(),
model_id: "claude-sonnet-4-20250514".to_string(),
};
assert_eq!(model.provider, "anthropic");
assert_eq!(model.model_id, "claude-sonnet-4-20250514");
}
#[test]
fn test_set_thinking_level() {
let session = make_session();
assert_eq!(session.thinking_level(), ThinkingLevel::Medium);
session.set_thinking_level(ThinkingLevel::High);
assert_eq!(session.thinking_level(), ThinkingLevel::High);
session.set_thinking_level(ThinkingLevel::Off);
assert_eq!(session.thinking_level(), ThinkingLevel::Off);
session.set_thinking_level(ThinkingLevel::Minimal);
assert_eq!(session.thinking_level(), ThinkingLevel::Minimal);
}
#[test]
fn test_set_thinking_level_noop_when_same() {
let session = make_session();
session.set_thinking_level(ThinkingLevel::Medium);
assert_eq!(session.thinking_level(), ThinkingLevel::Medium);
}
#[test]
fn test_cycle_thinking_level() {
let session = make_session();
assert_eq!(session.thinking_level(), ThinkingLevel::Medium);
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::High));
assert_eq!(session.thinking_level(), ThinkingLevel::High);
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::XHigh));
assert_eq!(session.thinking_level(), ThinkingLevel::XHigh);
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::Off));
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::Minimal));
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::Low));
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::Medium));
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::High));
let next = session.cycle_thinking_level();
assert_eq!(next, Some(ThinkingLevel::XHigh));
}
#[test]
fn test_thinking_level_full_cycle() {
let levels = [
ThinkingLevel::Off,
ThinkingLevel::Minimal,
ThinkingLevel::Low,
ThinkingLevel::Medium,
ThinkingLevel::High,
ThinkingLevel::XHigh,
];
let mut current = 0;
for _ in 0..levels.len() {
current = (current + 1) % levels.len();
}
assert_eq!(current, 0); }
#[tokio::test]
async fn test_steer_message() {
let session = make_session();
session.steer("direction 1".to_string()).await.unwrap();
assert_eq!(session.steering_messages(), vec!["direction 1"]);
assert_eq!(session.pending_message_count(), 1);
}
#[tokio::test]
async fn test_follow_up_message() {
let session = make_session();
session.follow_up("next task".to_string()).await.unwrap();
assert_eq!(session.follow_up_messages(), vec!["next task"]);
assert_eq!(session.pending_message_count(), 1);
}
#[tokio::test]
async fn test_multiple_steer_messages() {
let session = make_session();
session.steer("first".to_string()).await.unwrap();
session.steer("second".to_string()).await.unwrap();
session.steer("third".to_string()).await.unwrap();
assert_eq!(
session.steering_messages(),
vec!["first", "second", "third"]
);
assert_eq!(session.pending_message_count(), 3);
}
#[tokio::test]
async fn test_multiple_follow_up_messages() {
let session = make_session();
session.follow_up("a".to_string()).await.unwrap();
session.follow_up("b".to_string()).await.unwrap();
assert_eq!(session.follow_up_messages(), vec!["a", "b"]);
}
#[tokio::test]
async fn test_mixed_steer_and_follow_up() {
let session = make_session();
session.steer("steer-1".to_string()).await.unwrap();
session.follow_up("follow-1".to_string()).await.unwrap();
session.steer("steer-2".to_string()).await.unwrap();
assert_eq!(session.pending_message_count(), 3);
assert_eq!(session.steering_messages(), vec!["steer-1", "steer-2"]);
assert_eq!(session.follow_up_messages(), vec!["follow-1"]);
}
#[test]
fn test_clear_queue() {
let session = make_session();
{
let mut q = session.steering_messages.write();
q.push_back("s1".to_string());
q.push_back("s2".to_string());
}
{
let mut q = session.follow_up_messages.write();
q.push_back("f1".to_string());
}
assert_eq!(session.pending_message_count(), 3);
let (steering, follow_up) = session.clear_queue();
assert_eq!(steering, vec!["s1", "s2"]);
assert_eq!(follow_up, vec!["f1"]);
assert_eq!(session.pending_message_count(), 0);
}
#[test]
fn test_clear_empty_queue() {
let session = make_session();
let (s, f) = session.clear_queue();
assert!(s.is_empty());
assert!(f.is_empty());
}
#[test]
fn test_auto_compaction_default_enabled() {
let session = make_session();
assert!(session.auto_compaction_enabled());
}
#[test]
fn test_set_auto_compaction_enabled() {
let session = make_session();
session.set_auto_compaction_enabled(true);
assert!(session.auto_compaction_enabled());
session.set_auto_compaction_enabled(false);
assert!(!session.auto_compaction_enabled());
}
#[test]
fn test_is_compacting_initially_false() {
let session = make_session();
assert!(!session.is_compacting());
}
#[test]
fn test_compaction_reason_variants() {
assert_eq!(CompactionReason::Manual, CompactionReason::Manual);
assert_ne!(CompactionReason::Manual, CompactionReason::Threshold);
assert_ne!(CompactionReason::Threshold, CompactionReason::Overflow);
assert_ne!(CompactionReason::Manual, CompactionReason::Overflow);
}
#[test]
fn test_compaction_config_default() {
let config = CompactionConfig::default();
assert!(config.enabled);
assert!(config.threshold > 0.0);
}
#[test]
fn test_session_stats_empty() {
let session = make_session();
let stats = session.session_stats();
assert!(!stats.session_id.is_empty());
assert_eq!(stats.user_messages, 0);
assert_eq!(stats.assistant_messages, 0);
assert_eq!(stats.tool_calls, 0);
assert_eq!(stats.tool_results, 0);
assert_eq!(stats.total_messages, 0);
}
#[test]
fn test_session_stats_default() {
let stats = SessionStats {
session_id: "test".to_string(),
user_messages: 0,
assistant_messages: 0,
tool_calls: 0,
tool_results: 0,
total_messages: 0,
};
assert_eq!(stats.total_messages, 0);
}
#[test]
fn test_persist_session_empty_messages() {
let session = make_session();
session.persist_session();
}
#[test]
fn test_persist_session_empty_is_noop() {
let session = make_session();
session.persist_session();
let sm = session.session_manager.read();
assert_eq!(sm.persisted_count(), 0);
}
#[test]
fn test_persist_session_set_persisted_count() {
let session = make_session();
{
let mut sm = session.session_manager.write();
sm.set_persisted_count(5);
}
let sm = session.session_manager.read();
assert_eq!(sm.persisted_count(), 5);
}
#[test]
fn test_persist_session_idempotent_with_set() {
let session = make_session();
{
let mut sm = session.session_manager.write();
sm.set_persisted_count(3);
}
session.persist_session();
let sm = session.session_manager.read();
assert_eq!(sm.persisted_count(), 3);
}
#[test]
fn test_set_session_name() {
let session = make_session();
session.set_session_name("My Test Session".to_string());
assert!(!session.session_id().is_empty());
}
#[test]
fn test_subscribe_receives_events() {
let session = make_session();
let received = Arc::new(RwLock::new(Vec::new()));
let received_clone = received.clone();
let _guard = session.subscribe(Box::new(move |event| {
received_clone.write().push(format!("{:?}", event));
}));
session.set_thinking_level(ThinkingLevel::High);
let events = received.read();
assert!(
!events.is_empty(),
"Listener should receive at least one event"
);
assert!(events.iter().any(|e| e.contains("ThinkingLevelChanged")));
}
#[test]
fn test_subscribe_channel_with_guard() {
let session = make_session();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<SessionEvent>();
let _guard = session.subscribe(Box::new(move |event| {
let _ = tx.send(event.clone());
}));
session.set_thinking_level(ThinkingLevel::Off);
let event = rx
.try_recv()
.expect("Should receive event via subscribed channel");
match event {
SessionEvent::ThinkingLevelChanged { level } => {
assert_eq!(level, ThinkingLevel::Off);
}
other => panic!("Expected ThinkingLevelChanged, got {:?}", other),
}
}
#[test]
fn test_reset_clears_queues_and_overflow() {
let session = make_session();
{
let mut q = session.steering_messages.write();
q.push_back("steer".to_string());
}
{
let mut q = session.follow_up_messages.write();
q.push_back("follow".to_string());
}
*session.overflow_recovery_attempted.write() = true;
assert_eq!(session.pending_message_count(), 2);
session.reset();
assert_eq!(session.pending_message_count(), 0);
assert!(!*session.overflow_recovery_attempted.read());
}
#[test]
fn test_clone_handle_shares_state() {
let session = make_session();
let handle = session.clone_handle();
assert_eq!(session.session_id(), handle.session_id());
handle.set_thinking_level(ThinkingLevel::High);
assert_eq!(session.thinking_level(), ThinkingLevel::High);
}
#[test]
fn test_streaming_behavior_variants() {
assert_eq!(StreamingBehavior::Steer, StreamingBehavior::Steer);
assert_ne!(StreamingBehavior::Steer, StreamingBehavior::FollowUp);
}
#[test]
fn test_input_source_default() {
assert_eq!(InputSource::default(), InputSource::Interactive);
}
#[test]
fn test_prompt_options_default() {
let opts = PromptOptions::default();
assert!(opts.expand_templates);
assert!(opts.images.is_empty());
assert!(opts.streaming_behavior.is_none());
assert_eq!(opts.source, InputSource::Interactive);
}
#[test]
fn test_no_extension_runner_by_default() {
let session = make_session();
let guard = session.extension_runner();
assert!(guard.is_none());
}
#[test]
fn test_extension_tools_empty_without_runner() {
let session = make_session();
assert!(session.extension_tools().is_empty());
}
#[test]
fn test_extension_commands_empty_without_runner() {
let session = make_session();
assert!(session.extension_commands().is_empty());
}
#[test]
fn test_has_extension_handlers_false_without_runner() {
let session = make_session();
assert!(!session.has_extension_handlers("tool_call"));
}
#[test]
fn test_auto_retry_enabled() {
let session = make_session();
assert!(session.auto_retry_enabled());
}
#[test]
fn test_listener_guard_drop_removes() {
let session = make_session();
let received = Arc::new(RwLock::new(Vec::new()));
let received_clone = received.clone();
{
let _guard = session.subscribe(Box::new(move |event| {
received_clone.write().push(format!("{:?}", event));
}));
session.set_thinking_level(ThinkingLevel::High);
}
let count_after_drop = received.read().len();
assert_eq!(count_after_drop, 1); }
}