1#![allow(clippy::significant_drop_tightening)]
7#![allow(clippy::too_many_arguments)]
8#![allow(clippy::too_many_lines)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_lossless)]
11#![allow(clippy::ignored_unit_patterns)]
12#![allow(clippy::needless_pass_by_value)]
13
14use crate::agent::{AbortHandle, AgentEvent, AgentSession, InputSource, QueueMode};
15use crate::agent_cx::AgentCx;
16use crate::auth::AuthStorage;
17use crate::compaction::{
18 ResolvedCompactionSettings, compact, compaction_details_to_value, prepare_compaction,
19};
20use crate::config::Config;
21use crate::error::{Error, Result};
22use crate::error_hints;
23use crate::extensions::{
24 EXTENSION_EVENT_TIMEOUT_MS, ExtensionEventName, ExtensionManager, ExtensionUiRequest,
25 ExtensionUiResponse,
26};
27use crate::model::{
28 ContentBlock, ImageContent, Message, StopReason, TextContent, UserContent, UserMessage,
29};
30use crate::models::{ModelEntry, model_requires_configured_credential, normalize_api_key_opt};
31use crate::provider_metadata::provider_ids_match;
32use crate::providers;
33use crate::resources::ResourceLoader;
34use crate::session::SessionMessage;
35use crate::tools::{DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES, truncate_tail};
36use asupersync::channel::{mpsc, oneshot};
37use asupersync::runtime::RuntimeHandle;
38use asupersync::sync::{Mutex, OwnedMutexGuard};
39use asupersync::time::{sleep, wall_now};
40use memchr::memchr_iter;
41use serde_json::{Value, json};
42use std::collections::VecDeque;
43use std::future::Future;
44use std::io::{self, BufRead, Write};
45use std::path::PathBuf;
46use std::sync::Arc;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::time::Duration;
49
50#[derive(Clone)]
51pub struct RpcOptions {
52 pub config: Config,
53 pub resources: ResourceLoader,
54 pub available_models: Vec<ModelEntry>,
55 pub scoped_models: Vec<RpcScopedModel>,
56 pub cli_api_key: Option<String>,
57 pub auth: AuthStorage,
58 pub runtime_handle: RuntimeHandle,
59}
60
61#[derive(Debug, Clone)]
62pub struct RpcScopedModel {
63 pub model: ModelEntry,
64 pub thinking_level: Option<crate::model::ThinkingLevel>,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68enum StreamingBehavior {
69 Steer,
70 FollowUp,
71}
72
73#[derive(Debug, Clone)]
74struct RpcStateSnapshot {
75 steering_count: usize,
76 follow_up_count: usize,
77 steering_mode: QueueMode,
78 follow_up_mode: QueueMode,
79 auto_compaction_enabled: bool,
80 auto_retry_enabled: bool,
81}
82
83impl From<&RpcSharedState> for RpcStateSnapshot {
84 fn from(state: &RpcSharedState) -> Self {
85 Self {
86 steering_count: state.steering.len(),
87 follow_up_count: state.follow_up.len(),
88 steering_mode: state.steering_mode,
89 follow_up_mode: state.follow_up_mode,
90 auto_compaction_enabled: state.auto_compaction_enabled,
91 auto_retry_enabled: state.auto_retry_enabled,
92 }
93 }
94}
95
96impl RpcStateSnapshot {
97 const fn pending_count(&self) -> usize {
98 self.steering_count + self.follow_up_count
99 }
100}
101
102use crate::config::parse_queue_mode;
103
104fn streaming_behavior_value(parsed: &Value) -> Option<&Value> {
105 parsed
106 .get("streamingBehavior")
107 .or_else(|| parsed.get("streaming_behavior"))
108}
109
110fn parse_streaming_behavior(value: Option<&Value>) -> Result<Option<StreamingBehavior>> {
111 let Some(value) = value else {
112 return Ok(None);
113 };
114 let Some(s) = value.as_str() else {
115 return Err(Error::validation("streamingBehavior must be a string"));
116 };
117 match s {
118 "steer" => Ok(Some(StreamingBehavior::Steer)),
119 "follow-up" | "followUp" | "follow_up" => Ok(Some(StreamingBehavior::FollowUp)),
120 _ => Err(Error::validation(format!("Invalid streamingBehavior: {s}"))),
121 }
122}
123
124fn parse_optional_u32_field(parsed: &Value, field: &str) -> Result<Option<u32>> {
125 let Some(value) = parsed.get(field) else {
126 return Ok(None);
127 };
128 let number = value
129 .as_u64()
130 .ok_or_else(|| Error::Validation(format!("{field} must be a non-negative integer")))?;
131 u32::try_from(number)
132 .map(Some)
133 .map_err(|_| Error::Validation(format!("{field} exceeds the maximum supported value")))
134}
135
136fn future_with_current_cx<F>(
137 current_cx: asupersync::Cx,
138 future: F,
139) -> impl Future<Output = F::Output> + Send + 'static
140where
141 F: Future + Send + 'static,
142{
143 let mut future = Box::pin(future);
144 std::future::poll_fn(move |poll_cx| {
145 let _guard = asupersync::Cx::set_current(Some(current_cx.clone()));
146 future.as_mut().poll(poll_cx)
147 })
148}
149
150fn normalize_command_type(command_type: &str) -> &str {
151 match command_type {
152 "follow-up" | "followUp" | "queue-follow-up" | "queueFollowUp" => "follow_up",
153 "get-state" | "getState" => "get_state",
154 "set-model" | "setModel" => "set_model",
155 "set-steering-mode" | "setSteeringMode" => "set_steering_mode",
156 "set-follow-up-mode" | "setFollowUpMode" => "set_follow_up_mode",
157 "set-auto-compaction" | "setAutoCompaction" => "set_auto_compaction",
158 "set-auto-retry" | "setAutoRetry" => "set_auto_retry",
159 _ => command_type,
160 }
161}
162
163fn build_user_message(text: &str, images: &[ImageContent]) -> Message {
164 let timestamp = chrono::Utc::now().timestamp_millis();
165 if images.is_empty() {
166 return Message::User(UserMessage {
167 content: UserContent::Text(text.to_string()),
168 timestamp,
169 });
170 }
171 let blocks = build_prompt_content_blocks(text, images);
172 Message::User(UserMessage {
173 content: UserContent::Blocks(blocks),
174 timestamp,
175 })
176}
177
178fn build_prompt_content_blocks(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
179 let mut blocks = Vec::new();
180 if !text.trim().is_empty() {
181 blocks.push(ContentBlock::Text(TextContent::new(text.to_string())));
182 }
183 for image in images {
184 blocks.push(ContentBlock::Image(image.clone()));
185 }
186 blocks
187}
188
189fn parse_extension_command_line(message: &str) -> Option<(String, String)> {
190 let trimmed = message.trim_start();
191 let stripped = trimmed.strip_prefix('/')?;
192 let (command, args) = stripped
193 .split_once(char::is_whitespace)
194 .unwrap_or((stripped, ""));
195 let command = command.trim();
196 if command.is_empty() {
197 return None;
198 }
199 Some((command.to_string(), args.trim_start().to_string()))
200}
201
202fn resolve_extension_command(
203 message: &str,
204 manager: Option<&ExtensionManager>,
205) -> Option<(String, String)> {
206 if !message.trim_start().starts_with('/') {
207 return None;
208 }
209
210 let manager = manager?;
211 let (command_name, args) = parse_extension_command_line(message)?;
212 manager
213 .has_command(&command_name)
214 .then_some((command_name, args))
215}
216
217fn rpc_agent_event_handler(
218 out_tx: std::sync::mpsc::SyncSender<String>,
219 runtime_handle: RuntimeHandle,
220 extensions: Option<ExtensionManager>,
221) -> impl Fn(AgentEvent) + Send + Sync + 'static {
222 let coalescer = extensions.map(crate::extensions::EventCoalescer::new);
223
224 move |event: AgentEvent| {
225 let serialized = if let AgentEvent::AgentEnd {
226 messages, error, ..
227 } = &event
228 {
229 json!({
230 "type": "agent_end",
231 "messages": messages,
232 "error": error,
233 })
234 .to_string()
235 } else {
236 serde_json::to_string(&event).unwrap_or_else(|err| {
237 json!({
238 "type": "event_serialize_error",
239 "error": err.to_string(),
240 })
241 .to_string()
242 })
243 };
244 let _ = out_tx.send(serialized);
245 if let Some(coalescer) = &coalescer {
246 coalescer.dispatch_agent_event_lazy(&event, &runtime_handle);
247 }
248 }
249}
250
251async fn rpc_dispatch_session_before_switch(
252 manager: Option<ExtensionManager>,
253 reason: &str,
254 target_session_file: Option<&str>,
255) -> bool {
256 let Some(manager) = manager else {
257 return false;
258 };
259
260 let payload = target_session_file.map_or_else(
261 || json!({ "reason": reason }),
262 |target_session_file| json!({ "reason": reason, "targetSessionFile": target_session_file }),
263 );
264
265 manager
266 .dispatch_cancellable_event(
267 ExtensionEventName::SessionBeforeSwitch,
268 Some(payload),
269 EXTENSION_EVENT_TIMEOUT_MS,
270 )
271 .await
272 .unwrap_or(false)
273}
274
275async fn rpc_dispatch_session_switch_event(manager: Option<ExtensionManager>, payload: Value) {
276 let Some(manager) = manager else {
277 return;
278 };
279
280 let _ = manager
281 .dispatch_event(ExtensionEventName::SessionSwitch, Some(payload))
282 .await;
283}
284
285fn try_send_line_with_backpressure(tx: &mpsc::Sender<String>, mut line: String) -> bool {
286 loop {
287 match tx.try_send(line) {
288 Ok(()) => return true,
289 Err(mpsc::SendError::Full(unsent)) => {
290 line = unsent;
291 std::thread::sleep(Duration::from_millis(10));
292 }
293 Err(mpsc::SendError::Disconnected(_) | mpsc::SendError::Cancelled(_)) => {
294 return false;
295 }
296 }
297 }
298}
299
300#[derive(Debug)]
301struct RpcSharedState {
302 steering: VecDeque<Message>,
303 follow_up: VecDeque<Message>,
304 steering_mode: QueueMode,
305 follow_up_mode: QueueMode,
306 auto_compaction_enabled: bool,
307 auto_retry_enabled: bool,
308}
309
310const MAX_RPC_PENDING_MESSAGES: usize = 128;
311
312impl RpcSharedState {
313 fn new(config: &Config) -> Self {
314 Self {
315 steering: VecDeque::new(),
316 follow_up: VecDeque::new(),
317 steering_mode: config.steering_queue_mode(),
318 follow_up_mode: config.follow_up_queue_mode(),
319 auto_compaction_enabled: config.compaction_enabled(),
320 auto_retry_enabled: config.retry_enabled(),
321 }
322 }
323
324 fn pending_count(&self) -> usize {
325 self.steering.len() + self.follow_up.len()
326 }
327
328 fn push_steering(&mut self, message: Message) -> Result<()> {
329 if self.pending_count() >= MAX_RPC_PENDING_MESSAGES {
330 return Err(Error::session(
331 "Steering queue is full (Do you have too many pending commands?)",
332 ));
333 }
334 self.steering.push_back(message);
335 Ok(())
336 }
337
338 fn push_follow_up(&mut self, message: Message) -> Result<()> {
339 if self.pending_count() >= MAX_RPC_PENDING_MESSAGES {
340 return Err(Error::session("Follow-up queue is full"));
341 }
342 self.follow_up.push_back(message);
343 Ok(())
344 }
345
346 fn pop_steering(&mut self) -> Vec<Message> {
347 match self.steering_mode {
348 QueueMode::All => self.steering.drain(..).collect(),
349 QueueMode::OneAtATime => self.steering.pop_front().into_iter().collect(),
350 }
351 }
352
353 fn pop_follow_up(&mut self) -> Vec<Message> {
354 match self.follow_up_mode {
355 QueueMode::All => self.follow_up.drain(..).collect(),
356 QueueMode::OneAtATime => self.follow_up.pop_front().into_iter().collect(),
357 }
358 }
359}
360
361struct RunningBash {
363 id: String,
364 abort_tx: oneshot::Sender<()>,
365}
366
367#[derive(Debug, Default)]
368struct RpcUiBridgeState {
369 active: Option<ExtensionUiRequest>,
370 queue: VecDeque<ExtensionUiRequest>,
371}
372
373pub async fn run_stdio(mut session: AgentSession, options: RpcOptions) -> Result<()> {
374 session.set_input_source(InputSource::Rpc);
375 let (in_tx, in_rx) = mpsc::channel::<String>(1024);
376 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
377
378 std::thread::spawn(move || {
379 let stdin = io::stdin();
380 let mut reader = io::BufReader::new(stdin.lock());
381 let mut line = String::new();
382 loop {
383 line.clear();
384 match reader.read_line(&mut line) {
385 Ok(0) | Err(_) => break,
386 Ok(_) => {
387 let line_to_send = line.clone();
388 if !try_send_line_with_backpressure(&in_tx, line_to_send) {
391 break;
392 }
393 }
394 }
395 }
396 });
397
398 std::thread::spawn(move || {
399 let stdout = io::stdout();
400 let mut writer = io::BufWriter::new(stdout.lock());
401 for line in out_rx {
402 if writer.write_all(line.as_bytes()).is_err() {
403 break;
404 }
405 if writer.write_all(b"\n").is_err() {
406 break;
407 }
408 if writer.flush().is_err() {
409 break;
410 }
411 }
412 });
413
414 run(session, options, in_rx, out_tx).await
415}
416
417#[allow(clippy::too_many_lines)]
418#[allow(
419 clippy::significant_drop_tightening,
420 clippy::significant_drop_in_scrutinee
421)]
422pub async fn run(
423 session: AgentSession,
424 options: RpcOptions,
425 mut in_rx: mpsc::Receiver<String>,
426 out_tx: std::sync::mpsc::SyncSender<String>,
427) -> Result<()> {
428 let cx = AgentCx::for_current_or_request();
429 let session_handle = Arc::clone(&session.session);
430 let session = Arc::new(Mutex::new(session));
431 let shared_state = Arc::new(Mutex::new(RpcSharedState::new(&options.config)));
432 let is_streaming = Arc::new(AtomicBool::new(false));
433 let is_compacting = Arc::new(AtomicBool::new(false));
434 let abort_handle: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
435 let bash_state: Arc<Mutex<Option<RunningBash>>> = Arc::new(Mutex::new(None));
436 let retry_abort = Arc::new(AtomicBool::new(false));
437
438 {
439 use futures::future::BoxFuture;
440 let steering_state = Arc::clone(&shared_state);
441 let follow_state = Arc::clone(&shared_state);
442 let steering_cx = cx.clone();
443 let follow_cx = cx.clone();
444 let mut guard = session
445 .lock(&cx)
446 .await
447 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
448 guard.set_queue_modes(
449 options.config.steering_queue_mode(),
450 options.config.follow_up_queue_mode(),
451 );
452 let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
453 let steering_state = Arc::clone(&steering_state);
454 let steering_cx = steering_cx.clone();
455 Box::pin(async move {
456 steering_state
457 .lock(&steering_cx)
458 .await
459 .map_or_else(|_| Vec::new(), |mut state| state.pop_steering())
460 })
461 };
462 let follow_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
463 let follow_state = Arc::clone(&follow_state);
464 let follow_cx = follow_cx.clone();
465 Box::pin(async move {
466 follow_state
467 .lock(&follow_cx)
468 .await
469 .map_or_else(|_| Vec::new(), |mut state| state.pop_follow_up())
470 })
471 };
472 guard.agent.register_message_fetchers(
473 Some(Arc::new(steering_fetcher)),
474 Some(Arc::new(follow_fetcher)),
475 );
476 }
477
478 let rpc_extension_manager = {
482 let cx_ui = cx.clone();
483 let guard = session
484 .lock(&cx_ui)
485 .await
486 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
487 guard
488 .extensions
489 .as_ref()
490 .map(crate::extensions::ExtensionRegion::manager)
491 .cloned()
492 };
493
494 let rpc_ui_state: Option<Arc<Mutex<RpcUiBridgeState>>> = rpc_extension_manager
495 .as_ref()
496 .map(|_| Arc::new(Mutex::new(RpcUiBridgeState::default())));
497
498 if let Some(ref manager) = rpc_extension_manager {
499 let (extension_ui_tx, mut extension_ui_rx) =
500 asupersync::channel::mpsc::channel::<ExtensionUiRequest>(64);
501 manager.set_ui_sender(extension_ui_tx);
502
503 let out_tx_ui = out_tx.clone();
504 let ui_state = rpc_ui_state
505 .as_ref()
506 .map(Arc::clone)
507 .expect("rpc ui state should exist when extension manager exists");
508 let manager_ui = (*manager).clone();
509 let runtime_handle_ui = options.runtime_handle.clone();
510 options.runtime_handle.spawn(async move {
511 const MAX_UI_PENDING_REQUESTS: usize = 64;
512 let cx = AgentCx::for_request();
513 while let Ok(request) = extension_ui_rx.recv(&cx).await {
514 if request.expects_response() {
515 let emit_now = {
516 let Ok(mut guard) = ui_state.lock(&cx).await else {
517 return;
518 };
519 if guard.active.is_none() {
520 guard.active = Some(request.clone());
521 true
522 } else if guard.queue.len() < MAX_UI_PENDING_REQUESTS {
523 guard.queue.push_back(request.clone());
524 false
525 } else {
526 drop(guard);
527 let _ = manager_ui.respond_ui(ExtensionUiResponse {
528 id: request.id.clone(),
529 value: None,
530 cancelled: true,
531 });
532 false
533 }
534 };
535
536 if emit_now {
537 rpc_emit_extension_ui_request(
538 &runtime_handle_ui,
539 Arc::clone(&ui_state),
540 manager_ui.clone(),
541 out_tx_ui.clone(),
542 request,
543 );
544 }
545 } else {
546 let rpc_event = request.to_rpc_event();
548 let _ = out_tx_ui.send(event(&rpc_event));
549 }
550 }
551 });
552 }
553
554 while let Ok(line) = in_rx.recv(&cx).await {
555 if line.trim().is_empty() {
556 continue;
557 }
558
559 let parsed: Value = match serde_json::from_str(&line) {
560 Ok(v) => v,
561 Err(err) => {
562 let resp = response_error(None, "parse", format!("Failed to parse command: {err}"));
563 let _ = out_tx.send(resp);
564 continue;
565 }
566 };
567
568 let Some(command_type_raw) = parsed.get("type").and_then(Value::as_str) else {
569 let resp = response_error(None, "parse", "Missing command type".to_string());
570 let _ = out_tx.send(resp);
571 continue;
572 };
573 let command_type = normalize_command_type(command_type_raw);
574
575 let id = parsed.get("id").and_then(Value::as_str).map(str::to_string);
576
577 match command_type {
578 "prompt" => {
579 let Some(message) = parsed
580 .get("message")
581 .and_then(Value::as_str)
582 .map(String::from)
583 else {
584 let resp = response_error(id, "prompt", "Missing message".to_string());
585 let _ = out_tx.send(resp);
586 continue;
587 };
588
589 let images = match parse_prompt_images(parsed.get("images")) {
590 Ok(images) => images,
591 Err(err) => {
592 let resp = response_error_with_hints(id, "prompt", &err);
593 let _ = out_tx.send(resp);
594 continue;
595 }
596 };
597
598 let streaming_behavior =
599 match parse_streaming_behavior(streaming_behavior_value(&parsed)) {
600 Ok(value) => value,
601 Err(err) => {
602 let resp = response_error_with_hints(id, "prompt", &err);
603 let _ = out_tx.send(resp);
604 continue;
605 }
606 };
607
608 let extension_command =
609 resolve_extension_command(&message, rpc_extension_manager.as_ref());
610
611 if is_streaming.load(Ordering::SeqCst) {
612 if extension_command.is_some() {
613 let resp = response_error(
614 id,
615 "prompt",
616 "Extension commands are not allowed while agent is streaming"
617 .to_string(),
618 );
619 let _ = out_tx.send(resp);
620 continue;
621 }
622
623 if streaming_behavior.is_none() {
624 let resp = response_error(
625 id,
626 "prompt",
627 "Agent is currently streaming; specify streamingBehavior".to_string(),
628 );
629 let _ = out_tx.send(resp);
630 continue;
631 }
632
633 let expanded = options.resources.expand_input(&message);
634 let queued_result = {
635 let mut state = shared_state
636 .lock(&cx)
637 .await
638 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
639 match streaming_behavior {
640 Some(StreamingBehavior::Steer) => {
641 state.push_steering(build_user_message(&expanded, &images))
642 }
643 Some(StreamingBehavior::FollowUp) => {
644 state.push_follow_up(build_user_message(&expanded, &images))
645 }
646 None => Ok(()), }
648 };
649
650 match queued_result {
651 Ok(()) => {
652 let _ = out_tx.send(response_ok(id, "prompt", None));
653 }
654 Err(err) => {
655 let resp = response_error_with_hints(id, "prompt", &err);
656 let _ = out_tx.send(resp);
657 }
658 }
659 continue;
660 }
661
662 let _ = out_tx.send(response_ok(id, "prompt", None));
664
665 is_streaming.store(true, Ordering::SeqCst);
666
667 let out_tx = out_tx.clone();
668 let session = Arc::clone(&session);
669 let shared_state = Arc::clone(&shared_state);
670 let is_streaming = Arc::clone(&is_streaming);
671 let is_compacting = Arc::clone(&is_compacting);
672 let abort_handle_slot = Arc::clone(&abort_handle);
673 let runtime_handle = options.runtime_handle.clone();
674 if let Some((command_name, args)) = extension_command {
675 let command_runtime = runtime_handle.clone();
676 let command_cx = cx.clone();
677 runtime_handle.spawn(future_with_current_cx(
678 command_cx.cx().clone(),
679 async move {
680 run_extension_command(
681 session,
682 is_streaming,
683 abort_handle_slot,
684 out_tx,
685 command_runtime,
686 command_name,
687 args,
688 command_cx,
689 )
690 .await;
691 },
692 ));
693 } else {
694 let retry_abort = retry_abort.clone();
695 let options = options.clone();
696 let expanded = options.resources.expand_input(&message);
697 let prompt_cx = cx.clone();
698 runtime_handle.spawn(future_with_current_cx(
699 prompt_cx.cx().clone(),
700 async move {
701 run_prompt_with_retry(
702 session,
703 shared_state,
704 is_streaming,
705 is_compacting,
706 abort_handle_slot,
707 out_tx,
708 retry_abort,
709 options,
710 expanded,
711 images,
712 prompt_cx,
713 )
714 .await;
715 },
716 ));
717 }
718 }
719
720 "steer" => {
721 let Some(message) = parsed
722 .get("message")
723 .and_then(Value::as_str)
724 .map(String::from)
725 else {
726 let resp = response_error(id, "steer", "Missing message".to_string());
727 let _ = out_tx.send(resp);
728 continue;
729 };
730
731 if resolve_extension_command(&message, rpc_extension_manager.as_ref()).is_some() {
732 let resp = response_error(
733 id,
734 "steer",
735 "Extension commands are not allowed with steer".to_string(),
736 );
737 let _ = out_tx.send(resp);
738 continue;
739 }
740
741 let expanded = options.resources.expand_input(&message);
742 if is_streaming.load(Ordering::SeqCst) {
743 let result = shared_state
744 .lock(&cx)
745 .await
746 .map_err(|err| Error::session(format!("state lock failed: {err}")))?
747 .push_steering(build_user_message(&expanded, &[]));
748
749 match result {
750 Ok(()) => {
751 let _ = out_tx.send(response_ok(id, "steer", None));
752 }
753 Err(err) => {
754 let _ = out_tx.send(response_error_with_hints(id, "steer", &err));
755 }
756 }
757 continue;
758 }
759
760 let _ = out_tx.send(response_ok(id, "steer", None));
761
762 is_streaming.store(true, Ordering::SeqCst);
763
764 let out_tx = out_tx.clone();
765 let session = Arc::clone(&session);
766 let shared_state = Arc::clone(&shared_state);
767 let is_streaming = Arc::clone(&is_streaming);
768 let is_compacting = Arc::clone(&is_compacting);
769 let abort_handle_slot = Arc::clone(&abort_handle);
770 let retry_abort = retry_abort.clone();
771 let options = options.clone();
772 let expanded = expanded.clone();
773 let runtime_handle = options.runtime_handle.clone();
774 let prompt_cx = cx.clone();
775 runtime_handle.spawn(future_with_current_cx(prompt_cx.cx().clone(), async move {
776 run_prompt_with_retry(
777 session,
778 shared_state,
779 is_streaming,
780 is_compacting,
781 abort_handle_slot,
782 out_tx,
783 retry_abort,
784 options,
785 expanded,
786 Vec::new(),
787 prompt_cx,
788 )
789 .await;
790 }));
791 }
792
793 "follow_up" => {
794 let Some(message) = parsed
795 .get("message")
796 .and_then(Value::as_str)
797 .map(String::from)
798 else {
799 let resp = response_error(id, "follow_up", "Missing message".to_string());
800 let _ = out_tx.send(resp);
801 continue;
802 };
803
804 if resolve_extension_command(&message, rpc_extension_manager.as_ref()).is_some() {
805 let resp = response_error(
806 id,
807 "follow_up",
808 "Extension commands are not allowed with follow_up".to_string(),
809 );
810 let _ = out_tx.send(resp);
811 continue;
812 }
813
814 let expanded = options.resources.expand_input(&message);
815 if is_streaming.load(Ordering::SeqCst) {
816 let result = shared_state
817 .lock(&cx)
818 .await
819 .map_err(|err| Error::session(format!("state lock failed: {err}")))?
820 .push_follow_up(build_user_message(&expanded, &[]));
821
822 match result {
823 Ok(()) => {
824 let _ = out_tx.send(response_ok(id, "follow_up", None));
825 }
826 Err(err) => {
827 let _ = out_tx.send(response_error_with_hints(id, "follow_up", &err));
828 }
829 }
830 continue;
831 }
832
833 let _ = out_tx.send(response_ok(id, "follow_up", None));
834
835 is_streaming.store(true, Ordering::SeqCst);
836
837 let out_tx = out_tx.clone();
838 let session = Arc::clone(&session);
839 let shared_state = Arc::clone(&shared_state);
840 let is_streaming = Arc::clone(&is_streaming);
841 let is_compacting = Arc::clone(&is_compacting);
842 let abort_handle_slot = Arc::clone(&abort_handle);
843 let retry_abort = retry_abort.clone();
844 let options = options.clone();
845 let expanded = expanded.clone();
846 let runtime_handle = options.runtime_handle.clone();
847 let prompt_cx = cx.clone();
848 runtime_handle.spawn(future_with_current_cx(prompt_cx.cx().clone(), async move {
849 run_prompt_with_retry(
850 session,
851 shared_state,
852 is_streaming,
853 is_compacting,
854 abort_handle_slot,
855 out_tx,
856 retry_abort,
857 options,
858 expanded,
859 Vec::new(),
860 prompt_cx,
861 )
862 .await;
863 }));
864 }
865
866 "abort" => {
867 let handle = abort_handle
868 .lock(&cx)
869 .await
870 .map_err(|err| Error::session(format!("abort lock failed: {err}")))?
871 .clone();
872 if let Some(handle) = handle {
873 handle.abort();
874 }
875 let _ = out_tx.send(response_ok(id, "abort", None));
876 }
877
878 "get_state" => {
879 let snapshot = {
880 let state = shared_state
881 .lock(&cx)
882 .await
883 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
884 RpcStateSnapshot::from(&*state)
885 };
886 let data = {
887 let inner_session = session_handle.lock(&cx).await.map_err(|err| {
888 Error::session(format!("inner session lock failed: {err}"))
889 })?;
890 session_state(
891 &inner_session,
892 &options,
893 &snapshot,
894 is_streaming.load(Ordering::SeqCst),
895 is_compacting.load(Ordering::SeqCst),
896 )
897 };
898 let _ = out_tx.send(response_ok(id, "get_state", Some(data)));
899 }
900
901 "get_session_stats" => {
902 let data = {
903 let inner_session = session_handle.lock(&cx).await.map_err(|err| {
904 Error::session(format!("inner session lock failed: {err}"))
905 })?;
906 session_stats(&inner_session)
907 };
908 let _ = out_tx.send(response_ok(id, "get_session_stats", Some(data)));
909 }
910
911 "get_messages" => {
912 let messages = {
913 let inner_session = session_handle.lock(&cx).await.map_err(|err| {
914 Error::session(format!("inner session lock failed: {err}"))
915 })?;
916 inner_session
917 .entries_for_current_path()
918 .iter()
919 .filter_map(|entry| match entry {
920 crate::session::SessionEntry::Message(msg) => match msg.message {
921 SessionMessage::User { .. }
922 | SessionMessage::Assistant { .. }
923 | SessionMessage::ToolResult { .. }
924 | SessionMessage::BashExecution { .. }
925 | SessionMessage::Custom { .. } => Some(msg.message.clone()),
926 _ => None,
927 },
928 _ => None,
929 })
930 .collect::<Vec<_>>()
931 };
932 let messages = messages
933 .into_iter()
934 .map(rpc_session_message_value)
935 .collect::<Vec<_>>();
936 let _ = out_tx.send(response_ok(
937 id,
938 "get_messages",
939 Some(json!({ "messages": messages })),
940 ));
941 }
942
943 "get_available_models" => {
944 let models = options
945 .available_models
946 .iter()
947 .map(rpc_model_from_entry)
948 .collect::<Vec<_>>();
949 let _ = out_tx.send(response_ok(
950 id,
951 "get_available_models",
952 Some(json!({ "models": models })),
953 ));
954 }
955
956 "set_model" => {
957 let Some(provider) = parsed.get("provider").and_then(Value::as_str) else {
958 let _ = out_tx.send(response_error(
959 id,
960 "set_model",
961 "Missing provider".to_string(),
962 ));
963 continue;
964 };
965 let Some(model_id) = parsed.get("modelId").and_then(Value::as_str) else {
966 let _ = out_tx.send(response_error(
967 id,
968 "set_model",
969 "Missing modelId".to_string(),
970 ));
971 continue;
972 };
973
974 let Some(entry) = options
975 .available_models
976 .iter()
977 .find(|m| {
978 provider_ids_match(&m.model.provider, provider)
979 && m.model.id.eq_ignore_ascii_case(model_id)
980 })
981 .cloned()
982 else {
983 let _ = out_tx.send(response_error(
984 id,
985 "set_model",
986 format!("Model not found: {provider}/{model_id}"),
987 ));
988 continue;
989 };
990
991 let key = resolve_model_key(options.cli_api_key.as_deref(), &options.auth, &entry);
992 if model_requires_configured_credential(&entry) && key.is_none() {
993 let err = Error::auth(format!(
994 "Missing credentials for {}/{}",
995 entry.model.provider, entry.model.id
996 ));
997 let _ = out_tx.send(response_error_with_hints(id, "set_model", &err));
998 continue;
999 }
1000
1001 let result: Result<()> = async {
1002 let clamped_level = {
1003 let mut guard = session
1004 .lock(&cx)
1005 .await
1006 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1007 let provider_impl = providers::create_provider(
1008 &entry,
1009 guard
1010 .extensions
1011 .as_ref()
1012 .map(crate::extensions::ExtensionRegion::manager),
1013 )?;
1014 guard.agent.set_provider(provider_impl);
1015 guard.agent.stream_options_mut().api_key.clone_from(&key);
1016 guard
1017 .agent
1018 .stream_options_mut()
1019 .headers
1020 .clone_from(&entry.headers);
1021
1022 apply_model_change(&mut guard, &entry).await?;
1023
1024 let current_thinking = guard
1025 .agent
1026 .stream_options()
1027 .thinking_level
1028 .unwrap_or_default();
1029 entry.clamp_thinking_level(current_thinking)
1030 }; apply_thinking_level(Arc::clone(&session), clamped_level).await?;
1034 Ok(())
1035 }
1036 .await;
1037
1038 match result {
1039 Ok(()) => {
1040 let _ = out_tx.send(response_ok(
1041 id,
1042 "set_model",
1043 Some(rpc_model_from_entry(&entry)),
1044 ));
1045 }
1046 Err(err) => {
1047 let _ = out_tx.send(response_error_with_hints(id, "set_model", &err));
1048 }
1049 }
1050 }
1051
1052 "cycle_model" => {
1053 let result = async {
1054 let cycle_result = {
1055 let mut guard = session
1056 .lock(&cx)
1057 .await
1058 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1059 cycle_model_for_rpc(&mut guard, &options).await?
1060 };
1061
1062 if let Some((entry, thinking_level, is_scoped)) = cycle_result {
1063 apply_thinking_level_for_session(session.clone(), thinking_level, &cx).await?;
1065 Ok(Some((entry, thinking_level, is_scoped)))
1066 } else {
1067 Ok(None)
1068 }
1069 }
1070 .await;
1071
1072 match result {
1073 Ok(Some((entry, thinking_level, is_scoped))) => {
1074 let _ = out_tx.send(response_ok(
1075 id,
1076 "cycle_model",
1077 Some(json!({
1078 "model": rpc_model_from_entry(&entry),
1079 "thinkingLevel": thinking_level.to_string(),
1080 "isScoped": is_scoped,
1081 })),
1082 ));
1083 }
1084 Ok(None) => {
1085 let _ =
1086 out_tx.send(response_ok(id.clone(), "cycle_model", Some(Value::Null)));
1087 }
1088 Err(err) => {
1089 let _ = out_tx.send(response_error_with_hints(id, "cycle_model", &err));
1090 }
1091 }
1092 }
1093
1094 "set_thinking_level" => {
1095 let Some(level) = parsed.get("level").and_then(Value::as_str) else {
1096 let _ = out_tx.send(response_error(
1097 id,
1098 "set_thinking_level",
1099 "Missing level".to_string(),
1100 ));
1101 continue;
1102 };
1103 let level = match parse_thinking_level(level) {
1104 Ok(level) => level,
1105 Err(err) => {
1106 let _ =
1107 out_tx.send(response_error_with_hints(id, "set_thinking_level", &err));
1108 continue;
1109 }
1110 };
1111
1112 let clamped_level = {
1114 let mut guard = session
1115 .lock(&cx)
1116 .await
1117 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1118 let runtime_provider = guard.agent.provider().name().to_string();
1119 let runtime_model_id = guard.agent.provider().model_id().to_string();
1120 let inner_session = guard.session.lock(&cx).await.map_err(|err| {
1121 Error::session(format!("inner session lock failed: {err}"))
1122 })?;
1123 current_or_runtime_model_entry(
1124 &inner_session,
1125 &runtime_provider,
1126 &runtime_model_id,
1127 &options,
1128 )
1129 .map_or(level, |entry| entry.clamp_thinking_level(level))
1130 };
1131
1132 let result = {
1134 let session_clone = Arc::clone(&session);
1135 apply_thinking_level_for_session(session_clone, clamped_level, &cx).await
1136 };
1137
1138 if let Err(err) = result {
1139 let _ = out_tx.send(response_error_with_hints(
1140 id.clone(),
1141 "set_thinking_level",
1142 &err,
1143 ));
1144 continue;
1145 }
1146 let _ = out_tx.send(response_ok(id, "set_thinking_level", None));
1147 }
1148
1149 "cycle_thinking_level" => {
1150 let next = {
1152 let mut guard = session
1153 .lock(&cx)
1154 .await
1155 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1156 let runtime_provider = guard.agent.provider().name().to_string();
1157 let runtime_model_id = guard.agent.provider().model_id().to_string();
1158 let entry = {
1159 let inner_session = guard.session.lock(&cx).await.map_err(|err| {
1160 Error::session(format!("inner session lock failed: {err}"))
1161 })?;
1162 current_or_runtime_model_entry(
1163 &inner_session,
1164 &runtime_provider,
1165 &runtime_model_id,
1166 &options,
1167 )
1168 .cloned()
1169 };
1170 let Some(entry) = entry else {
1171 let _ =
1172 out_tx.send(response_ok(id, "cycle_thinking_level", Some(Value::Null)));
1173 continue;
1174 };
1175 if !entry.model.reasoning {
1176 let _ =
1177 out_tx.send(response_ok(id, "cycle_thinking_level", Some(Value::Null)));
1178 continue;
1179 }
1180
1181 let levels = available_thinking_levels(&entry);
1182 let current = guard
1183 .agent
1184 .stream_options()
1185 .thinking_level
1186 .unwrap_or_default();
1187 let current_index = levels
1188 .iter()
1189 .position(|level| *level == current)
1190 .unwrap_or(0);
1191 levels[(current_index + 1) % levels.len()]
1192 }; if let Err(err) = apply_thinking_level(Arc::clone(&session), next).await {
1196 let _ = out_tx.send(response_error_with_hints(
1197 id.clone(),
1198 "cycle_thinking_level",
1199 &err,
1200 ));
1201 continue;
1202 }
1203 let _ = out_tx.send(response_ok(
1204 id,
1205 "cycle_thinking_level",
1206 Some(json!({ "level": next.to_string() })),
1207 ));
1208 }
1209
1210 "set_steering_mode" => {
1211 let Some(mode) = parsed.get("mode").and_then(Value::as_str) else {
1212 let _ = out_tx.send(response_error(
1213 id,
1214 "set_steering_mode",
1215 "Missing mode".to_string(),
1216 ));
1217 continue;
1218 };
1219 let Some(mode) = parse_queue_mode(Some(mode)) else {
1220 let _ = out_tx.send(response_error(
1221 id,
1222 "set_steering_mode",
1223 "Invalid steering mode".to_string(),
1224 ));
1225 continue;
1226 };
1227 let follow_up_mode = {
1228 let mut state = shared_state
1229 .lock(&cx)
1230 .await
1231 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1232 state.steering_mode = mode;
1233 state.follow_up_mode
1234 };
1235 let mut guard = session
1236 .lock(&cx)
1237 .await
1238 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1239 guard.set_queue_modes(mode, follow_up_mode);
1240 drop(guard);
1241 let _ = out_tx.send(response_ok(id, "set_steering_mode", None));
1242 }
1243
1244 "set_follow_up_mode" => {
1245 let Some(mode) = parsed.get("mode").and_then(Value::as_str) else {
1246 let _ = out_tx.send(response_error(
1247 id,
1248 "set_follow_up_mode",
1249 "Missing mode".to_string(),
1250 ));
1251 continue;
1252 };
1253 let Some(mode) = parse_queue_mode(Some(mode)) else {
1254 let _ = out_tx.send(response_error(
1255 id,
1256 "set_follow_up_mode",
1257 "Invalid follow-up mode".to_string(),
1258 ));
1259 continue;
1260 };
1261 let steering_mode = {
1262 let mut state = shared_state
1263 .lock(&cx)
1264 .await
1265 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1266 state.follow_up_mode = mode;
1267 state.steering_mode
1268 };
1269 let mut guard = session
1270 .lock(&cx)
1271 .await
1272 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1273 guard.set_queue_modes(steering_mode, mode);
1274 drop(guard);
1275 let _ = out_tx.send(response_ok(id, "set_follow_up_mode", None));
1276 }
1277
1278 "set_auto_compaction" => {
1279 let Some(enabled) = parsed.get("enabled").and_then(Value::as_bool) else {
1280 let _ = out_tx.send(response_error(
1281 id,
1282 "set_auto_compaction",
1283 "Missing enabled".to_string(),
1284 ));
1285 continue;
1286 };
1287 let mut state = shared_state
1288 .lock(&cx)
1289 .await
1290 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1291 state.auto_compaction_enabled = enabled;
1292 drop(state);
1293 let _ = out_tx.send(response_ok(id, "set_auto_compaction", None));
1294 }
1295
1296 "set_auto_retry" => {
1297 let Some(enabled) = parsed.get("enabled").and_then(Value::as_bool) else {
1298 let _ = out_tx.send(response_error(
1299 id,
1300 "set_auto_retry",
1301 "Missing enabled".to_string(),
1302 ));
1303 continue;
1304 };
1305 let mut state = shared_state
1306 .lock(&cx)
1307 .await
1308 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1309 state.auto_retry_enabled = enabled;
1310 drop(state);
1311 let _ = out_tx.send(response_ok(id, "set_auto_retry", None));
1312 }
1313
1314 "abort_retry" => {
1315 retry_abort.store(true, Ordering::SeqCst);
1316 let _ = out_tx.send(response_ok(id, "abort_retry", None));
1317 }
1318
1319 "set_session_name" => {
1320 let Some(name) = parsed.get("name").and_then(Value::as_str) else {
1321 let _ = out_tx.send(response_error(
1322 id,
1323 "set_session_name",
1324 "Missing name".to_string(),
1325 ));
1326 continue;
1327 };
1328 let result: Result<()> = async {
1329 {
1331 let mut guard = session
1332 .lock(&cx)
1333 .await
1334 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1335 let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
1336 Error::session(format!("inner session lock failed: {err}"))
1337 })?;
1338 inner_session.append_session_info(Some(name.to_string()));
1339 } let mut guard = session
1343 .lock(&cx)
1344 .await
1345 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1346 guard.persist_session().await?;
1347 Ok(())
1348 }
1349 .await;
1350
1351 match result {
1352 Ok(()) => {
1353 let _ = out_tx.send(response_ok(id, "set_session_name", None));
1354 }
1355 Err(err) => {
1356 let _ =
1357 out_tx.send(response_error_with_hints(id, "set_session_name", &err));
1358 }
1359 }
1360 }
1361
1362 "get_last_assistant_text" => {
1363 let text = {
1364 let inner_session = session_handle.lock(&cx).await.map_err(|err| {
1365 Error::session(format!("inner session lock failed: {err}"))
1366 })?;
1367 last_assistant_text(&inner_session)
1368 };
1369 let _ = out_tx.send(response_ok(
1370 id,
1371 "get_last_assistant_text",
1372 Some(json!({ "text": text })),
1373 ));
1374 }
1375
1376 "export_html" => {
1377 let output_path = parsed
1378 .get("outputPath")
1379 .and_then(Value::as_str)
1380 .map(str::to_string);
1381 let snapshot = {
1386 let guard = session
1387 .lock(&cx)
1388 .await
1389 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1390 let inner = guard.session.lock(&cx).await.map_err(|err| {
1391 Error::session(format!("inner session lock failed: {err}"))
1392 })?;
1393 inner.export_snapshot()
1394 };
1395 match export_html_snapshot(&snapshot, output_path.as_deref()).await {
1396 Ok(path) => {
1397 let _ = out_tx.send(response_ok(
1398 id,
1399 "export_html",
1400 Some(json!({ "path": path })),
1401 ));
1402 }
1403 Err(err) => {
1404 let _ = out_tx.send(response_error_with_hints(id, "export_html", &err));
1405 }
1406 }
1407 }
1408
1409 "bash" => {
1410 let Some(command) = parsed.get("command").and_then(Value::as_str) else {
1411 let _ = out_tx.send(response_error(id, "bash", "Missing command".to_string()));
1412 continue;
1413 };
1414
1415 let mut running = bash_state
1416 .lock(&cx)
1417 .await
1418 .map_err(|err| Error::session(format!("bash state lock failed: {err}")))?;
1419 if running.is_some() {
1420 let _ = out_tx.send(response_error(
1421 id,
1422 "bash",
1423 "Bash command already running".to_string(),
1424 ));
1425 continue;
1426 }
1427
1428 let run_id = uuid::Uuid::new_v4().to_string();
1429 let (abort_tx, abort_rx) = oneshot::channel();
1430 *running = Some(RunningBash {
1431 id: run_id.clone(),
1432 abort_tx,
1433 });
1434
1435 let out_tx = out_tx.clone();
1436 let session = Arc::clone(&session);
1437 let bash_state = Arc::clone(&bash_state);
1438 let command = command.to_string();
1439 let id_clone = id.clone();
1440 let runtime_handle = options.runtime_handle.clone();
1441 let bash_cx = cx.clone();
1442
1443 runtime_handle.spawn(async move {
1444 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
1445 let result = run_bash_rpc(&cwd, &command, abort_rx).await;
1446
1447 let response = match result {
1448 Ok(result) => {
1449 let should_persist = if let Ok(mut guard) = session.lock(&bash_cx).await {
1451 if let Ok(mut inner_session) = guard.session.lock(&bash_cx).await {
1452 inner_session.append_message(SessionMessage::BashExecution {
1453 command: command.clone(),
1454 output: result.output.clone(),
1455 exit_code: result.exit_code,
1456 cancelled: Some(result.cancelled),
1457 truncated: Some(result.truncated),
1458 full_output_path: result.full_output_path.clone(),
1459 timestamp: Some(chrono::Utc::now().timestamp_millis()),
1460 extra: std::collections::HashMap::default(),
1461 });
1462 true
1463 } else {
1464 false
1465 }
1466 } else {
1467 false
1468 };
1469
1470 if should_persist {
1472 if let Ok(mut guard) = session.lock(&bash_cx).await {
1473 let _ = guard.persist_session().await;
1474 }
1475 }
1476
1477 response_ok(
1478 id_clone,
1479 "bash",
1480 Some(json!({
1481 "output": result.output,
1482 "exitCode": result.exit_code,
1483 "cancelled": result.cancelled,
1484 "truncated": result.truncated,
1485 "fullOutputPath": result.full_output_path,
1486 })),
1487 )
1488 }
1489 Err(err) => response_error_with_hints(id_clone, "bash", &err),
1490 };
1491
1492 let _ = out_tx.send(response);
1493 if let Ok(mut running) = bash_state.lock(&bash_cx).await {
1494 if running.as_ref().is_some_and(|r| r.id == run_id) {
1495 *running = None;
1496 }
1497 }
1498 });
1499 }
1500
1501 "abort_bash" => {
1502 let mut running = bash_state
1503 .lock(&cx)
1504 .await
1505 .map_err(|err| Error::session(format!("bash state lock failed: {err}")))?;
1506 if let Some(running_bash) = running.take() {
1507 let _ = running_bash.abort_tx.send(&cx, ());
1508 }
1509 let _ = out_tx.send(response_ok(id, "abort_bash", None));
1510 }
1511
1512 "compact" => {
1513 let custom_instructions = parsed
1514 .get("customInstructions")
1515 .and_then(Value::as_str)
1516 .map(str::to_string);
1517 let reserve_tokens_override =
1518 match parse_optional_u32_field(&parsed, "reserveTokens") {
1519 Ok(value) => value,
1520 Err(err) => {
1521 let _ =
1522 out_tx.send(response_error_with_hints(id.clone(), "compact", &err));
1523 continue;
1524 }
1525 };
1526 let keep_recent_tokens_override =
1527 match parse_optional_u32_field(&parsed, "keepRecentTokens") {
1528 Ok(value) => value,
1529 Err(err) => {
1530 let _ =
1531 out_tx.send(response_error_with_hints(id.clone(), "compact", &err));
1532 continue;
1533 }
1534 };
1535
1536 let result: Result<Value> = async {
1537 let mut guard = session
1538 .lock(&cx)
1539 .await
1540 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1541 let path_entries = {
1542 let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
1543 Error::session(format!("inner session lock failed: {err}"))
1544 })?;
1545 inner_session.ensure_entry_ids();
1546 inner_session
1547 .entries_for_current_path()
1548 .into_iter()
1549 .cloned()
1550 .collect::<Vec<_>>()
1551 };
1552
1553 let key = guard
1554 .agent
1555 .stream_options()
1556 .api_key
1557 .as_deref()
1558 .ok_or_else(|| Error::auth("Missing API key for compaction"))?;
1559
1560 let provider = guard.agent.provider();
1561
1562 let settings = ResolvedCompactionSettings {
1563 enabled: options.config.compaction_enabled(),
1564 reserve_tokens: reserve_tokens_override
1565 .unwrap_or_else(|| options.config.compaction_reserve_tokens()),
1566 keep_recent_tokens: keep_recent_tokens_override
1567 .unwrap_or_else(|| options.config.compaction_keep_recent_tokens()),
1568 ..Default::default()
1569 };
1570
1571 let prep = prepare_compaction(&path_entries, settings).ok_or_else(|| {
1572 Error::session(
1573 "Compaction not available (already compacted or missing IDs)",
1574 )
1575 })?;
1576
1577 is_compacting.store(true, Ordering::SeqCst);
1578 let compact_res =
1579 compact(prep, provider, key, custom_instructions.as_deref()).await;
1580 is_compacting.store(false, Ordering::SeqCst);
1581 let result_data = compact_res?;
1582
1583 let details_value = compaction_details_to_value(&result_data.details)?;
1584
1585 let messages = {
1586 let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
1587 Error::session(format!("inner session lock failed: {err}"))
1588 })?;
1589 inner_session.append_compaction(
1590 result_data.summary.clone(),
1591 result_data.first_kept_entry_id.clone(),
1592 result_data.tokens_before,
1593 Some(details_value.clone()),
1594 None,
1595 );
1596 inner_session.to_messages_for_current_path()
1597 };
1598 guard.persist_session().await?;
1599 guard.agent.replace_messages(messages);
1600
1601 Ok(json!({
1602 "summary": result_data.summary,
1603 "firstKeptEntryId": result_data.first_kept_entry_id,
1604 "tokensBefore": result_data.tokens_before,
1605 "details": details_value,
1606 }))
1607 }
1608 .await;
1609
1610 match result {
1611 Ok(data) => {
1612 let _ = out_tx.send(response_ok(id, "compact", Some(data)));
1613 }
1614 Err(err) => {
1615 let _ = out_tx.send(response_error_with_hints(id, "compact", &err));
1616 }
1617 }
1618 }
1619
1620 "new_session" => {
1621 if rpc_dispatch_session_before_switch(rpc_extension_manager.clone(), "new", None)
1622 .await
1623 {
1624 let _ = out_tx.send(response_ok(
1625 id,
1626 "new_session",
1627 Some(json!({ "cancelled": true })),
1628 ));
1629 continue;
1630 }
1631
1632 let parent = parsed
1633 .get("parentSession")
1634 .and_then(Value::as_str)
1635 .map(str::to_string);
1636 let (session_id, previous_session_file) = {
1637 let mut guard = session
1638 .lock(&cx)
1639 .await
1640 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1641 let (session_dir, provider, model_id, thinking_level, previous_session_file) = {
1642 let inner_session = guard.session.lock(&cx).await.map_err(|err| {
1643 Error::session(format!("inner session lock failed: {err}"))
1644 })?;
1645 (
1646 inner_session.session_dir.clone(),
1647 inner_session.header.provider.clone(),
1648 inner_session.header.model_id.clone(),
1649 inner_session.header.thinking_level.clone(),
1650 inner_session.path.as_ref().map(|p| p.display().to_string()),
1651 )
1652 };
1653 let mut new_session = if guard.save_enabled() {
1654 crate::session::Session::create_with_dir(session_dir)
1655 } else {
1656 crate::session::Session::in_memory()
1657 };
1658 new_session.header.parent_session = parent;
1659 new_session.header.provider.clone_from(&provider);
1661 new_session.header.model_id.clone_from(&model_id);
1662 new_session
1663 .header
1664 .thinking_level
1665 .clone_from(&thinking_level);
1666
1667 let session_id = new_session.header.id.clone();
1668 {
1669 let mut inner_session = guard.session.lock(&cx).await.map_err(|err| {
1670 Error::session(format!("inner session lock failed: {err}"))
1671 })?;
1672 *inner_session = new_session;
1673 }
1674 guard.agent.clear_messages();
1675 guard.agent.stream_options_mut().session_id = Some(session_id.clone());
1676
1677 (session_id, previous_session_file)
1678 };
1679 {
1680 let mut state = shared_state
1681 .lock(&cx)
1682 .await
1683 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1684 state.steering.clear();
1685 state.follow_up.clear();
1686 }
1687 rpc_dispatch_session_switch_event(
1688 rpc_extension_manager.clone(),
1689 json!({
1690 "reason": "new",
1691 "previousSessionFile": previous_session_file,
1692 "sessionId": session_id,
1693 }),
1694 )
1695 .await;
1696 let _ = out_tx.send(response_ok(
1697 id,
1698 "new_session",
1699 Some(json!({ "cancelled": false })),
1700 ));
1701 }
1702
1703 "switch_session" => {
1704 let Some(session_path) = parsed.get("sessionPath").and_then(Value::as_str) else {
1705 let _ = out_tx.send(response_error(
1706 id,
1707 "switch_session",
1708 "Missing sessionPath".to_string(),
1709 ));
1710 continue;
1711 };
1712
1713 if rpc_dispatch_session_before_switch(
1714 rpc_extension_manager.clone(),
1715 "resume",
1716 Some(session_path),
1717 )
1718 .await
1719 {
1720 let _ = out_tx.send(response_ok(
1721 id,
1722 "switch_session",
1723 Some(json!({ "cancelled": true })),
1724 ));
1725 continue;
1726 }
1727
1728 let session_path_buf = std::path::PathBuf::from(session_path);
1730 let sessions_dir = crate::config::Config::sessions_dir();
1731 let resolved_path = if session_path_buf.is_relative() {
1732 sessions_dir.join(&session_path_buf)
1733 } else {
1734 session_path_buf.clone()
1735 };
1736 if session_path_buf.is_relative() {
1737 let canonical_session = crate::extensions::safe_canonicalize(&resolved_path);
1738 let canonical_sessions_dir =
1739 crate::extensions::safe_canonicalize(&sessions_dir);
1740 if !canonical_session.starts_with(&canonical_sessions_dir) {
1741 let _ = out_tx.send(response_error(
1742 id,
1743 "switch_session",
1744 "Session path is outside the sessions directory".to_string(),
1745 ));
1746 continue;
1747 }
1748 }
1749
1750 let loaded =
1751 crate::session::Session::open(resolved_path.to_string_lossy().as_ref()).await;
1752 match loaded {
1753 Ok(new_session) => {
1754 let target_session_file = new_session.path.as_ref().map_or_else(
1755 || resolved_path.display().to_string(),
1756 |p| p.display().to_string(),
1757 );
1758 let messages = new_session.to_messages_for_current_path();
1759 let session_id = new_session.header.id.clone();
1760 let previous_session_file;
1761 let mut guard = session
1762 .lock(&cx)
1763 .await
1764 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1765 {
1766 let mut inner_session =
1767 guard.session.lock(&cx).await.map_err(|err| {
1768 Error::session(format!("inner session lock failed: {err}"))
1769 })?;
1770 previous_session_file =
1771 inner_session.path.as_ref().map(|p| p.display().to_string());
1772 *inner_session = new_session;
1773 }
1774 guard.agent.replace_messages(messages);
1775 guard.agent.stream_options_mut().session_id = Some(session_id.clone());
1776 let mut state = shared_state
1777 .lock(&cx)
1778 .await
1779 .map_err(|err| Error::session(format!("state lock failed: {err}")))?;
1780 state.steering.clear();
1781 state.follow_up.clear();
1782 drop(state);
1783 drop(guard);
1784
1785 rpc_dispatch_session_switch_event(
1786 rpc_extension_manager.clone(),
1787 json!({
1788 "reason": "resume",
1789 "previousSessionFile": previous_session_file,
1790 "targetSessionFile": target_session_file,
1791 "sessionId": session_id,
1792 }),
1793 )
1794 .await;
1795
1796 let _ = out_tx.send(response_ok(
1797 id,
1798 "switch_session",
1799 Some(json!({ "cancelled": false })),
1800 ));
1801 }
1802 Err(err) => {
1803 let _ = out_tx.send(response_error_with_hints(id, "switch_session", &err));
1804 }
1805 }
1806 }
1807
1808 "fork" => {
1809 let Some(entry_id) = parsed.get("entryId").and_then(Value::as_str) else {
1810 let _ = out_tx.send(response_error(id, "fork", "Missing entryId".to_string()));
1811 continue;
1812 };
1813
1814 let result: Result<String> =
1815 async {
1816 let (fork_plan, parent_path, session_dir, save_enabled, header_snapshot) = {
1818 let guard = session.lock(&cx).await.map_err(|err| {
1819 Error::session(format!("session lock failed: {err}"))
1820 })?;
1821 let inner = guard.session.lock(&cx).await.map_err(|err| {
1822 Error::session(format!("inner session lock failed: {err}"))
1823 })?;
1824 let plan = inner.plan_fork_from_user_message(entry_id)?;
1825 let parent_path = inner.path.as_ref().map(|p| p.display().to_string());
1826 let session_dir = inner.session_dir.clone();
1827 let header = inner.header.clone();
1828 (plan, parent_path, session_dir, guard.save_enabled(), header)
1829 };
1831
1832 let selected_text = fork_plan.selected_text.clone();
1834
1835 let mut new_session = if save_enabled {
1836 crate::session::Session::create_with_dir(session_dir)
1837 } else {
1838 crate::session::Session::in_memory()
1839 };
1840 new_session.header.parent_session = parent_path;
1841 new_session
1842 .header
1843 .provider
1844 .clone_from(&header_snapshot.provider);
1845 new_session
1846 .header
1847 .model_id
1848 .clone_from(&header_snapshot.model_id);
1849 new_session
1850 .header
1851 .thinking_level
1852 .clone_from(&header_snapshot.thinking_level);
1853 new_session.init_from_fork_plan(fork_plan);
1854
1855 let messages = new_session.to_messages_for_current_path();
1856 let session_id = new_session.header.id.clone();
1857
1858 {
1860 let mut guard = session.lock(&cx).await.map_err(|err| {
1861 Error::session(format!("session lock failed: {err}"))
1862 })?;
1863 let mut inner = guard.session.lock(&cx).await.map_err(|err| {
1864 Error::session(format!("inner session lock failed: {err}"))
1865 })?;
1866 *inner = new_session;
1867 drop(inner);
1868 guard.agent.replace_messages(messages);
1869 guard.agent.stream_options_mut().session_id = Some(session_id);
1870 }
1871
1872 {
1873 let mut state = shared_state.lock(&cx).await.map_err(|err| {
1874 Error::session(format!("state lock failed: {err}"))
1875 })?;
1876 state.steering.clear();
1877 state.follow_up.clear();
1878 }
1879
1880 Ok(selected_text)
1881 }
1882 .await;
1883
1884 match result {
1885 Ok(selected_text) => {
1886 let _ = out_tx.send(response_ok(
1887 id,
1888 "fork",
1889 Some(json!({ "text": selected_text, "cancelled": false })),
1890 ));
1891 }
1892 Err(err) => {
1893 let _ = out_tx.send(response_error_with_hints(id, "fork", &err));
1894 }
1895 }
1896 }
1897
1898 "get_fork_messages" => {
1899 let path_entries = {
1901 let guard = session
1902 .lock(&cx)
1903 .await
1904 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
1905 let inner_session = guard.session.lock(&cx).await.map_err(|err| {
1906 Error::session(format!("inner session lock failed: {err}"))
1907 })?;
1908 inner_session
1909 .entries_for_current_path()
1910 .into_iter()
1911 .cloned()
1912 .collect::<Vec<_>>()
1913 };
1914 let messages = fork_messages_from_entries(&path_entries);
1915 let _ = out_tx.send(response_ok(
1916 id,
1917 "get_fork_messages",
1918 Some(json!({ "messages": messages })),
1919 ));
1920 }
1921
1922 "get_commands" => {
1923 let commands = options.resources.list_commands();
1924 let _ = out_tx.send(response_ok(
1925 id,
1926 "get_commands",
1927 Some(json!({ "commands": commands })),
1928 ));
1929 }
1930
1931 "extension_ui_response" => {
1932 if let (Some(manager), Some(ui_state)) =
1933 (rpc_extension_manager.as_ref(), rpc_ui_state.as_ref())
1934 {
1935 let Some(request_id) = rpc_parse_extension_ui_response_id(&parsed) else {
1936 let _ = out_tx.send(response_error(
1937 id,
1938 "extension_ui_response",
1939 "Missing requestId (or id) field",
1940 ));
1941 continue;
1942 };
1943
1944 let (response, next_request) = {
1945 let Ok(mut guard) = ui_state.lock(&cx).await else {
1946 let _ = out_tx.send(response_error(
1947 id,
1948 "extension_ui_response",
1949 "Extension UI bridge unavailable",
1950 ));
1951 continue;
1952 };
1953
1954 let Some(active) = guard.active.clone() else {
1955 let _ = out_tx.send(response_error(
1956 id,
1957 "extension_ui_response",
1958 "No active extension UI request",
1959 ));
1960 continue;
1961 };
1962
1963 if active.id != request_id {
1964 let _ = out_tx.send(response_error(
1965 id,
1966 "extension_ui_response",
1967 format!(
1968 "Unexpected requestId: {request_id} (active: {})",
1969 active.id
1970 ),
1971 ));
1972 continue;
1973 }
1974
1975 let response = match rpc_parse_extension_ui_response(&parsed, &active) {
1976 Ok(response) => response,
1977 Err(message) => {
1978 let _ = out_tx.send(response_error(
1979 id,
1980 "extension_ui_response",
1981 message,
1982 ));
1983 continue;
1984 }
1985 };
1986
1987 guard.active = None;
1988 let next = guard.queue.pop_front();
1989 if let Some(ref next) = next {
1990 guard.active = Some(next.clone());
1991 }
1992 (response, next)
1993 };
1994
1995 let resolved = manager.respond_ui(response);
1996 let _ = out_tx.send(response_ok(
1997 id,
1998 "extension_ui_response",
1999 Some(json!({ "resolved": resolved })),
2000 ));
2001
2002 if let Some(next) = next_request {
2003 rpc_emit_extension_ui_request(
2004 &options.runtime_handle,
2005 Arc::clone(ui_state),
2006 (*manager).clone(),
2007 out_tx.clone(),
2008 next,
2009 );
2010 }
2011 } else {
2012 let _ = out_tx.send(response_ok(id, "extension_ui_response", None));
2013 }
2014 }
2015
2016 _ => {
2017 let _ = out_tx.send(response_error(
2018 id,
2019 command_type_raw,
2020 format!("Unknown command: {command_type_raw}"),
2021 ));
2022 }
2023 }
2024 }
2025
2026 let extension_region = session
2030 .lock(&cx)
2031 .await
2032 .ok()
2033 .and_then(|mut guard| guard.extensions.take());
2034 if let Some(ext) = extension_region {
2035 ext.shutdown().await;
2036 }
2037
2038 Ok(())
2039}
2040
2041#[allow(clippy::too_many_lines)]
2046async fn run_prompt_with_retry(
2047 session: Arc<Mutex<AgentSession>>,
2048 shared_state: Arc<Mutex<RpcSharedState>>,
2049 is_streaming: Arc<AtomicBool>,
2050 is_compacting: Arc<AtomicBool>,
2051 abort_handle_slot: Arc<Mutex<Option<AbortHandle>>>,
2052 out_tx: std::sync::mpsc::SyncSender<String>,
2053 retry_abort: Arc<AtomicBool>,
2054 options: RpcOptions,
2055 message: String,
2056 images: Vec<ImageContent>,
2057 cx: AgentCx,
2058) {
2059 retry_abort.store(false, Ordering::SeqCst);
2060 is_streaming.store(true, Ordering::SeqCst);
2061
2062 let max_retries = options.config.retry_max_retries();
2063 let mut retry_count: u32 = 0;
2064 let mut success = false;
2065 let mut final_error: Option<String> = None;
2066 let mut final_error_hints: Option<Value> = None;
2067
2068 loop {
2069 if retry_count > 0 && cx.checkpoint().is_err() {
2070 final_error = Some("Retry aborted".to_string());
2071 final_error_hints = None;
2072 break;
2073 }
2074
2075 let (abort_handle, abort_signal) = AbortHandle::new();
2076 if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
2077 *guard = Some(abort_handle);
2078 } else {
2079 is_streaming.store(false, Ordering::SeqCst);
2080 return;
2081 }
2082
2083 let runtime_for_events = options.runtime_handle.clone();
2084
2085 let result = {
2086 let mut guard = match OwnedMutexGuard::lock(Arc::clone(&session), &cx).await {
2087 Ok(guard) => guard,
2088 Err(err) => {
2089 final_error = Some(format!("session lock failed: {err}"));
2090 final_error_hints = None;
2091 break;
2092 }
2093 };
2094 let extensions = guard.extensions.as_ref().map(|r| r.manager().clone());
2095 let event_handler =
2096 rpc_agent_event_handler(out_tx.clone(), runtime_for_events, extensions);
2097
2098 if images.is_empty() {
2099 guard
2100 .run_text_with_abort(message.clone(), Some(abort_signal), event_handler)
2101 .await
2102 } else {
2103 let blocks = build_prompt_content_blocks(&message, &images);
2104 guard
2105 .run_with_content_with_abort(blocks, Some(abort_signal), event_handler)
2106 .await
2107 }
2108 };
2109
2110 if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
2111 *guard = None;
2112 }
2113
2114 match result {
2115 Ok(message) => {
2116 if matches!(message.stop_reason, StopReason::Error | StopReason::Aborted) {
2117 final_error = message
2118 .error_message
2119 .clone()
2120 .or_else(|| Some("Request error".to_string()));
2121 final_error_hints = None;
2122 if message.stop_reason == StopReason::Aborted {
2123 break;
2124 }
2125 if let Some(ref err_msg) = final_error {
2128 let context_window = if let Ok(guard) =
2129 OwnedMutexGuard::lock(Arc::clone(&session), &cx).await
2130 {
2131 let runtime_provider = guard.agent.provider().name().to_string();
2132 let runtime_model_id = guard.agent.provider().model_id().to_string();
2133 guard.session.lock(&cx).await.map_or(None, |inner| {
2134 current_or_runtime_model_entry(
2135 &inner,
2136 &runtime_provider,
2137 &runtime_model_id,
2138 &options,
2139 )
2140 .map(|e| e.model.context_window)
2141 })
2142 } else {
2143 None
2144 };
2145 if !crate::error::is_retryable_error(
2146 err_msg,
2147 Some(message.usage.input),
2148 context_window,
2149 ) {
2150 break;
2151 }
2152 }
2153 } else {
2154 success = true;
2155 break;
2156 }
2157 }
2158 Err(err) => {
2159 let err_str = err.to_string();
2160 if !crate::error::is_retryable_error(&err_str, None, None) {
2163 final_error = Some(err_str);
2164 final_error_hints = Some(error_hints_value(&err));
2165 break;
2166 }
2167 final_error = Some(err_str);
2168 final_error_hints = Some(error_hints_value(&err));
2169 }
2170 }
2171
2172 let retry_enabled = OwnedMutexGuard::lock(Arc::clone(&shared_state), &cx)
2173 .await
2174 .is_ok_and(|state| state.auto_retry_enabled);
2175 if !retry_enabled || retry_count >= max_retries {
2176 break;
2177 }
2178
2179 retry_count += 1;
2180 let delay_ms = retry_delay_ms(&options.config, retry_count);
2181 let error_message = final_error
2182 .clone()
2183 .unwrap_or_else(|| "Request error".to_string());
2184 let _ = out_tx.send(event(&json!({
2185 "type": "auto_retry_start",
2186 "attempt": retry_count,
2187 "maxAttempts": max_retries,
2188 "delayMs": delay_ms,
2189 "errorMessage": error_message,
2190 })));
2191
2192 let delay = Duration::from_millis(delay_ms as u64);
2193 let start = std::time::Instant::now();
2194 let mut retry_cancelled = false;
2195 while start.elapsed() < delay {
2196 if retry_abort.load(Ordering::SeqCst) {
2197 retry_cancelled = true;
2198 break;
2199 }
2200 if cx.checkpoint().is_err() {
2201 retry_cancelled = true;
2202 break;
2203 }
2204 let now = cx
2205 .cx()
2206 .timer_driver()
2207 .map_or_else(wall_now, |timer| timer.now());
2208 sleep(now, Duration::from_millis(50)).await;
2209 }
2210
2211 if retry_cancelled || retry_abort.load(Ordering::SeqCst) {
2212 final_error = Some("Retry aborted".to_string());
2213 break;
2214 }
2215
2216 if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&session), &cx).await {
2218 let _ = guard.revert_last_user_message().await;
2219 }
2220 }
2221
2222 if retry_count > 0 {
2223 let _ = out_tx.send(event(&json!({
2224 "type": "auto_retry_end",
2225 "success": success,
2226 "attempt": retry_count,
2227 "finalError": if success { Value::Null } else { json!(final_error.clone()) },
2228 })));
2229 }
2230
2231 is_streaming.store(false, Ordering::SeqCst);
2232
2233 if !success {
2234 if let Some(err) = final_error {
2235 let mut payload = json!({
2236 "type": "agent_end",
2237 "messages": [],
2238 "error": err
2239 });
2240 if let Some(hints) = final_error_hints {
2241 payload["errorHints"] = hints;
2242 }
2243 let _ = out_tx.send(event(&payload));
2244 }
2245 return;
2246 }
2247
2248 let auto_compaction_enabled = OwnedMutexGuard::lock(Arc::clone(&shared_state), &cx)
2249 .await
2250 .is_ok_and(|state| state.auto_compaction_enabled);
2251 if auto_compaction_enabled {
2252 maybe_auto_compact(session, options, is_compacting, out_tx).await;
2253 }
2254}
2255
2256async fn run_extension_command(
2257 session: Arc<Mutex<AgentSession>>,
2258 is_streaming: Arc<AtomicBool>,
2259 abort_handle_slot: Arc<Mutex<Option<AbortHandle>>>,
2260 out_tx: std::sync::mpsc::SyncSender<String>,
2261 runtime_handle: RuntimeHandle,
2262 command_name: String,
2263 args: String,
2264 cx: AgentCx,
2265) {
2266 is_streaming.store(true, Ordering::SeqCst);
2267
2268 let (abort_handle, abort_signal) = AbortHandle::new();
2269 if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
2270 *guard = Some(abort_handle);
2271 } else {
2272 is_streaming.store(false, Ordering::SeqCst);
2273 return;
2274 }
2275
2276 let result = {
2277 let mut guard = match OwnedMutexGuard::lock(Arc::clone(&session), &cx).await {
2278 Ok(guard) => guard,
2279 Err(err) => {
2280 let err = Error::session(format!("session lock failed: {err}"));
2281 let mut payload = json!({
2282 "type": "agent_end",
2283 "messages": [],
2284 "error": err.to_string(),
2285 });
2286 payload["errorHints"] = error_hints_value(&err);
2287 let _ = out_tx.send(event(&payload));
2288 is_streaming.store(false, Ordering::SeqCst);
2289 return;
2290 }
2291 };
2292 let extensions = guard
2293 .extensions
2294 .as_ref()
2295 .map(|region| region.manager().clone());
2296 let event_handler = rpc_agent_event_handler(out_tx.clone(), runtime_handle, extensions);
2297 guard
2298 .execute_extension_command_with_abort(
2299 &command_name,
2300 &args,
2301 EXTENSION_EVENT_TIMEOUT_MS,
2302 Some(abort_signal),
2303 event_handler,
2304 )
2305 .await
2306 };
2307
2308 if let Ok(mut guard) = OwnedMutexGuard::lock(Arc::clone(&abort_handle_slot), &cx).await {
2309 *guard = None;
2310 }
2311 is_streaming.store(false, Ordering::SeqCst);
2312
2313 if let Err(err) = result {
2314 let mut payload = json!({
2315 "type": "agent_end",
2316 "messages": [],
2317 "error": err.to_string(),
2318 });
2319 payload["errorHints"] = error_hints_value(&err);
2320 let _ = out_tx.send(event(&payload));
2321 }
2322}
2323
2324fn response_ok(id: Option<String>, command: &str, data: Option<Value>) -> String {
2329 let mut resp = json!({
2330 "type": "response",
2331 "command": command,
2332 "success": true,
2333 });
2334 if let Some(id) = id {
2335 resp["id"] = Value::String(id);
2336 }
2337 if let Some(data) = data {
2338 resp["data"] = data;
2339 }
2340 resp.to_string()
2341}
2342
2343fn response_error(id: Option<String>, command: &str, error: impl Into<String>) -> String {
2344 let mut resp = json!({
2345 "type": "response",
2346 "command": command,
2347 "success": false,
2348 "error": error.into(),
2349 });
2350 if let Some(id) = id {
2351 resp["id"] = Value::String(id);
2352 }
2353 resp.to_string()
2354}
2355
2356fn response_error_with_hints(id: Option<String>, command: &str, error: &Error) -> String {
2357 let mut resp = json!({
2358 "type": "response",
2359 "command": command,
2360 "success": false,
2361 "error": error.to_string(),
2362 "errorHints": error_hints_value(error),
2363 });
2364 if let Some(id) = id {
2365 resp["id"] = Value::String(id);
2366 }
2367 resp.to_string()
2368}
2369
2370fn event(value: &Value) -> String {
2371 value.to_string()
2372}
2373
2374fn rpc_emit_extension_ui_request(
2375 runtime_handle: &RuntimeHandle,
2376 ui_state: Arc<Mutex<RpcUiBridgeState>>,
2377 manager: ExtensionManager,
2378 out_tx_ui: std::sync::mpsc::SyncSender<String>,
2379 request: ExtensionUiRequest,
2380) {
2381 let rpc_event = request.to_rpc_event();
2383 let _ = out_tx_ui.send(event(&rpc_event));
2384
2385 if !request.expects_response() {
2386 return;
2387 }
2388
2389 let Some(timeout_ms) = request.effective_timeout_ms() else {
2392 return;
2393 };
2394
2395 let fire_ms = timeout_ms.saturating_sub(10).max(1);
2397 let request_id = request.id;
2398 let ui_state_timeout = Arc::clone(&ui_state);
2399 let manager_timeout = manager;
2400 let out_tx_timeout = out_tx_ui;
2401 let runtime_handle_inner = runtime_handle.clone();
2402
2403 runtime_handle.spawn(async move {
2404 sleep(wall_now(), Duration::from_millis(fire_ms)).await;
2405 let cx = AgentCx::for_request();
2406
2407 let next = {
2408 let Ok(mut guard) = ui_state_timeout.lock(cx.cx()).await else {
2409 return;
2410 };
2411
2412 let Some(active) = guard.active.as_ref() else {
2413 return;
2414 };
2415
2416 if active.id != request_id {
2418 return;
2419 }
2420
2421 let _ = manager_timeout.respond_ui(ExtensionUiResponse {
2423 id: request_id,
2424 value: None,
2425 cancelled: true,
2426 });
2427
2428 guard.active = None;
2429 let next = guard.queue.pop_front();
2430 if let Some(ref next) = next {
2431 guard.active = Some(next.clone());
2432 }
2433 next
2434 };
2435
2436 if let Some(next) = next {
2437 rpc_emit_extension_ui_request(
2438 &runtime_handle_inner,
2439 ui_state_timeout,
2440 manager_timeout,
2441 out_tx_timeout,
2442 next,
2443 );
2444 }
2445 });
2446}
2447
2448fn rpc_parse_extension_ui_response_id(parsed: &Value) -> Option<String> {
2449 let request_id = parsed
2450 .get("requestId")
2451 .and_then(Value::as_str)
2452 .map(str::trim)
2453 .filter(|value| !value.is_empty())
2454 .map(String::from);
2455
2456 request_id.or_else(|| {
2457 parsed
2458 .get("id")
2459 .and_then(Value::as_str)
2460 .map(str::trim)
2461 .filter(|value| !value.is_empty())
2462 .map(String::from)
2463 })
2464}
2465
2466fn rpc_parse_extension_ui_response(
2467 parsed: &Value,
2468 active: &ExtensionUiRequest,
2469) -> std::result::Result<ExtensionUiResponse, String> {
2470 let cancelled = parsed
2471 .get("cancelled")
2472 .and_then(Value::as_bool)
2473 .unwrap_or(false);
2474
2475 if cancelled && active.method != "custom" {
2476 return Ok(ExtensionUiResponse {
2477 id: active.id.clone(),
2478 value: None,
2479 cancelled: true,
2480 });
2481 }
2482
2483 match active.method.as_str() {
2484 "confirm" => {
2485 let value = parsed
2486 .get("confirmed")
2487 .and_then(Value::as_bool)
2488 .or_else(|| parsed.get("value").and_then(Value::as_bool))
2489 .ok_or_else(|| "confirm requires boolean `confirmed` (or `value`)".to_string())?;
2490 Ok(ExtensionUiResponse {
2491 id: active.id.clone(),
2492 value: Some(Value::Bool(value)),
2493 cancelled: false,
2494 })
2495 }
2496 "select" => {
2497 let Some(value) = parsed.get("value") else {
2498 return Err("select requires `value` field".to_string());
2499 };
2500
2501 let options = active
2502 .payload
2503 .get("options")
2504 .and_then(Value::as_array)
2505 .ok_or_else(|| "select request missing `options` array".to_string())?;
2506
2507 let mut allowed = Vec::with_capacity(options.len());
2508 for opt in options {
2509 match opt {
2510 Value::String(s) => allowed.push(Value::String(s.clone())),
2511 Value::Object(map) => {
2512 let label = map
2513 .get("label")
2514 .and_then(Value::as_str)
2515 .unwrap_or("")
2516 .trim();
2517 if label.is_empty() {
2518 continue;
2519 }
2520 if let Some(v) = map.get("value") {
2521 allowed.push(v.clone());
2522 } else {
2523 allowed.push(Value::String(label.to_string()));
2524 }
2525 }
2526 _ => {}
2527 }
2528 }
2529
2530 if !allowed.iter().any(|candidate| candidate == value) {
2531 return Err("select response value did not match any option".to_string());
2532 }
2533
2534 Ok(ExtensionUiResponse {
2535 id: active.id.clone(),
2536 value: Some(value.clone()),
2537 cancelled: false,
2538 })
2539 }
2540 "input" | "editor" => {
2541 let Some(value) = parsed.get("value") else {
2542 return Err(format!("{} requires `value` field", active.method));
2543 };
2544 if !value.is_string() {
2545 return Err(format!("{} requires string `value`", active.method));
2546 }
2547 Ok(ExtensionUiResponse {
2548 id: active.id.clone(),
2549 value: Some(value.clone()),
2550 cancelled: false,
2551 })
2552 }
2553 "custom" => {
2554 if let Some(value) = parsed.get("value").filter(|value| !value.is_null()) {
2555 return Ok(ExtensionUiResponse {
2556 id: active.id.clone(),
2557 value: Some(value.clone()),
2558 cancelled: false,
2559 });
2560 }
2561
2562 let mut payload = serde_json::Map::new();
2563 if let Some(key) = parsed.get("key").and_then(Value::as_str) {
2564 payload.insert("key".to_string(), Value::String(key.to_string()));
2565 }
2566 if let Some(width) = parsed.get("width").and_then(Value::as_u64) {
2567 payload.insert("width".to_string(), Value::from(width));
2568 }
2569 if let Some(close) = parsed
2570 .get("cancelled")
2571 .or_else(|| parsed.get("closed"))
2572 .and_then(Value::as_bool)
2573 {
2574 payload.insert("closed".to_string(), Value::Bool(close));
2575 }
2576 if payload.is_empty() {
2577 return Err("custom requires `value`, `key`, `width`, or `cancelled`".to_string());
2578 }
2579 Ok(ExtensionUiResponse {
2580 id: active.id.clone(),
2581 value: Some(Value::Object(payload)),
2582 cancelled: false,
2583 })
2584 }
2585 "notify" => Ok(ExtensionUiResponse {
2586 id: active.id.clone(),
2587 value: None,
2588 cancelled: false,
2589 }),
2590 other => Err(format!("Unsupported extension UI method: {other}")),
2591 }
2592}
2593
2594#[cfg(test)]
2595mod ui_bridge_tests {
2596 use super::*;
2597
2598 #[test]
2599 fn parse_extension_ui_response_id_prefers_request_id() {
2600 let value = json!({"type":"extension_ui_response","id":"legacy","requestId":"canonical"});
2601 assert_eq!(
2602 rpc_parse_extension_ui_response_id(&value),
2603 Some("canonical".to_string())
2604 );
2605 }
2606
2607 #[test]
2608 fn parse_extension_ui_response_id_accepts_id_alias() {
2609 let value = json!({"type":"extension_ui_response","id":"legacy"});
2610 assert_eq!(
2611 rpc_parse_extension_ui_response_id(&value),
2612 Some("legacy".to_string())
2613 );
2614 }
2615
2616 #[test]
2617 fn parse_confirm_response_accepts_confirmed_alias() {
2618 let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
2619 let value = json!({"type":"extension_ui_response","requestId":"req-1","confirmed":true});
2620 let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse confirm");
2621 assert!(!resp.cancelled);
2622 assert_eq!(resp.value, Some(json!(true)));
2623 }
2624
2625 #[test]
2626 fn parse_confirm_response_accepts_value_bool() {
2627 let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
2628 let value = json!({"type":"extension_ui_response","requestId":"req-1","value":false});
2629 let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse confirm");
2630 assert!(!resp.cancelled);
2631 assert_eq!(resp.value, Some(json!(false)));
2632 }
2633
2634 #[test]
2635 fn parse_cancelled_response_wins_over_value() {
2636 let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
2637 let value = json!({"type":"extension_ui_response","requestId":"req-1","cancelled":true,"value":true});
2638 let resp = rpc_parse_extension_ui_response(&value, &active).expect("parse cancel");
2639 assert!(resp.cancelled);
2640 assert_eq!(resp.value, None);
2641 }
2642
2643 #[test]
2644 fn parse_select_response_validates_against_options() {
2645 let active = ExtensionUiRequest::new(
2646 "req-1",
2647 "select",
2648 json!({"title":"pick","options":["A","B"]}),
2649 );
2650 let ok_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"B"});
2651 let ok = rpc_parse_extension_ui_response(&ok_value, &active).expect("parse select ok");
2652 assert_eq!(ok.value, Some(json!("B")));
2653
2654 let bad_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"C"});
2655 assert!(
2656 rpc_parse_extension_ui_response(&bad_value, &active).is_err(),
2657 "invalid selection should error"
2658 );
2659 }
2660
2661 #[test]
2662 fn parse_input_requires_string_value() {
2663 let active = ExtensionUiRequest::new("req-1", "input", json!({"title":"t"}));
2664 let ok_value = json!({"type":"extension_ui_response","requestId":"req-1","value":"hi"});
2665 let ok = rpc_parse_extension_ui_response(&ok_value, &active).expect("parse input ok");
2666 assert_eq!(ok.value, Some(json!("hi")));
2667
2668 let bad_value = json!({"type":"extension_ui_response","requestId":"req-1","value":123});
2669 assert!(
2670 rpc_parse_extension_ui_response(&bad_value, &active).is_err(),
2671 "non-string input should error"
2672 );
2673 }
2674
2675 #[test]
2676 fn parse_editor_requires_string_value() {
2677 let active = ExtensionUiRequest::new("req-1", "editor", json!({"title":"t"}));
2678 let ok = json!({"requestId":"req-1","value":"multi\nline"});
2679 let resp = rpc_parse_extension_ui_response(&ok, &active).expect("editor ok");
2680 assert_eq!(resp.value, Some(json!("multi\nline")));
2681
2682 let bad = json!({"requestId":"req-1","value":42});
2683 assert!(
2684 rpc_parse_extension_ui_response(&bad, &active).is_err(),
2685 "editor needs string"
2686 );
2687 }
2688
2689 #[test]
2690 fn parse_notify_returns_no_value() {
2691 let active = ExtensionUiRequest::new("req-1", "notify", json!({"title":"t"}));
2692 let val = json!({"requestId":"req-1"});
2693 let resp = rpc_parse_extension_ui_response(&val, &active).expect("notify ok");
2694 assert!(!resp.cancelled);
2695 assert!(resp.value.is_none());
2696 }
2697
2698 #[test]
2699 fn parse_custom_accepts_value_passthrough() {
2700 let active = ExtensionUiRequest::new("req-1", "custom", json!({}));
2701 let val = json!({"requestId":"req-1","value":{"key":"w","width":88}});
2702 let resp = rpc_parse_extension_ui_response(&val, &active).expect("custom value");
2703 assert_eq!(resp.value, Some(json!({"key":"w","width":88})));
2704 assert!(!resp.cancelled);
2705 }
2706
2707 #[test]
2708 fn parse_custom_accepts_key_width_fields() {
2709 let active = ExtensionUiRequest::new("req-1", "custom", json!({}));
2710 let val = json!({"requestId":"req-1","key":"q","width":120});
2711 let resp = rpc_parse_extension_ui_response(&val, &active).expect("custom key+width");
2712 assert_eq!(resp.value, Some(json!({"key":"q","width":120})));
2713 assert!(!resp.cancelled);
2714 }
2715
2716 #[test]
2717 fn parse_custom_preserves_cancelled_and_width_as_payload() {
2718 let active = ExtensionUiRequest::new("req-1", "custom", json!({}));
2719 let val = json!({"requestId":"req-1","width":120,"cancelled":true});
2720 let resp = rpc_parse_extension_ui_response(&val, &active).expect("custom cancelled+width");
2721 assert_eq!(resp.value, Some(json!({"width":120,"closed":true})));
2722 assert!(!resp.cancelled);
2723 }
2724
2725 #[test]
2726 fn parse_custom_treats_null_value_as_absent_for_close_payloads() {
2727 let active = ExtensionUiRequest::new("req-1", "custom", json!({}));
2728 let val = json!({"requestId":"req-1","value":null,"cancelled":true});
2729 let resp = rpc_parse_extension_ui_response(&val, &active).expect("custom null+cancelled");
2730 assert_eq!(resp.value, Some(json!({"closed":true})));
2731 assert!(!resp.cancelled);
2732 }
2733
2734 #[test]
2735 fn parse_unsupported_method_errors() {
2736 let active = ExtensionUiRequest::new("req-1", "custom_method", json!({}));
2737 let val = json!({"requestId":"req-1","value":"x"});
2738 let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
2739 assert!(err.contains("Unsupported"), "err={err}");
2740 }
2741
2742 #[test]
2743 fn parse_select_missing_value_field() {
2744 let active =
2745 ExtensionUiRequest::new("req-1", "select", json!({"title":"pick","options":["A"]}));
2746 let val = json!({"requestId":"req-1"});
2747 let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
2748 assert!(err.contains("value"), "err={err}");
2749 }
2750
2751 #[test]
2752 fn parse_confirm_missing_value_errors() {
2753 let active = ExtensionUiRequest::new("req-1", "confirm", json!({"title":"t"}));
2754 let val = json!({"requestId":"req-1"});
2755 let err = rpc_parse_extension_ui_response(&val, &active).unwrap_err();
2756 assert!(err.contains("confirm"), "err={err}");
2757 }
2758
2759 #[test]
2760 fn parse_select_with_label_value_objects() {
2761 let active = ExtensionUiRequest::new(
2762 "req-1",
2763 "select",
2764 json!({
2765 "title": "pick",
2766 "options": [
2767 {"label": "Alpha", "value": "a"},
2768 {"label": "Beta", "value": "b"},
2769 ]
2770 }),
2771 );
2772 let val = json!({"requestId":"req-1","value":"a"});
2773 let resp = rpc_parse_extension_ui_response(&val, &active).expect("select by value");
2774 assert_eq!(resp.value, Some(json!("a")));
2775 }
2776
2777 #[test]
2778 fn parse_id_rejects_empty_and_whitespace() {
2779 let val = json!({"requestId":" ","id":""});
2780 assert!(rpc_parse_extension_ui_response_id(&val).is_none());
2781 }
2782
2783 #[test]
2784 fn bridge_state_default_is_empty() {
2785 let state = RpcUiBridgeState::default();
2786 assert!(state.active.is_none());
2787 assert!(state.queue.is_empty());
2788 }
2789}
2790
2791fn error_hints_value(error: &Error) -> Value {
2792 let hint = error_hints::hints_for_error(error);
2793 json!({
2794 "summary": hint.summary,
2795 "hints": hint.hints,
2796 "contextFields": hint.context_fields,
2797 })
2798}
2799
2800fn rpc_session_message_value(message: SessionMessage) -> Value {
2801 let mut value = match serde_json::to_value(message) {
2802 Ok(v) => v,
2803 Err(err) => {
2804 tracing::error!("Failed to serialize SessionMessage: {err}");
2805 return serde_json::json!({"error": format!("serialization error: {err}")});
2806 }
2807 };
2808 rpc_flatten_content_blocks(&mut value);
2809 value
2810}
2811
2812fn rpc_flatten_content_blocks(value: &mut Value) {
2813 let Value::Object(message_obj) = value else {
2814 return;
2815 };
2816 let Some(content) = message_obj.get_mut("content") else {
2817 return;
2818 };
2819 let Value::Array(blocks) = content else {
2820 return;
2821 };
2822
2823 for block in blocks {
2824 let Value::Object(block_obj) = block else {
2825 continue;
2826 };
2827 let Some(inner) = block_obj.remove("0") else {
2828 continue;
2829 };
2830 let Value::Object(inner_obj) = inner else {
2831 block_obj.insert("0".to_string(), inner);
2832 continue;
2833 };
2834 for (key, value) in inner_obj {
2835 block_obj.entry(key).or_insert(value);
2836 }
2837 }
2838}
2839
2840fn retry_delay_ms(config: &Config, attempt: u32) -> u32 {
2841 let base = u64::from(config.retry_base_delay_ms());
2842 let max = u64::from(config.retry_max_delay_ms());
2843 let shift = attempt.saturating_sub(1);
2844 let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
2845 let delay = base.saturating_mul(multiplier).min(max);
2846 u32::try_from(delay).unwrap_or(u32::MAX)
2847}
2848
2849#[cfg(test)]
2850mod retry_tests {
2851 use super::*;
2852 use crate::agent::{Agent, AgentConfig, AgentSession};
2853 use crate::model::{AssistantMessage, Usage};
2854 use crate::provider::Provider;
2855 use crate::resources::ResourceLoader;
2856 use crate::session::Session;
2857 use crate::tools::ToolRegistry;
2858 use async_trait::async_trait;
2859 use futures::stream;
2860 use std::path::Path;
2861 use std::pin::Pin;
2862 use std::sync::atomic::{AtomicUsize, Ordering};
2863
2864 #[derive(Debug)]
2865 struct FlakyProvider {
2866 calls: AtomicUsize,
2867 }
2868
2869 impl FlakyProvider {
2870 const fn new() -> Self {
2871 Self {
2872 calls: AtomicUsize::new(0),
2873 }
2874 }
2875 }
2876
2877 #[async_trait]
2878 #[allow(clippy::unnecessary_literal_bound)]
2879 impl Provider for FlakyProvider {
2880 fn name(&self) -> &str {
2881 "test-provider"
2882 }
2883
2884 fn api(&self) -> &str {
2885 "test-api"
2886 }
2887
2888 fn model_id(&self) -> &str {
2889 "test-model"
2890 }
2891
2892 async fn stream(
2893 &self,
2894 _context: &crate::provider::Context<'_>,
2895 _options: &crate::provider::StreamOptions,
2896 ) -> crate::error::Result<
2897 Pin<
2898 Box<
2899 dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
2900 + Send,
2901 >,
2902 >,
2903 > {
2904 let call = self.calls.fetch_add(1, Ordering::SeqCst);
2905
2906 let mut partial = AssistantMessage {
2907 content: Vec::new(),
2908 api: self.api().to_string(),
2909 provider: self.name().to_string(),
2910 model: self.model_id().to_string(),
2911 usage: Usage::default(),
2912 stop_reason: StopReason::Stop,
2913 error_message: None,
2914 timestamp: 0,
2915 };
2916
2917 let events = if call == 0 {
2918 partial.stop_reason = StopReason::Error;
2920 partial.error_message = Some("server error".to_string());
2921 vec![
2922 Ok(crate::model::StreamEvent::Start {
2923 partial: partial.clone(),
2924 }),
2925 Ok(crate::model::StreamEvent::Error {
2926 reason: StopReason::Error,
2927 error: partial,
2928 }),
2929 ]
2930 } else {
2931 vec![
2933 Ok(crate::model::StreamEvent::Start {
2934 partial: partial.clone(),
2935 }),
2936 Ok(crate::model::StreamEvent::Done {
2937 reason: StopReason::Stop,
2938 message: partial,
2939 }),
2940 ]
2941 };
2942
2943 Ok(Box::pin(stream::iter(events)))
2944 }
2945 }
2946
2947 #[derive(Debug)]
2948 struct AlwaysErrorProvider;
2949
2950 #[async_trait]
2951 #[allow(clippy::unnecessary_literal_bound)]
2952 impl Provider for AlwaysErrorProvider {
2953 fn name(&self) -> &str {
2954 "test-provider"
2955 }
2956
2957 fn api(&self) -> &str {
2958 "test-api"
2959 }
2960
2961 fn model_id(&self) -> &str {
2962 "test-model"
2963 }
2964
2965 async fn stream(
2966 &self,
2967 _context: &crate::provider::Context<'_>,
2968 _options: &crate::provider::StreamOptions,
2969 ) -> crate::error::Result<
2970 Pin<
2971 Box<
2972 dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
2973 + Send,
2974 >,
2975 >,
2976 > {
2977 let mut partial = AssistantMessage {
2978 content: Vec::new(),
2979 api: self.api().to_string(),
2980 provider: self.name().to_string(),
2981 model: self.model_id().to_string(),
2982 usage: Usage::default(),
2983 stop_reason: StopReason::Error,
2984 error_message: Some("server error".to_string()),
2985 timestamp: 0,
2986 };
2987
2988 let events = vec![
2989 Ok(crate::model::StreamEvent::Start {
2990 partial: partial.clone(),
2991 }),
2992 Ok(crate::model::StreamEvent::Error {
2993 reason: StopReason::Error,
2994 error: {
2995 partial.stop_reason = StopReason::Error;
2996 partial
2997 },
2998 }),
2999 ];
3000
3001 Ok(Box::pin(stream::iter(events)))
3002 }
3003 }
3004
3005 #[test]
3006 fn rpc_auto_retry_retries_then_succeeds() {
3007 let runtime = asupersync::runtime::RuntimeBuilder::new()
3008 .blocking_threads(1, 8)
3009 .build()
3010 .expect("runtime build");
3011 let runtime_handle = runtime.handle();
3012
3013 runtime.block_on(async move {
3014 let provider = Arc::new(FlakyProvider::new());
3015 let tools = ToolRegistry::new(&[], Path::new("."), None);
3016 let agent = Agent::new(provider, tools, AgentConfig::default());
3017 let inner_session = Arc::new(Mutex::new(Session::in_memory()));
3018 let agent_session = AgentSession::new(
3019 agent,
3020 inner_session,
3021 false,
3022 crate::compaction::ResolvedCompactionSettings::default(),
3023 );
3024
3025 let session = Arc::new(Mutex::new(agent_session));
3026
3027 let mut config = Config::default();
3028 config.retry = Some(crate::config::RetrySettings {
3029 enabled: Some(true),
3030 max_retries: Some(1),
3031 base_delay_ms: Some(1),
3032 max_delay_ms: Some(1),
3033 });
3034
3035 let mut shared = RpcSharedState::new(&config);
3036 shared.auto_compaction_enabled = false;
3037 let shared_state = Arc::new(Mutex::new(shared));
3038
3039 let is_streaming = Arc::new(AtomicBool::new(false));
3040 let is_compacting = Arc::new(AtomicBool::new(false));
3041 let abort_handle_slot: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
3042 let retry_abort = Arc::new(AtomicBool::new(false));
3043 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
3044
3045 let auth_path = tempfile::tempdir()
3046 .expect("tempdir")
3047 .path()
3048 .join("auth.json");
3049 let auth = AuthStorage::load(auth_path).expect("auth load");
3050
3051 let options = RpcOptions {
3052 config,
3053 resources: ResourceLoader::empty(false),
3054 available_models: Vec::new(),
3055 scoped_models: Vec::new(),
3056 cli_api_key: None,
3057 auth,
3058 runtime_handle,
3059 };
3060
3061 run_prompt_with_retry(
3062 session,
3063 shared_state,
3064 is_streaming,
3065 is_compacting,
3066 abort_handle_slot,
3067 out_tx,
3068 retry_abort,
3069 options,
3070 "hello".to_string(),
3071 Vec::new(),
3072 AgentCx::for_request(),
3073 )
3074 .await;
3075
3076 let mut saw_retry_start = false;
3077 let mut saw_retry_end_success = false;
3078
3079 for line in out_rx.try_iter() {
3080 let Ok(value) = serde_json::from_str::<Value>(&line) else {
3081 continue;
3082 };
3083 let Some(kind) = value.get("type").and_then(Value::as_str) else {
3084 continue;
3085 };
3086 match kind {
3087 "auto_retry_start" => {
3088 saw_retry_start = true;
3089 }
3090 "auto_retry_end"
3091 if value.get("success").and_then(Value::as_bool) == Some(true) =>
3092 {
3093 saw_retry_end_success = true;
3094 }
3095 _ => {}
3096 }
3097 }
3098
3099 assert!(saw_retry_start, "missing auto_retry_start event");
3100 assert!(
3101 saw_retry_end_success,
3102 "missing successful auto_retry_end event"
3103 );
3104 });
3105 }
3106
3107 #[test]
3108 fn rpc_abort_retry_emits_ordered_retry_timeline() {
3109 let runtime = asupersync::runtime::RuntimeBuilder::new()
3110 .blocking_threads(1, 8)
3111 .build()
3112 .expect("runtime build");
3113 let runtime_handle = runtime.handle();
3114
3115 runtime.block_on(async move {
3116 let provider = Arc::new(AlwaysErrorProvider);
3117 let tools = ToolRegistry::new(&[], Path::new("."), None);
3118 let agent = Agent::new(provider, tools, AgentConfig::default());
3119 let inner_session = Arc::new(Mutex::new(Session::in_memory()));
3120 let agent_session = AgentSession::new(
3121 agent,
3122 inner_session,
3123 false,
3124 crate::compaction::ResolvedCompactionSettings::default(),
3125 );
3126
3127 let session = Arc::new(Mutex::new(agent_session));
3128
3129 let mut config = Config::default();
3130 config.retry = Some(crate::config::RetrySettings {
3131 enabled: Some(true),
3132 max_retries: Some(3),
3133 base_delay_ms: Some(100),
3134 max_delay_ms: Some(100),
3135 });
3136
3137 let mut shared = RpcSharedState::new(&config);
3138 shared.auto_compaction_enabled = false;
3139 let shared_state = Arc::new(Mutex::new(shared));
3140
3141 let is_streaming = Arc::new(AtomicBool::new(false));
3142 let is_compacting = Arc::new(AtomicBool::new(false));
3143 let abort_handle_slot: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
3144 let retry_abort = Arc::new(AtomicBool::new(false));
3145 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
3146
3147 let auth_path = tempfile::tempdir()
3148 .expect("tempdir")
3149 .path()
3150 .join("auth.json");
3151 let auth = AuthStorage::load(auth_path).expect("auth load");
3152
3153 let options = RpcOptions {
3154 config,
3155 resources: ResourceLoader::empty(false),
3156 available_models: Vec::new(),
3157 scoped_models: Vec::new(),
3158 cli_api_key: None,
3159 auth,
3160 runtime_handle,
3161 };
3162
3163 let retry_abort_for_thread = Arc::clone(&retry_abort);
3164 let abort_thread = std::thread::spawn(move || {
3165 std::thread::sleep(std::time::Duration::from_millis(10));
3166 retry_abort_for_thread.store(true, Ordering::SeqCst);
3167 });
3168
3169 run_prompt_with_retry(
3170 session,
3171 shared_state,
3172 is_streaming,
3173 is_compacting,
3174 abort_handle_slot,
3175 out_tx,
3176 retry_abort,
3177 options,
3178 "hello".to_string(),
3179 Vec::new(),
3180 AgentCx::for_request(),
3181 )
3182 .await;
3183 abort_thread.join().expect("abort thread join");
3184
3185 let mut timeline = Vec::new();
3186 let mut last_agent_end_error = None::<String>;
3187
3188 for line in out_rx.try_iter() {
3189 let Ok(value) = serde_json::from_str::<Value>(&line) else {
3190 continue;
3191 };
3192 let Some(kind) = value.get("type").and_then(Value::as_str) else {
3193 continue;
3194 };
3195 timeline.push(kind.to_string());
3196 if kind == "agent_end" {
3197 last_agent_end_error = value
3198 .get("error")
3199 .and_then(Value::as_str)
3200 .map(str::to_string);
3201 }
3202 }
3203
3204 let retry_start_idx = timeline
3205 .iter()
3206 .position(|kind| kind == "auto_retry_start")
3207 .expect("missing auto_retry_start");
3208 let retry_end_idx = timeline
3209 .iter()
3210 .position(|kind| kind == "auto_retry_end")
3211 .expect("missing auto_retry_end");
3212 let agent_end_idx = timeline
3213 .iter()
3214 .rposition(|kind| kind == "agent_end")
3215 .expect("missing agent_end");
3216
3217 assert!(
3218 retry_start_idx < retry_end_idx && retry_end_idx < agent_end_idx,
3219 "unexpected retry timeline ordering: {timeline:?}"
3220 );
3221 assert_eq!(
3222 last_agent_end_error.as_deref(),
3223 Some("Retry aborted"),
3224 "expected retry-abort terminal error, timeline: {timeline:?}"
3225 );
3226 });
3227 }
3228
3229 #[test]
3230 fn rpc_cancelled_agent_cx_aborts_retry_timeline() {
3231 let runtime = asupersync::runtime::RuntimeBuilder::new()
3232 .blocking_threads(1, 8)
3233 .build()
3234 .expect("runtime build");
3235 let runtime_handle = runtime.handle();
3236
3237 runtime.block_on(async move {
3238 let provider = Arc::new(AlwaysErrorProvider);
3239 let tools = ToolRegistry::new(&[], Path::new("."), None);
3240 let agent = Agent::new(provider, tools, AgentConfig::default());
3241 let inner_session = Arc::new(Mutex::new(Session::in_memory()));
3242 let agent_session = AgentSession::new(
3243 agent,
3244 inner_session,
3245 false,
3246 crate::compaction::ResolvedCompactionSettings::default(),
3247 );
3248
3249 let session = Arc::new(Mutex::new(agent_session));
3250
3251 let mut config = Config::default();
3252 config.retry = Some(crate::config::RetrySettings {
3253 enabled: Some(true),
3254 max_retries: Some(3),
3255 base_delay_ms: Some(100),
3256 max_delay_ms: Some(100),
3257 });
3258
3259 let mut shared = RpcSharedState::new(&config);
3260 shared.auto_compaction_enabled = false;
3261 let shared_state = Arc::new(Mutex::new(shared));
3262
3263 let is_streaming = Arc::new(AtomicBool::new(false));
3264 let is_compacting = Arc::new(AtomicBool::new(false));
3265 let abort_handle_slot: Arc<Mutex<Option<AbortHandle>>> = Arc::new(Mutex::new(None));
3266 let retry_abort = Arc::new(AtomicBool::new(false));
3267 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
3268
3269 let auth_path = tempfile::tempdir()
3270 .expect("tempdir")
3271 .path()
3272 .join("auth.json");
3273 let auth = AuthStorage::load(auth_path).expect("auth load");
3274
3275 let options = RpcOptions {
3276 config,
3277 resources: ResourceLoader::empty(false),
3278 available_models: Vec::new(),
3279 scoped_models: Vec::new(),
3280 cli_api_key: None,
3281 auth,
3282 runtime_handle,
3283 };
3284
3285 let retry_cx = asupersync::Cx::for_testing();
3286 let cancel_cx = retry_cx.clone();
3287 let cancel_thread = std::thread::spawn(move || {
3288 std::thread::sleep(std::time::Duration::from_millis(10));
3289 cancel_cx.set_cancel_requested(true);
3290 });
3291
3292 run_prompt_with_retry(
3293 session,
3294 shared_state,
3295 is_streaming,
3296 is_compacting,
3297 abort_handle_slot,
3298 out_tx,
3299 retry_abort,
3300 options,
3301 "hello".to_string(),
3302 Vec::new(),
3303 AgentCx::from_cx(retry_cx),
3304 )
3305 .await;
3306 cancel_thread.join().expect("cancel thread join");
3307
3308 let mut timeline = Vec::new();
3309 let mut last_agent_end_error = None::<String>;
3310
3311 for line in out_rx.try_iter() {
3312 let Ok(value) = serde_json::from_str::<Value>(&line) else {
3313 continue;
3314 };
3315 let Some(kind) = value.get("type").and_then(Value::as_str) else {
3316 continue;
3317 };
3318 timeline.push(kind.to_string());
3319 if kind == "agent_end" {
3320 last_agent_end_error = value
3321 .get("error")
3322 .and_then(Value::as_str)
3323 .map(str::to_string);
3324 }
3325 }
3326
3327 let retry_start_idx = timeline
3328 .iter()
3329 .position(|kind| kind == "auto_retry_start")
3330 .expect("missing auto_retry_start");
3331 let retry_end_idx = timeline
3332 .iter()
3333 .position(|kind| kind == "auto_retry_end")
3334 .expect("missing auto_retry_end");
3335 let agent_end_idx = timeline
3336 .iter()
3337 .rposition(|kind| kind == "agent_end")
3338 .expect("missing agent_end");
3339
3340 assert!(
3341 retry_start_idx < retry_end_idx && retry_end_idx < agent_end_idx,
3342 "unexpected retry timeline ordering: {timeline:?}"
3343 );
3344 assert_eq!(
3345 last_agent_end_error.as_deref(),
3346 Some("Retry aborted"),
3347 "expected retry-abort terminal error, timeline: {timeline:?}"
3348 );
3349 });
3350 }
3351
3352 #[test]
3353 fn rpc_prompt_command_inherits_cancelled_context_from_run() {
3354 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
3355 .build()
3356 .expect("runtime build");
3357 let runtime_handle = runtime.handle();
3358
3359 runtime.block_on(async move {
3360 let provider = Arc::new(AlwaysErrorProvider);
3361 let tools = ToolRegistry::new(&[], Path::new("."), None);
3362 let agent = Agent::new(provider, tools, AgentConfig::default());
3363 let agent_session = AgentSession::new(
3364 agent,
3365 Arc::new(asupersync::sync::Mutex::new(Session::in_memory())),
3366 false,
3367 crate::compaction::ResolvedCompactionSettings::default(),
3368 );
3369
3370 let mut config = Config::default();
3371 config.retry = Some(crate::config::RetrySettings {
3372 enabled: Some(true),
3373 max_retries: Some(10),
3374 base_delay_ms: Some(100),
3375 max_delay_ms: Some(100),
3376 });
3377
3378 let auth_path = tempfile::tempdir()
3379 .expect("tempdir")
3380 .path()
3381 .join("auth.json");
3382 let auth = AuthStorage::load(auth_path).expect("auth load");
3383 let options = RpcOptions {
3384 config,
3385 resources: ResourceLoader::empty(false),
3386 available_models: Vec::new(),
3387 scoped_models: Vec::new(),
3388 cli_api_key: None,
3389 auth,
3390 runtime_handle,
3391 };
3392
3393 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
3394 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
3395 let out_rx = Arc::new(std::sync::Mutex::new(out_rx));
3396
3397 let ambient_cx = asupersync::Cx::for_testing();
3398 let cancel_cx = ambient_cx.clone();
3399 let _current = asupersync::Cx::set_current(Some(ambient_cx));
3400
3401 let client_out_rx = Arc::clone(&out_rx);
3402 let client = async move {
3403 let send_cx = asupersync::Cx::for_testing();
3404 in_tx
3405 .send(
3406 &send_cx,
3407 r#"{"id":"1","type":"prompt","message":"hello"}"#.to_string(),
3408 )
3409 .await
3410 .expect("send prompt command");
3411
3412 let ack_wait = async {
3413 loop {
3414 let recv_result = {
3415 let rx = client_out_rx.lock().expect("lock rpc output receiver");
3416 rx.try_recv()
3417 };
3418
3419 match recv_result {
3420 Ok(line) => {
3421 let value: Value =
3422 serde_json::from_str(&line).expect("parse rpc output");
3423 if value.get("type").and_then(Value::as_str) == Some("response") {
3424 break value;
3425 }
3426 }
3427 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
3428 tracing::warn!(
3429 "prompt(cancel-inherit): output channel disconnected"
3430 );
3431 break Value::Object(serde_json::Map::new());
3432 }
3433 Err(std::sync::mpsc::TryRecvError::Empty) => {
3434 asupersync::time::sleep(
3435 asupersync::time::wall_now(),
3436 Duration::from_millis(5),
3437 )
3438 .await;
3439 }
3440 }
3441 }
3442 };
3443 futures::pin_mut!(ack_wait);
3444 let ack = asupersync::time::timeout(
3445 asupersync::time::wall_now(),
3446 Duration::from_secs(1),
3447 ack_wait,
3448 )
3449 .await;
3450 let ack = ack.expect("prompt acknowledgement");
3451 assert_eq!(ack["command"], "prompt");
3452 assert_eq!(ack["success"], true, "prompt should be accepted: {ack}");
3453
3454 let cancel_thread = std::thread::spawn(move || {
3455 std::thread::sleep(Duration::from_millis(20));
3456 cancel_cx.set_cancel_requested(true);
3457 });
3458
3459 let retry_abort_wait = async {
3460 let mut timeline = Vec::new();
3461 loop {
3462 let recv_result = {
3463 let rx = client_out_rx.lock().expect("lock rpc output receiver");
3464 rx.try_recv()
3465 };
3466
3467 match recv_result {
3468 Ok(line) => {
3469 let value: Value =
3470 serde_json::from_str(&line).expect("parse rpc output");
3471 let Some(kind) = value.get("type").and_then(Value::as_str) else {
3472 continue;
3473 };
3474 timeline.push(kind.to_string());
3475 if kind == "agent_end" {
3476 let agent_end_error = value
3477 .get("error")
3478 .and_then(Value::as_str)
3479 .map(str::to_string);
3480 if agent_end_error.as_deref() == Some("Retry aborted") {
3481 break (timeline, agent_end_error);
3482 }
3483 }
3484 }
3485 Err(std::sync::mpsc::TryRecvError::Disconnected) => {
3486 tracing::warn!(
3487 "prompt(cancel-inherit): output channel disconnected"
3488 );
3489 break (timeline, None);
3490 }
3491 Err(std::sync::mpsc::TryRecvError::Empty) => {
3492 asupersync::time::sleep(
3493 asupersync::time::wall_now(),
3494 Duration::from_millis(5),
3495 )
3496 .await;
3497 }
3498 }
3499 }
3500 };
3501 futures::pin_mut!(retry_abort_wait);
3502 let (timeline, last_agent_end_error) = asupersync::time::timeout(
3503 asupersync::time::wall_now(),
3504 Duration::from_secs(1),
3505 retry_abort_wait,
3506 )
3507 .await
3508 .expect("cancelled prompt should finish before timeout");
3509
3510 cancel_thread.join().expect("cancel thread join");
3511 let retry_start_idx = timeline
3512 .iter()
3513 .position(|kind| kind == "auto_retry_start")
3514 .expect("missing auto_retry_start");
3515 let retry_end_idx = timeline
3516 .iter()
3517 .position(|kind| kind == "auto_retry_end")
3518 .expect("missing auto_retry_end");
3519 let agent_end_idx = timeline
3520 .iter()
3521 .rposition(|kind| kind == "agent_end")
3522 .expect("missing agent_end");
3523 assert!(
3524 retry_start_idx < retry_end_idx && retry_end_idx < agent_end_idx,
3525 "unexpected retry timeline ordering: {timeline:?}"
3526 );
3527 assert_eq!(
3528 last_agent_end_error.as_deref(),
3529 Some("Retry aborted"),
3530 "expected retry-abort terminal error, timeline: {timeline:?}"
3531 );
3532
3533 drop(in_tx);
3534 };
3535
3536 let (server_result, ()) =
3537 futures::future::join(run(agent_session, options, in_rx, out_tx), client).await;
3538 assert!(server_result.is_ok(), "rpc server error: {server_result:?}");
3539 });
3540 }
3541
3542 #[cfg(target_os = "linux")]
3543 #[test]
3544 fn run_bash_rpc_cancelled_context_kills_process_tree() {
3545 asupersync::test_utils::run_test(|| async {
3546 let tmp = tempfile::tempdir().expect("tempdir");
3547 let marker = tmp.path().join("leaked_child.txt");
3548
3549 let ambient_cx = asupersync::Cx::for_testing();
3550 let cancel_cx = ambient_cx.clone();
3551 let _current = asupersync::Cx::set_current(Some(ambient_cx));
3552
3553 let cancel_thread = std::thread::spawn(move || {
3554 std::thread::sleep(std::time::Duration::from_millis(100));
3555 cancel_cx.set_cancel_requested(true);
3556 });
3557
3558 let (_abort_tx, abort_rx) = oneshot::channel();
3559 let result = run_bash_rpc(
3560 tmp.path(),
3561 "(sleep 3; echo leaked > leaked_child.txt) & sleep 10",
3562 abort_rx,
3563 )
3564 .await
3565 .expect("cancelled rpc bash should return a result");
3566
3567 cancel_thread.join().expect("cancel thread");
3568
3569 assert!(
3570 result.cancelled,
3571 "expected cancelled rpc bash result: {result:?}"
3572 );
3573
3574 std::thread::sleep(std::time::Duration::from_secs(4));
3575 assert!(
3576 !marker.exists(),
3577 "background child was not terminated on RPC cancellation"
3578 );
3579 });
3580 }
3581
3582 #[test]
3583 fn run_bash_rpc_large_output_completes_without_deadlock() {
3584 asupersync::test_utils::run_test(|| async {
3585 let tmp = tempfile::tempdir().expect("tempdir");
3586 let (_abort_tx, abort_rx) = oneshot::channel();
3587 let run = run_bash_rpc(tmp.path(), "yes x | head -c 1200000", abort_rx);
3588
3589 let result = asupersync::time::timeout(
3590 asupersync::time::wall_now(),
3591 std::time::Duration::from_secs(15),
3592 Box::pin(run),
3593 )
3594 .await
3595 .expect("rpc bash timed out; possible stdout/stderr reader deadlock")
3596 .expect("rpc bash should succeed");
3597
3598 assert_eq!(result.exit_code, 0, "expected successful shell exit");
3599 assert!(
3600 result.truncated,
3601 "large RPC bash output should truncate instead of blocking"
3602 );
3603 });
3604 }
3605
3606 #[test]
3607 fn rpc_spill_file_abandon_clears_path_and_unlinks_file() {
3608 let tmp = tempfile::tempdir().expect("tempdir");
3609 let spill_path = tmp.path().join("partial-rpc-bash.log");
3610 std::fs::write(&spill_path, b"partial output").expect("write spill file");
3611
3612 let mut temp_file = None;
3613 let mut temp_file_path = Some(spill_path.clone());
3614 let mut spill_failed = false;
3615
3616 abandon_bash_rpc_spill_file(&mut temp_file, &mut temp_file_path, &mut spill_failed);
3617
3618 assert!(spill_failed);
3619 assert!(temp_file.is_none());
3620 assert!(temp_file_path.is_none());
3621 assert!(
3622 !spill_path.exists(),
3623 "abandoned RPC spill files should not be left behind"
3624 );
3625 }
3626
3627 #[test]
3628 fn rpc_spill_file_hard_limit_abandons_partial_spill_file() {
3629 asupersync::test_utils::run_test(|| async {
3630 let tmp = tempfile::tempdir().expect("tempdir");
3631 let spill_path = tmp.path().join("hard-limit-rpc-bash.log");
3632 std::fs::write(&spill_path, b"partial output").expect("write spill file");
3633
3634 let spill_file = asupersync::fs::OpenOptions::new()
3635 .append(true)
3636 .open(&spill_path)
3637 .await
3638 .expect("open spill file");
3639
3640 let mut chunks = VecDeque::new();
3641 let mut chunks_bytes = 0usize;
3642 let mut total_bytes = crate::tools::BASH_FILE_LIMIT_BYTES;
3643 let mut total_lines = 0usize;
3644 let mut last_byte_was_newline = false;
3645 let mut temp_file = Some(spill_file);
3646 let mut temp_file_path = Some(spill_path.clone());
3647 let mut spill_failed = false;
3648
3649 ingest_bash_rpc_chunk(
3650 vec![b'x'],
3651 &mut chunks,
3652 &mut chunks_bytes,
3653 &mut total_bytes,
3654 &mut total_lines,
3655 &mut last_byte_was_newline,
3656 &mut temp_file,
3657 &mut temp_file_path,
3658 &mut spill_failed,
3659 DEFAULT_MAX_BYTES,
3660 )
3661 .await;
3662
3663 assert!(spill_failed);
3664 assert!(temp_file.is_none());
3665 assert!(temp_file_path.is_none());
3666 assert!(
3667 !spill_path.exists(),
3668 "hard-limit RPC spill files must be discarded"
3669 );
3670 });
3671 }
3672}
3673
3674fn should_auto_compact(tokens_before: u64, context_window: u32, reserve_tokens: u32) -> bool {
3675 let reserve = u64::from(reserve_tokens);
3676 let window = u64::from(context_window);
3677 tokens_before > window.saturating_sub(reserve)
3678}
3679
3680#[allow(clippy::too_many_lines)]
3681async fn maybe_auto_compact(
3682 session: Arc<Mutex<AgentSession>>,
3683 options: RpcOptions,
3684 is_compacting: Arc<AtomicBool>,
3685 out_tx: std::sync::mpsc::SyncSender<String>,
3686) {
3687 let cx = AgentCx::for_current_or_request();
3688 let (path_entries, context_window, reserve_tokens, settings) = {
3689 let Ok(guard) = session.lock(cx.cx()).await else {
3690 return;
3691 };
3692 let (path_entries, context_window) = {
3693 let runtime_provider = guard.agent.provider().name().to_string();
3694 let runtime_model_id = guard.agent.provider().model_id().to_string();
3695 let Ok(mut inner_session) = guard.session.lock(cx.cx()).await else {
3696 return;
3697 };
3698 inner_session.ensure_entry_ids();
3699 let Some(entry) = current_or_runtime_model_entry(
3700 &inner_session,
3701 &runtime_provider,
3702 &runtime_model_id,
3703 &options,
3704 ) else {
3705 return;
3706 };
3707 let path_entries = inner_session
3708 .entries_for_current_path()
3709 .into_iter()
3710 .cloned()
3711 .collect::<Vec<_>>();
3712 (path_entries, entry.model.context_window)
3713 };
3714
3715 let reserve_tokens = options.config.compaction_reserve_tokens();
3716 let settings = ResolvedCompactionSettings {
3717 enabled: true,
3718 reserve_tokens,
3719 keep_recent_tokens: options.config.compaction_keep_recent_tokens(),
3720 ..Default::default()
3721 };
3722
3723 (path_entries, context_window, reserve_tokens, settings)
3724 };
3725
3726 let Some(prep) = prepare_compaction(&path_entries, settings) else {
3727 return;
3728 };
3729 if !should_auto_compact(prep.tokens_before, context_window, reserve_tokens) {
3730 return;
3731 }
3732
3733 let _ = out_tx.send(event(&json!({
3734 "type": "auto_compaction_start",
3735 "reason": "threshold",
3736 })));
3737 is_compacting.store(true, Ordering::SeqCst);
3738
3739 let (provider, key) = {
3740 let Ok(guard) = session.lock(cx.cx()).await else {
3741 is_compacting.store(false, Ordering::SeqCst);
3742 return;
3743 };
3744 let Some(key) = guard.agent.stream_options().api_key.clone() else {
3745 is_compacting.store(false, Ordering::SeqCst);
3746 let _ = out_tx.send(event(&json!({
3747 "type": "auto_compaction_end",
3748 "result": Value::Null,
3749 "aborted": false,
3750 "willRetry": false,
3751 "errorMessage": "Missing API key for compaction",
3752 })));
3753 return;
3754 };
3755 (guard.agent.provider(), key)
3756 };
3757
3758 let result = compact(prep, provider, &key, None).await;
3759 is_compacting.store(false, Ordering::SeqCst);
3760
3761 match result {
3762 Ok(result) => {
3763 let details_value = match compaction_details_to_value(&result.details) {
3764 Ok(value) => value,
3765 Err(err) => {
3766 let _ = out_tx.send(event(&json!({
3767 "type": "auto_compaction_end",
3768 "result": Value::Null,
3769 "aborted": false,
3770 "willRetry": false,
3771 "errorMessage": err.to_string(),
3772 })));
3773 return;
3774 }
3775 };
3776
3777 let Ok(mut guard) = session.lock(cx.cx()).await else {
3778 return;
3779 };
3780 let messages = {
3781 let Ok(mut inner_session) = guard.session.lock(cx.cx()).await else {
3782 return;
3783 };
3784 inner_session.append_compaction(
3785 result.summary.clone(),
3786 result.first_kept_entry_id.clone(),
3787 result.tokens_before,
3788 Some(details_value.clone()),
3789 None,
3790 );
3791 inner_session.to_messages_for_current_path()
3792 };
3793 let _ = guard.persist_session().await;
3794 guard.agent.replace_messages(messages);
3795 drop(guard);
3796
3797 let _ = out_tx.send(event(&json!({
3798 "type": "auto_compaction_end",
3799 "result": {
3800 "summary": result.summary,
3801 "firstKeptEntryId": result.first_kept_entry_id,
3802 "tokensBefore": result.tokens_before,
3803 "details": details_value,
3804 },
3805 "aborted": false,
3806 "willRetry": false,
3807 })));
3808 }
3809 Err(err) => {
3810 let _ = out_tx.send(event(&json!({
3811 "type": "auto_compaction_end",
3812 "result": Value::Null,
3813 "aborted": false,
3814 "willRetry": false,
3815 "errorMessage": err.to_string(),
3816 })));
3817 }
3818 }
3819}
3820
3821fn rpc_model_from_entry(entry: &ModelEntry) -> Value {
3822 let input = entry
3823 .model
3824 .input
3825 .iter()
3826 .map(|t| match t {
3827 crate::provider::InputType::Text => "text",
3828 crate::provider::InputType::Image => "image",
3829 })
3830 .collect::<Vec<_>>();
3831
3832 json!({
3833 "id": entry.model.id,
3834 "name": entry.model.name,
3835 "api": entry.model.api,
3836 "provider": entry.model.provider,
3837 "baseUrl": entry.model.base_url,
3838 "reasoning": entry.model.reasoning,
3839 "input": input,
3840 "contextWindow": entry.model.context_window,
3841 "maxTokens": entry.model.max_tokens,
3842 "cost": entry.model.cost,
3843 })
3844}
3845
3846fn session_state(
3847 session: &crate::session::Session,
3848 options: &RpcOptions,
3849 snapshot: &RpcStateSnapshot,
3850 is_streaming: bool,
3851 is_compacting: bool,
3852) -> Value {
3853 let model = session
3854 .header
3855 .provider
3856 .as_deref()
3857 .zip(session.header.model_id.as_deref())
3858 .and_then(|(provider, model_id)| {
3859 options.available_models.iter().find(|m| {
3860 provider_ids_match(&m.model.provider, provider)
3861 && m.model.id.eq_ignore_ascii_case(model_id)
3862 })
3863 })
3864 .map(rpc_model_from_entry);
3865
3866 let message_count = session
3867 .entries_for_current_path()
3868 .iter()
3869 .filter(|entry| matches!(entry, crate::session::SessionEntry::Message(_)))
3870 .count();
3871
3872 let session_name = session
3873 .entries_for_current_path()
3874 .iter()
3875 .rev()
3876 .find_map(|entry| {
3877 let crate::session::SessionEntry::SessionInfo(info) = entry else {
3878 return None;
3879 };
3880 info.name.clone()
3881 });
3882
3883 let mut state = serde_json::Map::new();
3884 state.insert("model".to_string(), model.unwrap_or(Value::Null));
3885 state.insert(
3886 "thinkingLevel".to_string(),
3887 Value::String(
3888 session
3889 .header
3890 .thinking_level
3891 .clone()
3892 .unwrap_or_else(|| "off".to_string()),
3893 ),
3894 );
3895 state.insert("isStreaming".to_string(), Value::Bool(is_streaming));
3896 state.insert("isCompacting".to_string(), Value::Bool(is_compacting));
3897 state.insert(
3898 "steeringMode".to_string(),
3899 Value::String(snapshot.steering_mode.as_str().to_string()),
3900 );
3901 state.insert(
3902 "followUpMode".to_string(),
3903 Value::String(snapshot.follow_up_mode.as_str().to_string()),
3904 );
3905 state.insert(
3906 "sessionFile".to_string(),
3907 session
3908 .path
3909 .as_ref()
3910 .map_or(Value::Null, |p| Value::String(p.display().to_string())),
3911 );
3912 state.insert(
3913 "sessionId".to_string(),
3914 Value::String(session.header.id.clone()),
3915 );
3916 state.insert(
3917 "sessionName".to_string(),
3918 session_name.map_or(Value::Null, Value::String),
3919 );
3920 state.insert(
3921 "autoCompactionEnabled".to_string(),
3922 Value::Bool(snapshot.auto_compaction_enabled),
3923 );
3924 state.insert(
3925 "autoRetryEnabled".to_string(),
3926 Value::Bool(snapshot.auto_retry_enabled),
3927 );
3928 state.insert(
3929 "messageCount".to_string(),
3930 Value::Number(message_count.into()),
3931 );
3932 state.insert(
3933 "pendingMessageCount".to_string(),
3934 Value::Number(snapshot.pending_count().into()),
3935 );
3936 state.insert(
3937 "durabilityMode".to_string(),
3938 Value::String(session.autosave_durability_mode().as_str().to_string()),
3939 );
3940 Value::Object(state)
3941}
3942
3943fn session_stats(session: &crate::session::Session) -> Value {
3944 let mut user_messages: u64 = 0;
3945 let mut assistant_messages: u64 = 0;
3946 let mut tool_results: u64 = 0;
3947 let mut tool_calls: u64 = 0;
3948
3949 let mut total_input: u64 = 0;
3950 let mut total_output: u64 = 0;
3951 let mut total_cache_read: u64 = 0;
3952 let mut total_cache_write: u64 = 0;
3953 let mut total_cost: f64 = 0.0;
3954
3955 let messages = session.to_messages_for_current_path();
3956
3957 for message in &messages {
3958 match message {
3959 Message::User(_) | Message::Custom(_) => user_messages += 1,
3960 Message::Assistant(message) => {
3961 assistant_messages += 1;
3962 tool_calls += message
3963 .content
3964 .iter()
3965 .filter(|block| matches!(block, ContentBlock::ToolCall(_)))
3966 .count() as u64;
3967 total_input += message.usage.input;
3968 total_output += message.usage.output;
3969 total_cache_read += message.usage.cache_read;
3970 total_cache_write += message.usage.cache_write;
3971 total_cost += message.usage.cost.total;
3972 }
3973 Message::ToolResult(_) => tool_results += 1,
3974 }
3975 }
3976
3977 let total_messages = messages.len() as u64;
3978
3979 let total_tokens = total_input + total_output + total_cache_read + total_cache_write;
3980 let autosave = session.autosave_metrics();
3981 let pending_message_count = autosave.pending_mutations as u64;
3982 let durability_mode = session.autosave_durability_mode();
3983 let durability_mode_label = match durability_mode {
3984 crate::session::AutosaveDurabilityMode::Strict => "strict",
3985 crate::session::AutosaveDurabilityMode::Balanced => "balanced",
3986 crate::session::AutosaveDurabilityMode::Throughput => "throughput",
3987 };
3988 let (status_event, status_severity, status_summary, status_action, status_sli_ids) =
3989 if pending_message_count == 0 {
3990 (
3991 "session.persistence.healthy",
3992 "ok",
3993 "Persistence queue is clear.",
3994 "No action required.",
3995 vec!["sli_resume_ready_p95_ms"],
3996 )
3997 } else {
3998 let summary = match durability_mode {
3999 crate::session::AutosaveDurabilityMode::Strict => {
4000 "Pending persistence backlog under strict durability mode."
4001 }
4002 crate::session::AutosaveDurabilityMode::Balanced => {
4003 "Pending persistence backlog under balanced durability mode."
4004 }
4005 crate::session::AutosaveDurabilityMode::Throughput => {
4006 "Pending persistence backlog under throughput durability mode."
4007 }
4008 };
4009 let action = match durability_mode {
4010 crate::session::AutosaveDurabilityMode::Throughput => {
4011 "Expect deferred writes; trigger manual save before critical transitions."
4012 }
4013 _ => "Allow autosave flush to complete or trigger manual save before exit.",
4014 };
4015 (
4016 "session.persistence.backlog",
4017 "warning",
4018 summary,
4019 action,
4020 vec![
4021 "sli_resume_ready_p95_ms",
4022 "sli_failure_recovery_success_rate",
4023 ],
4024 )
4025 };
4026
4027 let mut data = serde_json::Map::new();
4028 data.insert(
4029 "sessionFile".to_string(),
4030 session
4031 .path
4032 .as_ref()
4033 .map_or(Value::Null, |p| Value::String(p.display().to_string())),
4034 );
4035 data.insert(
4036 "sessionId".to_string(),
4037 Value::String(session.header.id.clone()),
4038 );
4039 data.insert(
4040 "userMessages".to_string(),
4041 Value::Number(user_messages.into()),
4042 );
4043 data.insert(
4044 "assistantMessages".to_string(),
4045 Value::Number(assistant_messages.into()),
4046 );
4047 data.insert("toolCalls".to_string(), Value::Number(tool_calls.into()));
4048 data.insert(
4049 "toolResults".to_string(),
4050 Value::Number(tool_results.into()),
4051 );
4052 data.insert(
4053 "totalMessages".to_string(),
4054 Value::Number(total_messages.into()),
4055 );
4056 data.insert(
4057 "durabilityMode".to_string(),
4058 Value::String(durability_mode_label.to_string()),
4059 );
4060 data.insert(
4061 "pendingMessageCount".to_string(),
4062 Value::Number(pending_message_count.into()),
4063 );
4064 data.insert(
4065 "tokens".to_string(),
4066 json!({
4067 "input": total_input,
4068 "output": total_output,
4069 "cacheRead": total_cache_read,
4070 "cacheWrite": total_cache_write,
4071 "total": total_tokens,
4072 }),
4073 );
4074 data.insert(
4075 "persistenceStatus".to_string(),
4076 json!({
4077 "event": status_event,
4078 "severity": status_severity,
4079 "summary": status_summary,
4080 "action": status_action,
4081 "sliIds": status_sli_ids,
4082 "pendingMessageCount": pending_message_count,
4083 "flushCounters": {
4084 "started": autosave.flush_started,
4085 "succeeded": autosave.flush_succeeded,
4086 "failed": autosave.flush_failed,
4087 },
4088 }),
4089 );
4090 data.insert(
4091 "uxEventMarkers".to_string(),
4092 json!([
4093 {
4094 "event": status_event,
4095 "severity": status_severity,
4096 "durabilityMode": durability_mode_label,
4097 "pendingMessageCount": pending_message_count,
4098 "sliIds": status_sli_ids,
4099 }
4100 ]),
4101 );
4102 data.insert("cost".to_string(), Value::from(total_cost));
4103 Value::Object(data)
4104}
4105
4106fn last_assistant_text(session: &crate::session::Session) -> Option<String> {
4107 let entries = session.entries_for_current_path();
4108 for entry in entries.into_iter().rev() {
4109 let crate::session::SessionEntry::Message(msg_entry) = entry else {
4110 continue;
4111 };
4112 let SessionMessage::Assistant { message } = &msg_entry.message else {
4113 continue;
4114 };
4115 let mut text = String::new();
4116 for block in &message.content {
4117 if let ContentBlock::Text(t) = block {
4118 text.push_str(&t.text);
4119 }
4120 }
4121 if !text.is_empty() {
4122 return Some(text);
4123 }
4124 }
4125 None
4126}
4127
4128async fn export_html_snapshot(
4133 snapshot: &crate::session::ExportSnapshot,
4134 output_path: Option<&str>,
4135) -> Result<String> {
4136 let html = snapshot.to_html();
4137
4138 let path = output_path.map_or_else(
4139 || {
4140 snapshot.path.as_ref().map_or_else(
4141 || {
4142 let ts = chrono::Utc::now().format("%Y-%m-%dT%H-%M-%S%.3fZ");
4143 PathBuf::from(format!("pi-session-{ts}.html"))
4144 },
4145 |session_path| {
4146 let basename = session_path
4147 .file_stem()
4148 .and_then(|s| s.to_str())
4149 .unwrap_or("session");
4150 PathBuf::from(format!("pi-session-{basename}.html"))
4151 },
4152 )
4153 },
4154 PathBuf::from,
4155 );
4156
4157 if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) {
4158 asupersync::fs::create_dir_all(parent).await?;
4159 }
4160 asupersync::fs::write(&path, html).await?;
4161 Ok(path.display().to_string())
4162}
4163
4164#[derive(Debug, Clone)]
4165struct BashRpcResult {
4166 output: String,
4167 exit_code: i32,
4168 cancelled: bool,
4169 truncated: bool,
4170 full_output_path: Option<String>,
4171}
4172
4173fn abandon_bash_rpc_spill_file(
4174 temp_file: &mut Option<asupersync::fs::File>,
4175 temp_file_path: &mut Option<PathBuf>,
4176 spill_failed: &mut bool,
4177) {
4178 *spill_failed = true;
4179 *temp_file = None;
4180 if let Some(path) = temp_file_path.take() {
4181 if let Err(e) = std::fs::remove_file(&path)
4182 && e.kind() != std::io::ErrorKind::NotFound
4183 {
4184 tracing::debug!(
4185 "Failed to remove incomplete RPC bash spill file {}: {}",
4186 path.display(),
4187 e
4188 );
4189 }
4190 }
4191}
4192
4193const fn line_count_from_newline_count(
4194 total_bytes: usize,
4195 newline_count: usize,
4196 last_byte_was_newline: bool,
4197) -> usize {
4198 if total_bytes == 0 {
4199 0
4200 } else if last_byte_was_newline {
4201 newline_count
4202 } else {
4203 newline_count.saturating_add(1)
4204 }
4205}
4206
4207async fn ingest_bash_rpc_chunk(
4208 bytes: Vec<u8>,
4209 chunks: &mut VecDeque<Vec<u8>>,
4210 chunks_bytes: &mut usize,
4211 total_bytes: &mut usize,
4212 total_lines: &mut usize,
4213 last_byte_was_newline: &mut bool,
4214 temp_file: &mut Option<asupersync::fs::File>,
4215 temp_file_path: &mut Option<PathBuf>,
4216 spill_failed: &mut bool,
4217 max_chunks_bytes: usize,
4218) {
4219 if bytes.is_empty() {
4220 return;
4221 }
4222
4223 *last_byte_was_newline = bytes.last().is_some_and(|byte| *byte == b'\n');
4224 *total_bytes = total_bytes.saturating_add(bytes.len());
4225 *total_lines = total_lines.saturating_add(memchr_iter(b'\n', &bytes).count());
4226
4227 if *total_bytes > DEFAULT_MAX_BYTES && temp_file.is_none() && !*spill_failed {
4229 let id_full = uuid::Uuid::new_v4().simple().to_string();
4230 let id = &id_full[..16];
4231 let path = std::env::temp_dir().join(format!("pi-rpc-bash-{id}.log"));
4232
4233 let path_clone = path.clone();
4235 let expected_inode: Option<u64> =
4236 asupersync::runtime::spawn_blocking_io(move || -> std::io::Result<Option<u64>> {
4237 let mut options = std::fs::OpenOptions::new();
4238 options.write(true).create_new(true);
4239 #[cfg(unix)]
4240 {
4241 use std::os::unix::fs::OpenOptionsExt;
4242 options.mode(0o600);
4243 }
4244
4245 match options.open(&path_clone) {
4246 Ok(file) => {
4247 #[cfg(unix)]
4248 {
4249 use std::os::unix::fs::MetadataExt;
4250 Ok(file.metadata().ok().map(|m| m.ino()))
4251 }
4252 #[cfg(not(unix))]
4253 {
4254 drop(file);
4255 Ok(None)
4256 }
4257 }
4258 Err(e) => {
4259 tracing::warn!("Failed to create bash temp file: {e}");
4260 Ok(None)
4261 }
4262 }
4263 })
4264 .await
4265 .unwrap_or(None);
4266
4267 if expected_inode.is_some() || !cfg!(unix) {
4268 match asupersync::fs::OpenOptions::new()
4270 .append(true)
4271 .open(&path)
4272 .await
4273 {
4274 Ok(mut file) => {
4275 #[cfg_attr(not(unix), allow(unused_mut))]
4277 let mut identity_match = true;
4278 #[cfg(unix)]
4279 if let Some(expected) = expected_inode {
4280 use std::os::unix::fs::MetadataExt;
4281 match file.metadata().await {
4282 Ok(meta) => {
4283 if meta.ino() != expected {
4284 tracing::warn!(
4285 "Temp file identity mismatch (possible TOCTOU attack)"
4286 );
4287 identity_match = false;
4288 }
4289 }
4290 Err(e) => {
4291 tracing::warn!("Failed to stat temp file: {e}");
4292 identity_match = false;
4293 }
4294 }
4295 }
4296
4297 if identity_match {
4298 let mut failed_flush = false;
4300 for existing in chunks.iter() {
4301 use asupersync::io::AsyncWriteExt;
4302 if let Err(e) = file.write_all(existing).await {
4303 tracing::warn!("Failed to flush bash chunk to temp file: {e}");
4304 failed_flush = true;
4305 break;
4306 }
4307 }
4308 *temp_file_path = Some(path);
4309 if failed_flush {
4310 abandon_bash_rpc_spill_file(temp_file, temp_file_path, spill_failed);
4311 } else {
4312 *temp_file = Some(file);
4313 }
4314 } else {
4315 *temp_file_path = Some(path);
4316 abandon_bash_rpc_spill_file(temp_file, temp_file_path, spill_failed);
4317 }
4318 }
4319 Err(e) => {
4320 tracing::warn!("Failed to reopen bash temp file async: {e}");
4321 *temp_file_path = Some(path);
4322 abandon_bash_rpc_spill_file(temp_file, temp_file_path, spill_failed);
4323 }
4324 }
4325 } else {
4326 *spill_failed = true;
4327 }
4328 }
4329
4330 let mut abandon_spill_file = false;
4332 let mut close_spill_file = false;
4333 if let Some(file) = temp_file.as_mut() {
4334 if *total_bytes <= crate::tools::BASH_FILE_LIMIT_BYTES {
4335 use asupersync::io::AsyncWriteExt;
4336 if let Err(e) = file.write_all(&bytes).await {
4337 tracing::warn!("Failed to write bash chunk to temp file: {e}");
4338 abandon_spill_file = true;
4339 }
4340 } else {
4341 if !*spill_failed {
4343 tracing::warn!("Bash output exceeded hard limit; stopping file log");
4344 close_spill_file = true;
4345 *spill_failed = true;
4346 }
4347 }
4348 }
4349 if abandon_spill_file || close_spill_file {
4350 abandon_bash_rpc_spill_file(temp_file, temp_file_path, spill_failed);
4351 }
4352
4353 *chunks_bytes = chunks_bytes.saturating_add(bytes.len());
4355 chunks.push_back(bytes);
4356 while *chunks_bytes > max_chunks_bytes && chunks.len() > 1 {
4357 if let Some(front) = chunks.pop_front() {
4358 *chunks_bytes = chunks_bytes.saturating_sub(front.len());
4359 }
4360 }
4361}
4362
4363async fn run_bash_rpc(
4364 cwd: &std::path::Path,
4365 command: &str,
4366 mut abort_rx: oneshot::Receiver<()>,
4367) -> Result<BashRpcResult> {
4368 #[derive(Clone, Copy)]
4369 enum StreamKind {
4370 Stdout,
4371 Stderr,
4372 }
4373
4374 struct StreamChunk {
4375 kind: StreamKind,
4376 bytes: Vec<u8>,
4377 }
4378
4379 fn pump_stream(
4380 mut reader: impl std::io::Read,
4381 tx: std::sync::mpsc::SyncSender<StreamChunk>,
4382 kind: StreamKind,
4383 ) {
4384 let mut buf = [0u8; 8192];
4385 loop {
4386 let read = match reader.read(&mut buf) {
4387 Ok(0) => break,
4388 Ok(read) => read,
4389 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
4390 Err(_) => break,
4391 };
4392 let chunk = StreamChunk {
4393 kind,
4394 bytes: buf[..read].to_vec(),
4395 };
4396 if tx.send(chunk).is_err() {
4397 break;
4398 }
4399 }
4400 }
4401
4402 let shell = ["/bin/bash", "/usr/bin/bash", "/usr/local/bin/bash"]
4403 .into_iter()
4404 .find(|p| std::path::Path::new(p).exists())
4405 .unwrap_or("sh");
4406
4407 let command = format!("trap 'code=$?; wait; exit $code' EXIT\n{command}");
4408
4409 let mut child = std::process::Command::new(shell);
4410 child
4411 .arg("-c")
4412 .arg(&command)
4413 .current_dir(cwd)
4414 .stdin(std::process::Stdio::null())
4415 .stdout(std::process::Stdio::piped())
4416 .stderr(std::process::Stdio::piped());
4417 crate::tools::isolate_command_process_group(&mut child);
4418 let mut child = child
4419 .spawn()
4420 .map_err(|e| Error::tool("bash", format!("Failed to spawn shell: {e}")))?;
4421
4422 let Some(stdout) = child.stdout.take() else {
4423 return Err(Error::tool("bash", "Missing stdout".to_string()));
4424 };
4425 let Some(stderr) = child.stderr.take() else {
4426 return Err(Error::tool("bash", "Missing stderr".to_string()));
4427 };
4428
4429 let mut guard =
4430 crate::tools::ProcessGuard::new(child, crate::tools::ProcessCleanupMode::ProcessGroupTree);
4431
4432 let (tx, rx) = std::sync::mpsc::sync_channel::<StreamChunk>(1024);
4438 let tx_stdout = tx.clone();
4439 let _stdout_handle =
4440 std::thread::spawn(move || pump_stream(stdout, tx_stdout, StreamKind::Stdout));
4441 let _stderr_handle = std::thread::spawn(move || pump_stream(stderr, tx, StreamKind::Stderr));
4442
4443 let tick = Duration::from_millis(10);
4444 let cx = asupersync::Cx::current().unwrap_or_else(asupersync::Cx::for_request);
4445
4446 let mut chunks: VecDeque<Vec<u8>> = VecDeque::new();
4448 let mut chunks_bytes = 0usize;
4449 let mut total_bytes = 0usize;
4450 let mut total_lines = 0usize;
4451 let mut last_byte_was_newline = false;
4452 let mut temp_file: Option<asupersync::fs::File> = None;
4453 let mut temp_file_path: Option<PathBuf> = None;
4454 let max_chunks_bytes = DEFAULT_MAX_BYTES * 2;
4455
4456 let mut cancelled = false;
4457 let mut spill_failed = false;
4458
4459 let exit_code = loop {
4460 while let Ok(chunk) = rx.try_recv() {
4461 ingest_bash_rpc_chunk(
4462 chunk.bytes,
4463 &mut chunks,
4464 &mut chunks_bytes,
4465 &mut total_bytes,
4466 &mut total_lines,
4467 &mut last_byte_was_newline,
4468 &mut temp_file,
4469 &mut temp_file_path,
4470 &mut spill_failed,
4471 max_chunks_bytes,
4472 )
4473 .await;
4474 }
4475
4476 if !cancelled && abort_rx.try_recv().is_ok() {
4477 cancelled = true;
4478 let status_code = guard
4479 .kill()
4480 .map_or(-1, |status| status.code().unwrap_or(-1));
4481 break status_code;
4482 }
4483
4484 match guard.try_wait_child() {
4485 Ok(Some(status)) => break status.code().unwrap_or(-1),
4486 Ok(None) => {}
4487 Err(err) => {
4488 return Err(Error::tool(
4489 "bash",
4490 format!("Failed to wait for process: {err}"),
4491 ));
4492 }
4493 }
4494
4495 if cx.checkpoint().is_err() {
4496 cancelled = true;
4497 let _ = guard.kill();
4498 let status_code = -1;
4499 break status_code;
4500 }
4501
4502 let now = cx.timer_driver().map_or_else(wall_now, |timer| timer.now());
4503 sleep(now, tick).await;
4504 };
4505
4506 let now_drain = cx.timer_driver().map_or_else(wall_now, |timer| timer.now());
4508 let drain_deadline = now_drain + std::time::Duration::from_secs(2);
4509 let mut drain_timed_out = false;
4510 loop {
4511 match rx.try_recv() {
4512 Ok(chunk) => {
4513 ingest_bash_rpc_chunk(
4514 chunk.bytes,
4515 &mut chunks,
4516 &mut chunks_bytes,
4517 &mut total_bytes,
4518 &mut total_lines,
4519 &mut last_byte_was_newline,
4520 &mut temp_file,
4521 &mut temp_file_path,
4522 &mut spill_failed,
4523 max_chunks_bytes,
4524 )
4525 .await;
4526 }
4527 Err(std::sync::mpsc::TryRecvError::Empty) => {
4528 let now = cx.timer_driver().map_or_else(wall_now, |timer| timer.now());
4529 if now >= drain_deadline {
4530 drain_timed_out = true;
4531 break;
4532 }
4533 if cx.checkpoint().is_err() {
4534 cancelled = true;
4535 break;
4536 }
4537 sleep(now, tick).await;
4538 }
4539 Err(std::sync::mpsc::TryRecvError::Disconnected) => break,
4540 }
4541 }
4542
4543 drop(rx);
4550
4551 let _ = guard.wait();
4555
4556 drop(temp_file);
4559
4560 let mut combined = Vec::with_capacity(chunks_bytes);
4562 for chunk in chunks {
4563 combined.extend_from_slice(&chunk);
4564 }
4565 let tail_output = String::from_utf8_lossy(&combined).to_string();
4566
4567 let mut truncation = truncate_tail(tail_output, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
4568 if total_bytes > chunks_bytes {
4569 truncation.truncated = true;
4570 truncation.truncated_by = Some(crate::tools::TruncatedBy::Bytes);
4571 truncation.total_bytes = total_bytes;
4572 truncation.total_lines =
4573 line_count_from_newline_count(total_bytes, total_lines, last_byte_was_newline);
4574 } else if drain_timed_out {
4575 truncation.truncated = true;
4576 truncation.truncated_by = Some(crate::tools::TruncatedBy::Bytes);
4577 }
4578 let will_truncate = truncation.truncated;
4579
4580 let mut output_text = if truncation.content.is_empty() {
4581 "(no output)".to_string()
4582 } else {
4583 truncation.content
4584 };
4585
4586 if drain_timed_out {
4587 output_text.push_str("\n... [Output truncated: drain timeout]");
4588 }
4589
4590 Ok(BashRpcResult {
4591 output: output_text,
4592 exit_code,
4593 cancelled,
4594 truncated: will_truncate,
4595 full_output_path: temp_file_path.map(|p| p.display().to_string()),
4596 })
4597}
4598
4599fn parse_prompt_images(value: Option<&Value>) -> Result<Vec<ImageContent>> {
4600 let Some(value) = value else {
4601 return Ok(Vec::new());
4602 };
4603 let Some(arr) = value.as_array() else {
4604 return Err(Error::validation("images must be an array"));
4605 };
4606
4607 let mut images = Vec::new();
4608 for item in arr {
4609 let Some(obj) = item.as_object() else {
4610 continue;
4611 };
4612 let item_type = obj.get("type").and_then(Value::as_str).unwrap_or("");
4613 if item_type != "image" {
4614 continue;
4615 }
4616 let Some(source) = obj.get("source").and_then(Value::as_object) else {
4617 continue;
4618 };
4619 let source_type = source.get("type").and_then(Value::as_str).unwrap_or("");
4620 if source_type != "base64" {
4621 continue;
4622 }
4623 let Some(media_type) = source.get("mediaType").and_then(Value::as_str) else {
4624 continue;
4625 };
4626 let Some(data) = source.get("data").and_then(Value::as_str) else {
4627 continue;
4628 };
4629 images.push(ImageContent {
4630 data: data.to_string(),
4631 mime_type: media_type.to_string(),
4632 });
4633 }
4634 Ok(images)
4635}
4636
4637fn resolve_model_key(
4638 cli_api_key: Option<&str>,
4639 auth: &AuthStorage,
4640 entry: &ModelEntry,
4641) -> Option<String> {
4642 cli_api_key
4643 .and_then(|key| {
4644 let trimmed = key.trim();
4645 (!trimmed.is_empty()).then(|| trimmed.to_string())
4646 })
4647 .or_else(|| normalize_api_key_opt(auth.resolve_api_key(&entry.model.provider, None)))
4648 .or_else(|| normalize_api_key_opt(entry.api_key.clone()))
4649}
4650
4651fn parse_thinking_level(level: &str) -> Result<crate::model::ThinkingLevel> {
4652 level.parse().map_err(|err: String| Error::validation(err))
4653}
4654
4655fn session_thinking_level(
4656 session: &crate::session::Session,
4657) -> Option<crate::model::ThinkingLevel> {
4658 session
4659 .effective_thinking_level_for_current_path()
4660 .as_deref()
4661 .and_then(|raw| {
4662 raw.parse::<crate::model::ThinkingLevel>().map_or_else(
4663 |_| {
4664 tracing::warn!("Ignoring invalid session thinking level in RPC state: {raw}");
4665 None
4666 },
4667 Some,
4668 )
4669 })
4670}
4671
4672fn current_model_entry<'a>(
4673 session: &crate::session::Session,
4674 options: &'a RpcOptions,
4675) -> Option<&'a ModelEntry> {
4676 let (provider, model_id) = session.effective_model_for_current_path()?;
4677 model_entry_for_provider_and_id(&provider, &model_id, options)
4678}
4679
4680fn current_or_runtime_model_entry<'a>(
4681 session: &crate::session::Session,
4682 runtime_provider: &str,
4683 runtime_model_id: &str,
4684 options: &'a RpcOptions,
4685) -> Option<&'a ModelEntry> {
4686 current_model_entry(session, options)
4687 .or_else(|| model_entry_for_provider_and_id(runtime_provider, runtime_model_id, options))
4688}
4689
4690fn model_entry_for_provider_and_id<'a>(
4691 provider: &str,
4692 model_id: &str,
4693 options: &'a RpcOptions,
4694) -> Option<&'a ModelEntry> {
4695 options.available_models.iter().find(|m| {
4696 provider_ids_match(&m.model.provider, provider) && m.model.id.eq_ignore_ascii_case(model_id)
4697 })
4698}
4699
4700async fn apply_thinking_level(
4701 session: Arc<asupersync::sync::Mutex<AgentSession>>,
4702 level: crate::model::ThinkingLevel,
4703) -> Result<()> {
4704 let cx = AgentCx::for_current_or_request();
4705 let level_str = level.to_string();
4706
4707 {
4709 let mut guard = session
4710 .lock(&cx)
4711 .await
4712 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
4713
4714 {
4715 let mut inner_session = guard
4716 .session
4717 .lock(cx.cx())
4718 .await
4719 .map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
4720 let previous = session_thinking_level(&inner_session);
4721 inner_session.header.thinking_level = Some(level_str.clone());
4722 if previous != Some(level) {
4723 inner_session.append_thinking_level_change(level_str);
4724 }
4725 }
4726 guard.agent.stream_options_mut().thinking_level = Some(level);
4727 } let mut guard = session
4731 .lock(&cx)
4732 .await
4733 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
4734 guard.persist_session().await
4735}
4736
4737async fn apply_thinking_level_for_session(
4738 session: Arc<asupersync::sync::Mutex<AgentSession>>,
4739 level: crate::model::ThinkingLevel,
4740 cx: &AgentCx,
4741) -> Result<()> {
4742 {
4744 let mut guard = session
4745 .lock(cx)
4746 .await
4747 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
4748
4749 let level_str = level.to_string();
4750 {
4751 let mut inner_session = guard
4752 .session
4753 .lock(cx.cx())
4754 .await
4755 .map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
4756 let previous = session_thinking_level(&inner_session);
4757 inner_session.header.thinking_level = Some(level_str.clone());
4758 if previous != Some(level) {
4759 inner_session.append_thinking_level_change(level_str);
4760 }
4761 }
4762 guard.agent.stream_options_mut().thinking_level = Some(level);
4763 } let mut guard = session
4767 .lock(cx)
4768 .await
4769 .map_err(|err| Error::session(format!("session lock failed: {err}")))?;
4770 guard.persist_session().await
4771}
4772
4773async fn apply_model_change(guard: &mut AgentSession, entry: &ModelEntry) -> Result<()> {
4774 let cx = AgentCx::for_current_or_request();
4775 {
4776 let mut inner_session = guard
4777 .session
4778 .lock(cx.cx())
4779 .await
4780 .map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
4781 inner_session.header.provider = Some(entry.model.provider.clone());
4782 inner_session.header.model_id = Some(entry.model.id.clone());
4783 inner_session.append_model_change(entry.model.provider.clone(), entry.model.id.clone());
4784 }
4785 guard.persist_session().await
4786}
4787
4788fn fork_messages_from_entries(entries: &[crate::session::SessionEntry]) -> Vec<Value> {
4793 let mut result = Vec::new();
4794
4795 for entry in entries {
4796 let crate::session::SessionEntry::Message(m) = entry else {
4797 continue;
4798 };
4799 let SessionMessage::User { content, .. } = &m.message else {
4800 continue;
4801 };
4802 let entry_id = m.base.id.clone().unwrap_or_default();
4803 let text = extract_user_text(content);
4804 result.push(json!({
4805 "entryId": entry_id,
4806 "text": text,
4807 }));
4808 }
4809
4810 result
4811}
4812
4813fn extract_user_text(content: &crate::model::UserContent) -> Option<String> {
4814 match content {
4815 crate::model::UserContent::Text(text) => Some(text.clone()),
4816 crate::model::UserContent::Blocks(blocks) => blocks.iter().find_map(|b| {
4817 if let ContentBlock::Text(t) = b {
4818 Some(t.text.clone())
4819 } else {
4820 None
4821 }
4822 }),
4823 }
4824}
4825
4826fn available_thinking_levels(entry: &ModelEntry) -> Vec<crate::model::ThinkingLevel> {
4829 use crate::model::ThinkingLevel;
4830 if entry.model.reasoning {
4831 let mut levels = vec![
4832 ThinkingLevel::Off,
4833 ThinkingLevel::Minimal,
4834 ThinkingLevel::Low,
4835 ThinkingLevel::Medium,
4836 ThinkingLevel::High,
4837 ];
4838 if entry.supports_xhigh() {
4839 levels.push(ThinkingLevel::XHigh);
4840 }
4841 levels
4842 } else {
4843 vec![ThinkingLevel::Off]
4844 }
4845}
4846
4847async fn cycle_model_for_rpc(
4850 guard: &mut AgentSession,
4851 options: &RpcOptions,
4852) -> Result<Option<(ModelEntry, crate::model::ThinkingLevel, bool)>> {
4853 let (candidates, is_scoped) = if options.scoped_models.is_empty() {
4854 (options.available_models.clone(), false)
4855 } else {
4856 (
4857 options
4858 .scoped_models
4859 .iter()
4860 .map(|sm| sm.model.clone())
4861 .collect::<Vec<_>>(),
4862 true,
4863 )
4864 };
4865
4866 if candidates.len() <= 1 {
4867 return Ok(None);
4868 }
4869
4870 let cx = AgentCx::for_current_or_request();
4871 let runtime_provider = guard.agent.provider().name().to_string();
4872 let runtime_model_id = guard.agent.provider().model_id().to_string();
4873 let (current_provider, current_model_id) = {
4874 let inner_session = guard
4875 .session
4876 .lock(cx.cx())
4877 .await
4878 .map_err(|err| Error::session(format!("inner session lock failed: {err}")))?;
4879 current_or_runtime_model_entry(
4880 &inner_session,
4881 &runtime_provider,
4882 &runtime_model_id,
4883 options,
4884 )
4885 .map_or_else(
4886 || {
4887 (
4888 inner_session.header.provider.clone(),
4889 inner_session.header.model_id.clone(),
4890 )
4891 },
4892 |entry| {
4893 (
4894 Some(entry.model.provider.clone()),
4895 Some(entry.model.id.clone()),
4896 )
4897 },
4898 )
4899 };
4900
4901 let current_index = candidates.iter().position(|entry| {
4902 current_provider
4903 .as_deref()
4904 .is_some_and(|provider| provider_ids_match(provider, &entry.model.provider))
4905 && current_model_id
4906 .as_deref()
4907 .is_some_and(|model_id| model_id.eq_ignore_ascii_case(&entry.model.id))
4908 });
4909
4910 let next_index = current_index.map_or(0, |idx| (idx + 1) % candidates.len());
4911
4912 let next_entry = candidates[next_index].clone();
4913 let key = resolve_model_key(options.cli_api_key.as_deref(), &options.auth, &next_entry);
4914 if model_requires_configured_credential(&next_entry) && key.is_none() {
4915 return Err(Error::auth(format!(
4916 "Missing credentials for {}/{}",
4917 next_entry.model.provider, next_entry.model.id
4918 )));
4919 }
4920
4921 let provider_impl = crate::providers::create_provider(
4922 &next_entry,
4923 guard
4924 .extensions
4925 .as_ref()
4926 .map(crate::extensions::ExtensionRegion::manager),
4927 )?;
4928 guard.agent.set_provider(provider_impl);
4929
4930 guard.agent.stream_options_mut().api_key.clone_from(&key);
4931 guard
4932 .agent
4933 .stream_options_mut()
4934 .headers
4935 .clone_from(&next_entry.headers);
4936
4937 apply_model_change(guard, &next_entry).await?;
4938
4939 let desired_thinking = if is_scoped {
4940 options.scoped_models[next_index]
4941 .thinking_level
4942 .unwrap_or(crate::model::ThinkingLevel::Off)
4943 } else {
4944 guard
4945 .agent
4946 .stream_options()
4947 .thinking_level
4948 .unwrap_or_default()
4949 };
4950
4951 let next_thinking = next_entry.clamp_thinking_level(desired_thinking);
4952
4953 Ok(Some((next_entry, next_thinking, is_scoped)))
4954}
4955
4956#[cfg(test)]
4957mod tests {
4958 use super::*;
4959 use crate::agent::{Agent, AgentConfig};
4960 use crate::auth::AuthCredential;
4961 use crate::model::{
4962 AssistantMessage, ContentBlock, ImageContent, StopReason, TextContent, ThinkingLevel,
4963 Usage, UserContent, UserMessage,
4964 };
4965 use crate::package_manager::PackageManager;
4966 use crate::provider::{InputType, Model, ModelCost, Provider};
4967 use crate::resources::{ResourceCliOptions, ResourceLoader};
4968 use crate::session::Session;
4969 use crate::tools::ToolRegistry;
4970 use async_trait::async_trait;
4971 use futures::stream;
4972 use serde_json::json;
4973 use std::collections::HashMap;
4974 use std::path::Path;
4975 use std::path::PathBuf;
4976 use std::pin::Pin;
4977 use std::sync::mpsc::{Receiver, TryRecvError};
4978 use std::sync::{Arc, Mutex};
4979 use std::time::{Duration, Instant};
4980
4981 fn dummy_model(id: &str, reasoning: bool) -> Model {
4986 Model {
4987 id: id.to_string(),
4988 name: id.to_string(),
4989 api: "anthropic".to_string(),
4990 provider: "anthropic".to_string(),
4991 base_url: "https://api.anthropic.com".to_string(),
4992 reasoning,
4993 input: vec![InputType::Text],
4994 cost: ModelCost {
4995 input: 3.0,
4996 output: 15.0,
4997 cache_read: 0.3,
4998 cache_write: 3.75,
4999 },
5000 context_window: 200_000,
5001 max_tokens: 8192,
5002 headers: HashMap::new(),
5003 }
5004 }
5005
5006 fn dummy_entry(id: &str, reasoning: bool) -> ModelEntry {
5007 ModelEntry {
5008 model: dummy_model(id, reasoning),
5009 api_key: None,
5010 headers: HashMap::new(),
5011 auth_header: false,
5012 compat: None,
5013 oauth_config: None,
5014 }
5015 }
5016
5017 fn rpc_options_with_models(available_models: Vec<ModelEntry>) -> RpcOptions {
5018 let runtime = asupersync::runtime::RuntimeBuilder::new()
5019 .blocking_threads(1, 1)
5020 .build()
5021 .expect("runtime build");
5022 let runtime_handle = runtime.handle();
5023
5024 let auth_path = tempfile::tempdir()
5025 .expect("tempdir")
5026 .path()
5027 .join("auth.json");
5028 let auth = AuthStorage::load(auth_path).expect("auth load");
5029
5030 RpcOptions {
5031 config: Config::default(),
5032 resources: ResourceLoader::empty(false),
5033 available_models,
5034 scoped_models: Vec::new(),
5035 cli_api_key: None,
5036 auth,
5037 runtime_handle,
5038 }
5039 }
5040
5041 #[derive(Debug)]
5042 struct NoopProvider;
5043
5044 #[async_trait]
5045 #[allow(clippy::unnecessary_literal_bound)]
5046 impl Provider for NoopProvider {
5047 fn name(&self) -> &str {
5048 "test-provider"
5049 }
5050
5051 fn api(&self) -> &str {
5052 "test-api"
5053 }
5054
5055 fn model_id(&self) -> &str {
5056 "test-model"
5057 }
5058
5059 async fn stream(
5060 &self,
5061 _context: &crate::provider::Context<'_>,
5062 _options: &crate::provider::StreamOptions,
5063 ) -> crate::error::Result<
5064 Pin<
5065 Box<
5066 dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
5067 + Send,
5068 >,
5069 >,
5070 > {
5071 let message = AssistantMessage {
5072 content: Vec::new(),
5073 api: self.api().to_string(),
5074 provider: self.name().to_string(),
5075 model: self.model_id().to_string(),
5076 usage: Usage::default(),
5077 stop_reason: StopReason::Stop,
5078 error_message: None,
5079 timestamp: 0,
5080 };
5081 Ok(Box::pin(stream::iter(vec![
5082 Ok(crate::model::StreamEvent::Start {
5083 partial: message.clone(),
5084 }),
5085 Ok(crate::model::StreamEvent::Done {
5086 reason: StopReason::Stop,
5087 message,
5088 }),
5089 ])))
5090 }
5091 }
5092
5093 #[derive(Default)]
5094 struct RpcDeadlineProbeState {
5095 calls: std::sync::atomic::AtomicUsize,
5096 observed_deadlines: Mutex<Vec<Option<asupersync::Time>>>,
5097 }
5098
5099 struct RpcDeadlineProbeProvider {
5100 state: Arc<RpcDeadlineProbeState>,
5101 }
5102
5103 impl RpcDeadlineProbeProvider {
5104 fn assistant_message(&self) -> AssistantMessage {
5105 AssistantMessage {
5106 content: Vec::new(),
5107 api: self.api().to_string(),
5108 provider: self.name().to_string(),
5109 model: self.model_id().to_string(),
5110 usage: Usage::default(),
5111 stop_reason: StopReason::Stop,
5112 error_message: None,
5113 timestamp: 0,
5114 }
5115 }
5116 }
5117
5118 #[async_trait]
5119 #[allow(clippy::unnecessary_literal_bound)]
5120 impl Provider for RpcDeadlineProbeProvider {
5121 fn name(&self) -> &str {
5122 "deadline-probe"
5123 }
5124
5125 fn api(&self) -> &str {
5126 "deadline-probe"
5127 }
5128
5129 fn model_id(&self) -> &str {
5130 "deadline-probe-model"
5131 }
5132
5133 async fn stream(
5134 &self,
5135 _context: &crate::provider::Context<'_>,
5136 _options: &crate::provider::StreamOptions,
5137 ) -> crate::error::Result<
5138 Pin<
5139 Box<
5140 dyn futures::Stream<Item = crate::error::Result<crate::model::StreamEvent>>
5141 + Send,
5142 >,
5143 >,
5144 > {
5145 self.state.calls.fetch_add(1, Ordering::SeqCst);
5146 let deadline = asupersync::Cx::current().and_then(|cx| cx.budget().deadline);
5147 self.state
5148 .observed_deadlines
5149 .lock()
5150 .expect("lock rpc deadline probe")
5151 .push(deadline);
5152
5153 let message = self.assistant_message();
5154 Ok(Box::pin(stream::iter(vec![
5155 Ok(crate::model::StreamEvent::Start {
5156 partial: message.clone(),
5157 }),
5158 Ok(crate::model::StreamEvent::Done {
5159 reason: StopReason::Stop,
5160 message,
5161 }),
5162 ])))
5163 }
5164 }
5165
5166 fn build_test_agent_session_with_provider(
5167 session: Session,
5168 provider: Arc<dyn Provider>,
5169 ) -> AgentSession {
5170 let tools = ToolRegistry::new(&[], &std::env::current_dir().expect("current dir"), None);
5171 let agent = crate::agent::Agent::new(provider, tools, crate::agent::AgentConfig::default());
5172 let session = Arc::new(asupersync::sync::Mutex::new(session));
5173 AgentSession::new(
5174 agent,
5175 session,
5176 false,
5177 crate::compaction::ResolvedCompactionSettings::default(),
5178 )
5179 }
5180
5181 fn build_test_agent_session(session: Session) -> AgentSession {
5182 let provider: Arc<dyn Provider> = Arc::new(NoopProvider);
5183 build_test_agent_session_with_provider(session, provider)
5184 }
5185
5186 fn build_test_rpc_options(
5187 handle: &asupersync::runtime::RuntimeHandle,
5188 auth_path: PathBuf,
5189 ) -> RpcOptions {
5190 let auth = AuthStorage::load(auth_path).expect("load auth storage");
5191 RpcOptions {
5192 config: Config::default(),
5193 resources: ResourceLoader::empty(false),
5194 available_models: Vec::new(),
5195 scoped_models: Vec::new(),
5196 cli_api_key: None,
5197 auth,
5198 runtime_handle: handle.clone(),
5199 }
5200 }
5201
5202 async fn load_test_prompt_template_resources(
5203 cwd: &Path,
5204 template_name: &str,
5205 content: &str,
5206 ) -> ResourceLoader {
5207 let prompt_path = cwd.join(format!("{template_name}.md"));
5208 std::fs::write(&prompt_path, content).expect("write prompt template");
5209
5210 let manager = PackageManager::new(cwd.to_path_buf());
5211 let config = crate::config::Config::default();
5212 let cli = ResourceCliOptions {
5213 no_skills: true,
5214 no_prompt_templates: false,
5215 no_extensions: true,
5216 no_themes: true,
5217 skill_paths: Vec::new(),
5218 prompt_paths: vec![prompt_path.to_string_lossy().to_string()],
5219 extension_paths: Vec::new(),
5220 theme_paths: Vec::new(),
5221 };
5222
5223 ResourceLoader::load(&manager, cwd, &config, &cli)
5224 .await
5225 .expect("load prompt template resources")
5226 }
5227
5228 async fn build_queue_state_rpc_fixture(
5229 handle: &asupersync::runtime::RuntimeHandle,
5230 cwd: &Path,
5231 ) -> (AgentSession, RpcOptions) {
5232 let ext_entry_path = cwd.join("queue-state-ext.mjs");
5233 std::fs::write(&ext_entry_path, RPC_QUEUE_STATE_EXTENSION_EXT)
5234 .expect("write extension source");
5235
5236 let mut agent_session = build_test_agent_session(Session::in_memory());
5237 agent_session
5238 .enable_extensions(&[], cwd, None, &[ext_entry_path])
5239 .await
5240 .expect("enable extensions");
5241
5242 let mut options = build_test_rpc_options(handle, cwd.join("auth.json"));
5243 options.resources = load_test_prompt_template_resources(
5244 cwd,
5245 "report-queue-state",
5246 "Prompt template shadow that should not win.\n",
5247 )
5248 .await;
5249
5250 (agent_session, options)
5251 }
5252
5253 async fn recv_line(
5254 rx: &Arc<Mutex<Receiver<String>>>,
5255 label: &str,
5256 ) -> std::result::Result<String, String> {
5257 let start = Instant::now();
5258 loop {
5259 let recv_result = {
5260 let rx = rx.lock().expect("lock rpc output receiver");
5261 rx.try_recv()
5262 };
5263
5264 match recv_result {
5265 Ok(line) => return Ok(line),
5266 Err(TryRecvError::Disconnected) => {
5267 return Err(format!("{label}: output channel disconnected"));
5268 }
5269 Err(TryRecvError::Empty) => {}
5270 }
5271
5272 if start.elapsed() > Duration::from_secs(10) {
5273 return Err(format!("{label}: timed out waiting for output"));
5274 }
5275
5276 asupersync::time::sleep(asupersync::time::wall_now(), Duration::from_millis(5)).await;
5277 }
5278 }
5279
5280 fn parse_response(line: &str) -> Value {
5281 serde_json::from_str(line.trim()).expect("parse JSON response")
5282 }
5283
5284 async fn recv_response(out_rx: &Arc<Mutex<Receiver<String>>>, label: &str) -> Value {
5285 let start = Instant::now();
5286
5287 loop {
5288 let line = recv_line(out_rx, label)
5289 .await
5290 .unwrap_or_else(|err| panic!("{err}"));
5291 let value = parse_response(&line);
5292
5293 match value.get("type").and_then(Value::as_str) {
5294 Some("response") => return value,
5295 Some("agent_end") => {
5296 let has_error = value
5297 .get("error")
5298 .is_some_and(|error| !error.is_null() && error != "");
5299 assert!(
5300 !has_error,
5301 "{label}: unexpected agent_end error while waiting for response: {value}"
5302 );
5303 }
5304 _ => {}
5305 }
5306
5307 assert!(
5308 start.elapsed() <= Duration::from_secs(10),
5309 "{label}: timed out waiting for RPC response"
5310 );
5311 }
5312 }
5313
5314 async fn send_recv(
5315 in_tx: &asupersync::channel::mpsc::Sender<String>,
5316 out_rx: &Arc<Mutex<Receiver<String>>>,
5317 cmd: &str,
5318 label: &str,
5319 ) -> Value {
5320 let cx = asupersync::Cx::for_testing();
5321 in_tx
5322 .send(&cx, cmd.to_string())
5323 .await
5324 .unwrap_or_else(|_| panic!("send {label}"));
5325 recv_response(out_rx, label).await
5326 }
5327
5328 fn assert_ok(resp: &Value, command: &str) {
5329 assert_eq!(resp["type"], "response", "response type for {command}");
5330 assert_eq!(resp["command"], command);
5331 assert_eq!(resp["success"], true, "success for {command}: {resp}");
5332 }
5333
5334 fn assert_err(resp: &Value, command: &str) {
5335 assert_eq!(resp["type"], "response", "response type for {command}");
5336 assert_eq!(resp["command"], command);
5337 assert_eq!(
5338 resp["success"], false,
5339 "expected error for {command}: {resp}"
5340 );
5341 }
5342
5343 async fn recv_ui_request(out_rx: &Arc<Mutex<Receiver<String>>>, label: &str) -> Value {
5344 let start = Instant::now();
5345 loop {
5346 let recv_result = {
5347 let rx = out_rx.lock().expect("lock rpc output receiver");
5348 rx.try_recv()
5349 };
5350
5351 match recv_result {
5352 Ok(line) => {
5353 if let Ok(val) = serde_json::from_str::<Value>(&line) {
5354 if val.get("type").and_then(Value::as_str) == Some("extension_ui_request") {
5355 return val;
5356 }
5357 }
5358 }
5359 Err(TryRecvError::Disconnected) => {
5360 panic!(
5361 "{label}: output channel disconnected while waiting for extension_ui_request"
5362 );
5363 }
5364 Err(TryRecvError::Empty) => {}
5365 }
5366
5367 assert!(
5368 start.elapsed() <= Duration::from_secs(10),
5369 "{label}: timed out waiting for extension_ui_request"
5370 );
5371 asupersync::time::sleep(asupersync::time::wall_now(), Duration::from_millis(5)).await;
5372 }
5373 }
5374
5375 async fn wait_for_custom_message(
5376 in_tx: &asupersync::channel::mpsc::Sender<String>,
5377 out_rx: &Arc<Mutex<Receiver<String>>>,
5378 custom_type: &str,
5379 label: &str,
5380 ) -> Value {
5381 let start = Instant::now();
5382 let mut attempt = 0usize;
5383
5384 loop {
5385 let response = send_recv(
5386 in_tx,
5387 out_rx,
5388 &format!(r#"{{"id":"poll-{attempt}","type":"get_messages"}}"#),
5389 label,
5390 )
5391 .await;
5392 let messages = response["data"]["messages"]
5393 .as_array()
5394 .expect("messages array");
5395 if let Some(message) = messages
5396 .iter()
5397 .find(|message| message["role"] == "custom" && message["customType"] == custom_type)
5398 {
5399 return message.clone();
5400 }
5401
5402 assert!(
5403 start.elapsed() <= Duration::from_secs(10),
5404 "{label}: timed out waiting for custom message"
5405 );
5406 attempt = attempt.saturating_add(1);
5407 asupersync::time::sleep(asupersync::time::wall_now(), Duration::from_millis(10)).await;
5408 }
5409 }
5410
5411 const RPC_BUSY_EXTENSION_COMMAND_EXT: &str = r#"
5412export default function init(pi) {
5413 pi.registerCommand("wait-confirm", {
5414 description: "Block until RPC confirms",
5415 handler: async () => {
5416 const confirmed = await pi.ui("confirm", {
5417 title: "Wait",
5418 message: "Hold the command open"
5419 });
5420 return confirmed ? "confirmed" : "cancelled";
5421 }
5422 });
5423}
5424"#;
5425
5426 const RPC_QUEUE_STATE_EXTENSION_EXT: &str = r#"
5427export default function init(pi) {
5428 pi.registerCommand("report-queue-state", {
5429 description: "Report queue modes visible to extensions",
5430 handler: async () => {
5431 const state = await pi.session("getState", {});
5432 await pi.events("sendMessage", {
5433 message: {
5434 customType: "queue-state",
5435 content: JSON.stringify({
5436 steeringMode: state.steeringMode,
5437 followUpMode: state.followUpMode
5438 }),
5439 display: false
5440 },
5441 options: {
5442 triggerTurn: false
5443 }
5444 });
5445 return "reported";
5446 }
5447 });
5448}
5449"#;
5450
5451 #[test]
5452 fn line_count_from_newline_count_matches_trailing_newline_semantics() {
5453 assert_eq!(line_count_from_newline_count(0, 0, false), 0);
5454 assert_eq!(line_count_from_newline_count(2, 1, true), 1);
5455 assert_eq!(line_count_from_newline_count(1, 0, false), 1);
5456 assert_eq!(line_count_from_newline_count(3, 1, false), 2);
5457 }
5458
5459 #[test]
5464 fn parse_queue_mode_all() {
5465 assert_eq!(parse_queue_mode(Some("all")), Some(QueueMode::All));
5466 }
5467
5468 #[test]
5469 fn parse_queue_mode_one_at_a_time() {
5470 assert_eq!(
5471 parse_queue_mode(Some("one-at-a-time")),
5472 Some(QueueMode::OneAtATime)
5473 );
5474 }
5475
5476 #[test]
5477 fn parse_queue_mode_none_value() {
5478 assert_eq!(parse_queue_mode(None), None);
5479 }
5480
5481 #[test]
5482 fn parse_queue_mode_unknown_returns_none() {
5483 assert_eq!(parse_queue_mode(Some("batch")), None);
5484 assert_eq!(parse_queue_mode(Some("")), None);
5485 }
5486
5487 #[test]
5488 fn parse_queue_mode_trims_whitespace() {
5489 assert_eq!(parse_queue_mode(Some(" all ")), Some(QueueMode::All));
5490 }
5491
5492 #[test]
5493 fn provider_ids_match_accepts_aliases() {
5494 assert!(provider_ids_match("openrouter", "open-router"));
5495 assert!(provider_ids_match("google-gemini-cli", "gemini-cli"));
5496 assert!(!provider_ids_match("openai", "anthropic"));
5497 }
5498
5499 #[test]
5500 fn resolve_model_key_prefers_stored_auth_key_over_inline_entry_key() {
5501 let mut entry = dummy_entry("gpt-4o-mini", true);
5502 entry.model.provider = "openai".to_string();
5503 entry.auth_header = true;
5504 entry.api_key = Some("dummy-test-key-12345".to_string());
5505
5506 let auth_path = tempfile::tempdir()
5507 .expect("tempdir")
5508 .path()
5509 .join("auth.json");
5510 let mut auth = AuthStorage::load(auth_path).expect("auth load");
5511 auth.set(
5512 "openai".to_string(),
5513 AuthCredential::ApiKey {
5514 key: "stored-auth-key".to_string(),
5515 },
5516 );
5517
5518 assert_eq!(
5519 resolve_model_key(None, &auth, &entry).as_deref(),
5520 Some("stored-auth-key")
5521 );
5522 }
5523
5524 #[test]
5525 fn resolve_model_key_ignores_blank_inline_key_and_falls_back_to_auth_storage() {
5526 let mut entry = dummy_entry("gpt-4o-mini", true);
5527 entry.model.provider = "openai".to_string();
5528 entry.auth_header = true;
5529 entry.api_key = Some(" ".to_string()); let auth_path = tempfile::tempdir()
5532 .expect("tempdir")
5533 .path()
5534 .join("auth.json");
5535 let mut auth = AuthStorage::load(auth_path).expect("auth load");
5536 auth.set(
5537 "openai".to_string(),
5538 AuthCredential::ApiKey {
5539 key: "stored-auth-key".to_string(),
5540 },
5541 );
5542
5543 assert_eq!(
5544 resolve_model_key(None, &auth, &entry).as_deref(),
5545 Some("stored-auth-key")
5546 );
5547 }
5548
5549 #[test]
5550 fn resolve_model_key_prefers_cli_override_over_stored_and_inline_keys() {
5551 let mut entry = dummy_entry("gpt-4o-mini", true);
5552 entry.model.provider = "openai".to_string();
5553 entry.auth_header = true;
5554 entry.api_key = Some("inline-key".to_string());
5555
5556 let temp = tempfile::tempdir().expect("tempdir");
5557 let auth_path = temp.path().join("auth.json");
5558 let mut auth = AuthStorage::load(auth_path).expect("auth load");
5559 auth.set(
5560 "openai".to_string(),
5561 AuthCredential::ApiKey {
5562 key: "stored-auth-key".to_string(),
5563 },
5564 );
5565
5566 assert_eq!(
5567 resolve_model_key(Some("cli-override-key"), &auth, &entry).as_deref(),
5568 Some("cli-override-key")
5569 );
5570 }
5571
5572 #[test]
5573 fn unknown_keyless_model_does_not_require_credentials() {
5574 let mut entry = dummy_entry("dev-model", false);
5575 entry.model.provider = "acme-local".to_string();
5576 entry.auth_header = false;
5577 entry.oauth_config = None;
5578
5579 assert!(!model_requires_configured_credential(&entry));
5580 }
5581
5582 #[test]
5583 fn anthropic_model_requires_credentials_even_without_auth_header() {
5584 let mut entry = dummy_entry("claude-sonnet-4-6", true);
5585 entry.model.provider = "anthropic".to_string();
5586 entry.auth_header = false;
5587 entry.oauth_config = None;
5588
5589 assert!(model_requires_configured_credential(&entry));
5590 }
5591
5592 #[test]
5597 fn parse_streaming_behavior_steer() {
5598 let val = json!("steer");
5599 let result = parse_streaming_behavior(Some(&val)).unwrap();
5600 assert_eq!(result, Some(StreamingBehavior::Steer));
5601 }
5602
5603 #[test]
5604 fn parse_streaming_behavior_follow_up_hyphenated() {
5605 let val = json!("follow-up");
5606 let result = parse_streaming_behavior(Some(&val)).unwrap();
5607 assert_eq!(result, Some(StreamingBehavior::FollowUp));
5608 }
5609
5610 #[test]
5611 fn parse_streaming_behavior_follow_up_camel() {
5612 let val = json!("followUp");
5613 let result = parse_streaming_behavior(Some(&val)).unwrap();
5614 assert_eq!(result, Some(StreamingBehavior::FollowUp));
5615 }
5616
5617 #[test]
5618 fn parse_streaming_behavior_follow_up_snake() {
5619 let val = json!("follow_up");
5620 let result = parse_streaming_behavior(Some(&val)).unwrap();
5621 assert_eq!(result, Some(StreamingBehavior::FollowUp));
5622 }
5623
5624 #[test]
5625 fn parse_streaming_behavior_none() {
5626 let result = parse_streaming_behavior(None).unwrap();
5627 assert_eq!(result, None);
5628 }
5629
5630 #[test]
5631 fn parse_streaming_behavior_invalid_string() {
5632 let val = json!("invalid");
5633 assert!(parse_streaming_behavior(Some(&val)).is_err());
5634 }
5635
5636 #[test]
5637 fn parse_streaming_behavior_non_string_errors() {
5638 let val = json!(42);
5639 assert!(parse_streaming_behavior(Some(&val)).is_err());
5640 }
5641
5642 #[test]
5643 fn streaming_behavior_field_accepts_snake_case_key() {
5644 let payload = json!({ "streaming_behavior": "follow_up" });
5645 let value = streaming_behavior_value(&payload).expect("streaming behavior value");
5646 let result = parse_streaming_behavior(Some(value)).unwrap();
5647 assert_eq!(result, Some(StreamingBehavior::FollowUp));
5648 }
5649
5650 #[test]
5655 fn parse_optional_u32_field_none() {
5656 let payload = json!({ "type": "compact" });
5657 let parsed = parse_optional_u32_field(&payload, "reserveTokens").unwrap();
5658 assert_eq!(parsed, None);
5659 }
5660
5661 #[test]
5662 fn parse_optional_u32_field_valid() {
5663 let payload = json!({ "reserveTokens": 8192 });
5664 let parsed = parse_optional_u32_field(&payload, "reserveTokens").unwrap();
5665 assert_eq!(parsed, Some(8192));
5666 }
5667
5668 #[test]
5669 fn parse_optional_u32_field_invalid_type() {
5670 let payload = json!({ "reserveTokens": "8192" });
5671 assert!(parse_optional_u32_field(&payload, "reserveTokens").is_err());
5672 }
5673
5674 #[test]
5675 fn parse_optional_u32_field_too_large() {
5676 let payload = json!({ "reserveTokens": u64::from(u32::MAX) + 1 });
5677 assert!(parse_optional_u32_field(&payload, "reserveTokens").is_err());
5678 }
5679
5680 #[test]
5685 fn normalize_command_type_passthrough() {
5686 assert_eq!(normalize_command_type("prompt"), "prompt");
5687 assert_eq!(normalize_command_type("compact"), "compact");
5688 }
5689
5690 #[test]
5691 fn normalize_command_type_follow_up_aliases() {
5692 assert_eq!(normalize_command_type("follow-up"), "follow_up");
5693 assert_eq!(normalize_command_type("followUp"), "follow_up");
5694 assert_eq!(normalize_command_type("queue-follow-up"), "follow_up");
5695 assert_eq!(normalize_command_type("queueFollowUp"), "follow_up");
5696 }
5697
5698 #[test]
5699 fn normalize_command_type_kebab_and_camel_aliases() {
5700 assert_eq!(normalize_command_type("get-state"), "get_state");
5701 assert_eq!(normalize_command_type("getState"), "get_state");
5702 assert_eq!(normalize_command_type("set-model"), "set_model");
5703 assert_eq!(normalize_command_type("setModel"), "set_model");
5704 assert_eq!(
5705 normalize_command_type("set-steering-mode"),
5706 "set_steering_mode"
5707 );
5708 assert_eq!(
5709 normalize_command_type("setSteeringMode"),
5710 "set_steering_mode"
5711 );
5712 assert_eq!(
5713 normalize_command_type("set-follow-up-mode"),
5714 "set_follow_up_mode"
5715 );
5716 assert_eq!(
5717 normalize_command_type("setFollowUpMode"),
5718 "set_follow_up_mode"
5719 );
5720 assert_eq!(
5721 normalize_command_type("set-auto-compaction"),
5722 "set_auto_compaction"
5723 );
5724 assert_eq!(
5725 normalize_command_type("setAutoCompaction"),
5726 "set_auto_compaction"
5727 );
5728 assert_eq!(normalize_command_type("set-auto-retry"), "set_auto_retry");
5729 assert_eq!(normalize_command_type("setAutoRetry"), "set_auto_retry");
5730 }
5731
5732 #[test]
5737 fn build_user_message_text_only() {
5738 let msg = build_user_message("hello", &[]);
5739 match msg {
5740 Message::User(UserMessage {
5741 content: UserContent::Text(text),
5742 ..
5743 }) => assert_eq!(text, "hello"),
5744 other => unreachable!("expected different match, got: {other:?}"),
5745 }
5746 }
5747
5748 #[test]
5749 fn build_user_message_with_images() {
5750 let images = vec![ImageContent {
5751 data: "base64data".to_string(),
5752 mime_type: "image/png".to_string(),
5753 }];
5754 let msg = build_user_message("look at this", &images);
5755 match msg {
5756 Message::User(UserMessage {
5757 content: UserContent::Blocks(blocks),
5758 ..
5759 }) => {
5760 assert_eq!(blocks.len(), 2);
5761 assert!(matches!(&blocks[0], ContentBlock::Text(_)));
5762 assert!(matches!(&blocks[1], ContentBlock::Image(_)));
5763 }
5764 other => unreachable!("expected different match, got: {other:?}"),
5765 }
5766 }
5767
5768 #[test]
5769 fn build_user_message_image_only_omits_empty_text_block() {
5770 let images = vec![ImageContent {
5771 data: "base64data".to_string(),
5772 mime_type: "image/png".to_string(),
5773 }];
5774 let msg = build_user_message("", &images);
5775 match msg {
5776 Message::User(UserMessage {
5777 content: UserContent::Blocks(blocks),
5778 ..
5779 }) => {
5780 assert_eq!(blocks.len(), 1);
5781 assert!(matches!(&blocks[0], ContentBlock::Image(_)));
5782 }
5783 other => unreachable!("expected different match, got: {other:?}"),
5784 }
5785 }
5786
5787 #[test]
5792 fn parse_extension_command_line_parses_simple_command() {
5793 assert_eq!(
5794 parse_extension_command_line("/mycommand"),
5795 Some(("mycommand".to_string(), String::new()))
5796 );
5797 }
5798
5799 #[test]
5800 fn parse_extension_command_line_preserves_arguments() {
5801 assert_eq!(
5802 parse_extension_command_line("/mycommand alpha beta"),
5803 Some(("mycommand".to_string(), "alpha beta".to_string()))
5804 );
5805 }
5806
5807 #[test]
5808 fn parse_extension_command_line_requires_leading_slash() {
5809 assert_eq!(parse_extension_command_line("hello"), None);
5810 }
5811
5812 #[test]
5813 fn parse_extension_command_line_accepts_leading_whitespace() {
5814 assert_eq!(
5815 parse_extension_command_line(" /cmd\targ"),
5816 Some(("cmd".to_string(), "arg".to_string()))
5817 );
5818 }
5819
5820 #[test]
5821 fn parse_extension_command_line_rejects_blank_command_name() {
5822 assert_eq!(parse_extension_command_line("/ "), None);
5823 }
5824
5825 #[test]
5826 fn rpc_busy_extension_command_rejects_follow_on_extension_prompt_without_blocking() {
5827 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
5828 .build()
5829 .expect("build test runtime");
5830 let handle = runtime.handle();
5831
5832 runtime.block_on(async move {
5833 let temp = tempfile::tempdir().expect("tempdir");
5834 let cwd = temp.path().to_path_buf();
5835 let ext_entry_path = cwd.join("busy-ext.mjs");
5836 std::fs::write(&ext_entry_path, RPC_BUSY_EXTENSION_COMMAND_EXT)
5837 .expect("write extension source");
5838
5839 let mut agent_session = build_test_agent_session(Session::in_memory());
5840 agent_session
5841 .enable_extensions(&[], &cwd, None, &[ext_entry_path])
5842 .await
5843 .expect("enable extensions");
5844
5845 let options = build_test_rpc_options(&handle, cwd.join("auth.json"));
5846 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
5847 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
5848 let out_rx = Arc::new(Mutex::new(out_rx));
5849
5850 let server =
5851 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
5852
5853 let first = send_recv(
5854 &in_tx,
5855 &out_rx,
5856 r#"{"id":"1","type":"prompt","message":"/wait-confirm"}"#,
5857 "prompt(wait-confirm:first)",
5858 )
5859 .await;
5860 assert_ok(&first, "prompt");
5861
5862 let ui_event = recv_ui_request(&out_rx, "wait-confirm ui").await;
5863 assert_eq!(ui_event["method"], "confirm");
5864 let request_id = ui_event["id"]
5865 .as_str()
5866 .expect("ui request id should be a string")
5867 .to_string();
5868
5869 let second = send_recv(
5870 &in_tx,
5871 &out_rx,
5872 r#"{"id":"2","type":"prompt","message":"/wait-confirm"}"#,
5873 "prompt(wait-confirm:busy)",
5874 )
5875 .await;
5876 assert_err(&second, "prompt");
5877 assert_eq!(
5878 second["error"],
5879 "Extension commands are not allowed while agent is streaming"
5880 );
5881
5882 let response = json!({
5883 "id": "3",
5884 "type": "extension_ui_response",
5885 "requestId": request_id,
5886 "confirmed": true,
5887 })
5888 .to_string();
5889 let ui_resp = send_recv(&in_tx, &out_rx, &response, "wait-confirm response").await;
5890 assert_ok(&ui_resp, "extension_ui_response");
5891
5892 drop(in_tx);
5893 let result = server.await;
5894 assert!(result.is_ok(), "rpc server error: {result:?}");
5895 });
5896 }
5897
5898 #[test]
5899 fn rpc_queue_mode_updates_reach_extension_session_state() {
5900 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
5901 .build()
5902 .expect("build test runtime");
5903 let handle = runtime.handle();
5904
5905 runtime.block_on(async move {
5906 let temp = tempfile::tempdir().expect("tempdir");
5907 let cwd = temp.path().to_path_buf();
5908 let ext_entry_path = cwd.join("queue-state-ext.mjs");
5909 std::fs::write(&ext_entry_path, RPC_QUEUE_STATE_EXTENSION_EXT)
5910 .expect("write extension source");
5911
5912 let mut agent_session = build_test_agent_session(Session::in_memory());
5913 agent_session
5914 .enable_extensions(&[], &cwd, None, &[ext_entry_path])
5915 .await
5916 .expect("enable extensions");
5917
5918 let options = build_test_rpc_options(&handle, cwd.join("auth.json"));
5919 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
5920 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
5921 let out_rx = Arc::new(Mutex::new(out_rx));
5922
5923 let server =
5924 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
5925
5926 let steering = send_recv(
5927 &in_tx,
5928 &out_rx,
5929 r#"{"id":"1","type":"set_steering_mode","mode":"all"}"#,
5930 "set_steering_mode(queue-state)",
5931 )
5932 .await;
5933 assert_ok(&steering, "set_steering_mode");
5934
5935 let follow_up = send_recv(
5936 &in_tx,
5937 &out_rx,
5938 r#"{"id":"2","type":"setFollowUpMode","mode":"all"}"#,
5939 "setFollowUpMode(queue-state)",
5940 )
5941 .await;
5942 assert_ok(&follow_up, "set_follow_up_mode");
5943
5944 let prompt = send_recv(
5945 &in_tx,
5946 &out_rx,
5947 r#"{"id":"3","type":"prompt","message":"/report-queue-state"}"#,
5948 "prompt(report-queue-state)",
5949 )
5950 .await;
5951 assert_ok(&prompt, "prompt");
5952
5953 let message =
5954 wait_for_custom_message(&in_tx, &out_rx, "queue-state", "queue-state message")
5955 .await;
5956 let reported_state: Value = serde_json::from_str(
5957 message["content"]
5958 .as_str()
5959 .expect("queue-state content should be string"),
5960 )
5961 .expect("queue-state content should be json");
5962 assert_eq!(reported_state["steeringMode"], "all");
5963 assert_eq!(reported_state["followUpMode"], "all");
5964
5965 drop(in_tx);
5966 let result = server.await;
5967 assert!(result.is_ok(), "rpc server error: {result:?}");
5968 });
5969 }
5970
5971 #[test]
5972 fn rpc_prompt_prefers_extension_command_over_prompt_template_name_collision() {
5973 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
5974 .build()
5975 .expect("build test runtime");
5976 let handle = runtime.handle();
5977
5978 runtime.block_on(async move {
5979 let temp = tempfile::tempdir().expect("tempdir");
5980 let cwd = temp.path().to_path_buf();
5981 let (agent_session, options) = build_queue_state_rpc_fixture(&handle, &cwd).await;
5982 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
5983 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
5984 let out_rx = Arc::new(Mutex::new(out_rx));
5985
5986 let server =
5987 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
5988
5989 let prompt = send_recv(
5990 &in_tx,
5991 &out_rx,
5992 r#"{"id":"1","type":"prompt","message":"/report-queue-state"}"#,
5993 "prompt(report-queue-state:shadowed)",
5994 )
5995 .await;
5996 assert_ok(&prompt, "prompt");
5997
5998 let message =
5999 wait_for_custom_message(&in_tx, &out_rx, "queue-state", "queue-state shadowed")
6000 .await;
6001 let reported_state: Value = serde_json::from_str(
6002 message["content"]
6003 .as_str()
6004 .expect("queue-state content should be string"),
6005 )
6006 .expect("queue-state content should be json");
6007 assert!(
6008 reported_state["steeringMode"].is_string(),
6009 "extension command should report steeringMode"
6010 );
6011 assert!(
6012 reported_state["followUpMode"].is_string(),
6013 "extension command should report followUpMode"
6014 );
6015
6016 drop(in_tx);
6017 let result = server.await;
6018 assert!(result.is_ok(), "rpc server error: {result:?}");
6019 });
6020 }
6021
6022 #[test]
6023 fn rpc_steer_rejects_extension_command_even_when_prompt_template_name_matches() {
6024 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
6025 .build()
6026 .expect("build test runtime");
6027 let handle = runtime.handle();
6028
6029 runtime.block_on(async move {
6030 let temp = tempfile::tempdir().expect("tempdir");
6031 let cwd = temp.path().to_path_buf();
6032 let (agent_session, options) = build_queue_state_rpc_fixture(&handle, &cwd).await;
6033 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
6034 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
6035 let out_rx = Arc::new(Mutex::new(out_rx));
6036
6037 let server =
6038 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
6039
6040 let response = send_recv(
6041 &in_tx,
6042 &out_rx,
6043 r#"{"id":"1","type":"steer","message":"/report-queue-state"}"#,
6044 "steer(report-queue-state:shadowed)",
6045 )
6046 .await;
6047 assert_err(&response, "steer");
6048 assert_eq!(
6049 response["error"],
6050 "Extension commands are not allowed with steer"
6051 );
6052
6053 drop(in_tx);
6054 let result = server.await;
6055 assert!(result.is_ok(), "rpc server error: {result:?}");
6056 });
6057 }
6058
6059 #[test]
6060 fn rpc_follow_up_rejects_extension_command_even_when_prompt_template_name_matches() {
6061 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
6062 .build()
6063 .expect("build test runtime");
6064 let handle = runtime.handle();
6065
6066 runtime.block_on(async move {
6067 let temp = tempfile::tempdir().expect("tempdir");
6068 let cwd = temp.path().to_path_buf();
6069 let (agent_session, options) = build_queue_state_rpc_fixture(&handle, &cwd).await;
6070 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
6071 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
6072 let out_rx = Arc::new(Mutex::new(out_rx));
6073
6074 let server =
6075 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
6076
6077 let response = send_recv(
6078 &in_tx,
6079 &out_rx,
6080 r#"{"id":"1","type":"follow_up","message":"/report-queue-state"}"#,
6081 "follow_up(report-queue-state:shadowed)",
6082 )
6083 .await;
6084 assert_err(&response, "follow_up");
6085 assert_eq!(
6086 response["error"],
6087 "Extension commands are not allowed with follow_up"
6088 );
6089
6090 drop(in_tx);
6091 let result = server.await;
6092 assert!(result.is_ok(), "rpc server error: {result:?}");
6093 });
6094 }
6095
6096 #[test]
6097 fn rpc_startup_queue_modes_reach_extension_session_state() {
6098 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
6099 .build()
6100 .expect("build test runtime");
6101 let handle = runtime.handle();
6102
6103 runtime.block_on(async move {
6104 let temp = tempfile::tempdir().expect("tempdir");
6105 let cwd = temp.path().to_path_buf();
6106 let ext_entry_path = cwd.join("queue-state-ext.mjs");
6107 std::fs::write(&ext_entry_path, RPC_QUEUE_STATE_EXTENSION_EXT)
6108 .expect("write extension source");
6109
6110 let mut agent_session = build_test_agent_session(Session::in_memory());
6111 agent_session
6112 .enable_extensions(&[], &cwd, None, &[ext_entry_path])
6113 .await
6114 .expect("enable extensions");
6115
6116 let mut options = build_test_rpc_options(&handle, cwd.join("auth.json"));
6117 options.config.steering_mode = Some("all".to_string());
6118 options.config.follow_up_mode = Some("all".to_string());
6119
6120 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
6121 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
6122 let out_rx = Arc::new(Mutex::new(out_rx));
6123
6124 let server =
6125 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
6126
6127 let prompt = send_recv(
6128 &in_tx,
6129 &out_rx,
6130 r#"{"id":"1","type":"prompt","message":"/report-queue-state"}"#,
6131 "prompt(report-queue-state-startup)",
6132 )
6133 .await;
6134 assert_ok(&prompt, "prompt");
6135
6136 let message = wait_for_custom_message(
6137 &in_tx,
6138 &out_rx,
6139 "queue-state",
6140 "queue-state startup message",
6141 )
6142 .await;
6143 let reported_state: Value = serde_json::from_str(
6144 message["content"]
6145 .as_str()
6146 .expect("queue-state content should be string"),
6147 )
6148 .expect("queue-state content should be json");
6149 assert_eq!(reported_state["steeringMode"], "all");
6150 assert_eq!(reported_state["followUpMode"], "all");
6151
6152 drop(in_tx);
6153 let result = server.await;
6154 assert!(result.is_ok(), "rpc server error: {result:?}");
6155 });
6156 }
6157
6158 #[test]
6163 fn try_send_line_with_backpressure_enqueues_when_capacity_available() {
6164 let (tx, _rx) = mpsc::channel::<String>(1);
6165 assert!(try_send_line_with_backpressure(&tx, "line".to_string()));
6166 assert!(matches!(
6167 tx.try_send("next".to_string()),
6168 Err(mpsc::SendError::Full(_))
6169 ));
6170 }
6171
6172 #[test]
6173 fn try_send_line_with_backpressure_stops_when_receiver_closed() {
6174 let (tx, rx) = mpsc::channel::<String>(1);
6175 drop(rx);
6176 assert!(!try_send_line_with_backpressure(&tx, "line".to_string()));
6177 }
6178
6179 #[test]
6180 fn try_send_line_with_backpressure_waits_until_capacity_is_available() {
6181 let (tx, mut rx) = mpsc::channel::<String>(1);
6182 tx.try_send("occupied".to_string())
6183 .expect("seed initial occupied slot");
6184
6185 let expected = "delayed-line".to_string();
6186 let expected_for_thread = expected.clone();
6187 let recv_handle = std::thread::spawn(move || {
6188 std::thread::sleep(Duration::from_millis(30));
6189 let deadline = Instant::now() + Duration::from_millis(300);
6190 let mut received = Vec::new();
6191 while received.len() < 2 && Instant::now() < deadline {
6192 if let Ok(msg) = rx.try_recv() {
6193 received.push(msg);
6194 } else {
6195 std::thread::sleep(Duration::from_millis(5));
6196 }
6197 }
6198 assert_eq!(received.len(), 2, "should receive both queued lines");
6199 let first = received.remove(0);
6200 let second = received.remove(0);
6201 assert_eq!(first, "occupied");
6202 assert_eq!(second, expected_for_thread);
6203 });
6204
6205 assert!(try_send_line_with_backpressure(&tx, expected));
6206 drop(tx);
6207 recv_handle.join().expect("receiver thread should finish");
6208 }
6209
6210 #[test]
6211 fn try_send_line_with_backpressure_preserves_large_payload() {
6212 let (tx, mut rx) = mpsc::channel::<String>(1);
6213 tx.try_send("busy".to_string())
6214 .expect("seed initial busy slot");
6215
6216 let large = "x".repeat(256 * 1024);
6217 let large_for_thread = large.clone();
6218 let recv_handle = std::thread::spawn(move || {
6219 std::thread::sleep(Duration::from_millis(30));
6220 let deadline = Instant::now() + Duration::from_millis(500);
6221 let mut received = Vec::new();
6222 while received.len() < 2 && Instant::now() < deadline {
6223 if let Ok(msg) = rx.try_recv() {
6224 received.push(msg);
6225 } else {
6226 std::thread::sleep(Duration::from_millis(5));
6227 }
6228 }
6229 assert_eq!(received.len(), 2, "should receive busy + payload lines");
6230 let payload = received.remove(1);
6231 assert_eq!(payload.len(), large_for_thread.len());
6232 assert_eq!(payload, large_for_thread);
6233 });
6234
6235 assert!(try_send_line_with_backpressure(&tx, large));
6236 drop(tx);
6237 recv_handle.join().expect("receiver thread should finish");
6238 }
6239
6240 #[test]
6241 fn try_send_line_with_backpressure_detects_disconnect_while_waiting() {
6242 let (tx, rx) = mpsc::channel::<String>(1);
6243 tx.try_send("busy".to_string())
6244 .expect("seed initial busy slot");
6245
6246 let drop_handle = std::thread::spawn(move || {
6247 std::thread::sleep(Duration::from_millis(30));
6248 drop(rx);
6249 });
6250
6251 assert!(
6252 !try_send_line_with_backpressure(&tx, "line-after-disconnect".to_string()),
6253 "send should stop after receiver disconnects while channel is full"
6254 );
6255 drop_handle.join().expect("drop thread should finish");
6256 }
6257
6258 #[test]
6259 fn try_send_line_with_backpressure_high_volume_preserves_order_and_count() {
6260 let (tx, mut rx) = mpsc::channel::<String>(4);
6261 let lines: Vec<String> = (0..256)
6262 .map(|idx| format!("line-{idx:03}: {}", "x".repeat(64)))
6263 .collect();
6264 let expected = lines.clone();
6265
6266 let recv_handle = std::thread::spawn(move || {
6267 let deadline = Instant::now() + Duration::from_secs(4);
6268 let mut received = Vec::new();
6269 while received.len() < expected.len() && Instant::now() < deadline {
6270 if let Ok(msg) = rx.try_recv() {
6271 received.push(msg);
6272 }
6273 std::thread::sleep(Duration::from_millis(1));
6274 }
6275 assert_eq!(
6276 received.len(),
6277 expected.len(),
6278 "should receive every line under sustained backpressure"
6279 );
6280 assert_eq!(received, expected, "line ordering must remain stable");
6281 });
6282
6283 for line in lines {
6284 assert!(try_send_line_with_backpressure(&tx, line));
6285 }
6286 drop(tx);
6287 recv_handle.join().expect("receiver thread should finish");
6288 }
6289
6290 #[test]
6291 fn try_send_line_with_backpressure_preserves_partial_line_without_newline() {
6292 let (tx, mut rx) = mpsc::channel::<String>(1);
6293 tx.try_send("busy".to_string())
6294 .expect("seed initial busy slot");
6295
6296 let partial_json = "{\"type\":\"prompt\",\"message\":\"tail-fragment-ascii\"".to_string();
6297 let expected = partial_json.clone();
6298
6299 let recv_handle = std::thread::spawn(move || {
6300 std::thread::sleep(Duration::from_millis(25));
6301 let first = rx.try_recv().expect("seeded line should be available");
6302 assert_eq!(first, "busy");
6303 let deadline = Instant::now() + Duration::from_millis(500);
6304 let second = loop {
6305 if let Ok(line) = rx.try_recv() {
6306 break line;
6307 }
6308 assert!(
6309 Instant::now() < deadline,
6310 "partial payload should be available"
6311 );
6312 std::thread::sleep(Duration::from_millis(5));
6313 };
6314 assert_eq!(second, expected);
6315 });
6316
6317 assert!(try_send_line_with_backpressure(&tx, partial_json));
6318 drop(tx);
6319 recv_handle.join().expect("receiver thread should finish");
6320 }
6321
6322 #[test]
6327 fn snapshot_pending_count() {
6328 let snapshot = RpcStateSnapshot {
6329 steering_count: 3,
6330 follow_up_count: 7,
6331 steering_mode: QueueMode::All,
6332 follow_up_mode: QueueMode::OneAtATime,
6333 auto_compaction_enabled: false,
6334 auto_retry_enabled: true,
6335 };
6336 assert_eq!(snapshot.pending_count(), 10);
6337 }
6338
6339 #[test]
6340 fn snapshot_pending_count_zero() {
6341 let snapshot = RpcStateSnapshot {
6342 steering_count: 0,
6343 follow_up_count: 0,
6344 steering_mode: QueueMode::All,
6345 follow_up_mode: QueueMode::All,
6346 auto_compaction_enabled: false,
6347 auto_retry_enabled: false,
6348 };
6349 assert_eq!(snapshot.pending_count(), 0);
6350 }
6351
6352 #[test]
6353 fn shared_state_blocks_follow_up_when_steering_queue_reaches_total_cap() {
6354 let config = Config::default();
6355 let mut shared = RpcSharedState::new(&config);
6356
6357 for idx in 0..MAX_RPC_PENDING_MESSAGES {
6358 shared
6359 .push_steering(build_user_message(&format!("steer-{idx}"), &[]))
6360 .expect("steering enqueue within total cap");
6361 }
6362
6363 let err = shared
6364 .push_follow_up(build_user_message("follow-up-overflow", &[]))
6365 .expect_err("follow-up enqueue should respect total pending cap");
6366 assert!(matches!(err, Error::Session(_)));
6367 assert_eq!(shared.pending_count(), MAX_RPC_PENDING_MESSAGES);
6368 }
6369
6370 #[test]
6371 fn shared_state_blocks_steering_when_follow_up_queue_reaches_total_cap() {
6372 let config = Config::default();
6373 let mut shared = RpcSharedState::new(&config);
6374
6375 for idx in 0..MAX_RPC_PENDING_MESSAGES {
6376 shared
6377 .push_follow_up(build_user_message(&format!("follow-up-{idx}"), &[]))
6378 .expect("follow-up enqueue within total cap");
6379 }
6380
6381 let err = shared
6382 .push_steering(build_user_message("steer-overflow", &[]))
6383 .expect_err("steering enqueue should respect total pending cap");
6384 assert!(matches!(err, Error::Session(_)));
6385 assert_eq!(shared.pending_count(), MAX_RPC_PENDING_MESSAGES);
6386 }
6387
6388 #[test]
6393 fn retry_delay_first_attempt_is_base() {
6394 let config = Config::default();
6395 assert_eq!(retry_delay_ms(&config, 0), config.retry_base_delay_ms());
6397 assert_eq!(retry_delay_ms(&config, 1), config.retry_base_delay_ms());
6398 }
6399
6400 #[test]
6401 fn retry_delay_doubles_each_attempt() {
6402 let config = Config::default();
6403 let base = config.retry_base_delay_ms();
6404 assert_eq!(retry_delay_ms(&config, 2), base * 2);
6406 assert_eq!(retry_delay_ms(&config, 3), base * 4);
6407 }
6408
6409 #[test]
6410 fn retry_delay_capped_at_max() {
6411 let config = Config::default();
6412 let max = config.retry_max_delay_ms();
6413 let delay = retry_delay_ms(&config, 30);
6415 assert_eq!(delay, max);
6416 }
6417
6418 #[test]
6419 fn retry_delay_saturates_on_overflow() {
6420 let config = Config::default();
6421 let delay = retry_delay_ms(&config, u32::MAX);
6423 assert!(delay <= config.retry_max_delay_ms());
6424 }
6425
6426 #[test]
6431 fn auto_compact_below_threshold() {
6432 assert!(!should_auto_compact(50_000, 200_000, 40_000));
6434 }
6435
6436 #[test]
6437 fn auto_compact_above_threshold() {
6438 assert!(should_auto_compact(170_000, 200_000, 40_000));
6440 }
6441
6442 #[test]
6443 fn auto_compact_exact_threshold() {
6444 assert!(!should_auto_compact(160_000, 200_000, 40_000));
6446 }
6447
6448 #[test]
6449 fn auto_compact_reserve_exceeds_window() {
6450 assert!(should_auto_compact(1, 100, 200));
6452 }
6453
6454 #[test]
6455 fn auto_compact_zero_tokens() {
6456 assert!(!should_auto_compact(0, 200_000, 40_000));
6457 }
6458
6459 #[test]
6464 fn flatten_content_blocks_unwraps_inner_0() {
6465 let mut value = json!({
6466 "content": [
6467 {"0": {"type": "text", "text": "hello"}}
6468 ]
6469 });
6470 rpc_flatten_content_blocks(&mut value);
6471 let blocks = value["content"].as_array().unwrap();
6472 assert_eq!(blocks[0]["type"], "text");
6473 assert_eq!(blocks[0]["text"], "hello");
6474 assert!(blocks[0].get("0").is_none());
6475 }
6476
6477 #[test]
6478 fn flatten_content_blocks_preserves_non_wrapped() {
6479 let mut value = json!({
6480 "content": [
6481 {"type": "text", "text": "already flat"}
6482 ]
6483 });
6484 rpc_flatten_content_blocks(&mut value);
6485 let blocks = value["content"].as_array().unwrap();
6486 assert_eq!(blocks[0]["type"], "text");
6487 assert_eq!(blocks[0]["text"], "already flat");
6488 }
6489
6490 #[test]
6491 fn flatten_content_blocks_no_content_field() {
6492 let mut value = json!({"role": "assistant"});
6493 rpc_flatten_content_blocks(&mut value); assert_eq!(value, json!({"role": "assistant"}));
6495 }
6496
6497 #[test]
6498 fn flatten_content_blocks_non_object() {
6499 let mut value = json!("just a string");
6500 rpc_flatten_content_blocks(&mut value); }
6502
6503 #[test]
6504 fn flatten_content_blocks_existing_keys_not_overwritten() {
6505 let mut value = json!({
6507 "content": [
6508 {"type": "existing", "0": {"type": "inner", "extra": "data"}}
6509 ]
6510 });
6511 rpc_flatten_content_blocks(&mut value);
6512 let blocks = value["content"].as_array().unwrap();
6513 assert_eq!(blocks[0]["type"], "existing");
6515 assert_eq!(blocks[0]["extra"], "data");
6517 }
6518
6519 #[test]
6524 fn parse_prompt_images_none() {
6525 let images = parse_prompt_images(None).unwrap();
6526 assert!(images.is_empty());
6527 }
6528
6529 #[test]
6530 fn parse_prompt_images_empty_array() {
6531 let val = json!([]);
6532 let images = parse_prompt_images(Some(&val)).unwrap();
6533 assert!(images.is_empty());
6534 }
6535
6536 #[test]
6537 fn parse_prompt_images_valid() {
6538 let val = json!([{
6539 "type": "image",
6540 "source": {
6541 "type": "base64",
6542 "mediaType": "image/png",
6543 "data": "iVBORw0KGgo="
6544 }
6545 }]);
6546 let images = parse_prompt_images(Some(&val)).unwrap();
6547 assert_eq!(images.len(), 1);
6548 assert_eq!(images[0].mime_type, "image/png");
6549 assert_eq!(images[0].data, "iVBORw0KGgo=");
6550 }
6551
6552 #[test]
6553 fn parse_prompt_images_skips_non_image_type() {
6554 let val = json!([{
6555 "type": "text",
6556 "text": "hello"
6557 }]);
6558 let images = parse_prompt_images(Some(&val)).unwrap();
6559 assert!(images.is_empty());
6560 }
6561
6562 #[test]
6563 fn parse_prompt_images_skips_non_base64_source() {
6564 let val = json!([{
6565 "type": "image",
6566 "source": {
6567 "type": "url",
6568 "url": "https://example.com/img.png"
6569 }
6570 }]);
6571 let images = parse_prompt_images(Some(&val)).unwrap();
6572 assert!(images.is_empty());
6573 }
6574
6575 #[test]
6576 fn parse_prompt_images_not_array_errors() {
6577 let val = json!("not-an-array");
6578 assert!(parse_prompt_images(Some(&val)).is_err());
6579 }
6580
6581 #[test]
6582 fn parse_prompt_images_multiple_valid() {
6583 let val = json!([
6584 {
6585 "type": "image",
6586 "source": {"type": "base64", "mediaType": "image/jpeg", "data": "abc"}
6587 },
6588 {
6589 "type": "image",
6590 "source": {"type": "base64", "mediaType": "image/webp", "data": "def"}
6591 }
6592 ]);
6593 let images = parse_prompt_images(Some(&val)).unwrap();
6594 assert_eq!(images.len(), 2);
6595 assert_eq!(images[0].mime_type, "image/jpeg");
6596 assert_eq!(images[1].mime_type, "image/webp");
6597 }
6598
6599 #[test]
6604 fn extract_user_text_from_text_content() {
6605 let content = UserContent::Text("hello world".to_string());
6606 assert_eq!(extract_user_text(&content), Some("hello world".to_string()));
6607 }
6608
6609 #[test]
6610 fn extract_user_text_from_blocks() {
6611 let content = UserContent::Blocks(vec![
6612 ContentBlock::Image(ImageContent {
6613 data: String::new(),
6614 mime_type: "image/png".to_string(),
6615 }),
6616 ContentBlock::Text(TextContent::new("found it")),
6617 ]);
6618 assert_eq!(extract_user_text(&content), Some("found it".to_string()));
6619 }
6620
6621 #[test]
6622 fn extract_user_text_blocks_no_text() {
6623 let content = UserContent::Blocks(vec![ContentBlock::Image(ImageContent {
6624 data: String::new(),
6625 mime_type: "image/png".to_string(),
6626 })]);
6627 assert_eq!(extract_user_text(&content), None);
6628 }
6629
6630 #[test]
6635 fn parse_thinking_level_all_variants() {
6636 assert_eq!(parse_thinking_level("off").unwrap(), ThinkingLevel::Off);
6637 assert_eq!(parse_thinking_level("none").unwrap(), ThinkingLevel::Off);
6638 assert_eq!(parse_thinking_level("0").unwrap(), ThinkingLevel::Off);
6639 assert_eq!(
6640 parse_thinking_level("minimal").unwrap(),
6641 ThinkingLevel::Minimal
6642 );
6643 assert_eq!(parse_thinking_level("min").unwrap(), ThinkingLevel::Minimal);
6644 assert_eq!(parse_thinking_level("low").unwrap(), ThinkingLevel::Low);
6645 assert_eq!(parse_thinking_level("1").unwrap(), ThinkingLevel::Low);
6646 assert_eq!(
6647 parse_thinking_level("medium").unwrap(),
6648 ThinkingLevel::Medium
6649 );
6650 assert_eq!(parse_thinking_level("med").unwrap(), ThinkingLevel::Medium);
6651 assert_eq!(parse_thinking_level("2").unwrap(), ThinkingLevel::Medium);
6652 assert_eq!(parse_thinking_level("high").unwrap(), ThinkingLevel::High);
6653 assert_eq!(parse_thinking_level("3").unwrap(), ThinkingLevel::High);
6654 assert_eq!(parse_thinking_level("xhigh").unwrap(), ThinkingLevel::XHigh);
6655 assert_eq!(parse_thinking_level("4").unwrap(), ThinkingLevel::XHigh);
6656 }
6657
6658 #[test]
6659 fn parse_thinking_level_case_insensitive() {
6660 assert_eq!(parse_thinking_level("HIGH").unwrap(), ThinkingLevel::High);
6661 assert_eq!(
6662 parse_thinking_level("Medium").unwrap(),
6663 ThinkingLevel::Medium
6664 );
6665 assert_eq!(parse_thinking_level(" Off ").unwrap(), ThinkingLevel::Off);
6666 }
6667
6668 #[test]
6669 fn parse_thinking_level_invalid() {
6670 assert!(parse_thinking_level("invalid").is_err());
6671 assert!(parse_thinking_level("").is_err());
6672 assert!(parse_thinking_level("5").is_err());
6673 }
6674
6675 #[test]
6680 fn supports_xhigh_known_models() {
6681 assert!(dummy_entry("gpt-5.1-codex-max", true).supports_xhigh());
6682 assert!(dummy_entry("gpt-5.2", true).supports_xhigh());
6683 assert!(dummy_entry("gpt-5.4", true).supports_xhigh());
6684 assert!(dummy_entry("gpt-5.2-codex", true).supports_xhigh());
6685 assert!(dummy_entry("gpt-5.3-codex", true).supports_xhigh());
6686 }
6687
6688 #[test]
6689 fn supports_xhigh_unknown_models() {
6690 assert!(!dummy_entry("claude-opus-4-6", true).supports_xhigh());
6691 assert!(!dummy_entry("gpt-4o", true).supports_xhigh());
6692 assert!(!dummy_entry("", true).supports_xhigh());
6693 }
6694
6695 #[test]
6696 fn clamp_thinking_non_reasoning_model() {
6697 let entry = dummy_entry("claude-3-haiku", false);
6698 assert_eq!(
6699 entry.clamp_thinking_level(ThinkingLevel::High),
6700 ThinkingLevel::Off
6701 );
6702 }
6703
6704 #[test]
6705 fn clamp_thinking_xhigh_without_support() {
6706 let entry = dummy_entry("claude-opus-4-6", true);
6707 assert_eq!(
6708 entry.clamp_thinking_level(ThinkingLevel::XHigh),
6709 ThinkingLevel::High
6710 );
6711 }
6712
6713 #[test]
6714 fn clamp_thinking_xhigh_with_support() {
6715 let entry = dummy_entry("gpt-5.2", true);
6716 assert_eq!(
6717 entry.clamp_thinking_level(ThinkingLevel::XHigh),
6718 ThinkingLevel::XHigh
6719 );
6720 }
6721
6722 #[test]
6723 fn clamp_thinking_normal_level_passthrough() {
6724 let entry = dummy_entry("claude-opus-4-6", true);
6725 assert_eq!(
6726 entry.clamp_thinking_level(ThinkingLevel::Medium),
6727 ThinkingLevel::Medium
6728 );
6729 }
6730
6731 #[test]
6736 fn available_thinking_levels_non_reasoning() {
6737 let entry = dummy_entry("gpt-4o-mini", false);
6738 let levels = available_thinking_levels(&entry);
6739 assert_eq!(levels, vec![ThinkingLevel::Off]);
6740 }
6741
6742 #[test]
6743 fn available_thinking_levels_reasoning_no_xhigh() {
6744 let entry = dummy_entry("claude-opus-4-6", true);
6745 let levels = available_thinking_levels(&entry);
6746 assert_eq!(
6747 levels,
6748 vec![
6749 ThinkingLevel::Off,
6750 ThinkingLevel::Minimal,
6751 ThinkingLevel::Low,
6752 ThinkingLevel::Medium,
6753 ThinkingLevel::High,
6754 ]
6755 );
6756 }
6757
6758 #[test]
6759 fn available_thinking_levels_reasoning_with_xhigh() {
6760 let entry = dummy_entry("gpt-5.2", true);
6761 let levels = available_thinking_levels(&entry);
6762 assert_eq!(
6763 levels,
6764 vec![
6765 ThinkingLevel::Off,
6766 ThinkingLevel::Minimal,
6767 ThinkingLevel::Low,
6768 ThinkingLevel::Medium,
6769 ThinkingLevel::High,
6770 ThinkingLevel::XHigh,
6771 ]
6772 );
6773 }
6774
6775 #[test]
6780 fn rpc_model_from_entry_basic() {
6781 let entry = dummy_entry("claude-opus-4-6", true);
6782 let value = rpc_model_from_entry(&entry);
6783 assert_eq!(value["id"], "claude-opus-4-6");
6784 assert_eq!(value["name"], "claude-opus-4-6");
6785 assert_eq!(value["provider"], "anthropic");
6786 assert_eq!(value["reasoning"], true);
6787 assert_eq!(value["contextWindow"], 200_000);
6788 assert_eq!(value["maxTokens"], 8192);
6789 }
6790
6791 #[test]
6792 fn rpc_model_from_entry_input_types() {
6793 let mut entry = dummy_entry("gpt-4o", false);
6794 entry.model.input = vec![InputType::Text, InputType::Image];
6795 let value = rpc_model_from_entry(&entry);
6796 let input = value["input"].as_array().unwrap();
6797 assert_eq!(input.len(), 2);
6798 assert_eq!(input[0], "text");
6799 assert_eq!(input[1], "image");
6800 }
6801
6802 #[test]
6803 fn rpc_model_from_entry_cost_present() {
6804 let entry = dummy_entry("test-model", false);
6805 let value = rpc_model_from_entry(&entry);
6806 assert!(value.get("cost").is_some());
6807 let cost = &value["cost"];
6808 assert_eq!(cost["input"], 3.0);
6809 assert_eq!(cost["output"], 15.0);
6810 }
6811
6812 #[test]
6813 fn current_model_entry_matches_provider_alias_and_model_case() {
6814 let mut model = dummy_entry("gpt-4o-mini", true);
6815 model.model.provider = "openrouter".to_string();
6816 let options = rpc_options_with_models(vec![model]);
6817
6818 let mut session = Session::in_memory();
6819 session.header.provider = Some("open-router".to_string());
6820 session.header.model_id = Some("GPT-4O-MINI".to_string());
6821
6822 let resolved = current_model_entry(&session, &options).expect("resolve aliased model");
6823 assert_eq!(resolved.model.provider, "openrouter");
6824 assert_eq!(resolved.model.id, "gpt-4o-mini");
6825 }
6826
6827 #[test]
6828 fn current_or_runtime_model_entry_falls_back_when_header_is_unresolved() {
6829 let mut runtime = dummy_entry("test-model", false);
6830 runtime.model.provider = "test-provider".to_string();
6831 let options = rpc_options_with_models(vec![runtime]);
6832
6833 let mut session = Session::in_memory();
6834 session.header.provider = Some("missing-provider".to_string());
6835 session.header.model_id = Some("missing-model".to_string());
6836
6837 let resolved =
6838 current_or_runtime_model_entry(&session, "test-provider", "test-model", &options)
6839 .expect("resolve runtime fallback");
6840 assert_eq!(resolved.model.provider, "test-provider");
6841 assert_eq!(resolved.model.id, "test-model");
6842 assert_eq!(
6843 resolved.clamp_thinking_level(ThinkingLevel::High),
6844 ThinkingLevel::Off
6845 );
6846 }
6847
6848 #[test]
6849 fn cycle_model_for_rpc_does_not_mutate_provider_when_credentials_are_missing() {
6850 let runtime = asupersync::runtime::RuntimeBuilder::new()
6851 .blocking_threads(1, 1)
6852 .build()
6853 .expect("runtime build");
6854
6855 runtime.block_on(async move {
6856 let mut current = dummy_entry("gpt-4o-mini", true);
6857 current.model.provider = "openai".to_string();
6858 current.model.api = "openai-completions".to_string();
6859 current.model.base_url = "https://api.openai.com/v1".to_string();
6860 current.auth_header = true;
6861
6862 let next = ModelEntry {
6863 model: Model {
6864 id: "cloud-model".to_string(),
6865 name: "cloud-model".to_string(),
6866 api: "openai-completions".to_string(),
6867 provider: "acme-remote".to_string(),
6868 base_url: "https://example.invalid/v1".to_string(),
6869 reasoning: true,
6870 input: vec![InputType::Text],
6871 cost: ModelCost {
6872 input: 0.0,
6873 output: 0.0,
6874 cache_read: 0.0,
6875 cache_write: 0.0,
6876 },
6877 context_window: 128_000,
6878 max_tokens: 8_192,
6879 headers: HashMap::new(),
6880 },
6881 api_key: None,
6882 headers: HashMap::new(),
6883 auth_header: true,
6884 compat: None,
6885 oauth_config: None,
6886 };
6887
6888 let provider =
6889 crate::providers::create_provider(¤t, None).expect("create current provider");
6890 let agent = Agent::new(
6891 provider,
6892 ToolRegistry::new(&[], Path::new("."), None),
6893 AgentConfig::default(),
6894 );
6895
6896 let mut session = Session::in_memory();
6897 session.header.provider = Some(current.model.provider.clone());
6898 session.header.model_id = Some(current.model.id.clone());
6899 let mut agent_session = AgentSession::new(
6900 agent,
6901 Arc::new(asupersync::sync::Mutex::new(session)),
6902 false,
6903 crate::compaction::ResolvedCompactionSettings::default(),
6904 );
6905
6906 let options = rpc_options_with_models(vec![current.clone(), next]);
6907 let err = cycle_model_for_rpc(&mut agent_session, &options)
6908 .await
6909 .expect_err("missing credentials should abort model cycling");
6910 assert!(
6911 err.to_string().contains("Missing credentials"),
6912 "unexpected error: {err}"
6913 );
6914 assert_eq!(
6915 agent_session.agent.provider().name(),
6916 current.model.provider
6917 );
6918 assert_eq!(agent_session.agent.provider().model_id(), current.model.id);
6919
6920 let cx = AgentCx::for_request();
6921 let session = agent_session
6922 .session
6923 .lock(cx.cx())
6924 .await
6925 .expect("session lock");
6926 assert_eq!(
6927 session.header.provider.as_deref(),
6928 Some(current.model.provider.as_str())
6929 );
6930 assert_eq!(
6931 session.header.model_id.as_deref(),
6932 Some(current.model.id.as_str())
6933 );
6934 });
6935 }
6936
6937 #[test]
6938 fn cycle_model_for_rpc_uses_runtime_model_when_header_is_missing() {
6939 let runtime = asupersync::runtime::RuntimeBuilder::new()
6940 .blocking_threads(1, 1)
6941 .build()
6942 .expect("runtime build");
6943
6944 runtime.block_on(async move {
6945 let mut current = dummy_entry("test-model", false);
6946 current.model.provider = "test-provider".to_string();
6947 current.model.api = "test-api".to_string();
6948 current.model.base_url = "https://example.test/v1".to_string();
6949
6950 let mut next = dummy_entry("after-runtime", true);
6951 next.api_key = Some("inline-next-key".to_string());
6952 let options = rpc_options_with_models(vec![current, next.clone()]);
6953
6954 let mut agent_session = build_test_agent_session(Session::in_memory());
6955 let result = cycle_model_for_rpc(&mut agent_session, &options)
6956 .await
6957 .expect("cycle should succeed")
6958 .expect("should choose next model");
6959
6960 assert_eq!(result.0.model.provider, next.model.provider);
6961 assert_eq!(result.0.model.id, next.model.id);
6962 assert_eq!(agent_session.agent.provider().name(), next.model.provider);
6963 assert_eq!(agent_session.agent.provider().model_id(), next.model.id);
6964
6965 let cx = AgentCx::for_request();
6966 let session = agent_session
6967 .session
6968 .lock(cx.cx())
6969 .await
6970 .expect("session lock");
6971 assert_eq!(
6972 session.header.provider.as_deref(),
6973 Some(next.model.provider.as_str())
6974 );
6975 assert_eq!(
6976 session.header.model_id.as_deref(),
6977 Some(next.model.id.as_str())
6978 );
6979 });
6980 }
6981
6982 #[test]
6983 fn cycle_model_for_rpc_uses_cli_api_key_override_for_remote_model() {
6984 let runtime = asupersync::runtime::RuntimeBuilder::new()
6985 .blocking_threads(1, 1)
6986 .build()
6987 .expect("runtime build");
6988
6989 runtime.block_on(async move {
6990 let mut current = dummy_entry("test-model", false);
6991 current.model.provider = "test-provider".to_string();
6992 current.model.api = "test-api".to_string();
6993 current.model.base_url = "https://example.test/v1".to_string();
6994 current.auth_header = false;
6995 current.api_key = None;
6996
6997 let mut next = dummy_entry("cloud-model", true);
6998 next.model.provider = "openai".to_string();
6999 next.model.api = "openai-completions".to_string();
7000 next.model.base_url = "https://api.openai.com/v1".to_string();
7001 next.auth_header = true;
7002 next.api_key = None;
7003
7004 let mut options = rpc_options_with_models(vec![current, next.clone()]);
7005 options.cli_api_key = Some("cli-override-key".to_string());
7006
7007 let mut agent_session = build_test_agent_session(Session::in_memory());
7008 let result = cycle_model_for_rpc(&mut agent_session, &options)
7009 .await
7010 .expect("cycle should succeed")
7011 .expect("should choose next model");
7012
7013 assert_eq!(result.0.model.provider, next.model.provider);
7014 assert_eq!(result.0.model.id, next.model.id);
7015 assert_eq!(
7016 agent_session.agent.stream_options().api_key.as_deref(),
7017 Some("cli-override-key")
7018 );
7019 });
7020 }
7021
7022 #[test]
7023 fn apply_thinking_level_inherits_cancelled_context_when_session_lock_is_held() {
7024 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
7025 .build()
7026 .expect("runtime build");
7027
7028 runtime.block_on(async {
7029 let agent_session = build_test_agent_session(Session::in_memory());
7030 let session_handle = Arc::new(asupersync::sync::Mutex::new(agent_session));
7031 let inner_session_handle = {
7032 let guard = session_handle.lock(&AgentCx::for_request()).await.expect("session lock");
7033 Arc::clone(&guard.session)
7034 };
7035 let hold_cx = AgentCx::for_request();
7036 let held_guard = inner_session_handle
7037 .lock(hold_cx.cx())
7038 .await
7039 .expect("session lock");
7040
7041 let ambient_cx = asupersync::Cx::for_testing();
7042 ambient_cx.set_cancel_requested(true);
7043 let _current = asupersync::Cx::set_current(Some(ambient_cx));
7044
7045 let err = {
7046 let apply = apply_thinking_level(Arc::clone(&session_handle), ThinkingLevel::High);
7047 futures::pin_mut!(apply);
7048 let inner = asupersync::time::timeout(
7049 asupersync::time::wall_now(),
7050 Duration::from_millis(100),
7051 apply,
7052 )
7053 .await;
7054 let outcome =
7055 inner.expect("cancelled thinking helper should finish before timeout");
7056 outcome.expect_err("lock acquisition should honor inherited cancellation")
7057 };
7058 assert!(
7059 err.to_string().contains("inner session lock failed"),
7060 "unexpected error: {err}"
7061 );
7062
7063 drop(held_guard);
7064
7065 let verify_cx = AgentCx::for_request();
7066 let session_arc = {
7067 let guard = session_handle.lock(&verify_cx).await.expect("session lock");
7068 Arc::clone(&guard.session)
7069 };
7070 let session = session_arc.lock(verify_cx.cx()).await.expect("session lock");
7071 assert!(session.header.thinking_level.is_none());
7072 drop(session);
7073 let agent_thinking_level = {
7074 let guard = session_handle.lock(&verify_cx).await.expect("session lock");
7075 guard.agent.stream_options().thinking_level
7076 };
7077 assert!(agent_thinking_level.is_none());
7078 });
7079 }
7080
7081 #[test]
7082 fn apply_thinking_level_canonicalizes_header_without_duplicate_history() {
7083 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
7084 .build()
7085 .expect("runtime build");
7086
7087 runtime.block_on(async {
7088 let mut session = Session::in_memory();
7089 session.header.thinking_level = Some("HIGH".to_string());
7090 let agent_session = build_test_agent_session(session);
7091 let session_handle = Arc::new(asupersync::sync::Mutex::new(agent_session));
7092
7093 apply_thinking_level(Arc::clone(&session_handle), ThinkingLevel::High)
7094 .await
7095 .expect("apply thinking level");
7096
7097 let verify_cx = AgentCx::for_request();
7098 let guard = session_handle
7099 .lock(&verify_cx)
7100 .await
7101 .expect("session lock");
7102 let session = guard
7103 .session
7104 .lock(verify_cx.cx())
7105 .await
7106 .expect("session lock");
7107 assert_eq!(session.header.thinking_level.as_deref(), Some("high"));
7108 let thinking_changes = session
7109 .entries
7110 .iter()
7111 .filter(|entry| {
7112 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
7113 })
7114 .count();
7115 assert_eq!(thinking_changes, 0);
7116 drop(session);
7117
7118 assert_eq!(
7119 guard.agent.stream_options().thinking_level,
7120 Some(ThinkingLevel::High)
7121 );
7122 });
7123 }
7124
7125 #[test]
7126 fn rpc_set_model_persists_clamped_thinking_header_even_when_runtime_is_already_off() {
7127 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
7128 .build()
7129 .expect("runtime build");
7130 let handle = runtime.handle();
7131
7132 runtime.block_on(async move {
7133 let mut next = dummy_entry("llama3.2", false);
7134 next.model.provider = "ollama".to_string();
7135 next.model.api = "openai-completions".to_string();
7136 next.model.base_url = "http://127.0.0.1:11434/v1".to_string();
7137
7138 let temp = tempfile::tempdir().expect("tempdir");
7139 let auth_path = temp.path().join("auth.json");
7140 let mut options = build_test_rpc_options(&handle, auth_path);
7141 options.available_models = vec![next.clone()];
7142
7143 let agent_session = build_test_agent_session(Session::in_memory());
7144 let session_handle = Arc::clone(&agent_session.session);
7145 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(8);
7146 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
7147 let out_rx = Arc::new(Mutex::new(out_rx));
7148
7149 let server =
7150 handle.spawn(async move { run(agent_session, options, in_rx, out_tx).await });
7151
7152 let response = send_recv(
7153 &in_tx,
7154 &out_rx,
7155 r#"{"id":"1","type":"set_model","provider":"ollama","modelId":"llama3.2"}"#,
7156 "set_model(sync-thinking)",
7157 )
7158 .await;
7159 assert_ok(&response, "set_model");
7160
7161 drop(in_tx);
7162 let result = server.await;
7163 assert!(result.is_ok(), "rpc server error: {result:?}");
7164
7165 let verify_cx = AgentCx::for_request();
7166 let session = session_handle
7167 .lock(verify_cx.cx())
7168 .await
7169 .expect("session lock");
7170 assert_eq!(session.header.provider.as_deref(), Some("ollama"));
7171 assert_eq!(session.header.model_id.as_deref(), Some("llama3.2"));
7172 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
7173 let thinking_changes = session
7174 .entries
7175 .iter()
7176 .filter(|entry| {
7177 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
7178 })
7179 .count();
7180 assert_eq!(thinking_changes, 1);
7181 });
7182 }
7183
7184 #[test]
7185 fn rpc_prompt_command_inherits_deadline_from_run() {
7186 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
7187 .build()
7188 .expect("runtime build");
7189 let runtime_handle = runtime.handle();
7190
7191 runtime.block_on(async move {
7192 let state = Arc::new(RpcDeadlineProbeState::default());
7193 let provider: Arc<dyn Provider> = Arc::new(RpcDeadlineProbeProvider {
7194 state: Arc::clone(&state),
7195 });
7196 let agent_session =
7197 build_test_agent_session_with_provider(Session::in_memory(), provider);
7198
7199 let auth_path = tempfile::tempdir()
7200 .expect("tempdir")
7201 .path()
7202 .join("auth.json");
7203 let options = build_test_rpc_options(&runtime_handle, auth_path);
7204
7205 let (in_tx, in_rx) = asupersync::channel::mpsc::channel::<String>(16);
7206 let (out_tx, out_rx) = std::sync::mpsc::sync_channel::<String>(1024);
7207 let out_rx = Arc::new(Mutex::new(out_rx));
7208
7209 let expected_deadline = asupersync::time::wall_now() + Duration::from_secs(30);
7210 let ambient_cx = AgentCx::for_request_with_budget(asupersync::Budget {
7211 deadline: Some(expected_deadline),
7212 ..asupersync::Budget::INFINITE
7213 });
7214 let _current = asupersync::Cx::set_current(Some(ambient_cx.cx().clone()));
7215
7216 let client_out_rx = Arc::clone(&out_rx);
7217 let client = async move {
7218 let response = send_recv(
7219 &in_tx,
7220 &client_out_rx,
7221 r#"{"id":"1","type":"prompt","message":"deadline please"}"#,
7222 "prompt(deadline)",
7223 )
7224 .await;
7225 assert_eq!(response["command"], "prompt");
7226 assert_eq!(
7227 response["success"], true,
7228 "prompt should succeed under inherited deadline: {response}"
7229 );
7230
7231 let mut saw_agent_end = false;
7234 for _ in 0..50 {
7235 let Ok(msg) = client_out_rx
7236 .lock()
7237 .expect("lock rx")
7238 .recv_timeout(Duration::from_secs(5))
7239 else {
7240 break;
7241 };
7242 if let Ok(json) = serde_json::from_str::<Value>(&msg) {
7243 if json["type"] == "agent_end" {
7244 saw_agent_end = true;
7245 break;
7246 }
7247 }
7248 }
7249 assert!(saw_agent_end, "expected agent_end event before dropping");
7250
7251 drop(in_tx);
7252 };
7253
7254 let (server_result, ()) =
7255 futures::future::join(run(agent_session, options, in_rx, out_tx), client).await;
7256 assert!(server_result.is_ok(), "rpc server error: {server_result:?}");
7257 assert_eq!(state.calls.load(Ordering::SeqCst), 1);
7258 let deadlines = state
7259 .observed_deadlines
7260 .lock()
7261 .expect("lock rpc deadline probe")
7262 .clone();
7263 assert_eq!(deadlines.as_slice(), &[Some(expected_deadline)]);
7264 });
7265 }
7266
7267 #[test]
7268 fn cycle_model_for_rpc_inherits_cancelled_context_when_session_lock_is_held() {
7269 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
7270 .build()
7271 .expect("runtime build");
7272
7273 runtime.block_on(async {
7274 let current = dummy_entry("current-model", true);
7275 let mut next = dummy_entry("next-model", true);
7276 next.api_key = Some("inline-next-key".to_string());
7277
7278 let provider =
7279 crate::providers::create_provider(¤t, None).expect("create current provider");
7280 let agent = Agent::new(
7281 provider,
7282 ToolRegistry::new(&[], Path::new("."), None),
7283 AgentConfig::default(),
7284 );
7285
7286 let mut session = Session::in_memory();
7287 session.header.provider = Some(current.model.provider.clone());
7288 session.header.model_id = Some(current.model.id.clone());
7289 let mut agent_session = AgentSession::new(
7290 agent,
7291 Arc::new(asupersync::sync::Mutex::new(session)),
7292 false,
7293 crate::compaction::ResolvedCompactionSettings::default(),
7294 );
7295 let options = rpc_options_with_models(vec![current.clone(), next]);
7296 let session_handle = Arc::clone(&agent_session.session);
7297
7298 let hold_cx = AgentCx::for_request();
7299 let held_guard = session_handle
7300 .lock(hold_cx.cx())
7301 .await
7302 .expect("session lock");
7303
7304 let ambient_cx = asupersync::Cx::for_testing();
7305 ambient_cx.set_cancel_requested(true);
7306 let _current = asupersync::Cx::set_current(Some(ambient_cx));
7307
7308 let err = {
7309 let cycle = cycle_model_for_rpc(&mut agent_session, &options);
7310 futures::pin_mut!(cycle);
7311 let inner = asupersync::time::timeout(
7312 asupersync::time::wall_now(),
7313 Duration::from_millis(100),
7314 cycle,
7315 )
7316 .await;
7317 let outcome = inner.expect("cancelled cycle helper should finish before timeout");
7318 outcome.expect_err("lock acquisition should honor inherited cancellation")
7319 };
7320 assert!(
7321 err.to_string().contains("inner session lock failed"),
7322 "unexpected error: {err}"
7323 );
7324
7325 drop(held_guard);
7326
7327 assert_eq!(
7328 agent_session.agent.provider().name(),
7329 current.model.provider
7330 );
7331 assert_eq!(agent_session.agent.provider().model_id(), current.model.id);
7332
7333 let verify_cx = AgentCx::for_request();
7334 let session = agent_session
7335 .session
7336 .lock(verify_cx.cx())
7337 .await
7338 .expect("session lock");
7339 assert_eq!(
7340 session.header.provider.as_deref(),
7341 Some(current.model.provider.as_str())
7342 );
7343 assert_eq!(
7344 session.header.model_id.as_deref(),
7345 Some(current.model.id.as_str())
7346 );
7347 });
7348 }
7349
7350 #[test]
7351 fn session_state_resolves_model_for_provider_alias() {
7352 let mut model = dummy_entry("gpt-4o-mini", true);
7353 model.model.provider = "openrouter".to_string();
7354 let options = rpc_options_with_models(vec![model]);
7355
7356 let mut session = Session::in_memory();
7357 session.header.provider = Some("open-router".to_string());
7358 session.header.model_id = Some("gpt-4o-mini".to_string());
7359
7360 let snapshot = RpcStateSnapshot {
7361 steering_count: 0,
7362 follow_up_count: 0,
7363 steering_mode: QueueMode::OneAtATime,
7364 follow_up_mode: QueueMode::OneAtATime,
7365 auto_compaction_enabled: false,
7366 auto_retry_enabled: false,
7367 };
7368
7369 let state = session_state(&session, &options, &snapshot, false, false);
7370 assert_eq!(state["model"]["provider"], "openrouter");
7371 assert_eq!(state["model"]["id"], "gpt-4o-mini");
7372 }
7373
7374 #[test]
7379 fn error_hints_value_produces_expected_shape() {
7380 let error = Error::validation("test error");
7381 let value = error_hints_value(&error);
7382 assert!(value.get("summary").is_some());
7383 assert!(value.get("hints").is_some());
7384 assert!(value.get("contextFields").is_some());
7385 assert!(value["hints"].is_array());
7386 }
7387
7388 #[test]
7393 fn parse_ui_response_id_empty_string() {
7394 let value = json!({"requestId": ""});
7395 assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
7396 }
7397
7398 #[test]
7399 fn parse_ui_response_id_whitespace_only() {
7400 let value = json!({"requestId": " "});
7401 assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
7402 }
7403
7404 #[test]
7405 fn parse_ui_response_id_trims() {
7406 let value = json!({"requestId": " req-1 "});
7407 assert_eq!(
7408 rpc_parse_extension_ui_response_id(&value),
7409 Some("req-1".to_string())
7410 );
7411 }
7412
7413 #[test]
7414 fn parse_ui_response_id_prefers_request_id_over_id_alias() {
7415 let value = json!({"requestId": "req-1", "id": "legacy-id"});
7416 assert_eq!(
7417 rpc_parse_extension_ui_response_id(&value),
7418 Some("req-1".to_string())
7419 );
7420 }
7421
7422 #[test]
7423 fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_not_string() {
7424 let value = json!({"requestId": 123, "id": "legacy-id"});
7425 assert_eq!(
7426 rpc_parse_extension_ui_response_id(&value),
7427 Some("legacy-id".to_string())
7428 );
7429 }
7430
7431 #[test]
7432 fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_blank() {
7433 let value = json!({"requestId": "", "id": "legacy-id"});
7434 assert_eq!(
7435 rpc_parse_extension_ui_response_id(&value),
7436 Some("legacy-id".to_string())
7437 );
7438 }
7439
7440 #[test]
7441 fn parse_ui_response_id_falls_back_to_id_alias_when_request_id_whitespace() {
7442 let value = json!({"requestId": " ", "id": "legacy-id"});
7443 assert_eq!(
7444 rpc_parse_extension_ui_response_id(&value),
7445 Some("legacy-id".to_string())
7446 );
7447 }
7448
7449 #[test]
7450 fn parse_ui_response_id_neither_field() {
7451 let value = json!({"type": "something"});
7452 assert_eq!(rpc_parse_extension_ui_response_id(&value), None);
7453 }
7454
7455 #[test]
7460 fn parse_editor_response_requires_string() {
7461 let active = ExtensionUiRequest::new("req-1", "editor", json!({"title": "t"}));
7462 let ok = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "code"});
7463 assert!(rpc_parse_extension_ui_response(&ok, &active).is_ok());
7464
7465 let bad = json!({"type": "extension_ui_response", "requestId": "req-1", "value": 42});
7466 assert!(rpc_parse_extension_ui_response(&bad, &active).is_err());
7467 }
7468
7469 #[test]
7470 fn parse_notify_response_returns_ack() {
7471 let active = ExtensionUiRequest::new("req-1", "notify", json!({"title": "t"}));
7472 let val = json!({"type": "extension_ui_response", "requestId": "req-1"});
7473 let resp = rpc_parse_extension_ui_response(&val, &active).unwrap();
7474 assert!(!resp.cancelled);
7475 }
7476
7477 #[test]
7478 fn parse_unknown_method_errors() {
7479 let active = ExtensionUiRequest::new("req-1", "unknown_method", json!({}));
7480 let val = json!({"type": "extension_ui_response", "requestId": "req-1"});
7481 assert!(rpc_parse_extension_ui_response(&val, &active).is_err());
7482 }
7483
7484 #[test]
7485 fn parse_select_with_object_options() {
7486 let active = ExtensionUiRequest::new(
7487 "req-1",
7488 "select",
7489 json!({"title": "pick", "options": [{"label": "Alpha", "value": "a"}, {"label": "Beta"}]}),
7490 );
7491 let val_a = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "a"});
7493 let resp = rpc_parse_extension_ui_response(&val_a, &active).unwrap();
7494 assert_eq!(resp.value, Some(json!("a")));
7495
7496 let val_b = json!({"type": "extension_ui_response", "requestId": "req-1", "value": "Beta"});
7498 let resp = rpc_parse_extension_ui_response(&val_b, &active).unwrap();
7499 assert_eq!(resp.value, Some(json!("Beta")));
7500 }
7501}