#![allow(clippy::significant_drop_tightening)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::too_many_lines)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::ignored_unit_patterns)]
#![allow(clippy::needless_pass_by_value)]
use crate::agent::{AbortHandle, AgentEvent, AgentSession, QueueMode};
use crate::agent_cx::AgentCx;
use crate::auth::AuthStorage;
use crate::compaction::{
ResolvedCompactionSettings, compact, compaction_details_to_value, prepare_compaction,
};
use crate::config::Config;
use crate::error::{Error, Result};
use crate::error_hints;
use crate::extensions::{ExtensionManager, ExtensionUiRequest, ExtensionUiResponse};
use crate::model::{
ContentBlock, ImageContent, Message, StopReason, TextContent, UserContent, UserMessage,
};
use crate::models::ModelEntry;
use crate::provider_metadata::{canonical_provider_id, provider_metadata};
use crate::providers;
use crate::resources::ResourceLoader;
use crate::session::SessionMessage;
use crate::tools::{DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, truncate_tail};
use asupersync::channel::{mpsc, oneshot};
use asupersync::runtime::RuntimeHandle;
use asupersync::sync::{Mutex, OwnedMutexGuard};
use asupersync::time::{sleep, wall_now};
use memchr::memchr_iter;
use serde_json::{Value, json};
use std::collections::VecDeque;
use std::io::{self, BufRead, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
fn provider_ids_match(left: &str, right: &str) -> bool {
let left = left.trim();
let right = right.trim();
if left.eq_ignore_ascii_case(right) {
return true;
}
let left_canonical = canonical_provider_id(left).unwrap_or(left);
let right_canonical = canonical_provider_id(right).unwrap_or(right);
left_canonical.eq_ignore_ascii_case(right)
|| right_canonical.eq_ignore_ascii_case(left)
|| left_canonical.eq_ignore_ascii_case(right_canonical)
}
#[derive(Clone)]
pub struct RpcOptions {
pub config: Config,
pub resources: ResourceLoader,
pub available_models: Vec<ModelEntry>,
pub scoped_models: Vec<RpcScopedModel>,
pub auth: AuthStorage,
pub runtime_handle: RuntimeHandle,
}
#[derive(Debug, Clone)]
pub struct RpcScopedModel {
pub model: ModelEntry,
pub thinking_level: Option<crate::model::ThinkingLevel>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum StreamingBehavior {
Steer,
FollowUp,
}
#[derive(Debug, Clone)]
struct RpcStateSnapshot {
steering_count: usize,
follow_up_count: usize,
steering_mode: QueueMode,
follow_up_mode: QueueMode,
auto_compaction_enabled: bool,
auto_retry_enabled: bool,
}
impl From<&RpcSharedState> for RpcStateSnapshot {
fn from(state: &RpcSharedState) -> Self {
Self {
steering_count: state.steering.len(),
follow_up_count: state.follow_up.len(),
steering_mode: state.steering_mode,
follow_up_mode: state.follow_up_mode,
auto_compaction_enabled: state.auto_compaction_enabled,
auto_retry_enabled: state.auto_retry_enabled,
}
}
}
impl RpcStateSnapshot {
const fn pending_count(&self) -> usize {
self.steering_count + self.follow_up_count
}
}
use crate::config::parse_queue_mode;
fn parse_streaming_behavior(value: Option<&Value>) -> Result<Option<StreamingBehavior>> {
let Some(value) = value else {
return Ok(None);
};
let Some(s) = value.as_str() else {
return Err(Error::validation("streamingBehavior must be a string"));
};
match s {
"steer" => Ok(Some(StreamingBehavior::Steer)),
"follow-up" | "followUp" => Ok(Some(StreamingBehavior::FollowUp)),
_ => Err(Error::validation(format!("Invalid streamingBehavior: {s}"))),
}
}
fn normalize_command_type(command_type: &str) -> &str {
match command_type {
"follow-up" | "followUp" | "queue-follow-up" | "queueFollowUp" => "follow_up",
"get-state" | "getState" => "get_state",
"set-model" | "setModel" => "set_model",
"set-steering-mode" | "setSteeringMode" => "set_steering_mode",
"set-follow-up-mode" | "setFollowUpMode" => "set_follow_up_mode",
"set-auto-compaction" | "setAutoCompaction" => "set_auto_compaction",
"set-auto-retry" | "setAutoRetry" => "set_auto_retry",
_ => command_type,
}
}
fn build_user_message(text: &str, images: &[ImageContent]) -> Message {
let timestamp = chrono::Utc::now().timestamp_millis();
if images.is_empty() {
return Message::User(UserMessage {
content: UserContent::Text(text.to_string()),
timestamp,
});
}
let mut blocks = vec![ContentBlock::Text(TextContent::new(text.to_string()))];
for image in images {
blocks.push(ContentBlock::Image(image.clone()));
}
Message::User(UserMessage {
content: UserContent::Blocks(blocks),
timestamp,
})
}
fn is_extension_command(message: &str, expanded: &str) -> bool {
message.trim_start().starts_with('/') && message == expanded
}
fn try_send_line_with_backpressure(tx: &mpsc::Sender<String>, mut line: String) -> bool {
loop {
match tx.try_send(line) {
Ok(()) => return true,
Err(mpsc::SendError::Full(unsent)) => {
line = unsent;
std::thread::sleep(Duration::from_millis(10));
}
Err(mpsc::SendError::Disconnected(_) | mpsc::SendError::Cancelled(_)) => {
return false;
}
}
}
}
#[derive(Debug)]
struct RpcSharedState {
steering: VecDeque<Message>,
follow_up: VecDeque<Message>,
steering_mode: QueueMode,
follow_up_mode: QueueMode,
auto_compaction_enabled: bool,
auto_retry_enabled: bool,
}
const MAX_RPC_PENDING_MESSAGES: usize = 128;
impl RpcSharedState {
fn new(config: &Config) -> Self {
Self {
steering: VecDeque::new(),
follow_up: VecDeque::new(),
steering_mode: config.steering_queue_mode(),
follow_up_mode: config.follow_up_queue_mode(),
auto_compaction_enabled: config.compaction_enabled(),
auto_retry_enabled: config.retry_enabled(),
}
}
fn pending_count(&self) -> usize {
self.steering.len() + self.follow_up.len()
}
fn push_steering(&mut self, message: Message) -> Result<()> {
if self.steering.len() >= MAX_RPC_PENDING_MESSAGES {
return Err(Error::session(
"Steering queue is full (Do you have too many pending commands?)",
));
}
self.steering.push_back(message);
Ok(())
}
fn push_follow_up(&mut self, message: Message) -> Result<()> {
if self.follow_up.len() >= MAX_RPC_PENDING_MESSAGES {
return Err(Error::session("Follow-up queue is full"));
}
self.follow_up.push_back(message);
Ok(())
}
fn pop_steering(&mut self) -> Vec<Message> {
match self.steering_mode {
QueueMode::All => self.steering.drain(..).collect(),
QueueMode::OneAtATime => self.steering.pop_front().into_iter().collect(),
}
}
fn pop_follow_up(&mut self) -> Vec<Message> {
match self.follow_up_mode {
QueueMode::All => self.follow_up.drain(..).collect(),
QueueMode::OneAtATime => self.follow_up.pop_front().into_iter().collect(),
}
}
}
struct RunningBash {
id: String,
abort_tx: oneshot::Sender<()>,
}
#[derive(Debug, Default)]
struct RpcUiBridgeState {
active: Option<ExtensionUiRequest>,
queue: VecDeque<ExtensionUiRequest>,
}
pub async fn run_stdio(mut session: AgentSession, options: RpcOptions) -> Result<()> {
session.agent.set_queue_modes(
options.config.steering_queue_mode(),
options.config.follow_up_queue_mode(),
);
let (in_tx, in_rx) = mpsc::channel::<String>(1024);
let (out_tx, out_rx) = std::sync::mpsc::channel::<String>();
std::thread::spawn(move || {
let stdin = io::stdin();
let mut reader = io::BufReader::new(stdin.lock());
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line) {
Ok(0) | Err(_) => break,
Ok(_) => {
let line_to_send = std::mem::take(&mut line);
if !try_send_line_with_backpressure(&in_tx, line_to_send) {
break;
}
}
}
}
});
std::thread::spawn(move || {
let stdout = io::stdout();
let mut writer = io::BufWriter::new(stdout.lock());
for line in out_rx {
if writer.write_all(line.as_bytes()).is_err() {
break;
}
if writer.write_all(b"\n").is_err() {
break;
}
if writer.flush().is_err() {
break;
}
}
});
run(session, options, in_rx, out_tx).await
}
#[allow(clippy::too_many_lines)]
#[allow(
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee
)]
pub async fn run(
session: AgentSession,
options: RpcOptions,
in_rx: mpsc::Receiver<String>,
out_tx: std::sync::mpsc::Sender<String>,
) -> Result<()> {
let cx = AgentCx::for_request();
let session_handle = Arc::clone(&session.session);
let session = Arc::new(Mutex::new(session));
let shared_state = Arc::new(Mutex::new(RpcSharedState::new(&options.config)));
let is_streaming = Arc::new(AtomicBool::new(false));
let is_compacting = Arc::new(AtomicBool::new(false));
let abort_handle: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
let bash_state: Arc<Mutex<Option<RunningBash>>> = Arc::new(Mutex::new(None));
let retry_abort = Arc::new(AtomicBool::new(false));
{
use futures::future::BoxFuture;
let steering_state = Arc::clone(&shared_state);
let follow_state = Arc::clone(&shared_state);
let steering_cx = cx.clone();
let follow_cx = cx.clone();
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
let steering_state = Arc::clone(&steering_state);
let steering_cx = steering_cx.clone();
Box::pin(async move {
steering_state
.lock(&steering_cx)
.await
.map_or_else(|_| Vec::new(), |mut state| state.pop_steering())
})
};
let follow_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
let follow_state = Arc::clone(&follow_state);
let follow_cx = follow_cx.clone();
Box::pin(async move {
follow_state
.lock(&follow_cx)
.await
.map_or_else(|_| Vec::new(), |mut state| state.pop_follow_up())
})
};
guard.agent.register_message_fetchers(
Some(Arc::new(steering_fetcher)),
Some(Arc::new(follow_fetcher)),
);
}
let rpc_extension_manager = {
let cx_ui = cx.clone();
let guard = session
.lock(&cx_ui)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
guard
.extensions
.as_ref()
.map(crate::extensions::ExtensionRegion::manager)
.cloned()
};
let rpc_ui_state: Option<Arc<Mutex<RpcUiBridgeState>>> = rpc_extension_manager
.as_ref()
.map(|_| Arc::new(Mutex::new(RpcUiBridgeState::default())));
if let Some(ref manager) = rpc_extension_manager {
let (extension_ui_tx, extension_ui_rx) =
asupersync::channel::mpsc::channel::<ExtensionUiRequest>(64);
manager.set_ui_sender(extension_ui_tx);
let out_tx_ui = out_tx.clone();
let ui_state = rpc_ui_state
.as_ref()
.map(Arc::clone)
.expect("rpc ui state should exist when extension manager exists");
let manager_ui = (*manager).clone();
let runtime_handle_ui = options.runtime_handle.clone();
options.runtime_handle.spawn(async move {
const MAX_UI_PENDING_REQUESTS: usize = 64;
let cx = AgentCx::for_request();
while let Ok(request) = extension_ui_rx.recv(&cx).await {
if request.expects_response() {
let emit_now = {
let Ok(mut guard) = ui_state.lock(&cx).await else {
return;
};
if guard.active.is_none() {
guard.active = Some(request.clone());
true
} else if guard.queue.len() < MAX_UI_PENDING_REQUESTS {
guard.queue.push_back(request.clone());
false
} else {
drop(guard);
let _ = manager_ui.respond_ui(ExtensionUiResponse {
id: request.id.clone(),
value: None,
cancelled: true,
});
false
}
};
if emit_now {
rpc_emit_extension_ui_request(
&runtime_handle_ui,
Arc::clone(&ui_state),
manager_ui.clone(),
out_tx_ui.clone(),
request,
);
}
} else {
let rpc_event = request.to_rpc_event();
let _ = out_tx_ui.send(event(&rpc_event));
}
}
});
}
while let Ok(line) = in_rx.recv(&cx).await {
if line.trim().is_empty() {
continue;
}
let parsed: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(err) => {
let resp = response_error(None, "parse", format!("Failed to parse command: {err}"));
let _ = out_tx.send(resp);
continue;
}
};
let Some(command_type_raw) = parsed.get("type").and_then(Value::as_str) else {
let resp = response_error(None, "parse", "Missing command type".to_string());
let _ = out_tx.send(resp);
continue;
};
let command_type = normalize_command_type(command_type_raw);
let id = parsed.get("id").and_then(Value::as_str).map(str::to_string);
match command_type {
"prompt" => {
let Some(message) = parsed
.get("message")
.and_then(Value::as_str)
.map(String::from)
else {
let resp = response_error(id, "prompt", "Missing message".to_string());
let _ = out_tx.send(resp);
continue;
};
let images = match parse_prompt_images(parsed.get("images")) {
Ok(images) => images,
Err(err) => {
let resp = response_error_with_hints(id, "prompt", &err);
let _ = out_tx.send(resp);
continue;
}
};
let streaming_behavior =
match parse_streaming_behavior(parsed.get("streamingBehavior")) {
Ok(value) => value,
Err(err) => {
let resp = response_error_with_hints(id, "prompt", &err);
let _ = out_tx.send(resp);
continue;
}
};
let expanded = options.resources.expand_input(&message);
if is_streaming.load(Ordering::SeqCst) {
if streaming_behavior.is_none() {
let resp = response_error(
id,
"prompt",
"Agent is currently streaming; specify streamingBehavior".to_string(),
);
let _ = out_tx.send(resp);
continue;
}
let queued_result = {
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
match streaming_behavior {
Some(StreamingBehavior::Steer) => {
state.push_steering(build_user_message(&expanded, &images))
}
Some(StreamingBehavior::FollowUp) => {
state.push_follow_up(build_user_message(&expanded, &images))
}
None => Ok(()), }
};
match queued_result {
Ok(()) => {
let _ = out_tx.send(response_ok(id, "prompt", None));
}
Err(err) => {
let resp = response_error_with_hints(id, "prompt", &err);
let _ = out_tx.send(resp);
}
}
continue;
}
let _ = out_tx.send(response_ok(id, "prompt", None));
is_streaming.store(true, Ordering::SeqCst);
let out_tx = out_tx.clone();
let session = Arc::clone(&session);
let shared_state = Arc::clone(&shared_state);
let is_streaming = Arc::clone(&is_streaming);
let is_compacting = Arc::clone(&is_compacting);
let abort_handle_slot = Arc::clone(&abort_handle);
let retry_abort = retry_abort.clone();
let options = options.clone();
let expanded = expanded.clone();
let runtime_handle = options.runtime_handle.clone();
runtime_handle.spawn(async move {
let cx = AgentCx::for_request();
run_prompt_with_retry(
session,
shared_state,
is_streaming,
is_compacting,
abort_handle_slot,
out_tx,
retry_abort,
options,
expanded,
images,
cx,
)
.await;
});
}
"steer" => {
let Some(message) = parsed
.get("message")
.and_then(Value::as_str)
.map(String::from)
else {
let resp = response_error(id, "steer", "Missing message".to_string());
let _ = out_tx.send(resp);
continue;
};
let expanded = options.resources.expand_input(&message);
if is_extension_command(&message, &expanded) {
let resp = response_error(
id,
"steer",
"Extension commands are not allowed with steer".to_string(),
);
let _ = out_tx.send(resp);
continue;
}
if is_streaming.load(Ordering::SeqCst) {
let result = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?
.push_steering(build_user_message(&expanded, &[]));
match result {
Ok(()) => {
let _ = out_tx.send(response_ok(id, "steer", None));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "steer", &err));
}
}
continue;
}
let _ = out_tx.send(response_ok(id, "steer", None));
is_streaming.store(true, Ordering::SeqCst);
let out_tx = out_tx.clone();
let session = Arc::clone(&session);
let shared_state = Arc::clone(&shared_state);
let is_streaming = Arc::clone(&is_streaming);
let is_compacting = Arc::clone(&is_compacting);
let abort_handle_slot = Arc::clone(&abort_handle);
let retry_abort = retry_abort.clone();
let options = options.clone();
let expanded = expanded.clone();
let runtime_handle = options.runtime_handle.clone();
runtime_handle.spawn(async move {
let cx = AgentCx::for_request();
run_prompt_with_retry(
session,
shared_state,
is_streaming,
is_compacting,
abort_handle_slot,
out_tx,
retry_abort,
options,
expanded,
Vec::new(),
cx,
)
.await;
});
}
"follow_up" => {
let Some(message) = parsed
.get("message")
.and_then(Value::as_str)
.map(String::from)
else {
let resp = response_error(id, "follow_up", "Missing message".to_string());
let _ = out_tx.send(resp);
continue;
};
let expanded = options.resources.expand_input(&message);
if is_extension_command(&message, &expanded) {
let resp = response_error(
id,
"follow_up",
"Extension commands are not allowed with follow_up".to_string(),
);
let _ = out_tx.send(resp);
continue;
}
if is_streaming.load(Ordering::SeqCst) {
let result = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?
.push_follow_up(build_user_message(&expanded, &[]));
match result {
Ok(()) => {
let _ = out_tx.send(response_ok(id, "follow_up", None));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "follow_up", &err));
}
}
continue;
}
let _ = out_tx.send(response_ok(id, "follow_up", None));
is_streaming.store(true, Ordering::SeqCst);
let out_tx = out_tx.clone();
let session = Arc::clone(&session);
let shared_state = Arc::clone(&shared_state);
let is_streaming = Arc::clone(&is_streaming);
let is_compacting = Arc::clone(&is_compacting);
let abort_handle_slot = Arc::clone(&abort_handle);
let retry_abort = retry_abort.clone();
let options = options.clone();
let expanded = expanded.clone();
let runtime_handle = options.runtime_handle.clone();
runtime_handle.spawn(async move {
let cx = AgentCx::for_request();
run_prompt_with_retry(
session,
shared_state,
is_streaming,
is_compacting,
abort_handle_slot,
out_tx,
retry_abort,
options,
expanded,
Vec::new(),
cx,
)
.await;
});
}
"abort" => {
let handle = abort_handle
.lock(&cx)
.await
.map_err(|err| Error::session(format!("abort lock failed: {err}")))?
.clone();
if let Some(handle) = handle {
handle.abort();
}
let _ = out_tx.send(response_ok(id, "abort", None));
}
"get_state" => {
let snapshot = {
let state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
RpcStateSnapshot::from(&*state)
};
let data = {
let inner_session = session_handle.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
session_state(
&inner_session,
&options,
&snapshot,
is_streaming.load(Ordering::SeqCst),
is_compacting.load(Ordering::SeqCst),
)
};
let _ = out_tx.send(response_ok(id, "get_state", Some(data)));
}
"get_session_stats" => {
let data = {
let inner_session = session_handle.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
session_stats(&inner_session)
};
let _ = out_tx.send(response_ok(id, "get_session_stats", Some(data)));
}
"get_messages" => {
let messages = {
let inner_session = session_handle.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner_session
.entries_for_current_path()
.iter()
.filter_map(|entry| match entry {
crate::session::SessionEntry::Message(msg) => match msg.message {
SessionMessage::User { .. }
| SessionMessage::Assistant { .. }
| SessionMessage::ToolResult { .. }
| SessionMessage::BashExecution { .. } => Some(msg.message.clone()),
_ => None,
},
_ => None,
})
.collect::<Vec<_>>()
};
let messages = messages
.into_iter()
.map(rpc_session_message_value)
.collect::<Vec<_>>();
let _ = out_tx.send(response_ok(
id,
"get_messages",
Some(json!({ "messages": messages })),
));
}
"get_available_models" => {
let models = options
.available_models
.iter()
.map(rpc_model_from_entry)
.collect::<Vec<_>>();
let _ = out_tx.send(response_ok(
id,
"get_available_models",
Some(json!({ "models": models })),
));
}
"set_model" => {
let Some(provider) = parsed.get("provider").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_model",
"Missing provider".to_string(),
));
continue;
};
let Some(model_id) = parsed.get("modelId").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_model",
"Missing modelId".to_string(),
));
continue;
};
let Some(entry) = options
.available_models
.iter()
.find(|m| {
provider_ids_match(&m.model.provider, provider)
&& m.model.id.eq_ignore_ascii_case(model_id)
})
.cloned()
else {
let _ = out_tx.send(response_error(
id,
"set_model",
format!("Model not found: {provider}/{model_id}"),
));
continue;
};
let key = resolve_model_key(&options.auth, &entry);
if model_requires_configured_credential(&entry) && key.is_none() {
let err = Error::auth(format!(
"Missing credentials for {}/{}",
entry.model.provider, entry.model.id
));
let _ = out_tx.send(response_error_with_hints(id, "set_model", &err));
continue;
}
let result: Result<()> = async {
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let provider_impl = providers::create_provider(
&entry,
guard
.extensions
.as_ref()
.map(crate::extensions::ExtensionRegion::manager),
)?;
guard.agent.set_provider(provider_impl);
guard.agent.stream_options_mut().api_key.clone_from(&key);
guard
.agent
.stream_options_mut()
.headers
.clone_from(&entry.headers);
apply_model_change(&mut guard, &entry).await?;
let current_thinking = guard
.agent
.stream_options()
.thinking_level
.unwrap_or_default();
let clamped = entry.clamp_thinking_level(current_thinking);
if clamped != current_thinking {
apply_thinking_level(&mut guard, clamped).await?;
}
Ok(())
}
.await;
match result {
Ok(()) => {
let _ = out_tx.send(response_ok(
id,
"set_model",
Some(rpc_model_from_entry(&entry)),
));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "set_model", &err));
}
}
}
"cycle_model" => {
let result = async {
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
cycle_model_for_rpc(&mut guard, &options).await
}
.await;
match result {
Ok(Some((entry, thinking_level, is_scoped))) => {
let _ = out_tx.send(response_ok(
id,
"cycle_model",
Some(json!({
"model": rpc_model_from_entry(&entry),
"thinkingLevel": thinking_level.to_string(),
"isScoped": is_scoped,
})),
));
}
Ok(None) => {
let _ =
out_tx.send(response_ok(id.clone(), "cycle_model", Some(Value::Null)));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "cycle_model", &err));
}
}
}
"set_thinking_level" => {
let Some(level) = parsed.get("level").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_thinking_level",
"Missing level".to_string(),
));
continue;
};
let level = match parse_thinking_level(level) {
Ok(level) => level,
Err(err) => {
let _ =
out_tx.send(response_error_with_hints(id, "set_thinking_level", &err));
continue;
}
};
{
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let level = {
let inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
current_model_entry(&inner_session, &options)
.map_or(level, |entry| entry.clamp_thinking_level(level))
};
if let Err(err) = apply_thinking_level(&mut guard, level).await {
let _ = out_tx.send(response_error_with_hints(
id.clone(),
"set_thinking_level",
&err,
));
continue;
}
}
let _ = out_tx.send(response_ok(id, "set_thinking_level", None));
}
"cycle_thinking_level" => {
let next = {
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let entry = {
let inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
current_model_entry(&inner_session, &options).cloned()
};
let Some(entry) = entry else {
let _ =
out_tx.send(response_ok(id, "cycle_thinking_level", Some(Value::Null)));
continue;
};
if !entry.model.reasoning {
let _ =
out_tx.send(response_ok(id, "cycle_thinking_level", Some(Value::Null)));
continue;
}
let levels = available_thinking_levels(&entry);
let current = guard
.agent
.stream_options()
.thinking_level
.unwrap_or_default();
let current_index = levels
.iter()
.position(|level| *level == current)
.unwrap_or(0);
let next = levels[(current_index + 1) % levels.len()];
if let Err(err) = apply_thinking_level(&mut guard, next).await {
let _ = out_tx.send(response_error_with_hints(
id.clone(),
"cycle_thinking_level",
&err,
));
continue;
}
next
};
let _ = out_tx.send(response_ok(
id,
"cycle_thinking_level",
Some(json!({ "level": next.to_string() })),
));
}
"set_steering_mode" => {
let Some(mode) = parsed.get("mode").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_steering_mode",
"Missing mode".to_string(),
));
continue;
};
let Some(mode) = parse_queue_mode(Some(mode)) else {
let _ = out_tx.send(response_error(
id,
"set_steering_mode",
"Invalid steering mode".to_string(),
));
continue;
};
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.steering_mode = mode;
drop(state);
let _ = out_tx.send(response_ok(id, "set_steering_mode", None));
}
"set_follow_up_mode" => {
let Some(mode) = parsed.get("mode").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_follow_up_mode",
"Missing mode".to_string(),
));
continue;
};
let Some(mode) = parse_queue_mode(Some(mode)) else {
let _ = out_tx.send(response_error(
id,
"set_follow_up_mode",
"Invalid follow-up mode".to_string(),
));
continue;
};
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.follow_up_mode = mode;
drop(state);
let _ = out_tx.send(response_ok(id, "set_follow_up_mode", None));
}
"set_auto_compaction" => {
let Some(enabled) = parsed.get("enabled").and_then(Value::as_bool) else {
let _ = out_tx.send(response_error(
id,
"set_auto_compaction",
"Missing enabled".to_string(),
));
continue;
};
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.auto_compaction_enabled = enabled;
drop(state);
let _ = out_tx.send(response_ok(id, "set_auto_compaction", None));
}
"set_auto_retry" => {
let Some(enabled) = parsed.get("enabled").and_then(Value::as_bool) else {
let _ = out_tx.send(response_error(
id,
"set_auto_retry",
"Missing enabled".to_string(),
));
continue;
};
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.auto_retry_enabled = enabled;
drop(state);
let _ = out_tx.send(response_ok(id, "set_auto_retry", None));
}
"abort_retry" => {
retry_abort.store(true, Ordering::SeqCst);
let _ = out_tx.send(response_ok(id, "abort_retry", None));
}
"set_session_name" => {
let Some(name) = parsed.get("name").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"set_session_name",
"Missing name".to_string(),
));
continue;
};
let result: Result<()> = async {
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
{
let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner_session.append_session_info(Some(name.to_string()));
}
guard.persist_session().await?;
Ok(())
}
.await;
match result {
Ok(()) => {
let _ = out_tx.send(response_ok(id, "set_session_name", None));
}
Err(err) => {
let _ =
out_tx.send(response_error_with_hints(id, "set_session_name", &err));
}
}
}
"get_last_assistant_text" => {
let text = {
let inner_session = session_handle.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
last_assistant_text(&inner_session)
};
let _ = out_tx.send(response_ok(
id,
"get_last_assistant_text",
Some(json!({ "text": text })),
));
}
"export_html" => {
let output_path = parsed
.get("outputPath")
.and_then(Value::as_str)
.map(str::to_string);
let snapshot = {
let guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let inner = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner.export_snapshot()
};
match export_html_snapshot(&snapshot, output_path.as_deref()).await {
Ok(path) => {
let _ = out_tx.send(response_ok(
id,
"export_html",
Some(json!({ "path": path })),
));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "export_html", &err));
}
}
}
"bash" => {
let Some(command) = parsed.get("command").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(id, "bash", "Missing command".to_string()));
continue;
};
let mut running = bash_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("bash state lock failed: {err}")))?;
if running.is_some() {
let _ = out_tx.send(response_error(
id,
"bash",
"Bash command already running".to_string(),
));
continue;
}
let run_id = uuid::Uuid::new_v4().to_string();
let (abort_tx, abort_rx) = oneshot::channel();
*running = Some(RunningBash {
id: run_id.clone(),
abort_tx,
});
let out_tx = out_tx.clone();
let session = Arc::clone(&session);
let bash_state = Arc::clone(&bash_state);
let command = command.to_string();
let id_clone = id.clone();
let runtime_handle = options.runtime_handle.clone();
runtime_handle.spawn(async move {
let cx = AgentCx::for_request();
let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
let result = run_bash_rpc(&cwd, &command, abort_rx).await;
let response = match result {
Ok(result) => {
if let Ok(mut guard) = session.lock(&cx).await {
if let Ok(mut inner_session) = guard.session.lock(&cx).await {
inner_session.append_message(SessionMessage::BashExecution {
command: command.clone(),
output: result.output.clone(),
exit_code: result.exit_code,
cancelled: Some(result.cancelled),
truncated: Some(result.truncated),
full_output_path: result.full_output_path.clone(),
timestamp: Some(chrono::Utc::now().timestamp_millis()),
extra: std::collections::HashMap::default(),
});
}
let _ = guard.persist_session().await;
}
response_ok(
id_clone,
"bash",
Some(json!({
"output": result.output,
"exitCode": result.exit_code,
"cancelled": result.cancelled,
"truncated": result.truncated,
"fullOutputPath": result.full_output_path,
})),
)
}
Err(err) => response_error_with_hints(id_clone, "bash", &err),
};
let _ = out_tx.send(response);
if let Ok(mut running) = bash_state.lock(&cx).await {
if running.as_ref().is_some_and(|r| r.id == run_id) {
*running = None;
}
}
});
}
"abort_bash" => {
let mut running = bash_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("bash state lock failed: {err}")))?;
if let Some(running_bash) = running.take() {
let _ = running_bash.abort_tx.send(&cx, ());
}
let _ = out_tx.send(response_ok(id, "abort_bash", None));
}
"compact" => {
let custom_instructions = parsed
.get("customInstructions")
.and_then(Value::as_str)
.map(str::to_string);
let result: Result<Value> = async {
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let path_entries = {
let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner_session.ensure_entry_ids();
inner_session
.entries_for_current_path()
.into_iter()
.cloned()
.collect::<Vec<_>>()
};
let key = guard
.agent
.stream_options()
.api_key
.as_deref()
.ok_or_else(|| Error::auth("Missing API key for compaction"))?;
let provider = guard.agent.provider();
let settings = ResolvedCompactionSettings {
enabled: options.config.compaction_enabled(),
reserve_tokens: options.config.compaction_reserve_tokens(),
keep_recent_tokens: options.config.compaction_keep_recent_tokens(),
..Default::default()
};
let prep = prepare_compaction(&path_entries, settings).ok_or_else(|| {
Error::session(
"Compaction not available (already compacted or missing IDs)",
)
})?;
is_compacting.store(true, Ordering::SeqCst);
let compact_res =
compact(prep, provider, key, custom_instructions.as_deref()).await;
is_compacting.store(false, Ordering::SeqCst);
let result_data = compact_res?;
let details_value = compaction_details_to_value(&result_data.details)?;
let messages = {
let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner_session.append_compaction(
result_data.summary.clone(),
result_data.first_kept_entry_id.clone(),
result_data.tokens_before,
Some(details_value.clone()),
None,
);
inner_session.to_messages_for_current_path()
};
guard.persist_session().await?;
guard.agent.replace_messages(messages);
Ok(json!({
"summary": result_data.summary,
"firstKeptEntryId": result_data.first_kept_entry_id,
"tokensBefore": result_data.tokens_before,
"details": details_value,
}))
}
.await;
match result {
Ok(data) => {
let _ = out_tx.send(response_ok(id, "compact", Some(data)));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "compact", &err));
}
}
}
"new_session" => {
let parent = parsed
.get("parentSession")
.and_then(Value::as_str)
.map(str::to_string);
{
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let (session_dir, provider, model_id, thinking_level) = {
let inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
(
inner_session.session_dir.clone(),
inner_session.header.provider.clone(),
inner_session.header.model_id.clone(),
inner_session.header.thinking_level.clone(),
)
};
let mut new_session = if guard.save_enabled() {
crate::session::Session::create_with_dir(session_dir)
} else {
crate::session::Session::in_memory()
};
new_session.header.parent_session = parent;
new_session.header.provider.clone_from(&provider);
new_session.header.model_id.clone_from(&model_id);
new_session
.header
.thinking_level
.clone_from(&thinking_level);
let session_id = new_session.header.id.clone();
{
let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
*inner_session = new_session;
}
guard.agent.clear_messages();
guard.agent.stream_options_mut().session_id = Some(session_id);
}
{
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.steering.clear();
state.follow_up.clear();
}
let _ = out_tx.send(response_ok(
id,
"new_session",
Some(json!({ "cancelled": false })),
));
}
"switch_session" => {
let Some(session_path) = parsed.get("sessionPath").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(
id,
"switch_session",
"Missing sessionPath".to_string(),
));
continue;
};
let loaded = crate::session::Session::open(session_path).await;
match loaded {
Ok(new_session) => {
let messages = new_session.to_messages_for_current_path();
let session_id = new_session.header.id.clone();
let mut guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
{
let mut inner_session =
guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
*inner_session = new_session;
}
guard.agent.replace_messages(messages);
guard.agent.stream_options_mut().session_id = Some(session_id);
let _ = out_tx.send(response_ok(
id,
"switch_session",
Some(json!({ "cancelled": false })),
));
let mut state = shared_state
.lock(&cx)
.await
.map_err(|err| Error::session(format!("state lock failed: {err}")))?;
state.steering.clear();
state.follow_up.clear();
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "switch_session", &err));
}
}
}
"fork" => {
let Some(entry_id) = parsed.get("entryId").and_then(Value::as_str) else {
let _ = out_tx.send(response_error(id, "fork", "Missing entryId".to_string()));
continue;
};
let result: Result<String> =
async {
let (fork_plan, parent_path, session_dir, save_enabled, header_snapshot) = {
let guard = session.lock(&cx).await.map_err(|err| {
Error::session(format!("session lock failed: {err}"))
})?;
let inner = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
let plan = inner.plan_fork_from_user_message(entry_id)?;
let parent_path = inner.path.as_ref().map(|p| p.display().to_string());
let session_dir = inner.session_dir.clone();
let header = inner.header.clone();
(plan, parent_path, session_dir, guard.save_enabled(), header)
};
let crate::session::ForkPlan {
entries,
leaf_id,
selected_text,
} = fork_plan;
let mut new_session = if save_enabled {
crate::session::Session::create_with_dir(session_dir)
} else {
crate::session::Session::in_memory()
};
new_session.header.parent_session = parent_path;
new_session
.header
.provider
.clone_from(&header_snapshot.provider);
new_session
.header
.model_id
.clone_from(&header_snapshot.model_id);
new_session
.header
.thinking_level
.clone_from(&header_snapshot.thinking_level);
new_session.entries = entries;
new_session.leaf_id = leaf_id;
new_session.ensure_entry_ids();
let messages = new_session.to_messages_for_current_path();
let session_id = new_session.header.id.clone();
{
let mut guard = session.lock(&cx).await.map_err(|err| {
Error::session(format!("session lock failed: {err}"))
})?;
let mut inner = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
*inner = new_session;
drop(inner);
guard.agent.replace_messages(messages);
guard.agent.stream_options_mut().session_id = Some(session_id);
}
{
let mut state = shared_state.lock(&cx).await.map_err(|err| {
Error::session(format!("state lock failed: {err}"))
})?;
state.steering.clear();
state.follow_up.clear();
}
Ok(selected_text)
}
.await;
match result {
Ok(selected_text) => {
let _ = out_tx.send(response_ok(
id,
"fork",
Some(json!({ "text": selected_text, "cancelled": false })),
));
}
Err(err) => {
let _ = out_tx.send(response_error_with_hints(id, "fork", &err));
}
}
}
"get_fork_messages" => {
let path_entries = {
let guard = session
.lock(&cx)
.await
.map_err(|err| Error::session(format!("session lock failed: {err}")))?;
let inner_session = guard.session.lock(&cx).await.map_err(|err| {
Error::session(format!("inner session lock failed: {err}"))
})?;
inner_session
.entries_for_current_path()
.into_iter()
.cloned()
.collect::<Vec<_>>()
};
let messages = fork_messages_from_entries(&path_entries);
let _ = out_tx.send(response_ok(
id,
"get_fork_messages",
Some(json!({ "messages": messages })),
));
}
"get_commands" => {
let commands = options.resources.list_commands();
let _ = out_tx.send(response_ok(
id,
"get_commands",
Some(json!({ "commands": commands })),
));
}
"extension_ui_response" => {
if let (Some(manager), Some(ui_state)) =
(rpc_extension_manager.as_ref(), rpc_ui_state.as_ref())
{
let Some(request_id) = rpc_parse_extension_ui_response_id(&parsed) else {
let _ = out_tx.send(response_error(
id,
"extension_ui_response",
"Missing requestId (or id) field",
));
continue;
};
let (response, next_request) = {
let Ok(mut guard) = ui_state.lock(&cx).await else {
let _ = out_tx.send(response_error(
id,
"extension_ui_response",
"Extension UI bridge unavailable",
));
continue;
};
let Some(active) = guard.active.clone() else {
let _ = out_tx.send(response_error(
id,
"extension_ui_response",
"No active extension UI request",
));
continue;
};
if active.id != request_id {
let _ = out_tx.send(response_error(
id,
"extension_ui_response",
format!(
"Unexpected requestId: {request_id} (active: {})",
active.id
),
));
continue;
}
let response = match rpc_parse_extension_ui_response(&parsed, &active) {
Ok(response) => response,
Err(message) => {
let _ = out_tx.send(response_error(
id,
"extension_ui_response",
message,
));
continue;
}
};
guard.active = None;
let next = guard.queue.pop_front();
if let Some(ref next) = next {
guard.active = Some(next.clone());
}
(response, next)
};
let resolved = manager.respond_ui(response);
let _ = out_tx.send(response_ok(
id,
"extension_ui_response",
Some(json!({ "resolved": resolved })),
));
if let Some(next) = next_request {
rpc_emit_extension_ui_request(
&options.runtime_handle,
Arc::clone(ui_state),
(*manager).clone(),
out_tx.clone(),
next,
);
}
} else {
let _ = out_tx.send(response_ok(id, "extension_ui_response", None));
}
}
_ => {
let _ = out_tx.send(response_error(
id,
command_type_raw,
format!("Unknown command: {command_type_raw}"),
));
}
}
}
let extension_region = session
.lock(&cx)
.await
.ok()
.and_then(|mut guard| guard.extensions.take());
if let Some(ext) = extension_region {
ext.shutdown().await;
}
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn run_prompt_with_retry(
session: Arc<Mutex<AgentSession>>,
shared_state: Arc<Mutex<RpcSharedState>>,
is_streaming: Arc<AtomicBool>,
is_compacting: Arc<AtomicBool>,
abort_handle_slot: Arc<Mutex<Option<AbortHandle>>>,
out_tx: std::sync::mpsc::Sender<String>,
retry_abort: Arc<AtomicBool>,
options: RpcOptions,
message: String,
images: Vec<ImageContent>,
cx: AgentCx,
) {
retry_abort.store(false, Ordering::SeqCst);
is_streaming.store(true, Ordering::SeqCst);
let max_retries = options.config.retry_max_retries();
let mut retry_count: u32 = 0;
let mut success = false;
let mut final_error: Option<String> = None;
let mut final_error_hints: Option<Value> = None;
loop {
let (abort_handle, abort_signal) = AbortHandle::new();
if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
*guard = Some(abort_handle);
} else {
is_streaming.store(false, Ordering::SeqCst);
return;
}
let runtime_for_events = options.runtime_handle.clone();
let result = {
let mut guard = match OwnedMutexGuard::lock(Arc::clone(&session), &cx).await {
Ok(guard) => guard,
Err(err) => {
final_error = Some(format!("session lock failed: {err}"));
final_error_hints = None;
break;
}
};
let extensions = guard.extensions.as_ref().map(|r| r.manager().clone());
let runtime_for_events_handler = runtime_for_events.clone();
let event_tx = out_tx.clone();
let coalescer = extensions
.as_ref()
.map(|m| crate::extensions::EventCoalescer::new(m.clone()));
let event_handler = move |event: AgentEvent| {
let serialized = if let AgentEvent::AgentEnd {
messages, error, ..
} = &event
{
json!({
"type": "agent_end",
"messages": messages,
"error": error,
})
.to_string()
} else {
serde_json::to_string(&event).unwrap_or_else(|err| {
json!({
"type": "event_serialize_error",
"error": err.to_string(),
})
.to_string()
})
};
let _ = event_tx.send(serialized);
if let Some(coal) = &coalescer {
coal.dispatch_agent_event_lazy(&event, &runtime_for_events_handler);
}
};
if images.is_empty() {
guard
.run_text_with_abort(message.clone(), Some(abort_signal), event_handler)
.await
} else {
let mut blocks = vec![ContentBlock::Text(TextContent::new(message.clone()))];
for image in &images {
blocks.push(ContentBlock::Image(image.clone()));
}
guard
.run_with_content_with_abort(blocks, Some(abort_signal), event_handler)
.await
}
};
if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
*guard = None;
}
match result {
Ok(message) => {
if matches!(message.stop_reason, StopReason::Error | StopReason::Aborted) {
final_error = message
.error_message
.clone()
.or_else(|| Some("Request error".to_string()));
final_error_hints = None;
if message.stop_reason == StopReason::Aborted {
break;
}
if let Some(ref err_msg) = final_error {
let context_window = if let Ok(guard) =
OwnedMutexGuard::lock(Arc::clone(&session), &cx).await
{
guard.session.lock(&cx).await.map_or(None, |inner| {
current_model_entry(&inner, &options)
.map(|e| e.model.context_window)
})
} else {
None
};
if !crate::error::is_retryable_error(
err_msg,
Some(message.usage.input),
context_window,
) {
break;
}
}
} else {
success = true;
break;
}
}
Err(err) => {
let err_str = err.to_string();
if !crate::error::is_retryable_error(&err_str, None, None) {
final_error = Some(err_str);
final_error_hints = Some(error_hints_value(&err));
break;
}
final_error = Some(err_str);
final_error_hints = Some(error_hints_value(&err));
}
}
let retry_enabled = OwnedMutexGuard::lock(Arc::clone(&shared_state), &cx)
.await
.is_ok_and(|state| state.auto_retry_enabled);
if !retry_enabled || retry_count >= max_retries {
break;
}
retry_count += 1;
let delay_ms = retry_delay_ms(&options.config, retry_count);
let error_message = final_error
.clone()
.unwrap_or_else(|| "Request error".to_string());
let _ = out_tx.send(event(&json!({
"type": "auto_retry_start",
"attempt": retry_count,
"maxAttempts": max_retries,
"delayMs": delay_ms,
"errorMessage": error_message,
})));
let delay = Duration::from_millis(delay_ms as u64);
let start = std::time::Instant::now();
while start.elapsed() < delay {
if retry_abort.load(Ordering::SeqCst) {
break;
}
sleep(wall_now(), Duration::from_millis(50)).await;
}
if retry_abort.load(Ordering::SeqCst) {
final_error = Some("Retry aborted".to_string());
break;
}
}
if retry_count > 0 {
let _ = out_tx.send(event(&json!({
"type": "auto_retry_end",
"success": success,
"attempt": retry_count,
"finalError": if success { Value::Null } else { json!(final_error.clone()) },
})));
}
is_streaming.store(false, Ordering::SeqCst);
if !success {
if let Some(err) = final_error {
let mut payload = json!({
"type": "agent_end",
"messages": [],
"error": err
});
if let Some(hints) = final_error_hints {
payload["errorHints"] = hints;
}
let _ = out_tx.send(event(&payload));
}
return;
}
let auto_compaction_enabled = OwnedMutexGuard::lock(Arc::clone(&shared_state), &cx)
.await
.is_ok_and(|state| state.auto_compaction_enabled);
if auto_compaction_enabled {
maybe_auto_compact(session, options, is_compacting, out_tx).await;
}
}
fn response_ok(id: Option<String>, command: &str, data: Option<Value>) -> String {
let mut resp = json!({
"type": "response",
"command": command,
"success": true,
});
if let Some(id) = id {
resp["id"] = Value::String(id);
}
if let Some(data) = data {
resp["data"] = data;
}
resp.to_string()
}
fn response_error(id: Option<String>, command: &str, error: impl Into<String>) -> String {
let mut resp = json!({
"type": "response",
"command": command,
"success": false,
"error": error.into(),
});
if let Some(id) = id {
resp["id"] = Value::String(id);
}
resp.to_string()
}
fn response_error_with_hints(id: Option<String>, command: &str, error: &Error) -> String {
let mut resp = json!({
"type": "response",
"command": command,
"success": false,
"error": error.to_string(),
"errorHints": error_hints_value(error),
});
if let Some(id) = id {
resp["id"] = Value::String(id);
}
resp.to_string()
}
fn event(value: &Value) -> String {
value.to_string()
}
fn rpc_emit_extension_ui_request(
runtime_handle: &RuntimeHandle,
ui_state: Arc<Mutex<RpcUiBridgeState>>,
manager: ExtensionManager,
out_tx_ui: std::sync::mpsc::Sender<String>,
request: ExtensionUiRequest,
) {
let rpc_event = request.to_rpc_event();
let _ = out_tx_ui.send(event(&rpc_event));
if !request.expects_response() {
return;
}
let Some(timeout_ms) = request.effective_timeout_ms() else {
return;
};
let fire_ms = timeout_ms.saturating_sub(10).max(1);
let request_id = request.id;
let ui_state_timeout = Arc::clone(&ui_state);
let manager_timeout = manager;
let out_tx_timeout = out_tx_ui;
let runtime_handle_inner = runtime_handle.clone();
runtime_handle.spawn(async move {
sleep(wall_now(), Duration::from_millis(fire_ms)).await;
let cx = AgentCx::for_request();
let next = {
let Ok(mut guard) = ui_state_timeout.lock(cx.cx()).await else {
return;
};
let Some(active) = guard.active.as_ref() else {
return;
};
if active.id != request_id {
return;
}
let _ = manager_timeout.respond_ui(ExtensionUiResponse {
id: request_id,
value: None,
cancelled: true,
});
guard.active = None;
let next = guard.queue.pop_front();
if let Some(ref next) = next {
guard.active = Some(next.clone());
}
next
};
if let Some(next) = next {
rpc_emit_extension_ui_request(
&runtime_handle_inner,
ui_state_timeout,
manager_timeout,
out_tx_timeout,
next,
);
}
});
}
fn rpc_parse_extension_ui_response_id(parsed: &Value) -> Option<String> {
let request_id = parsed
.get("requestId")
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.map(String::from);
request_id.or_else(|| {
parsed
.get("id")
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.map(String::from)
})
}
fn rpc_parse_extension_ui_response(
parsed: &Value,
active: &ExtensionUiRequest,
) -> std::result::Result<ExtensionUiResponse, String> {
let cancelled = parsed
.get("cancelled")
.and_then(Value::as_bool)
.unwrap_or(false);
if cancelled {
return Ok(ExtensionUiResponse {
id: active.id.clone(),
value: None,
cancelled: true,
});
}
match active.method.as_str() {
"confirm" => {
let value = parsed
.get("confirmed")
.and_then(Value::as_bool)
.or_else(|| parsed.get("value").and_then(Value::as_bool))
.ok_or_else(|| "confirm requires boolean `confirmed` (or `value`)".to_string())?;
Ok(ExtensionUiResponse {
id: active.id.clone(),
value: Some(Value::Bool(value)),
cancelled: false,
})
}
"select" => {
let Some(value) = parsed.get("value") else {
return Err("select requires `value` field".to_string());
};
let options = active
.payload
.get("options")
.and_then(Value::as_array)
.ok_or_else(|| "select request missing `options` array".to_string())?;
let mut allowed = Vec::with_capacity(options.len());
for opt in options {
match opt {
Value::String(s) => allowed.push(Value::String(s.clone())),
Value::Object(map) => {
let label = map
.get("label")
.and_then(Value::as_str)
.unwrap_or("")
.trim();
if label.is_empty() {
continue;
}
if let Some(v) = map.get("value") {
allowed.push(v.clone());
} else {
allowed.push(Value::String(label.to_string()));
}
}
_ => {}
}
}
if !allowed.iter().any(|candidate| candidate == value) {
return Err("select response value did not match any option".to_string());
}
Ok(ExtensionUiResponse {
id: active.id.clone(),
value: Some(value.clone()),
cancelled: false,
})
}
"input" | "editor" => {
let Some(value) = parsed.get("value") else {
return Err(format!("{} requires `value` field", active.method));
};
if !value.is_string() {
return Err(format!("{} requires string `value`", active.method));
}
Ok(ExtensionUiResponse {
id: active.id.clone(),
value: Some(value.clone()),
cancelled: false,
})
}
"notify" => Ok(ExtensionUiResponse {
id: active.id.clone(),
value: None,
cancelled: false,
}),
other => Err(format!("Unsupported extension UI method: {other}")),
}
}
#[cfg(test)]
mod ui_bridge_tests {
use super::*;
#[test]
fn parse_extension_ui_response_id_prefers_request_id() {
let value = json!({"type":"extension_ui_response","id":"legacy","requestId":"canonical"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("canonical".to_string())
);
}
#[test]
fn parse_extension_ui_response_id_accepts_id_alias() {
let value = json!({"type":"extension_ui_response","id":"legacy"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("legacy".to_string())
);
}
#[test]
fn parse_confirm_response_accepts_confirmed_alias() {
let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
let value = json!({"type":"extension_ui_response","requestId":"req-1","confirmed":true});
let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse confirm");
assert!(!resp.cancelled);
assert_eq!(resp.value, Some(json!(true)));
}
#[test]
fn parse_confirm_response_accepts_value_bool() {
let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
let value = json!({"type":"extension_ui_response","requestId":"req-1","value":false});
let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse confirm");
assert!(!resp.cancelled);
assert_eq!(resp.value, Some(json!(false)));
}
#[test]
fn parse_cancelled_response_wins_over_value() {
let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
let value = json!({"type":"extension_ui_response","requestId":"req-1","cancelled":true,"value":true});
let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse cancel");
assert!(resp.cancelled);
assert_eq!(resp.value, None);
}
#[test]
fn parse_select_response_validates_against_options() {
let active = ExtensionUiRequest::new(
"req-1",
"select",
json!({"title":"pick","options":["A","B"]}),
);
let ok_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"B"});
let ok = rpc_parse_extension_ui_response(&ok_value, &active).expect("parse select ok");
assert_eq!(ok.value, Some(json!("B")));
let bad_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"C"});
assert!(
rpc_parse_extension_ui_response(&bad_value, &active).is_err(),
"invalid selection should error"
);
}
#[test]
fn parse_input_requires_string_value() {
let active = ExtensionUiRequest::new("req-1", "input", json!({"title":"t"}));
let ok_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"hi"});
let ok = rpc_parse_extension_ui_response(&ok_value, &active).expect("parse input ok");
assert_eq!(ok.value, Some(json!("hi")));
let bad_value = json!({"type":"extension_ui_response","requestId":"req-1","value":123});
assert!(
rpc_parse_extension_ui_response(&bad_value, &active).is_err(),
"non-string input should error"
);
}
#[test]
fn parse_editor_requires_string_value() {
let active = ExtensionUiRequest::new("req-1", "editor", json!({"title":"t"}));
let ok = json!({"requestId":"req-1","value":"multi\nline"});
let resp = rpc_parse_extension_ui_response(&ok, &active).expect("editor ok");
assert_eq!(resp.value, Some(json!("multi\nline")));
let bad = json!({"requestId":"req-1","value":42});
assert!(
rpc_parse_extension_ui_response(&bad, &active).is_err(),
"editor needs string"
);
}
#[test]
fn parse_notify_returns_no_value() {
let active = ExtensionUiRequest::new("req-1", "notify", json!({"title":"t"}));
let val = json!({"requestId":"req-1"});
let resp = rpc_parse_extension_ui_response(&val, &active).expect("notify ok");
assert!(!resp.cancelled);
assert!(resp.value.is_none());
}
#[test]
fn parse_unsupported_method_errors() {
let active = ExtensionUiRequest::new("req-1", "custom_method", json!({}));
let val = json!({"requestId":"req-1","value":"x"});
let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
assert!(err.contains("Unsupported"), "err={err}");
}
#[test]
fn parse_select_missing_value_field() {
let active =
ExtensionUiRequest::new("req-1", "select", json!({"title":"pick","options":["A"]}));
let val = json!({"requestId":"req-1"});
let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
assert!(err.contains("value"), "err={err}");
}
#[test]
fn parse_confirm_missing_value_errors() {
let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
let val = json!({"requestId":"req-1"});
let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
assert!(err.contains("confirm"), "err={err}");
}
#[test]
fn parse_select_with_label_value_objects() {
let active = ExtensionUiRequest::new(
"req-1",
"select",
json!({
"title": "pick",
"options": [
{"label": "Alpha", "value": "a"},
{"label": "Beta", "value": "b"},
]
}),
);
let val = json!({"requestId":"req-1","value":"a"});
let resp = rpc_parse_extension_ui_response(&val, &active).expect("select by value");
assert_eq!(resp.value, Some(json!("a")));
}
#[test]
fn parse_id_rejects_empty_and_whitespace() {
let val = json!({"requestId":" ","id":""});
assert!(rpc_parse_extension_ui_response_id(&val).is_none());
}
#[test]
fn bridge_state_default_is_empty() {
let state = RpcUiBridgeState::default();
assert!(state.active.is_none());
assert!(state.queue.is_empty());
}
}
fn error_hints_value(error: &Error) -> Value {
let hint = error_hints::hints_for_error(error);
json!({
"summary": hint.summary,
"hints": hint.hints,
"contextFields": hint.context_fields,
})
}
fn rpc_session_message_value(message: SessionMessage) -> Value {
let mut value =
serde_json::to_value(message).expect("SessionMessage should always serialize to JSON");
rpc_flatten_content_blocks(&mut value);
value
}
fn rpc_flatten_content_blocks(value: &mut Value) {
let Value::Object(message_obj) = value else {
return;
};
let Some(content) = message_obj.get_mut("content") else {
return;
};
let Value::Array(blocks) = content else {
return;
};
for block in blocks {
let Value::Object(block_obj) = block else {
continue;
};
let Some(inner) = block_obj.remove("0") else {
continue;
};
let Value::Object(inner_obj) = inner else {
block_obj.insert("0".to_string(), inner);
continue;
};
for (key, value) in inner_obj {
block_obj.entry(key).or_insert(value);
}
}
}
fn retry_delay_ms(config: &Config, attempt: u32) -> u32 {
let base = u64::from(config.retry_base_delay_ms());
let max = u64::from(config.retry_max_delay_ms());
let shift = attempt.saturating_sub(1);
let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
let delay = base.saturating_mul(multiplier).min(max);
u32::try_from(delay).unwrap_or(u32::MAX)
}
#[cfg(test)]
mod retry_tests {
use super::*;
use crate::agent::{Agent, AgentConfig, AgentSession};
use crate::model::{AssistantMessage, Usage};
use crate::provider::Provider;
use crate::resources::ResourceLoader;
use crate::session::Session;
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use futures::stream;
use std::path::Path;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct FlakyProvider {
calls: AtomicUsize,
}
impl FlakyProvider {
const fn new() -> Self {
Self {
calls: AtomicUsize::new(0),
}
}
}
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for FlakyProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &crate::provider::Context<'_>,
_options: &crate::provider::StreamOptions,
) -> crate::error::Result<
Pin<
Box<
dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
+ Send,
>,
>,
> {
let call = self.calls.fetch_add(1, Ordering::SeqCst);
let mut partial = AssistantMessage {
content: Vec::new(),
api: self.api().to_string(),
provider: self.name().to_string(),
model: self.model_id().to_string(),
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp: 0,
};
let events = if call == 0 {
partial.stop_reason = StopReason::Error;
partial.error_message = Some("server error".to_string());
vec![
Ok(crate::model::StreamEvent::Start {
partial: partial.clone(),
}),
Ok(crate::model::StreamEvent::Error {
reason: StopReason::Error,
error: partial,
}),
]
} else {
vec![
Ok(crate::model::StreamEvent::Start {
partial: partial.clone(),
}),
Ok(crate::model::StreamEvent::Done {
reason: StopReason::Stop,
message: partial,
}),
]
};
Ok(Box::pin(stream::iter(events)))
}
}
#[derive(Debug)]
struct AlwaysErrorProvider;
#[async_trait]
#[allow(clippy::unnecessary_literal_bound)]
impl Provider for AlwaysErrorProvider {
fn name(&self) -> &str {
"test-provider"
}
fn api(&self) -> &str {
"test-api"
}
fn model_id(&self) -> &str {
"test-model"
}
async fn stream(
&self,
_context: &crate::provider::Context<'_>,
_options: &crate::provider::StreamOptions,
) -> crate::error::Result<
Pin<
Box<
dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
+ Send,
>,
>,
> {
let mut partial = AssistantMessage {
content: Vec::new(),
api: self.api().to_string(),
provider: self.name().to_string(),
model: self.model_id().to_string(),
usage: Usage::default(),
stop_reason: StopReason::Error,
error_message: Some("server error".to_string()),
timestamp: 0,
};
let events = vec![
Ok(crate::model::StreamEvent::Start {
partial: partial.clone(),
}),
Ok(crate::model::StreamEvent::Error {
reason: StopReason::Error,
error: {
partial.stop_reason = StopReason::Error;
partial
},
}),
];
Ok(Box::pin(stream::iter(events)))
}
}
#[test]
fn rpc_auto_retry_retries_then_succeeds() {
let runtime = asupersync::runtime::RuntimeBuilder::new()
.blocking_threads(1, 8)
.build()
.expect("runtime build");
let runtime_handle = runtime.handle();
runtime.block_on(async move {
let provider = Arc::new(FlakyProvider::new());
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let inner_session = Arc::new(Mutex::new(Session::in_memory()));
let agent_session = AgentSession::new(
agent,
inner_session,
false,
crate::compaction::ResolvedCompactionSettings::default(),
);
let session = Arc::new(Mutex::new(agent_session));
let mut config = Config::default();
config.retry = Some(crate::config::RetrySettings {
enabled: Some(true),
max_retries: Some(1),
base_delay_ms: Some(1),
max_delay_ms: Some(1),
});
let mut shared = RpcSharedState::new(&config);
shared.auto_compaction_enabled = false;
let shared_state = Arc::new(Mutex::new(shared));
let is_streaming = Arc::new(AtomicBool::new(false));
let is_compacting = Arc::new(AtomicBool::new(false));
let abort_handle_slot: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
let retry_abort = Arc::new(AtomicBool::new(false));
let (out_tx, out_rx) = std::sync::mpsc::channel::<String>();
let auth_path = tempfile::tempdir()
.expect("tempdir")
.path()
.join("auth.json");
let auth = AuthStorage::load(auth_path).expect("auth load");
let options = RpcOptions {
config,
resources: ResourceLoader::empty(false),
available_models: Vec::new(),
scoped_models: Vec::new(),
auth,
runtime_handle,
};
run_prompt_with_retry(
session,
shared_state,
is_streaming,
is_compacting,
abort_handle_slot,
out_tx,
retry_abort,
options,
"hello".to_string(),
Vec::new(),
AgentCx::for_request(),
)
.await;
let mut saw_retry_start = false;
let mut saw_retry_end_success = false;
for line in out_rx.try_iter() {
let Ok(value) = serde_json::from_str::<Value>(&line) else {
continue;
};
let Some(kind) = value.get("type").and_then(Value::as_str) else {
continue;
};
match kind {
"auto_retry_start" => {
saw_retry_start = true;
}
"auto_retry_end" => {
if value.get("success").and_then(Value::as_bool) == Some(true) {
saw_retry_end_success = true;
}
}
_ => {}
}
}
assert!(saw_retry_start, "missing auto_retry_start event");
assert!(
saw_retry_end_success,
"missing successful auto_retry_end event"
);
});
}
#[test]
fn rpc_abort_retry_emits_ordered_retry_timeline() {
let runtime = asupersync::runtime::RuntimeBuilder::new()
.blocking_threads(1, 8)
.build()
.expect("runtime build");
let runtime_handle = runtime.handle();
runtime.block_on(async move {
let provider = Arc::new(AlwaysErrorProvider);
let tools = ToolRegistry::new(&[], Path::new("."), None);
let agent = Agent::new(provider, tools, AgentConfig::default());
let inner_session = Arc::new(Mutex::new(Session::in_memory()));
let agent_session = AgentSession::new(
agent,
inner_session,
false,
crate::compaction::ResolvedCompactionSettings::default(),
);
let session = Arc::new(Mutex::new(agent_session));
let mut config = Config::default();
config.retry = Some(crate::config::RetrySettings {
enabled: Some(true),
max_retries: Some(3),
base_delay_ms: Some(100),
max_delay_ms: Some(100),
});
let mut shared = RpcSharedState::new(&config);
shared.auto_compaction_enabled = false;
let shared_state = Arc::new(Mutex::new(shared));
let is_streaming = Arc::new(AtomicBool::new(false));
let is_compacting = Arc::new(AtomicBool::new(false));
let abort_handle_slot: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
let retry_abort = Arc::new(AtomicBool::new(false));
let (out_tx, out_rx) = std::sync::mpsc::channel::<String>();
let auth_path = tempfile::tempdir()
.expect("tempdir")
.path()
.join("auth.json");
let auth = AuthStorage::load(auth_path).expect("auth load");
let options = RpcOptions {
config,
resources: ResourceLoader::empty(false),
available_models: Vec::new(),
scoped_models: Vec::new(),
auth,
runtime_handle,
};
let retry_abort_for_thread = Arc::clone(&retry_abort);
let abort_thread = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
retry_abort_for_thread.store(true, Ordering::SeqCst);
});
run_prompt_with_retry(
session,
shared_state,
is_streaming,
is_compacting,
abort_handle_slot,
out_tx,
retry_abort,
options,
"hello".to_string(),
Vec::new(),
AgentCx::for_request(),
)
.await;
abort_thread.join().expect("abort thread join");
let mut timeline = Vec::new();
let mut last_agent_end_error = None::<String>;
for line in out_rx.try_iter() {
let Ok(value) = serde_json::from_str::<Value>(&line) else {
continue;
};
let Some(kind) = value.get("type").and_then(Value::as_str) else {
continue;
};
timeline.push(kind.to_string());
if kind == "agent_end" {
last_agent_end_error = value
.get("error")
.and_then(Value::as_str)
.map(str::to_string);
}
}
let retry_start_idx = timeline
.iter()
.position(|kind| kind == "auto_retry_start")
.expect("missing auto_retry_start");
let retry_end_idx = timeline
.iter()
.position(|kind| kind == "auto_retry_end")
.expect("missing auto_retry_end");
let agent_end_idx = timeline
.iter()
.rposition(|kind| kind == "agent_end")
.expect("missing agent_end");
assert!(
retry_start_idx < retry_end_idx && retry_end_idx < agent_end_idx,
"unexpected retry timeline ordering: {timeline:?}"
);
assert_eq!(
last_agent_end_error.as_deref(),
Some("Retry aborted"),
"expected retry-abort terminal error, timeline: {timeline:?}"
);
});
}
}
fn should_auto_compact(tokens_before: u64, context_window: u32, reserve_tokens: u32) -> bool {
let reserve = u64::from(reserve_tokens);
let window = u64::from(context_window);
tokens_before > window.saturating_sub(reserve)
}
#[allow(clippy::too_many_lines)]
async fn maybe_auto_compact(
session: Arc<Mutex<AgentSession>>,
options: RpcOptions,
is_compacting: Arc<AtomicBool>,
out_tx: std::sync::mpsc::Sender<String>,
) {
let cx = AgentCx::for_request();
let (path_entries, context_window, reserve_tokens, settings) = {
let Ok(guard) = session.lock(cx.cx()).await else {
return;
};
let (path_entries, context_window) = {
let Ok(mut inner_session) = guard.session.lock(cx.cx()).await else {
return;
};
inner_session.ensure_entry_ids();
let Some(entry) = current_model_entry(&inner_session, &options) else {
return;
};
let path_entries = inner_session
.entries_for_current_path()
.into_iter()
.cloned()
.collect::<Vec<_>>();
(path_entries, entry.model.context_window)
};
let reserve_tokens = options.config.compaction_reserve_tokens();
let settings = ResolvedCompactionSettings {
enabled: true,
reserve_tokens,
keep_recent_tokens: options.config.compaction_keep_recent_tokens(),
..Default::default()
};
(path_entries, context_window, reserve_tokens, settings)
};
let Some(prep) = prepare_compaction(&path_entries, settings) else {
return;
};
if !should_auto_compact(prep.tokens_before, context_window, reserve_tokens) {
return;
}
let _ = out_tx.send(event(&json!({
"type": "auto_compaction_start",
"reason": "threshold",
})));
is_compacting.store(true, Ordering::SeqCst);
let (provider, key) = {
let Ok(guard) = session.lock(cx.cx()).await else {
is_compacting.store(false, Ordering::SeqCst);
return;
};
let Some(key) = guard.agent.stream_options().api_key.clone() else {
is_compacting.store(false, Ordering::SeqCst);
let _ = out_tx.send(event(&json!({
"type": "auto_compaction_end",
"result": Value::Null,
"aborted": false,
"willRetry": false,
"errorMessage": "Missing API key for compaction",
})));
return;
};
(guard.agent.provider(), key)
};
let result = compact(prep, provider, &key, None).await;
is_compacting.store(false, Ordering::SeqCst);
match result {
Ok(result) => {
let details_value = match compaction_details_to_value(&result.details) {
Ok(value) => value,
Err(err) => {
let _ = out_tx.send(event(&json!({
"type": "auto_compaction_end",
"result": Value::Null,
"aborted": false,
"willRetry": false,
"errorMessage": err.to_string(),
})));
return;
}
};
let Ok(mut guard) = session.lock(cx.cx()).await else {
return;
};
let messages = {
let Ok(mut inner_session) = guard.session.lock(cx.cx()).await else {
return;
};
inner_session.append_compaction(
result.summary.clone(),
result.first_kept_entry_id.clone(),
result.tokens_before,
Some(details_value.clone()),
None,
);
inner_session.to_messages_for_current_path()
};
let _ = guard.persist_session().await;
guard.agent.replace_messages(messages);
drop(guard);
let _ = out_tx.send(event(&json!({
"type": "auto_compaction_end",
"result": {
"summary": result.summary,
"firstKeptEntryId": result.first_kept_entry_id,
"tokensBefore": result.tokens_before,
"details": details_value,
},
"aborted": false,
"willRetry": false,
})));
}
Err(err) => {
let _ = out_tx.send(event(&json!({
"type": "auto_compaction_end",
"result": Value::Null,
"aborted": false,
"willRetry": false,
"errorMessage": err.to_string(),
})));
}
}
}
fn rpc_model_from_entry(entry: &ModelEntry) -> Value {
let input = entry
.model
.input
.iter()
.map(|t| match t {
crate::provider::InputType::Text => "text",
crate::provider::InputType::Image => "image",
})
.collect::<Vec<_>>();
json!({
"id": entry.model.id,
"name": entry.model.name,
"api": entry.model.api,
"provider": entry.model.provider,
"baseUrl": entry.model.base_url,
"reasoning": entry.model.reasoning,
"input": input,
"contextWindow": entry.model.context_window,
"maxTokens": entry.model.max_tokens,
"cost": entry.model.cost,
})
}
fn session_state(
session: &crate::session::Session,
options: &RpcOptions,
snapshot: &RpcStateSnapshot,
is_streaming: bool,
is_compacting: bool,
) -> Value {
let model = session
.header
.provider
.as_deref()
.zip(session.header.model_id.as_deref())
.and_then(|(provider, model_id)| {
options.available_models.iter().find(|m| {
provider_ids_match(&m.model.provider, provider)
&& m.model.id.eq_ignore_ascii_case(model_id)
})
})
.map(rpc_model_from_entry);
let message_count = session
.entries_for_current_path()
.iter()
.filter(|entry| matches!(entry, crate::session::SessionEntry::Message(_)))
.count();
let session_name = session
.entries_for_current_path()
.iter()
.rev()
.find_map(|entry| {
let crate::session::SessionEntry::SessionInfo(info) = entry else {
return None;
};
info.name.clone()
});
let mut state = serde_json::Map::new();
state.insert("model".to_string(), model.unwrap_or(Value::Null));
state.insert(
"thinkingLevel".to_string(),
Value::String(
session
.header
.thinking_level
.clone()
.unwrap_or_else(|| "off".to_string()),
),
);
state.insert("isStreaming".to_string(), Value::Bool(is_streaming));
state.insert("isCompacting".to_string(), Value::Bool(is_compacting));
state.insert(
"steeringMode".to_string(),
Value::String(snapshot.steering_mode.as_str().to_string()),
);
state.insert(
"followUpMode".to_string(),
Value::String(snapshot.follow_up_mode.as_str().to_string()),
);
state.insert(
"sessionFile".to_string(),
session
.path
.as_ref()
.map_or(Value::Null, |p| Value::String(p.display().to_string())),
);
state.insert(
"sessionId".to_string(),
Value::String(session.header.id.clone()),
);
state.insert(
"sessionName".to_string(),
session_name.map_or(Value::Null, Value::String),
);
state.insert(
"autoCompactionEnabled".to_string(),
Value::Bool(snapshot.auto_compaction_enabled),
);
state.insert(
"messageCount".to_string(),
Value::Number(message_count.into()),
);
state.insert(
"pendingMessageCount".to_string(),
Value::Number(snapshot.pending_count().into()),
);
state.insert(
"durabilityMode".to_string(),
Value::String(session.autosave_durability_mode().as_str().to_string()),
);
Value::Object(state)
}
fn session_stats(session: &crate::session::Session) -> Value {
let mut user_messages: u64 = 0;
let mut assistant_messages: u64 = 0;
let mut tool_results: u64 = 0;
let mut tool_calls: u64 = 0;
let mut total_input: u64 = 0;
let mut total_output: u64 = 0;
let mut total_cache_read: u64 = 0;
let mut total_cache_write: u64 = 0;
let mut total_cost: f64 = 0.0;
let messages = session.to_messages_for_current_path();
for message in &messages {
match message {
Message::User(_) | Message::Custom(_) => user_messages += 1,
Message::Assistant(message) => {
assistant_messages += 1;
tool_calls += message
.content
.iter()
.filter(|block| matches!(block, ContentBlock::ToolCall(_)))
.count() as u64;
total_input += message.usage.input;
total_output += message.usage.output;
total_cache_read += message.usage.cache_read;
total_cache_write += message.usage.cache_write;
total_cost += message.usage.cost.total;
}
Message::ToolResult(_) => tool_results += 1,
}
}
let total_messages = messages.len() as u64;
let total_tokens = total_input + total_output + total_cache_read + total_cache_write;
let autosave = session.autosave_metrics();
let pending_message_count = autosave.pending_mutations as u64;
let durability_mode = session.autosave_durability_mode();
let durability_mode_label = match durability_mode {
crate::session::AutosaveDurabilityMode::Strict => "strict",
crate::session::AutosaveDurabilityMode::Balanced => "balanced",
crate::session::AutosaveDurabilityMode::Throughput => "throughput",
};
let (status_event, status_severity, status_summary, status_action, status_sli_ids) =
if pending_message_count == 0 {
(
"session.persistence.healthy",
"ok",
"Persistence queue is clear.",
"No action required.",
vec!["sli_resume_ready_p95_ms"],
)
} else {
let summary = match durability_mode {
crate::session::AutosaveDurabilityMode::Strict => {
"Pending persistence backlog under strict durability mode."
}
crate::session::AutosaveDurabilityMode::Balanced => {
"Pending persistence backlog under balanced durability mode."
}
crate::session::AutosaveDurabilityMode::Throughput => {
"Pending persistence backlog under throughput durability mode."
}
};
let action = match durability_mode {
crate::session::AutosaveDurabilityMode::Throughput => {
"Expect deferred writes; trigger manual save before critical transitions."
}
_ => "Allow autosave flush to complete or trigger manual save before exit.",
};
(
"session.persistence.backlog",
"warning",
summary,
action,
vec![
"sli_resume_ready_p95_ms",
"sli_failure_recovery_success_rate",
],
)
};
let mut data = serde_json::Map::new();
data.insert(
"sessionFile".to_string(),
session
.path
.as_ref()
.map_or(Value::Null, |p| Value::String(p.display().to_string())),
);
data.insert(
"sessionId".to_string(),
Value::String(session.header.id.clone()),
);
data.insert(
"userMessages".to_string(),
Value::Number(user_messages.into()),
);
data.insert(
"assistantMessages".to_string(),
Value::Number(assistant_messages.into()),
);
data.insert("toolCalls".to_string(), Value::Number(tool_calls.into()));
data.insert(
"toolResults".to_string(),
Value::Number(tool_results.into()),
);
data.insert(
"totalMessages".to_string(),
Value::Number(total_messages.into()),
);
data.insert(
"durabilityMode".to_string(),
Value::String(durability_mode_label.to_string()),
);
data.insert(
"pendingMessageCount".to_string(),
Value::Number(pending_message_count.into()),
);
data.insert(
"tokens".to_string(),
json!({
"input": total_input,
"output": total_output,
"cacheRead": total_cache_read,
"cacheWrite": total_cache_write,
"total": total_tokens,
}),
);
data.insert(
"persistenceStatus".to_string(),
json!({
"event": status_event,
"severity": status_severity,
"summary": status_summary,
"action": status_action,
"sliIds": status_sli_ids,
"pendingMessageCount": pending_message_count,
"flushCounters": {
"started": autosave.flush_started,
"succeeded": autosave.flush_succeeded,
"failed": autosave.flush_failed,
},
}),
);
data.insert(
"uxEventMarkers".to_string(),
json!([
{
"event": status_event,
"severity": status_severity,
"durabilityMode": durability_mode_label,
"pendingMessageCount": pending_message_count,
"sliIds": status_sli_ids,
}
]),
);
data.insert("cost".to_string(), Value::from(total_cost));
Value::Object(data)
}
fn last_assistant_text(session: &crate::session::Session) -> Option<String> {
let entries = session.entries_for_current_path();
for entry in entries.into_iter().rev() {
let crate::session::SessionEntry::Message(msg_entry) = entry else {
continue;
};
let SessionMessage::Assistant { message } = &msg_entry.message else {
continue;
};
let mut text = String::new();
for block in &message.content {
if let ContentBlock::Text(t) = block {
text.push_str(&t.text);
}
}
if !text.is_empty() {
return Some(text);
}
}
None
}
async fn export_html_snapshot(
snapshot: &crate::session::ExportSnapshot,
output_path: Option<&str>,
) -> Result<String> {
let html = snapshot.to_html();
let path = output_path.map_or_else(
|| {
snapshot.path.as_ref().map_or_else(
|| {
let ts = chrono::Utc::now().format("%Y-%m-%dT%H-%M-%S%.3fZ");
PathBuf::from(format!("pi-session-{ts}.html"))
},
|session_path| {
let basename = session_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("session");
PathBuf::from(format!("pi-session-{basename}.html"))
},
)
},
PathBuf::from,
);
if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
asupersync::fs::create_dir_all(parent).await?;
}
asupersync::fs::write(&path, html).await?;
Ok(path.display().to_string())
}
#[derive(Debug, Clone)]
struct BashRpcResult {
output: String,
exit_code: i32,
cancelled: bool,
truncated: bool,
full_output_path: Option<String>,
}
const fn line_count_from_newline_count(
total_bytes: usize,
newline_count: usize,
last_byte_was_newline: bool,
) -> usize {
if total_bytes == 0 {
0
} else if last_byte_was_newline {
newline_count
} else {
newline_count.saturating_add(1)
}
}
async fn ingest_bash_rpc_chunk(
bytes: Vec<u8>,
chunks: &mut VecDeque<Vec<u8>>,
chunks_bytes: &mut usize,
total_bytes: &mut usize,
total_lines: &mut usize,
last_byte_was_newline: &mut bool,
temp_file: &mut Option<asupersync::fs::File>,
temp_file_path: &mut Option<PathBuf>,
spill_failed: &mut bool,
max_chunks_bytes: usize,
) {
if bytes.is_empty() {
return;
}
*last_byte_was_newline = bytes.last().is_some_and(|byte| *byte == b'\n');
*total_bytes = total_bytes.saturating_add(bytes.len());
*total_lines = total_lines.saturating_add(memchr_iter(b'\n', &bytes).count());
if *total_bytes > DEFAULT_MAX_BYTES && temp_file.is_none() && !*spill_failed {
let id_full = uuid::Uuid::new_v4().simple().to_string();
let id = &id_full[..16];
let path = std::env::temp_dir().join(format!("pi-rpc-bash-{id}.log"));
let expected_inode: Option<u64> = {
let mut options = std::fs::OpenOptions::new();
options.write(true).create_new(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
options.mode(0o600);
}
match options.open(&path) {
Ok(file) => {
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
file.metadata().ok().map(|m| m.ino())
}
#[cfg(not(unix))]
{
None
}
}
Err(e) => {
tracing::warn!("Failed to create bash temp file: {e}");
None
}
}
};
if expected_inode.is_some() || !cfg!(unix) {
match asupersync::fs::OpenOptions::new()
.append(true)
.open(&path)
.await
{
Ok(mut file) => {
let mut identity_match = true;
#[cfg(unix)]
if let Some(expected) = expected_inode {
use std::os::unix::fs::MetadataExt;
match file.metadata().await {
Ok(meta) => {
if meta.ino() != expected {
tracing::warn!(
"Temp file identity mismatch (possible TOCTOU attack)"
);
identity_match = false;
}
}
Err(e) => {
tracing::warn!("Failed to stat temp file: {e}");
identity_match = false;
}
}
}
if identity_match {
for existing in chunks.iter() {
use asupersync::io::AsyncWriteExt;
if let Err(e) = file.write_all(existing).await {
tracing::warn!("Failed to flush bash chunk to temp file: {e}");
*spill_failed = true;
break;
}
}
if !*spill_failed {
*temp_file = Some(file);
*temp_file_path = Some(path);
}
} else {
let _ = std::fs::remove_file(&path);
*spill_failed = true;
}
}
Err(e) => {
tracing::warn!("Failed to reopen bash temp file async: {e}");
let _ = std::fs::remove_file(&path);
*spill_failed = true;
}
}
} else {
*spill_failed = true;
}
}
if let Some(file) = temp_file.as_mut() {
if *total_bytes <= crate::tools::BASH_FILE_LIMIT_BYTES {
use asupersync::io::AsyncWriteExt;
if let Err(e) = file.write_all(&bytes).await {
tracing::warn!("Failed to write bash chunk to temp file: {e}");
*spill_failed = true;
*temp_file = None;
}
} else {
if !*spill_failed {
tracing::warn!("Bash output exceeded hard limit; stopping file log");
*spill_failed = true;
*temp_file = None;
}
}
}
*chunks_bytes = chunks_bytes.saturating_add(bytes.len());
chunks.push_back(bytes);
while *chunks_bytes > max_chunks_bytes && chunks.len() > 1 {
if let Some(front) = chunks.pop_front() {
*chunks_bytes = chunks_bytes.saturating_sub(front.len());
}
}
}
async fn run_bash_rpc(
cwd: &std::path::Path,
command: &str,
abort_rx: oneshot::Receiver<()>,
) -> Result<BashRpcResult> {
#[derive(Clone, Copy)]
enum StreamKind {
Stdout,
Stderr,
}
struct StreamChunk {
kind: StreamKind,
bytes: Vec<u8>,
}
fn pump_stream(
mut reader: impl std::io::Read,
tx: std::sync::mpsc::SyncSender<StreamChunk>,
kind: StreamKind,
) {
let mut buf = [0u8; 8192];
loop {
let read = match reader.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(read) => read,
};
let chunk = StreamChunk {
kind,
bytes: buf[..read].to_vec(),
};
if tx.send(chunk).is_err() {
break;
}
}
}
let shell = ["/bin/bash", "/usr/bin/bash", "/usr/local/bin/bash"]
.into_iter()
.find(|p| std::path::Path::new(p).exists())
.unwrap_or("sh");
let command = format!("trap 'code=$?; wait; exit $code' EXIT\n{command}");
let mut child = std::process::Command::new(shell)
.arg("-c")
.arg(&command)
.current_dir(cwd)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| Error::tool("bash", format!("Failed to spawn shell: {e}")))?;
let Some(stdout) = child.stdout.take() else {
return Err(Error::tool("bash", "Missing stdout".to_string()));
};
let Some(stderr) = child.stderr.take() else {
return Err(Error::tool("bash", "Missing stderr".to_string()));
};
let mut guard = crate::tools::ProcessGuard::new(child, true);
let (tx, rx) = std::sync::mpsc::sync_channel::<StreamChunk>(128);
let tx_stdout = tx.clone();
let _stdout_handle =
std::thread::spawn(move || pump_stream(stdout, tx_stdout, StreamKind::Stdout));
let _stderr_handle = std::thread::spawn(move || pump_stream(stderr, tx, StreamKind::Stderr));
let tick = Duration::from_millis(10);
let mut chunks: VecDeque<Vec<u8>> = VecDeque::new();
let mut chunks_bytes = 0usize;
let mut total_bytes = 0usize;
let mut total_lines = 0usize;
let mut last_byte_was_newline = false;
let mut temp_file: Option<asupersync::fs::File> = None;
let mut temp_file_path: Option<PathBuf> = None;
let max_chunks_bytes = DEFAULT_MAX_BYTES * 2;
let mut cancelled = false;
let mut spill_failed = false;
let exit_code = loop {
while let Ok(chunk) = rx.try_recv() {
ingest_bash_rpc_chunk(
chunk.bytes,
&mut chunks,
&mut chunks_bytes,
&mut total_bytes,
&mut total_lines,
&mut last_byte_was_newline,
&mut temp_file,
&mut temp_file_path,
&mut spill_failed,
max_chunks_bytes,
)
.await;
}
if !cancelled && abort_rx.try_recv().is_ok() {
cancelled = true;
if let Ok(Some(status)) = guard.kill() {
break status.code().unwrap_or(-1);
}
}
match guard.try_wait_child() {
Ok(Some(status)) => break status.code().unwrap_or(-1),
Ok(None) => {}
Err(err) => {
return Err(Error::tool(
"bash",
format!("Failed to wait for process: {err}"),
));
}
}
sleep(wall_now(), tick).await;
};
let drain_deadline = Instant::now() + Duration::from_secs(2);
let mut drain_timed_out = false;
loop {
match rx.try_recv() {
Ok(chunk) => {
ingest_bash_rpc_chunk(
chunk.bytes,
&mut chunks,
&mut chunks_bytes,
&mut total_bytes,
&mut total_lines,
&mut last_byte_was_newline,
&mut temp_file,
&mut temp_file_path,
&mut spill_failed,
max_chunks_bytes,
)
.await;
}
Err(std::sync::mpsc::TryRecvError::Empty) => {
if Instant::now() >= drain_deadline {
drain_timed_out = true;
break;
}
sleep(wall_now(), tick).await;
}
Err(std::sync::mpsc::TryRecvError::Disconnected) => break,
}
}
drop(rx);
drop(temp_file);
let mut combined = Vec::with_capacity(chunks_bytes);
for chunk in chunks {
combined.extend_from_slice(&chunk);
}
let tail_output = String::from_utf8_lossy(&combined).to_string();
let mut truncation = truncate_tail(tail_output, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
if total_bytes > chunks_bytes {
truncation.truncated = true;
truncation.truncated_by = Some(crate::tools::TruncatedBy::Bytes);
truncation.total_bytes = total_bytes;
truncation.total_lines =
line_count_from_newline_count(total_bytes, total_lines, last_byte_was_newline);
} else if drain_timed_out {
truncation.truncated = true;
truncation.truncated_by = Some(crate::tools::TruncatedBy::Bytes);
}
let will_truncate = truncation.truncated;
let mut output_text = if truncation.content.is_empty() {
"(no output)".to_string()
} else {
truncation.content
};
if drain_timed_out {
output_text.push_str("\n... [Output truncated: drain timeout]");
}
Ok(BashRpcResult {
output: output_text,
exit_code,
cancelled,
truncated: will_truncate,
full_output_path: temp_file_path.map(|p| p.display().to_string()),
})
}
fn parse_prompt_images(value: Option<&Value>) -> Result<Vec<ImageContent>> {
let Some(value) = value else {
return Ok(Vec::new());
};
let Some(arr) = value.as_array() else {
return Err(Error::validation("images must be an array"));
};
let mut images = Vec::new();
for item in arr {
let Some(obj) = item.as_object() else {
continue;
};
let item_type = obj.get("type").and_then(Value::as_str).unwrap_or("");
if item_type != "image" {
continue;
}
let Some(source) = obj.get("source").and_then(Value::as_object) else {
continue;
};
let source_type = source.get("type").and_then(Value::as_str).unwrap_or("");
if source_type != "base64" {
continue;
}
let Some(media_type) = source.get("mediaType").and_then(Value::as_str) else {
continue;
};
let Some(data) = source.get("data").and_then(Value::as_str) else {
continue;
};
images.push(ImageContent {
data: data.to_string(),
mime_type: media_type.to_string(),
});
}
Ok(images)
}
fn resolve_model_key(auth: &AuthStorage, entry: &ModelEntry) -> Option<String> {
normalize_api_key_opt(auth.resolve_api_key(&entry.model.provider, None))
.or_else(|| normalize_api_key_opt(entry.api_key.clone()))
}
fn normalize_api_key_opt(api_key: Option<String>) -> Option<String> {
api_key.and_then(|key| {
let trimmed = key.trim();
(!trimmed.is_empty()).then(|| trimmed.to_string())
})
}
fn model_requires_configured_credential(entry: &ModelEntry) -> bool {
let provider = entry.model.provider.as_str();
entry.auth_header
|| provider_metadata(provider).is_some_and(|meta| !meta.auth_env_keys.is_empty())
|| entry.oauth_config.is_some()
}
fn parse_thinking_level(level: &str) -> Result<crate::model::ThinkingLevel> {
level.parse().map_err(|err: String| Error::validation(err))
}
fn current_model_entry<'a>(
session: &crate::session::Session,
options: &'a RpcOptions,
) -> Option<&'a ModelEntry> {
let provider = session.header.provider.as_deref()?;
let model_id = session.header.model_id.as_deref()?;
options.available_models.iter().find(|m| {
provider_ids_match(&m.model.provider, provider) && m.model.id.eq_ignore_ascii_case(model_id)
})
}
async fn apply_thinking_level(
guard: &mut AgentSession,
level: crate::model::ThinkingLevel,
) -> Result<()> {
let cx = AgentCx::for_request();
{
let mut inner_session = guard
.session
.lock(cx.cx())
.await
.map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
inner_session.header.thinking_level = Some(level.to_string());
inner_session.append_thinking_level_change(level.to_string());
}
guard.agent.stream_options_mut().thinking_level = Some(level);
guard.persist_session().await
}
async fn apply_model_change(guard: &mut AgentSession, entry: &ModelEntry) -> Result<()> {
let cx = AgentCx::for_request();
{
let mut inner_session = guard
.session
.lock(cx.cx())
.await
.map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
inner_session.header.provider = Some(entry.model.provider.clone());
inner_session.header.model_id = Some(entry.model.id.clone());
inner_session.append_model_change(entry.model.provider.clone(), entry.model.id.clone());
}
guard.persist_session().await
}
fn fork_messages_from_entries(entries: &[crate::session::SessionEntry]) -> Vec<Value> {
let mut result = Vec::new();
for entry in entries {
let crate::session::SessionEntry::Message(m) = entry else {
continue;
};
let SessionMessage::User { content, .. } = &m.message else {
continue;
};
let entry_id = m.base.id.clone().unwrap_or_default();
let text = extract_user_text(content);
result.push(json!({
"entryId": entry_id,
"text": text,
}));
}
result
}
fn extract_user_text(content: &crate::model::UserContent) -> Option<String> {
match content {
crate::model::UserContent::Text(text) => Some(text.clone()),
crate::model::UserContent::Blocks(blocks) => blocks.iter().find_map(|b| {
if let ContentBlock::Text(t) = b {
Some(t.text.clone())
} else {
None
}
}),
}
}
fn available_thinking_levels(entry: &ModelEntry) -> Vec<crate::model::ThinkingLevel> {
use crate::model::ThinkingLevel;
if entry.model.reasoning {
let mut levels = vec![
ThinkingLevel::Off,
ThinkingLevel::Minimal,
ThinkingLevel::Low,
ThinkingLevel::Medium,
ThinkingLevel::High,
];
if entry.supports_xhigh() {
levels.push(ThinkingLevel::XHigh);
}
levels
} else {
vec![ThinkingLevel::Off]
}
}
async fn cycle_model_for_rpc(
guard: &mut AgentSession,
options: &RpcOptions,
) -> Result<Option<(ModelEntry, crate::model::ThinkingLevel, bool)>> {
let (candidates, is_scoped) = if options.scoped_models.is_empty() {
(options.available_models.clone(), false)
} else {
(
options
.scoped_models
.iter()
.map(|sm| sm.model.clone())
.collect::<Vec<_>>(),
true,
)
};
if candidates.len() <= 1 {
return Ok(None);
}
let cx = AgentCx::for_request();
let (current_provider, current_model_id) = {
let inner_session = guard
.session
.lock(cx.cx())
.await
.map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
(
inner_session.header.provider.clone(),
inner_session.header.model_id.clone(),
)
};
let current_index = candidates.iter().position(|entry| {
current_provider
.as_deref()
.is_some_and(|provider| provider_ids_match(provider, &entry.model.provider))
&& current_model_id
.as_deref()
.is_some_and(|model_id| model_id.eq_ignore_ascii_case(&entry.model.id))
});
let next_index = current_index.map_or(0, |idx| (idx + 1) % candidates.len());
let next_entry = candidates[next_index].clone();
let provider_impl = crate::providers::create_provider(
&next_entry,
guard
.extensions
.as_ref()
.map(crate::extensions::ExtensionRegion::manager),
)?;
guard.agent.set_provider(provider_impl);
let key = resolve_model_key(&options.auth, &next_entry);
if model_requires_configured_credential(&next_entry) && key.is_none() {
return Err(Error::auth(format!(
"Missing credentials for {}/{}",
next_entry.model.provider, next_entry.model.id
)));
}
guard.agent.stream_options_mut().api_key.clone_from(&key);
guard
.agent
.stream_options_mut()
.headers
.clone_from(&next_entry.headers);
apply_model_change(guard, &next_entry).await?;
let desired_thinking = if is_scoped {
options.scoped_models[next_index]
.thinking_level
.unwrap_or(crate::model::ThinkingLevel::Off)
} else {
guard
.agent
.stream_options()
.thinking_level
.unwrap_or_default()
};
let next_thinking = next_entry.clamp_thinking_level(desired_thinking);
apply_thinking_level(guard, next_thinking).await?;
Ok(Some((next_entry, next_thinking, is_scoped)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::AuthCredential;
use crate::model::{
ContentBlock, ImageContent, TextContent, ThinkingLevel, UserContent, UserMessage,
};
use crate::provider::{InputType, Model, ModelCost};
use crate::session::Session;
use serde_json::json;
use std::collections::HashMap;
fn dummy_model(id: &str, reasoning: bool) -> Model {
Model {
id: id.to_string(),
name: id.to_string(),
api: "anthropic".to_string(),
provider: "anthropic".to_string(),
base_url: "https://api.anthropic.com".to_string(),
reasoning,
input: vec![InputType::Text],
cost: ModelCost {
input: 3.0,
output: 15.0,
cache_read: 0.3,
cache_write: 3.75,
},
context_window: 200_000,
max_tokens: 8192,
headers: HashMap::new(),
}
}
fn dummy_entry(id: &str, reasoning: bool) -> ModelEntry {
ModelEntry {
model: dummy_model(id, reasoning),
api_key: None,
headers: HashMap::new(),
auth_header: false,
compat: None,
oauth_config: None,
}
}
fn rpc_options_with_models(available_models: Vec<ModelEntry>) -> RpcOptions {
let runtime = asupersync::runtime::RuntimeBuilder::new()
.blocking_threads(1, 1)
.build()
.expect("runtime build");
let runtime_handle = runtime.handle();
let auth_path = tempfile::tempdir()
.expect("tempdir")
.path()
.join("auth.json");
let auth = AuthStorage::load(auth_path).expect("auth load");
RpcOptions {
config: Config::default(),
resources: ResourceLoader::empty(false),
available_models,
scoped_models: Vec::new(),
auth,
runtime_handle,
}
}
#[test]
fn line_count_from_newline_count_matches_trailing_newline_semantics() {
assert_eq!(line_count_from_newline_count(0, 0, false), 0);
assert_eq!(line_count_from_newline_count(2, 1, true), 1);
assert_eq!(line_count_from_newline_count(1, 0, false), 1);
assert_eq!(line_count_from_newline_count(3, 1, false), 2);
}
#[test]
fn parse_queue_mode_all() {
assert_eq!(parse_queue_mode(Some("all")), Some(QueueMode::All));
}
#[test]
fn parse_queue_mode_one_at_a_time() {
assert_eq!(
parse_queue_mode(Some("one-at-a-time")),
Some(QueueMode::OneAtATime)
);
}
#[test]
fn parse_queue_mode_none_value() {
assert_eq!(parse_queue_mode(None), None);
}
#[test]
fn parse_queue_mode_unknown_returns_none() {
assert_eq!(parse_queue_mode(Some("batch")), None);
assert_eq!(parse_queue_mode(Some("")), None);
}
#[test]
fn parse_queue_mode_trims_whitespace() {
assert_eq!(parse_queue_mode(Some(" all ")), Some(QueueMode::All));
}
#[test]
fn provider_ids_match_accepts_aliases() {
assert!(provider_ids_match("openrouter", "open-router"));
assert!(provider_ids_match("google-gemini-cli", "gemini-cli"));
assert!(!provider_ids_match("openai", "anthropic"));
}
#[test]
fn resolve_model_key_prefers_stored_auth_key_over_inline_entry_key() {
let mut entry = dummy_entry("gpt-4o-mini", true);
entry.model.provider = "openai".to_string();
entry.auth_header = true;
entry.api_key = Some("inline-model-key".to_string());
let auth_path = tempfile::tempdir()
.expect("tempdir")
.path()
.join("auth.json");
let mut auth = AuthStorage::load(auth_path).expect("auth load");
auth.set(
"openai".to_string(),
AuthCredential::ApiKey {
key: "stored-auth-key".to_string(),
},
);
assert_eq!(
resolve_model_key(&auth, &entry).as_deref(),
Some("stored-auth-key")
);
}
#[test]
fn resolve_model_key_ignores_blank_inline_key_and_falls_back_to_auth_storage() {
let mut entry = dummy_entry("gpt-4o-mini", true);
entry.model.provider = "openai".to_string();
entry.auth_header = true;
entry.api_key = Some(" ".to_string());
let auth_path = tempfile::tempdir()
.expect("tempdir")
.path()
.join("auth.json");
let mut auth = AuthStorage::load(auth_path).expect("auth load");
auth.set(
"openai".to_string(),
AuthCredential::ApiKey {
key: "stored-auth-key".to_string(),
},
);
assert_eq!(
resolve_model_key(&auth, &entry).as_deref(),
Some("stored-auth-key")
);
}
#[test]
fn unknown_keyless_model_does_not_require_credentials() {
let mut entry = dummy_entry("dev-model", false);
entry.model.provider = "acme-local".to_string();
entry.auth_header = false;
entry.oauth_config = None;
assert!(!model_requires_configured_credential(&entry));
}
#[test]
fn anthropic_model_requires_credentials_even_without_auth_header() {
let mut entry = dummy_entry("claude-sonnet-4-6", true);
entry.model.provider = "anthropic".to_string();
entry.auth_header = false;
entry.oauth_config = None;
assert!(model_requires_configured_credential(&entry));
}
#[test]
fn parse_streaming_behavior_steer() {
let val = json!("steer");
let result = parse_streaming_behavior(Some(&val)).unwrap();
assert_eq!(result, Some(StreamingBehavior::Steer));
}
#[test]
fn parse_streaming_behavior_follow_up_hyphenated() {
let val = json!("follow-up");
let result = parse_streaming_behavior(Some(&val)).unwrap();
assert_eq!(result, Some(StreamingBehavior::FollowUp));
}
#[test]
fn parse_streaming_behavior_follow_up_camel() {
let val = json!("followUp");
let result = parse_streaming_behavior(Some(&val)).unwrap();
assert_eq!(result, Some(StreamingBehavior::FollowUp));
}
#[test]
fn parse_streaming_behavior_none() {
let result = parse_streaming_behavior(None).unwrap();
assert_eq!(result, None);
}
#[test]
fn parse_streaming_behavior_invalid_string() {
let val = json!("invalid");
assert!(parse_streaming_behavior(Some(&val)).is_err());
}
#[test]
fn parse_streaming_behavior_non_string_errors() {
let val = json!(42);
assert!(parse_streaming_behavior(Some(&val)).is_err());
}
#[test]
fn normalize_command_type_passthrough() {
assert_eq!(normalize_command_type("prompt"), "prompt");
assert_eq!(normalize_command_type("compact"), "compact");
}
#[test]
fn normalize_command_type_follow_up_aliases() {
assert_eq!(normalize_command_type("follow-up"), "follow_up");
assert_eq!(normalize_command_type("followUp"), "follow_up");
assert_eq!(normalize_command_type("queue-follow-up"), "follow_up");
assert_eq!(normalize_command_type("queueFollowUp"), "follow_up");
}
#[test]
fn normalize_command_type_kebab_and_camel_aliases() {
assert_eq!(normalize_command_type("get-state"), "get_state");
assert_eq!(normalize_command_type("getState"), "get_state");
assert_eq!(normalize_command_type("set-model"), "set_model");
assert_eq!(normalize_command_type("setModel"), "set_model");
assert_eq!(
normalize_command_type("set-steering-mode"),
"set_steering_mode"
);
assert_eq!(
normalize_command_type("setSteeringMode"),
"set_steering_mode"
);
assert_eq!(
normalize_command_type("set-follow-up-mode"),
"set_follow_up_mode"
);
assert_eq!(
normalize_command_type("setFollowUpMode"),
"set_follow_up_mode"
);
assert_eq!(
normalize_command_type("set-auto-compaction"),
"set_auto_compaction"
);
assert_eq!(
normalize_command_type("setAutoCompaction"),
"set_auto_compaction"
);
assert_eq!(normalize_command_type("set-auto-retry"), "set_auto_retry");
assert_eq!(normalize_command_type("setAutoRetry"), "set_auto_retry");
}
#[test]
fn build_user_message_text_only() {
let msg = build_user_message("hello", &[]);
match msg {
Message::User(UserMessage {
content: UserContent::Text(text),
..
}) => assert_eq!(text, "hello"),
other => panic!("expected text user message, got {other:?}"),
}
}
#[test]
fn build_user_message_with_images() {
let images = vec![ImageContent {
data: "base64data".to_string(),
mime_type: "image/png".to_string(),
}];
let msg = build_user_message("look at this", &images);
match msg {
Message::User(UserMessage {
content: UserContent::Blocks(blocks),
..
}) => {
assert_eq!(blocks.len(), 2);
assert!(matches!(&blocks[0], ContentBlock::Text(_)));
assert!(matches!(&blocks[1], ContentBlock::Image(_)));
}
other => panic!("expected blocks user message, got {other:?}"),
}
}
#[test]
fn is_extension_command_slash_unchanged() {
assert!(is_extension_command("/mycommand", "/mycommand"));
}
#[test]
fn is_extension_command_expanded_returns_false() {
assert!(!is_extension_command(
"/prompt-name",
"This is the expanded prompt text."
));
}
#[test]
fn is_extension_command_no_slash() {
assert!(!is_extension_command("hello", "hello"));
}
#[test]
fn is_extension_command_leading_whitespace() {
assert!(is_extension_command(" /cmd", " /cmd"));
}
#[test]
fn try_send_line_with_backpressure_enqueues_when_capacity_available() {
let (tx, _rx) = mpsc::channel::<String>(1);
assert!(try_send_line_with_backpressure(&tx, "line".to_string()));
assert!(matches!(
tx.try_send("next".to_string()),
Err(mpsc::SendError::Full(_))
));
}
#[test]
fn try_send_line_with_backpressure_stops_when_receiver_closed() {
let (tx, rx) = mpsc::channel::<String>(1);
drop(rx);
assert!(!try_send_line_with_backpressure(&tx, "line".to_string()));
}
#[test]
fn try_send_line_with_backpressure_waits_until_capacity_is_available() {
let (tx, rx) = mpsc::channel::<String>(1);
tx.try_send("occupied".to_string())
.expect("seed initial occupied slot");
let expected = "delayed-line".to_string();
let expected_for_thread = expected.clone();
let recv_handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(30));
let deadline = Instant::now() + Duration::from_millis(300);
let mut received = Vec::new();
while received.len() < 2 && Instant::now() < deadline {
if let Ok(msg) = rx.try_recv() {
received.push(msg);
} else {
std::thread::sleep(Duration::from_millis(5));
}
}
assert_eq!(received.len(), 2, "should receive both queued lines");
let first = received.remove(0);
let second = received.remove(0);
assert_eq!(first, "occupied");
assert_eq!(second, expected_for_thread);
});
assert!(try_send_line_with_backpressure(&tx, expected));
drop(tx);
recv_handle.join().expect("receiver thread should finish");
}
#[test]
fn try_send_line_with_backpressure_preserves_large_payload() {
let (tx, rx) = mpsc::channel::<String>(1);
tx.try_send("busy".to_string())
.expect("seed initial busy slot");
let large = "x".repeat(256 * 1024);
let large_for_thread = large.clone();
let recv_handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(30));
let deadline = Instant::now() + Duration::from_millis(500);
let mut received = Vec::new();
while received.len() < 2 && Instant::now() < deadline {
if let Ok(msg) = rx.try_recv() {
received.push(msg);
} else {
std::thread::sleep(Duration::from_millis(5));
}
}
assert_eq!(received.len(), 2, "should receive busy + payload lines");
let payload = received.remove(1);
assert_eq!(payload.len(), large_for_thread.len());
assert_eq!(payload, large_for_thread);
});
assert!(try_send_line_with_backpressure(&tx, large));
drop(tx);
recv_handle.join().expect("receiver thread should finish");
}
#[test]
fn try_send_line_with_backpressure_detects_disconnect_while_waiting() {
let (tx, rx) = mpsc::channel::<String>(1);
tx.try_send("busy".to_string())
.expect("seed initial busy slot");
let drop_handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(30));
drop(rx);
});
assert!(
!try_send_line_with_backpressure(&tx, "line-after-disconnect".to_string()),
"send should stop after receiver disconnects while channel is full"
);
drop_handle.join().expect("drop thread should finish");
}
#[test]
fn try_send_line_with_backpressure_high_volume_preserves_order_and_count() {
let (tx, rx) = mpsc::channel::<String>(4);
let lines: Vec<String> = (0..256)
.map(|idx| format!("line-{idx:03}: {}", "x".repeat(64)))
.collect();
let expected = lines.clone();
let recv_handle = std::thread::spawn(move || {
let deadline = Instant::now() + Duration::from_secs(4);
let mut received = Vec::new();
while received.len() < expected.len() && Instant::now() < deadline {
if let Ok(msg) = rx.try_recv() {
received.push(msg);
}
std::thread::sleep(Duration::from_millis(1));
}
assert_eq!(
received.len(),
expected.len(),
"should receive every line under sustained backpressure"
);
assert_eq!(received, expected, "line ordering must remain stable");
});
for line in lines {
assert!(try_send_line_with_backpressure(&tx, line));
}
drop(tx);
recv_handle.join().expect("receiver thread should finish");
}
#[test]
fn try_send_line_with_backpressure_preserves_partial_line_without_newline() {
let (tx, rx) = mpsc::channel::<String>(1);
tx.try_send("busy".to_string())
.expect("seed initial busy slot");
let partial_json = "{\"type\":\"prompt\",\"message\":\"tail-fragment-ascii\"".to_string();
let expected = partial_json.clone();
let recv_handle = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(25));
let first = rx.try_recv().expect("seeded line should be available");
assert_eq!(first, "busy");
let deadline = Instant::now() + Duration::from_millis(500);
let second = loop {
if let Ok(line) = rx.try_recv() {
break line;
}
assert!(
Instant::now() < deadline,
"partial payload should be available"
);
std::thread::sleep(Duration::from_millis(5));
};
assert_eq!(second, expected);
});
assert!(try_send_line_with_backpressure(&tx, partial_json));
drop(tx);
recv_handle.join().expect("receiver thread should finish");
}
#[test]
fn snapshot_pending_count() {
let snapshot = RpcStateSnapshot {
steering_count: 3,
follow_up_count: 7,
steering_mode: QueueMode::All,
follow_up_mode: QueueMode::OneAtATime,
auto_compaction_enabled: false,
auto_retry_enabled: true,
};
assert_eq!(snapshot.pending_count(), 10);
}
#[test]
fn snapshot_pending_count_zero() {
let snapshot = RpcStateSnapshot {
steering_count: 0,
follow_up_count: 0,
steering_mode: QueueMode::All,
follow_up_mode: QueueMode::All,
auto_compaction_enabled: false,
auto_retry_enabled: false,
};
assert_eq!(snapshot.pending_count(), 0);
}
#[test]
fn retry_delay_first_attempt_is_base() {
let config = Config::default();
assert_eq!(retry_delay_ms(&config, 0), config.retry_base_delay_ms());
assert_eq!(retry_delay_ms(&config, 1), config.retry_base_delay_ms());
}
#[test]
fn retry_delay_doubles_each_attempt() {
let config = Config::default();
let base = config.retry_base_delay_ms();
assert_eq!(retry_delay_ms(&config, 2), base * 2);
assert_eq!(retry_delay_ms(&config, 3), base * 4);
}
#[test]
fn retry_delay_capped_at_max() {
let config = Config::default();
let max = config.retry_max_delay_ms();
let delay = retry_delay_ms(&config, 30);
assert_eq!(delay, max);
}
#[test]
fn retry_delay_saturates_on_overflow() {
let config = Config::default();
let delay = retry_delay_ms(&config, u32::MAX);
assert!(delay <= config.retry_max_delay_ms());
}
#[test]
fn auto_compact_below_threshold() {
assert!(!should_auto_compact(50_000, 200_000, 40_000));
}
#[test]
fn auto_compact_above_threshold() {
assert!(should_auto_compact(170_000, 200_000, 40_000));
}
#[test]
fn auto_compact_exact_threshold() {
assert!(!should_auto_compact(160_000, 200_000, 40_000));
}
#[test]
fn auto_compact_reserve_exceeds_window() {
assert!(should_auto_compact(1, 100, 200));
}
#[test]
fn auto_compact_zero_tokens() {
assert!(!should_auto_compact(0, 200_000, 40_000));
}
#[test]
fn flatten_content_blocks_unwraps_inner_0() {
let mut value = json!({
"content": [
{"0": {"type": "text", "text": "hello"}}
]
});
rpc_flatten_content_blocks(&mut value);
let blocks = value["content"].as_array().unwrap();
assert_eq!(blocks[0]["type"], "text");
assert_eq!(blocks[0]["text"], "hello");
assert!(blocks[0].get("0").is_none());
}
#[test]
fn flatten_content_blocks_preserves_non_wrapped() {
let mut value = json!({
"content": [
{"type": "text", "text": "already flat"}
]
});
rpc_flatten_content_blocks(&mut value);
let blocks = value["content"].as_array().unwrap();
assert_eq!(blocks[0]["type"], "text");
assert_eq!(blocks[0]["text"], "already flat");
}
#[test]
fn flatten_content_blocks_no_content_field() {
let mut value = json!({"role": "assistant"});
rpc_flatten_content_blocks(&mut value); assert_eq!(value, json!({"role": "assistant"}));
}
#[test]
fn flatten_content_blocks_non_object() {
let mut value = json!("just a string");
rpc_flatten_content_blocks(&mut value); }
#[test]
fn flatten_content_blocks_existing_keys_not_overwritten() {
let mut value = json!({
"content": [
{"type": "existing", "0": {"type": "inner", "extra": "data"}}
]
});
rpc_flatten_content_blocks(&mut value);
let blocks = value["content"].as_array().unwrap();
assert_eq!(blocks[0]["type"], "existing");
assert_eq!(blocks[0]["extra"], "data");
}
#[test]
fn parse_prompt_images_none() {
let images = parse_prompt_images(None).unwrap();
assert!(images.is_empty());
}
#[test]
fn parse_prompt_images_empty_array() {
let val = json!([]);
let images = parse_prompt_images(Some(&val)).unwrap();
assert!(images.is_empty());
}
#[test]
fn parse_prompt_images_valid() {
let val = json!([{
"type": "image",
"source": {
"type": "base64",
"mediaType": "image/png",
"data": "iVBORw0KGgo="
}
}]);
let images = parse_prompt_images(Some(&val)).unwrap();
assert_eq!(images.len(), 1);
assert_eq!(images[0].mime_type, "image/png");
assert_eq!(images[0].data, "iVBORw0KGgo=");
}
#[test]
fn parse_prompt_images_skips_non_image_type() {
let val = json!([{
"type": "text",
"text": "hello"
}]);
let images = parse_prompt_images(Some(&val)).unwrap();
assert!(images.is_empty());
}
#[test]
fn parse_prompt_images_skips_non_base64_source() {
let val = json!([{
"type": "image",
"source": {
"type": "url",
"url": "https://example.com/img.png"
}
}]);
let images = parse_prompt_images(Some(&val)).unwrap();
assert!(images.is_empty());
}
#[test]
fn parse_prompt_images_not_array_errors() {
let val = json!("not-an-array");
assert!(parse_prompt_images(Some(&val)).is_err());
}
#[test]
fn parse_prompt_images_multiple_valid() {
let val = json!([
{
"type": "image",
"source": {"type": "base64", "mediaType": "image/jpeg", "data": "abc"}
},
{
"type": "image",
"source": {"type": "base64", "mediaType": "image/webp", "data": "def"}
}
]);
let images = parse_prompt_images(Some(&val)).unwrap();
assert_eq!(images.len(), 2);
assert_eq!(images[0].mime_type, "image/jpeg");
assert_eq!(images[1].mime_type, "image/webp");
}
#[test]
fn extract_user_text_from_text_content() {
let content = UserContent::Text("hello world".to_string());
assert_eq!(extract_user_text(&content), Some("hello world".to_string()));
}
#[test]
fn extract_user_text_from_blocks() {
let content = UserContent::Blocks(vec![
ContentBlock::Image(ImageContent {
data: String::new(),
mime_type: "image/png".to_string(),
}),
ContentBlock::Text(TextContent::new("found it")),
]);
assert_eq!(extract_user_text(&content), Some("found it".to_string()));
}
#[test]
fn extract_user_text_blocks_no_text() {
let content = UserContent::Blocks(vec![ContentBlock::Image(ImageContent {
data: String::new(),
mime_type: "image/png".to_string(),
})]);
assert_eq!(extract_user_text(&content), None);
}
#[test]
fn parse_thinking_level_all_variants() {
assert_eq!(parse_thinking_level("off").unwrap(), ThinkingLevel::Off);
assert_eq!(parse_thinking_level("none").unwrap(), ThinkingLevel::Off);
assert_eq!(parse_thinking_level("0").unwrap(), ThinkingLevel::Off);
assert_eq!(
parse_thinking_level("minimal").unwrap(),
ThinkingLevel::Minimal
);
assert_eq!(parse_thinking_level("min").unwrap(), ThinkingLevel::Minimal);
assert_eq!(parse_thinking_level("low").unwrap(), ThinkingLevel::Low);
assert_eq!(parse_thinking_level("1").unwrap(), ThinkingLevel::Low);
assert_eq!(
parse_thinking_level("medium").unwrap(),
ThinkingLevel::Medium
);
assert_eq!(parse_thinking_level("med").unwrap(), ThinkingLevel::Medium);
assert_eq!(parse_thinking_level("2").unwrap(), ThinkingLevel::Medium);
assert_eq!(parse_thinking_level("high").unwrap(), ThinkingLevel::High);
assert_eq!(parse_thinking_level("3").unwrap(), ThinkingLevel::High);
assert_eq!(parse_thinking_level("xhigh").unwrap(), ThinkingLevel::XHigh);
assert_eq!(parse_thinking_level("4").unwrap(), ThinkingLevel::XHigh);
}
#[test]
fn parse_thinking_level_case_insensitive() {
assert_eq!(parse_thinking_level("HIGH").unwrap(), ThinkingLevel::High);
assert_eq!(
parse_thinking_level("Medium").unwrap(),
ThinkingLevel::Medium
);
assert_eq!(parse_thinking_level(" Off ").unwrap(), ThinkingLevel::Off);
}
#[test]
fn parse_thinking_level_invalid() {
assert!(parse_thinking_level("invalid").is_err());
assert!(parse_thinking_level("").is_err());
assert!(parse_thinking_level("5").is_err());
}
#[test]
fn supports_xhigh_known_models() {
assert!(dummy_entry("gpt-5.1-codex-max", true).supports_xhigh());
assert!(dummy_entry("gpt-5.2", true).supports_xhigh());
assert!(dummy_entry("gpt-5.2-codex", true).supports_xhigh());
assert!(dummy_entry("gpt-5.3-codex", true).supports_xhigh());
}
#[test]
fn supports_xhigh_unknown_models() {
assert!(!dummy_entry("claude-opus-4-6", true).supports_xhigh());
assert!(!dummy_entry("gpt-4o", true).supports_xhigh());
assert!(!dummy_entry("", true).supports_xhigh());
}
#[test]
fn clamp_thinking_non_reasoning_model() {
let entry = dummy_entry("claude-3-haiku", false);
assert_eq!(
entry.clamp_thinking_level(ThinkingLevel::High),
ThinkingLevel::Off
);
}
#[test]
fn clamp_thinking_xhigh_without_support() {
let entry = dummy_entry("claude-opus-4-6", true);
assert_eq!(
entry.clamp_thinking_level(ThinkingLevel::XHigh),
ThinkingLevel::High
);
}
#[test]
fn clamp_thinking_xhigh_with_support() {
let entry = dummy_entry("gpt-5.2", true);
assert_eq!(
entry.clamp_thinking_level(ThinkingLevel::XHigh),
ThinkingLevel::XHigh
);
}
#[test]
fn clamp_thinking_normal_level_passthrough() {
let entry = dummy_entry("claude-opus-4-6", true);
assert_eq!(
entry.clamp_thinking_level(ThinkingLevel::Medium),
ThinkingLevel::Medium
);
}
#[test]
fn available_thinking_levels_non_reasoning() {
let entry = dummy_entry("gpt-4o-mini", false);
let levels = available_thinking_levels(&entry);
assert_eq!(levels, vec![ThinkingLevel::Off]);
}
#[test]
fn available_thinking_levels_reasoning_no_xhigh() {
let entry = dummy_entry("claude-opus-4-6", true);
let levels = available_thinking_levels(&entry);
assert_eq!(
levels,
vec![
ThinkingLevel::Off,
ThinkingLevel::Minimal,
ThinkingLevel::Low,
ThinkingLevel::Medium,
ThinkingLevel::High,
]
);
}
#[test]
fn available_thinking_levels_reasoning_with_xhigh() {
let entry = dummy_entry("gpt-5.2", true);
let levels = available_thinking_levels(&entry);
assert_eq!(
levels,
vec![
ThinkingLevel::Off,
ThinkingLevel::Minimal,
ThinkingLevel::Low,
ThinkingLevel::Medium,
ThinkingLevel::High,
ThinkingLevel::XHigh,
]
);
}
#[test]
fn rpc_model_from_entry_basic() {
let entry = dummy_entry("claude-opus-4-6", true);
let value = rpc_model_from_entry(&entry);
assert_eq!(value["id"], "claude-opus-4-6");
assert_eq!(value["name"], "claude-opus-4-6");
assert_eq!(value["provider"], "anthropic");
assert_eq!(value["reasoning"], true);
assert_eq!(value["contextWindow"], 200_000);
assert_eq!(value["maxTokens"], 8192);
}
#[test]
fn rpc_model_from_entry_input_types() {
let mut entry = dummy_entry("gpt-4o", false);
entry.model.input = vec![InputType::Text, InputType::Image];
let value = rpc_model_from_entry(&entry);
let input = value["input"].as_array().unwrap();
assert_eq!(input.len(), 2);
assert_eq!(input[0], "text");
assert_eq!(input[1], "image");
}
#[test]
fn rpc_model_from_entry_cost_present() {
let entry = dummy_entry("test-model", false);
let value = rpc_model_from_entry(&entry);
assert!(value.get("cost").is_some());
let cost = &value["cost"];
assert_eq!(cost["input"], 3.0);
assert_eq!(cost["output"], 15.0);
}
#[test]
fn current_model_entry_matches_provider_alias_and_model_case() {
let mut model = dummy_entry("gpt-4o-mini", true);
model.model.provider = "openrouter".to_string();
let options = rpc_options_with_models(vec![model]);
let mut session = Session::in_memory();
session.header.provider = Some("open-router".to_string());
session.header.model_id = Some("GPT-4O-MINI".to_string());
let resolved = current_model_entry(&session, &options).expect("resolve aliased model");
assert_eq!(resolved.model.provider, "openrouter");
assert_eq!(resolved.model.id, "gpt-4o-mini");
}
#[test]
fn session_state_resolves_model_for_provider_alias() {
let mut model = dummy_entry("gpt-4o-mini", true);
model.model.provider = "openrouter".to_string();
let options = rpc_options_with_models(vec![model]);
let mut session = Session::in_memory();
session.header.provider = Some("open-router".to_string());
session.header.model_id = Some("gpt-4o-mini".to_string());
let snapshot = RpcStateSnapshot {
steering_count: 0,
follow_up_count: 0,
steering_mode: QueueMode::OneAtATime,
follow_up_mode: QueueMode::OneAtATime,
auto_compaction_enabled: false,
auto_retry_enabled: false,
};
let state = session_state(&session, &options, &snapshot, false, false);
assert_eq!(state["model"]["provider"], "openrouter");
assert_eq!(state["model"]["id"], "gpt-4o-mini");
}
#[test]
fn error_hints_value_produces_expected_shape() {
let error = Error::validation("test error");
let value = error_hints_value(&error);
assert!(value.get("summary").is_some());
assert!(value.get("hints").is_some());
assert!(value.get("contextFields").is_some());
assert!(value["hints"].is_array());
}
#[test]
fn parse_ui_response_id_empty_string() {
let value = json!({"requestId": ""});
assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
}
#[test]
fn parse_ui_response_id_whitespace_only() {
let value = json!({"requestId": " "});
assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
}
#[test]
fn parse_ui_response_id_trims() {
let value = json!({"requestId": " req-1 "});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("req-1".to_string())
);
}
#[test]
fn parse_ui_response_id_prefers_request_id_over_id_alias() {
let value = json!({"requestId": "req-1", "id": "legacy-id"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("req-1".to_string())
);
}
#[test]
fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_not_string() {
let value = json!({"requestId": 123, "id": "legacy-id"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("legacy-id".to_string())
);
}
#[test]
fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_blank() {
let value = json!({"requestId": "", "id": "legacy-id"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("legacy-id".to_string())
);
}
#[test]
fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_whitespace() {
let value = json!({"requestId": " ", "id": "legacy-id"});
assert_eq!(
rpc_parse_extension_ui_response_id(&value),
Some("legacy-id".to_string())
);
}
#[test]
fn parse_ui_response_id_neither_field() {
let value = json!({"type": "something"});
assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
}
#[test]
fn parse_editor_response_requires_string() {
let active = ExtensionUiRequest::new("req-1", "editor", json!({"title": "t"}));
let ok = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "code"});
assert!(rpc_parse_extension_ui_response(&ok, &active).is_ok());
let bad = json!({"type": "extension_ui_response", "requestId": "req-1", "value": 42});
assert!(rpc_parse_extension_ui_response(&bad, &active).is_err());
}
#[test]
fn parse_notify_response_returns_ack() {
let active = ExtensionUiRequest::new("req-1", "notify", json!({"title": "t"}));
let val = json!({"type": "extension_ui_response", "requestId": "req-1"});
let resp = rpc_parse_extension_ui_response(&val, &active).unwrap();
assert!(!resp.cancelled);
}
#[test]
fn parse_unknown_method_errors() {
let active = ExtensionUiRequest::new("req-1", "unknown_method", json!({}));
let val = json!({"type": "extension_ui_response", "requestId": "req-1"});
assert!(rpc_parse_extension_ui_response(&val, &active).is_err());
}
#[test]
fn parse_select_with_object_options() {
let active = ExtensionUiRequest::new(
"req-1",
"select",
json!({"title": "pick", "options": [{"label": "Alpha", "value": "a"}, {"label": "Beta"}]}),
);
let val_a = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "a"});
let resp = rpc_parse_extension_ui_response(&val_a, &active).unwrap();
assert_eq!(resp.value, Some(json!("a")));
let val_b = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "Beta"});
let resp = rpc_parse_extension_ui_response(&val_b, &active).unwrap();
assert_eq!(resp.value, Some(json!("Beta")));
}
}