use crate::auth::AuthStorage;
use crate::compaction::{self, ResolvedCompactionSettings};
use crate::compaction_worker::{CompactionQuota, CompactionWorkerState};
use crate::error::{Error, Result};
use crate::extension_events::{InputEventOutcome, apply_input_event_response};
use crate::extension_tools::collect_extension_tool_wrappers;
use crate::extensions::{
EXTENSION_EVENT_TIMEOUT_MS, ExtensionDeliverAs, ExtensionEventName, ExtensionHostActions,
ExtensionLoadSpec, ExtensionManager, ExtensionPolicy, ExtensionRegion, ExtensionRuntimeHandle,
ExtensionSendMessage, ExtensionSendUserMessage, JsExtensionLoadSpec, JsExtensionRuntimeHandle,
NativeRustExtensionLoadSpec, NativeRustExtensionRuntimeHandle, RepairPolicyMode,
resolve_extension_load_spec,
};
#[cfg(feature = "wasm-host")]
use crate::extensions::{WasmExtensionHost, WasmExtensionLoadSpec};
use crate::extensions_js::{PiJsRuntimeConfig, RepairMode};
use crate::model::{
AssistantMessage, AssistantMessageEvent, ContentBlock, CustomMessage, ImageContent, Message,
StopReason, StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage,
UserContent, UserMessage,
};
use crate::models::{ModelEntry, ModelRegistry};
use crate::provider::{Context, Provider, StreamOptions, ToolDef};
use crate::session::{AutosaveFlushTrigger, Session, SessionHandle};
use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
use asupersync::sync::{Mutex, Notify};
use async_trait::async_trait;
use chrono::Utc;
use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream;
use serde::Serialize;
use serde_json::{Value, json};
use std::borrow::Cow;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::{AtomicBool, Ordering};
const MAX_CONCURRENT_TOOLS: usize = 8;
#[derive(Debug, Clone)]
pub struct AgentConfig {
pub system_prompt: Option<String>,
pub max_tool_iterations: usize,
pub stream_options: StreamOptions,
pub block_images: bool,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
system_prompt: None,
max_tool_iterations: 50,
stream_options: StreamOptions::default(),
block_images: false,
}
}
}
pub type MessageFetcher = Arc<dyn Fn() -> BoxFuture<'static, Vec<Message>> + Send + Sync + 'static>;
type AgentEventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync + 'static>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueMode {
All,
OneAtATime,
}
impl QueueMode {
pub const fn as_str(self) -> &'static str {
match self {
Self::All => "all",
Self::OneAtATime => "one-at-a-time",
}
}
}
#[derive(Debug, Clone, Copy)]
enum QueueKind {
Steering,
FollowUp,
}
#[derive(Debug, Clone)]
struct QueuedMessage {
seq: u64,
enqueued_at: i64,
message: Message,
}
#[derive(Debug)]
struct MessageQueue {
steering: VecDeque<QueuedMessage>,
follow_up: VecDeque<QueuedMessage>,
steering_mode: QueueMode,
follow_up_mode: QueueMode,
next_seq: u64,
}
impl MessageQueue {
const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
Self {
steering: VecDeque::new(),
follow_up: VecDeque::new(),
steering_mode,
follow_up_mode,
next_seq: 0,
}
}
const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
self.steering_mode = steering_mode;
self.follow_up_mode = follow_up_mode;
}
fn pending_count(&self) -> usize {
self.steering.len() + self.follow_up.len()
}
fn push(&mut self, kind: QueueKind, message: Message) -> u64 {
let seq = self.next_seq;
self.next_seq = self.next_seq.saturating_add(1);
let entry = QueuedMessage {
seq,
enqueued_at: Utc::now().timestamp_millis(),
message,
};
match kind {
QueueKind::Steering => self.steering.push_back(entry),
QueueKind::FollowUp => self.follow_up.push_back(entry),
}
seq
}
fn push_steering(&mut self, message: Message) -> u64 {
self.push(QueueKind::Steering, message)
}
fn push_follow_up(&mut self, message: Message) -> u64 {
self.push(QueueKind::FollowUp, message)
}
fn pop_steering(&mut self) -> Vec<Message> {
self.pop_kind(QueueKind::Steering)
}
fn pop_follow_up(&mut self) -> Vec<Message> {
self.pop_kind(QueueKind::FollowUp)
}
fn pop_kind(&mut self, kind: QueueKind) -> Vec<Message> {
let (queue, mode) = match kind {
QueueKind::Steering => (&mut self.steering, self.steering_mode),
QueueKind::FollowUp => (&mut self.follow_up, self.follow_up_mode),
};
match mode {
QueueMode::All => queue.drain(..).map(|entry| entry.message).collect(),
QueueMode::OneAtATime => queue
.pop_front()
.into_iter()
.map(|entry| entry.message)
.collect(),
}
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEvent {
AgentStart {
#[serde(rename = "sessionId")]
session_id: Arc<str>,
},
AgentEnd {
#[serde(rename = "sessionId")]
session_id: Arc<str>,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
},
TurnStart {
#[serde(rename = "sessionId")]
session_id: Arc<str>,
#[serde(rename = "turnIndex")]
turn_index: usize,
timestamp: i64,
},
TurnEnd {
#[serde(rename = "sessionId")]
session_id: Arc<str>,
#[serde(rename = "turnIndex")]
turn_index: usize,
message: Message,
#[serde(rename = "toolResults")]
tool_results: Vec<Message>,
},
MessageStart { message: Message },
MessageUpdate {
message: Message,
#[serde(rename = "assistantMessageEvent")]
assistant_message_event: AssistantMessageEvent,
},
MessageEnd { message: Message },
ToolExecutionStart {
#[serde(rename = "toolCallId")]
tool_call_id: String,
#[serde(rename = "toolName")]
tool_name: String,
args: serde_json::Value,
},
ToolExecutionUpdate {
#[serde(rename = "toolCallId")]
tool_call_id: String,
#[serde(rename = "toolName")]
tool_name: String,
args: serde_json::Value,
#[serde(rename = "partialResult")]
partial_result: ToolOutput,
},
ToolExecutionEnd {
#[serde(rename = "toolCallId")]
tool_call_id: String,
#[serde(rename = "toolName")]
tool_name: String,
result: ToolOutput,
#[serde(rename = "isError")]
is_error: bool,
},
AutoCompactionStart { reason: String },
AutoCompactionEnd {
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<serde_json::Value>,
aborted: bool,
#[serde(rename = "willRetry")]
will_retry: bool,
#[serde(rename = "errorMessage", skip_serializing_if = "Option::is_none")]
error_message: Option<String>,
},
AutoRetryStart {
attempt: u32,
#[serde(rename = "maxAttempts")]
max_attempts: u32,
#[serde(rename = "delayMs")]
delay_ms: u64,
#[serde(rename = "errorMessage")]
error_message: String,
},
AutoRetryEnd {
success: bool,
attempt: u32,
#[serde(rename = "finalError", skip_serializing_if = "Option::is_none")]
final_error: Option<String>,
},
ExtensionError {
#[serde(rename = "extensionId", skip_serializing_if = "Option::is_none")]
extension_id: Option<String>,
event: String,
error: String,
},
}
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortSignalInner>,
}
#[derive(Debug, Clone)]
pub struct AbortSignal {
inner: Arc<AbortSignalInner>,
}
#[derive(Debug)]
struct AbortSignalInner {
aborted: AtomicBool,
notify: Notify,
}
impl AbortHandle {
#[must_use]
pub fn new() -> (Self, AbortSignal) {
let inner = Arc::new(AbortSignalInner {
aborted: AtomicBool::new(false),
notify: Notify::new(),
});
(
Self {
inner: Arc::clone(&inner),
},
AbortSignal { inner },
)
}
pub fn abort(&self) {
if !self.inner.aborted.swap(true, Ordering::SeqCst) {
self.inner.notify.notify_waiters();
}
}
}
impl AbortSignal {
#[must_use]
pub fn is_aborted(&self) -> bool {
self.inner.aborted.load(Ordering::SeqCst)
}
pub async fn wait(&self) {
if self.is_aborted() {
return;
}
loop {
self.inner.notify.notified().await;
if self.is_aborted() {
return;
}
}
}
}
pub struct Agent {
provider: Arc<dyn Provider>,
tools: ToolRegistry,
config: AgentConfig,
extensions: Option<ExtensionManager>,
messages: Vec<Message>,
steering_fetchers: Vec<MessageFetcher>,
follow_up_fetchers: Vec<MessageFetcher>,
message_queue: MessageQueue,
cached_tool_defs: Option<Vec<ToolDef>>,
}
impl Agent {
pub fn new(provider: Arc<dyn Provider>, tools: ToolRegistry, config: AgentConfig) -> Self {
Self {
provider,
tools,
config,
extensions: None,
messages: Vec::new(),
steering_fetchers: Vec::new(),
follow_up_fetchers: Vec::new(),
message_queue: MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime),
cached_tool_defs: None,
}
}
#[must_use]
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn clear_messages(&mut self) {
self.messages.clear();
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn replace_messages(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
pub fn set_provider(&mut self, provider: Arc<dyn Provider>) {
self.provider = provider;
}
pub fn register_message_fetchers(
&mut self,
steering: Option<MessageFetcher>,
follow_up: Option<MessageFetcher>,
) {
if let Some(fetcher) = steering {
self.steering_fetchers.push(fetcher);
}
if let Some(fetcher) = follow_up {
self.follow_up_fetchers.push(fetcher);
}
}
pub fn extend_tools<I>(&mut self, tools: I)
where
I: IntoIterator<Item = Box<dyn Tool>>,
{
self.tools.extend(tools);
self.cached_tool_defs = None; }
pub fn queue_steering(&mut self, message: Message) -> u64 {
self.message_queue.push_steering(message)
}
pub fn queue_follow_up(&mut self, message: Message) -> u64 {
self.message_queue.push_follow_up(message)
}
pub const fn set_queue_modes(&mut self, steering: QueueMode, follow_up: QueueMode) {
self.message_queue.set_modes(steering, follow_up);
}
#[must_use]
pub fn queued_message_count(&self) -> usize {
self.message_queue.pending_count()
}
pub fn provider(&self) -> Arc<dyn Provider> {
Arc::clone(&self.provider)
}
pub const fn stream_options(&self) -> &StreamOptions {
&self.config.stream_options
}
pub const fn stream_options_mut(&mut self) -> &mut StreamOptions {
&mut self.config.stream_options
}
fn build_context(&mut self) -> Context<'_> {
let messages: Cow<'_, [Message]> = if self.config.block_images {
let mut msgs = self.messages.clone();
msgs.retain(|m| match m {
Message::Custom(c) => c.display,
_ => true,
});
let stats = filter_images_for_provider(&mut msgs);
if stats.removed_images > 0 {
tracing::debug!(
filtered_images = stats.removed_images,
affected_messages = stats.affected_messages,
"Filtered image content from outbound provider context (images.block_images=true)"
);
}
Cow::Owned(msgs)
} else {
let has_hidden = self.messages.iter().any(|m| match m {
Message::Custom(c) => !c.display,
_ => false,
});
if has_hidden {
let mut msgs = self.messages.clone();
msgs.retain(|m| match m {
Message::Custom(c) => c.display,
_ => true,
});
Cow::Owned(msgs)
} else {
Cow::Borrowed(self.messages.as_slice())
}
};
if self.cached_tool_defs.is_none() {
let defs: Vec<ToolDef> = self
.tools
.tools()
.iter()
.map(|t| ToolDef {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters(),
})
.collect();
self.cached_tool_defs = Some(defs);
}
let tools = Cow::Borrowed(self.cached_tool_defs.as_deref().unwrap());
Context {
system_prompt: self.config.system_prompt.as_deref().map(Cow::Borrowed),
messages,
tools,
}
}
pub async fn run(
&mut self,
user_input: impl Into<String>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_with_abort(user_input, None, on_event).await
}
pub async fn run_with_abort(
&mut self,
user_input: impl Into<String>,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let user_message = Message::User(UserMessage {
content: UserContent::Text(user_input.into()),
timestamp: Utc::now().timestamp_millis(),
});
self.run_loop(vec![user_message], Arc::new(on_event), abort)
.await
}
pub async fn run_with_content(
&mut self,
content: Vec<ContentBlock>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_with_content_with_abort(content, None, on_event)
.await
}
pub async fn run_with_content_with_abort(
&mut self,
content: Vec<ContentBlock>,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let user_message = Message::User(UserMessage {
content: UserContent::Blocks(content),
timestamp: Utc::now().timestamp_millis(),
});
self.run_loop(vec![user_message], Arc::new(on_event), abort)
.await
}
pub async fn run_with_message_with_abort(
&mut self,
message: Message,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_loop(vec![message], Arc::new(on_event), abort)
.await
}
pub async fn run_continue_with_abort(
&mut self,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_loop(Vec::new(), Arc::new(on_event), abort).await
}
fn build_abort_message(&self, partial: Option<&AssistantMessage>) -> AssistantMessage {
let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
content: Vec::new(),
api: self.provider.api().to_string(),
provider: self.provider.name().to_string(),
model: self.provider.model_id().to_string(),
usage: Usage::default(),
stop_reason: StopReason::Aborted,
error_message: Some("Aborted".to_string()),
timestamp: Utc::now().timestamp_millis(),
});
message.stop_reason = StopReason::Aborted;
message.error_message = Some("Aborted".to_string());
message.timestamp = Utc::now().timestamp_millis();
message
}
#[allow(clippy::too_many_lines)]
async fn run_loop(
&mut self,
prompts: Vec<Message>,
on_event: AgentEventHandler,
abort: Option<AbortSignal>,
) -> Result<AssistantMessage> {
let session_id: Arc<str> = self
.config
.stream_options
.session_id
.as_deref()
.unwrap_or("")
.into();
let mut iterations = 0usize;
let mut turn_index: usize = 0;
let mut new_messages: Vec<Message> = Vec::with_capacity(prompts.len() + 8);
let mut last_assistant: Option<Arc<AssistantMessage>> = None;
let agent_start_event = AgentEvent::AgentStart {
session_id: session_id.clone(),
};
self.dispatch_extension_lifecycle_event(&agent_start_event)
.await;
on_event(agent_start_event);
for prompt in prompts {
self.messages.push(prompt.clone());
on_event(AgentEvent::MessageStart {
message: prompt.clone(),
});
on_event(AgentEvent::MessageEnd {
message: prompt.clone(),
});
new_messages.push(prompt);
}
let mut pending_messages = self.drain_steering_messages().await;
loop {
let mut has_more_tool_calls = true;
let mut steering_after_tools: Option<Vec<Message>> = None;
while has_more_tool_calls || !pending_messages.is_empty() {
let current_turn_index = turn_index;
let turn_start_event = AgentEvent::TurnStart {
session_id: session_id.clone(),
turn_index: current_turn_index,
timestamp: Utc::now().timestamp_millis(),
};
self.dispatch_extension_lifecycle_event(&turn_start_event)
.await;
on_event(turn_start_event);
for message in std::mem::take(&mut pending_messages) {
self.messages.push(message.clone());
on_event(AgentEvent::MessageStart {
message: message.clone(),
});
on_event(AgentEvent::MessageEnd {
message: message.clone(),
});
new_messages.push(message);
}
if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
let abort_message = self.build_abort_message(None);
let message = Message::assistant(abort_message.clone());
self.messages.push(message.clone());
new_messages.push(message.clone());
on_event(AgentEvent::MessageStart {
message: message.clone(),
});
on_event(AgentEvent::MessageEnd {
message: message.clone(),
});
let turn_end_event = AgentEvent::TurnEnd {
session_id: session_id.clone(),
turn_index: current_turn_index,
message,
tool_results: Vec::new(),
};
self.dispatch_extension_lifecycle_event(&turn_end_event)
.await;
on_event(turn_end_event);
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: std::mem::take(&mut new_messages),
error: Some(
abort_message
.error_message
.clone()
.unwrap_or_else(|| "Aborted".to_string()),
),
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
return Ok(abort_message);
}
let assistant_message = match self
.stream_assistant_response(Arc::clone(&on_event), abort.clone())
.await
{
Ok(msg) => msg,
Err(err) => {
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: std::mem::take(&mut new_messages),
error: Some(err.to_string()),
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
return Err(err);
}
};
let assistant_arc = Arc::new(assistant_message);
last_assistant = Some(Arc::clone(&assistant_arc));
let assistant_event_message = Message::Assistant(Arc::clone(&assistant_arc));
new_messages.push(assistant_event_message.clone());
if matches!(
assistant_arc.stop_reason,
StopReason::Error | StopReason::Aborted
) {
let turn_end_event = AgentEvent::TurnEnd {
session_id: session_id.clone(),
turn_index: current_turn_index,
message: assistant_event_message.clone(),
tool_results: Vec::new(),
};
self.dispatch_extension_lifecycle_event(&turn_end_event)
.await;
on_event(turn_end_event);
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: std::mem::take(&mut new_messages),
error: assistant_arc.error_message.clone(),
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
return Ok(Arc::unwrap_or_clone(assistant_arc));
}
let tool_calls = extract_tool_calls(&assistant_arc.content);
has_more_tool_calls = !tool_calls.is_empty();
let mut tool_results: Vec<Arc<ToolResultMessage>> = Vec::new();
if has_more_tool_calls {
iterations += 1;
if iterations > self.config.max_tool_iterations {
let error_message = format!(
"Maximum tool iterations ({}) exceeded",
self.config.max_tool_iterations
);
let mut stop_message = (*assistant_arc).clone();
stop_message.stop_reason = StopReason::Error;
stop_message.error_message = Some(error_message.clone());
let stop_arc = Arc::new(stop_message.clone());
let stop_event_message = Message::Assistant(Arc::clone(&stop_arc));
if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
*last = stop_event_message.clone();
}
if let Some(last @ Message::Assistant(_)) = new_messages.last_mut() {
*last = stop_event_message.clone();
}
let turn_end_event = AgentEvent::TurnEnd {
session_id: session_id.clone(),
turn_index: current_turn_index,
message: stop_event_message,
tool_results: Vec::new(),
};
self.dispatch_extension_lifecycle_event(&turn_end_event)
.await;
on_event(turn_end_event);
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: std::mem::take(&mut new_messages),
error: Some(error_message),
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
return Ok(stop_message);
}
let outcome = match self
.execute_tool_calls(
&tool_calls,
Arc::clone(&on_event),
&mut new_messages,
abort.clone(),
)
.await
{
Ok(outcome) => outcome,
Err(err) => {
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: std::mem::take(&mut new_messages),
error: Some(err.to_string()),
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
return Err(err);
}
};
tool_results = outcome.tool_results;
steering_after_tools = outcome.steering_messages;
}
let tool_messages = tool_results
.iter()
.map(|r| Message::ToolResult(Arc::clone(r)))
.collect::<Vec<_>>();
let turn_end_event = AgentEvent::TurnEnd {
session_id: session_id.clone(),
turn_index: current_turn_index,
message: assistant_event_message.clone(),
tool_results: tool_messages,
};
self.dispatch_extension_lifecycle_event(&turn_end_event)
.await;
on_event(turn_end_event);
turn_index = turn_index.saturating_add(1);
if let Some(steering) = steering_after_tools.take() {
pending_messages = steering;
} else {
pending_messages = self.drain_steering_messages().await;
}
}
let follow_up = self.drain_follow_up_messages().await;
if follow_up.is_empty() {
break;
}
pending_messages = follow_up;
}
let Some(final_arc) = last_assistant else {
return Err(Error::api("Agent completed without assistant message"));
};
let agent_end_event = AgentEvent::AgentEnd {
session_id: session_id.clone(),
messages: new_messages,
error: None,
};
self.dispatch_extension_lifecycle_event(&agent_end_event)
.await;
on_event(agent_end_event);
Ok(Arc::unwrap_or_clone(final_arc))
}
async fn fetch_messages(&self, fetcher: Option<&MessageFetcher>) -> Vec<Message> {
if let Some(fetcher) = fetcher {
(fetcher)().await
} else {
Vec::new()
}
}
async fn dispatch_extension_lifecycle_event(&self, event: &AgentEvent) {
let Some(extensions) = &self.extensions else {
return;
};
let name = match event {
AgentEvent::AgentStart { .. } => ExtensionEventName::AgentStart,
AgentEvent::AgentEnd { .. } => ExtensionEventName::AgentEnd,
AgentEvent::TurnStart { .. } => ExtensionEventName::TurnStart,
AgentEvent::TurnEnd { .. } => ExtensionEventName::TurnEnd,
_ => return,
};
let payload = match serde_json::to_value(event) {
Ok(payload) => payload,
Err(err) => {
tracing::warn!("failed to serialize agent lifecycle event (fail-open): {err}");
return;
}
};
if let Err(err) = extensions.dispatch_event(name, Some(payload)).await {
tracing::warn!("agent lifecycle extension hook failed (fail-open): {err}");
}
}
async fn drain_steering_messages(&mut self) -> Vec<Message> {
for fetcher in &self.steering_fetchers {
let fetched = self.fetch_messages(Some(fetcher)).await;
for message in fetched {
self.message_queue.push_steering(message);
}
}
self.message_queue.pop_steering()
}
async fn drain_follow_up_messages(&mut self) -> Vec<Message> {
for fetcher in &self.follow_up_fetchers {
let fetched = self.fetch_messages(Some(fetcher)).await;
for message in fetched {
self.message_queue.push_follow_up(message);
}
}
self.message_queue.pop_follow_up()
}
#[allow(clippy::too_many_lines)]
async fn stream_assistant_response(
&mut self,
on_event: AgentEventHandler,
abort: Option<AbortSignal>,
) -> Result<AssistantMessage> {
let provider = Arc::clone(&self.provider);
let stream_options = self.config.stream_options.clone();
let context = self.build_context();
let mut stream = provider.stream(&context, &stream_options).await?;
let mut added_partial = false;
let mut sent_start = false;
loop {
let event_result = if let Some(signal) = abort.as_ref() {
let abort_fut = signal.wait().fuse();
let event_fut = stream.next().fuse();
futures::pin_mut!(abort_fut, event_fut);
match futures::future::select(abort_fut, event_fut).await {
futures::future::Either::Left(((), _event_fut)) => {
let last_partial = if added_partial {
match self.messages.last() {
Some(Message::Assistant(a)) => Some(a.as_ref()),
_ => None,
}
} else {
None
};
let abort_arc = Arc::new(self.build_abort_message(last_partial));
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&abort_arc)),
});
self.messages
.push(Message::Assistant(Arc::clone(&abort_arc)));
added_partial = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&abort_arc)),
assistant_message_event: AssistantMessageEvent::Error {
reason: StopReason::Aborted,
error: Arc::clone(&abort_arc),
},
});
return Ok(self.finalize_assistant_message(
Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
&on_event,
added_partial,
));
}
futures::future::Either::Right((event, _abort_fut)) => event,
}
} else {
stream.next().await
};
let Some(event_result) = event_result else {
break;
};
let event = event_result?;
match event {
StreamEvent::Start { partial } => {
let shared = Arc::new(partial);
self.update_partial_message(Arc::clone(&shared), &mut added_partial);
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::Start { partial: shared },
});
}
StreamEvent::TextStart { content_index, .. } => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
let msg = Arc::make_mut(msg_arc);
if content_index == msg.content.len() {
msg.content.push(ContentBlock::Text(TextContent::new("")));
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::TextStart {
content_index,
partial: shared,
},
});
}
}
StreamEvent::TextDelta {
content_index,
delta,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
{
let msg = Arc::make_mut(msg_arc);
if let Some(ContentBlock::Text(text)) =
msg.content.get_mut(content_index)
{
text.text.push_str(&delta);
}
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::TextDelta {
content_index,
delta,
partial: shared,
},
});
}
}
StreamEvent::TextEnd {
content_index,
content,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
{
let msg = Arc::make_mut(msg_arc);
if let Some(ContentBlock::Text(text)) =
msg.content.get_mut(content_index)
{
text.text.clone_from(&content);
}
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::TextEnd {
content_index,
content,
partial: shared,
},
});
}
}
StreamEvent::ThinkingStart { content_index, .. } => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
let msg = Arc::make_mut(msg_arc);
if content_index == msg.content.len() {
msg.content.push(ContentBlock::Thinking(ThinkingContent {
thinking: String::new(),
thinking_signature: None,
}));
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ThinkingStart {
content_index,
partial: shared,
},
});
}
}
StreamEvent::ThinkingDelta {
content_index,
delta,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
{
let msg = Arc::make_mut(msg_arc);
if let Some(ContentBlock::Thinking(thinking)) =
msg.content.get_mut(content_index)
{
thinking.thinking.push_str(&delta);
}
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ThinkingDelta {
content_index,
delta,
partial: shared,
},
});
}
}
StreamEvent::ThinkingEnd {
content_index,
content,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
{
let msg = Arc::make_mut(msg_arc);
if let Some(ContentBlock::Thinking(thinking)) =
msg.content.get_mut(content_index)
{
thinking.thinking.clone_from(&content);
}
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ThinkingEnd {
content_index,
content,
partial: shared,
},
});
}
}
StreamEvent::ToolCallStart { content_index, .. } => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
let msg = Arc::make_mut(msg_arc);
if content_index == msg.content.len() {
msg.content.push(ContentBlock::ToolCall(ToolCall {
id: String::new(),
name: String::new(),
arguments: serde_json::Value::Null,
thought_signature: None,
}));
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ToolCallStart {
content_index,
partial: shared,
},
});
}
}
StreamEvent::ToolCallDelta {
content_index,
delta,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ToolCallDelta {
content_index,
delta,
partial: shared,
},
});
}
}
StreamEvent::ToolCallEnd {
content_index,
tool_call,
..
} => {
if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
{
let msg = Arc::make_mut(msg_arc);
if let Some(ContentBlock::ToolCall(tc)) =
msg.content.get_mut(content_index)
{
*tc = tool_call.clone();
}
}
let shared = Arc::clone(msg_arc);
if !sent_start {
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&shared)),
});
sent_start = true;
}
on_event(AgentEvent::MessageUpdate {
message: Message::Assistant(Arc::clone(&shared)),
assistant_message_event: AssistantMessageEvent::ToolCallEnd {
content_index,
tool_call,
partial: shared,
},
});
}
}
StreamEvent::Done { message, .. } => {
return Ok(self.finalize_assistant_message(message, &on_event, added_partial));
}
StreamEvent::Error { error, .. } => {
return Ok(self.finalize_assistant_message(error, &on_event, added_partial));
}
}
}
if added_partial {
if let Some(Message::Assistant(last_msg)) = self.messages.last() {
let mut final_msg = (**last_msg).clone();
final_msg.stop_reason = StopReason::Error;
final_msg.error_message = Some("Stream ended without Done event".to_string());
return Ok(self.finalize_assistant_message(final_msg, &on_event, true));
}
}
Err(Error::api("Stream ended without Done event"))
}
fn update_partial_message(
&mut self,
partial: Arc<AssistantMessage>,
added_partial: &mut bool,
) -> bool {
if *added_partial {
if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
*last = Message::Assistant(partial);
} else {
tracing::warn!("update_partial_message: expected last message to be Assistant");
self.messages.push(Message::Assistant(partial));
}
false
} else {
self.messages.push(Message::Assistant(partial));
*added_partial = true;
true
}
}
fn finalize_assistant_message(
&mut self,
message: AssistantMessage,
on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
added_partial: bool,
) -> AssistantMessage {
let arc = Arc::new(message);
if added_partial {
if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
*last = Message::Assistant(Arc::clone(&arc));
} else {
tracing::warn!("finalize_assistant_message: expected last message to be Assistant");
self.messages.push(Message::Assistant(Arc::clone(&arc)));
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&arc)),
});
}
} else {
self.messages.push(Message::Assistant(Arc::clone(&arc)));
on_event(AgentEvent::MessageStart {
message: Message::Assistant(Arc::clone(&arc)),
});
}
on_event(AgentEvent::MessageEnd {
message: Message::Assistant(Arc::clone(&arc)),
});
Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone())
}
async fn execute_parallel_batch(
&self,
batch: Vec<(usize, ToolCall)>,
on_event: AgentEventHandler,
abort: Option<AbortSignal>,
) -> Vec<(usize, (ToolOutput, bool))> {
let futures = batch.into_iter().map(|(idx, tc)| {
let on_event = Arc::clone(&on_event);
async move { (idx, self.execute_tool_owned(tc, on_event).await) }
});
if let Some(signal) = abort.as_ref() {
use futures::future::{Either, select};
let all_fut = stream::iter(futures)
.buffer_unordered(MAX_CONCURRENT_TOOLS)
.collect::<Vec<_>>()
.fuse();
let abort_fut = signal.wait().fuse();
futures::pin_mut!(all_fut, abort_fut);
match select(all_fut, abort_fut).await {
Either::Left((batch_results, _)) => batch_results,
Either::Right(_) => Vec::new(), }
} else {
stream::iter(futures)
.buffer_unordered(MAX_CONCURRENT_TOOLS)
.collect::<Vec<_>>()
.await
}
}
#[allow(clippy::too_many_lines)]
async fn execute_tool_calls(
&mut self,
tool_calls: &[ToolCall],
on_event: AgentEventHandler,
new_messages: &mut Vec<Message>,
abort: Option<AbortSignal>,
) -> Result<ToolExecutionOutcome> {
let mut results = Vec::new();
let mut steering_messages: Option<Vec<Message>> = None;
for tool_call in tool_calls {
on_event(AgentEvent::ToolExecutionStart {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
args: tool_call.arguments.clone(),
});
}
let mut pending_parallel: Vec<(usize, ToolCall)> = Vec::new();
let mut tool_outputs: Vec<Option<(ToolOutput, bool)>> = vec![None; tool_calls.len()];
for (index, tool_call) in tool_calls.iter().enumerate() {
if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
break;
}
let is_read_only =
matches!(self.tools.get(&tool_call.name), Some(tool) if tool.is_read_only());
if is_read_only {
pending_parallel.push((index, tool_call.clone()));
} else {
let steering = self.drain_steering_messages().await;
if !steering.is_empty() {
steering_messages = Some(steering);
break;
}
if !pending_parallel.is_empty() {
let batch = std::mem::take(&mut pending_parallel);
let results = self
.execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
.await;
for (idx, result) in results {
tool_outputs[idx] = Some(result);
}
}
if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
break;
}
let steering = self.drain_steering_messages().await;
if !steering.is_empty() {
steering_messages = Some(steering);
break;
}
let result = self
.execute_tool(tool_call.clone(), Arc::clone(&on_event))
.await;
tool_outputs[index] = Some(result);
}
}
if !pending_parallel.is_empty()
&& !abort.as_ref().is_some_and(AbortSignal::is_aborted)
&& steering_messages.is_none()
{
let batch = std::mem::take(&mut pending_parallel);
let steering = self.drain_steering_messages().await;
if steering.is_empty() {
let results = self
.execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
.await;
for (idx, result) in results {
tool_outputs[idx] = Some(result);
}
} else {
steering_messages = Some(steering);
}
}
for (index, tool_call) in tool_calls.iter().enumerate() {
if steering_messages.is_none() && !abort.as_ref().is_some_and(AbortSignal::is_aborted) {
let steering = self.drain_steering_messages().await;
if !steering.is_empty() {
steering_messages = Some(steering);
}
}
if let Some((output, is_error)) = tool_outputs[index].take() {
let tool_result = Arc::new(ToolResultMessage {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
content: output.content,
details: output.details,
is_error,
timestamp: Utc::now().timestamp_millis(),
});
on_event(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_result.tool_call_id.clone(),
tool_name: tool_result.tool_name.clone(),
result: ToolOutput {
content: tool_result.content.clone(),
details: tool_result.details.clone(),
is_error,
},
is_error,
});
let msg = Message::ToolResult(Arc::clone(&tool_result));
self.messages.push(msg.clone());
on_event(AgentEvent::MessageStart {
message: msg.clone(),
});
new_messages.push(msg.clone());
on_event(AgentEvent::MessageEnd { message: msg });
results.push(tool_result);
} else if steering_messages.is_some() {
results.push(self.skip_tool_call(tool_call, &on_event, new_messages));
} else {
let output = ToolOutput {
content: vec![ContentBlock::Text(TextContent::new(
"Tool execution aborted",
))],
details: None,
is_error: true,
};
on_event(AgentEvent::ToolExecutionUpdate {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
args: tool_call.arguments.clone(),
partial_result: ToolOutput {
content: output.content.clone(),
details: output.details.clone(),
is_error: true,
},
});
on_event(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
result: ToolOutput {
content: output.content.clone(),
details: output.details.clone(),
is_error: true,
},
is_error: true,
});
let tool_result = Arc::new(ToolResultMessage {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
content: output.content,
details: output.details,
is_error: true,
timestamp: Utc::now().timestamp_millis(),
});
let msg = Message::ToolResult(Arc::clone(&tool_result));
self.messages.push(msg.clone());
on_event(AgentEvent::MessageStart {
message: msg.clone(),
});
let end_msg = msg.clone();
new_messages.push(msg);
on_event(AgentEvent::MessageEnd { message: end_msg });
results.push(tool_result);
}
}
Ok(ToolExecutionOutcome {
tool_results: results,
steering_messages,
})
}
async fn execute_tool(
&self,
tool_call: ToolCall,
on_event: AgentEventHandler,
) -> (ToolOutput, bool) {
let extensions = self.extensions.clone();
let (mut output, is_error) = if let Some(extensions) = &extensions {
match Self::dispatch_tool_call_hook(extensions, &tool_call).await {
Some(blocked_output) => (blocked_output, true),
None => {
self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
.await
}
}
} else {
self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
.await
};
if let Some(extensions) = &extensions {
Self::apply_tool_result_hook(extensions, &tool_call, &mut output, is_error).await;
}
(output, is_error)
}
async fn execute_tool_owned(
&self,
tool_call: ToolCall,
on_event: AgentEventHandler,
) -> (ToolOutput, bool) {
self.execute_tool(tool_call, on_event).await
}
async fn execute_tool_without_hooks(
&self,
tool_call: &ToolCall,
on_event: AgentEventHandler,
) -> (ToolOutput, bool) {
let Some(tool) = self.tools.get(&tool_call.name) else {
return (Self::tool_not_found_output(&tool_call.name), true);
};
let tool_name = tool_call.name.clone();
let tool_id = tool_call.id.clone();
let tool_args = tool_call.arguments.clone();
let on_event = Arc::clone(&on_event);
let update_callback = move |update: ToolUpdate| {
on_event(AgentEvent::ToolExecutionUpdate {
tool_call_id: tool_id.clone(),
tool_name: tool_name.clone(),
args: tool_args.clone(),
partial_result: ToolOutput {
content: update.content,
details: update.details,
is_error: false,
},
});
};
match tool
.execute(
&tool_call.id,
tool_call.arguments.clone(),
Some(Box::new(update_callback)),
)
.await
{
Ok(output) => {
let is_error = output.is_error;
(output, is_error)
}
Err(e) => (
ToolOutput {
content: vec![ContentBlock::Text(TextContent::new(format!("Error: {e}")))],
details: None,
is_error: true,
},
true,
),
}
}
fn tool_not_found_output(tool_name: &str) -> ToolOutput {
ToolOutput {
content: vec![ContentBlock::Text(TextContent::new(format!(
"Error: Tool '{tool_name}' not found"
)))],
details: None,
is_error: true,
}
}
async fn dispatch_tool_call_hook(
extensions: &ExtensionManager,
tool_call: &ToolCall,
) -> Option<ToolOutput> {
match extensions
.dispatch_tool_call(tool_call, EXTENSION_EVENT_TIMEOUT_MS)
.await
{
Ok(Some(result)) if result.block => {
Some(Self::tool_call_blocked_output(result.reason.as_deref()))
}
Ok(_) => None,
Err(err) => {
tracing::warn!("tool_call extension hook failed (fail-open): {err}");
None
}
}
}
fn tool_call_blocked_output(reason: Option<&str>) -> ToolOutput {
let reason = reason.map(str::trim).filter(|reason| !reason.is_empty());
let message = reason.map_or_else(
|| "Tool execution was blocked by an extension".to_string(),
|reason| format!("Tool execution blocked: {reason}"),
);
ToolOutput {
content: vec![ContentBlock::Text(TextContent::new(message))],
details: None,
is_error: true,
}
}
async fn apply_tool_result_hook(
extensions: &ExtensionManager,
tool_call: &ToolCall,
output: &mut ToolOutput,
is_error: bool,
) {
match extensions
.dispatch_tool_result(tool_call, &*output, is_error, EXTENSION_EVENT_TIMEOUT_MS)
.await
{
Ok(Some(result)) => {
if let Some(content) = result.content {
output.content = content;
}
if let Some(details) = result.details {
output.details = Some(details);
}
}
Ok(None) => {}
Err(err) => tracing::warn!("tool_result extension hook failed (fail-open): {err}"),
}
}
fn skip_tool_call(
&mut self,
tool_call: &ToolCall,
on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
new_messages: &mut Vec<Message>,
) -> Arc<ToolResultMessage> {
let output = ToolOutput {
content: vec![ContentBlock::Text(TextContent::new(
"Skipped due to queued user message.",
))],
details: None,
is_error: true,
};
on_event(AgentEvent::ToolExecutionUpdate {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
args: tool_call.arguments.clone(),
partial_result: output.clone(),
});
on_event(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
result: output.clone(),
is_error: true,
});
let tool_result = Arc::new(ToolResultMessage {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
content: output.content,
details: output.details,
is_error: true,
timestamp: Utc::now().timestamp_millis(),
});
let msg = Message::ToolResult(Arc::clone(&tool_result));
self.messages.push(msg.clone());
new_messages.push(msg.clone());
on_event(AgentEvent::MessageStart {
message: msg.clone(),
});
on_event(AgentEvent::MessageEnd { message: msg });
tool_result
}
}
struct ToolExecutionOutcome {
tool_results: Vec<Arc<ToolResultMessage>>,
steering_messages: Option<Vec<Message>>,
}
pub struct PreWarmedExtensionRuntime {
pub manager: ExtensionManager,
pub runtime: ExtensionRuntimeHandle,
pub tools: Arc<ToolRegistry>,
}
pub struct AgentSession {
pub agent: Agent,
pub session: Arc<Mutex<Session>>,
save_enabled: bool,
pub extensions: Option<ExtensionRegion>,
extensions_is_streaming: Arc<AtomicBool>,
compaction_settings: ResolvedCompactionSettings,
compaction_worker: CompactionWorkerState,
model_registry: Option<ModelRegistry>,
auth_storage: Option<AuthStorage>,
}
#[derive(Debug, Default)]
struct ExtensionInjectedQueue {
steering: VecDeque<Message>,
follow_up: VecDeque<Message>,
}
impl ExtensionInjectedQueue {
fn push_steering(&mut self, message: Message) {
self.steering.push_back(message);
}
fn push_follow_up(&mut self, message: Message) {
self.follow_up.push_back(message);
}
fn pop_steering(&mut self) -> Vec<Message> {
self.steering.drain(..).collect()
}
fn pop_follow_up(&mut self) -> Vec<Message> {
self.follow_up.drain(..).collect()
}
}
#[derive(Clone)]
struct AgentSessionHostActions {
session: Arc<Mutex<Session>>,
injected: Arc<StdMutex<ExtensionInjectedQueue>>,
is_streaming: Arc<AtomicBool>,
}
impl AgentSessionHostActions {
fn enqueue(&self, deliver_as: Option<ExtensionDeliverAs>, message: Message) {
let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
let Ok(mut queue) = self.injected.lock() else {
return;
};
match deliver_as {
ExtensionDeliverAs::FollowUp => {
queue.push_follow_up(message);
}
ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
queue.push_steering(message);
}
}
}
async fn append_to_session(&self, message: Message) -> Result<()> {
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.append_model_message(message);
Ok(())
}
}
#[async_trait]
impl ExtensionHostActions for AgentSessionHostActions {
async fn send_message(&self, message: ExtensionSendMessage) -> Result<()> {
let custom_message = Message::Custom(CustomMessage {
content: message.content,
custom_type: message.custom_type,
display: message.display,
details: message.details,
timestamp: Utc::now().timestamp_millis(),
});
if matches!(message.deliver_as, Some(ExtensionDeliverAs::NextTurn)) {
return self.append_to_session(custom_message).await;
}
if self.is_streaming.load(Ordering::SeqCst) {
self.enqueue(message.deliver_as, custom_message);
return Ok(());
}
let _ = message.trigger_turn;
self.append_to_session(custom_message).await
}
async fn send_user_message(&self, message: ExtensionSendUserMessage) -> Result<()> {
let user_message = Message::User(UserMessage {
content: UserContent::Text(message.text),
timestamp: Utc::now().timestamp_millis(),
});
if self.is_streaming.load(Ordering::SeqCst) {
self.enqueue(message.deliver_as, user_message);
return Ok(());
}
self.append_to_session(user_message).await
}
}
#[cfg(test)]
mod message_queue_tests {
use super::*;
fn user_message(text: &str) -> Message {
Message::User(UserMessage {
content: UserContent::Text(text.to_string()),
timestamp: 0,
})
}
#[test]
fn message_queue_one_at_a_time() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
queue.push_steering(user_message("a"));
queue.push_steering(user_message("b"));
let first = queue.pop_steering();
assert_eq!(first.len(), 1);
assert!(matches!(
first.first(),
Some(Message::User(UserMessage { content, .. }))
if matches!(content, UserContent::Text(text) if text == "a")
));
let second = queue.pop_steering();
assert_eq!(second.len(), 1);
assert!(matches!(
second.first(),
Some(Message::User(UserMessage { content, .. }))
if matches!(content, UserContent::Text(text) if text == "b")
));
assert!(queue.pop_steering().is_empty());
}
#[test]
fn message_queue_all_mode() {
let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
queue.push_steering(user_message("a"));
queue.push_steering(user_message("b"));
let drained = queue.pop_steering();
assert_eq!(drained.len(), 2);
assert!(queue.pop_steering().is_empty());
}
#[test]
fn message_queue_separates_kinds() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
queue.push_steering(user_message("steer"));
queue.push_follow_up(user_message("follow"));
let steering = queue.pop_steering();
assert_eq!(steering.len(), 1);
assert_eq!(queue.pending_count(), 1);
let follow = queue.pop_follow_up();
assert_eq!(follow.len(), 1);
assert_eq!(queue.pending_count(), 0);
}
#[test]
fn message_queue_seq_increments() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
let first = queue.push_steering(user_message("a"));
let second = queue.push_follow_up(user_message("b"));
assert!(second > first);
}
#[test]
fn message_queue_seq_saturates_at_u64_max() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
queue.next_seq = u64::MAX;
let first = queue.push_steering(user_message("a"));
let second = queue.push_follow_up(user_message("b"));
assert_eq!(first, u64::MAX);
assert_eq!(second, u64::MAX);
assert_eq!(queue.pending_count(), 2);
}
#[test]
fn message_queue_follow_up_all_mode_drains_entire_queue_in_order() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::All);
queue.push_follow_up(user_message("f1"));
queue.push_follow_up(user_message("f2"));
let follow_up = queue.pop_follow_up();
assert_eq!(follow_up.len(), 2);
assert!(matches!(
follow_up.first(),
Some(Message::User(UserMessage { content, .. }))
if matches!(content, UserContent::Text(text) if text == "f1")
));
assert!(matches!(
follow_up.get(1),
Some(Message::User(UserMessage { content, .. }))
if matches!(content, UserContent::Text(text) if text == "f2")
));
assert!(queue.pop_follow_up().is_empty());
}
}
#[cfg(test)]
mod extensions_integration_tests {
use super::*;
use crate::session::Session;
use asupersync::runtime::RuntimeBuilder;
use async_trait::async_trait;
use futures::Stream;
use serde_json::json;
use std::path::Path;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
#[derive(Debug)]
struct NoopProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for NoopProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
Ok(Box::pin(futures::stream::empty()))
}
}
#[derive(Debug)]
struct CountingTool {
calls: Arc<AtomicUsize>,
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Tool for CountingTool {
fn name(&self) -> &str {
"count_tool"
}
fn label(&self) -> &str {
"count_tool"
}
fn description(&self) -> &str {
"counting tool"
}
fn parameters(&self) -> serde_json::Value {
json!({ "type": "object" })
}
async fn execute(
&self,
_tool_call_id: &str,
_input: serde_json::Value,
_on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
) -> Result<ToolOutput> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(ToolOutput {
content: vec![ContentBlock::Text(TextContent::new("ok"))],
details: None,
is_error: false,
})
}
}
#[derive(Debug)]
struct ToolUseProvider {
stream_calls: AtomicUsize,
}
impl ToolUseProvider {
const fn new() -> Self {
Self {
stream_calls: AtomicUsize::new(0),
}
}
fn assistant_message(
&self,
stop_reason: StopReason,
content: Vec<ContentBlock>,
) -> AssistantMessage {
AssistantMessage {
content,
api: self.api().to_string(),
provider: self.name().to_string(),
model: self.model_id().to_string(),
usage: Usage::default(),
stop_reason,
error_message: None,
timestamp: 0,
}
}
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for ToolUseProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let call_index = self.stream_calls.fetch_add(1, Ordering::SeqCst);
let partial = self.assistant_message(StopReason::Stop, Vec::new());
let (reason, message) = if call_index == 0 {
let tool_calls = vec![
ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
},
ToolCall {
id: "call-2".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
},
];
(
StopReason::ToolUse,
self.assistant_message(
StopReason::ToolUse,
tool_calls
.into_iter()
.map(ContentBlock::ToolCall)
.collect::<Vec<_>>(),
),
)
} else {
(
StopReason::Stop,
self.assistant_message(
StopReason::Stop,
vec![ContentBlock::Text(TextContent::new("done"))],
),
)
};
let events = vec![
Ok(StreamEvent::Start { partial }),
Ok(StreamEvent::Done { reason, message }),
];
Ok(Box::pin(futures::stream::iter(events)))
}
}
#[test]
fn agent_session_enable_extensions_registers_extension_tools() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.registerTool({
name: "hello_tool",
label: "hello_tool",
description: "test tool",
parameters: { type: "object", properties: { name: { type: "string" } } },
execute: async (_callId, input, _onUpdate, _abort, ctx) => {
const who = input && input.name ? String(input.name) : "world";
const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
return {
content: [{ type: "text", text: `hello ${who}` }],
details: { from: "extension", cwd: cwd },
isError: false
};
}
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool = agent_session
.agent
.tools
.get("hello_tool")
.expect("hello_tool registered");
let output = tool
.execute("call-1", json!({ "name": "pi" }), None)
.await
.expect("execute tool");
assert!(!output.is_error);
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected single text content block, got {:?}",
output.content
);
let [ContentBlock::Text(text)] = output.content.as_slice() else {
return;
};
assert_eq!(text.text, "hello pi");
let details = output.details.expect("details present");
assert_eq!(
details.get("from").and_then(serde_json::Value::as_str),
Some("extension")
);
});
}
#[test]
fn agent_session_enable_extensions_rejects_mixed_js_and_native_entries() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let js_entry = temp_dir.path().join("ext.mjs");
let native_entry = temp_dir.path().join("ext.native.json");
std::fs::write(
&js_entry,
r"
export default function init(_pi) {}
",
)
.expect("write js extension entry");
std::fs::write(&native_entry, "{}").expect("write native extension descriptor");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
let err = agent_session
.enable_extensions(&[], temp_dir.path(), None, &[js_entry, native_entry])
.await
.expect_err("mixed extension runtimes should be rejected");
let msg = err.to_string();
assert!(
msg.contains("Mixed extension runtimes are not supported"),
"unexpected mixed-runtime error message: {msg}"
);
});
}
#[test]
fn extension_send_message_persists_custom_message_entry_when_idle() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.registerTool({
name: "emit_message",
label: "emit_message",
description: "emit a custom message",
parameters: { type: "object" },
execute: async () => {
pi.sendMessage({
customType: "note",
content: "hello",
display: true,
details: { from: "test" }
}, {});
return { content: [{ type: "text", text: "ok" }], isError: false };
}
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session = AgentSession::new(
agent,
Arc::clone(&session),
false,
ResolvedCompactionSettings::default(),
);
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool = agent_session
.agent
.tools
.get("emit_message")
.expect("emit_message registered");
let _ = tool
.execute("call-1", json!({}), None)
.await
.expect("execute tool");
let cx = crate::agent_cx::AgentCx::for_request();
let session_guard = session
.lock(cx.cx())
.await
.expect("lock session");
let messages = session_guard.to_messages_for_current_path();
assert!(
messages.iter().any(|msg| {
matches!(
msg,
Message::Custom(CustomMessage { custom_type, content, display, details, .. })
if custom_type == "note"
&& content == "hello"
&& *display
&& details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
)
}),
"expected custom message to be persisted, got {messages:?}"
);
});
}
#[test]
fn extension_send_message_persists_custom_message_entry_when_idle_after_await() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.registerTool({
name: "emit_message",
label: "emit_message",
description: "emit a custom message",
parameters: { type: "object" },
execute: async () => {
await Promise.resolve();
pi.sendMessage({
customType: "note",
content: "hello-after-await",
display: true,
details: { from: "test" }
}, {});
return { content: [{ type: "text", text: "ok" }], isError: false };
}
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session = AgentSession::new(
agent,
Arc::clone(&session),
false,
ResolvedCompactionSettings::default(),
);
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool = agent_session
.agent
.tools
.get("emit_message")
.expect("emit_message registered");
let _ = tool
.execute("call-1", json!({}), None)
.await
.expect("execute tool");
let cx = crate::agent_cx::AgentCx::for_request();
let session_guard = session
.lock(cx.cx())
.await
.expect("lock session");
let messages = session_guard.to_messages_for_current_path();
assert!(
messages.iter().any(|msg| {
matches!(
msg,
Message::Custom(CustomMessage { custom_type, content, display, details, .. })
if custom_type == "note"
&& content == "hello-after-await"
&& *display
&& details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
)
}),
"expected custom message to be persisted, got {messages:?}"
);
});
}
#[test]
fn send_user_message_steer_skips_remaining_tools() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
let sent = false;
pi.on("tool_call", async (event) => {
if (sent) return {};
if (event && event.toolName === "count_tool") {
sent = true;
await pi.events("sendUserMessage", {
text: "steer-now",
options: { deliverAs: "steer" }
});
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(ToolUseProvider::new());
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let _ = agent_session
.run_text("go".to_string(), |_| {})
.await
.expect("run_text");
assert_eq!(calls.load(Ordering::SeqCst), 1);
});
}
#[test]
fn send_user_message_follow_up_does_not_skip_tools() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
let sent = false;
pi.on("tool_call", async (event) => {
if (sent) return {};
if (event && event.toolName === "count_tool") {
sent = true;
await pi.events("sendUserMessage", {
text: "follow-up",
options: { deliverAs: "followUp" }
});
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(ToolUseProvider::new());
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let _ = agent_session
.run_text("go".to_string(), |_| {})
.await
.expect("run_text");
assert_eq!(calls.load(Ordering::SeqCst), 2);
});
}
#[test]
fn tool_call_hook_can_block_tool_execution() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_call", async (event) => {
if (event && event.toolName === "count_tool") {
return { block: true, reason: "blocked in test" };
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(is_error);
assert!(output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 0);
assert_eq!(output.details, None);
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "Tool execution blocked: blocked in test");
}
});
}
#[test]
fn tool_call_hook_errors_fail_open() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_call", async (_event) => {
throw new Error("boom");
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(!is_error);
assert!(!output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 1);
});
}
#[test]
fn tool_call_hook_absent_allows_tool_execution() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r"
export default function init(_pi) {}
",
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(!is_error);
assert!(!output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 1);
});
}
#[test]
fn tool_call_hook_returns_empty_allows_tool_execution() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_call", async (_event) => ({}));
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(!is_error);
assert!(!output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 1);
});
}
#[test]
fn tool_call_hook_can_block_bash_tool_execution() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_call", async (event) => {
const name = event && event.toolName ? String(event.toolName) : "";
if (name === "bash") return { block: true, reason: "blocked bash in test" };
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::new(&["bash"], temp_dir.path(), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&["bash"], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "bash".to_string(),
arguments: json!({ "command": "printf 'hi' > blocked.txt" }),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(is_error);
assert!(output.is_error);
assert_eq!(output.details, None);
assert!(
!temp_dir.path().join("blocked.txt").exists(),
"expected bash command not to run when blocked"
);
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "Tool execution blocked: blocked bash in test");
}
});
}
#[test]
fn tool_result_hook_can_modify_tool_output() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_result", async (event) => {
if (event && event.toolName === "count_tool") {
return {
content: [{ type: "text", text: "modified" }],
details: { from: "tool_result" }
};
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(!is_error);
assert!(!output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(output.details, Some(json!({ "from": "tool_result" })));
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "modified");
}
});
}
#[test]
fn tool_result_hook_can_modify_tool_not_found_error() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_result", async (event) => {
if (event && event.toolName === "missing_tool" && event.isError) {
return {
content: [{ type: "text", text: "overridden" }],
details: { handled: true }
};
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let tools = ToolRegistry::from_tools(Vec::new());
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "missing_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(is_error);
assert!(output.is_error);
assert_eq!(output.details, Some(json!({ "handled": true })));
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "overridden");
}
});
}
#[test]
fn tool_result_hook_errors_fail_open() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_result", async (_event) => {
throw new Error("boom");
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(!is_error);
assert!(!output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 1);
assert_eq!(output.details, None);
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "ok");
}
});
}
#[test]
fn tool_result_hook_runs_on_blocked_tool_call() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
runtime.block_on(async {
let temp_dir = tempfile::tempdir().expect("tempdir");
let entry_path = temp_dir.path().join("ext.mjs");
std::fs::write(
&entry_path,
r#"
export default function init(pi) {
pi.on("tool_call", async (event) => {
if (event && event.toolName === "count_tool") {
return { block: true, reason: "blocked in test" };
}
return {};
});
pi.on("tool_result", async (event) => {
if (event && event.toolName === "count_tool" && event.isError) {
return { content: [{ type: "text", text: "override" }] };
}
return {};
});
}
"#,
)
.expect("write extension entry");
let provider = Arc::new(NoopProvider);
let calls = Arc::new(AtomicUsize::new(0));
let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
calls: Arc::clone(&calls),
})]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
agent_session
.enable_extensions(&[], temp_dir.path(), None, &[entry_path])
.await
.expect("enable extensions");
let tool_call = ToolCall {
id: "call-1".to_string(),
name: "count_tool".to_string(),
arguments: json!({}),
thought_signature: None,
};
let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
assert!(is_error);
assert!(output.is_error);
assert_eq!(calls.load(Ordering::SeqCst), 0);
assert!(
matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
"Expected text output, got {:?}",
output.content
);
if let [ContentBlock::Text(text)] = output.content.as_slice() {
assert_eq!(text.text, "override");
}
});
}
}
#[cfg(test)]
mod abort_tests {
use super::*;
use crate::session::Session;
use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
use asupersync::runtime::RuntimeBuilder;
use async_trait::async_trait;
use futures::Stream;
use serde_json::json;
use std::path::Path;
use std::pin::Pin;
use std::sync::Mutex as StdMutex;
use std::sync::atomic::AtomicUsize;
use std::task::{Context as TaskContext, Poll};
struct StartThenPending {
start: Option<StreamEvent>,
}
impl Stream for StartThenPending {
type Item = crate::error::Result<StreamEvent>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut TaskContext<'_>,
) -> Poll<Option<Self::Item>> {
if let Some(event) = self.start.take() {
return Poll::Ready(Some(Ok(event)));
}
Poll::Pending
}
}
#[derive(Debug)]
struct HangingProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for HangingProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let partial = AssistantMessage {
content: Vec::new(),
api: self.api().to_string(),
provider: self.name().to_string(),
model: self.model_id().to_string(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
};
Ok(Box::pin(StartThenPending {
start: Some(StreamEvent::Start { partial }),
}))
}
}
#[derive(Debug)]
struct CountingProvider {
calls: Arc<std::sync::atomic::AtomicUsize>,
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for CountingProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(Box::pin(futures::stream::empty()))
}
}
#[derive(Debug)]
struct PhasedProvider {
pending_calls: usize,
calls: AtomicUsize,
}
impl PhasedProvider {
const fn new(pending_calls: usize) -> Self {
Self {
pending_calls,
calls: AtomicUsize::new(0),
}
}
fn base_message() -> AssistantMessage {
AssistantMessage {
content: Vec::new(),
api: "test-api".to_string(),
provider: "test-provider".to_string(),
model: "test-model".to_string(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
}
}
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for PhasedProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let call = self.calls.fetch_add(1, Ordering::SeqCst);
if call < self.pending_calls {
return Ok(Box::pin(StartThenPending {
start: Some(StreamEvent::Start {
partial: Self::base_message(),
}),
}));
}
let partial = Self::base_message();
let mut done = Self::base_message();
done.content = vec![ContentBlock::Text(TextContent::new(format!(
"resumed-response-{call}"
)))];
Ok(Box::pin(futures::stream::iter(vec![
Ok(StreamEvent::Start { partial }),
Ok(StreamEvent::Done {
reason: StopReason::Stop,
message: done,
}),
])))
}
}
#[derive(Debug)]
struct ToolCallProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for ToolCallProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let message = AssistantMessage {
content: vec![ContentBlock::ToolCall(ToolCall {
id: "call-1".to_string(),
name: "hanging_tool".to_string(),
arguments: json!({}),
thought_signature: None,
})],
api: "test-api".to_string(),
provider: "test-provider".to_string(),
model: "test-model".to_string(),
usage: Usage::default(),
stop_reason: StopReason::ToolUse,
error_message: None,
timestamp: 0,
};
Ok(Box::pin(futures::stream::iter(vec![Ok(
StreamEvent::Done {
reason: StopReason::ToolUse,
message,
},
)])))
}
}
#[derive(Debug)]
struct HangingTool;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Tool for HangingTool {
fn name(&self) -> &str {
"hanging_tool"
}
fn label(&self) -> &str {
"Hanging Tool"
}
fn description(&self) -> &str {
"Never completes unless aborted by the host"
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {},
"additionalProperties": false
})
}
async fn execute(
&self,
_tool_call_id: &str,
_input: serde_json::Value,
_on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
) -> crate::error::Result<ToolOutput> {
futures::future::pending::<()>().await;
unreachable!("hanging tool should be aborted by the agent")
}
}
fn event_tag(event: &AgentEvent) -> &'static str {
match event {
AgentEvent::AgentStart { .. } => "agent_start",
AgentEvent::AgentEnd { error, .. } => {
if error.as_deref() == Some("Aborted") {
"agent_end_aborted"
} else {
"agent_end"
}
}
AgentEvent::TurnStart { .. } => "turn_start",
AgentEvent::TurnEnd { .. } => "turn_end",
AgentEvent::MessageStart { .. } => "message_start",
AgentEvent::MessageUpdate {
assistant_message_event,
..
} => match &assistant_message_event {
AssistantMessageEvent::Error {
reason: StopReason::Aborted,
..
} => "assistant_error_aborted",
AssistantMessageEvent::Done { .. } => "assistant_done",
_ => "assistant_update",
},
AgentEvent::MessageEnd { .. } => "message_end",
AgentEvent::ToolExecutionStart { .. } => "tool_start",
AgentEvent::ToolExecutionUpdate { .. } => "tool_update",
AgentEvent::ToolExecutionEnd { .. } => "tool_end",
AgentEvent::AutoCompactionStart { .. } => "auto_compaction_start",
AgentEvent::AutoCompactionEnd { .. } => "auto_compaction_end",
AgentEvent::AutoRetryStart { .. } => "auto_retry_start",
AgentEvent::AutoRetryEnd { .. } => "auto_retry_end",
AgentEvent::ExtensionError { .. } => "extension_error",
}
}
fn assert_abort_resume_message_sequence(persisted: &[Message]) {
assert_eq!(
persisted.len(),
6,
"expected three user+assistant pairs, got: {persisted:?}"
);
let assistant_states = persisted
.iter()
.filter_map(|message| match message {
Message::Assistant(assistant) => Some(assistant.stop_reason),
_ => None,
})
.collect::<Vec<_>>();
assert_eq!(
assistant_states,
vec![StopReason::Aborted, StopReason::Aborted, StopReason::Stop]
);
}
fn assert_abort_resume_timeline_boundaries(timeline: &[String]) {
assert!(
timeline
.iter()
.any(|event| event == "run0:agent_end_aborted"),
"missing aborted boundary for first run: {timeline:?}"
);
assert!(
timeline
.iter()
.any(|event| event == "run1:agent_end_aborted"),
"missing aborted boundary for second run: {timeline:?}"
);
assert!(
timeline.iter().any(|event| event == "run2:agent_end"),
"missing successful boundary for resumed run: {timeline:?}"
);
}
#[test]
fn abort_interrupts_in_flight_stream() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
let started = Arc::new(Notify::new());
let started_wait = started.notified();
let (abort_handle, abort_signal) = AbortHandle::new();
let provider = Arc::new(HangingProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
let started_tx = Arc::clone(&started);
let join = handle.spawn(async move {
agent_session
.run_text_with_abort("hello".to_string(), Some(abort_signal), move |event| {
if matches!(
event,
AgentEvent::MessageStart {
message: Message::Assistant(_)
}
) {
started_tx.notify_one();
}
})
.await
});
runtime.block_on(async move {
started_wait.await;
abort_handle.abort();
let message = join.await.expect("run_text_with_abort");
assert_eq!(message.stop_reason, StopReason::Aborted);
assert_eq!(message.error_message.as_deref(), Some("Aborted"));
});
}
#[test]
fn abort_before_run_skips_provider_stream_call() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let provider = Arc::new(CountingProvider {
calls: Arc::clone(&calls),
});
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
let (abort_handle, abort_signal) = AbortHandle::new();
abort_handle.abort();
runtime.block_on(async move {
let message = agent_session
.run_text_with_abort("hello".to_string(), Some(abort_signal), |_| {})
.await
.expect("run_text_with_abort");
assert_eq!(message.stop_reason, StopReason::Aborted);
assert_eq!(calls.load(Ordering::SeqCst), 0);
});
}
#[test]
fn abort_then_resume_preserves_session_history() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
runtime.block_on(async move {
let provider = Arc::new(PhasedProvider::new(1));
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session = AgentSession::new(
agent,
Arc::clone(&session),
false,
ResolvedCompactionSettings::default(),
);
let started = Arc::new(Notify::new());
let (abort_handle, abort_signal) = AbortHandle::new();
let started_for_abort = Arc::clone(&started);
let abort_join = handle.spawn(async move {
started_for_abort.notified().await;
abort_handle.abort();
});
let aborted = agent_session
.run_text_with_abort("first".to_string(), Some(abort_signal), {
let started = Arc::clone(&started);
move |event| {
if matches!(
event,
AgentEvent::MessageStart {
message: Message::Assistant(_)
}
) {
started.notify_one();
}
}
})
.await
.expect("first run");
abort_join.await;
assert_eq!(aborted.stop_reason, StopReason::Aborted);
assert_eq!(aborted.error_message.as_deref(), Some("Aborted"));
let resumed = agent_session
.run_text("second".to_string(), |_| {})
.await
.expect("resumed run");
assert_eq!(resumed.stop_reason, StopReason::Stop);
assert!(resumed.error_message.is_none());
let cx = crate::agent_cx::AgentCx::for_request();
let persisted = session
.lock(cx.cx())
.await
.expect("lock session")
.to_messages_for_current_path();
assert_eq!(
persisted.len(),
4,
"unexpected message history after abort+resume: {persisted:?}"
);
assert!(matches!(persisted.first(), Some(Message::User(_))));
assert!(matches!(
persisted.get(1),
Some(Message::Assistant(assistant)) if assistant.stop_reason == StopReason::Aborted
));
assert!(matches!(persisted.get(2), Some(Message::User(_))));
assert!(matches!(
persisted.get(3),
Some(Message::Assistant(assistant))
if assistant.stop_reason == StopReason::Stop && assistant.error_message.is_none()
));
});
}
#[test]
fn repeated_abort_then_resume_has_consistent_timeline_and_state() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
runtime.block_on(async move {
let provider = Arc::new(PhasedProvider::new(2));
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session = AgentSession::new(
agent,
Arc::clone(&session),
false,
ResolvedCompactionSettings::default(),
);
let timeline = Arc::new(StdMutex::new(Vec::<String>::new()));
for run_idx in 0..2 {
let started = Arc::new(Notify::new());
let (abort_handle, abort_signal) = AbortHandle::new();
let started_for_abort = Arc::clone(&started);
let abort_join = handle.spawn(async move {
started_for_abort.notified().await;
abort_handle.abort();
});
let run_timeline = Arc::clone(&timeline);
let aborted = agent_session
.run_text_with_abort(format!("abort-run-{run_idx}"), Some(abort_signal), {
let started = Arc::clone(&started);
move |event| {
if let Ok(mut events) = run_timeline.lock() {
events.push(format!("run{run_idx}:{}", event_tag(&event)));
}
if matches!(
event,
AgentEvent::MessageStart {
message: Message::Assistant(_)
}
) {
started.notify_one();
}
}
})
.await
.expect("aborted run");
abort_join.await;
assert_eq!(
aborted.stop_reason,
StopReason::Aborted,
"run {run_idx} should abort cleanly"
);
}
let run_timeline = Arc::clone(&timeline);
let resumed = agent_session
.run_text("final-run".to_string(), move |event| {
if let Ok(mut events) = run_timeline.lock() {
events.push(format!("run2:{}", event_tag(&event)));
}
})
.await
.expect("final resumed run");
assert_eq!(resumed.stop_reason, StopReason::Stop);
assert!(resumed.error_message.is_none());
let cx = crate::agent_cx::AgentCx::for_request();
let persisted = session
.lock(cx.cx())
.await
.expect("lock session")
.to_messages_for_current_path();
assert_abort_resume_message_sequence(&persisted);
let timeline = timeline.lock().expect("timeline lock").clone();
assert_abort_resume_timeline_boundaries(&timeline);
});
}
#[test]
fn abort_during_tool_execution_records_aborted_tool_result() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
runtime.block_on(async move {
let provider = Arc::new(ToolCallProvider);
let tools = ToolRegistry::from_tools(vec![Box::new(HangingTool)]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session = AgentSession::new(
agent,
Arc::clone(&session),
false,
ResolvedCompactionSettings::default(),
);
let tool_started = Arc::new(Notify::new());
let (abort_handle, abort_signal) = AbortHandle::new();
let tool_started_for_abort = Arc::clone(&tool_started);
let abort_join = handle.spawn(async move {
tool_started_for_abort.notified().await;
abort_handle.abort();
});
let result = agent_session
.run_text_with_abort("trigger tool".to_string(), Some(abort_signal), {
let tool_started = Arc::clone(&tool_started);
move |event| {
if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
tool_started.notify_one();
}
}
})
.await
.expect("tool-abort run");
abort_join.await;
assert_eq!(result.stop_reason, StopReason::Aborted);
let cx = crate::agent_cx::AgentCx::for_request();
let persisted = session
.lock(cx.cx())
.await
.expect("lock session")
.to_messages_for_current_path();
let tool_result = persisted
.iter()
.find_map(|message| match message {
Message::ToolResult(result) => Some(result),
_ => None,
})
.expect("expected tool result message");
assert!(tool_result.is_error);
assert!(
tool_result.content.iter().any(|block| {
matches!(
block,
ContentBlock::Text(text) if text.text.contains("Tool execution aborted")
)
}),
"missing aborted tool marker in tool output: {:?}",
tool_result.content
);
});
}
}
#[cfg(test)]
mod turn_event_tests {
use super::*;
use crate::session::Session;
use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
use asupersync::runtime::RuntimeBuilder;
use async_trait::async_trait;
use futures::Stream;
use serde_json::json;
use std::path::Path;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
fn assistant_message(text: &str) -> AssistantMessage {
AssistantMessage {
content: vec![ContentBlock::Text(TextContent::new(text))],
api: "test-api".to_string(),
provider: "test-provider".to_string(),
model: "test-model".to_string(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
}
}
struct SingleShotProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for SingleShotProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let partial = assistant_message("");
let final_message = assistant_message("hello");
let events = vec![
Ok(StreamEvent::Start { partial }),
Ok(StreamEvent::Done {
reason: StopReason::Stop,
message: final_message,
}),
];
Ok(Box::pin(futures::stream::iter(events)))
}
}
#[derive(Debug)]
struct EchoTool;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Tool for EchoTool {
fn name(&self) -> &str {
"echo_tool"
}
fn label(&self) -> &str {
"echo_tool"
}
fn description(&self) -> &str {
"echo test tool"
}
fn parameters(&self) -> serde_json::Value {
json!({ "type": "object" })
}
async fn execute(
&self,
_tool_call_id: &str,
_input: serde_json::Value,
_on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
) -> Result<ToolOutput> {
Ok(ToolOutput {
content: vec![ContentBlock::Text(TextContent::new("tool-ok"))],
details: None,
is_error: false,
})
}
}
#[derive(Debug)]
struct ToolTurnProvider {
calls: AtomicUsize,
}
impl ToolTurnProvider {
const fn new() -> Self {
Self {
calls: AtomicUsize::new(0),
}
}
fn assistant_message_with(
&self,
stop_reason: StopReason,
content: Vec<ContentBlock>,
) -> AssistantMessage {
AssistantMessage {
content,
api: self.api().to_string(),
provider: self.name().to_string(),
model: self.model_id().to_string(),
usage: Usage::default(),
stop_reason,
error_message: None,
timestamp: 0,
}
}
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for ToolTurnProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
let partial = self.assistant_message_with(StopReason::Stop, Vec::new());
let done = if call_index == 0 {
self.assistant_message_with(
StopReason::ToolUse,
vec![ContentBlock::ToolCall(ToolCall {
id: "tool-1".to_string(),
name: "echo_tool".to_string(),
arguments: json!({}),
thought_signature: None,
})],
)
} else {
self.assistant_message_with(
StopReason::Stop,
vec![ContentBlock::Text(TextContent::new("final"))],
)
};
Ok(Box::pin(futures::stream::iter(vec![
Ok(StreamEvent::Start { partial }),
Ok(StreamEvent::Done {
reason: done.stop_reason,
message: done,
}),
])))
}
}
#[test]
fn turn_events_wrap_assistant_response() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
let provider = Arc::new(SingleShotProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let events_capture = Arc::clone(&events);
let join = handle.spawn(async move {
agent_session
.run_text("hello".to_string(), move |event| {
events_capture.lock().unwrap().push(event);
})
.await
.expect("run_text")
});
runtime.block_on(async move {
let message = join.await;
assert_eq!(message.stop_reason, StopReason::Stop);
let events = events.lock().unwrap();
let turn_start_indices = events
.iter()
.enumerate()
.filter_map(|(idx, event)| {
matches!(event, AgentEvent::TurnStart { .. }).then_some(idx)
})
.collect::<Vec<_>>();
let turn_end_indices = events
.iter()
.enumerate()
.filter_map(|(idx, event)| {
matches!(event, AgentEvent::TurnEnd { .. }).then_some(idx)
})
.collect::<Vec<_>>();
assert_eq!(turn_start_indices.len(), 1);
assert_eq!(turn_end_indices.len(), 1);
assert!(turn_start_indices[0] < turn_end_indices[0]);
let assistant_message_end = events
.iter()
.enumerate()
.find_map(|(idx, event)| match event {
AgentEvent::MessageEnd {
message: Message::Assistant(_),
} => Some(idx),
_ => None,
})
.expect("assistant message end");
assert!(assistant_message_end < turn_end_indices[0]);
let (message_is_assistant, tool_results_empty) = {
let turn_end_event = &events[turn_end_indices[0]];
assert!(
matches!(turn_end_event, AgentEvent::TurnEnd { .. }),
"Expected TurnEnd event, got {turn_end_event:?}"
);
match turn_end_event {
AgentEvent::TurnEnd {
message,
tool_results,
..
} => (
matches!(message, Message::Assistant(_)),
tool_results.is_empty(),
),
_ => (false, false),
}
};
drop(events);
assert!(message_is_assistant);
assert!(tool_results_empty);
});
}
#[test]
fn turn_events_include_tool_execution_and_tool_result_messages() {
let runtime = RuntimeBuilder::current_thread()
.build()
.expect("runtime build");
let handle = runtime.handle();
let provider = Arc::new(ToolTurnProvider::new());
let tools = ToolRegistry::from_tools(vec![Box::new(EchoTool)]);
let agent = Agent::new(provider, tools, AgentConfig::default());
let session = Arc::new(Mutex::new(Session::in_memory()));
let mut agent_session =
AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let events_capture = Arc::clone(&events);
let join = handle.spawn(async move {
agent_session
.run_text("hello".to_string(), move |event| {
events_capture.lock().expect("events lock").push(event);
})
.await
.expect("run_text")
});
runtime.block_on(async move {
let message = join.await;
assert_eq!(message.stop_reason, StopReason::Stop);
let events = events.lock().expect("events lock");
let turn_start_count = events
.iter()
.filter(|event| matches!(event, AgentEvent::TurnStart { .. }))
.count();
let turn_end_count = events
.iter()
.filter(|event| matches!(event, AgentEvent::TurnEnd { .. }))
.count();
assert_eq!(
turn_start_count, 2,
"expected one tool turn and one final turn"
);
assert_eq!(
turn_end_count, 2,
"expected one tool turn and one final turn"
);
let tool_start_idx = events
.iter()
.position(|event| matches!(event, AgentEvent::ToolExecutionStart { .. }))
.expect("tool execution start event");
let tool_end_idx = events
.iter()
.position(|event| matches!(event, AgentEvent::ToolExecutionEnd { .. }))
.expect("tool execution end event");
assert!(tool_start_idx < tool_end_idx);
let first_turn_end_idx = events
.iter()
.position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
.expect("first turn end");
assert!(
tool_end_idx < first_turn_end_idx,
"tool execution should complete before first turn end"
);
let first_turn_tool_results = events.iter().find_map(|event| match event {
AgentEvent::TurnEnd {
turn_index,
tool_results,
..
} if *turn_index == 0 => Some(tool_results),
_ => None,
});
let Some(first_turn_tool_results) = first_turn_tool_results else {
panic!("missing first turn tool results");
};
assert_eq!(first_turn_tool_results.len(), 1);
let first_result = first_turn_tool_results.first().unwrap();
if let Message::ToolResult(tr) = first_result {
assert_eq!(tr.tool_name, "echo_tool");
assert!(!tr.is_error);
} else {
panic!("expected ToolResult message");
}
drop(events);
});
}
}
impl AgentSession {
pub const fn runtime_repair_mode_from_policy_mode(mode: RepairPolicyMode) -> RepairMode {
match mode {
RepairPolicyMode::Off => RepairMode::Off,
RepairPolicyMode::Suggest => RepairMode::Suggest,
RepairPolicyMode::AutoSafe => RepairMode::AutoSafe,
RepairPolicyMode::AutoStrict => RepairMode::AutoStrict,
}
}
#[allow(clippy::too_many_arguments)]
async fn start_js_extension_runtime(
stage: &'static str,
cwd: &std::path::Path,
tools: Arc<ToolRegistry>,
manager: ExtensionManager,
policy: ExtensionPolicy,
repair_mode: RepairMode,
memory_limit_bytes: usize,
) -> Result<ExtensionRuntimeHandle> {
let mut config = PiJsRuntimeConfig {
cwd: cwd.display().to_string(),
repair_mode,
..PiJsRuntimeConfig::default()
};
config.limits.memory_limit_bytes = Some(memory_limit_bytes).filter(|bytes| *bytes > 0);
let runtime =
JsExtensionRuntimeHandle::start_with_policy(config, tools, manager, policy).await?;
tracing::info!(
event = "pi.extension_runtime.engine_decision",
stage,
requested = "quickjs",
selected = "quickjs",
fallback = false,
"Extension runtime engine selected (legacy JS/TS)"
);
Ok(ExtensionRuntimeHandle::Js(runtime))
}
#[allow(clippy::too_many_arguments)]
async fn start_native_extension_runtime(
stage: &'static str,
_cwd: &std::path::Path,
_tools: Arc<ToolRegistry>,
_manager: ExtensionManager,
_policy: ExtensionPolicy,
_repair_mode: RepairMode,
_memory_limit_bytes: usize,
) -> Result<ExtensionRuntimeHandle> {
let runtime = NativeRustExtensionRuntimeHandle::start().await?;
tracing::info!(
event = "pi.extension_runtime.engine_decision",
stage,
requested = "native-rust",
selected = "native-rust",
fallback = false,
"Extension runtime engine selected (native-rust)"
);
Ok(ExtensionRuntimeHandle::NativeRust(runtime))
}
pub fn new(
agent: Agent,
session: Arc<Mutex<Session>>,
save_enabled: bool,
compaction_settings: ResolvedCompactionSettings,
) -> Self {
Self {
agent,
session,
save_enabled,
extensions: None,
extensions_is_streaming: Arc::new(AtomicBool::new(false)),
compaction_settings,
compaction_worker: CompactionWorkerState::new(CompactionQuota::default()),
model_registry: None,
auth_storage: None,
}
}
#[must_use]
pub fn with_model_registry(mut self, registry: ModelRegistry) -> Self {
self.model_registry = Some(registry);
self
}
#[must_use]
pub fn with_auth_storage(mut self, auth: AuthStorage) -> Self {
self.auth_storage = Some(auth);
self
}
pub fn set_model_registry(&mut self, registry: ModelRegistry) {
self.model_registry = Some(registry);
}
pub fn set_auth_storage(&mut self, auth: AuthStorage) {
self.auth_storage = Some(auth);
}
pub async fn set_provider_model(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
{
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.set_model_header(
Some(provider_id.to_string()),
Some(model_id.to_string()),
None,
);
}
self.apply_session_model_selection(provider_id, model_id);
let provider = self.agent.provider();
if provider.name() != provider_id || provider.model_id() != model_id {
return Err(Error::validation(format!(
"Unable to switch provider/model to {provider_id}/{model_id}"
)));
}
self.persist_session().await
}
fn resolve_stream_api_key_for_model(&self, entry: &ModelEntry) -> Option<String> {
let normalize = |key_opt: Option<String>| {
key_opt.and_then(|key| {
let trimmed = key.trim();
(!trimmed.is_empty()).then(|| trimmed.to_string())
})
};
self.auth_storage
.as_ref()
.and_then(|auth| normalize(auth.resolve_api_key(&entry.model.provider, None)))
.or_else(|| normalize(entry.api_key.clone()))
}
fn apply_session_model_selection(&mut self, provider_id: &str, model_id: &str) {
if self.agent.provider().name() == provider_id
&& self.agent.provider().model_id() == model_id
{
return;
}
let Some(registry) = &self.model_registry else {
return;
};
let Some(entry) = registry.find(provider_id, model_id) else {
tracing::warn!("Session model {provider_id}/{model_id} not found in model registry");
return;
};
match crate::providers::create_provider(
&entry,
self.extensions.as_ref().map(ExtensionRegion::manager),
) {
Ok(provider) => {
tracing::info!("Updating agent provider to {provider_id}/{model_id}");
self.agent.set_provider(provider);
let resolved_key = self.resolve_stream_api_key_for_model(&entry);
if resolved_key.is_none() {
tracing::warn!(
"No API key resolved for session model {provider_id}/{model_id}; clearing stream API key"
);
}
let stream_options = self.agent.stream_options_mut();
stream_options.api_key = resolved_key;
stream_options.headers.clone_from(&entry.headers);
}
Err(e) => {
tracing::warn!("Failed to create provider for session model: {e}");
}
}
}
pub const fn save_enabled(&self) -> bool {
self.save_enabled
}
pub async fn compact_now(
&mut self,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<()> {
self.compact_synchronous(Arc::new(on_event)).await
}
async fn maybe_compact(&mut self, on_event: AgentEventHandler) -> Result<()> {
if !self.compaction_settings.enabled {
return Ok(());
}
if let Some(outcome) = self.compaction_worker.try_recv() {
match outcome {
Ok(result) => {
self.apply_compaction_result(result, Arc::clone(&on_event))
.await?;
}
Err(e) => {
on_event(AgentEvent::AutoCompactionEnd {
result: None,
aborted: false,
will_retry: false,
error_message: Some(e.to_string()),
});
}
}
}
if !self.compaction_worker.can_start() {
return Ok(());
}
let preparation = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
let entries = session
.entries_for_current_path()
.into_iter()
.cloned()
.collect::<Vec<_>>();
compaction::prepare_compaction(&entries, self.compaction_settings.clone())
};
if let Some(prep) = preparation {
on_event(AgentEvent::AutoCompactionStart {
reason: "threshold".to_string(),
});
let provider = self.agent.provider();
let api_key = self
.agent
.stream_options()
.api_key
.clone()
.unwrap_or_default();
self.compaction_worker.start(prep, provider, api_key, None);
}
Ok(())
}
async fn apply_compaction_result(
&self,
result: compaction::CompactionResult,
on_event: AgentEventHandler,
) -> Result<()> {
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
let details = compaction::compaction_details_to_value(&result.details).ok();
let result_value = details.clone();
session.append_compaction(
result.summary,
result.first_kept_entry_id,
result.tokens_before,
details,
None, );
if self.save_enabled {
session
.flush_autosave(AutosaveFlushTrigger::Periodic)
.await?;
}
on_event(AgentEvent::AutoCompactionEnd {
result: result_value,
aborted: false,
will_retry: false,
error_message: None,
});
Ok(())
}
async fn compact_synchronous(&self, on_event: AgentEventHandler) -> Result<()> {
if !self.compaction_settings.enabled {
return Ok(());
}
let preparation = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
let entries = session
.entries_for_current_path()
.into_iter()
.cloned()
.collect::<Vec<_>>();
compaction::prepare_compaction(&entries, self.compaction_settings.clone())
};
if let Some(prep) = preparation {
on_event(AgentEvent::AutoCompactionStart {
reason: "threshold".to_string(),
});
let provider = self.agent.provider();
let api_key = self
.agent
.stream_options()
.api_key
.clone()
.unwrap_or_default();
match compaction::compact(prep, provider, &api_key, None).await {
Ok(result) => {
self.apply_compaction_result(result, Arc::clone(&on_event))
.await?;
}
Err(e) => {
on_event(AgentEvent::AutoCompactionEnd {
result: None,
aborted: false,
will_retry: false,
error_message: Some(e.to_string()),
});
return Err(e);
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn enable_extensions(
&mut self,
enabled_tools: &[&str],
cwd: &std::path::Path,
config: Option<&crate::config::Config>,
extension_entries: &[std::path::PathBuf],
) -> Result<()> {
self.enable_extensions_with_policy(
enabled_tools,
cwd,
config,
extension_entries,
None,
None,
None,
)
.await
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
pub async fn enable_extensions_with_policy(
&mut self,
enabled_tools: &[&str],
cwd: &std::path::Path,
config: Option<&crate::config::Config>,
extension_entries: &[std::path::PathBuf],
policy: Option<ExtensionPolicy>,
repair_policy: Option<RepairPolicyMode>,
pre_warmed: Option<PreWarmedExtensionRuntime>,
) -> Result<()> {
let mut js_specs: Vec<JsExtensionLoadSpec> = Vec::new();
let mut native_specs: Vec<NativeRustExtensionLoadSpec> = Vec::new();
#[cfg(feature = "wasm-host")]
let mut wasm_specs: Vec<WasmExtensionLoadSpec> = Vec::new();
for entry in extension_entries {
match resolve_extension_load_spec(entry)? {
ExtensionLoadSpec::Js(spec) => js_specs.push(spec),
ExtensionLoadSpec::NativeRust(spec) => native_specs.push(spec),
#[cfg(feature = "wasm-host")]
ExtensionLoadSpec::Wasm(spec) => wasm_specs.push(spec),
}
}
if !js_specs.is_empty() && !native_specs.is_empty() {
return Err(Error::validation(
"Mixed extension runtimes are not supported in one session yet. Use either JS/TS extensions (QuickJS) or native-rust descriptors (*.native.json), but not both at once."
.to_string(),
));
}
let resolved_policy = policy.clone().unwrap_or_default();
let resolved_repair_policy = repair_policy
.or_else(|| config.map(|cfg| cfg.resolve_repair_policy(None)))
.unwrap_or(RepairPolicyMode::AutoSafe);
let runtime_repair_mode =
Self::runtime_repair_mode_from_policy_mode(resolved_repair_policy);
let memory_limit_bytes =
(resolved_policy.max_memory_mb as usize).saturating_mul(1024 * 1024);
let wants_js_runtime = !js_specs.is_empty();
#[allow(unused_variables)]
let (manager, tools) = if let Some(pre) = pre_warmed {
let manager = pre.manager;
let tools = pre.tools;
let runtime = match pre.runtime {
ExtensionRuntimeHandle::NativeRust(runtime) => {
if wants_js_runtime {
tracing::warn!(
event = "pi.extension_runtime.prewarm.mismatch",
expected = "quickjs",
got = "native-rust",
"Pre-warmed runtime mismatched requested JS mode; creating quickjs runtime"
);
Self::start_js_extension_runtime(
"agent_enable_extensions_prewarm_mismatch",
cwd,
Arc::clone(&tools),
manager.clone(),
resolved_policy.clone(),
runtime_repair_mode,
memory_limit_bytes,
)
.await?
} else {
tracing::info!(
event = "pi.extension_runtime.engine_decision",
stage = "agent_enable_extensions_prewarmed",
requested = "native-rust",
selected = "native-rust",
fallback = false,
"Using pre-warmed extension runtime"
);
ExtensionRuntimeHandle::NativeRust(runtime)
}
}
ExtensionRuntimeHandle::Js(runtime) => {
if wants_js_runtime {
tracing::info!(
event = "pi.extension_runtime.engine_decision",
stage = "agent_enable_extensions_prewarmed",
requested = "quickjs",
selected = "quickjs",
fallback = false,
"Using pre-warmed extension runtime"
);
ExtensionRuntimeHandle::Js(runtime)
} else {
tracing::warn!(
event = "pi.extension_runtime.prewarm.mismatch",
expected = "native-rust",
got = "quickjs",
"Pre-warmed runtime mismatched requested native mode; creating native-rust runtime"
);
Self::start_native_extension_runtime(
"agent_enable_extensions_prewarm_mismatch",
cwd,
Arc::clone(&tools),
manager.clone(),
resolved_policy.clone(),
runtime_repair_mode,
memory_limit_bytes,
)
.await?
}
}
};
manager.set_runtime(runtime);
(manager, tools)
} else {
let manager = ExtensionManager::new();
manager.set_cwd(cwd.display().to_string());
let tools = Arc::new(ToolRegistry::new(enabled_tools, cwd, config));
if let Some(cfg) = config {
let resolved_risk = cfg.resolve_extension_risk_with_metadata();
tracing::info!(
event = "pi.extension_runtime_risk.config",
source = resolved_risk.source,
enabled = resolved_risk.settings.enabled,
alpha = resolved_risk.settings.alpha,
window_size = resolved_risk.settings.window_size,
ledger_limit = resolved_risk.settings.ledger_limit,
fail_closed = resolved_risk.settings.fail_closed,
"Resolved extension runtime risk settings"
);
manager.set_runtime_risk_config(resolved_risk.settings);
}
let runtime = if wants_js_runtime {
Self::start_js_extension_runtime(
"agent_enable_extensions_boot",
cwd,
Arc::clone(&tools),
manager.clone(),
resolved_policy,
runtime_repair_mode,
memory_limit_bytes,
)
.await?
} else {
Self::start_native_extension_runtime(
"agent_enable_extensions_boot",
cwd,
Arc::clone(&tools),
manager.clone(),
resolved_policy,
runtime_repair_mode,
memory_limit_bytes,
)
.await?
};
manager.set_runtime(runtime);
(manager, tools)
};
manager.set_session(Arc::new(SessionHandle(self.session.clone())));
let injected = Arc::new(StdMutex::new(ExtensionInjectedQueue::default()));
let host_actions = AgentSessionHostActions {
session: Arc::clone(&self.session),
injected: Arc::clone(&injected),
is_streaming: Arc::clone(&self.extensions_is_streaming),
};
manager.set_host_actions(Arc::new(host_actions));
{
let steering_queue = Arc::clone(&injected);
let follow_up_queue = Arc::clone(&injected);
let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
let steering_queue = Arc::clone(&steering_queue);
Box::pin(async move {
let Ok(mut queue) = steering_queue.lock() else {
return Vec::new();
};
queue.pop_steering()
})
};
let follow_up_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
let follow_up_queue = Arc::clone(&follow_up_queue);
Box::pin(async move {
let Ok(mut queue) = follow_up_queue.lock() else {
return Vec::new();
};
queue.pop_follow_up()
})
};
self.agent.register_message_fetchers(
Some(Arc::new(steering_fetcher)),
Some(Arc::new(follow_up_fetcher)),
);
}
if !js_specs.is_empty() {
manager.load_js_extensions(js_specs).await?;
}
if !native_specs.is_empty() {
manager.load_native_extensions(native_specs).await?;
}
if let Some(rt) = manager.runtime() {
let events = rt.drain_repair_events().await;
if !events.is_empty() {
log_repair_diagnostics(&events);
}
}
#[cfg(feature = "wasm-host")]
if !wasm_specs.is_empty() {
let host = WasmExtensionHost::new(cwd, policy.unwrap_or_default())?;
manager
.load_wasm_extensions(&host, wasm_specs, Arc::clone(&tools))
.await?;
}
let session_path = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::extension(e.to_string()))?;
session.path.as_ref().map(|p| p.display().to_string())
};
if let Err(err) = manager
.dispatch_event(
ExtensionEventName::Startup,
Some(serde_json::json!({
"version": env!("CARGO_PKG_VERSION"),
"sessionFile": session_path,
})),
)
.await
{
tracing::warn!("startup extension hook failed (fail-open): {err}");
}
let ctx_payload = serde_json::json!({ "cwd": cwd.display().to_string() });
let wrappers = collect_extension_tool_wrappers(&manager, ctx_payload).await?;
self.agent.extend_tools(wrappers);
self.agent.extensions = Some(manager.clone());
self.extensions = Some(ExtensionRegion::new(manager));
Ok(())
}
pub async fn save_and_index(&mut self) -> Result<()> {
if self.save_enabled {
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session
.flush_autosave(AutosaveFlushTrigger::Periodic)
.await?;
}
Ok(())
}
pub async fn persist_session(&mut self) -> Result<()> {
if !self.save_enabled {
return Ok(());
}
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session
.flush_autosave(AutosaveFlushTrigger::Periodic)
.await?;
Ok(())
}
pub async fn run_text(
&mut self,
input: String,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_text_with_abort(input, None, on_event).await
}
pub async fn run_text_with_abort(
&mut self,
input: String,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let outcome = self.dispatch_input_event(input, Vec::new()).await?;
let (text, images) = match outcome {
InputEventOutcome::Continue { text, images } => (text, images),
InputEventOutcome::Block { reason } => {
let message = reason.unwrap_or_else(|| "Input blocked".to_string());
return Err(Error::extension(message));
}
};
self.dispatch_before_agent_start().await;
if images.is_empty() {
self.run_agent_with_text(text, abort, on_event).await
} else {
let content = Self::build_content_blocks_for_input(&text, &images);
self.run_agent_with_content(content, abort, on_event).await
}
}
pub async fn run_with_content(
&mut self,
content: Vec<ContentBlock>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
self.run_with_content_with_abort(content, None, on_event)
.await
}
pub async fn run_with_content_with_abort(
&mut self,
content: Vec<ContentBlock>,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let (text, images) = Self::split_content_blocks_for_input(&content);
let outcome = self.dispatch_input_event(text, images).await?;
let (text, images) = match outcome {
InputEventOutcome::Continue { text, images } => (text, images),
InputEventOutcome::Block { reason } => {
let message = reason.unwrap_or_else(|| "Input blocked".to_string());
return Err(Error::extension(message));
}
};
self.dispatch_before_agent_start().await;
let content_for_agent = Self::build_content_blocks_for_input(&text, &images);
self.run_agent_with_content(content_for_agent, abort, on_event)
.await
}
async fn dispatch_input_event(
&self,
text: String,
images: Vec<ImageContent>,
) -> Result<InputEventOutcome> {
let Some(region) = &self.extensions else {
return Ok(InputEventOutcome::Continue { text, images });
};
let images_value = serde_json::to_value(&images).unwrap_or(Value::Null);
let payload = json!({
"text": text,
"images": images_value,
"source": "user",
});
let response = region
.manager()
.dispatch_event_with_response(
ExtensionEventName::Input,
Some(payload),
EXTENSION_EVENT_TIMEOUT_MS,
)
.await?;
Ok(apply_input_event_response(response, text, images))
}
async fn dispatch_before_agent_start(&self) {
if let Some(region) = &self.extensions {
if let Err(err) = region
.manager()
.dispatch_event(ExtensionEventName::BeforeAgentStart, None)
.await
{
tracing::warn!("before_agent_start extension hook failed (fail-open): {err}");
}
}
}
fn split_content_blocks_for_input(blocks: &[ContentBlock]) -> (String, Vec<ImageContent>) {
let mut text = String::new();
let mut images = Vec::new();
for block in blocks {
match block {
ContentBlock::Text(text_block) => {
if !text_block.text.trim().is_empty() {
if !text.is_empty() {
text.push('\n');
}
text.push_str(&text_block.text);
}
}
ContentBlock::Image(image) => images.push(image.clone()),
_ => {}
}
}
(text, images)
}
fn build_content_blocks_for_input(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
let mut content = Vec::new();
if !text.trim().is_empty() {
content.push(ContentBlock::Text(TextContent::new(text.to_string())));
}
for image in images {
content.push(ContentBlock::Image(image.clone()));
}
content
}
pub(crate) async fn run_agent_with_text(
&mut self,
input: String,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let on_event: AgentEventHandler = Arc::new(on_event);
let session_model = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
(
session.header.provider.clone(),
session.header.model_id.clone(),
)
};
if let (Some(provider_id), Some(model_id)) = session_model {
self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
}
self.maybe_compact(Arc::clone(&on_event)).await?;
let history = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.to_messages_for_current_path()
};
self.agent.replace_messages(history);
let start_len = self.agent.messages().len();
let user_message = Message::User(UserMessage {
content: UserContent::Text(input),
timestamp: Utc::now().timestamp_millis(),
});
{
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.append_model_message(user_message.clone());
if self.save_enabled {
session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
}
}
self.extensions_is_streaming.store(true, Ordering::SeqCst);
let on_event_for_run = Arc::clone(&on_event);
let result = self
.agent
.run_with_message_with_abort(user_message, abort, move |event| {
on_event_for_run(event);
})
.await;
self.extensions_is_streaming.store(false, Ordering::SeqCst);
let result = result?;
self.persist_new_messages(start_len + 1).await?;
Ok(result)
}
pub(crate) async fn run_agent_with_content(
&mut self,
content: Vec<ContentBlock>,
abort: Option<AbortSignal>,
on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
) -> Result<AssistantMessage> {
let on_event: AgentEventHandler = Arc::new(on_event);
let session_model = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
(
session.header.provider.clone(),
session.header.model_id.clone(),
)
};
if let (Some(provider_id), Some(model_id)) = session_model {
self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
}
self.maybe_compact(Arc::clone(&on_event)).await?;
let history = {
let cx = crate::agent_cx::AgentCx::for_request();
let session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.to_messages_for_current_path()
};
self.agent.replace_messages(history);
let start_len = self.agent.messages().len();
let user_message = Message::User(UserMessage {
content: UserContent::Blocks(content),
timestamp: Utc::now().timestamp_millis(),
});
{
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
session.append_model_message(user_message.clone());
if self.save_enabled {
session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
}
}
self.extensions_is_streaming.store(true, Ordering::SeqCst);
let on_event_for_run = Arc::clone(&on_event);
let result = self
.agent
.run_with_message_with_abort(user_message, abort, move |event| {
on_event_for_run(event);
})
.await;
self.extensions_is_streaming.store(false, Ordering::SeqCst);
let result = result?;
self.persist_new_messages(start_len + 1).await?;
Ok(result)
}
async fn persist_new_messages(&self, start_len: usize) -> Result<()> {
let new_messages = self.agent.messages()[start_len..].to_vec();
{
let cx = crate::agent_cx::AgentCx::for_request();
let mut session = self
.session
.lock(cx.cx())
.await
.map_err(|e| Error::session(e.to_string()))?;
for message in new_messages {
session.append_model_message(message);
}
if self.save_enabled {
session
.flush_autosave(AutosaveFlushTrigger::Periodic)
.await?;
}
}
Ok(())
}
}
fn log_repair_diagnostics(events: &[crate::extensions_js::ExtensionRepairEvent]) {
use std::collections::BTreeMap;
for ev in events {
tracing::info!(
event = "extension.auto_repair",
extension_id = %ev.extension_id,
pattern = %ev.pattern,
success = ev.success,
original_error = %ev.original_error,
repair_action = %ev.repair_action,
);
}
let mut by_pattern: BTreeMap<String, Vec<&str>> = BTreeMap::new();
for ev in events {
by_pattern
.entry(ev.pattern.to_string())
.or_default()
.push(&ev.extension_id);
}
let verbose = std::env::var("PI_AUTO_REPAIR_VERBOSE")
.is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
if verbose {
eprintln!(
"[auto-repair] {} extension{} auto-repaired:",
events.len(),
if events.len() == 1 { "" } else { "s" }
);
for ev in events {
eprintln!(
" {}: {} ({})",
ev.pattern, ev.extension_id, ev.repair_action
);
}
} else {
let patterns: Vec<String> = by_pattern
.iter()
.map(|(pat, ids)| format!("{pat}:{}", ids.len()))
.collect();
tracing::info!(
event = "extension.auto_repair.summary",
count = events.len(),
patterns = %patterns.join(", "),
"auto-repaired {} extension(s)",
events.len(),
);
}
}
const BLOCK_IMAGES_PLACEHOLDER: &str = "Image reading is disabled.";
#[derive(Debug, Default, Clone, Copy)]
struct ImageFilterStats {
removed_images: usize,
affected_messages: usize,
}
fn filter_images_for_provider(messages: &mut [Message]) -> ImageFilterStats {
let mut stats = ImageFilterStats::default();
for message in messages {
let removed = filter_images_from_message(message);
if removed > 0 {
stats.removed_images += removed;
stats.affected_messages += 1;
}
}
stats
}
fn filter_images_from_message(message: &mut Message) -> usize {
match message {
Message::User(user) => match &mut user.content {
UserContent::Text(_) => 0,
UserContent::Blocks(blocks) => filter_image_blocks(blocks),
},
Message::Assistant(assistant) => {
let assistant = Arc::make_mut(assistant);
filter_image_blocks(&mut assistant.content)
}
Message::ToolResult(tool_result) => {
filter_image_blocks(&mut Arc::make_mut(tool_result).content)
}
Message::Custom(_) => 0,
}
}
fn filter_image_blocks(blocks: &mut Vec<ContentBlock>) -> usize {
let mut removed = 0usize;
let mut filtered = Vec::with_capacity(blocks.len());
for block in blocks.drain(..) {
match block {
ContentBlock::Image(_) => {
removed += 1;
let previous_is_placeholder =
filtered
.last()
.is_some_and(|prev| matches!(prev, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER));
if !previous_is_placeholder {
filtered.push(ContentBlock::Text(TextContent::new(
BLOCK_IMAGES_PLACEHOLDER,
)));
}
}
other => filtered.push(other),
}
}
*blocks = filtered;
removed
}
fn extract_tool_calls(content: &[ContentBlock]) -> Vec<ToolCall> {
content
.iter()
.filter_map(|block| {
if let ContentBlock::ToolCall(tc) = block {
Some(tc.clone())
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthCredential;
use crate::provider::{InputType, Model, ModelCost};
use async_trait::async_trait;
use futures::Stream;
use std::collections::HashMap;
use std::path::Path;
use std::pin::Pin;
fn user_message(text: &str) -> Message {
Message::User(UserMessage {
content: UserContent::Text(text.to_string()),
timestamp: 0,
})
}
fn assert_user_text(message: &Message, expected: &str) {
assert!(
matches!(
message,
Message::User(UserMessage {
content: UserContent::Text(_),
..
})
),
"expected user text message, got {message:?}"
);
if let Message::User(UserMessage {
content: UserContent::Text(text),
..
}) = message
{
assert_eq!(text, expected);
}
}
fn sample_image_block() -> ContentBlock {
ContentBlock::Image(ImageContent {
data: "aGVsbG8=".to_string(),
mime_type: "image/png".to_string(),
})
}
fn image_count_in_message(message: &Message) -> usize {
let count_images = |blocks: &[ContentBlock]| {
blocks
.iter()
.filter(|block| matches!(block, ContentBlock::Image(_)))
.count()
};
match message {
Message::User(UserMessage {
content: UserContent::Blocks(blocks),
..
}) => count_images(blocks),
Message::Assistant(msg) => count_images(&msg.content),
Message::ToolResult(tool_result) => count_images(&tool_result.content),
Message::User(UserMessage {
content: UserContent::Text(_),
..
})
| Message::Custom(_) => 0,
}
}
#[derive(Debug)]
struct SilentProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for SilentProvider {
fn name(&self) -> &str {
"silent-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &Context<'_>,
_options: &StreamOptions,
) -> crate::error::Result<
Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
> {
Ok(Box::pin(futures::stream::empty()))
}
}
#[test]
fn test_extract_tool_calls() {
let content = vec![
ContentBlock::Text(TextContent::new("Hello")),
ContentBlock::ToolCall(ToolCall {
id: "tc1".to_string(),
name: "read".to_string(),
arguments: serde_json::json!({"path": "file.txt"}),
thought_signature: None,
}),
ContentBlock::Text(TextContent::new("World")),
ContentBlock::ToolCall(ToolCall {
id: "tc2".to_string(),
name: "bash".to_string(),
arguments: serde_json::json!({"command": "ls"}),
thought_signature: None,
}),
];
let tool_calls = extract_tool_calls(&content);
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_calls[0].name, "read");
assert_eq!(tool_calls[1].name, "bash");
}
#[test]
fn test_agent_config_default() {
let config = AgentConfig::default();
assert_eq!(config.max_tool_iterations, 50);
assert!(config.system_prompt.is_none());
assert!(!config.block_images);
}
#[test]
fn filter_image_blocks_replaces_images_with_deduped_placeholder_text() {
let mut blocks = vec![
sample_image_block(),
sample_image_block(),
ContentBlock::Text(TextContent::new("tail")),
sample_image_block(),
];
let removed = filter_image_blocks(&mut blocks);
assert_eq!(removed, 3);
assert!(
!blocks
.iter()
.any(|block| matches!(block, ContentBlock::Image(_)))
);
assert!(matches!(
blocks.first(),
Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
));
assert!(matches!(
blocks.get(1),
Some(ContentBlock::Text(TextContent { text, .. })) if text == "tail"
));
assert!(matches!(
blocks.get(2),
Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
));
}
#[test]
fn filter_images_for_provider_filters_images_from_all_block_message_types() {
let mut messages = vec![
Message::User(UserMessage {
content: UserContent::Blocks(vec![
ContentBlock::Text(TextContent::new("hello")),
sample_image_block(),
]),
timestamp: 0,
}),
Message::Assistant(Arc::new(AssistantMessage {
content: vec![sample_image_block()],
api: "test".to_string(),
provider: "test".to_string(),
model: "test".to_string(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
})),
Message::tool_result(ToolResultMessage {
tool_call_id: "tc1".to_string(),
tool_name: "read".to_string(),
content: vec![
sample_image_block(),
ContentBlock::Text(TextContent::new("ok")),
],
details: None,
is_error: false,
timestamp: 0,
}),
];
let stats = filter_images_for_provider(&mut messages);
assert_eq!(stats.removed_images, 3);
assert_eq!(stats.affected_messages, 3);
assert_eq!(
messages.iter().map(image_count_in_message).sum::<usize>(),
0,
"no images should remain in provider-bound context"
);
}
#[test]
fn build_context_strips_images_when_block_images_enabled() {
let mut agent = Agent::new(
Arc::new(SilentProvider),
ToolRegistry::new(&[], Path::new("."), None),
AgentConfig {
system_prompt: None,
max_tool_iterations: 50,
stream_options: StreamOptions::default(),
block_images: true,
},
);
agent.add_message(Message::User(UserMessage {
content: UserContent::Blocks(vec![sample_image_block()]),
timestamp: 0,
}));
let context = agent.build_context();
assert_eq!(context.messages.len(), 1);
assert_eq!(image_count_in_message(&context.messages[0]), 0);
assert!(matches!(
&context.messages[0],
Message::User(UserMessage {
content: UserContent::Blocks(blocks),
..
}) if blocks
.iter()
.any(|block| matches!(block, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER))
));
}
#[test]
fn build_context_keeps_images_when_block_images_disabled() {
let mut agent = Agent::new(
Arc::new(SilentProvider),
ToolRegistry::new(&[], Path::new("."), None),
AgentConfig {
system_prompt: None,
max_tool_iterations: 50,
stream_options: StreamOptions::default(),
block_images: false,
},
);
agent.add_message(Message::User(UserMessage {
content: UserContent::Blocks(vec![sample_image_block()]),
timestamp: 0,
}));
let context = agent.build_context();
assert_eq!(context.messages.len(), 1);
assert_eq!(image_count_in_message(&context.messages[0]), 1);
}
#[test]
fn auto_compaction_start_serializes_with_pi_mono_compatible_type_tag() {
let event = AgentEvent::AutoCompactionStart {
reason: "threshold".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_start");
assert_eq!(json["reason"], "threshold");
}
#[test]
fn auto_compaction_end_serializes_with_pi_mono_compatible_fields() {
let event = AgentEvent::AutoCompactionEnd {
result: Some(serde_json::json!({"tokens_before": 5000, "tokens_after": 2000})),
aborted: false,
will_retry: false,
error_message: None,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_end");
assert_eq!(json["aborted"], false);
assert_eq!(json["willRetry"], false);
assert!(json.get("errorMessage").is_none()); assert!(json["result"].is_object());
}
#[test]
fn auto_compaction_end_includes_error_message_when_present() {
let event = AgentEvent::AutoCompactionEnd {
result: None,
aborted: true,
will_retry: false,
error_message: Some("Compaction failed".to_string()),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_end");
assert_eq!(json["aborted"], true);
assert_eq!(json["errorMessage"], "Compaction failed");
}
#[test]
fn auto_retry_start_serializes_with_camel_case_fields() {
let event = AgentEvent::AutoRetryStart {
attempt: 1,
max_attempts: 3,
delay_ms: 2000,
error_message: "Rate limited".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_start");
assert_eq!(json["attempt"], 1);
assert_eq!(json["maxAttempts"], 3);
assert_eq!(json["delayMs"], 2000);
assert_eq!(json["errorMessage"], "Rate limited");
}
#[test]
fn auto_retry_end_serializes_success_and_omits_null_final_error() {
let event = AgentEvent::AutoRetryEnd {
success: true,
attempt: 2,
final_error: None,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_end");
assert_eq!(json["success"], true);
assert_eq!(json["attempt"], 2);
assert!(json.get("finalError").is_none());
}
#[test]
fn auto_retry_end_includes_final_error_on_failure() {
let event = AgentEvent::AutoRetryEnd {
success: false,
attempt: 3,
final_error: Some("Max retries exceeded".to_string()),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_end");
assert_eq!(json["success"], false);
assert_eq!(json["attempt"], 3);
assert_eq!(json["finalError"], "Max retries exceeded");
}
#[test]
fn message_queue_push_increments_seq_and_counts_both_queues() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
assert_eq!(queue.pending_count(), 0);
assert_eq!(queue.push_steering(user_message("s1")), 0);
assert_eq!(queue.push_follow_up(user_message("f1")), 1);
assert_eq!(queue.push_steering(user_message("s2")), 2);
assert_eq!(queue.pending_count(), 3);
}
#[test]
fn message_queue_pop_steering_one_at_a_time_preserves_order() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
queue.push_steering(user_message("s1"));
queue.push_steering(user_message("s2"));
let first = queue.pop_steering();
assert_eq!(first.len(), 1);
assert_user_text(&first[0], "s1");
assert_eq!(queue.pending_count(), 1);
let second = queue.pop_steering();
assert_eq!(second.len(), 1);
assert_user_text(&second[0], "s2");
assert_eq!(queue.pending_count(), 0);
let empty = queue.pop_steering();
assert!(empty.is_empty());
}
#[test]
fn message_queue_pop_respects_queue_modes_per_kind() {
let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
queue.push_steering(user_message("s1"));
queue.push_steering(user_message("s2"));
queue.push_follow_up(user_message("f1"));
queue.push_follow_up(user_message("f2"));
let steering = queue.pop_steering();
assert_eq!(steering.len(), 2);
assert_user_text(&steering[0], "s1");
assert_user_text(&steering[1], "s2");
assert_eq!(queue.pending_count(), 2);
let follow_up = queue.pop_follow_up();
assert_eq!(follow_up.len(), 1);
assert_user_text(&follow_up[0], "f1");
assert_eq!(queue.pending_count(), 1);
let follow_up = queue.pop_follow_up();
assert_eq!(follow_up.len(), 1);
assert_user_text(&follow_up[0], "f2");
assert_eq!(queue.pending_count(), 0);
}
#[test]
fn message_queue_set_modes_applies_to_existing_messages() {
let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
queue.push_steering(user_message("s1"));
queue.push_steering(user_message("s2"));
let first = queue.pop_steering();
assert_eq!(first.len(), 1);
assert_user_text(&first[0], "s1");
queue.set_modes(QueueMode::All, QueueMode::OneAtATime);
let remaining = queue.pop_steering();
assert_eq!(remaining.len(), 1);
assert_user_text(&remaining[0], "s2");
}
fn build_switch_test_session(auth: &AuthStorage) -> AgentSession {
let registry = ModelRegistry::load(auth, None);
let current_entry = registry
.find("anthropic", "claude-sonnet-4-5")
.expect("anthropic model in registry");
let provider = crate::providers::create_provider(¤t_entry, None)
.expect("create anthropic provider");
let tools = ToolRegistry::new(&[], Path::new("."), None);
let mut stream_options = StreamOptions {
api_key: Some("stale-key".to_string()),
..Default::default()
};
let _ = stream_options
.headers
.insert("x-stale-header".to_string(), "stale-value".to_string());
let agent = Agent::new(
provider,
tools,
AgentConfig {
system_prompt: None,
max_tool_iterations: 50,
stream_options,
block_images: false,
},
);
let mut session = Session::in_memory();
session.header.provider = Some("openai".to_string());
session.header.model_id = Some("gpt-4o".to_string());
let mut agent_session = AgentSession::new(
agent,
Arc::new(Mutex::new(session)),
false,
ResolvedCompactionSettings::default(),
);
agent_session.set_model_registry(registry);
agent_session.set_auth_storage(auth.clone());
agent_session
}
#[test]
fn apply_session_model_selection_updates_stream_credentials_and_headers() {
let dir = tempfile::tempdir().expect("tempdir");
let auth_path = dir.path().join("auth.json");
let mut auth = AuthStorage::load(auth_path).expect("load auth");
auth.set(
"anthropic",
AuthCredential::ApiKey {
key: "anthropic-key".to_string(),
},
);
auth.set(
"openai",
AuthCredential::ApiKey {
key: "openai-key".to_string(),
},
);
let mut agent_session = build_switch_test_session(&auth);
agent_session.apply_session_model_selection("openai", "gpt-4o");
assert_eq!(agent_session.agent.provider().name(), "openai");
assert_eq!(agent_session.agent.provider().model_id(), "gpt-4o");
assert_eq!(
agent_session.agent.stream_options().api_key.as_deref(),
Some("openai-key")
);
assert!(
agent_session.agent.stream_options().headers.is_empty(),
"stream headers should be refreshed from selected model entry"
);
}
#[test]
fn apply_session_model_selection_clears_stale_key_when_target_has_no_key() {
let dir = tempfile::tempdir().expect("tempdir");
let auth_path = dir.path().join("auth.json");
let mut auth = AuthStorage::load(auth_path).expect("load auth");
auth.set(
"anthropic",
AuthCredential::ApiKey {
key: "anthropic-key".to_string(),
},
);
let mut agent_session = build_switch_test_session(&auth);
agent_session.apply_session_model_selection("openai", "gpt-4o");
assert_eq!(agent_session.agent.provider().name(), "openai");
assert_eq!(
agent_session.agent.stream_options().api_key,
None,
"stale key must be cleared when target model has no configured key"
);
}
#[test]
fn apply_session_model_selection_treats_blank_model_key_as_missing() {
let dir = tempfile::tempdir().expect("tempdir");
let auth_path = dir.path().join("auth.json");
let auth = AuthStorage::load(auth_path).expect("load auth");
let mut registry = ModelRegistry::load(&auth, None);
registry.merge_entries(vec![ModelEntry {
model: Model {
id: "blank-model".to_string(),
name: "Blank Model".to_string(),
api: "openai-completions".to_string(),
provider: "acme".to_string(),
base_url: "https://example.invalid/v1".to_string(),
reasoning: true,
input: vec![InputType::Text],
cost: ModelCost {
input: 0.0,
output: 0.0,
cache_read: 0.0,
cache_write: 0.0,
},
context_window: 128_000,
max_tokens: 8_192,
headers: HashMap::new(),
},
api_key: Some(" ".to_string()),
headers: HashMap::new(),
auth_header: true,
compat: None,
oauth_config: None,
}]);
let mut agent_session = build_switch_test_session(&auth);
agent_session.set_model_registry(registry);
agent_session.apply_session_model_selection("acme", "blank-model");
assert_eq!(agent_session.agent.provider().name(), "acme");
assert_eq!(
agent_session.agent.stream_options().api_key,
None,
"blank model keys must not be treated as valid credentials"
);
}
#[test]
fn auto_compaction_start_serializes_to_pi_mono_format() {
let event = AgentEvent::AutoCompactionStart {
reason: "threshold".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_start");
assert_eq!(json["reason"], "threshold");
}
#[test]
fn auto_compaction_end_serializes_to_pi_mono_format() {
let event = AgentEvent::AutoCompactionEnd {
result: Some(serde_json::json!({
"summary": "Compacted",
"firstKeptEntryId": "abc123",
"tokensBefore": 50000,
"details": { "readFiles": [], "modifiedFiles": [] }
})),
aborted: false,
will_retry: true,
error_message: None,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_end");
assert!(json["result"].is_object());
assert_eq!(json["aborted"], false);
assert_eq!(json["willRetry"], true);
assert!(json.get("errorMessage").is_none());
}
#[test]
fn auto_compaction_end_with_error_serializes_error_message() {
let event = AgentEvent::AutoCompactionEnd {
result: None,
aborted: false,
will_retry: false,
error_message: Some("compaction failed".to_string()),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_compaction_end");
assert!(json["result"].is_null());
assert_eq!(json["errorMessage"], "compaction failed");
}
#[test]
fn auto_retry_start_serializes_to_pi_mono_format() {
let event = AgentEvent::AutoRetryStart {
attempt: 2,
max_attempts: 3,
delay_ms: 4000,
error_message: "rate limited".to_string(),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_start");
assert_eq!(json["attempt"], 2);
assert_eq!(json["maxAttempts"], 3);
assert_eq!(json["delayMs"], 4000);
assert_eq!(json["errorMessage"], "rate limited");
}
#[test]
fn auto_retry_end_success_serializes_to_pi_mono_format() {
let event = AgentEvent::AutoRetryEnd {
success: true,
attempt: 2,
final_error: None,
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_end");
assert_eq!(json["success"], true);
assert_eq!(json["attempt"], 2);
assert!(json.get("finalError").is_none());
}
#[test]
fn auto_retry_end_failure_serializes_final_error() {
let event = AgentEvent::AutoRetryEnd {
success: false,
attempt: 3,
final_error: Some("max retries exceeded".to_string()),
};
let json = serde_json::to_value(&event).unwrap();
assert_eq!(json["type"], "auto_retry_end");
assert_eq!(json["success"], false);
assert_eq!(json["attempt"], 3);
assert_eq!(json["finalError"], "max retries exceeded");
}
}