1use crate::auth::AuthStorage;
16use tracing::warn;
17use crate::compaction::{self, ResolvedCompactionSettings};
18use crate::compaction_worker::{CompactionQuota, CompactionWorkerState};
19use crate::error::{Error, Result};
20use crate::extension_events::{
21 BeforeAgentStartOutcome, InputEventOutcome, SessionBeforeCompactOutcome,
22 apply_before_agent_start_response, apply_input_event_response,
23 apply_session_before_compact_response,
24};
25use crate::extension_tools::collect_extension_tool_wrappers;
26use crate::extensions::{
27 EXTENSION_EVENT_TIMEOUT_MS, ExtensionDeliverAs, ExtensionEventName, ExtensionHostActions,
28 ExtensionLoadSpec, ExtensionManager, ExtensionPolicy, ExtensionRegion, ExtensionRuntimeHandle,
29 ExtensionSendMessage, ExtensionSendUserMessage, JsExtensionLoadSpec, JsExtensionRuntimeHandle,
30 NativeRustExtensionLoadSpec, NativeRustExtensionRuntimeHandle, RepairPolicyMode,
31 resolve_extension_load_spec,
32};
33#[cfg(feature = "wasm-host")]
34use crate::extensions::{WasmExtensionHost, WasmExtensionLoadSpec};
35use crate::extensions_js::{PiJsRuntimeConfig, RepairMode};
36use crate::model::{
37 AssistantMessage, AssistantMessageEvent, ContentBlock, CustomMessage, ImageContent, Message,
38 StopReason, StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage,
39 UserContent, UserMessage,
40};
41use crate::models::{ModelEntry, ModelRegistry, model_requires_configured_credential};
42use crate::provider::{Context, Provider, StreamOptions, ToolDef};
43use crate::session::{AutosaveFlushTrigger, Session, SessionHandle};
44use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
45use asupersync::runtime::{Runtime, RuntimeBuilder, RuntimeHandle};
46use asupersync::sync::{Mutex, Notify};
47use async_trait::async_trait;
48use chrono::Utc;
49use futures::FutureExt;
50use futures::StreamExt;
51use futures::future::BoxFuture;
52use futures::stream;
53use serde::Serialize;
54use serde_json::{Value, json};
55use std::borrow::Cow;
56use std::collections::VecDeque;
57use std::sync::Arc;
58use std::sync::Mutex as StdMutex;
59use std::sync::atomic::{AtomicBool, Ordering};
60
61const MAX_CONCURRENT_TOOLS: usize = 8;
62const MAX_STEERING_QUEUE_SIZE: usize = 100;
64const MAX_FOLLOW_UP_QUEUE_SIZE: usize = 100;
66const MAX_AGENT_MESSAGES: usize = 10_000;
68
69#[derive(Debug, Clone)]
75pub struct AgentConfig {
76 pub system_prompt: Option<String>,
78
79 pub max_tool_iterations: usize,
81
82 pub stream_options: StreamOptions,
84
85 pub block_images: bool,
87
88 pub fail_closed_hooks: bool,
90}
91
92impl Default for AgentConfig {
93 fn default() -> Self {
94 Self {
95 system_prompt: None,
96 max_tool_iterations: 50,
97 stream_options: StreamOptions::default(),
98 block_images: false,
99 fail_closed_hooks: false,
100 }
101 }
102}
103
104pub type MessageFetcher = Arc<dyn Fn() -> BoxFuture<'static, Vec<Message>> + Send + Sync + 'static>;
106
107type AgentEventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync + 'static>;
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum QueueMode {
111 All,
112 OneAtATime,
113}
114
115impl QueueMode {
116 pub const fn as_str(self) -> &'static str {
117 match self {
118 Self::All => "all",
119 Self::OneAtATime => "one-at-a-time",
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum InputSource {
126 Interactive,
127 Rpc,
128 Extension,
129}
130
131impl InputSource {
132 pub const fn as_str(self) -> &'static str {
133 match self {
134 Self::Interactive => "interactive",
135 Self::Rpc => "rpc",
136 Self::Extension => "extension",
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
142enum QueueKind {
143 Steering,
144 FollowUp,
145}
146
147#[derive(Debug, Clone)]
148struct QueuedMessage {
149 seq: u64,
150 enqueued_at: i64,
151 message: Message,
152}
153
154#[derive(Debug)]
155struct MessageQueue {
156 steering: VecDeque<QueuedMessage>,
157 follow_up: VecDeque<QueuedMessage>,
158 steering_mode: QueueMode,
159 follow_up_mode: QueueMode,
160 next_seq: u64,
161}
162
163impl MessageQueue {
164 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
165 Self {
166 steering: VecDeque::new(),
167 follow_up: VecDeque::new(),
168 steering_mode,
169 follow_up_mode,
170 next_seq: 0,
171 }
172 }
173
174 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
175 self.steering_mode = steering_mode;
176 self.follow_up_mode = follow_up_mode;
177 }
178
179 fn pending_count(&self) -> usize {
180 self.steering.len() + self.follow_up.len()
181 }
182
183 fn push(&mut self, kind: QueueKind, message: Message) -> u64 {
184 let seq = self.next_seq;
185 self.next_seq = self.next_seq.saturating_add(1);
186 let entry = QueuedMessage {
187 seq,
188 enqueued_at: Utc::now().timestamp_millis(),
189 message,
190 };
191 match kind {
192 QueueKind::Steering => {
193 if self.steering.len() >= MAX_STEERING_QUEUE_SIZE {
194 tracing::warn!(
195 "Steering queue full ({} messages), dropping oldest message",
196 MAX_STEERING_QUEUE_SIZE
197 );
198 self.steering.pop_front();
199 }
200 self.steering.push_back(entry);
201 }
202 QueueKind::FollowUp => {
203 if self.follow_up.len() >= MAX_FOLLOW_UP_QUEUE_SIZE {
204 tracing::warn!(
205 "Follow-up queue full ({} messages), dropping oldest message",
206 MAX_FOLLOW_UP_QUEUE_SIZE
207 );
208 self.follow_up.pop_front();
209 }
210 self.follow_up.push_back(entry);
211 }
212 }
213 seq
214 }
215
216 fn push_steering(&mut self, message: Message) -> u64 {
217 self.push(QueueKind::Steering, message)
218 }
219
220 fn push_follow_up(&mut self, message: Message) -> u64 {
221 self.push(QueueKind::FollowUp, message)
222 }
223
224 fn pop_steering(&mut self) -> Vec<Message> {
225 self.pop_kind(QueueKind::Steering)
226 }
227
228 fn pop_follow_up(&mut self) -> Vec<Message> {
229 self.pop_kind(QueueKind::FollowUp)
230 }
231
232 fn pop_kind(&mut self, kind: QueueKind) -> Vec<Message> {
233 let (queue, mode) = match kind {
234 QueueKind::Steering => (&mut self.steering, self.steering_mode),
235 QueueKind::FollowUp => (&mut self.follow_up, self.follow_up_mode),
236 };
237
238 match mode {
239 QueueMode::All => queue.drain(..).map(|entry| entry.message).collect(),
240 QueueMode::OneAtATime => queue
241 .pop_front()
242 .into_iter()
243 .map(|entry| entry.message)
244 .collect(),
245 }
246 }
247}
248
249#[derive(Debug, Clone, Serialize)]
255#[serde(tag = "type", rename_all = "snake_case")]
256pub enum AgentEvent {
257 AgentStart {
259 #[serde(rename = "sessionId")]
260 session_id: Arc<str>,
261 },
262 AgentEnd {
264 #[serde(rename = "sessionId")]
265 session_id: Arc<str>,
266 messages: Vec<Message>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 error: Option<String>,
269 },
270 TurnStart {
272 #[serde(rename = "sessionId")]
273 session_id: Arc<str>,
274 #[serde(rename = "turnIndex")]
275 turn_index: usize,
276 timestamp: i64,
277 },
278 TurnEnd {
280 #[serde(rename = "sessionId")]
281 session_id: Arc<str>,
282 #[serde(rename = "turnIndex")]
283 turn_index: usize,
284 message: Message,
285 #[serde(rename = "toolResults")]
286 tool_results: Vec<Message>,
287 },
288 MessageStart { message: Message },
290 MessageUpdate {
292 message: Message,
293 #[serde(rename = "assistantMessageEvent")]
294 assistant_message_event: AssistantMessageEvent,
295 },
296 MessageEnd { message: Message },
298 ToolExecutionStart {
300 #[serde(rename = "toolCallId")]
301 tool_call_id: String,
302 #[serde(rename = "toolName")]
303 tool_name: String,
304 args: serde_json::Value,
305 },
306 ToolExecutionUpdate {
308 #[serde(rename = "toolCallId")]
309 tool_call_id: String,
310 #[serde(rename = "toolName")]
311 tool_name: String,
312 args: serde_json::Value,
313 #[serde(rename = "partialResult")]
314 partial_result: ToolOutput,
315 },
316 ToolExecutionEnd {
318 #[serde(rename = "toolCallId")]
319 tool_call_id: String,
320 #[serde(rename = "toolName")]
321 tool_name: String,
322 result: ToolOutput,
323 #[serde(rename = "isError")]
324 is_error: bool,
325 },
326 AutoCompactionStart { reason: String },
328 AutoCompactionEnd {
330 #[serde(skip_serializing_if = "Option::is_none")]
331 result: Option<serde_json::Value>,
332 aborted: bool,
333 #[serde(rename = "willRetry")]
334 will_retry: bool,
335 #[serde(rename = "errorMessage", skip_serializing_if = "Option::is_none")]
336 error_message: Option<String>,
337 },
338 AutoRetryStart {
340 attempt: u32,
341 #[serde(rename = "maxAttempts")]
342 max_attempts: u32,
343 #[serde(rename = "delayMs")]
344 delay_ms: u64,
345 #[serde(rename = "errorMessage")]
346 error_message: String,
347 },
348 AutoRetryEnd {
350 success: bool,
351 attempt: u32,
352 #[serde(rename = "finalError", skip_serializing_if = "Option::is_none")]
353 final_error: Option<String>,
354 },
355 ExtensionError {
357 #[serde(rename = "extensionId", skip_serializing_if = "Option::is_none")]
358 extension_id: Option<String>,
359 event: String,
360 error: String,
361 },
362}
363
364#[derive(Debug, Clone)]
370pub struct AbortHandle {
371 inner: Arc<AbortSignalInner>,
372}
373
374#[derive(Debug, Clone)]
376pub struct AbortSignal {
377 inner: Arc<AbortSignalInner>,
378}
379
380#[derive(Debug)]
381struct AbortSignalInner {
382 aborted: AtomicBool,
383 notify: Notify,
384}
385
386impl AbortHandle {
387 #[must_use]
389 pub fn new() -> (Self, AbortSignal) {
390 let inner = Arc::new(AbortSignalInner {
391 aborted: AtomicBool::new(false),
392 notify: Notify::new(),
393 });
394 (
395 Self {
396 inner: Arc::clone(&inner),
397 },
398 AbortSignal { inner },
399 )
400 }
401
402 pub fn abort(&self) {
404 if !self.inner.aborted.swap(true, Ordering::SeqCst) {
405 self.inner.notify.notify_waiters();
406 }
407 }
408}
409
410impl AbortSignal {
411 #[must_use]
413 pub fn is_aborted(&self) -> bool {
414 self.inner.aborted.load(Ordering::SeqCst)
415 }
416
417 pub async fn wait(&self) {
418 if self.is_aborted() {
419 return;
420 }
421
422 loop {
423 self.inner.notify.notified().await;
424 if self.is_aborted() {
425 return;
426 }
427 }
428 }
429}
430
431pub struct Agent {
433 provider: Arc<dyn Provider>,
435
436 tools: ToolRegistry,
438
439 config: AgentConfig,
441
442 extensions: Option<ExtensionManager>,
444
445 messages: Vec<Message>,
447
448 steering_fetchers: Vec<MessageFetcher>,
450
451 follow_up_fetchers: Vec<MessageFetcher>,
453
454 message_queue: MessageQueue,
456
457 cached_tool_defs: Option<Vec<ToolDef>>,
459}
460
461impl Agent {
462 pub fn new(provider: Arc<dyn Provider>, tools: ToolRegistry, config: AgentConfig) -> Self {
464 Self {
465 provider,
466 tools,
467 config,
468 extensions: None,
469 messages: Vec::new(),
470 steering_fetchers: Vec::new(),
471 follow_up_fetchers: Vec::new(),
472 message_queue: MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime),
473 cached_tool_defs: None,
474 }
475 }
476
477 #[must_use]
479 pub fn messages(&self) -> &[Message] {
480 &self.messages
481 }
482
483 pub fn clear_messages(&mut self) {
485 self.messages.clear();
486 }
487
488 pub fn add_message(&mut self, message: Message) {
490 if self.messages.len() >= MAX_AGENT_MESSAGES {
491 tracing::warn!(
492 "Agent message history full ({} messages), dropping oldest message",
493 MAX_AGENT_MESSAGES
494 );
495 self.messages.remove(0);
496 }
497 self.messages.push(message);
498 }
499
500 pub fn replace_messages(&mut self, messages: Vec<Message>) {
502 self.messages = messages;
503 }
504
505 pub fn set_provider(&mut self, provider: Arc<dyn Provider>) {
507 self.provider = provider;
508 }
509
510 pub fn register_message_fetchers(
515 &mut self,
516 steering: Option<MessageFetcher>,
517 follow_up: Option<MessageFetcher>,
518 ) {
519 if let Some(fetcher) = steering {
520 self.steering_fetchers.push(fetcher);
521 }
522 if let Some(fetcher) = follow_up {
523 self.follow_up_fetchers.push(fetcher);
524 }
525 }
526
527 pub fn extend_tools<I>(&mut self, tools: I)
529 where
530 I: IntoIterator<Item = Box<dyn Tool>>,
531 {
532 self.tools.extend(tools);
533 self.cached_tool_defs = None; }
535
536 pub fn queue_steering(&mut self, message: Message) -> u64 {
538 self.message_queue.push_steering(message)
539 }
540
541 pub fn queue_follow_up(&mut self, message: Message) -> u64 {
543 self.message_queue.push_follow_up(message)
544 }
545
546 pub const fn set_queue_modes(&mut self, steering: QueueMode, follow_up: QueueMode) {
548 self.message_queue.set_modes(steering, follow_up);
549 }
550
551 pub const fn queue_modes(&self) -> (QueueMode, QueueMode) {
552 (
553 self.message_queue.steering_mode,
554 self.message_queue.follow_up_mode,
555 )
556 }
557
558 #[must_use]
560 pub fn queued_message_count(&self) -> usize {
561 self.message_queue.pending_count()
562 }
563
564 pub fn provider(&self) -> Arc<dyn Provider> {
565 Arc::clone(&self.provider)
566 }
567
568 pub const fn stream_options(&self) -> &StreamOptions {
569 &self.config.stream_options
570 }
571
572 pub const fn stream_options_mut(&mut self) -> &mut StreamOptions {
573 &mut self.config.stream_options
574 }
575
576 pub fn system_prompt(&self) -> Option<&str> {
577 self.config.system_prompt.as_deref()
578 }
579
580 pub fn set_system_prompt(&mut self, system_prompt: Option<String>) {
581 self.config.system_prompt = system_prompt;
582 }
583
584 fn build_context(&mut self) -> Context<'_> {
586 let messages: Cow<'_, [Message]> = if self.config.block_images {
587 let mut msgs = self.messages.clone();
588 msgs.retain(|m| match m {
590 Message::Custom(c) => c.display,
591 _ => true,
592 });
593 let stats = filter_images_for_provider(&mut msgs);
594 if stats.removed_images > 0 {
595 tracing::debug!(
596 filtered_images = stats.removed_images,
597 affected_messages = stats.affected_messages,
598 "Filtered image content from outbound provider context (images.block_images=true)"
599 );
600 }
601 Cow::Owned(msgs)
602 } else {
603 let has_hidden = self.messages.iter().any(|m| match m {
605 Message::Custom(c) => !c.display,
606 _ => false,
607 });
608
609 if has_hidden {
610 let mut msgs = self.messages.clone();
611 msgs.retain(|m| match m {
612 Message::Custom(c) => c.display,
613 _ => true,
614 });
615 Cow::Owned(msgs)
616 } else {
617 Cow::Borrowed(self.messages.as_slice())
618 }
619 };
620
621 if self.cached_tool_defs.is_none() {
623 let defs: Vec<ToolDef> = self
624 .tools
625 .tools()
626 .iter()
627 .map(|t| ToolDef {
628 name: t.name().to_string(),
629 description: t.description().to_string(),
630 parameters: t.parameters(),
631 })
632 .collect();
633 self.cached_tool_defs = Some(defs);
634 }
635 let tools = Cow::Borrowed(self.cached_tool_defs.as_deref().unwrap());
636
637 Context {
638 system_prompt: self.config.system_prompt.as_deref().map(Cow::Borrowed),
639 messages,
640 tools,
641 }
642 }
643
644 pub async fn run(
648 &mut self,
649 user_input: impl Into<String>,
650 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
651 ) -> Result<AssistantMessage> {
652 self.run_with_abort(user_input, None, on_event).await
653 }
654
655 pub async fn run_with_abort(
657 &mut self,
658 user_input: impl Into<String>,
659 abort: Option<AbortSignal>,
660 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
661 ) -> Result<AssistantMessage> {
662 let user_message = Message::User(UserMessage {
664 content: UserContent::Text(user_input.into()),
665 timestamp: Utc::now().timestamp_millis(),
666 });
667
668 self.run_loop(vec![user_message], Arc::new(on_event), abort)
670 .await
671 }
672
673 pub async fn run_with_content(
675 &mut self,
676 content: Vec<ContentBlock>,
677 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
678 ) -> Result<AssistantMessage> {
679 self.run_with_content_with_abort(content, None, on_event)
680 .await
681 }
682
683 pub async fn run_with_content_with_abort(
685 &mut self,
686 content: Vec<ContentBlock>,
687 abort: Option<AbortSignal>,
688 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
689 ) -> Result<AssistantMessage> {
690 let user_message = Message::User(UserMessage {
692 content: UserContent::Blocks(content),
693 timestamp: Utc::now().timestamp_millis(),
694 });
695
696 self.run_loop(vec![user_message], Arc::new(on_event), abort)
698 .await
699 }
700
701 pub async fn run_with_message_with_abort(
703 &mut self,
704 message: Message,
705 abort: Option<AbortSignal>,
706 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
707 ) -> Result<AssistantMessage> {
708 self.run_loop(vec![message], Arc::new(on_event), abort)
709 .await
710 }
711
712 pub async fn run_with_messages_with_abort(
714 &mut self,
715 messages: Vec<Message>,
716 abort: Option<AbortSignal>,
717 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
718 ) -> Result<AssistantMessage> {
719 self.run_loop(messages, Arc::new(on_event), abort).await
720 }
721
722 pub async fn run_continue_with_abort(
724 &mut self,
725 abort: Option<AbortSignal>,
726 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
727 ) -> Result<AssistantMessage> {
728 self.run_loop(Vec::new(), Arc::new(on_event), abort).await
729 }
730
731 fn build_abort_message(&self, partial: Option<&AssistantMessage>) -> AssistantMessage {
732 let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
733 content: Vec::new(),
734 api: self.provider.api().to_string(),
735 provider: self.provider.name().to_string(),
736 model: self.provider.model_id().to_string(),
737 usage: Usage::default(),
738 stop_reason: StopReason::Aborted,
739 error_message: Some("Aborted".to_string()),
740 timestamp: Utc::now().timestamp_millis(),
741 });
742 message.stop_reason = StopReason::Aborted;
743 message.error_message = Some("Aborted".to_string());
744 message.timestamp = Utc::now().timestamp_millis();
745 message
746 }
747
748 fn build_error_message(
749 &self,
750 partial: Option<&AssistantMessage>,
751 error_message: impl Into<String>,
752 ) -> AssistantMessage {
753 let error_message = error_message.into();
754 let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
755 content: Vec::new(),
756 api: self.provider.api().to_string(),
757 provider: self.provider.name().to_string(),
758 model: self.provider.model_id().to_string(),
759 usage: Usage::default(),
760 stop_reason: StopReason::Error,
761 error_message: Some(error_message.clone()),
762 timestamp: Utc::now().timestamp_millis(),
763 });
764 message.stop_reason = StopReason::Error;
765 message.error_message = Some(error_message);
766 message.timestamp = Utc::now().timestamp_millis();
767 message
768 }
769
770 #[allow(clippy::too_many_lines)]
772 async fn run_loop(
773 &mut self,
774 prompts: Vec<Message>,
775 on_event: AgentEventHandler,
776 abort: Option<AbortSignal>,
777 ) -> Result<AssistantMessage> {
778 let loop_cx = crate::agent_cx::AgentCx::for_current_or_request();
779 let session_id: Arc<str> = self
780 .config
781 .stream_options
782 .session_id
783 .as_deref()
784 .unwrap_or("")
785 .into();
786 let mut iterations = 0usize;
787 let mut turn_index: usize = 0;
788 let mut new_messages: Vec<Message> = Vec::with_capacity(prompts.len() + 8);
789 let mut last_assistant: Option<Arc<AssistantMessage>> = None;
790
791 let agent_start_event = AgentEvent::AgentStart {
792 session_id: session_id.clone(),
793 };
794 self.dispatch_extension_lifecycle_event(&agent_start_event)
795 .await;
796 on_event(agent_start_event);
797
798 for prompt in prompts {
799 self.messages.push(prompt.clone());
800 on_event(AgentEvent::MessageStart {
801 message: prompt.clone(),
802 });
803 on_event(AgentEvent::MessageEnd {
804 message: prompt.clone(),
805 });
806 new_messages.push(prompt);
807 }
808
809 let mut pending_messages = self.drain_steering_messages().await;
811
812 loop {
813 let mut has_more_tool_calls = true;
814 let mut steering_after_tools: Option<Vec<Message>> = None;
815
816 while has_more_tool_calls || !pending_messages.is_empty() {
817 let current_turn_index = turn_index;
818 let turn_start_event = AgentEvent::TurnStart {
819 session_id: session_id.clone(),
820 turn_index: current_turn_index,
821 timestamp: Utc::now().timestamp_millis(),
822 };
823 self.dispatch_extension_lifecycle_event(&turn_start_event)
824 .await;
825 on_event(turn_start_event);
826
827 for message in std::mem::take(&mut pending_messages) {
828 self.messages.push(message.clone());
829 on_event(AgentEvent::MessageStart {
830 message: message.clone(),
831 });
832 on_event(AgentEvent::MessageEnd {
833 message: message.clone(),
834 });
835 new_messages.push(message);
836 }
837
838 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
839 let abort_message = self.build_abort_message(None);
840 let message = Message::assistant(abort_message.clone());
841
842 self.messages.push(message.clone());
843 new_messages.push(message.clone());
844 on_event(AgentEvent::MessageStart {
845 message: message.clone(),
846 });
847 on_event(AgentEvent::MessageEnd {
848 message: message.clone(),
849 });
850
851 let turn_end_event = AgentEvent::TurnEnd {
852 session_id: session_id.clone(),
853 turn_index: current_turn_index,
854 message,
855 tool_results: Vec::new(),
856 };
857 self.dispatch_extension_lifecycle_event(&turn_end_event)
858 .await;
859 on_event(turn_end_event);
860 let agent_end_event = AgentEvent::AgentEnd {
861 session_id: session_id.clone(),
862 messages: std::mem::take(&mut new_messages),
863 error: Some(
864 abort_message
865 .error_message
866 .clone()
867 .unwrap_or_else(|| "Aborted".to_string()),
868 ),
869 };
870 self.dispatch_extension_lifecycle_event(&agent_end_event)
871 .await;
872 on_event(agent_end_event);
873 return Ok(abort_message);
874 }
875
876 let assistant_message = match self
877 .stream_assistant_response(Arc::clone(&on_event), abort.clone(), &loop_cx)
878 .await
879 {
880 Ok(msg) => msg,
881 Err(err) => {
882 let err_string = err.to_string();
883 let steering_to_add = self.drain_steering_messages().await;
884 for message in steering_to_add {
885 self.messages.push(message.clone());
886 on_event(AgentEvent::MessageStart {
887 message: message.clone(),
888 });
889 on_event(AgentEvent::MessageEnd {
890 message: message.clone(),
891 });
892 new_messages.push(message);
893 }
894
895 let error_message = self.build_error_message(None, err_string.clone());
896 let assistant_event_message = Message::assistant(error_message.clone());
897 self.messages.push(assistant_event_message.clone());
898 new_messages.push(assistant_event_message.clone());
899 on_event(AgentEvent::MessageStart {
900 message: assistant_event_message.clone(),
901 });
902 on_event(AgentEvent::MessageEnd {
903 message: assistant_event_message.clone(),
904 });
905
906 let turn_end_event = AgentEvent::TurnEnd {
907 session_id: session_id.clone(),
908 turn_index: current_turn_index,
909 message: assistant_event_message,
910 tool_results: Vec::new(),
911 };
912 self.dispatch_extension_lifecycle_event(&turn_end_event)
913 .await;
914 on_event(turn_end_event);
915
916 let agent_end_event = AgentEvent::AgentEnd {
917 session_id: session_id.clone(),
918 messages: std::mem::take(&mut new_messages),
919 error: Some(err_string),
920 };
921 self.dispatch_extension_lifecycle_event(&agent_end_event)
922 .await;
923 on_event(agent_end_event);
924 return Err(err);
925 }
926 };
927 let assistant_arc = Arc::new(assistant_message);
930 last_assistant = Some(Arc::clone(&assistant_arc));
931
932 let assistant_event_message = Message::Assistant(Arc::clone(&assistant_arc));
933 new_messages.push(assistant_event_message.clone());
934
935 if matches!(
936 assistant_arc.stop_reason,
937 StopReason::Error | StopReason::Aborted
938 ) {
939 let steering_to_add = self.drain_steering_messages().await;
940 for message in steering_to_add {
941 self.messages.push(message.clone());
942 on_event(AgentEvent::MessageStart {
943 message: message.clone(),
944 });
945 on_event(AgentEvent::MessageEnd {
946 message: message.clone(),
947 });
948 new_messages.push(message);
949 }
950
951 let turn_end_event = AgentEvent::TurnEnd {
952 session_id: session_id.clone(),
953 turn_index: current_turn_index,
954 message: assistant_event_message.clone(),
955 tool_results: Vec::new(),
956 };
957 self.dispatch_extension_lifecycle_event(&turn_end_event)
958 .await;
959 on_event(turn_end_event);
960 let agent_end_event = AgentEvent::AgentEnd {
961 session_id: session_id.clone(),
962 messages: std::mem::take(&mut new_messages),
963 error: assistant_arc.error_message.clone(),
964 };
965 self.dispatch_extension_lifecycle_event(&agent_end_event)
966 .await;
967 on_event(agent_end_event);
968 return Ok(Arc::unwrap_or_clone(assistant_arc));
969 }
970
971 let tool_calls = extract_tool_calls(&assistant_arc.content);
972 has_more_tool_calls = !tool_calls.is_empty();
973
974 let mut tool_results: Vec<Arc<ToolResultMessage>> = Vec::new();
975 if has_more_tool_calls {
976 iterations += 1;
977 if iterations > self.config.max_tool_iterations {
978 let error_message = format!(
979 "Maximum tool iterations ({}) exceeded",
980 self.config.max_tool_iterations
981 );
982 let mut stop_message = (*assistant_arc).clone();
983 stop_message.stop_reason = StopReason::Error;
984 stop_message.error_message = Some(error_message.clone());
985
986 stop_message
988 .content
989 .retain(|b| !matches!(b, crate::model::ContentBlock::ToolCall(_)));
990
991 let stop_arc = Arc::new(stop_message.clone());
992 let stop_event_message = Message::Assistant(Arc::clone(&stop_arc));
993
994 if let Some(last @ Message::Assistant(_)) = self
997 .messages
998 .iter_mut()
999 .rev()
1000 .find(|m| matches!(m, Message::Assistant(_)))
1001 {
1002 *last = stop_event_message.clone();
1003 }
1004 if let Some(last @ Message::Assistant(_)) = new_messages.last_mut() {
1005 *last = stop_event_message.clone();
1006 }
1007
1008 let steering_to_add = self.drain_steering_messages().await;
1009 for message in steering_to_add {
1010 self.messages.push(message.clone());
1011 on_event(AgentEvent::MessageStart {
1012 message: message.clone(),
1013 });
1014 on_event(AgentEvent::MessageEnd {
1015 message: message.clone(),
1016 });
1017 new_messages.push(message);
1018 }
1019
1020 let turn_end_event = AgentEvent::TurnEnd {
1021 session_id: session_id.clone(),
1022 turn_index: current_turn_index,
1023 message: stop_event_message,
1024 tool_results: Vec::new(),
1025 };
1026 self.dispatch_extension_lifecycle_event(&turn_end_event)
1027 .await;
1028 on_event(turn_end_event);
1029
1030 let agent_end_event = AgentEvent::AgentEnd {
1031 session_id: session_id.clone(),
1032 messages: std::mem::take(&mut new_messages),
1033 error: Some(error_message),
1034 };
1035 self.dispatch_extension_lifecycle_event(&agent_end_event)
1036 .await;
1037 on_event(agent_end_event);
1038
1039 return Ok(stop_message);
1040 }
1041
1042 let outcome = match self
1043 .execute_tool_calls(
1044 &tool_calls,
1045 Arc::clone(&on_event),
1046 &mut new_messages,
1047 abort.clone(),
1048 )
1049 .await
1050 {
1051 Ok(outcome) => outcome,
1052 Err(err) => {
1053 let steering_to_add = self.drain_steering_messages().await;
1054 for message in steering_to_add {
1055 self.messages.push(message.clone());
1056 on_event(AgentEvent::MessageStart {
1057 message: message.clone(),
1058 });
1059 on_event(AgentEvent::MessageEnd {
1060 message: message.clone(),
1061 });
1062 new_messages.push(message);
1063 }
1064
1065 let turn_end_event = AgentEvent::TurnEnd {
1066 session_id: session_id.clone(),
1067 turn_index: current_turn_index,
1068 message: assistant_event_message.clone(),
1069 tool_results: Vec::new(),
1070 };
1071 self.dispatch_extension_lifecycle_event(&turn_end_event)
1072 .await;
1073 on_event(turn_end_event);
1074
1075 let agent_end_event = AgentEvent::AgentEnd {
1076 session_id: session_id.clone(),
1077 messages: std::mem::take(&mut new_messages),
1078 error: Some(err.to_string()),
1079 };
1080 self.dispatch_extension_lifecycle_event(&agent_end_event)
1081 .await;
1082 on_event(agent_end_event);
1083 return Err(err);
1084 }
1085 };
1086 tool_results = outcome.tool_results;
1087 steering_after_tools = outcome.steering_messages;
1088 }
1089
1090 let tool_messages = tool_results
1091 .iter()
1092 .map(|r| Message::ToolResult(Arc::clone(r)))
1093 .collect::<Vec<_>>();
1094
1095 let turn_end_event = AgentEvent::TurnEnd {
1096 session_id: session_id.clone(),
1097 turn_index: current_turn_index,
1098 message: assistant_event_message.clone(),
1099 tool_results: tool_messages,
1100 };
1101 self.dispatch_extension_lifecycle_event(&turn_end_event)
1102 .await;
1103 on_event(turn_end_event);
1104
1105 turn_index = turn_index.saturating_add(1);
1106
1107 if let Some(steering) = steering_after_tools.take() {
1108 pending_messages = steering;
1109 } else {
1110 pending_messages = self.drain_steering_messages().await;
1112 }
1113 }
1114
1115 let follow_up = self.drain_follow_up_messages().await;
1117 if follow_up.is_empty() {
1118 break;
1119 }
1120 pending_messages = follow_up;
1121 }
1122
1123 let Some(final_arc) = last_assistant else {
1124 return Err(Error::api("Agent completed without assistant message"));
1125 };
1126
1127 let agent_end_event = AgentEvent::AgentEnd {
1128 session_id: session_id.clone(),
1129 messages: new_messages,
1130 error: None,
1131 };
1132 self.dispatch_extension_lifecycle_event(&agent_end_event)
1133 .await;
1134 on_event(agent_end_event);
1135 Ok(Arc::unwrap_or_clone(final_arc))
1136 }
1137
1138 async fn fetch_messages(&self, fetcher: Option<&MessageFetcher>) -> Vec<Message> {
1139 if let Some(fetcher) = fetcher {
1140 (fetcher)().await
1141 } else {
1142 Vec::new()
1143 }
1144 }
1145
1146 async fn dispatch_extension_lifecycle_event(&self, event: &AgentEvent) {
1147 let Some(extensions) = &self.extensions else {
1148 return;
1149 };
1150
1151 let name = match event {
1152 AgentEvent::AgentStart { .. } => ExtensionEventName::AgentStart,
1153 AgentEvent::AgentEnd { .. } => ExtensionEventName::AgentEnd,
1154 AgentEvent::TurnStart { .. } => ExtensionEventName::TurnStart,
1155 AgentEvent::TurnEnd { .. } => ExtensionEventName::TurnEnd,
1156 _ => return,
1157 };
1158
1159 let payload = match serde_json::to_value(event) {
1160 Ok(payload) => payload,
1161 Err(err) => {
1162 tracing::warn!("failed to serialize agent lifecycle event (fail-open): {err}");
1163 return;
1164 }
1165 };
1166
1167 if let Err(err) = extensions.dispatch_event(name, Some(payload)).await {
1168 tracing::warn!("agent lifecycle extension hook failed (fail-open): {err}");
1169 }
1170 }
1171
1172 async fn dispatch_context_event(&self, messages: &[Message]) -> Option<Vec<Message>> {
1173 let Some(extensions) = &self.extensions else {
1174 return None;
1175 };
1176
1177 let payload = json!({ "messages": messages });
1178 let response = extensions
1179 .dispatch_event_with_response(
1180 ExtensionEventName::Context,
1181 Some(payload),
1182 EXTENSION_EVENT_TIMEOUT_MS,
1183 )
1184 .await
1185 .ok()?;
1186
1187 let value = response?;
1188
1189 if value.is_null() {
1190 return None;
1191 }
1192
1193 let messages_value = if let Some(obj) = value.as_object() {
1194 obj.get("messages").cloned()?
1195 } else if value.is_array() {
1196 value
1197 } else {
1198 return None;
1199 };
1200
1201 if messages_value.is_null() {
1202 return Some(Vec::new());
1203 }
1204
1205 match serde_json::from_value(messages_value) {
1206 Ok(messages) => Some(messages),
1207 Err(err) => {
1208 tracing::warn!("context extension hook returned invalid messages: {err}");
1209 None
1210 }
1211 }
1212 }
1213
1214 async fn drain_steering_messages(&mut self) -> Vec<Message> {
1215 for fetcher in &self.steering_fetchers {
1216 let fetched = self.fetch_messages(Some(fetcher)).await;
1217 for message in fetched {
1218 self.message_queue.push_steering(message);
1219 }
1220 }
1221 self.message_queue.pop_steering()
1222 }
1223
1224 async fn drain_follow_up_messages(&mut self) -> Vec<Message> {
1225 for fetcher in &self.follow_up_fetchers {
1226 let fetched = self.fetch_messages(Some(fetcher)).await;
1227 for message in fetched {
1228 self.message_queue.push_follow_up(message);
1229 }
1230 }
1231 self.message_queue.pop_follow_up()
1232 }
1233
1234 #[allow(clippy::too_many_lines)]
1236 async fn stream_assistant_response(
1237 &mut self,
1238 on_event: AgentEventHandler,
1239 abort: Option<AbortSignal>,
1240 checkpoint_cx: &crate::agent_cx::AgentCx,
1241 ) -> Result<AssistantMessage> {
1242 let provider = Arc::clone(&self.provider);
1244 let stream_options = self.config.stream_options.clone();
1245 let (system_prompt, tools, base_messages) = {
1246 let context = self.build_context();
1247 (
1248 context.system_prompt.as_deref().map(str::to_string),
1249 context.tools.to_vec(),
1250 context.messages.to_vec(),
1251 )
1252 };
1253 let messages = self
1254 .dispatch_context_event(&base_messages)
1255 .await
1256 .unwrap_or(base_messages);
1257 let context = Context::owned(system_prompt, messages, tools);
1258 let mut stream = provider.stream(&context, &stream_options).await?;
1259
1260 let mut added_partial = false;
1261 let mut sent_start = false;
1264
1265 'stream: loop {
1266 if checkpoint_cx.checkpoint().is_err() {
1267 let last_partial = if added_partial {
1268 match self
1269 .messages
1270 .iter()
1271 .rev()
1272 .find(|m| matches!(m, Message::Assistant(_)))
1273 {
1274 Some(Message::Assistant(a)) => Some(a.as_ref()),
1275 _ => None,
1276 }
1277 } else {
1278 None
1279 };
1280 let abort_arc = Arc::new(self.build_abort_message(last_partial));
1281 if !sent_start {
1282 on_event(AgentEvent::MessageStart {
1283 message: Message::Assistant(Arc::clone(&abort_arc)),
1284 });
1285 self.messages
1286 .push(Message::Assistant(Arc::clone(&abort_arc)));
1287 added_partial = true;
1288 }
1289 on_event(AgentEvent::MessageUpdate {
1290 message: Message::Assistant(Arc::clone(&abort_arc)),
1291 assistant_message_event: AssistantMessageEvent::Error {
1292 reason: StopReason::Aborted,
1293 error: Arc::clone(&abort_arc),
1294 },
1295 });
1296 return Ok(self.finalize_assistant_message(
1297 Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
1298 &on_event,
1299 added_partial,
1300 ));
1301 }
1302
1303 let event_result = if let Some(signal) = abort.as_ref() {
1304 let abort_fut = signal.wait().fuse();
1305 let event_fut = stream.next().fuse();
1306 futures::pin_mut!(abort_fut, event_fut);
1307
1308 match futures::future::select(abort_fut, event_fut).await {
1309 futures::future::Either::Left(((), _event_fut)) => {
1310 let last_partial = if added_partial {
1311 match self
1312 .messages
1313 .iter()
1314 .rev()
1315 .find(|m| matches!(m, Message::Assistant(_)))
1316 {
1317 Some(Message::Assistant(a)) => Some(a.as_ref()),
1318 _ => None,
1319 }
1320 } else {
1321 None
1322 };
1323 let abort_arc = Arc::new(self.build_abort_message(last_partial));
1324 if !sent_start {
1325 on_event(AgentEvent::MessageStart {
1326 message: Message::Assistant(Arc::clone(&abort_arc)),
1327 });
1328 self.messages
1329 .push(Message::Assistant(Arc::clone(&abort_arc)));
1330 added_partial = true;
1331 }
1335 on_event(AgentEvent::MessageUpdate {
1336 message: Message::Assistant(Arc::clone(&abort_arc)),
1337 assistant_message_event: AssistantMessageEvent::Error {
1338 reason: StopReason::Aborted,
1339 error: Arc::clone(&abort_arc),
1340 },
1341 });
1342 return Ok(self.finalize_assistant_message(
1343 Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
1344 &on_event,
1345 added_partial,
1346 ));
1347 }
1348 futures::future::Either::Right((event, _abort_fut)) => event,
1349 }
1350 } else {
1351 let event_fut = stream.next().fuse();
1352 futures::pin_mut!(event_fut);
1353 loop {
1354 let now = checkpoint_cx
1355 .cx()
1356 .timer_driver()
1357 .map_or_else(asupersync::time::wall_now, |timer| timer.now());
1358 let tick_fut =
1359 asupersync::time::sleep(now, std::time::Duration::from_millis(25)).fuse();
1360 futures::pin_mut!(tick_fut);
1361
1362 match futures::future::select(tick_fut, &mut event_fut).await {
1363 futures::future::Either::Left(((), _event_fut)) => {
1364 if checkpoint_cx.checkpoint().is_err() {
1365 continue 'stream;
1366 }
1367 }
1368 futures::future::Either::Right((result, _tick_fut)) => break result,
1369 }
1370 }
1371 };
1372
1373 let Some(event_result) = event_result else {
1374 break;
1375 };
1376 let event = match event_result {
1377 Ok(e) => e,
1378 Err(err) => {
1379 let partial = if added_partial {
1380 match self
1381 .messages
1382 .iter()
1383 .rev()
1384 .find(|m| matches!(m, Message::Assistant(_)))
1385 {
1386 Some(Message::Assistant(a)) => Some(a.as_ref()),
1387 _ => None,
1388 }
1389 } else {
1390 None
1391 };
1392 let msg = self.build_error_message(partial, err.to_string());
1393
1394 return Ok(self.finalize_assistant_message(msg, &on_event, added_partial));
1398 }
1399 };
1400
1401 match event {
1402 StreamEvent::Start { partial } => {
1403 if added_partial {
1404 if let Some(Message::Assistant(msg_arc)) = self
1405 .messages
1406 .iter_mut()
1407 .rev()
1408 .find(|m| matches!(m, Message::Assistant(_)))
1409 {
1410 let msg = Arc::make_mut(msg_arc);
1411 if msg.content.is_empty() {
1412 *msg = partial;
1413 } else {
1414 msg.api = partial.api;
1415 msg.provider = partial.provider;
1416 msg.model = partial.model;
1417 msg.usage = partial.usage;
1418 msg.stop_reason = partial.stop_reason;
1419 msg.error_message = partial.error_message;
1420 msg.timestamp = partial.timestamp;
1421 }
1422 let shared = Arc::clone(msg_arc);
1423 if !sent_start {
1424 on_event(AgentEvent::MessageStart {
1425 message: Message::Assistant(Arc::clone(&shared)),
1426 });
1427 sent_start = true;
1428 }
1429 on_event(AgentEvent::MessageUpdate {
1430 message: Message::Assistant(Arc::clone(&shared)),
1431 assistant_message_event: AssistantMessageEvent::Start {
1432 partial: shared,
1433 },
1434 });
1435 } else {
1436 let shared = Arc::new(partial);
1437 self.update_partial_message(Arc::clone(&shared), &mut added_partial);
1438 on_event(AgentEvent::MessageStart {
1439 message: Message::Assistant(Arc::clone(&shared)),
1440 });
1441 sent_start = true;
1442 on_event(AgentEvent::MessageUpdate {
1443 message: Message::Assistant(Arc::clone(&shared)),
1444 assistant_message_event: AssistantMessageEvent::Start {
1445 partial: shared,
1446 },
1447 });
1448 }
1449 } else {
1450 let shared = Arc::new(partial);
1451 self.update_partial_message(Arc::clone(&shared), &mut added_partial);
1452 on_event(AgentEvent::MessageStart {
1453 message: Message::Assistant(Arc::clone(&shared)),
1454 });
1455 sent_start = true;
1456 on_event(AgentEvent::MessageUpdate {
1457 message: Message::Assistant(Arc::clone(&shared)),
1458 assistant_message_event: AssistantMessageEvent::Start {
1459 partial: shared,
1460 },
1461 });
1462 }
1463 }
1464 StreamEvent::TextStart { content_index, .. } => {
1465 self.seed_partial_message_if_missing(&mut added_partial);
1466 if let Some(Message::Assistant(msg_arc)) = self
1467 .messages
1468 .iter_mut()
1469 .rev()
1470 .find(|m| matches!(m, Message::Assistant(_)))
1471 {
1472 let msg = Arc::make_mut(msg_arc);
1473 if content_index == msg.content.len() {
1474 msg.content.push(ContentBlock::Text(TextContent::new("")));
1475 }
1476 let shared = Arc::clone(msg_arc);
1477 if !sent_start {
1478 on_event(AgentEvent::MessageStart {
1479 message: Message::Assistant(Arc::clone(&shared)),
1480 });
1481 sent_start = true;
1482 }
1483 on_event(AgentEvent::MessageUpdate {
1484 message: Message::Assistant(Arc::clone(&shared)),
1485 assistant_message_event: AssistantMessageEvent::TextStart {
1486 content_index,
1487 partial: shared,
1488 },
1489 });
1490 }
1491 }
1492 StreamEvent::TextDelta {
1493 content_index,
1494 delta,
1495 ..
1496 } => {
1497 self.seed_partial_message_if_missing(&mut added_partial);
1498 if let Some(Message::Assistant(msg_arc)) = self
1499 .messages
1500 .iter_mut()
1501 .rev()
1502 .find(|m| matches!(m, Message::Assistant(_)))
1503 {
1504 {
1505 let msg = Arc::make_mut(msg_arc);
1506 if msg.content.get(content_index).is_none()
1507 && content_index == msg.content.len()
1508 {
1509 msg.content.push(ContentBlock::Text(TextContent::new("")));
1510 }
1511 if let Some(ContentBlock::Text(text)) =
1512 msg.content.get_mut(content_index)
1513 {
1514 text.text.push_str(&delta);
1515 }
1516 }
1517 let shared = Arc::clone(msg_arc);
1518 if !sent_start {
1519 on_event(AgentEvent::MessageStart {
1520 message: Message::Assistant(Arc::clone(&shared)),
1521 });
1522 sent_start = true;
1523 }
1524 on_event(AgentEvent::MessageUpdate {
1525 message: Message::Assistant(Arc::clone(&shared)),
1526 assistant_message_event: AssistantMessageEvent::TextDelta {
1527 content_index,
1528 delta,
1529 partial: shared,
1530 },
1531 });
1532 }
1533 }
1534 StreamEvent::TextEnd {
1535 content_index,
1536 content,
1537 ..
1538 } => {
1539 self.seed_partial_message_if_missing(&mut added_partial);
1540 if let Some(Message::Assistant(msg_arc)) = self
1541 .messages
1542 .iter_mut()
1543 .rev()
1544 .find(|m| matches!(m, Message::Assistant(_)))
1545 {
1546 {
1547 let msg = Arc::make_mut(msg_arc);
1548 if msg.content.get(content_index).is_none()
1549 && content_index == msg.content.len()
1550 {
1551 msg.content.push(ContentBlock::Text(TextContent::new("")));
1552 }
1553 if let Some(ContentBlock::Text(text)) =
1554 msg.content.get_mut(content_index)
1555 {
1556 text.text.clone_from(&content);
1557 }
1558 }
1559 let shared = Arc::clone(msg_arc);
1560 if !sent_start {
1561 on_event(AgentEvent::MessageStart {
1562 message: Message::Assistant(Arc::clone(&shared)),
1563 });
1564 sent_start = true;
1565 }
1566 on_event(AgentEvent::MessageUpdate {
1567 message: Message::Assistant(Arc::clone(&shared)),
1568 assistant_message_event: AssistantMessageEvent::TextEnd {
1569 content_index,
1570 content,
1571 partial: shared,
1572 },
1573 });
1574 }
1575 }
1576 StreamEvent::ThinkingStart { content_index, .. } => {
1577 self.seed_partial_message_if_missing(&mut added_partial);
1578 if let Some(Message::Assistant(msg_arc)) = self
1579 .messages
1580 .iter_mut()
1581 .rev()
1582 .find(|m| matches!(m, Message::Assistant(_)))
1583 {
1584 let msg = Arc::make_mut(msg_arc);
1585 if content_index == msg.content.len() {
1586 msg.content.push(ContentBlock::Thinking(ThinkingContent {
1587 thinking: String::new(),
1588 thinking_signature: None,
1589 }));
1590 }
1591 let shared = Arc::clone(msg_arc);
1592 if !sent_start {
1593 on_event(AgentEvent::MessageStart {
1594 message: Message::Assistant(Arc::clone(&shared)),
1595 });
1596 sent_start = true;
1597 }
1598 on_event(AgentEvent::MessageUpdate {
1599 message: Message::Assistant(Arc::clone(&shared)),
1600 assistant_message_event: AssistantMessageEvent::ThinkingStart {
1601 content_index,
1602 partial: shared,
1603 },
1604 });
1605 }
1606 }
1607 StreamEvent::ThinkingDelta {
1608 content_index,
1609 delta,
1610 ..
1611 } => {
1612 self.seed_partial_message_if_missing(&mut added_partial);
1613 if let Some(Message::Assistant(msg_arc)) = self
1614 .messages
1615 .iter_mut()
1616 .rev()
1617 .find(|m| matches!(m, Message::Assistant(_)))
1618 {
1619 {
1620 let msg = Arc::make_mut(msg_arc);
1621 if msg.content.get(content_index).is_none()
1622 && content_index == msg.content.len()
1623 {
1624 msg.content.push(ContentBlock::Thinking(ThinkingContent {
1625 thinking: String::new(),
1626 thinking_signature: None,
1627 }));
1628 }
1629 if let Some(ContentBlock::Thinking(thinking)) =
1630 msg.content.get_mut(content_index)
1631 {
1632 thinking.thinking.push_str(&delta);
1633 }
1634 }
1635 let shared = Arc::clone(msg_arc);
1636 if !sent_start {
1637 on_event(AgentEvent::MessageStart {
1638 message: Message::Assistant(Arc::clone(&shared)),
1639 });
1640 sent_start = true;
1641 }
1642 on_event(AgentEvent::MessageUpdate {
1643 message: Message::Assistant(Arc::clone(&shared)),
1644 assistant_message_event: AssistantMessageEvent::ThinkingDelta {
1645 content_index,
1646 delta,
1647 partial: shared,
1648 },
1649 });
1650 }
1651 }
1652 StreamEvent::ThinkingEnd {
1653 content_index,
1654 content,
1655 ..
1656 } => {
1657 self.seed_partial_message_if_missing(&mut added_partial);
1658 if let Some(Message::Assistant(msg_arc)) = self
1659 .messages
1660 .iter_mut()
1661 .rev()
1662 .find(|m| matches!(m, Message::Assistant(_)))
1663 {
1664 {
1665 let msg = Arc::make_mut(msg_arc);
1666 if msg.content.get(content_index).is_none()
1667 && content_index == msg.content.len()
1668 {
1669 msg.content.push(ContentBlock::Thinking(ThinkingContent {
1670 thinking: String::new(),
1671 thinking_signature: None,
1672 }));
1673 }
1674 if let Some(ContentBlock::Thinking(thinking)) =
1675 msg.content.get_mut(content_index)
1676 {
1677 thinking.thinking.clone_from(&content);
1678 }
1679 }
1680 let shared = Arc::clone(msg_arc);
1681 if !sent_start {
1682 on_event(AgentEvent::MessageStart {
1683 message: Message::Assistant(Arc::clone(&shared)),
1684 });
1685 sent_start = true;
1686 }
1687 on_event(AgentEvent::MessageUpdate {
1688 message: Message::Assistant(Arc::clone(&shared)),
1689 assistant_message_event: AssistantMessageEvent::ThinkingEnd {
1690 content_index,
1691 content,
1692 partial: shared,
1693 },
1694 });
1695 }
1696 }
1697 StreamEvent::ToolCallStart { content_index, .. } => {
1698 self.seed_partial_message_if_missing(&mut added_partial);
1699 if let Some(Message::Assistant(msg_arc)) = self
1700 .messages
1701 .iter_mut()
1702 .rev()
1703 .find(|m| matches!(m, Message::Assistant(_)))
1704 {
1705 let msg = Arc::make_mut(msg_arc);
1706 if content_index == msg.content.len() {
1707 msg.content.push(ContentBlock::ToolCall(ToolCall {
1708 id: String::new(),
1709 name: String::new(),
1710 arguments: serde_json::Value::Null,
1711 thought_signature: None,
1712 }));
1713 }
1714 let shared = Arc::clone(msg_arc);
1715 if !sent_start {
1716 on_event(AgentEvent::MessageStart {
1717 message: Message::Assistant(Arc::clone(&shared)),
1718 });
1719 sent_start = true;
1720 }
1721 on_event(AgentEvent::MessageUpdate {
1722 message: Message::Assistant(Arc::clone(&shared)),
1723 assistant_message_event: AssistantMessageEvent::ToolCallStart {
1724 content_index,
1725 partial: shared,
1726 },
1727 });
1728 }
1729 }
1730 StreamEvent::ToolCallDelta {
1731 content_index,
1732 delta,
1733 ..
1734 } => {
1735 self.seed_partial_message_if_missing(&mut added_partial);
1736 if let Some(Message::Assistant(msg_arc)) = self
1737 .messages
1738 .iter_mut()
1739 .rev()
1740 .find(|m| matches!(m, Message::Assistant(_)))
1741 {
1742 if msg_arc.content.get(content_index).is_none()
1743 && content_index == msg_arc.content.len()
1744 {
1745 let msg = Arc::make_mut(msg_arc);
1746 msg.content.push(ContentBlock::ToolCall(ToolCall {
1747 id: String::new(),
1748 name: String::new(),
1749 arguments: serde_json::Value::Null,
1750 thought_signature: None,
1751 }));
1752 }
1753 let shared = Arc::clone(msg_arc);
1756 if !sent_start {
1757 on_event(AgentEvent::MessageStart {
1758 message: Message::Assistant(Arc::clone(&shared)),
1759 });
1760 sent_start = true;
1761 }
1762 on_event(AgentEvent::MessageUpdate {
1763 message: Message::Assistant(Arc::clone(&shared)),
1764 assistant_message_event: AssistantMessageEvent::ToolCallDelta {
1765 content_index,
1766 delta,
1767 partial: shared,
1768 },
1769 });
1770 }
1771 }
1772 StreamEvent::ToolCallEnd {
1773 content_index,
1774 tool_call,
1775 ..
1776 } => {
1777 self.seed_partial_message_if_missing(&mut added_partial);
1778 if let Some(Message::Assistant(msg_arc)) = self
1779 .messages
1780 .iter_mut()
1781 .rev()
1782 .find(|m| matches!(m, Message::Assistant(_)))
1783 {
1784 {
1785 let msg = Arc::make_mut(msg_arc);
1786 if msg.content.get(content_index).is_none()
1787 && content_index == msg.content.len()
1788 {
1789 msg.content.push(ContentBlock::ToolCall(ToolCall {
1790 id: String::new(),
1791 name: String::new(),
1792 arguments: serde_json::Value::Null,
1793 thought_signature: None,
1794 }));
1795 }
1796 if let Some(ContentBlock::ToolCall(tc)) =
1797 msg.content.get_mut(content_index)
1798 {
1799 *tc = tool_call.clone();
1800 }
1801 }
1802 let shared = Arc::clone(msg_arc);
1803 if !sent_start {
1804 on_event(AgentEvent::MessageStart {
1805 message: Message::Assistant(Arc::clone(&shared)),
1806 });
1807 sent_start = true;
1808 }
1809 on_event(AgentEvent::MessageUpdate {
1810 message: Message::Assistant(Arc::clone(&shared)),
1811 assistant_message_event: AssistantMessageEvent::ToolCallEnd {
1812 content_index,
1813 tool_call,
1814 partial: shared,
1815 },
1816 });
1817 }
1818 }
1819 StreamEvent::Done { message, .. } => {
1820 return Ok(self.finalize_assistant_message(message, &on_event, added_partial));
1821 }
1822 StreamEvent::Error { error, .. } => {
1823 return Ok(self.finalize_assistant_message(error, &on_event, added_partial));
1824 }
1825 }
1826 }
1827
1828 if added_partial {
1832 if let Some(Message::Assistant(last_msg)) = self
1833 .messages
1834 .iter()
1835 .rev()
1836 .find(|m| matches!(m, Message::Assistant(_)))
1837 {
1838 let mut final_msg = (**last_msg).clone();
1839 final_msg.stop_reason = StopReason::Error;
1840 final_msg.error_message = Some("Stream ended without Done event".to_string());
1841 return Ok(self.finalize_assistant_message(final_msg, &on_event, true));
1842 }
1843 }
1844 Err(Error::api("Stream ended without Done event"))
1845 }
1846
1847 fn seed_partial_message_if_missing(&mut self, added_partial: &mut bool) {
1852 if *added_partial {
1853 return;
1854 }
1855
1856 let message = AssistantMessage {
1857 content: Vec::new(),
1858 api: self.provider.api().to_string(),
1859 provider: self.provider.name().to_string(),
1860 model: self.provider.model_id().to_string(),
1861 usage: Usage::default(),
1862 stop_reason: StopReason::Stop,
1863 error_message: None,
1864 timestamp: Utc::now().timestamp_millis(),
1865 };
1866 self.messages.push(Message::Assistant(Arc::new(message)));
1867 *added_partial = true;
1868 }
1869
1870 fn update_partial_message(
1875 &mut self,
1876 partial: Arc<AssistantMessage>,
1877 added_partial: &mut bool,
1878 ) -> bool {
1879 if *added_partial {
1880 if let Some(target) = self
1881 .messages
1882 .iter_mut()
1883 .rev()
1884 .find(|m| matches!(m, Message::Assistant(_)))
1885 {
1886 *target = Message::Assistant(partial);
1887 } else {
1888 tracing::warn!("update_partial_message: expected an Assistant message in history");
1891 self.messages.push(Message::Assistant(partial));
1892 }
1893 false
1894 } else {
1895 self.messages.push(Message::Assistant(partial));
1896 *added_partial = true;
1897 true
1898 }
1899 }
1900
1901 fn finalize_assistant_message(
1902 &mut self,
1903 message: AssistantMessage,
1904 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
1905 added_partial: bool,
1906 ) -> AssistantMessage {
1907 let arc = Arc::new(message);
1908 if added_partial {
1909 if let Some(target) = self
1910 .messages
1911 .iter_mut()
1912 .rev()
1913 .find(|m| matches!(m, Message::Assistant(_)))
1914 {
1915 *target = Message::Assistant(Arc::clone(&arc));
1916 } else {
1917 tracing::warn!(
1920 "finalize_assistant_message: expected an Assistant message in history"
1921 );
1922 self.messages.push(Message::Assistant(Arc::clone(&arc)));
1923 on_event(AgentEvent::MessageStart {
1924 message: Message::Assistant(Arc::clone(&arc)),
1925 });
1926 }
1927 } else {
1928 self.messages.push(Message::Assistant(Arc::clone(&arc)));
1929 on_event(AgentEvent::MessageStart {
1930 message: Message::Assistant(Arc::clone(&arc)),
1931 });
1932 }
1933
1934 on_event(AgentEvent::MessageEnd {
1935 message: Message::Assistant(Arc::clone(&arc)),
1936 });
1937 Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone())
1938 }
1939
1940 async fn execute_parallel_batch(
1941 &self,
1942 batch: Vec<(usize, ToolCall)>,
1943 on_event: AgentEventHandler,
1944 abort: Option<AbortSignal>,
1945 ) -> Vec<(usize, (ToolOutput, bool))> {
1946 let futures = batch.into_iter().map(|(idx, tc)| {
1947 let on_event = Arc::clone(&on_event);
1948 async move { (idx, self.execute_tool_owned(tc, on_event).await) }
1949 });
1950
1951 if let Some(signal) = abort.as_ref() {
1952 use futures::future::{Either, select};
1953 let all_fut = stream::iter(futures)
1954 .buffer_unordered(MAX_CONCURRENT_TOOLS)
1955 .collect::<Vec<_>>()
1956 .fuse();
1957 let abort_fut = signal.wait().fuse();
1958 futures::pin_mut!(all_fut, abort_fut);
1959
1960 match select(all_fut, abort_fut).await {
1961 Either::Left((batch_results, _)) => batch_results,
1962 Either::Right(_) => Vec::new(), }
1964 } else {
1965 stream::iter(futures)
1966 .buffer_unordered(MAX_CONCURRENT_TOOLS)
1967 .collect::<Vec<_>>()
1968 .await
1969 }
1970 }
1971
1972 #[allow(clippy::too_many_lines)]
1973 async fn execute_tool_calls(
1974 &mut self,
1975 tool_calls: &[ToolCall],
1976 on_event: AgentEventHandler,
1977 new_messages: &mut Vec<Message>,
1978 abort: Option<AbortSignal>,
1979 ) -> Result<ToolExecutionOutcome> {
1980 let mut results = Vec::new();
1981 let mut steering_messages: Option<Vec<Message>> = None;
1982
1983 for tool_call in tool_calls {
1985 on_event(AgentEvent::ToolExecutionStart {
1986 tool_call_id: tool_call.id.clone(),
1987 tool_name: tool_call.name.clone(),
1988 args: tool_call.arguments.clone(),
1989 });
1990 }
1991
1992 let mut pending_parallel: Vec<(usize, ToolCall)> = Vec::new();
1994 let mut tool_outputs: Vec<Option<(ToolOutput, bool)>> = vec![None; tool_calls.len()];
1995
1996 for (index, tool_call) in tool_calls.iter().enumerate() {
1998 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1999 break;
2000 }
2001
2002 let is_read_only =
2003 matches!(self.tools.get(&tool_call.name), Some(tool) if tool.is_read_only());
2004
2005 if is_read_only {
2006 pending_parallel.push((index, tool_call.clone()));
2007 } else {
2008 let steering = self.drain_steering_messages().await;
2010 if !steering.is_empty() {
2011 steering_messages = Some(steering);
2012 break;
2013 }
2014
2015 if !pending_parallel.is_empty() {
2017 let batch = std::mem::take(&mut pending_parallel);
2018 let results = self
2019 .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
2020 .await;
2021 for (idx, result) in results {
2022 tool_outputs[idx] = Some(result);
2023 }
2024 }
2025
2026 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
2027 break;
2028 }
2029
2030 let steering = self.drain_steering_messages().await;
2033 if !steering.is_empty() {
2034 steering_messages = Some(steering);
2035 break;
2036 }
2037
2038 if let Some(signal) = abort.as_ref() {
2041 use futures::future::{Either, select};
2042 let tool_fut = self
2043 .execute_tool(tool_call.clone(), Arc::clone(&on_event))
2044 .fuse();
2045 let abort_fut = signal.wait().fuse();
2046 futures::pin_mut!(tool_fut, abort_fut);
2047 match select(tool_fut, abort_fut).await {
2048 Either::Left((result, _)) => {
2049 tool_outputs[index] = Some(result);
2050 }
2051 Either::Right(_) => {
2052 break;
2055 }
2056 }
2057 } else {
2058 let result = self
2059 .execute_tool(tool_call.clone(), Arc::clone(&on_event))
2060 .await;
2061 tool_outputs[index] = Some(result);
2062 }
2063 }
2064 }
2065
2066 if !pending_parallel.is_empty()
2068 && !abort.as_ref().is_some_and(AbortSignal::is_aborted)
2069 && steering_messages.is_none()
2070 {
2071 let batch = std::mem::take(&mut pending_parallel);
2072 let steering = self.drain_steering_messages().await;
2074 if steering.is_empty() {
2075 let results = self
2076 .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
2077 .await;
2078 for (idx, result) in results {
2079 tool_outputs[idx] = Some(result);
2080 }
2081 } else {
2082 steering_messages = Some(steering);
2083 }
2084 }
2085
2086 for (index, tool_call) in tool_calls.iter().enumerate() {
2088 if steering_messages.is_none() && !abort.as_ref().is_some_and(AbortSignal::is_aborted) {
2091 let steering = self.drain_steering_messages().await;
2092 if !steering.is_empty() {
2093 steering_messages = Some(steering);
2094 }
2095 }
2096
2097 if let Some((output, is_error)) = tool_outputs[index].take() {
2101 let tool_result = Arc::new(ToolResultMessage {
2106 tool_call_id: tool_call.id.clone(),
2107 tool_name: tool_call.name.clone(),
2108 content: output.content,
2109 details: output.details,
2110 is_error,
2111 timestamp: Utc::now().timestamp_millis(),
2112 });
2113
2114 on_event(AgentEvent::ToolExecutionEnd {
2117 tool_call_id: tool_result.tool_call_id.clone(),
2118 tool_name: tool_result.tool_name.clone(),
2119 result: ToolOutput {
2120 content: tool_result.content.clone(),
2121 details: tool_result.details.clone(),
2122 is_error,
2123 },
2124 is_error,
2125 });
2126
2127 let msg = Message::ToolResult(Arc::clone(&tool_result));
2128 self.messages.push(msg.clone());
2129 on_event(AgentEvent::MessageStart {
2130 message: msg.clone(),
2131 });
2132 new_messages.push(msg.clone());
2133 on_event(AgentEvent::MessageEnd { message: msg });
2134
2135 results.push(tool_result);
2136 } else if steering_messages.is_some() {
2137 results.push(self.skip_tool_call(tool_call, &on_event, new_messages));
2139 } else {
2140 let output = ToolOutput {
2142 content: vec![ContentBlock::Text(TextContent::new(
2143 "Tool execution aborted",
2144 ))],
2145 details: None,
2146 is_error: true,
2147 };
2148
2149 on_event(AgentEvent::ToolExecutionUpdate {
2150 tool_call_id: tool_call.id.clone(),
2151 tool_name: tool_call.name.clone(),
2152 args: tool_call.arguments.clone(),
2153 partial_result: ToolOutput {
2154 content: output.content.clone(),
2155 details: output.details.clone(),
2156 is_error: true,
2157 },
2158 });
2159
2160 on_event(AgentEvent::ToolExecutionEnd {
2161 tool_call_id: tool_call.id.clone(),
2162 tool_name: tool_call.name.clone(),
2163 result: ToolOutput {
2164 content: output.content.clone(),
2165 details: output.details.clone(),
2166 is_error: true,
2167 },
2168 is_error: true,
2169 });
2170
2171 let tool_result = Arc::new(ToolResultMessage {
2172 tool_call_id: tool_call.id.clone(),
2173 tool_name: tool_call.name.clone(),
2174 content: output.content,
2175 details: output.details,
2176 is_error: true,
2177 timestamp: Utc::now().timestamp_millis(),
2178 });
2179
2180 let msg = Message::ToolResult(Arc::clone(&tool_result));
2181 self.messages.push(msg.clone());
2182 on_event(AgentEvent::MessageStart {
2183 message: msg.clone(),
2184 });
2185 let end_msg = msg.clone();
2186 new_messages.push(msg);
2187 on_event(AgentEvent::MessageEnd { message: end_msg });
2188
2189 results.push(tool_result);
2190 }
2191 }
2192
2193 Ok(ToolExecutionOutcome {
2194 tool_results: results,
2195 steering_messages,
2196 })
2197 }
2198
2199 async fn execute_tool(
2200 &self,
2201 tool_call: ToolCall,
2202 on_event: AgentEventHandler,
2203 ) -> (ToolOutput, bool) {
2204 let extensions = self.extensions.clone();
2205
2206 let (mut output, is_error) = if let Some(extensions) = &extensions {
2207 match Self::dispatch_tool_call_hook(
2208 extensions,
2209 &tool_call,
2210 self.config.fail_closed_hooks,
2211 )
2212 .await
2213 {
2214 Some(blocked_output) => (blocked_output, true),
2215 None => {
2216 self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
2217 .await
2218 }
2219 }
2220 } else {
2221 self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
2222 .await
2223 };
2224
2225 if let Some(extensions) = &extensions {
2226 Self::apply_tool_result_hook(extensions, &tool_call, &mut output, is_error).await;
2227 }
2228
2229 (output, is_error)
2230 }
2231
2232 async fn execute_tool_owned(
2233 &self,
2234 tool_call: ToolCall,
2235 on_event: AgentEventHandler,
2236 ) -> (ToolOutput, bool) {
2237 self.execute_tool(tool_call, on_event).await
2238 }
2239
2240 async fn execute_tool_without_hooks(
2241 &self,
2242 tool_call: &ToolCall,
2243 on_event: AgentEventHandler,
2244 ) -> (ToolOutput, bool) {
2245 let Some(tool) = self.tools.get(&tool_call.name) else {
2247 return (Self::tool_not_found_output(&tool_call.name), true);
2248 };
2249
2250 let tool_name = tool_call.name.clone();
2251 let tool_id = tool_call.id.clone();
2252 let tool_args = tool_call.arguments.clone();
2253 let on_event = Arc::clone(&on_event);
2254
2255 let update_callback = move |update: ToolUpdate| {
2256 on_event(AgentEvent::ToolExecutionUpdate {
2257 tool_call_id: tool_id.clone(),
2258 tool_name: tool_name.clone(),
2259 args: tool_args.clone(),
2260 partial_result: ToolOutput {
2261 content: update.content,
2262 details: update.details,
2263 is_error: false,
2264 },
2265 });
2266 };
2267
2268 match tool
2269 .execute(
2270 &tool_call.id,
2271 tool_call.arguments.clone(),
2272 Some(Box::new(update_callback)),
2273 )
2274 .await
2275 {
2276 Ok(output) => {
2277 let is_error = output.is_error;
2278 (output, is_error)
2279 }
2280 Err(e) => (
2281 ToolOutput {
2282 content: vec![ContentBlock::Text(TextContent::new(format!("Error: {e}")))],
2283 details: None,
2284 is_error: true,
2285 },
2286 true,
2287 ),
2288 }
2289 }
2290
2291 fn tool_not_found_output(tool_name: &str) -> ToolOutput {
2292 ToolOutput {
2293 content: vec![ContentBlock::Text(TextContent::new(format!(
2294 "Error: Tool '{tool_name}' not found"
2295 )))],
2296 details: None,
2297 is_error: true,
2298 }
2299 }
2300
2301 async fn dispatch_tool_call_hook(
2302 extensions: &ExtensionManager,
2303 tool_call: &ToolCall,
2304 fail_closed_hooks: bool,
2305 ) -> Option<ToolOutput> {
2306 match extensions
2307 .dispatch_tool_call(tool_call, EXTENSION_EVENT_TIMEOUT_MS)
2308 .await
2309 {
2310 Ok(Some(result)) if result.block => {
2311 Some(Self::tool_call_blocked_output(result.reason.as_deref()))
2312 }
2313 Ok(_) => None,
2314 Err(err) => {
2315 if fail_closed_hooks {
2316 tracing::warn!(
2317 error = ?err,
2318 "tool_call extension hook failed (fail-closed)"
2319 );
2320 Some(Self::tool_call_blocked_output(Some(
2321 "extension hook failed",
2322 )))
2323 } else {
2324 tracing::warn!("tool_call extension hook failed (fail-open): {err}");
2325 None
2326 }
2327 }
2328 }
2329 }
2330
2331 fn tool_call_blocked_output(reason: Option<&str>) -> ToolOutput {
2332 let reason = reason.map(str::trim).filter(|reason| !reason.is_empty());
2333 let message = reason.map_or_else(
2334 || "Tool execution was blocked by an extension".to_string(),
2335 |reason| format!("Tool execution blocked: {reason}"),
2336 );
2337
2338 ToolOutput {
2339 content: vec![ContentBlock::Text(TextContent::new(message))],
2340 details: None,
2341 is_error: true,
2342 }
2343 }
2344
2345 async fn apply_tool_result_hook(
2346 extensions: &ExtensionManager,
2347 tool_call: &ToolCall,
2348 output: &mut ToolOutput,
2349 is_error: bool,
2350 ) {
2351 match extensions
2352 .dispatch_tool_result(tool_call, &*output, is_error, EXTENSION_EVENT_TIMEOUT_MS)
2353 .await
2354 {
2355 Ok(Some(result)) => {
2356 if let Some(content) = result.content {
2357 output.content = content;
2358 }
2359 if let Some(details) = result.details {
2360 output.details = Some(details);
2361 }
2362 }
2363 Ok(None) => {}
2364 Err(err) => tracing::warn!("tool_result extension hook failed (fail-open): {err}"),
2365 }
2366 }
2367
2368 fn skip_tool_call(
2369 &mut self,
2370 tool_call: &ToolCall,
2371 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
2372 new_messages: &mut Vec<Message>,
2373 ) -> Arc<ToolResultMessage> {
2374 let output = ToolOutput {
2375 content: vec![ContentBlock::Text(TextContent::new(
2376 "Skipped due to queued user message.",
2377 ))],
2378 details: None,
2379 is_error: true,
2380 };
2381
2382 on_event(AgentEvent::ToolExecutionUpdate {
2385 tool_call_id: tool_call.id.clone(),
2386 tool_name: tool_call.name.clone(),
2387 args: tool_call.arguments.clone(),
2388 partial_result: output.clone(),
2389 });
2390 on_event(AgentEvent::ToolExecutionEnd {
2391 tool_call_id: tool_call.id.clone(),
2392 tool_name: tool_call.name.clone(),
2393 result: output.clone(),
2394 is_error: true,
2395 });
2396
2397 let tool_result = Arc::new(ToolResultMessage {
2398 tool_call_id: tool_call.id.clone(),
2399 tool_name: tool_call.name.clone(),
2400 content: output.content,
2401 details: output.details,
2402 is_error: true,
2403 timestamp: Utc::now().timestamp_millis(),
2404 });
2405
2406 let msg = Message::ToolResult(Arc::clone(&tool_result));
2407 self.messages.push(msg.clone());
2408 new_messages.push(msg.clone());
2409
2410 on_event(AgentEvent::MessageStart {
2411 message: msg.clone(),
2412 });
2413 on_event(AgentEvent::MessageEnd { message: msg });
2414
2415 tool_result
2416 }
2417}
2418
2419struct ToolExecutionOutcome {
2424 tool_results: Vec<Arc<ToolResultMessage>>,
2425 steering_messages: Option<Vec<Message>>,
2426}
2427
2428pub struct PreWarmedExtensionRuntime {
2433 pub manager: ExtensionManager,
2435 pub runtime: ExtensionRuntimeHandle,
2437 pub tools: Arc<ToolRegistry>,
2439}
2440
2441struct AtomicBoolGuard(Arc<AtomicBool>);
2444
2445impl AtomicBoolGuard {
2446 fn activate(flag: &Arc<AtomicBool>) -> Self {
2447 flag.store(true, Ordering::SeqCst);
2448 Self(Arc::clone(flag))
2449 }
2450}
2451
2452impl Drop for AtomicBoolGuard {
2453 fn drop(&mut self) {
2454 self.0.store(false, Ordering::SeqCst);
2455 }
2456}
2457
2458pub struct AgentSession {
2459 pub agent: Agent,
2460 pub session: Arc<Mutex<Session>>,
2461 save_enabled: bool,
2462 input_source: InputSource,
2463 pub extensions: Option<ExtensionRegion>,
2466 extensions_is_streaming: Arc<AtomicBool>,
2467 extensions_is_compacting: Arc<AtomicBool>,
2468 extensions_turn_active: Arc<AtomicBool>,
2469 extensions_pending_idle_actions: Arc<StdMutex<VecDeque<PendingIdleAction>>>,
2470 extension_queue_modes: Option<Arc<StdMutex<ExtensionQueueModeState>>>,
2471 extension_injected_queue: Option<Arc<StdMutex<ExtensionInjectedQueue>>>,
2472 compaction_settings: ResolvedCompactionSettings,
2473 compaction_runtime: Option<Runtime>,
2474 runtime_handle: Option<RuntimeHandle>,
2475 compaction_worker: CompactionWorkerState,
2476 model_registry: Option<ModelRegistry>,
2477 auth_storage: Option<AuthStorage>,
2478}
2479
2480#[derive(Debug, Clone, Copy)]
2481struct ExtensionQueueModeState {
2482 steering_mode: QueueMode,
2483 follow_up_mode: QueueMode,
2484}
2485
2486impl ExtensionQueueModeState {
2487 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
2488 Self {
2489 steering_mode,
2490 follow_up_mode,
2491 }
2492 }
2493
2494 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
2495 self.steering_mode = steering_mode;
2496 self.follow_up_mode = follow_up_mode;
2497 }
2498}
2499
2500#[derive(Debug)]
2501struct ExtensionInjectedQueue {
2502 steering: VecDeque<Message>,
2503 follow_up: VecDeque<Message>,
2504 steering_mode: QueueMode,
2505 follow_up_mode: QueueMode,
2506}
2507
2508impl ExtensionInjectedQueue {
2509 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
2510 Self {
2511 steering: VecDeque::new(),
2512 follow_up: VecDeque::new(),
2513 steering_mode,
2514 follow_up_mode,
2515 }
2516 }
2517
2518 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
2519 self.steering_mode = steering_mode;
2520 self.follow_up_mode = follow_up_mode;
2521 }
2522
2523 fn push_steering(&mut self, message: Message) {
2524 if self.steering.len() >= MAX_STEERING_QUEUE_SIZE {
2525 tracing::warn!(
2526 "Extension steering queue full ({} messages), dropping oldest message",
2527 MAX_STEERING_QUEUE_SIZE
2528 );
2529 self.steering.pop_front();
2530 }
2531 self.steering.push_back(message);
2532 }
2533
2534 fn push_follow_up(&mut self, message: Message) {
2535 if self.follow_up.len() >= MAX_FOLLOW_UP_QUEUE_SIZE {
2536 tracing::warn!(
2537 "Extension follow-up queue full ({} messages), dropping oldest message",
2538 MAX_FOLLOW_UP_QUEUE_SIZE
2539 );
2540 self.follow_up.pop_front();
2541 }
2542 self.follow_up.push_back(message);
2543 }
2544
2545 fn pop_steering(&mut self) -> Vec<Message> {
2546 match self.steering_mode {
2547 QueueMode::All => self.steering.drain(..).collect(),
2548 QueueMode::OneAtATime => self.steering.pop_front().into_iter().collect(),
2549 }
2550 }
2551
2552 fn pop_follow_up(&mut self) -> Vec<Message> {
2553 match self.follow_up_mode {
2554 QueueMode::All => self.follow_up.drain(..).collect(),
2555 QueueMode::OneAtATime => self.follow_up.pop_front().into_iter().collect(),
2556 }
2557 }
2558}
2559
2560impl Default for ExtensionInjectedQueue {
2561 fn default() -> Self {
2562 Self::new(QueueMode::OneAtATime, QueueMode::OneAtATime)
2563 }
2564}
2565
2566#[derive(Debug)]
2567enum PendingIdleAction {
2568 CustomMessage(Message),
2569 UserText(String),
2570}
2571
2572#[derive(Clone)]
2573struct AgentSessionHostActions {
2574 session: Arc<Mutex<Session>>,
2575 injected: Arc<StdMutex<ExtensionInjectedQueue>>,
2576 is_streaming: Arc<AtomicBool>,
2577 is_turn_active: Arc<AtomicBool>,
2578 pending_idle_actions: Arc<StdMutex<VecDeque<PendingIdleAction>>>,
2579}
2580
2581impl AgentSessionHostActions {
2582 fn enqueue(&self, deliver_as: Option<ExtensionDeliverAs>, message: Message) {
2583 let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
2584 let Ok(mut queue) = self.injected.lock() else {
2585 tracing::error!("injected queue mutex poisoned; dropping extension message");
2586 return;
2587 };
2588 match deliver_as {
2589 ExtensionDeliverAs::FollowUp => {
2590 queue.push_follow_up(message);
2591 }
2592 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
2593 queue.push_steering(message);
2594 }
2595 }
2596 }
2597
2598 async fn append_to_session(&self, message: Message) -> Result<()> {
2599 let cx = crate::agent_cx::AgentCx::for_current_or_request();
2600 let mut session = self
2601 .session
2602 .lock(cx.cx())
2603 .await
2604 .map_err(|e| Error::session(e.to_string()))?;
2605 session.append_model_message(message);
2606 Ok(())
2607 }
2608
2609 fn queue_pending_idle_action(&self, action: PendingIdleAction) {
2610 let Ok(mut actions) = self.pending_idle_actions.lock() else {
2611 tracing::error!("pending idle actions mutex poisoned; dropping idle action");
2612 return;
2613 };
2614 actions.push_back(action);
2615 }
2616}
2617
2618#[async_trait]
2619impl ExtensionHostActions for AgentSessionHostActions {
2620 async fn send_message(&self, message: ExtensionSendMessage) -> Result<()> {
2621 let custom_message = Message::Custom(CustomMessage {
2622 content: message.content,
2623 custom_type: message.custom_type,
2624 display: message.display,
2625 details: message.details,
2626 timestamp: Utc::now().timestamp_millis(),
2627 });
2628
2629 if matches!(message.deliver_as, Some(ExtensionDeliverAs::NextTurn)) {
2630 return self.append_to_session(custom_message).await;
2631 }
2632
2633 if self.is_streaming.load(Ordering::SeqCst) {
2634 self.enqueue(message.deliver_as, custom_message);
2635 return Ok(());
2636 }
2637
2638 if self.is_turn_active.load(Ordering::SeqCst) {
2639 return self.append_to_session(custom_message).await;
2640 }
2641
2642 if message.trigger_turn {
2643 self.queue_pending_idle_action(PendingIdleAction::CustomMessage(custom_message));
2644 return Ok(());
2645 }
2646
2647 self.append_to_session(custom_message).await
2648 }
2649
2650 async fn send_user_message(&self, message: ExtensionSendUserMessage) -> Result<()> {
2651 let text = message.text;
2652 let user_message = Message::User(UserMessage {
2653 content: UserContent::Text(text.clone()),
2654 timestamp: Utc::now().timestamp_millis(),
2655 });
2656
2657 if self.is_streaming.load(Ordering::SeqCst) {
2658 self.enqueue(message.deliver_as, user_message);
2659 return Ok(());
2660 }
2661
2662 if self.is_turn_active.load(Ordering::SeqCst) {
2663 return self.append_to_session(user_message).await;
2664 }
2665
2666 self.queue_pending_idle_action(PendingIdleAction::UserText(text));
2667 Ok(())
2668 }
2669}
2670
2671#[cfg(test)]
2672mod message_queue_tests {
2673 use super::*;
2674
2675 fn user_message(text: &str) -> Message {
2676 Message::User(UserMessage {
2677 content: UserContent::Text(text.to_string()),
2678 timestamp: 0,
2679 })
2680 }
2681
2682 #[test]
2683 fn message_queue_one_at_a_time() {
2684 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2685 queue.push_steering(user_message("a"));
2686 queue.push_steering(user_message("b"));
2687
2688 let first = queue.pop_steering();
2689 assert_eq!(first.len(), 1);
2690 assert!(matches!(
2691 first.first(),
2692 Some(Message::User(UserMessage { content, .. }))
2693 if matches!(content, UserContent::Text(text) if text == "a")
2694 ));
2695
2696 let second = queue.pop_steering();
2697 assert_eq!(second.len(), 1);
2698 assert!(matches!(
2699 second.first(),
2700 Some(Message::User(UserMessage { content, .. }))
2701 if matches!(content, UserContent::Text(text) if text == "b")
2702 ));
2703
2704 assert!(queue.pop_steering().is_empty());
2705 }
2706
2707 #[test]
2708 fn message_queue_all_mode() {
2709 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
2710 queue.push_steering(user_message("a"));
2711 queue.push_steering(user_message("b"));
2712
2713 let drained = queue.pop_steering();
2714 assert_eq!(drained.len(), 2);
2715 assert!(queue.pop_steering().is_empty());
2716 }
2717
2718 #[test]
2719 fn message_queue_separates_kinds() {
2720 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2721 queue.push_steering(user_message("steer"));
2722 queue.push_follow_up(user_message("follow"));
2723
2724 let steering = queue.pop_steering();
2725 assert_eq!(steering.len(), 1);
2726 assert_eq!(queue.pending_count(), 1);
2727
2728 let follow = queue.pop_follow_up();
2729 assert_eq!(follow.len(), 1);
2730 assert_eq!(queue.pending_count(), 0);
2731 }
2732
2733 #[test]
2734 fn message_queue_seq_increments() {
2735 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2736 let first = queue.push_steering(user_message("a"));
2737 let second = queue.push_follow_up(user_message("b"));
2738 assert!(second > first);
2739 }
2740
2741 #[test]
2742 fn message_queue_seq_saturates_at_u64_max() {
2743 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2744 queue.next_seq = u64::MAX;
2745
2746 let first = queue.push_steering(user_message("a"));
2747 let second = queue.push_follow_up(user_message("b"));
2748
2749 assert_eq!(first, u64::MAX);
2750 assert_eq!(second, u64::MAX);
2751 assert_eq!(queue.pending_count(), 2);
2752 }
2753
2754 #[test]
2755 fn message_queue_follow_up_all_mode_drains_entire_queue_in_order() {
2756 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::All);
2757 queue.push_follow_up(user_message("f1"));
2758 queue.push_follow_up(user_message("f2"));
2759
2760 let follow_up = queue.pop_follow_up();
2761 assert_eq!(follow_up.len(), 2);
2762 assert!(matches!(
2763 follow_up.first(),
2764 Some(Message::User(UserMessage { content, .. }))
2765 if matches!(content, UserContent::Text(text) if text == "f1")
2766 ));
2767 assert!(matches!(
2768 follow_up.get(1),
2769 Some(Message::User(UserMessage { content, .. }))
2770 if matches!(content, UserContent::Text(text) if text == "f2")
2771 ));
2772 assert!(queue.pop_follow_up().is_empty());
2773 }
2774}
2775
2776#[cfg(test)]
2777mod extensions_integration_tests {
2778 use super::*;
2779
2780 use crate::session::Session;
2781 use asupersync::runtime::RuntimeBuilder;
2782 use async_trait::async_trait;
2783 use futures::Stream;
2784 use serde_json::json;
2785 use std::path::Path;
2786 use std::pin::Pin;
2787 use std::sync::atomic::AtomicUsize;
2788 use std::time::Duration;
2789
2790 #[derive(Debug)]
2791 struct NoopProvider;
2792
2793 #[async_trait]
2794 #[allow(clippy::unnecessary_literal_bound)]
2795 impl Provider for NoopProvider {
2796 fn name(&self) -> &str {
2797 "test-provider"
2798 }
2799
2800 fn api(&self) -> &str {
2801 "test-api"
2802 }
2803
2804 fn model_id(&self) -> &str {
2805 "test-model"
2806 }
2807
2808 async fn stream(
2809 &self,
2810 _context: &Context<'_>,
2811 _options: &StreamOptions,
2812 ) -> crate::error::Result<
2813 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2814 > {
2815 Ok(Box::pin(futures::stream::empty()))
2816 }
2817 }
2818
2819 #[derive(Debug)]
2820 struct IdleCommandProvider;
2821
2822 #[async_trait]
2823 #[allow(clippy::unnecessary_literal_bound)]
2824 impl Provider for IdleCommandProvider {
2825 fn name(&self) -> &str {
2826 "test-provider"
2827 }
2828
2829 fn api(&self) -> &str {
2830 "test-api"
2831 }
2832
2833 fn model_id(&self) -> &str {
2834 "test-model"
2835 }
2836
2837 async fn stream(
2838 &self,
2839 _context: &Context<'_>,
2840 _options: &StreamOptions,
2841 ) -> crate::error::Result<
2842 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2843 > {
2844 let partial = AssistantMessage {
2845 content: Vec::new(),
2846 api: self.api().to_string(),
2847 provider: self.name().to_string(),
2848 model: self.model_id().to_string(),
2849 usage: Usage::default(),
2850 stop_reason: StopReason::Stop,
2851 error_message: None,
2852 timestamp: 0,
2853 };
2854 let done = AssistantMessage {
2855 content: vec![ContentBlock::Text(TextContent::new(
2856 "resumed-response-0".to_string(),
2857 ))],
2858 api: self.api().to_string(),
2859 provider: self.name().to_string(),
2860 model: self.model_id().to_string(),
2861 usage: Usage::default(),
2862 stop_reason: StopReason::Stop,
2863 error_message: None,
2864 timestamp: 0,
2865 };
2866 Ok(Box::pin(futures::stream::iter(vec![
2867 Ok(StreamEvent::Start { partial }),
2868 Ok(StreamEvent::Done {
2869 reason: StopReason::Stop,
2870 message: done,
2871 }),
2872 ])))
2873 }
2874 }
2875
2876 #[derive(Debug)]
2877 struct CountingTool {
2878 calls: Arc<AtomicUsize>,
2879 }
2880
2881 #[async_trait]
2882 #[allow(clippy::unnecessary_literal_bound)]
2883 impl Tool for CountingTool {
2884 fn name(&self) -> &str {
2885 "count_tool"
2886 }
2887
2888 fn label(&self) -> &str {
2889 "count_tool"
2890 }
2891
2892 fn description(&self) -> &str {
2893 "counting tool"
2894 }
2895
2896 fn parameters(&self) -> serde_json::Value {
2897 json!({ "type": "object" })
2898 }
2899
2900 async fn execute(
2901 &self,
2902 _tool_call_id: &str,
2903 _input: serde_json::Value,
2904 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
2905 ) -> Result<ToolOutput> {
2906 self.calls.fetch_add(1, Ordering::SeqCst);
2907 Ok(ToolOutput {
2908 content: vec![ContentBlock::Text(TextContent::new("ok"))],
2909 details: None,
2910 is_error: false,
2911 })
2912 }
2913 }
2914
2915 #[derive(Debug)]
2916 struct ToolUseProvider {
2917 stream_calls: AtomicUsize,
2918 }
2919
2920 impl ToolUseProvider {
2921 const fn new() -> Self {
2922 Self {
2923 stream_calls: AtomicUsize::new(0),
2924 }
2925 }
2926
2927 fn assistant_message(
2928 &self,
2929 stop_reason: StopReason,
2930 content: Vec<ContentBlock>,
2931 ) -> AssistantMessage {
2932 AssistantMessage {
2933 content,
2934 api: self.api().to_string(),
2935 provider: self.name().to_string(),
2936 model: self.model_id().to_string(),
2937 usage: Usage::default(),
2938 stop_reason,
2939 error_message: None,
2940 timestamp: 0,
2941 }
2942 }
2943 }
2944
2945 #[async_trait]
2946 #[allow(clippy::unnecessary_literal_bound)]
2947 impl Provider for ToolUseProvider {
2948 fn name(&self) -> &str {
2949 "test-provider"
2950 }
2951
2952 fn api(&self) -> &str {
2953 "test-api"
2954 }
2955
2956 fn model_id(&self) -> &str {
2957 "test-model"
2958 }
2959
2960 async fn stream(
2961 &self,
2962 _context: &Context<'_>,
2963 _options: &StreamOptions,
2964 ) -> crate::error::Result<
2965 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2966 > {
2967 let call_index = self.stream_calls.fetch_add(1, Ordering::SeqCst);
2968
2969 let partial = self.assistant_message(StopReason::Stop, Vec::new());
2970
2971 let (reason, message) = if call_index == 0 {
2972 let tool_calls = vec![
2973 ToolCall {
2974 id: "call-1".to_string(),
2975 name: "count_tool".to_string(),
2976 arguments: json!({}),
2977 thought_signature: None,
2978 },
2979 ToolCall {
2980 id: "call-2".to_string(),
2981 name: "count_tool".to_string(),
2982 arguments: json!({}),
2983 thought_signature: None,
2984 },
2985 ];
2986
2987 (
2988 StopReason::ToolUse,
2989 self.assistant_message(
2990 StopReason::ToolUse,
2991 tool_calls
2992 .into_iter()
2993 .map(ContentBlock::ToolCall)
2994 .collect::<Vec<_>>(),
2995 ),
2996 )
2997 } else {
2998 (
2999 StopReason::Stop,
3000 self.assistant_message(
3001 StopReason::Stop,
3002 vec![ContentBlock::Text(TextContent::new("done"))],
3003 ),
3004 )
3005 };
3006
3007 let events = vec![
3008 Ok(StreamEvent::Start { partial }),
3009 Ok(StreamEvent::Done { reason, message }),
3010 ];
3011 Ok(Box::pin(futures::stream::iter(events)))
3012 }
3013 }
3014
3015 #[test]
3016 fn agent_session_enable_extensions_registers_extension_tools() {
3017 let runtime = RuntimeBuilder::current_thread()
3018 .build()
3019 .expect("runtime build");
3020
3021 runtime.block_on(async {
3022 let temp_dir = tempfile::tempdir().expect("tempdir");
3023 let entry_path = temp_dir.path().join("ext.mjs");
3024 std::fs::write(
3025 &entry_path,
3026 r#"
3027 export default function init(pi) {
3028 pi.registerTool({
3029 name: "hello_tool",
3030 label: "hello_tool",
3031 description: "test tool",
3032 parameters: { type: "object", properties: { name: { type: "string" } } },
3033 execute: async (_callId, input, _onUpdate, _abort, ctx) => {
3034 const who = input && input.name ? String(input.name) : "world";
3035 const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
3036 return {
3037 content: [{ type: "text", text: `hello ${who}` }],
3038 details: { from: "extension", cwd: cwd },
3039 isError: false
3040 };
3041 }
3042 });
3043 }
3044 "#,
3045 )
3046 .expect("write extension entry");
3047
3048 let provider = Arc::new(NoopProvider);
3049 let tools = ToolRegistry::new(&[], Path::new("."), None);
3050 let agent = Agent::new(provider, tools, AgentConfig::default());
3051 let session = Arc::new(Mutex::new(Session::in_memory()));
3052 let mut agent_session =
3053 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3054
3055 agent_session
3056 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3057 .await
3058 .expect("enable extensions");
3059
3060 let tool = agent_session
3061 .agent
3062 .tools
3063 .get("hello_tool")
3064 .expect("hello_tool registered");
3065
3066 let output = tool
3067 .execute("call-1", json!({ "name": "pi" }), None)
3068 .await
3069 .expect("execute tool");
3070
3071 assert!(!output.is_error);
3072 assert!(
3073 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3074 "Expected single text content block, got {:?}",
3075 output.content
3076 );
3077 let [ContentBlock::Text(text)] = output.content.as_slice() else {
3078 return;
3079 };
3080 assert_eq!(text.text, "hello pi");
3081
3082 let details = output.details.expect("details present");
3083 assert_eq!(
3084 details.get("from").and_then(serde_json::Value::as_str),
3085 Some("extension")
3086 );
3087 });
3088 }
3089
3090 #[test]
3091 fn agent_session_enable_extensions_with_no_entries_clears_and_is_noop() {
3092 let runtime = RuntimeBuilder::current_thread()
3093 .build()
3094 .expect("runtime build");
3095
3096 runtime.block_on(async {
3097 let temp_dir = tempfile::tempdir().expect("tempdir");
3098 let provider = Arc::new(NoopProvider);
3099 let tools = ToolRegistry::new(&[], Path::new("."), None);
3100 let agent = Agent::new(provider, tools, AgentConfig::default());
3101 let session = Arc::new(Mutex::new(Session::in_memory()));
3102 let mut agent_session =
3103 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3104
3105 let dummy_manager = ExtensionManager::new();
3107 agent_session.extensions = Some(crate::extensions::ExtensionRegion::new(dummy_manager.clone()));
3108 agent_session.agent.extensions = Some(dummy_manager.clone());
3109 agent_session.extension_queue_modes = Some(Arc::new(std::sync::Mutex::new(ExtensionQueueModeState::new(
3110 QueueMode::OneAtATime,
3111 QueueMode::OneAtATime,
3112 ))));
3113 agent_session.extension_injected_queue = Some(Arc::new(std::sync::Mutex::new(ExtensionInjectedQueue::default())));
3114
3115 agent_session
3116 .enable_extensions(&[], temp_dir.path(), None, &[])
3117 .await
3118 .expect("empty extension list should be a no-op");
3119
3120 assert!(
3121 agent_session.extensions.is_none(),
3122 "no extension region should be created (and existing should be cleared) for an empty extension list"
3123 );
3124 assert!(
3125 agent_session.agent.extensions.is_none(),
3126 "agent should not report extensions active when nothing was requested"
3127 );
3128 assert!(
3129 agent_session.extension_queue_modes.is_none(),
3130 "empty extension list should clear queue mode mirrors"
3131 );
3132 assert!(
3133 agent_session.extension_injected_queue.is_none(),
3134 "empty extension list should clear injected extension queues"
3135 );
3136 });
3137 }
3138
3139 #[test]
3140 fn agent_session_enable_extensions_rejects_mixed_js_and_native_entries() {
3141 let runtime = RuntimeBuilder::current_thread()
3142 .build()
3143 .expect("runtime build");
3144
3145 runtime.block_on(async {
3146 let temp_dir = tempfile::tempdir().expect("tempdir");
3147 let js_entry = temp_dir.path().join("ext.mjs");
3148 let native_entry = temp_dir.path().join("ext.native.json");
3149 std::fs::write(
3150 &js_entry,
3151 r"
3152 export default function init(_pi) {}
3153 ",
3154 )
3155 .expect("write js extension entry");
3156 std::fs::write(&native_entry, "{}").expect("write native extension descriptor");
3157
3158 let provider = Arc::new(NoopProvider);
3159 let tools = ToolRegistry::new(&[], Path::new("."), None);
3160 let agent = Agent::new(provider, tools, AgentConfig::default());
3161 let session = Arc::new(Mutex::new(Session::in_memory()));
3162 let mut agent_session =
3163 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3164
3165 let err = agent_session
3166 .enable_extensions(&[], temp_dir.path(), None, &[js_entry, native_entry])
3167 .await
3168 .expect_err("mixed extension runtimes should be rejected");
3169 let msg = err.to_string();
3170 assert!(
3171 msg.contains("Mixed extension runtimes are not supported"),
3172 "unexpected mixed-runtime error message: {msg}"
3173 );
3174 });
3175 }
3176
3177 #[test]
3178 fn extension_send_message_persists_custom_message_entry_when_idle() {
3179 let runtime = RuntimeBuilder::current_thread()
3180 .build()
3181 .expect("runtime build");
3182
3183 runtime.block_on(async {
3184 let temp_dir = tempfile::tempdir().expect("tempdir");
3185 let entry_path = temp_dir.path().join("ext.mjs");
3186 std::fs::write(
3187 &entry_path,
3188 r#"
3189 export default function init(pi) {
3190 pi.registerTool({
3191 name: "emit_message",
3192 label: "emit_message",
3193 description: "emit a custom message",
3194 parameters: { type: "object" },
3195 execute: async () => {
3196 pi.sendMessage({
3197 customType: "note",
3198 content: "hello",
3199 display: true,
3200 details: { from: "test" }
3201 }, {});
3202 return { content: [{ type: "text", text: "ok" }], isError: false };
3203 }
3204 });
3205 }
3206 "#,
3207 )
3208 .expect("write extension entry");
3209
3210 let provider = Arc::new(NoopProvider);
3211 let tools = ToolRegistry::new(&[], Path::new("."), None);
3212 let agent = Agent::new(provider, tools, AgentConfig::default());
3213 let session = Arc::new(Mutex::new(Session::in_memory()));
3214 let mut agent_session = AgentSession::new(
3215 agent,
3216 Arc::clone(&session),
3217 false,
3218 ResolvedCompactionSettings::default(),
3219 );
3220
3221 agent_session
3222 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3223 .await
3224 .expect("enable extensions");
3225
3226 let tool = agent_session
3227 .agent
3228 .tools
3229 .get("emit_message")
3230 .expect("emit_message registered");
3231
3232 let _ = tool
3233 .execute("call-1", json!({}), None)
3234 .await
3235 .expect("execute tool");
3236
3237 let cx = crate::agent_cx::AgentCx::for_request();
3238 let session_guard = session.lock(cx.cx()).await.expect("lock session");
3239 let messages = session_guard.to_messages_for_current_path();
3240
3241 assert!(
3242 messages.iter().any(|msg| {
3243 matches!(
3244 msg,
3245 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
3246 if custom_type == "note"
3247 && content == "hello"
3248 && *display
3249 && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
3250 )
3251 }),
3252 "expected custom message to be persisted, got {messages:?}"
3253 );
3254 });
3255 }
3256
3257 #[test]
3258 fn extension_send_message_persists_custom_message_entry_when_idle_after_await() {
3259 let runtime = RuntimeBuilder::current_thread()
3260 .build()
3261 .expect("runtime build");
3262
3263 runtime.block_on(async {
3264 let temp_dir = tempfile::tempdir().expect("tempdir");
3265 let entry_path = temp_dir.path().join("ext.mjs");
3266 std::fs::write(
3267 &entry_path,
3268 r#"
3269 export default function init(pi) {
3270 pi.registerTool({
3271 name: "emit_message",
3272 label: "emit_message",
3273 description: "emit a custom message",
3274 parameters: { type: "object" },
3275 execute: async () => {
3276 await Promise.resolve();
3277 pi.sendMessage({
3278 customType: "note",
3279 content: "hello-after-await",
3280 display: true,
3281 details: { from: "test" }
3282 }, {});
3283 return { content: [{ type: "text", text: "ok" }], isError: false };
3284 }
3285 });
3286 }
3287 "#,
3288 )
3289 .expect("write extension entry");
3290
3291 let provider = Arc::new(NoopProvider);
3292 let tools = ToolRegistry::new(&[], Path::new("."), None);
3293 let agent = Agent::new(provider, tools, AgentConfig::default());
3294 let session = Arc::new(Mutex::new(Session::in_memory()));
3295 let mut agent_session = AgentSession::new(
3296 agent,
3297 Arc::clone(&session),
3298 false,
3299 ResolvedCompactionSettings::default(),
3300 );
3301
3302 agent_session
3303 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3304 .await
3305 .expect("enable extensions");
3306
3307 let tool = agent_session
3308 .agent
3309 .tools
3310 .get("emit_message")
3311 .expect("emit_message registered");
3312
3313 let _ = tool
3314 .execute("call-1", json!({}), None)
3315 .await
3316 .expect("execute tool");
3317
3318 let cx = crate::agent_cx::AgentCx::for_request();
3319 let session_guard = session.lock(cx.cx()).await.expect("lock session");
3320 let messages = session_guard.to_messages_for_current_path();
3321
3322 assert!(
3323 messages.iter().any(|msg| {
3324 matches!(
3325 msg,
3326 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
3327 if custom_type == "note"
3328 && content == "hello-after-await"
3329 && *display
3330 && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
3331 )
3332 }),
3333 "expected custom message to be persisted, got {messages:?}"
3334 );
3335 });
3336 }
3337
3338 #[test]
3339 fn agent_host_actions_send_message_inherits_cancelled_context_when_locked() {
3340 let runtime = RuntimeBuilder::current_thread()
3341 .build()
3342 .expect("runtime build");
3343
3344 runtime.block_on(async {
3345 let session = Arc::new(Mutex::new(Session::in_memory()));
3346 let actions = AgentSessionHostActions {
3347 session: Arc::clone(&session),
3348 injected: Arc::new(StdMutex::new(ExtensionInjectedQueue::default())),
3349 is_streaming: Arc::new(AtomicBool::new(false)),
3350 is_turn_active: Arc::new(AtomicBool::new(false)),
3351 pending_idle_actions: Arc::new(StdMutex::new(VecDeque::new())),
3352 };
3353
3354 let hold_cx = crate::agent_cx::AgentCx::for_request();
3355 let held_guard = session.lock(hold_cx.cx()).await.expect("lock session");
3356
3357 let ambient_cx = asupersync::Cx::for_testing();
3358 ambient_cx.set_cancel_requested(true);
3359 let _current = asupersync::Cx::set_current(Some(ambient_cx));
3360 let inner = asupersync::time::timeout(
3361 asupersync::time::wall_now(),
3362 Duration::from_millis(100),
3363 actions.send_message(ExtensionSendMessage {
3364 extension_id: Some("ext".to_string()),
3365 custom_type: "note".to_string(),
3366 content: "blocked".to_string(),
3367 display: false,
3368 details: None,
3369 deliver_as: Some(ExtensionDeliverAs::NextTurn),
3370 trigger_turn: false,
3371 }),
3372 )
3373 .await;
3374 let outcome = inner.expect("cancelled helper should finish before timeout");
3375 let err = outcome.expect_err("session append should fail under inherited cancellation");
3376 assert!(
3377 err.to_string().contains("mutex lock cancelled"),
3378 "unexpected error: {err}"
3379 );
3380
3381 drop(held_guard);
3382
3383 let cx = crate::agent_cx::AgentCx::for_request();
3384 let guard = session.lock(cx.cx()).await.expect("lock session");
3385 assert!(
3386 guard.to_messages_for_current_path().is_empty(),
3387 "cancelled send_message should not append a message"
3388 );
3389 });
3390 }
3391
3392 #[test]
3393 fn extension_command_send_message_trigger_turn_runs_agent_turn_when_idle() {
3394 let runtime = RuntimeBuilder::current_thread()
3395 .build()
3396 .expect("runtime build");
3397
3398 runtime.block_on(async {
3399 let temp_dir = tempfile::tempdir().expect("tempdir");
3400 let entry_path = temp_dir.path().join("ext.mjs");
3401 std::fs::write(
3402 &entry_path,
3403 r#"
3404 export default function init(pi) {
3405 pi.registerCommand("emit-now", {
3406 description: "emit a custom message and trigger a turn",
3407 handler: async () => {
3408 await pi.events("sendMessage", {
3409 message: {
3410 customType: "note",
3411 content: "turn-now",
3412 display: true
3413 },
3414 options: {
3415 deliverAs: "steer",
3416 triggerTurn: true
3417 }
3418 });
3419 return "queued";
3420 }
3421 });
3422 }
3423 "#,
3424 )
3425 .expect("write extension entry");
3426
3427 let provider = Arc::new(IdleCommandProvider);
3428 let tools = ToolRegistry::new(&[], Path::new("."), None);
3429 let agent = Agent::new(provider, tools, AgentConfig::default());
3430 let session = Arc::new(Mutex::new(Session::in_memory()));
3431 let mut agent_session = AgentSession::new(
3432 agent,
3433 Arc::clone(&session),
3434 false,
3435 ResolvedCompactionSettings::default(),
3436 );
3437
3438 agent_session
3439 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3440 .await
3441 .expect("enable extensions");
3442
3443 let value = agent_session
3444 .execute_extension_command("emit-now", "", 5_000, |_| {})
3445 .await
3446 .expect("execute extension command");
3447 assert_eq!(value.as_str(), Some("queued"));
3448
3449 let cx = crate::agent_cx::AgentCx::for_request();
3450 let session_guard = session.lock(cx.cx()).await.expect("lock session");
3451 let messages = session_guard.to_messages_for_current_path();
3452
3453 assert!(
3454 messages.iter().any(|msg| {
3455 matches!(
3456 msg,
3457 Message::Custom(CustomMessage { custom_type, content, .. })
3458 if custom_type == "note" && content == "turn-now"
3459 )
3460 }),
3461 "expected custom message prompt in session, got {messages:?}"
3462 );
3463 assert!(
3464 messages.iter().any(|msg| {
3465 matches!(
3466 msg,
3467 Message::Assistant(assistant)
3468 if assistant.content.iter().any(|block| matches!(
3469 block,
3470 ContentBlock::Text(TextContent { text, .. })
3471 if text == "resumed-response-0"
3472 ))
3473 )
3474 }),
3475 "expected assistant response after triggered turn, got {messages:?}"
3476 );
3477 });
3478 }
3479
3480 #[test]
3481 fn agent_extension_session_get_state_reports_agent_runtime_state() {
3482 let runtime = RuntimeBuilder::current_thread()
3483 .build()
3484 .expect("runtime build");
3485
3486 runtime.block_on(async {
3487 let mut session = Session::in_memory();
3488 session.set_model_header(
3489 Some("test-provider".to_string()),
3490 Some("test-model".to_string()),
3491 Some("high".to_string()),
3492 );
3493 session.append_message(crate::session::SessionMessage::User {
3494 content: UserContent::Text("hello".to_string()),
3495 timestamp: Some(1),
3496 });
3497 let session = Arc::new(Mutex::new(session));
3498
3499 let extension_session = AgentExtensionSession {
3500 handle: SessionHandle(Arc::clone(&session)),
3501 is_streaming: Arc::new(AtomicBool::new(true)),
3502 is_compacting: Arc::new(AtomicBool::new(true)),
3503 queue_modes: Arc::new(StdMutex::new(ExtensionQueueModeState::new(
3504 QueueMode::All,
3505 QueueMode::OneAtATime,
3506 ))),
3507 auto_compaction_enabled: true,
3508 };
3509
3510 let state = <AgentExtensionSession as crate::extensions::ExtensionSession>::get_state(
3511 &extension_session,
3512 )
3513 .await;
3514
3515 assert_eq!(state["model"]["provider"], "test-provider");
3516 assert_eq!(state["model"]["id"], "test-model");
3517 assert_eq!(state["thinkingLevel"], "high");
3518 assert_eq!(state["isStreaming"], true);
3519 assert_eq!(state["isCompacting"], true);
3520 assert_eq!(state["steeringMode"], "all");
3521 assert_eq!(state["followUpMode"], "one-at-a-time");
3522 assert_eq!(state["autoCompactionEnabled"], true);
3523 assert_eq!(state["messageCount"], 1);
3524 });
3525 }
3526
3527 #[test]
3528 fn agent_extension_session_get_state_uses_branch_local_model_and_thinking() {
3529 let runtime = RuntimeBuilder::current_thread()
3530 .build()
3531 .expect("runtime build");
3532
3533 runtime.block_on(async {
3534 let mut session = Session::in_memory();
3535 let root_id = session.append_message(crate::session::SessionMessage::User {
3536 content: UserContent::Text("root".to_string()),
3537 timestamp: Some(1),
3538 });
3539 session.append_model_change("openai".to_string(), "gpt-4o".to_string());
3540 let branch_a_thinking = session.append_thinking_level_change("low".to_string());
3541 session.set_model_header(
3542 Some("openai".to_string()),
3543 Some("gpt-4o".to_string()),
3544 Some("low".to_string()),
3545 );
3546
3547 assert!(session.create_branch_from(&root_id));
3548 session.append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
3549 session.append_thinking_level_change("high".to_string());
3550 session.set_model_header(
3551 Some("anthropic".to_string()),
3552 Some("claude-sonnet-4-5".to_string()),
3553 Some("high".to_string()),
3554 );
3555
3556 assert!(session.navigate_to(&branch_a_thinking));
3557 let session = Arc::new(Mutex::new(session));
3558
3559 let extension_session = AgentExtensionSession {
3560 handle: SessionHandle(Arc::clone(&session)),
3561 is_streaming: Arc::new(AtomicBool::new(false)),
3562 is_compacting: Arc::new(AtomicBool::new(false)),
3563 queue_modes: Arc::new(StdMutex::new(ExtensionQueueModeState::new(
3564 QueueMode::OneAtATime,
3565 QueueMode::OneAtATime,
3566 ))),
3567 auto_compaction_enabled: false,
3568 };
3569
3570 let state = <AgentExtensionSession as crate::extensions::ExtensionSession>::get_state(
3571 &extension_session,
3572 )
3573 .await;
3574
3575 assert_eq!(state["model"]["provider"], "openai");
3576 assert_eq!(state["model"]["id"], "gpt-4o");
3577 assert_eq!(state["thinkingLevel"], "low");
3578 });
3579 }
3580
3581 #[test]
3582 fn agent_session_set_queue_modes_updates_extension_delivery_state() {
3583 let provider = Arc::new(NoopProvider);
3584 let tools = ToolRegistry::new(&[], Path::new("."), None);
3585 let agent = Agent::new(provider, tools, AgentConfig::default());
3586 let session = Arc::new(Mutex::new(Session::in_memory()));
3587 let mut agent_session =
3588 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3589
3590 let queue_modes = Arc::new(StdMutex::new(ExtensionQueueModeState::new(
3591 QueueMode::OneAtATime,
3592 QueueMode::OneAtATime,
3593 )));
3594 let injected_queue = Arc::new(StdMutex::new(ExtensionInjectedQueue::new(
3595 QueueMode::OneAtATime,
3596 QueueMode::OneAtATime,
3597 )));
3598 agent_session.extension_queue_modes = Some(Arc::clone(&queue_modes));
3599 agent_session.extension_injected_queue = Some(Arc::clone(&injected_queue));
3600
3601 agent_session.set_queue_modes(QueueMode::All, QueueMode::All);
3602
3603 assert_eq!(
3604 agent_session.agent.queue_modes(),
3605 (QueueMode::All, QueueMode::All)
3606 );
3607 let mirrored = queue_modes.lock().expect("lock queue mode mirror");
3608 assert_eq!(mirrored.steering_mode, QueueMode::All);
3609 assert_eq!(mirrored.follow_up_mode, QueueMode::All);
3610 drop(mirrored);
3611
3612 let queued_follow_up_len = {
3613 let mut queue = injected_queue.lock().expect("lock injected queue");
3614 queue.push_follow_up(Message::User(UserMessage {
3615 content: UserContent::Text("first".to_string()),
3616 timestamp: 0,
3617 }));
3618 queue.push_follow_up(Message::User(UserMessage {
3619 content: UserContent::Text("second".to_string()),
3620 timestamp: 0,
3621 }));
3622 queue.pop_follow_up().len()
3623 };
3624 assert_eq!(
3625 queued_follow_up_len, 2,
3626 "updated queue modes should apply to extension-injected follow-ups"
3627 );
3628 }
3629
3630 #[test]
3631 fn extension_command_send_user_message_runs_agent_turn_when_idle() {
3632 let runtime = RuntimeBuilder::current_thread()
3633 .build()
3634 .expect("runtime build");
3635
3636 runtime.block_on(async {
3637 let temp_dir = tempfile::tempdir().expect("tempdir");
3638 let entry_path = temp_dir.path().join("ext.mjs");
3639 std::fs::write(
3640 &entry_path,
3641 r#"
3642 export default function init(pi) {
3643 pi.registerCommand("inject-user", {
3644 description: "inject a user message",
3645 handler: async () => {
3646 await pi.events("sendUserMessage", {
3647 text: "Please review the changes"
3648 });
3649 return "queued";
3650 }
3651 });
3652 }
3653 "#,
3654 )
3655 .expect("write extension entry");
3656
3657 let provider = Arc::new(IdleCommandProvider);
3658 let tools = ToolRegistry::new(&[], Path::new("."), None);
3659 let agent = Agent::new(provider, tools, AgentConfig::default());
3660 let session = Arc::new(Mutex::new(Session::in_memory()));
3661 let mut agent_session = AgentSession::new(
3662 agent,
3663 Arc::clone(&session),
3664 false,
3665 ResolvedCompactionSettings::default(),
3666 );
3667
3668 agent_session
3669 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3670 .await
3671 .expect("enable extensions");
3672
3673 let value = agent_session
3674 .execute_extension_command("inject-user", "", 5_000, |_| {})
3675 .await
3676 .expect("execute extension command");
3677 assert_eq!(value.as_str(), Some("queued"));
3678
3679 let cx = crate::agent_cx::AgentCx::for_request();
3680 let session_guard = session.lock(cx.cx()).await.expect("lock session");
3681 let messages = session_guard.to_messages_for_current_path();
3682
3683 assert!(
3684 messages.iter().any(|msg| {
3685 matches!(
3686 msg,
3687 Message::User(UserMessage {
3688 content: UserContent::Text(text),
3689 ..
3690 }) if text == "Please review the changes"
3691 )
3692 }),
3693 "expected injected user message in session, got {messages:?}"
3694 );
3695 assert!(
3696 messages.iter().any(|msg| {
3697 matches!(
3698 msg,
3699 Message::Assistant(assistant)
3700 if assistant.content.iter().any(|block| matches!(
3701 block,
3702 ContentBlock::Text(TextContent { text, .. })
3703 if text == "resumed-response-0"
3704 ))
3705 )
3706 }),
3707 "expected assistant response after injected user turn, got {messages:?}"
3708 );
3709 });
3710 }
3711
3712 #[test]
3713 fn send_user_message_steer_skips_remaining_tools() {
3714 let runtime = RuntimeBuilder::current_thread()
3715 .build()
3716 .expect("runtime build");
3717
3718 runtime.block_on(async {
3719 let temp_dir = tempfile::tempdir().expect("tempdir");
3720 let entry_path = temp_dir.path().join("ext.mjs");
3721 std::fs::write(
3722 &entry_path,
3723 r#"
3724 export default function init(pi) {
3725 let sent = false;
3726 pi.on("tool_call", async (event) => {
3727 if (sent) return {};
3728 if (event && event.toolName === "count_tool") {
3729 sent = true;
3730 await pi.events("sendUserMessage", {
3731 text: "steer-now",
3732 options: { deliverAs: "steer" }
3733 });
3734 }
3735 return {};
3736 });
3737 }
3738 "#,
3739 )
3740 .expect("write extension entry");
3741
3742 let provider = Arc::new(ToolUseProvider::new());
3743 let calls = Arc::new(AtomicUsize::new(0));
3744 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3745 calls: Arc::clone(&calls),
3746 })]);
3747 let agent = Agent::new(provider, tools, AgentConfig::default());
3748 let session = Arc::new(Mutex::new(Session::in_memory()));
3749 let mut agent_session =
3750 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3751
3752 agent_session
3753 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3754 .await
3755 .expect("enable extensions");
3756
3757 let _ = agent_session
3758 .run_text("go".to_string(), |_| {})
3759 .await
3760 .expect("run_text");
3761
3762 assert_eq!(calls.load(Ordering::SeqCst), 1);
3764 });
3765 }
3766
3767 #[test]
3768 fn send_user_message_follow_up_does_not_skip_tools() {
3769 let runtime = RuntimeBuilder::current_thread()
3770 .build()
3771 .expect("runtime build");
3772
3773 runtime.block_on(async {
3774 let temp_dir = tempfile::tempdir().expect("tempdir");
3775 let entry_path = temp_dir.path().join("ext.mjs");
3776 std::fs::write(
3777 &entry_path,
3778 r#"
3779 export default function init(pi) {
3780 let sent = false;
3781 pi.on("tool_call", async (event) => {
3782 if (sent) return {};
3783 if (event && event.toolName === "count_tool") {
3784 sent = true;
3785 await pi.events("sendUserMessage", {
3786 text: "follow-up",
3787 options: { deliverAs: "followUp" }
3788 });
3789 }
3790 return {};
3791 });
3792 }
3793 "#,
3794 )
3795 .expect("write extension entry");
3796
3797 let provider = Arc::new(ToolUseProvider::new());
3798 let calls = Arc::new(AtomicUsize::new(0));
3799 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3800 calls: Arc::clone(&calls),
3801 })]);
3802 let agent = Agent::new(provider, tools, AgentConfig::default());
3803 let session = Arc::new(Mutex::new(Session::in_memory()));
3804 let mut agent_session =
3805 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3806
3807 agent_session
3808 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3809 .await
3810 .expect("enable extensions");
3811
3812 let _ = agent_session
3813 .run_text("go".to_string(), |_| {})
3814 .await
3815 .expect("run_text");
3816
3817 assert_eq!(calls.load(Ordering::SeqCst), 2);
3818 });
3819 }
3820
3821 #[test]
3822 fn tool_call_hook_can_block_tool_execution() {
3823 let runtime = RuntimeBuilder::current_thread()
3824 .build()
3825 .expect("runtime build");
3826
3827 runtime.block_on(async {
3828 let temp_dir = tempfile::tempdir().expect("tempdir");
3829 let entry_path = temp_dir.path().join("ext.mjs");
3830 std::fs::write(
3831 &entry_path,
3832 r#"
3833 export default function init(pi) {
3834 pi.on("tool_call", async (event) => {
3835 if (event && event.toolName === "count_tool") {
3836 return { block: true, reason: "blocked in test" };
3837 }
3838 return {};
3839 });
3840 }
3841 "#,
3842 )
3843 .expect("write extension entry");
3844
3845 let provider = Arc::new(NoopProvider);
3846 let calls = Arc::new(AtomicUsize::new(0));
3847 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3848 calls: Arc::clone(&calls),
3849 })]);
3850 let agent = Agent::new(provider, tools, AgentConfig::default());
3851 let session = Arc::new(Mutex::new(Session::in_memory()));
3852 let mut agent_session =
3853 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3854
3855 agent_session
3856 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3857 .await
3858 .expect("enable extensions");
3859
3860 let tool_call = ToolCall {
3861 id: "call-1".to_string(),
3862 name: "count_tool".to_string(),
3863 arguments: json!({}),
3864 thought_signature: None,
3865 };
3866
3867 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3868 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3869
3870 assert!(is_error);
3871 assert!(output.is_error);
3872 assert_eq!(calls.load(Ordering::SeqCst), 0);
3873
3874 assert_eq!(output.details, None);
3875 assert!(
3876 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3877 "Expected text output, got {:?}",
3878 output.content
3879 );
3880 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3881 assert_eq!(text.text, "Tool execution blocked: blocked in test");
3882 }
3883 });
3884 }
3885
3886 #[test]
3887 fn tool_call_hook_errors_fail_open() {
3888 let runtime = RuntimeBuilder::current_thread()
3889 .build()
3890 .expect("runtime build");
3891
3892 runtime.block_on(async {
3893 let temp_dir = tempfile::tempdir().expect("tempdir");
3894 let entry_path = temp_dir.path().join("ext.mjs");
3895 std::fs::write(
3896 &entry_path,
3897 r#"
3898 export default function init(pi) {
3899 pi.on("tool_call", async (_event) => {
3900 throw new Error("boom");
3901 });
3902 }
3903 "#,
3904 )
3905 .expect("write extension entry");
3906
3907 let provider = Arc::new(NoopProvider);
3908 let calls = Arc::new(AtomicUsize::new(0));
3909 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3910 calls: Arc::clone(&calls),
3911 })]);
3912 let agent = Agent::new(provider, tools, AgentConfig::default());
3913 let session = Arc::new(Mutex::new(Session::in_memory()));
3914 let mut agent_session =
3915 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3916
3917 agent_session
3918 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3919 .await
3920 .expect("enable extensions");
3921
3922 let tool_call = ToolCall {
3923 id: "call-1".to_string(),
3924 name: "count_tool".to_string(),
3925 arguments: json!({}),
3926 thought_signature: None,
3927 };
3928
3929 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3930 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3931
3932 assert!(!is_error);
3933 assert!(!output.is_error);
3934 assert_eq!(calls.load(Ordering::SeqCst), 1);
3935 });
3936 }
3937
3938 #[test]
3939 fn tool_call_hook_errors_fail_closed_when_configured() {
3940 let runtime = RuntimeBuilder::current_thread()
3941 .build()
3942 .expect("runtime build");
3943
3944 runtime.block_on(async {
3945 let temp_dir = tempfile::tempdir().expect("tempdir");
3946 let entry_path = temp_dir.path().join("ext.mjs");
3947 std::fs::write(
3948 &entry_path,
3949 r#"
3950 export default function init(pi) {
3951 pi.on("tool_call", async (_event) => {
3952 throw new Error("boom");
3953 });
3954 }
3955 "#,
3956 )
3957 .expect("write extension entry");
3958
3959 let provider = Arc::new(NoopProvider);
3960 let calls = Arc::new(AtomicUsize::new(0));
3961 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3962 calls: Arc::clone(&calls),
3963 })]);
3964 let agent = Agent::new(
3965 provider,
3966 tools,
3967 AgentConfig {
3968 fail_closed_hooks: true,
3969 ..AgentConfig::default()
3970 },
3971 );
3972 let session = Arc::new(Mutex::new(Session::in_memory()));
3973 let mut agent_session =
3974 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3975
3976 agent_session
3977 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3978 .await
3979 .expect("enable extensions");
3980
3981 let tool_call = ToolCall {
3982 id: "call-1".to_string(),
3983 name: "count_tool".to_string(),
3984 arguments: json!({}),
3985 thought_signature: None,
3986 };
3987
3988 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3989 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3990
3991 assert!(is_error);
3992 assert!(output.is_error);
3993 assert_eq!(calls.load(Ordering::SeqCst), 0);
3994 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3995 assert_eq!(text.text, "Tool execution blocked: extension hook failed");
3996 } else {
3997 panic!("Expected text output, got {:?}", output.content);
3998 }
3999 });
4000 }
4001
4002 #[test]
4003 fn tool_call_hook_absent_allows_tool_execution() {
4004 let runtime = RuntimeBuilder::current_thread()
4005 .build()
4006 .expect("runtime build");
4007
4008 runtime.block_on(async {
4009 let temp_dir = tempfile::tempdir().expect("tempdir");
4010 let entry_path = temp_dir.path().join("ext.mjs");
4011 std::fs::write(
4012 &entry_path,
4013 r"
4014 export default function init(_pi) {}
4015 ",
4016 )
4017 .expect("write extension entry");
4018
4019 let provider = Arc::new(NoopProvider);
4020 let calls = Arc::new(AtomicUsize::new(0));
4021 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
4022 calls: Arc::clone(&calls),
4023 })]);
4024 let agent = Agent::new(provider, tools, AgentConfig::default());
4025 let session = Arc::new(Mutex::new(Session::in_memory()));
4026 let mut agent_session =
4027 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4028
4029 agent_session
4030 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4031 .await
4032 .expect("enable extensions");
4033
4034 let tool_call = ToolCall {
4035 id: "call-1".to_string(),
4036 name: "count_tool".to_string(),
4037 arguments: json!({}),
4038 thought_signature: None,
4039 };
4040
4041 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4042 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4043
4044 assert!(!is_error);
4045 assert!(!output.is_error);
4046 assert_eq!(calls.load(Ordering::SeqCst), 1);
4047 });
4048 }
4049
4050 #[test]
4051 fn tool_call_hook_returns_empty_allows_tool_execution() {
4052 let runtime = RuntimeBuilder::current_thread()
4053 .build()
4054 .expect("runtime build");
4055
4056 runtime.block_on(async {
4057 let temp_dir = tempfile::tempdir().expect("tempdir");
4058 let entry_path = temp_dir.path().join("ext.mjs");
4059 std::fs::write(
4060 &entry_path,
4061 r#"
4062 export default function init(pi) {
4063 pi.on("tool_call", async (_event) => ({}));
4064 }
4065 "#,
4066 )
4067 .expect("write extension entry");
4068
4069 let provider = Arc::new(NoopProvider);
4070 let calls = Arc::new(AtomicUsize::new(0));
4071 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
4072 calls: Arc::clone(&calls),
4073 })]);
4074 let agent = Agent::new(provider, tools, AgentConfig::default());
4075 let session = Arc::new(Mutex::new(Session::in_memory()));
4076 let mut agent_session =
4077 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4078
4079 agent_session
4080 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4081 .await
4082 .expect("enable extensions");
4083
4084 let tool_call = ToolCall {
4085 id: "call-1".to_string(),
4086 name: "count_tool".to_string(),
4087 arguments: json!({}),
4088 thought_signature: None,
4089 };
4090
4091 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4092 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4093
4094 assert!(!is_error);
4095 assert!(!output.is_error);
4096 assert_eq!(calls.load(Ordering::SeqCst), 1);
4097 });
4098 }
4099
4100 #[test]
4101 fn tool_call_hook_can_block_bash_tool_execution() {
4102 let runtime = RuntimeBuilder::current_thread()
4103 .build()
4104 .expect("runtime build");
4105
4106 runtime.block_on(async {
4107 let temp_dir = tempfile::tempdir().expect("tempdir");
4108 let entry_path = temp_dir.path().join("ext.mjs");
4109 std::fs::write(
4110 &entry_path,
4111 r#"
4112 export default function init(pi) {
4113 pi.on("tool_call", async (event) => {
4114 const name = event && event.toolName ? String(event.toolName) : "";
4115 if (name === "bash") return { block: true, reason: "blocked bash in test" };
4116 return {};
4117 });
4118 }
4119 "#,
4120 )
4121 .expect("write extension entry");
4122
4123 let provider = Arc::new(NoopProvider);
4124 let tools = ToolRegistry::new(&["bash"], temp_dir.path(), None);
4125 let agent = Agent::new(provider, tools, AgentConfig::default());
4126 let session = Arc::new(Mutex::new(Session::in_memory()));
4127 let mut agent_session =
4128 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4129
4130 agent_session
4131 .enable_extensions(&["bash"], temp_dir.path(), None, &[entry_path])
4132 .await
4133 .expect("enable extensions");
4134
4135 let tool_call = ToolCall {
4136 id: "call-1".to_string(),
4137 name: "bash".to_string(),
4138 arguments: json!({ "command": "printf 'hi' > blocked.txt" }),
4139 thought_signature: None,
4140 };
4141
4142 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4143 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4144
4145 assert!(is_error);
4146 assert!(output.is_error);
4147 assert_eq!(output.details, None);
4148 assert!(
4149 !temp_dir.path().join("blocked.txt").exists(),
4150 "expected bash command not to run when blocked"
4151 );
4152 assert!(
4153 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4154 "Expected text output, got {:?}",
4155 output.content
4156 );
4157 if let [ContentBlock::Text(text)] = output.content.as_slice() {
4158 assert_eq!(text.text, "Tool execution blocked: blocked bash in test");
4159 }
4160 });
4161 }
4162
4163 #[test]
4164 fn tool_result_hook_can_modify_tool_output() {
4165 let runtime = RuntimeBuilder::current_thread()
4166 .build()
4167 .expect("runtime build");
4168
4169 runtime.block_on(async {
4170 let temp_dir = tempfile::tempdir().expect("tempdir");
4171 let entry_path = temp_dir.path().join("ext.mjs");
4172 std::fs::write(
4173 &entry_path,
4174 r#"
4175 export default function init(pi) {
4176 pi.on("tool_result", async (event) => {
4177 if (event && event.toolName === "count_tool") {
4178 return {
4179 content: [{ type: "text", text: "modified" }],
4180 details: { from: "tool_result" }
4181 };
4182 }
4183 return {};
4184 });
4185 }
4186 "#,
4187 )
4188 .expect("write extension entry");
4189
4190 let provider = Arc::new(NoopProvider);
4191 let calls = Arc::new(AtomicUsize::new(0));
4192 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
4193 calls: Arc::clone(&calls),
4194 })]);
4195 let agent = Agent::new(provider, tools, AgentConfig::default());
4196 let session = Arc::new(Mutex::new(Session::in_memory()));
4197 let mut agent_session =
4198 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4199
4200 agent_session
4201 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4202 .await
4203 .expect("enable extensions");
4204
4205 let tool_call = ToolCall {
4206 id: "call-1".to_string(),
4207 name: "count_tool".to_string(),
4208 arguments: json!({}),
4209 thought_signature: None,
4210 };
4211
4212 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4213 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4214
4215 assert!(!is_error);
4216 assert!(!output.is_error);
4217 assert_eq!(calls.load(Ordering::SeqCst), 1);
4218 assert_eq!(output.details, Some(json!({ "from": "tool_result" })));
4219
4220 assert!(
4221 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4222 "Expected text output, got {:?}",
4223 output.content
4224 );
4225 if let [ContentBlock::Text(text)] = output.content.as_slice() {
4226 assert_eq!(text.text, "modified");
4227 }
4228 });
4229 }
4230
4231 #[test]
4232 fn tool_result_hook_can_modify_tool_not_found_error() {
4233 let runtime = RuntimeBuilder::current_thread()
4234 .build()
4235 .expect("runtime build");
4236
4237 runtime.block_on(async {
4238 let temp_dir = tempfile::tempdir().expect("tempdir");
4239 let entry_path = temp_dir.path().join("ext.mjs");
4240 std::fs::write(
4241 &entry_path,
4242 r#"
4243 export default function init(pi) {
4244 pi.on("tool_result", async (event) => {
4245 if (event && event.toolName === "missing_tool" && event.isError) {
4246 return {
4247 content: [{ type: "text", text: "overridden" }],
4248 details: { handled: true }
4249 };
4250 }
4251 return {};
4252 });
4253 }
4254 "#,
4255 )
4256 .expect("write extension entry");
4257
4258 let provider = Arc::new(NoopProvider);
4259 let tools = ToolRegistry::from_tools(Vec::new());
4260 let agent = Agent::new(provider, tools, AgentConfig::default());
4261 let session = Arc::new(Mutex::new(Session::in_memory()));
4262 let mut agent_session =
4263 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4264
4265 agent_session
4266 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4267 .await
4268 .expect("enable extensions");
4269
4270 let tool_call = ToolCall {
4271 id: "call-1".to_string(),
4272 name: "missing_tool".to_string(),
4273 arguments: json!({}),
4274 thought_signature: None,
4275 };
4276
4277 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4278 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4279
4280 assert!(is_error);
4281 assert!(output.is_error);
4282 assert_eq!(output.details, Some(json!({ "handled": true })));
4283
4284 assert!(
4285 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4286 "Expected text output, got {:?}",
4287 output.content
4288 );
4289 if let [ContentBlock::Text(text)] = output.content.as_slice() {
4290 assert_eq!(text.text, "overridden");
4291 }
4292 });
4293 }
4294
4295 #[test]
4296 fn tool_result_hook_errors_fail_open() {
4297 let runtime = RuntimeBuilder::current_thread()
4298 .build()
4299 .expect("runtime build");
4300
4301 runtime.block_on(async {
4302 let temp_dir = tempfile::tempdir().expect("tempdir");
4303 let entry_path = temp_dir.path().join("ext.mjs");
4304 std::fs::write(
4305 &entry_path,
4306 r#"
4307 export default function init(pi) {
4308 pi.on("tool_result", async (_event) => {
4309 throw new Error("boom");
4310 });
4311 }
4312 "#,
4313 )
4314 .expect("write extension entry");
4315
4316 let provider = Arc::new(NoopProvider);
4317 let calls = Arc::new(AtomicUsize::new(0));
4318 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
4319 calls: Arc::clone(&calls),
4320 })]);
4321 let agent = Agent::new(provider, tools, AgentConfig::default());
4322 let session = Arc::new(Mutex::new(Session::in_memory()));
4323 let mut agent_session =
4324 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4325
4326 agent_session
4327 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4328 .await
4329 .expect("enable extensions");
4330
4331 let tool_call = ToolCall {
4332 id: "call-1".to_string(),
4333 name: "count_tool".to_string(),
4334 arguments: json!({}),
4335 thought_signature: None,
4336 };
4337
4338 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4339 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4340
4341 assert!(!is_error);
4342 assert!(!output.is_error);
4343 assert_eq!(calls.load(Ordering::SeqCst), 1);
4344
4345 assert_eq!(output.details, None);
4346 assert!(
4347 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4348 "Expected text output, got {:?}",
4349 output.content
4350 );
4351 if let [ContentBlock::Text(text)] = output.content.as_slice() {
4352 assert_eq!(text.text, "ok");
4353 }
4354 });
4355 }
4356
4357 #[test]
4358 fn tool_result_hook_runs_on_blocked_tool_call() {
4359 let runtime = RuntimeBuilder::current_thread()
4360 .build()
4361 .expect("runtime build");
4362
4363 runtime.block_on(async {
4364 let temp_dir = tempfile::tempdir().expect("tempdir");
4365 let entry_path = temp_dir.path().join("ext.mjs");
4366 std::fs::write(
4367 &entry_path,
4368 r#"
4369 export default function init(pi) {
4370 pi.on("tool_call", async (event) => {
4371 if (event && event.toolName === "count_tool") {
4372 return { block: true, reason: "blocked in test" };
4373 }
4374 return {};
4375 });
4376
4377 pi.on("tool_result", async (event) => {
4378 if (event && event.toolName === "count_tool" && event.isError) {
4379 return { content: [{ type: "text", text: "override" }] };
4380 }
4381 return {};
4382 });
4383 }
4384 "#,
4385 )
4386 .expect("write extension entry");
4387
4388 let provider = Arc::new(NoopProvider);
4389 let calls = Arc::new(AtomicUsize::new(0));
4390 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
4391 calls: Arc::clone(&calls),
4392 })]);
4393 let agent = Agent::new(provider, tools, AgentConfig::default());
4394 let session = Arc::new(Mutex::new(Session::in_memory()));
4395 let mut agent_session =
4396 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4397
4398 agent_session
4399 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4400 .await
4401 .expect("enable extensions");
4402
4403 let tool_call = ToolCall {
4404 id: "call-1".to_string(),
4405 name: "count_tool".to_string(),
4406 arguments: json!({}),
4407 thought_signature: None,
4408 };
4409
4410 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
4411 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
4412
4413 assert!(is_error);
4414 assert!(output.is_error);
4415 assert_eq!(calls.load(Ordering::SeqCst), 0);
4416
4417 assert!(
4418 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4419 "Expected text output, got {:?}",
4420 output.content
4421 );
4422 if let [ContentBlock::Text(text)] = output.content.as_slice() {
4423 assert_eq!(text.text, "override");
4424 }
4425 });
4426 }
4427}
4428
4429#[cfg(test)]
4430mod abort_tests {
4431 use super::*;
4432 use crate::session::Session;
4433 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
4434 use asupersync::runtime::RuntimeBuilder;
4435 use async_trait::async_trait;
4436 use futures::Stream;
4437 use serde_json::json;
4438 use std::path::Path;
4439 use std::pin::Pin;
4440 use std::sync::Mutex as StdMutex;
4441 use std::sync::atomic::AtomicUsize;
4442 use std::task::{Context as TaskContext, Poll};
4443
4444 struct StartThenPending {
4445 start: Option<StreamEvent>,
4446 }
4447
4448 impl Stream for StartThenPending {
4449 type Item = crate::error::Result<StreamEvent>;
4450
4451 fn poll_next(
4452 mut self: Pin<&mut Self>,
4453 _cx: &mut TaskContext<'_>,
4454 ) -> Poll<Option<Self::Item>> {
4455 if let Some(event) = self.start.take() {
4456 return Poll::Ready(Some(Ok(event)));
4457 }
4458 Poll::Pending
4459 }
4460 }
4461
4462 #[derive(Debug)]
4463 struct HangingProvider;
4464
4465 #[async_trait]
4466 #[allow(clippy::unnecessary_literal_bound)]
4467 impl Provider for HangingProvider {
4468 fn name(&self) -> &str {
4469 "test-provider"
4470 }
4471
4472 fn api(&self) -> &str {
4473 "test-api"
4474 }
4475
4476 fn model_id(&self) -> &str {
4477 "test-model"
4478 }
4479
4480 async fn stream(
4481 &self,
4482 _context: &Context<'_>,
4483 _options: &StreamOptions,
4484 ) -> crate::error::Result<
4485 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4486 > {
4487 let partial = AssistantMessage {
4488 content: Vec::new(),
4489 api: self.api().to_string(),
4490 provider: self.name().to_string(),
4491 model: self.model_id().to_string(),
4492 usage: Usage::default(),
4493 stop_reason: StopReason::Stop,
4494 error_message: None,
4495 timestamp: 0,
4496 };
4497
4498 Ok(Box::pin(StartThenPending {
4499 start: Some(StreamEvent::Start { partial }),
4500 }))
4501 }
4502 }
4503
4504 #[derive(Debug)]
4505 struct CountingProvider {
4506 calls: Arc<std::sync::atomic::AtomicUsize>,
4507 }
4508
4509 #[async_trait]
4510 #[allow(clippy::unnecessary_literal_bound)]
4511 impl Provider for CountingProvider {
4512 fn name(&self) -> &str {
4513 "test-provider"
4514 }
4515
4516 fn api(&self) -> &str {
4517 "test-api"
4518 }
4519
4520 fn model_id(&self) -> &str {
4521 "test-model"
4522 }
4523
4524 async fn stream(
4525 &self,
4526 _context: &Context<'_>,
4527 _options: &StreamOptions,
4528 ) -> crate::error::Result<
4529 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4530 > {
4531 self.calls.fetch_add(1, Ordering::SeqCst);
4532 Ok(Box::pin(futures::stream::empty()))
4533 }
4534 }
4535
4536 #[derive(Debug)]
4537 struct PhasedProvider {
4538 pending_calls: usize,
4539 calls: AtomicUsize,
4540 }
4541
4542 impl PhasedProvider {
4543 const fn new(pending_calls: usize) -> Self {
4544 Self {
4545 pending_calls,
4546 calls: AtomicUsize::new(0),
4547 }
4548 }
4549
4550 fn base_message() -> AssistantMessage {
4551 AssistantMessage {
4552 content: Vec::new(),
4553 api: "test-api".to_string(),
4554 provider: "test-provider".to_string(),
4555 model: "test-model".to_string(),
4556 usage: Usage::default(),
4557 stop_reason: StopReason::Stop,
4558 error_message: None,
4559 timestamp: 0,
4560 }
4561 }
4562 }
4563
4564 #[async_trait]
4565 #[allow(clippy::unnecessary_literal_bound)]
4566 impl Provider for PhasedProvider {
4567 fn name(&self) -> &str {
4568 "test-provider"
4569 }
4570
4571 fn api(&self) -> &str {
4572 "test-api"
4573 }
4574
4575 fn model_id(&self) -> &str {
4576 "test-model"
4577 }
4578
4579 async fn stream(
4580 &self,
4581 _context: &Context<'_>,
4582 _options: &StreamOptions,
4583 ) -> crate::error::Result<
4584 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4585 > {
4586 let call = self.calls.fetch_add(1, Ordering::SeqCst);
4587 if call < self.pending_calls {
4588 return Ok(Box::pin(StartThenPending {
4589 start: Some(StreamEvent::Start {
4590 partial: Self::base_message(),
4591 }),
4592 }));
4593 }
4594
4595 let partial = Self::base_message();
4596 let mut done = Self::base_message();
4597 done.content = vec![ContentBlock::Text(TextContent::new(format!(
4598 "resumed-response-{call}"
4599 )))];
4600
4601 Ok(Box::pin(futures::stream::iter(vec![
4602 Ok(StreamEvent::Start { partial }),
4603 Ok(StreamEvent::Done {
4604 reason: StopReason::Stop,
4605 message: done,
4606 }),
4607 ])))
4608 }
4609 }
4610
4611 #[derive(Debug)]
4612 struct ToolCallProvider;
4613
4614 #[async_trait]
4615 #[allow(clippy::unnecessary_literal_bound)]
4616 impl Provider for ToolCallProvider {
4617 fn name(&self) -> &str {
4618 "test-provider"
4619 }
4620
4621 fn api(&self) -> &str {
4622 "test-api"
4623 }
4624
4625 fn model_id(&self) -> &str {
4626 "test-model"
4627 }
4628
4629 async fn stream(
4630 &self,
4631 _context: &Context<'_>,
4632 _options: &StreamOptions,
4633 ) -> crate::error::Result<
4634 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4635 > {
4636 let message = AssistantMessage {
4637 content: vec![ContentBlock::ToolCall(ToolCall {
4638 id: "call-1".to_string(),
4639 name: "hanging_tool".to_string(),
4640 arguments: json!({}),
4641 thought_signature: None,
4642 })],
4643 api: "test-api".to_string(),
4644 provider: "test-provider".to_string(),
4645 model: "test-model".to_string(),
4646 usage: Usage::default(),
4647 stop_reason: StopReason::ToolUse,
4648 error_message: None,
4649 timestamp: 0,
4650 };
4651
4652 Ok(Box::pin(futures::stream::iter(vec![Ok(
4653 StreamEvent::Done {
4654 reason: StopReason::ToolUse,
4655 message,
4656 },
4657 )])))
4658 }
4659 }
4660
4661 #[derive(Debug)]
4662 struct HangingTool;
4663
4664 #[async_trait]
4665 #[allow(clippy::unnecessary_literal_bound)]
4666 impl Tool for HangingTool {
4667 fn name(&self) -> &str {
4668 "hanging_tool"
4669 }
4670
4671 fn label(&self) -> &str {
4672 "Hanging Tool"
4673 }
4674
4675 fn description(&self) -> &str {
4676 "Never completes unless aborted by the host"
4677 }
4678
4679 fn parameters(&self) -> serde_json::Value {
4680 json!({
4681 "type": "object",
4682 "properties": {},
4683 "additionalProperties": false
4684 })
4685 }
4686
4687 async fn execute(
4688 &self,
4689 _tool_call_id: &str,
4690 _input: serde_json::Value,
4691 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
4692 ) -> crate::error::Result<ToolOutput> {
4693 futures::future::pending::<()>().await;
4694 unreachable!("hanging tool should be aborted by the agent")
4695 }
4696 }
4697
4698 fn event_tag(event: &AgentEvent) -> &'static str {
4699 match event {
4700 AgentEvent::AgentStart { .. } => "agent_start",
4701 AgentEvent::AgentEnd { error, .. } => {
4702 if error.as_deref() == Some("Aborted") {
4703 "agent_end_aborted"
4704 } else {
4705 "agent_end"
4706 }
4707 }
4708 AgentEvent::TurnStart { .. } => "turn_start",
4709 AgentEvent::TurnEnd { .. } => "turn_end",
4710 AgentEvent::MessageStart { .. } => "message_start",
4711 AgentEvent::MessageUpdate {
4712 assistant_message_event,
4713 ..
4714 } => match &assistant_message_event {
4715 AssistantMessageEvent::Error {
4716 reason: StopReason::Aborted,
4717 ..
4718 } => "assistant_error_aborted",
4719 AssistantMessageEvent::Done { .. } => "assistant_done",
4720 _ => "assistant_update",
4721 },
4722 AgentEvent::MessageEnd { .. } => "message_end",
4723 AgentEvent::ToolExecutionStart { .. } => "tool_start",
4724 AgentEvent::ToolExecutionUpdate { .. } => "tool_update",
4725 AgentEvent::ToolExecutionEnd { .. } => "tool_end",
4726 AgentEvent::AutoCompactionStart { .. } => "auto_compaction_start",
4727 AgentEvent::AutoCompactionEnd { .. } => "auto_compaction_end",
4728 AgentEvent::AutoRetryStart { .. } => "auto_retry_start",
4729 AgentEvent::AutoRetryEnd { .. } => "auto_retry_end",
4730 AgentEvent::ExtensionError { .. } => "extension_error",
4731 }
4732 }
4733
4734 fn assert_abort_resume_message_sequence(persisted: &[Message]) {
4735 assert_eq!(
4736 persisted.len(),
4737 6,
4738 "expected three user+assistant pairs, got: {persisted:?}"
4739 );
4740
4741 let assistant_states = persisted
4742 .iter()
4743 .filter_map(|message| match message {
4744 Message::Assistant(assistant) => Some(assistant.stop_reason),
4745 _ => None,
4746 })
4747 .collect::<Vec<_>>();
4748 assert_eq!(
4749 assistant_states,
4750 vec![StopReason::Aborted, StopReason::Aborted, StopReason::Stop]
4751 );
4752 }
4753
4754 fn assert_abort_resume_timeline_boundaries(timeline: &[String]) {
4755 assert!(
4756 timeline
4757 .iter()
4758 .any(|event| event == "run0:agent_end_aborted"),
4759 "missing aborted boundary for first run: {timeline:?}"
4760 );
4761 assert!(
4762 timeline
4763 .iter()
4764 .any(|event| event == "run1:agent_end_aborted"),
4765 "missing aborted boundary for second run: {timeline:?}"
4766 );
4767 assert!(
4768 timeline.iter().any(|event| event == "run2:agent_end"),
4769 "missing successful boundary for resumed run: {timeline:?}"
4770 );
4771 }
4772
4773 #[test]
4774 fn abort_interrupts_in_flight_stream() {
4775 let runtime = RuntimeBuilder::current_thread()
4776 .build()
4777 .expect("runtime build");
4778 let handle = runtime.handle();
4779
4780 let started = Arc::new(Notify::new());
4781 let started_wait = started.notified();
4782
4783 let (abort_handle, abort_signal) = AbortHandle::new();
4784
4785 let provider = Arc::new(HangingProvider);
4786 let tools = ToolRegistry::new(&[], Path::new("."), None);
4787 let agent = Agent::new(provider, tools, AgentConfig::default());
4788 let session = Arc::new(Mutex::new(Session::in_memory()));
4789 let mut agent_session =
4790 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4791
4792 let started_tx = Arc::clone(&started);
4793 let join = handle.spawn(async move {
4794 agent_session
4795 .run_text_with_abort("hello".to_string(), Some(abort_signal), move |event| {
4796 if matches!(
4797 event,
4798 AgentEvent::MessageStart {
4799 message: Message::Assistant(_)
4800 }
4801 ) {
4802 started_tx.notify_one();
4803 }
4804 })
4805 .await
4806 });
4807
4808 runtime.block_on(async move {
4809 started_wait.await;
4810 abort_handle.abort();
4811
4812 let message = join.await.expect("run_text_with_abort");
4813 assert_eq!(message.stop_reason, StopReason::Aborted);
4814 assert_eq!(message.error_message.as_deref(), Some("Aborted"));
4815 });
4816 }
4817
4818 #[test]
4819 fn ambient_cancellation_interrupts_in_flight_stream() {
4820 let runtime = RuntimeBuilder::current_thread()
4821 .build()
4822 .expect("runtime build");
4823
4824 runtime.block_on(async move {
4825 let (started_tx, started_rx) = std::sync::mpsc::channel();
4826
4827 let provider = Arc::new(HangingProvider);
4828 let tools = ToolRegistry::new(&[], Path::new("."), None);
4829 let agent = Agent::new(provider, tools, AgentConfig::default());
4830 let session = Arc::new(Mutex::new(Session::in_memory()));
4831 let mut agent_session =
4832 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4833
4834 let ambient_cx = asupersync::Cx::for_testing();
4835 let cancel_cx = ambient_cx.clone();
4836 let _current = asupersync::Cx::set_current(Some(ambient_cx));
4837
4838 let cancel_thread = std::thread::spawn(move || {
4839 started_rx
4840 .recv_timeout(std::time::Duration::from_secs(1))
4841 .expect("stream start");
4842 cancel_cx.set_cancel_requested(true);
4843 });
4844
4845 let run = agent_session.run_text_with_abort("hello".to_string(), None, move |event| {
4846 if matches!(
4847 event,
4848 AgentEvent::MessageStart {
4849 message: Message::Assistant(_)
4850 }
4851 ) {
4852 let _ = started_tx.send(());
4853 }
4854 });
4855 futures::pin_mut!(run);
4856
4857 let message = asupersync::time::timeout(
4858 asupersync::time::wall_now(),
4859 std::time::Duration::from_secs(1),
4860 run,
4861 )
4862 .await
4863 .expect("ambient cancellation should finish before timeout")
4864 .expect("run_text_with_abort");
4865
4866 cancel_thread.join().expect("cancel thread");
4867
4868 assert_eq!(message.stop_reason, StopReason::Aborted);
4869 assert_eq!(message.error_message.as_deref(), Some("Aborted"));
4870 });
4871 }
4872
4873 #[test]
4874 fn abort_before_run_skips_provider_stream_call() {
4875 let runtime = RuntimeBuilder::current_thread()
4876 .build()
4877 .expect("runtime build");
4878
4879 let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
4880 let provider = Arc::new(CountingProvider {
4881 calls: Arc::clone(&calls),
4882 });
4883 let tools = ToolRegistry::new(&[], Path::new("."), None);
4884 let agent = Agent::new(provider, tools, AgentConfig::default());
4885 let session = Arc::new(Mutex::new(Session::in_memory()));
4886 let mut agent_session =
4887 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4888
4889 let (abort_handle, abort_signal) = AbortHandle::new();
4890 abort_handle.abort();
4891
4892 runtime.block_on(async move {
4893 let message = agent_session
4894 .run_text_with_abort("hello".to_string(), Some(abort_signal), |_| {})
4895 .await
4896 .expect("run_text_with_abort");
4897 assert_eq!(message.stop_reason, StopReason::Aborted);
4898 assert_eq!(calls.load(Ordering::SeqCst), 0);
4899 });
4900 }
4901
4902 #[test]
4903 fn abort_then_resume_preserves_session_history() {
4904 let runtime = RuntimeBuilder::current_thread()
4905 .build()
4906 .expect("runtime build");
4907 let handle = runtime.handle();
4908
4909 runtime.block_on(async move {
4910 let provider = Arc::new(PhasedProvider::new(1));
4911 let tools = ToolRegistry::new(&[], Path::new("."), None);
4912 let agent = Agent::new(provider, tools, AgentConfig::default());
4913 let session = Arc::new(Mutex::new(Session::in_memory()));
4914 let mut agent_session = AgentSession::new(
4915 agent,
4916 Arc::clone(&session),
4917 false,
4918 ResolvedCompactionSettings::default(),
4919 );
4920
4921 let started = Arc::new(Notify::new());
4922 let (abort_handle, abort_signal) = AbortHandle::new();
4923 let started_for_abort = Arc::clone(&started);
4924 let abort_join = handle.spawn(async move {
4925 started_for_abort.notified().await;
4926 abort_handle.abort();
4927 });
4928
4929 let aborted = agent_session
4930 .run_text_with_abort("first".to_string(), Some(abort_signal), {
4931 let started = Arc::clone(&started);
4932 move |event| {
4933 if matches!(
4934 event,
4935 AgentEvent::MessageStart {
4936 message: Message::Assistant(_)
4937 }
4938 ) {
4939 started.notify_one();
4940 }
4941 }
4942 })
4943 .await
4944 .expect("first run");
4945 abort_join.await;
4946
4947 assert_eq!(aborted.stop_reason, StopReason::Aborted);
4948 assert_eq!(aborted.error_message.as_deref(), Some("Aborted"));
4949
4950 let resumed = agent_session
4951 .run_text("second".to_string(), |_| {})
4952 .await
4953 .expect("resumed run");
4954 assert_eq!(resumed.stop_reason, StopReason::Stop);
4955 assert!(resumed.error_message.is_none());
4956
4957 let cx = crate::agent_cx::AgentCx::for_request();
4958 let persisted = session
4959 .lock(cx.cx())
4960 .await
4961 .expect("lock session")
4962 .to_messages_for_current_path();
4963
4964 assert_eq!(
4965 persisted.len(),
4966 4,
4967 "unexpected message history after abort+resume: {persisted:?}"
4968 );
4969 assert!(matches!(persisted.first(), Some(Message::User(_))));
4970 assert!(matches!(
4971 persisted.get(1),
4972 Some(Message::Assistant(assistant)) if assistant.stop_reason == StopReason::Aborted
4973 ));
4974 assert!(matches!(persisted.get(2), Some(Message::User(_))));
4975 assert!(matches!(
4976 persisted.get(3),
4977 Some(Message::Assistant(assistant))
4978 if assistant.stop_reason == StopReason::Stop && assistant.error_message.is_none()
4979 ));
4980 });
4981 }
4982
4983 #[test]
4984 fn repeated_abort_then_resume_has_consistent_timeline_and_state() {
4985 let runtime = RuntimeBuilder::current_thread()
4986 .build()
4987 .expect("runtime build");
4988 let handle = runtime.handle();
4989
4990 runtime.block_on(async move {
4991 let provider = Arc::new(PhasedProvider::new(2));
4992 let tools = ToolRegistry::new(&[], Path::new("."), None);
4993 let agent = Agent::new(provider, tools, AgentConfig::default());
4994 let session = Arc::new(Mutex::new(Session::in_memory()));
4995 let mut agent_session = AgentSession::new(
4996 agent,
4997 Arc::clone(&session),
4998 false,
4999 ResolvedCompactionSettings::default(),
5000 );
5001
5002 let timeline = Arc::new(StdMutex::new(Vec::<String>::new()));
5003
5004 for run_idx in 0..2 {
5005 let started = Arc::new(Notify::new());
5006 let (abort_handle, abort_signal) = AbortHandle::new();
5007 let started_for_abort = Arc::clone(&started);
5008 let abort_join = handle.spawn(async move {
5009 started_for_abort.notified().await;
5010 abort_handle.abort();
5011 });
5012
5013 let run_timeline = Arc::clone(&timeline);
5014 let aborted = agent_session
5015 .run_text_with_abort(format!("abort-run-{run_idx}"), Some(abort_signal), {
5016 let started = Arc::clone(&started);
5017 move |event| {
5018 if let Ok(mut events) = run_timeline.lock() {
5019 events.push(format!("run{run_idx}:{}", event_tag(&event)));
5020 }
5021 if matches!(
5022 event,
5023 AgentEvent::MessageStart {
5024 message: Message::Assistant(_)
5025 }
5026 ) {
5027 started.notify_one();
5028 }
5029 }
5030 })
5031 .await
5032 .expect("aborted run");
5033 abort_join.await;
5034
5035 assert_eq!(
5036 aborted.stop_reason,
5037 StopReason::Aborted,
5038 "run {run_idx} should abort cleanly"
5039 );
5040 }
5041
5042 let run_timeline = Arc::clone(&timeline);
5043 let resumed = agent_session
5044 .run_text("final-run".to_string(), move |event| {
5045 if let Ok(mut events) = run_timeline.lock() {
5046 events.push(format!("run2:{}", event_tag(&event)));
5047 }
5048 })
5049 .await
5050 .expect("final resumed run");
5051 assert_eq!(resumed.stop_reason, StopReason::Stop);
5052 assert!(resumed.error_message.is_none());
5053
5054 let cx = crate::agent_cx::AgentCx::for_request();
5055 let persisted = session
5056 .lock(cx.cx())
5057 .await
5058 .expect("lock session")
5059 .to_messages_for_current_path();
5060
5061 assert_abort_resume_message_sequence(&persisted);
5062
5063 let timeline = timeline
5064 .lock()
5065 .unwrap_or_else(std::sync::PoisonError::into_inner)
5066 .clone();
5067 assert_abort_resume_timeline_boundaries(&timeline);
5068 });
5069 }
5070
5071 #[test]
5072 fn abort_during_tool_execution_records_aborted_tool_result() {
5073 let runtime = RuntimeBuilder::current_thread()
5074 .build()
5075 .expect("runtime build");
5076 let handle = runtime.handle();
5077
5078 runtime.block_on(async move {
5079 let provider = Arc::new(ToolCallProvider);
5080 let tools = ToolRegistry::from_tools(vec![Box::new(HangingTool)]);
5081 let agent = Agent::new(provider, tools, AgentConfig::default());
5082 let session = Arc::new(Mutex::new(Session::in_memory()));
5083 let mut agent_session = AgentSession::new(
5084 agent,
5085 Arc::clone(&session),
5086 false,
5087 ResolvedCompactionSettings::default(),
5088 );
5089
5090 let tool_started = Arc::new(Notify::new());
5091 let (abort_handle, abort_signal) = AbortHandle::new();
5092 let tool_started_for_abort = Arc::clone(&tool_started);
5093 let abort_join = handle.spawn(async move {
5094 tool_started_for_abort.notified().await;
5095 abort_handle.abort();
5096 });
5097
5098 let result = agent_session
5099 .run_text_with_abort("trigger tool".to_string(), Some(abort_signal), {
5100 let tool_started = Arc::clone(&tool_started);
5101 move |event| {
5102 if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
5103 tool_started.notify_one();
5104 }
5105 }
5106 })
5107 .await
5108 .expect("tool-abort run");
5109 abort_join.await;
5110 assert_eq!(result.stop_reason, StopReason::Aborted);
5111
5112 let cx = crate::agent_cx::AgentCx::for_request();
5113 let persisted = session
5114 .lock(cx.cx())
5115 .await
5116 .expect("lock session")
5117 .to_messages_for_current_path();
5118
5119 let tool_result = persisted
5120 .iter()
5121 .find_map(|message| match message {
5122 Message::ToolResult(result) => Some(result),
5123 _ => None,
5124 })
5125 .expect("expected tool result message");
5126 assert!(tool_result.is_error);
5127 assert!(
5128 tool_result.content.iter().any(|block| {
5129 matches!(
5130 block,
5131 ContentBlock::Text(text) if text.text.contains("Tool execution aborted")
5132 )
5133 }),
5134 "missing aborted tool marker in tool output: {:?}",
5135 tool_result.content
5136 );
5137 });
5138 }
5139}
5140
5141#[cfg(test)]
5142mod turn_event_tests {
5143 use super::*;
5144 use crate::session::Session;
5145 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
5146 use asupersync::runtime::RuntimeBuilder;
5147 use async_trait::async_trait;
5148 use futures::Stream;
5149 use serde_json::json;
5150 use std::path::Path;
5151 use std::pin::Pin;
5152 use std::sync::atomic::AtomicUsize;
5153 fn assistant_message(text: &str) -> AssistantMessage {
5157 AssistantMessage {
5158 content: vec![ContentBlock::Text(TextContent::new(text))],
5159 api: "test-api".to_string(),
5160 provider: "test-provider".to_string(),
5161 model: "test-model".to_string(),
5162 usage: Usage::default(),
5163 stop_reason: StopReason::Stop,
5164 error_message: None,
5165 timestamp: 0,
5166 }
5167 }
5168
5169 struct SingleShotProvider;
5170
5171 #[async_trait]
5172 #[allow(clippy::unnecessary_literal_bound)]
5173 impl Provider for SingleShotProvider {
5174 fn name(&self) -> &str {
5175 "test-provider"
5176 }
5177
5178 fn api(&self) -> &str {
5179 "test-api"
5180 }
5181
5182 fn model_id(&self) -> &str {
5183 "test-model"
5184 }
5185
5186 async fn stream(
5187 &self,
5188 _context: &Context<'_>,
5189 _options: &StreamOptions,
5190 ) -> crate::error::Result<
5191 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
5192 > {
5193 let partial = assistant_message("");
5194 let final_message = assistant_message("hello");
5195 let events = vec![
5196 Ok(StreamEvent::Start { partial }),
5197 Ok(StreamEvent::Done {
5198 reason: StopReason::Stop,
5199 message: final_message,
5200 }),
5201 ];
5202 Ok(Box::pin(futures::stream::iter(events)))
5203 }
5204 }
5205
5206 struct StreamSetupErrorProvider;
5207
5208 #[async_trait]
5209 #[allow(clippy::unnecessary_literal_bound)]
5210 impl Provider for StreamSetupErrorProvider {
5211 fn name(&self) -> &str {
5212 "test-provider"
5213 }
5214
5215 fn api(&self) -> &str {
5216 "test-api"
5217 }
5218
5219 fn model_id(&self) -> &str {
5220 "test-model"
5221 }
5222
5223 async fn stream(
5224 &self,
5225 _context: &Context<'_>,
5226 _options: &StreamOptions,
5227 ) -> crate::error::Result<
5228 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
5229 > {
5230 Err(Error::api("stream setup failed"))
5231 }
5232 }
5233
5234 #[derive(Debug)]
5235 struct EchoTool;
5236
5237 #[async_trait]
5238 #[allow(clippy::unnecessary_literal_bound)]
5239 impl Tool for EchoTool {
5240 fn name(&self) -> &str {
5241 "echo_tool"
5242 }
5243
5244 fn label(&self) -> &str {
5245 "echo_tool"
5246 }
5247
5248 fn description(&self) -> &str {
5249 "echo test tool"
5250 }
5251
5252 fn parameters(&self) -> serde_json::Value {
5253 json!({ "type": "object" })
5254 }
5255
5256 async fn execute(
5257 &self,
5258 _tool_call_id: &str,
5259 _input: serde_json::Value,
5260 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
5261 ) -> Result<ToolOutput> {
5262 Ok(ToolOutput {
5263 content: vec![ContentBlock::Text(TextContent::new("tool-ok"))],
5264 details: None,
5265 is_error: false,
5266 })
5267 }
5268 }
5269
5270 #[derive(Debug)]
5271 struct ToolTurnProvider {
5272 calls: AtomicUsize,
5273 }
5274
5275 impl ToolTurnProvider {
5276 const fn new() -> Self {
5277 Self {
5278 calls: AtomicUsize::new(0),
5279 }
5280 }
5281
5282 fn assistant_message_with(
5283 &self,
5284 stop_reason: StopReason,
5285 content: Vec<ContentBlock>,
5286 ) -> AssistantMessage {
5287 AssistantMessage {
5288 content,
5289 api: self.api().to_string(),
5290 provider: self.name().to_string(),
5291 model: self.model_id().to_string(),
5292 usage: Usage::default(),
5293 stop_reason,
5294 error_message: None,
5295 timestamp: 0,
5296 }
5297 }
5298 }
5299
5300 #[async_trait]
5301 #[allow(clippy::unnecessary_literal_bound)]
5302 impl Provider for ToolTurnProvider {
5303 fn name(&self) -> &str {
5304 "test-provider"
5305 }
5306
5307 fn api(&self) -> &str {
5308 "test-api"
5309 }
5310
5311 fn model_id(&self) -> &str {
5312 "test-model"
5313 }
5314
5315 async fn stream(
5316 &self,
5317 _context: &Context<'_>,
5318 _options: &StreamOptions,
5319 ) -> crate::error::Result<
5320 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
5321 > {
5322 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
5323 let partial = self.assistant_message_with(StopReason::Stop, Vec::new());
5324 let done = if call_index == 0 {
5325 self.assistant_message_with(
5326 StopReason::ToolUse,
5327 vec![ContentBlock::ToolCall(ToolCall {
5328 id: "tool-1".to_string(),
5329 name: "echo_tool".to_string(),
5330 arguments: json!({}),
5331 thought_signature: None,
5332 })],
5333 )
5334 } else {
5335 self.assistant_message_with(
5336 StopReason::Stop,
5337 vec![ContentBlock::Text(TextContent::new("final"))],
5338 )
5339 };
5340
5341 Ok(Box::pin(futures::stream::iter(vec![
5342 Ok(StreamEvent::Start { partial }),
5343 Ok(StreamEvent::Done {
5344 reason: done.stop_reason,
5345 message: done,
5346 }),
5347 ])))
5348 }
5349 }
5350
5351 #[test]
5352 fn turn_events_wrap_assistant_response() {
5353 let runtime = RuntimeBuilder::current_thread()
5354 .build()
5355 .expect("runtime build");
5356 let handle = runtime.handle();
5357
5358 let provider = Arc::new(SingleShotProvider);
5359 let tools = ToolRegistry::new(&[], Path::new("."), None);
5360 let agent = Agent::new(provider, tools, AgentConfig::default());
5361 let session = Arc::new(Mutex::new(Session::in_memory()));
5362 let mut agent_session =
5363 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5364
5365 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
5366 Arc::new(std::sync::Mutex::new(Vec::new()));
5367 let events_capture = Arc::clone(&events);
5368
5369 let join = handle.spawn(async move {
5370 agent_session
5371 .run_text("hello".to_string(), move |event| {
5372 events_capture
5373 .lock()
5374 .unwrap_or_else(std::sync::PoisonError::into_inner)
5375 .push(event);
5376 })
5377 .await
5378 .expect("run_text")
5379 });
5380
5381 runtime.block_on(async move {
5382 let message = join.await;
5383 assert_eq!(message.stop_reason, StopReason::Stop);
5384
5385 let events = events
5386 .lock()
5387 .unwrap_or_else(std::sync::PoisonError::into_inner);
5388 let turn_start_indices = events
5389 .iter()
5390 .enumerate()
5391 .filter_map(|(idx, event)| {
5392 matches!(event, AgentEvent::TurnStart { .. }).then_some(idx)
5393 })
5394 .collect::<Vec<_>>();
5395 let turn_end_indices = events
5396 .iter()
5397 .enumerate()
5398 .filter_map(|(idx, event)| {
5399 matches!(event, AgentEvent::TurnEnd { .. }).then_some(idx)
5400 })
5401 .collect::<Vec<_>>();
5402
5403 assert_eq!(turn_start_indices.len(), 1);
5404 assert_eq!(turn_end_indices.len(), 1);
5405 assert!(turn_start_indices[0] < turn_end_indices[0]);
5406
5407 let assistant_message_end = events
5408 .iter()
5409 .enumerate()
5410 .find_map(|(idx, event)| match event {
5411 AgentEvent::MessageEnd {
5412 message: Message::Assistant(_),
5413 } => Some(idx),
5414 _ => None,
5415 })
5416 .expect("assistant message end");
5417
5418 assert!(assistant_message_end < turn_end_indices[0]);
5419
5420 let (message_is_assistant, tool_results_empty) = {
5421 let turn_end_event = &events[turn_end_indices[0]];
5422 assert!(
5423 matches!(turn_end_event, AgentEvent::TurnEnd { .. }),
5424 "Expected TurnEnd event, got {turn_end_event:?}"
5425 );
5426 match turn_end_event {
5427 AgentEvent::TurnEnd {
5428 message,
5429 tool_results,
5430 ..
5431 } => (
5432 matches!(message, Message::Assistant(_)),
5433 tool_results.is_empty(),
5434 ),
5435 _ => (false, false),
5436 }
5437 };
5438 drop(events);
5439 assert!(message_is_assistant);
5440 assert!(tool_results_empty);
5441 });
5442 }
5443
5444 #[test]
5445 fn stream_setup_errors_still_emit_turn_end_before_agent_end() {
5446 let runtime = RuntimeBuilder::current_thread()
5447 .build()
5448 .expect("runtime build");
5449 let handle = runtime.handle();
5450
5451 let provider = Arc::new(StreamSetupErrorProvider);
5452 let tools = ToolRegistry::new(&[], Path::new("."), None);
5453 let agent = Agent::new(provider, tools, AgentConfig::default());
5454 let session = Arc::new(Mutex::new(Session::in_memory()));
5455 let mut agent_session =
5456 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5457
5458 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
5459 Arc::new(std::sync::Mutex::new(Vec::new()));
5460 let events_capture = Arc::clone(&events);
5461
5462 let join = handle.spawn(async move {
5463 agent_session
5464 .run_text("hello".to_string(), move |event| {
5465 events_capture
5466 .lock()
5467 .unwrap_or_else(std::sync::PoisonError::into_inner)
5468 .push(event);
5469 })
5470 .await
5471 .expect_err("run_text should fail before streaming starts")
5472 });
5473
5474 runtime.block_on(async move {
5475 let err = join.await;
5476 assert!(
5477 err.to_string().contains("stream setup failed"),
5478 "unexpected error: {err}"
5479 );
5480
5481 let events = events
5482 .lock()
5483 .unwrap_or_else(std::sync::PoisonError::into_inner);
5484 let turn_start_idx = events
5485 .iter()
5486 .position(|event| matches!(event, AgentEvent::TurnStart { turn_index: 0, .. }))
5487 .expect("turn start");
5488 let turn_end_idx = events
5489 .iter()
5490 .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
5491 .expect("turn end");
5492 let agent_end_idx = events
5493 .iter()
5494 .position(|event| matches!(event, AgentEvent::AgentEnd { .. }))
5495 .expect("agent end");
5496
5497 assert!(turn_start_idx < turn_end_idx);
5498 assert!(turn_end_idx < agent_end_idx);
5499
5500 let assistant_message_end = events
5501 .iter()
5502 .position(|event| {
5503 matches!(
5504 event,
5505 AgentEvent::MessageEnd {
5506 message: Message::Assistant(_),
5507 }
5508 )
5509 })
5510 .expect("assistant message end");
5511 assert!(assistant_message_end < turn_end_idx);
5512
5513 match &events[turn_end_idx] {
5514 AgentEvent::TurnEnd {
5515 message,
5516 tool_results,
5517 ..
5518 } => {
5519 assert!(tool_results.is_empty());
5520 match message {
5521 Message::Assistant(message) => {
5522 assert_eq!(message.stop_reason, StopReason::Error);
5523 assert_eq!(
5524 message.error_message.as_deref(),
5525 Some("API error: stream setup failed")
5526 );
5527 assert_eq!(message.api, "test-api");
5528 assert_eq!(message.provider, "test-provider");
5529 assert_eq!(message.model, "test-model");
5530 }
5531 other => panic!("expected assistant message in TurnEnd, got {other:?}"),
5532 }
5533 }
5534 other => panic!("expected TurnEnd event, got {other:?}"),
5535 }
5536
5537 match &events[agent_end_idx] {
5538 AgentEvent::AgentEnd { error, .. } => {
5539 assert_eq!(error.as_deref(), Some("API error: stream setup failed"));
5540 }
5541 other => panic!("expected AgentEnd event, got {other:?}"),
5542 }
5543 });
5544 }
5545
5546 #[test]
5547 fn turn_events_include_tool_execution_and_tool_result_messages() {
5548 let runtime = RuntimeBuilder::current_thread()
5549 .build()
5550 .expect("runtime build");
5551 let handle = runtime.handle();
5552
5553 let provider = Arc::new(ToolTurnProvider::new());
5554 let tools = ToolRegistry::from_tools(vec![Box::new(EchoTool)]);
5555 let agent = Agent::new(provider, tools, AgentConfig::default());
5556 let session = Arc::new(Mutex::new(Session::in_memory()));
5557 let mut agent_session =
5558 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5559
5560 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
5561 Arc::new(std::sync::Mutex::new(Vec::new()));
5562 let events_capture = Arc::clone(&events);
5563
5564 let join = handle.spawn(async move {
5565 agent_session
5566 .run_text("hello".to_string(), move |event| {
5567 events_capture
5568 .lock()
5569 .unwrap_or_else(std::sync::PoisonError::into_inner)
5570 .push(event);
5571 })
5572 .await
5573 .expect("run_text")
5574 });
5575
5576 runtime.block_on(async move {
5577 let message = join.await;
5578 assert_eq!(message.stop_reason, StopReason::Stop);
5579
5580 let events = events
5581 .lock()
5582 .unwrap_or_else(std::sync::PoisonError::into_inner);
5583 let turn_start_count = events
5584 .iter()
5585 .filter(|event| matches!(event, AgentEvent::TurnStart { .. }))
5586 .count();
5587 let turn_end_count = events
5588 .iter()
5589 .filter(|event| matches!(event, AgentEvent::TurnEnd { .. }))
5590 .count();
5591 assert_eq!(
5592 turn_start_count, 2,
5593 "expected one tool turn and one final turn"
5594 );
5595 assert_eq!(
5596 turn_end_count, 2,
5597 "expected one tool turn and one final turn"
5598 );
5599
5600 let tool_start_idx = events
5601 .iter()
5602 .position(|event| matches!(event, AgentEvent::ToolExecutionStart { .. }))
5603 .expect("tool execution start event");
5604 let tool_end_idx = events
5605 .iter()
5606 .position(|event| matches!(event, AgentEvent::ToolExecutionEnd { .. }))
5607 .expect("tool execution end event");
5608 assert!(tool_start_idx < tool_end_idx);
5609
5610 let first_turn_end_idx = events
5611 .iter()
5612 .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
5613 .expect("first turn end");
5614 assert!(
5615 tool_end_idx < first_turn_end_idx,
5616 "tool execution should complete before first turn end"
5617 );
5618
5619 let first_turn_tool_results = events.iter().find_map(|event| match event {
5620 AgentEvent::TurnEnd {
5621 turn_index,
5622 tool_results,
5623 ..
5624 } if *turn_index == 0 => Some(tool_results),
5625 _ => None,
5626 });
5627
5628 let first_turn_tool_results =
5629 first_turn_tool_results.expect("expected tool results for first turn");
5630 assert_eq!(first_turn_tool_results.len(), 1);
5631 let first_result = first_turn_tool_results.first().unwrap();
5632 if let Message::ToolResult(tr) = first_result {
5633 assert_eq!(tr.tool_name, "echo_tool");
5634 assert!(!tr.is_error);
5635 } else {
5636 unreachable!("expected Message::ToolResult, got {:?}", first_result);
5637 }
5638 drop(events);
5639 });
5640 }
5641}
5642
5643#[derive(Clone)]
5644struct AgentExtensionSession {
5645 handle: SessionHandle,
5646 is_streaming: Arc<AtomicBool>,
5647 is_compacting: Arc<AtomicBool>,
5648 queue_modes: Arc<StdMutex<ExtensionQueueModeState>>,
5649 auto_compaction_enabled: bool,
5650}
5651
5652impl AgentExtensionSession {
5653 fn current_queue_modes(&self) -> (QueueMode, QueueMode) {
5654 self.queue_modes
5655 .lock()
5656 .map_or((QueueMode::OneAtATime, QueueMode::OneAtATime), |state| {
5657 (state.steering_mode, state.follow_up_mode)
5658 })
5659 }
5660
5661 fn state_fallback(&self) -> Value {
5662 let (steering_mode, follow_up_mode) = self.current_queue_modes();
5663 json!({
5664 "model": null,
5665 "thinkingLevel": "off",
5666 "durabilityMode": "balanced",
5667 "isStreaming": self.is_streaming.load(std::sync::atomic::Ordering::SeqCst),
5668 "isCompacting": self.is_compacting.load(std::sync::atomic::Ordering::SeqCst),
5669 "steeringMode": steering_mode.as_str(),
5670 "followUpMode": follow_up_mode.as_str(),
5671 "sessionFile": null,
5672 "sessionId": "",
5673 "sessionName": null,
5674 "autoCompactionEnabled": self.auto_compaction_enabled,
5675 "messageCount": 0,
5676 "pendingMessageCount": 0,
5677 })
5678 }
5679}
5680
5681#[async_trait]
5682impl crate::extensions::ExtensionSession for AgentExtensionSession {
5683 async fn get_state(&self) -> Value {
5684 let (steering_mode, follow_up_mode) = self.current_queue_modes();
5685 let mut state =
5686 <SessionHandle as crate::extensions::ExtensionSession>::get_state(&self.handle).await;
5687 let Some(object) = state.as_object_mut() else {
5688 return self.state_fallback();
5689 };
5690
5691 object.insert(
5692 "isStreaming".to_string(),
5693 Value::Bool(self.is_streaming.load(std::sync::atomic::Ordering::SeqCst)),
5694 );
5695 object.insert(
5696 "isCompacting".to_string(),
5697 Value::Bool(self.is_compacting.load(std::sync::atomic::Ordering::SeqCst)),
5698 );
5699 object.insert(
5700 "steeringMode".to_string(),
5701 Value::String(steering_mode.as_str().to_string()),
5702 );
5703 object.insert(
5704 "followUpMode".to_string(),
5705 Value::String(follow_up_mode.as_str().to_string()),
5706 );
5707 object.insert(
5708 "autoCompactionEnabled".to_string(),
5709 Value::Bool(self.auto_compaction_enabled),
5710 );
5711
5712 state
5713 }
5714
5715 async fn get_messages(&self) -> Vec<crate::session::SessionMessage> {
5716 <SessionHandle as crate::extensions::ExtensionSession>::get_messages(&self.handle).await
5717 }
5718
5719 async fn get_entries(&self) -> Vec<Value> {
5720 <SessionHandle as crate::extensions::ExtensionSession>::get_entries(&self.handle).await
5721 }
5722
5723 async fn get_branch(&self) -> Vec<Value> {
5724 <SessionHandle as crate::extensions::ExtensionSession>::get_branch(&self.handle).await
5725 }
5726
5727 async fn set_name(&self, name: String) -> crate::error::Result<()> {
5728 <SessionHandle as crate::extensions::ExtensionSession>::set_name(&self.handle, name).await
5729 }
5730
5731 async fn append_message(
5732 &self,
5733 message: crate::session::SessionMessage,
5734 ) -> crate::error::Result<()> {
5735 <SessionHandle as crate::extensions::ExtensionSession>::append_message(
5736 &self.handle,
5737 message,
5738 )
5739 .await
5740 }
5741
5742 async fn append_custom_entry(
5743 &self,
5744 custom_type: String,
5745 data: Option<Value>,
5746 ) -> crate::error::Result<()> {
5747 <SessionHandle as crate::extensions::ExtensionSession>::append_custom_entry(
5748 &self.handle,
5749 custom_type,
5750 data,
5751 )
5752 .await
5753 }
5754
5755 async fn set_model(&self, provider: String, model_id: String) -> crate::error::Result<()> {
5756 <SessionHandle as crate::extensions::ExtensionSession>::set_model(
5757 &self.handle,
5758 provider,
5759 model_id,
5760 )
5761 .await
5762 }
5763
5764 async fn get_model(&self) -> (Option<String>, Option<String>) {
5765 <SessionHandle as crate::extensions::ExtensionSession>::get_model(&self.handle).await
5766 }
5767
5768 async fn set_thinking_level(&self, level: String) -> crate::error::Result<()> {
5769 <SessionHandle as crate::extensions::ExtensionSession>::set_thinking_level(
5770 &self.handle,
5771 level,
5772 )
5773 .await
5774 }
5775
5776 async fn get_thinking_level(&self) -> Option<String> {
5777 <SessionHandle as crate::extensions::ExtensionSession>::get_thinking_level(&self.handle)
5778 .await
5779 }
5780
5781 async fn set_label(
5782 &self,
5783 target_id: String,
5784 label: Option<String>,
5785 ) -> crate::error::Result<()> {
5786 <SessionHandle as crate::extensions::ExtensionSession>::set_label(
5787 &self.handle,
5788 target_id,
5789 label,
5790 )
5791 .await
5792 }
5793}
5794
5795impl AgentSession {
5796 pub const fn runtime_repair_mode_from_policy_mode(mode: RepairPolicyMode) -> RepairMode {
5797 match mode {
5798 RepairPolicyMode::Off => RepairMode::Off,
5799 RepairPolicyMode::Suggest => RepairMode::Suggest,
5800 RepairPolicyMode::AutoSafe => RepairMode::AutoSafe,
5801 RepairPolicyMode::AutoStrict => RepairMode::AutoStrict,
5802 }
5803 }
5804
5805 #[allow(clippy::too_many_arguments)]
5806 async fn start_js_extension_runtime(
5807 stage: &'static str,
5808 cwd: &std::path::Path,
5809 tools: Arc<ToolRegistry>,
5810 manager: ExtensionManager,
5811 policy: ExtensionPolicy,
5812 repair_mode: RepairMode,
5813 memory_limit_bytes: usize,
5814 ) -> Result<ExtensionRuntimeHandle> {
5815 let mut config = PiJsRuntimeConfig {
5816 cwd: cwd.display().to_string(),
5817 repair_mode,
5818 ..PiJsRuntimeConfig::default()
5819 };
5820 config.limits.memory_limit_bytes = Some(memory_limit_bytes).filter(|bytes| *bytes > 0);
5821
5822 let runtime =
5823 JsExtensionRuntimeHandle::start_with_policy(config, tools, manager, policy).await?;
5824 tracing::info!(
5825 event = "pi.extension_runtime.engine_decision",
5826 stage,
5827 requested = "quickjs",
5828 selected = "quickjs",
5829 fallback = false,
5830 "Extension runtime engine selected (legacy JS/TS)"
5831 );
5832 Ok(ExtensionRuntimeHandle::Js(runtime))
5833 }
5834
5835 #[allow(clippy::too_many_arguments)]
5836 async fn start_native_extension_runtime(
5837 stage: &'static str,
5838 _cwd: &std::path::Path,
5839 _tools: Arc<ToolRegistry>,
5840 _manager: ExtensionManager,
5841 _policy: ExtensionPolicy,
5842 _repair_mode: RepairMode,
5843 _memory_limit_bytes: usize,
5844 ) -> Result<ExtensionRuntimeHandle> {
5845 let runtime = NativeRustExtensionRuntimeHandle::start().await?;
5846 tracing::info!(
5847 event = "pi.extension_runtime.engine_decision",
5848 stage,
5849 requested = "native-rust",
5850 selected = "native-rust",
5851 fallback = false,
5852 "Extension runtime engine selected (native-rust)"
5853 );
5854 Ok(ExtensionRuntimeHandle::NativeRust(runtime))
5855 }
5856
5857 pub fn new(
5858 agent: Agent,
5859 session: Arc<Mutex<Session>>,
5860 save_enabled: bool,
5861 compaction_settings: ResolvedCompactionSettings,
5862 ) -> Self {
5863 Self {
5864 agent,
5865 session,
5866 save_enabled,
5867 input_source: InputSource::Interactive,
5868 extensions: None,
5869 extensions_is_streaming: Arc::new(AtomicBool::new(false)),
5870 extensions_is_compacting: Arc::new(AtomicBool::new(false)),
5871 extensions_turn_active: Arc::new(AtomicBool::new(false)),
5872 extensions_pending_idle_actions: Arc::new(StdMutex::new(VecDeque::new())),
5873 extension_queue_modes: None,
5874 extension_injected_queue: None,
5875 compaction_settings,
5876 compaction_runtime: None,
5877 runtime_handle: None,
5878 compaction_worker: CompactionWorkerState::new(CompactionQuota::default()),
5879 model_registry: None,
5880 auth_storage: None,
5881 }
5882 }
5883
5884 pub const fn set_input_source(&mut self, source: InputSource) {
5885 self.input_source = source;
5886 }
5887
5888 #[must_use]
5889 pub fn with_runtime_handle(mut self, runtime_handle: RuntimeHandle) -> Self {
5890 self.compaction_runtime = None;
5891 self.runtime_handle = Some(runtime_handle);
5892 self
5893 }
5894
5895 #[must_use]
5896 pub fn with_model_registry(mut self, registry: ModelRegistry) -> Self {
5897 self.model_registry = Some(registry);
5898 self
5899 }
5900
5901 #[must_use]
5902 pub fn with_auth_storage(mut self, auth: AuthStorage) -> Self {
5903 self.auth_storage = Some(auth);
5904 self
5905 }
5906
5907 pub fn set_model_registry(&mut self, registry: ModelRegistry) {
5908 self.model_registry = Some(registry);
5909 }
5910
5911 pub fn set_auth_storage(&mut self, auth: AuthStorage) {
5912 self.auth_storage = Some(auth);
5913 }
5914
5915 pub fn set_queue_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
5916 self.agent.set_queue_modes(steering_mode, follow_up_mode);
5917
5918 if let Some(queue_modes) = &self.extension_queue_modes
5919 && let Ok(mut state) = queue_modes.lock()
5920 {
5921 state.set_modes(steering_mode, follow_up_mode);
5922 }
5923
5924 if let Some(injected_queue) = &self.extension_injected_queue
5925 && let Ok(mut queue) = injected_queue.lock()
5926 {
5927 queue.set_modes(steering_mode, follow_up_mode);
5928 }
5929 }
5930
5931 pub const fn set_compaction_context_window(&mut self, context_window_tokens: u32) {
5932 self.compaction_settings.context_window_tokens = context_window_tokens;
5933 }
5934
5935 pub async fn set_provider_model(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
5936 let already_active = {
5937 let provider = self.agent.provider();
5938 provider.name() == provider_id && provider.model_id() == model_id
5939 };
5940 let current_thinking = self
5941 .agent
5942 .stream_options()
5943 .thinking_level
5944 .unwrap_or_default();
5945
5946 let target_entry = self
5947 .model_registry
5948 .as_ref()
5949 .and_then(|registry| registry.find(provider_id, model_id));
5950 let next_thinking = if let Some(target_entry) = target_entry {
5951 let resolved_key = self.resolve_stream_api_key_for_model(&target_entry);
5952 if !already_active
5953 && model_requires_configured_credential(&target_entry)
5954 && resolved_key.is_none()
5955 {
5956 return Err(Error::auth(format!(
5957 "Missing credentials for {provider_id}/{model_id}"
5958 )));
5959 }
5960 self.clamp_thinking_level_for_model(provider_id, model_id, current_thinking)
5961 } else if already_active {
5962 current_thinking
5963 } else {
5964 return Err(Error::validation(format!(
5965 "Unable to switch provider/model to {provider_id}/{model_id}"
5966 )));
5967 };
5968
5969 if !already_active {
5970 self.apply_session_model_selection(provider_id, model_id)?;
5971 }
5972 self.agent.stream_options_mut().thinking_level = Some(next_thinking);
5973
5974 {
5975 let cx = crate::agent_cx::AgentCx::for_request();
5976 let mut session = self
5977 .session
5978 .lock(cx.cx())
5979 .await
5980 .map_err(|e| Error::session(e.to_string()))?;
5981 let previous_model = session.effective_model_for_current_path();
5982 let previous_thinking = session
5983 .effective_thinking_level_for_current_path()
5984 .as_deref()
5985 .and_then(|value| value.parse::<crate::model::ThinkingLevel>().ok());
5986 if previous_model
5987 .as_ref()
5988 .map(|(provider, model_id)| (provider.as_str(), model_id.as_str()))
5989 != Some((provider_id, model_id))
5990 {
5991 session.append_model_change(provider_id.to_string(), model_id.to_string());
5992 }
5993 session.set_model_header(
5994 Some(provider_id.to_string()),
5995 Some(model_id.to_string()),
5996 Some(next_thinking.to_string()),
5997 );
5998 if previous_thinking != Some(next_thinking) {
5999 session.append_thinking_level_change(next_thinking.to_string());
6000 }
6001 }
6002
6003 self.persist_session().await
6004 }
6005
6006 pub(crate) fn clamp_thinking_level_for_model(
6007 &self,
6008 provider_id: &str,
6009 model_id: &str,
6010 level: crate::model::ThinkingLevel,
6011 ) -> crate::model::ThinkingLevel {
6012 self.model_registry
6013 .as_ref()
6014 .and_then(|registry| registry.find(provider_id, model_id))
6015 .map_or(level, |entry| entry.clamp_thinking_level(level))
6016 }
6017
6018 fn resolve_stream_api_key_for_model(&self, entry: &ModelEntry) -> Option<String> {
6019 let normalize = |key_opt: Option<String>| {
6020 key_opt.and_then(|key| {
6021 let trimmed = key.trim();
6022 (!trimmed.is_empty()).then(|| trimmed.to_string())
6023 })
6024 };
6025
6026 self.auth_storage
6027 .as_ref()
6028 .and_then(|auth| normalize(auth.resolve_api_key(&entry.model.provider, None)))
6029 .or_else(|| normalize(entry.api_key.clone()))
6030 }
6031
6032 pub(crate) async fn sync_runtime_selection_from_session_header(&mut self) -> Result<()> {
6033 let session_state = {
6034 let cx = crate::agent_cx::AgentCx::for_request();
6035 let session = self
6036 .session
6037 .lock(cx.cx())
6038 .await
6039 .map_err(|e| Error::session(e.to_string()))?;
6040 (
6041 session.effective_model_for_current_path(),
6042 session.effective_thinking_level_for_current_path(),
6043 )
6044 };
6045
6046 let (session_model, session_thinking) = session_state;
6047 let current_thinking = self
6048 .agent
6049 .stream_options()
6050 .thinking_level
6051 .unwrap_or_default();
6052
6053 if let Some((provider_id, model_id)) = session_model.as_ref() {
6054 self.apply_session_model_selection(provider_id, model_id)?;
6055 }
6056
6057 let parsed_session_thinking = session_thinking.as_deref().and_then(|raw| {
6058 raw.parse::<crate::model::ThinkingLevel>().map_or_else(
6059 |_| {
6060 tracing::warn!("Ignoring invalid session thinking level: {raw}");
6061 None
6062 },
6063 Some,
6064 )
6065 });
6066 let requested = parsed_session_thinking.unwrap_or(current_thinking);
6067
6068 let effective = if let Some((provider_id, model_id)) = session_model.as_ref() {
6069 self.clamp_thinking_level_for_model(provider_id, model_id, requested)
6070 } else {
6071 requested
6072 };
6073
6074 self.agent.stream_options_mut().thinking_level = Some(effective);
6075
6076 let thinking_changed = effective != current_thinking;
6077 let persist_needed = if session_thinking.is_some() {
6078 parsed_session_thinking != Some(effective)
6079 } else {
6080 thinking_changed
6081 };
6082 if !persist_needed {
6083 return Ok(());
6084 }
6085
6086 {
6087 let cx = crate::agent_cx::AgentCx::for_request();
6088 let mut session = self
6089 .session
6090 .lock(cx.cx())
6091 .await
6092 .map_err(|e| Error::session(e.to_string()))?;
6093 let previous_thinking = session
6094 .header
6095 .thinking_level
6096 .as_deref()
6097 .and_then(|value| value.parse::<crate::model::ThinkingLevel>().ok());
6098 session.set_model_header(None, None, Some(effective.to_string()));
6099 if thinking_changed && previous_thinking != Some(effective) {
6100 session.append_thinking_level_change(effective.to_string());
6101 }
6102 }
6103
6104 self.persist_session().await
6105 }
6106
6107 fn apply_session_model_selection(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
6108 if self.agent.provider().name() == provider_id
6109 && self.agent.provider().model_id() == model_id
6110 {
6111 return Ok(());
6112 }
6113
6114 let Some(registry) = &self.model_registry else {
6115 return Err(Error::validation(format!(
6116 "Unable to switch provider/model to {provider_id}/{model_id}"
6117 )));
6118 };
6119
6120 let Some(entry) = registry.find(provider_id, model_id) else {
6121 return Err(Error::validation(format!(
6122 "Unable to switch provider/model to {provider_id}/{model_id}"
6123 )));
6124 };
6125
6126 let resolved_key = self.resolve_stream_api_key_for_model(&entry);
6127 if model_requires_configured_credential(&entry) && resolved_key.is_none() {
6128 return Err(Error::auth(format!(
6129 "Missing credentials for {provider_id}/{model_id}"
6130 )));
6131 }
6132
6133 match crate::providers::create_provider(
6134 &entry,
6135 self.extensions.as_ref().map(ExtensionRegion::manager),
6136 ) {
6137 Ok(provider) => {
6138 tracing::info!("Updating agent provider to {provider_id}/{model_id}");
6139 self.agent.set_provider(provider);
6140
6141 let stream_options = self.agent.stream_options_mut();
6142 stream_options.api_key = resolved_key; stream_options.headers.clone_from(&entry.headers);
6144 Ok(())
6145 }
6146 Err(e) => Err(Error::validation(format!(
6147 "Unable to switch provider/model to {provider_id}/{model_id}: {e}"
6148 ))),
6149 }
6150 }
6151
6152 pub const fn save_enabled(&self) -> bool {
6153 self.save_enabled
6154 }
6155
6156 pub async fn compact_now(
6158 &mut self,
6159 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6160 ) -> Result<()> {
6161 self.compact_synchronous(Arc::new(on_event)).await
6162 }
6163
6164 pub async fn execute_extension_command(
6165 &mut self,
6166 command_name: &str,
6167 args: &str,
6168 timeout_ms: u64,
6169 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6170 ) -> Result<Value> {
6171 self.execute_extension_command_with_abort(command_name, args, timeout_ms, None, on_event)
6172 .await
6173 }
6174
6175 pub async fn execute_extension_command_with_abort(
6176 &mut self,
6177 command_name: &str,
6178 args: &str,
6179 timeout_ms: u64,
6180 abort: Option<AbortSignal>,
6181 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6182 ) -> Result<Value> {
6183 let manager = self
6184 .extensions
6185 .as_ref()
6186 .map(ExtensionRegion::manager)
6187 .ok_or_else(|| Error::extension("Extensions are disabled"))?
6188 .clone();
6189 let on_event: AgentEventHandler = Arc::new(on_event);
6190
6191 self.run_pending_idle_actions_with_abort(abort.clone(), Arc::clone(&on_event))
6192 .await?;
6193
6194 let command_result = manager
6195 .execute_command(command_name, args, timeout_ms)
6196 .await;
6197 let replay_result = self
6198 .run_pending_idle_actions_with_abort(abort, Arc::clone(&on_event))
6199 .await;
6200
6201 match command_result {
6202 Ok(value) => {
6203 replay_result?;
6204 Ok(value)
6205 }
6206 Err(err) => {
6207 if let Err(replay_err) = replay_result {
6208 tracing::warn!(
6209 "extension command follow-up replay failed after command error: {replay_err}"
6210 );
6211 }
6212 Err(err)
6213 }
6214 }
6215 }
6216
6217 #[allow(clippy::too_many_lines)]
6223 async fn maybe_compact(&mut self, on_event: AgentEventHandler) -> Result<()> {
6224 if !self.compaction_settings.enabled {
6225 return Ok(());
6226 }
6227
6228 if let Some(outcome) = self.compaction_worker.try_recv().await {
6230 self.extensions_is_compacting
6231 .store(false, std::sync::atomic::Ordering::SeqCst);
6232 match outcome {
6233 Ok(result) => {
6234 self.apply_compaction_result(result, Arc::clone(&on_event))
6235 .await?;
6236 }
6237 Err(e) => {
6238 on_event(AgentEvent::AutoCompactionEnd {
6239 result: None,
6240 aborted: false,
6241 will_retry: false,
6242 error_message: Some(e.to_string()),
6243 });
6244 }
6245 }
6246 }
6247
6248 if !self.compaction_worker.can_start() {
6250 return Ok(());
6251 }
6252
6253 let (entries, preparation) = {
6254 let cx = crate::agent_cx::AgentCx::for_request();
6255 let mut session = self
6256 .session
6257 .lock(cx.cx())
6258 .await
6259 .map_err(|e| Error::session(e.to_string()))?;
6260 session.ensure_entry_ids();
6261 let entries = session
6262 .entries_for_current_path()
6263 .into_iter()
6264 .cloned()
6265 .collect::<Vec<_>>();
6266 let prep = compaction::prepare_compaction(&entries, self.compaction_settings.clone());
6267 (entries, prep)
6268 };
6269
6270 if let Some(prep) = preparation {
6271 on_event(AgentEvent::AutoCompactionStart {
6272 reason: "threshold".to_string(),
6273 });
6274
6275 let before_outcome = self.dispatch_before_compact(&prep, &entries, None).await;
6276 if before_outcome.cancel {
6277 on_event(AgentEvent::AutoCompactionEnd {
6278 result: None,
6279 aborted: true,
6280 will_retry: false,
6281 error_message: None,
6282 });
6283 return Ok(());
6284 }
6285
6286 if let Some(compaction) = before_outcome.compaction {
6287 let result_value = compaction.details.clone();
6288 self.extensions_is_compacting
6289 .store(true, std::sync::atomic::Ordering::SeqCst);
6290 let apply_result = self
6291 .apply_compaction_entry(
6292 compaction.summary,
6293 compaction.first_kept_entry_id,
6294 compaction.tokens_before,
6295 compaction.details,
6296 true,
6297 )
6298 .await;
6299 self.extensions_is_compacting
6300 .store(false, std::sync::atomic::Ordering::SeqCst);
6301 apply_result?;
6302 on_event(AgentEvent::AutoCompactionEnd {
6303 result: result_value,
6304 aborted: false,
6305 will_retry: false,
6306 error_message: None,
6307 });
6308 return Ok(());
6309 }
6310
6311 let provider = self.agent.provider();
6312 let api_key = self .agent
6314 .stream_options()
6315 .api_key
6316 .clone()
6317 .unwrap_or_default();
6318
6319 let runtime_handle = match self.compaction_runtime_handle() {
6320 Ok(runtime_handle) => runtime_handle,
6321 Err(e) => {
6322 on_event(AgentEvent::AutoCompactionEnd {
6323 result: None,
6324 aborted: false,
6325 will_retry: false,
6326 error_message: Some(e.to_string()),
6327 });
6328 return Ok(());
6329 }
6330 };
6331
6332 self.compaction_worker
6333 .start(&runtime_handle, prep, provider, api_key, None);
6334 self.extensions_is_compacting
6335 .store(true, std::sync::atomic::Ordering::SeqCst);
6336 }
6337
6338 Ok(())
6339 }
6340
6341 fn compaction_runtime_handle(&mut self) -> Result<RuntimeHandle> {
6342 if let Some(runtime_handle) = self.runtime_handle.clone() {
6343 return Ok(runtime_handle);
6344 }
6345
6346 let runtime = RuntimeBuilder::new().build().map_err(|e| {
6347 Error::session(format!("Background compaction runtime init failed: {e}"))
6348 })?;
6349 let runtime_handle = runtime.handle();
6350 self.compaction_runtime = Some(runtime);
6351 self.runtime_handle = Some(runtime_handle.clone());
6352 Ok(runtime_handle)
6353 }
6354
6355 async fn apply_compaction_entry(
6356 &self,
6357 summary: String,
6358 first_kept_entry_id: String,
6359 tokens_before: u64,
6360 details: Option<Value>,
6361 from_extension: bool,
6362 ) -> Result<()> {
6363 let cx = crate::agent_cx::AgentCx::for_request();
6364 let mut session = self
6365 .session
6366 .lock(cx.cx())
6367 .await
6368 .map_err(|e| Error::session(e.to_string()))?;
6369
6370 let from_hook = if from_extension { Some(true) } else { None };
6371 let entry_id = session.append_compaction(
6372 summary,
6373 first_kept_entry_id,
6374 tokens_before,
6375 details,
6376 from_hook,
6377 );
6378
6379 if self.save_enabled {
6380 session
6381 .flush_autosave(AutosaveFlushTrigger::Periodic)
6382 .await?;
6383 }
6384
6385 let compaction_entry = session.get_entry(&entry_id).and_then(|entry| {
6386 if let crate::session::SessionEntry::Compaction(compaction) = entry {
6387 Some(compaction.clone())
6388 } else {
6389 None
6390 }
6391 });
6392 drop(session);
6393
6394 if let (Some(region), Some(compaction_entry)) = (&self.extensions, compaction_entry) {
6395 let payload = json!({
6396 "compactionEntry": compaction_entry,
6397 "fromExtension": from_extension,
6398 });
6399 if let Err(err) = region
6400 .manager()
6401 .dispatch_event(ExtensionEventName::SessionCompact, Some(payload))
6402 .await
6403 {
6404 tracing::warn!("session_compact extension hook failed (fail-open): {err}");
6405 }
6406 }
6407
6408 Ok(())
6409 }
6410
6411 async fn apply_compaction_result(
6413 &self,
6414 result: compaction::CompactionResult,
6415 on_event: AgentEventHandler,
6416 ) -> Result<()> {
6417 let details = compaction::compaction_details_to_value(&result.details).ok();
6418 let result_value = details.clone();
6419
6420 self.apply_compaction_entry(
6421 result.summary,
6422 result.first_kept_entry_id,
6423 result.tokens_before,
6424 details,
6425 false,
6426 )
6427 .await?;
6428
6429 on_event(AgentEvent::AutoCompactionEnd {
6430 result: result_value,
6431 aborted: false,
6432 will_retry: false,
6433 error_message: None,
6434 });
6435
6436 Ok(())
6437 }
6438
6439 async fn compact_synchronous(&self, on_event: AgentEventHandler) -> Result<()> {
6441 if !self.compaction_settings.enabled {
6442 return Ok(());
6443 }
6444
6445 let (entries, preparation) = {
6446 let cx = crate::agent_cx::AgentCx::for_request();
6447 let mut session = self
6448 .session
6449 .lock(cx.cx())
6450 .await
6451 .map_err(|e| Error::session(e.to_string()))?;
6452 session.ensure_entry_ids();
6453 let entries = session
6454 .entries_for_current_path()
6455 .into_iter()
6456 .cloned()
6457 .collect::<Vec<_>>();
6458 let prep = compaction::prepare_compaction(&entries, self.compaction_settings.clone());
6459 (entries, prep)
6460 };
6461
6462 if let Some(prep) = preparation {
6463 on_event(AgentEvent::AutoCompactionStart {
6464 reason: "threshold".to_string(),
6465 });
6466
6467 let before_outcome = self.dispatch_before_compact(&prep, &entries, None).await;
6468 if before_outcome.cancel {
6469 on_event(AgentEvent::AutoCompactionEnd {
6470 result: None,
6471 aborted: true,
6472 will_retry: false,
6473 error_message: None,
6474 });
6475 return Err(Error::extension("Compaction cancelled".to_string()));
6476 }
6477
6478 if let Some(compaction) = before_outcome.compaction {
6479 let result_value = compaction.details.clone();
6480 self.extensions_is_compacting
6481 .store(true, std::sync::atomic::Ordering::SeqCst);
6482 let apply_result = self
6483 .apply_compaction_entry(
6484 compaction.summary,
6485 compaction.first_kept_entry_id,
6486 compaction.tokens_before,
6487 compaction.details,
6488 true,
6489 )
6490 .await;
6491 self.extensions_is_compacting
6492 .store(false, std::sync::atomic::Ordering::SeqCst);
6493 apply_result?;
6494 on_event(AgentEvent::AutoCompactionEnd {
6495 result: result_value,
6496 aborted: false,
6497 will_retry: false,
6498 error_message: None,
6499 });
6500 return Ok(());
6501 }
6502 self.extensions_is_compacting
6503 .store(true, std::sync::atomic::Ordering::SeqCst);
6504
6505 let provider = self.agent.provider();
6506 let api_key = self .agent
6508 .stream_options()
6509 .api_key
6510 .clone()
6511 .unwrap_or_default();
6512
6513 let compaction_result = compaction::compact(prep, provider, &api_key, None).await;
6514 self.extensions_is_compacting
6515 .store(false, std::sync::atomic::Ordering::SeqCst);
6516
6517 match compaction_result {
6518 Ok(result) => {
6519 self.apply_compaction_result(result, Arc::clone(&on_event))
6520 .await?;
6521 }
6522 Err(e) => {
6523 on_event(AgentEvent::AutoCompactionEnd {
6524 result: None,
6525 aborted: false,
6526 will_retry: false,
6527 error_message: Some(e.to_string()),
6528 });
6529 return Err(e);
6530 }
6531 }
6532 }
6533 Ok(())
6534 }
6535
6536 fn resolve_extension_policy_for_enable(
6537 config: Option<&crate::config::Config>,
6538 policy: Option<ExtensionPolicy>,
6539 ) -> ExtensionPolicy {
6540 policy.unwrap_or_else(|| {
6541 config.map_or_else(
6542 || crate::config::Config::default().resolve_extension_policy(None),
6543 |cfg| cfg.resolve_extension_policy(None),
6544 )
6545 })
6546 }
6547
6548 pub async fn enable_extensions(
6549 &mut self,
6550 enabled_tools: &[&str],
6551 cwd: &std::path::Path,
6552 config: Option<&crate::config::Config>,
6553 extension_entries: &[std::path::PathBuf],
6554 ) -> Result<()> {
6555 self.enable_extensions_with_policy(
6556 enabled_tools,
6557 cwd,
6558 config,
6559 extension_entries,
6560 None,
6561 None,
6562 None,
6563 )
6564 .await
6565 }
6566
6567 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
6568 pub async fn enable_extensions_with_policy(
6569 &mut self,
6570 enabled_tools: &[&str],
6571 cwd: &std::path::Path,
6572 config: Option<&crate::config::Config>,
6573 extension_entries: &[std::path::PathBuf],
6574 policy: Option<ExtensionPolicy>,
6575 repair_policy: Option<RepairPolicyMode>,
6576 pre_warmed: Option<PreWarmedExtensionRuntime>,
6577 ) -> Result<()> {
6578 let mut js_specs: Vec<JsExtensionLoadSpec> = Vec::new();
6579 let mut native_specs: Vec<NativeRustExtensionLoadSpec> = Vec::new();
6580 #[cfg(feature = "wasm-host")]
6581 let mut wasm_specs: Vec<WasmExtensionLoadSpec> = Vec::new();
6582
6583 for entry in extension_entries {
6584 match resolve_extension_load_spec(entry)? {
6585 ExtensionLoadSpec::Js(spec) => js_specs.push(spec),
6586 ExtensionLoadSpec::NativeRust(spec) => native_specs.push(spec),
6587 #[cfg(feature = "wasm-host")]
6588 ExtensionLoadSpec::Wasm(spec) => wasm_specs.push(spec),
6589 }
6590 }
6591
6592 if !js_specs.is_empty() && !native_specs.is_empty() {
6593 return Err(Error::validation(
6594 "Mixed extension runtimes are not supported in one session yet. Use either JS/TS extensions (QuickJS) or native-rust descriptors (*.native.json), but not both at once."
6595 .to_string(),
6596 ));
6597 }
6598
6599 #[cfg(feature = "wasm-host")]
6600 if js_specs.is_empty() && native_specs.is_empty() && wasm_specs.is_empty() {
6601 self.extensions = None;
6602 self.agent.extensions = None;
6603 self.extension_queue_modes = None;
6604 self.extension_injected_queue = None;
6605 return Ok(());
6606 }
6607
6608 #[cfg(not(feature = "wasm-host"))]
6609 if js_specs.is_empty() && native_specs.is_empty() {
6610 self.extensions = None;
6611 self.agent.extensions = None;
6612 self.extension_queue_modes = None;
6613 self.extension_injected_queue = None;
6614 return Ok(());
6615 }
6616
6617 let resolved_policy = Self::resolve_extension_policy_for_enable(config, policy);
6618 let resolved_repair_policy = repair_policy
6619 .or_else(|| config.map(|cfg| cfg.resolve_repair_policy(None)))
6620 .unwrap_or(RepairPolicyMode::AutoSafe);
6621 let runtime_repair_mode =
6622 Self::runtime_repair_mode_from_policy_mode(resolved_repair_policy);
6623 let memory_limit_bytes =
6624 (resolved_policy.max_memory_mb as usize).saturating_mul(1024 * 1024);
6625 let wants_js_runtime = !js_specs.is_empty();
6626
6627 #[allow(unused_variables)]
6630 let (manager, tools) = if let Some(pre) = pre_warmed {
6631 let manager = pre.manager;
6632 let tools = pre.tools;
6633 let runtime = match pre.runtime {
6634 ExtensionRuntimeHandle::NativeRust(runtime) => {
6635 if wants_js_runtime {
6636 tracing::warn!(
6637 event = "pi.extension_runtime.prewarm.mismatch",
6638 expected = "quickjs",
6639 got = "native-rust",
6640 "Pre-warmed runtime mismatched requested JS mode; creating quickjs runtime"
6641 );
6642 Self::start_js_extension_runtime(
6643 "agent_enable_extensions_prewarm_mismatch",
6644 cwd,
6645 Arc::clone(&tools),
6646 manager.clone(),
6647 resolved_policy.clone(),
6648 runtime_repair_mode,
6649 memory_limit_bytes,
6650 )
6651 .await?
6652 } else {
6653 tracing::info!(
6654 event = "pi.extension_runtime.engine_decision",
6655 stage = "agent_enable_extensions_prewarmed",
6656 requested = "native-rust",
6657 selected = "native-rust",
6658 fallback = false,
6659 "Using pre-warmed extension runtime"
6660 );
6661 ExtensionRuntimeHandle::NativeRust(runtime)
6662 }
6663 }
6664 ExtensionRuntimeHandle::Js(runtime) => {
6665 if wants_js_runtime {
6666 tracing::info!(
6667 event = "pi.extension_runtime.engine_decision",
6668 stage = "agent_enable_extensions_prewarmed",
6669 requested = "quickjs",
6670 selected = "quickjs",
6671 fallback = false,
6672 "Using pre-warmed extension runtime"
6673 );
6674 ExtensionRuntimeHandle::Js(runtime)
6675 } else {
6676 tracing::warn!(
6677 event = "pi.extension_runtime.prewarm.mismatch",
6678 expected = "native-rust",
6679 got = "quickjs",
6680 "Pre-warmed runtime mismatched requested native mode; creating native-rust runtime"
6681 );
6682 Self::start_native_extension_runtime(
6683 "agent_enable_extensions_prewarm_mismatch",
6684 cwd,
6685 Arc::clone(&tools),
6686 manager.clone(),
6687 resolved_policy.clone(),
6688 runtime_repair_mode,
6689 memory_limit_bytes,
6690 )
6691 .await?
6692 }
6693 }
6694 };
6695 manager.set_runtime(runtime);
6696 (manager, tools)
6697 } else {
6698 let manager = ExtensionManager::new();
6699 manager.set_cwd(cwd.display().to_string());
6700 let tools = Arc::new(ToolRegistry::new(enabled_tools, cwd, config));
6701
6702 if let Some(cfg) = config {
6703 let resolved_risk = cfg.resolve_extension_risk_with_metadata();
6704 tracing::info!(
6705 event = "pi.extension_runtime_risk.config",
6706 source = resolved_risk.source,
6707 enabled = resolved_risk.settings.enabled,
6708 alpha = resolved_risk.settings.alpha,
6709 window_size = resolved_risk.settings.window_size,
6710 ledger_limit = resolved_risk.settings.ledger_limit,
6711 fail_closed = resolved_risk.settings.fail_closed,
6712 "Resolved extension runtime risk settings"
6713 );
6714 manager.set_runtime_risk_config(resolved_risk.settings);
6715 }
6716
6717 let runtime = if wants_js_runtime {
6718 Self::start_js_extension_runtime(
6719 "agent_enable_extensions_boot",
6720 cwd,
6721 Arc::clone(&tools),
6722 manager.clone(),
6723 resolved_policy.clone(),
6724 runtime_repair_mode,
6725 memory_limit_bytes,
6726 )
6727 .await?
6728 } else {
6729 Self::start_native_extension_runtime(
6730 "agent_enable_extensions_boot",
6731 cwd,
6732 Arc::clone(&tools),
6733 manager.clone(),
6734 resolved_policy.clone(),
6735 runtime_repair_mode,
6736 memory_limit_bytes,
6737 )
6738 .await?
6739 };
6740 manager.set_runtime(runtime);
6741 (manager, tools)
6742 };
6743
6744 let (steering_mode, follow_up_mode) = self.agent.queue_modes();
6748 let queue_modes = Arc::new(StdMutex::new(ExtensionQueueModeState::new(
6749 steering_mode,
6750 follow_up_mode,
6751 )));
6752 manager.set_session(Arc::new(AgentExtensionSession {
6753 handle: SessionHandle(self.session.clone()),
6754 is_streaming: Arc::clone(&self.extensions_is_streaming),
6755 is_compacting: Arc::clone(&self.extensions_is_compacting),
6756 queue_modes: Arc::clone(&queue_modes),
6757 auto_compaction_enabled: self.compaction_settings.enabled,
6758 }));
6759
6760 let injected = Arc::new(StdMutex::new(ExtensionInjectedQueue::new(
6761 steering_mode,
6762 follow_up_mode,
6763 )));
6764 let host_actions = AgentSessionHostActions {
6765 session: Arc::clone(&self.session),
6766 injected: Arc::clone(&injected),
6767 is_streaming: Arc::clone(&self.extensions_is_streaming),
6768 is_turn_active: Arc::clone(&self.extensions_turn_active),
6769 pending_idle_actions: Arc::clone(&self.extensions_pending_idle_actions),
6770 };
6771 self.extension_queue_modes = Some(Arc::clone(&queue_modes));
6772 self.extension_injected_queue = Some(Arc::clone(&injected));
6773 manager.set_host_actions(Arc::new(host_actions));
6774 {
6775 let steering_queue = Arc::clone(&injected);
6776 let follow_up_queue = Arc::clone(&injected);
6777 let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
6778 let steering_queue = Arc::clone(&steering_queue);
6779 Box::pin(async move {
6780 let Ok(mut queue) = steering_queue.lock() else {
6781 return Vec::new();
6782 };
6783 queue.pop_steering()
6784 })
6785 };
6786 let follow_up_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
6787 let follow_up_queue = Arc::clone(&follow_up_queue);
6788 Box::pin(async move {
6789 let Ok(mut queue) = follow_up_queue.lock() else {
6790 return Vec::new();
6791 };
6792 queue.pop_follow_up()
6793 })
6794 };
6795 self.agent.register_message_fetchers(
6796 Some(Arc::new(steering_fetcher)),
6797 Some(Arc::new(follow_up_fetcher)),
6798 );
6799 }
6800
6801 if !js_specs.is_empty() {
6802 manager.load_js_extensions(js_specs).await?;
6803 }
6804
6805 if !native_specs.is_empty() {
6806 manager.load_native_extensions(native_specs).await?;
6807 }
6808
6809 if let Some(rt) = manager.runtime() {
6811 let events = rt.drain_repair_events().await;
6812 if !events.is_empty() {
6813 log_repair_diagnostics(&events);
6814 }
6815 }
6816
6817 #[cfg(feature = "wasm-host")]
6818 if !wasm_specs.is_empty() {
6819 let host = WasmExtensionHost::new(cwd, resolved_policy.clone())?;
6820 manager
6821 .load_wasm_extensions(&host, wasm_specs, Arc::clone(&tools))
6822 .await?;
6823 }
6824
6825 let session_path = {
6828 let cx = crate::agent_cx::AgentCx::for_request();
6829 let session = self
6830 .session
6831 .lock(cx.cx())
6832 .await
6833 .map_err(|e| Error::extension(e.to_string()))?;
6834 session.path.as_ref().map(|p| p.display().to_string())
6835 };
6836
6837 if let Err(err) = manager
6838 .dispatch_event(
6839 ExtensionEventName::Startup,
6840 Some(serde_json::json!({
6841 "version": env!("CARGO_PKG_VERSION"),
6842 "sessionFile": session_path,
6843 })),
6844 )
6845 .await
6846 {
6847 tracing::warn!("startup extension hook failed (fail-open): {err}");
6848 }
6849
6850 if let Err(err) = manager
6851 .dispatch_event(ExtensionEventName::SessionStart, None)
6852 .await
6853 {
6854 tracing::warn!("session_start extension hook failed (fail-open): {err}");
6855 }
6856
6857 let ctx_payload = serde_json::json!({ "cwd": cwd.display().to_string() });
6858 let wrappers = collect_extension_tool_wrappers(&manager, ctx_payload).await?;
6859 self.agent.extend_tools(wrappers);
6860 self.agent.extensions = Some(manager.clone());
6861 self.extensions = Some(ExtensionRegion::new(manager));
6862 Ok(())
6863 }
6864
6865 pub async fn save_and_index(&mut self) -> Result<()> {
6866 if self.save_enabled {
6867 let cx = crate::agent_cx::AgentCx::for_request();
6868 let mut session = self
6869 .session
6870 .lock(cx.cx())
6871 .await
6872 .map_err(|e| Error::session(e.to_string()))?;
6873 session
6874 .flush_autosave(AutosaveFlushTrigger::Periodic)
6875 .await?;
6876 }
6877 Ok(())
6878 }
6879
6880 pub async fn persist_session(&mut self) -> Result<()> {
6881 if !self.save_enabled {
6882 return Ok(());
6883 }
6884 let cx = crate::agent_cx::AgentCx::for_request();
6885 let mut session = self
6886 .session
6887 .lock(cx.cx())
6888 .await
6889 .map_err(|e| Error::session(e.to_string()))?;
6890 session
6891 .flush_autosave(AutosaveFlushTrigger::Periodic)
6892 .await?;
6893 Ok(())
6894 }
6895
6896 pub async fn run_text(
6897 &mut self,
6898 input: String,
6899 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6900 ) -> Result<AssistantMessage> {
6901 self.run_text_with_abort(input, None, on_event).await
6902 }
6903
6904 pub async fn run_text_with_abort(
6905 &mut self,
6906 input: String,
6907 abort: Option<AbortSignal>,
6908 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6909 ) -> Result<AssistantMessage> {
6910 self.extensions_turn_active.store(true, Ordering::SeqCst);
6911 let result = async {
6912 let outcome = self.dispatch_input_event(input, Vec::new()).await?;
6913 let (text, images) = match outcome {
6914 InputEventOutcome::Continue { text, images } => (text, images),
6915 InputEventOutcome::Block { reason } => {
6916 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
6917 return Err(Error::extension(message));
6918 }
6919 };
6920
6921 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
6922 let BeforeAgentStartOutcome {
6923 messages: custom_messages,
6924 system_prompt,
6925 } = self
6926 .dispatch_before_agent_start(
6927 &text,
6928 &images,
6929 base_system_prompt.as_deref().unwrap_or(""),
6930 )
6931 .await;
6932 if let Some(prompt) = system_prompt {
6933 self.agent.set_system_prompt(Some(prompt));
6934 } else {
6935 self.agent.set_system_prompt(base_system_prompt.clone());
6936 }
6937
6938 let result = if images.is_empty() {
6939 self.run_agent_with_text(text, abort, on_event, custom_messages)
6940 .await
6941 } else {
6942 let content = Self::build_content_blocks_for_input(&text, &images);
6943 self.run_agent_with_content(content, abort, on_event, custom_messages)
6944 .await
6945 };
6946
6947 self.agent.set_system_prompt(base_system_prompt);
6948 result
6949 }
6950 .await;
6951 self.extensions_turn_active.store(false, Ordering::SeqCst);
6952 result
6953 }
6954
6955 pub async fn run_with_content(
6956 &mut self,
6957 content: Vec<ContentBlock>,
6958 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6959 ) -> Result<AssistantMessage> {
6960 self.run_with_content_with_abort(content, None, on_event)
6961 .await
6962 }
6963
6964 pub async fn run_with_content_with_abort(
6965 &mut self,
6966 content: Vec<ContentBlock>,
6967 abort: Option<AbortSignal>,
6968 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
6969 ) -> Result<AssistantMessage> {
6970 self.extensions_turn_active.store(true, Ordering::SeqCst);
6971 let result = async {
6972 let (text, images) = Self::split_content_blocks_for_input(&content);
6973 let outcome = self.dispatch_input_event(text, images).await?;
6974 let (text, images) = match outcome {
6975 InputEventOutcome::Continue { text, images } => (text, images),
6976 InputEventOutcome::Block { reason } => {
6977 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
6978 return Err(Error::extension(message));
6979 }
6980 };
6981
6982 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
6983 let BeforeAgentStartOutcome {
6984 messages: custom_messages,
6985 system_prompt,
6986 } = self
6987 .dispatch_before_agent_start(
6988 &text,
6989 &images,
6990 base_system_prompt.as_deref().unwrap_or(""),
6991 )
6992 .await;
6993 if let Some(prompt) = system_prompt {
6994 self.agent.set_system_prompt(Some(prompt));
6995 } else {
6996 self.agent.set_system_prompt(base_system_prompt.clone());
6997 }
6998
6999 let content_for_agent = Self::build_content_blocks_for_input(&text, &images);
7000 let result = self
7001 .run_agent_with_content(content_for_agent, abort, on_event, custom_messages)
7002 .await;
7003
7004 self.agent.set_system_prompt(base_system_prompt);
7005 result
7006 }
7007 .await;
7008 self.extensions_turn_active.store(false, Ordering::SeqCst);
7009 result
7010 }
7011
7012 pub async fn revert_last_user_message(&mut self) -> Result<bool> {
7013 let cx = crate::agent_cx::AgentCx::for_request();
7014 let mut session = self
7015 .session
7016 .lock(cx.cx())
7017 .await
7018 .map_err(|e| Error::session(e.to_string()))?;
7019
7020 let reverted = session.revert_last_user_message();
7021 if reverted {
7022 let messages = session.to_messages_for_current_path();
7023 self.agent.replace_messages(messages);
7024 }
7025 Ok(reverted)
7026 }
7027
7028 async fn dispatch_input_event(
7029 &self,
7030 text: String,
7031 images: Vec<ImageContent>,
7032 ) -> Result<InputEventOutcome> {
7033 let Some(region) = &self.extensions else {
7034 return Ok(InputEventOutcome::Continue { text, images });
7035 };
7036
7037 let images_value = serde_json::to_value(&images).unwrap_or(Value::Null);
7038 let attachments_value = images_value.clone();
7039 let text_clone = text.clone();
7040 let payload = json!({
7041 "text": text,
7042 "content": text_clone,
7043 "images": images_value,
7044 "attachments": attachments_value,
7045 "source": self.input_source.as_str(),
7046 });
7047
7048 let response = region
7049 .manager()
7050 .dispatch_event_with_response(
7051 ExtensionEventName::Input,
7052 Some(payload),
7053 EXTENSION_EVENT_TIMEOUT_MS,
7054 )
7055 .await?;
7056
7057 Ok(apply_input_event_response(response, text, images))
7058 }
7059
7060 async fn dispatch_before_agent_start(
7061 &self,
7062 prompt: &str,
7063 images: &[ImageContent],
7064 system_prompt: &str,
7065 ) -> BeforeAgentStartOutcome {
7066 let Some(region) = &self.extensions else {
7067 return BeforeAgentStartOutcome {
7068 messages: Vec::new(),
7069 system_prompt: None,
7070 };
7071 };
7072
7073 let images_value = serde_json::to_value(images).unwrap_or(Value::Null);
7074 let payload = json!({
7075 "prompt": prompt,
7076 "images": images_value,
7077 "systemPrompt": system_prompt,
7078 });
7079
7080 let response = region
7081 .manager()
7082 .dispatch_event_with_response(
7083 ExtensionEventName::BeforeAgentStart,
7084 Some(payload),
7085 EXTENSION_EVENT_TIMEOUT_MS,
7086 )
7087 .await;
7088
7089 match response {
7090 Ok(value) => apply_before_agent_start_response(value, Utc::now().timestamp_millis()),
7091 Err(err) => {
7092 tracing::warn!("before_agent_start extension hook failed (fail-open): {err}");
7093 BeforeAgentStartOutcome {
7094 messages: Vec::new(),
7095 system_prompt: None,
7096 }
7097 }
7098 }
7099 }
7100
7101 async fn dispatch_before_compact(
7102 &self,
7103 preparation: &compaction::CompactionPreparation,
7104 branch_entries: &[crate::session::SessionEntry],
7105 custom_instructions: Option<&str>,
7106 ) -> SessionBeforeCompactOutcome {
7107 let Some(region) = &self.extensions else {
7108 return SessionBeforeCompactOutcome::default();
7109 };
7110
7111 let prep_value = compaction::compaction_preparation_to_value(preparation);
7112 let branch_entries_value =
7113 serde_json::to_value(branch_entries).unwrap_or(Value::Array(Vec::new()));
7114 let mut payload = serde_json::Map::new();
7115 payload.insert("preparation".to_string(), prep_value);
7116 payload.insert("branchEntries".to_string(), branch_entries_value);
7117 if let Some(custom_instructions) = custom_instructions {
7118 payload.insert(
7119 "customInstructions".to_string(),
7120 Value::String(custom_instructions.to_string()),
7121 );
7122 }
7123
7124 let response = region
7125 .manager()
7126 .dispatch_event_with_response(
7127 ExtensionEventName::SessionBeforeCompact,
7128 Some(Value::Object(payload)),
7129 EXTENSION_EVENT_TIMEOUT_MS,
7130 )
7131 .await;
7132
7133 match response {
7134 Ok(value) => apply_session_before_compact_response(value, preparation.tokens_before),
7135 Err(err) => {
7136 tracing::warn!("session_before_compact extension hook failed (fail-open): {err}");
7137 SessionBeforeCompactOutcome::default()
7138 }
7139 }
7140 }
7141
7142 fn split_content_blocks_for_input(blocks: &[ContentBlock]) -> (String, Vec<ImageContent>) {
7143 let mut text = String::new();
7144 let mut images = Vec::new();
7145 for block in blocks {
7146 match block {
7147 ContentBlock::Text(text_block) if !text_block.text.trim().is_empty() => {
7148 if !text.is_empty() {
7149 text.push('\n');
7150 }
7151 text.push_str(&text_block.text);
7152 }
7153 ContentBlock::Image(image) => images.push(image.clone()),
7154 _ => {}
7155 }
7156 }
7157 (text, images)
7158 }
7159
7160 fn build_content_blocks_for_input(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
7161 let mut content = Vec::new();
7162 if !text.trim().is_empty() {
7163 content.push(ContentBlock::Text(TextContent::new(text.to_string())));
7164 }
7165 for image in images {
7166 content.push(ContentBlock::Image(image.clone()));
7167 }
7168 content
7169 }
7170
7171 fn take_pending_idle_actions(&self) -> Vec<PendingIdleAction> {
7172 let Ok(mut actions) = self.extensions_pending_idle_actions.lock() else {
7173 return Vec::new();
7174 };
7175 actions.drain(..).collect()
7176 }
7177
7178 async fn run_pending_idle_actions_with_abort(
7179 &mut self,
7180 abort: Option<AbortSignal>,
7181 on_event: AgentEventHandler,
7182 ) -> Result<()> {
7183 let actions = self.take_pending_idle_actions();
7184 if actions.is_empty() {
7185 return Ok(());
7186 }
7187
7188 let previous_source = self.input_source;
7189 self.input_source = InputSource::Extension;
7190 let result = async {
7191 for action in actions {
7192 match action {
7193 PendingIdleAction::CustomMessage(message) => {
7194 let handler = Arc::clone(&on_event);
7195 self.run_custom_message_with_abort(message, abort.clone(), move |event| {
7196 handler(event);
7197 })
7198 .await?;
7199 }
7200 PendingIdleAction::UserText(text) => {
7201 let handler = Arc::clone(&on_event);
7202 self.run_text_with_abort(text, abort.clone(), move |event| {
7203 handler(event);
7204 })
7205 .await?;
7206 }
7207 }
7208 }
7209 Ok(())
7210 }
7211 .await;
7212 self.input_source = previous_source;
7213 result
7214 }
7215
7216 async fn run_custom_message_with_abort(
7217 &mut self,
7218 message: Message,
7219 abort: Option<AbortSignal>,
7220 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
7221 ) -> Result<AssistantMessage> {
7222 self.extensions_turn_active.store(true, Ordering::SeqCst);
7223 let result = async {
7224 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
7225 let BeforeAgentStartOutcome {
7226 messages: custom_messages,
7227 system_prompt,
7228 } = self
7229 .dispatch_before_agent_start("", &[], base_system_prompt.as_deref().unwrap_or(""))
7230 .await;
7231 if let Some(prompt) = system_prompt {
7232 self.agent.set_system_prompt(Some(prompt));
7233 } else {
7234 self.agent.set_system_prompt(base_system_prompt.clone());
7235 }
7236
7237 let result = self
7238 .run_agent_with_prompt_message(message, abort, on_event, custom_messages)
7239 .await;
7240
7241 self.agent.set_system_prompt(base_system_prompt);
7242 result
7243 }
7244 .await;
7245 self.extensions_turn_active.store(false, Ordering::SeqCst);
7246 result
7247 }
7248
7249 async fn run_agent_with_prompt_message(
7250 &mut self,
7251 prompt_message: Message,
7252 abort: Option<AbortSignal>,
7253 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
7254 custom_messages: Vec<CustomMessage>,
7255 ) -> Result<AssistantMessage> {
7256 let on_event: AgentEventHandler = Arc::new(on_event);
7257 self.sync_runtime_selection_from_session_header().await?;
7258
7259 self.maybe_compact(Arc::clone(&on_event)).await?;
7260 let history = {
7261 let cx = crate::agent_cx::AgentCx::for_request();
7262 let session = self
7263 .session
7264 .lock(cx.cx())
7265 .await
7266 .map_err(|e| Error::session(e.to_string()))?;
7267 session.to_messages_for_current_path()
7268 };
7269 self.agent.replace_messages(history);
7270
7271 let start_len = self.agent.messages().len();
7272 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
7273 prompts.push(prompt_message.clone());
7274 prompts.extend(custom_messages.into_iter().map(Message::Custom));
7275
7276 {
7277 let cx = crate::agent_cx::AgentCx::for_request();
7278 let mut session = self
7279 .session
7280 .lock(cx.cx())
7281 .await
7282 .map_err(|e| Error::session(e.to_string()))?;
7283 session.append_model_message(prompt_message.clone());
7284 if self.save_enabled {
7285 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
7286 }
7287 }
7288
7289 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
7290 let on_event_for_run = Arc::clone(&on_event);
7291 let result = self
7292 .agent
7293 .run_with_messages_with_abort(prompts, abort, move |event| {
7294 on_event_for_run(event);
7295 })
7296 .await;
7297 drop(streaming_guard);
7298
7299 let persist_result = self.persist_new_messages(start_len + 1).await;
7300
7301 let result = result?;
7302 persist_result?;
7303 Ok(result)
7304 }
7305
7306 pub(crate) async fn run_agent_with_text(
7307 &mut self,
7308 input: String,
7309 abort: Option<AbortSignal>,
7310 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
7311 custom_messages: Vec<CustomMessage>,
7312 ) -> Result<AssistantMessage> {
7313 let on_event: AgentEventHandler = Arc::new(on_event);
7314 self.sync_runtime_selection_from_session_header().await?;
7315
7316 self.maybe_compact(Arc::clone(&on_event)).await?;
7317 let history = {
7318 let cx = crate::agent_cx::AgentCx::for_request();
7319 let session = self
7320 .session
7321 .lock(cx.cx())
7322 .await
7323 .map_err(|e| Error::session(e.to_string()))?;
7324 session.to_messages_for_current_path()
7325 };
7326 self.agent.replace_messages(history);
7327
7328 let start_len = self.agent.messages().len();
7329
7330 let user_message = Message::User(UserMessage {
7332 content: UserContent::Text(input),
7333 timestamp: Utc::now().timestamp_millis(),
7334 });
7335 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
7336 prompts.push(user_message.clone());
7337 prompts.extend(custom_messages.into_iter().map(Message::Custom));
7338
7339 {
7340 let cx = crate::agent_cx::AgentCx::for_request();
7341 let mut session = self
7342 .session
7343 .lock(cx.cx())
7344 .await
7345 .map_err(|e| Error::session(e.to_string()))?;
7346 session.append_model_message(user_message.clone());
7347 if self.save_enabled {
7348 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
7349 }
7350 }
7351
7352 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
7353 let on_event_for_run = Arc::clone(&on_event);
7354 let result = self
7355 .agent
7356 .run_with_messages_with_abort(prompts, abort, move |event| {
7357 on_event_for_run(event);
7358 })
7359 .await;
7360 drop(streaming_guard);
7361
7362 let persist_result = self.persist_new_messages(start_len + 1).await;
7365
7366 let result = result?;
7367 persist_result?;
7368 Ok(result)
7369 }
7370
7371 pub(crate) async fn run_agent_with_content(
7372 &mut self,
7373 content: Vec<ContentBlock>,
7374 abort: Option<AbortSignal>,
7375 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
7376 custom_messages: Vec<CustomMessage>,
7377 ) -> Result<AssistantMessage> {
7378 let on_event: AgentEventHandler = Arc::new(on_event);
7379 self.sync_runtime_selection_from_session_header().await?;
7380
7381 self.maybe_compact(Arc::clone(&on_event)).await?;
7382 let history = {
7383 let cx = crate::agent_cx::AgentCx::for_request();
7384 let session = self
7385 .session
7386 .lock(cx.cx())
7387 .await
7388 .map_err(|e| Error::session(e.to_string()))?;
7389 session.to_messages_for_current_path()
7390 };
7391 self.agent.replace_messages(history);
7392
7393 let start_len = self.agent.messages().len();
7394
7395 let user_message = Message::User(UserMessage {
7397 content: UserContent::Blocks(content),
7398 timestamp: Utc::now().timestamp_millis(),
7399 });
7400 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
7401 prompts.push(user_message.clone());
7402 prompts.extend(custom_messages.into_iter().map(Message::Custom));
7403
7404 {
7405 let cx = crate::agent_cx::AgentCx::for_request();
7406 let mut session = self
7407 .session
7408 .lock(cx.cx())
7409 .await
7410 .map_err(|e| Error::session(e.to_string()))?;
7411 session.append_model_message(user_message.clone());
7412 if self.save_enabled {
7413 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
7414 }
7415 }
7416
7417 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
7418 let on_event_for_run = Arc::clone(&on_event);
7419 let result = self
7420 .agent
7421 .run_with_messages_with_abort(prompts, abort, move |event| {
7422 on_event_for_run(event);
7423 })
7424 .await;
7425 drop(streaming_guard);
7426
7427 let persist_result = self.persist_new_messages(start_len + 1).await;
7430
7431 let result = result?;
7432 persist_result?;
7433 Ok(result)
7434 }
7435
7436 async fn persist_new_messages(&self, start_len: usize) -> Result<()> {
7437 let new_messages = self.agent.messages()[start_len..].to_vec();
7438 {
7439 let cx = crate::agent_cx::AgentCx::for_request();
7440 let mut session = self
7441 .session
7442 .lock(cx.cx())
7443 .await
7444 .map_err(|e| Error::session(e.to_string()))?;
7445 for message in new_messages {
7446 session.append_model_message(message);
7447 }
7448 if self.save_enabled {
7449 session
7450 .flush_autosave(AutosaveFlushTrigger::Periodic)
7451 .await?;
7452 }
7453 }
7454 Ok(())
7455 }
7456}
7457
7458fn log_repair_diagnostics(events: &[crate::extensions_js::ExtensionRepairEvent]) {
7467 use std::collections::BTreeMap;
7468
7469 for ev in events {
7471 tracing::info!(
7472 event = "extension.auto_repair",
7473 extension_id = %ev.extension_id,
7474 pattern = %ev.pattern,
7475 success = ev.success,
7476 original_error = %ev.original_error,
7477 repair_action = %ev.repair_action,
7478 );
7479 }
7480
7481 let mut by_pattern: BTreeMap<String, Vec<&str>> = BTreeMap::new();
7483 for ev in events {
7484 by_pattern
7485 .entry(ev.pattern.to_string())
7486 .or_default()
7487 .push(&ev.extension_id);
7488 }
7489
7490 let verbose = std::env::var("PI_AUTO_REPAIR_VERBOSE")
7491 .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
7492
7493 if verbose {
7494 warn!(
7495 "[auto-repair] {} extension{} auto-repaired:",
7496 events.len(),
7497 if events.len() == 1 { "" } else { "s" }
7498 );
7499 for ev in events {
7500 warn!(
7501 " {}: {} ({})",
7502 ev.pattern, ev.extension_id, ev.repair_action
7503 );
7504 }
7505 } else {
7506 let patterns: Vec<String> = by_pattern
7508 .iter()
7509 .map(|(pat, ids)| format!("{pat}:{}", ids.len()))
7510 .collect();
7511 tracing::info!(
7512 event = "extension.auto_repair.summary",
7513 count = events.len(),
7514 patterns = %patterns.join(", "),
7515 "auto-repaired {} extension(s)",
7516 events.len(),
7517 );
7518 }
7519}
7520
7521const BLOCK_IMAGES_PLACEHOLDER: &str = "Image reading is disabled.";
7522
7523#[derive(Debug, Default, Clone, Copy)]
7524struct ImageFilterStats {
7525 removed_images: usize,
7526 affected_messages: usize,
7527}
7528
7529fn filter_images_for_provider(messages: &mut [Message]) -> ImageFilterStats {
7530 let mut stats = ImageFilterStats::default();
7531 for message in messages {
7532 let removed = filter_images_from_message(message);
7533 if removed > 0 {
7534 stats.removed_images += removed;
7535 stats.affected_messages += 1;
7536 }
7537 }
7538 stats
7539}
7540
7541fn filter_images_from_message(message: &mut Message) -> usize {
7542 match message {
7543 Message::User(user) => match &mut user.content {
7544 UserContent::Text(_) => 0,
7545 UserContent::Blocks(blocks) => filter_image_blocks(blocks),
7546 },
7547 Message::Assistant(assistant) => {
7548 let assistant = Arc::make_mut(assistant);
7549 filter_image_blocks(&mut assistant.content)
7550 }
7551 Message::ToolResult(tool_result) => {
7552 filter_image_blocks(&mut Arc::make_mut(tool_result).content)
7553 }
7554 Message::Custom(_) => 0,
7555 }
7556}
7557
7558fn filter_image_blocks(blocks: &mut Vec<ContentBlock>) -> usize {
7559 let mut removed = 0usize;
7560 let mut filtered = Vec::with_capacity(blocks.len());
7561
7562 for block in blocks.drain(..) {
7563 match block {
7564 ContentBlock::Image(_) => {
7565 removed += 1;
7566 let previous_is_placeholder =
7567 filtered
7568 .last()
7569 .is_some_and(|prev| matches!(prev, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER));
7570 if !previous_is_placeholder {
7571 filtered.push(ContentBlock::Text(TextContent::new(
7572 BLOCK_IMAGES_PLACEHOLDER,
7573 )));
7574 }
7575 }
7576 other => filtered.push(other),
7577 }
7578 }
7579
7580 *blocks = filtered;
7581 removed
7582}
7583
7584fn extract_tool_calls(content: &[ContentBlock]) -> Vec<ToolCall> {
7586 content
7587 .iter()
7588 .filter_map(|block| {
7589 if let ContentBlock::ToolCall(tc) = block {
7590 Some(tc.clone())
7591 } else {
7592 None
7593 }
7594 })
7595 .collect()
7596}
7597
7598#[cfg(test)]
7603mod tests {
7604 use super::*;
7605 use crate::auth::AuthCredential;
7606 use crate::provider::{InputType, Model, ModelCost};
7607 use asupersync::runtime::RuntimeBuilder;
7608 use async_trait::async_trait;
7609 use futures::Stream;
7610 use std::collections::HashMap;
7611 use std::path::Path;
7612 use std::pin::Pin;
7613
7614 fn user_message(text: &str) -> Message {
7615 Message::User(UserMessage {
7616 content: UserContent::Text(text.to_string()),
7617 timestamp: 0,
7618 })
7619 }
7620
7621 fn assert_user_text(message: &Message, expected: &str) {
7622 assert!(
7623 matches!(
7624 message,
7625 Message::User(UserMessage {
7626 content: UserContent::Text(_),
7627 ..
7628 })
7629 ),
7630 "expected user text message, got {message:?}"
7631 );
7632 if let Message::User(UserMessage {
7633 content: UserContent::Text(text),
7634 ..
7635 }) = message
7636 {
7637 assert_eq!(text, expected);
7638 }
7639 }
7640
7641 fn sample_image_block() -> ContentBlock {
7642 ContentBlock::Image(ImageContent {
7643 data: "aGVsbG8=".to_string(),
7644 mime_type: "image/png".to_string(),
7645 })
7646 }
7647
7648 fn image_count_in_message(message: &Message) -> usize {
7649 let count_images = |blocks: &[ContentBlock]| {
7650 blocks
7651 .iter()
7652 .filter(|block| matches!(block, ContentBlock::Image(_)))
7653 .count()
7654 };
7655 match message {
7656 Message::User(UserMessage {
7657 content: UserContent::Blocks(blocks),
7658 ..
7659 }) => count_images(blocks),
7660 Message::Assistant(msg) => count_images(&msg.content),
7661 Message::ToolResult(tool_result) => count_images(&tool_result.content),
7662 Message::User(UserMessage {
7663 content: UserContent::Text(_),
7664 ..
7665 })
7666 | Message::Custom(_) => 0,
7667 }
7668 }
7669
7670 fn assistant_message(text: &str) -> AssistantMessage {
7671 AssistantMessage {
7672 content: vec![ContentBlock::Text(TextContent::new(text))],
7673 api: "test-api".to_string(),
7674 provider: "test-provider".to_string(),
7675 model: "test-model".to_string(),
7676 usage: Usage::default(),
7677 stop_reason: StopReason::Stop,
7678 error_message: None,
7679 timestamp: 0,
7680 }
7681 }
7682
7683 #[derive(Debug)]
7684 struct SilentProvider;
7685
7686 #[async_trait]
7687 #[allow(clippy::unnecessary_literal_bound)]
7688 impl Provider for SilentProvider {
7689 fn name(&self) -> &str {
7690 "silent-provider"
7691 }
7692
7693 fn api(&self) -> &str {
7694 "test-api"
7695 }
7696
7697 fn model_id(&self) -> &str {
7698 "test-model"
7699 }
7700
7701 async fn stream(
7702 &self,
7703 _context: &Context<'_>,
7704 _options: &StreamOptions,
7705 ) -> crate::error::Result<
7706 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
7707 > {
7708 Ok(Box::pin(futures::stream::empty()))
7709 }
7710 }
7711
7712 #[derive(Debug)]
7713 struct DeltaOnlyProvider;
7714
7715 #[async_trait]
7716 #[allow(clippy::unnecessary_literal_bound)]
7717 impl Provider for DeltaOnlyProvider {
7718 fn name(&self) -> &str {
7719 "test-provider"
7720 }
7721
7722 fn api(&self) -> &str {
7723 "test-api"
7724 }
7725
7726 fn model_id(&self) -> &str {
7727 "test-model"
7728 }
7729
7730 async fn stream(
7731 &self,
7732 _context: &Context<'_>,
7733 _options: &StreamOptions,
7734 ) -> crate::error::Result<
7735 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
7736 > {
7737 let final_message = assistant_message("hello");
7738 let events = vec![
7739 Ok(StreamEvent::TextDelta {
7740 content_index: 0,
7741 delta: "hello".to_string(),
7742 }),
7743 Ok(StreamEvent::Done {
7744 reason: StopReason::Stop,
7745 message: final_message,
7746 }),
7747 ];
7748 Ok(Box::pin(futures::stream::iter(events)))
7749 }
7750 }
7751
7752 #[test]
7753 fn delta_without_start_does_not_mutate_previous_message() {
7754 let runtime = RuntimeBuilder::current_thread()
7755 .build()
7756 .expect("runtime build");
7757
7758 runtime.block_on(async {
7759 let provider = Arc::new(DeltaOnlyProvider);
7760 let tools = ToolRegistry::from_tools(Vec::new());
7761 let mut agent = Agent::new(provider, tools, AgentConfig::default());
7762
7763 agent.add_message(Message::Assistant(Arc::new(assistant_message("prev"))));
7764
7765 agent
7766 .run_with_message_with_abort(user_message("hi"), None, |_| {})
7767 .await
7768 .expect("run");
7769
7770 let assistant_texts = agent
7771 .messages()
7772 .iter()
7773 .filter_map(|message| match message {
7774 Message::Assistant(msg)
7775 if matches!(msg.content.as_slice(), [ContentBlock::Text(_)]) =>
7776 {
7777 if let [ContentBlock::Text(text)] = msg.content.as_slice() {
7778 Some(text.text.clone())
7779 } else {
7780 None
7781 }
7782 }
7783 _ => None,
7784 })
7785 .collect::<Vec<_>>();
7786
7787 assert_eq!(
7788 assistant_texts.as_slice(),
7789 ["prev".to_string(), "hello".to_string()]
7790 );
7791 });
7792 }
7793
7794 #[test]
7795 fn enable_extensions_policy_resolution_defaults_to_permissive() {
7796 let policy = AgentSession::resolve_extension_policy_for_enable(None, None);
7797 assert_eq!(
7798 policy.mode,
7799 crate::extensions::ExtensionPolicyMode::Permissive
7800 );
7801 }
7802
7803 #[test]
7804 fn enable_extensions_policy_resolution_respects_config_default_toggle() {
7805 let config = crate::config::Config {
7806 extension_policy: Some(crate::config::ExtensionPolicyConfig {
7807 profile: None,
7808 default_permissive: Some(false),
7809 allow_dangerous: None,
7810 }),
7811 ..Default::default()
7812 };
7813 let policy = AgentSession::resolve_extension_policy_for_enable(Some(&config), None);
7814 assert_eq!(policy.mode, crate::extensions::ExtensionPolicyMode::Strict);
7815 }
7816
7817 #[test]
7818 fn enable_extensions_policy_resolution_prefers_explicit_policy() {
7819 let config = crate::config::Config {
7820 extension_policy: Some(crate::config::ExtensionPolicyConfig {
7821 profile: None,
7822 default_permissive: Some(false),
7823 allow_dangerous: None,
7824 }),
7825 ..Default::default()
7826 };
7827 let explicit = crate::extensions::PolicyProfile::Permissive.to_policy();
7828 let policy =
7829 AgentSession::resolve_extension_policy_for_enable(Some(&config), Some(explicit));
7830 assert_eq!(
7831 policy.mode,
7832 crate::extensions::ExtensionPolicyMode::Permissive
7833 );
7834 }
7835
7836 #[test]
7837 fn test_extract_tool_calls() {
7838 let content = vec![
7839 ContentBlock::Text(TextContent::new("Hello")),
7840 ContentBlock::ToolCall(ToolCall {
7841 id: "tc1".to_string(),
7842 name: "read".to_string(),
7843 arguments: serde_json::json!({"path": "file.txt"}),
7844 thought_signature: None,
7845 }),
7846 ContentBlock::Text(TextContent::new("World")),
7847 ContentBlock::ToolCall(ToolCall {
7848 id: "tc2".to_string(),
7849 name: "bash".to_string(),
7850 arguments: serde_json::json!({"command": "ls"}),
7851 thought_signature: None,
7852 }),
7853 ];
7854
7855 let tool_calls = extract_tool_calls(&content);
7856 assert_eq!(tool_calls.len(), 2);
7857 assert_eq!(tool_calls[0].name, "read");
7858 assert_eq!(tool_calls[1].name, "bash");
7859 }
7860
7861 #[test]
7862 fn test_agent_config_default() {
7863 let config = AgentConfig::default();
7864 assert_eq!(config.max_tool_iterations, 50);
7865 assert!(config.system_prompt.is_none());
7866 assert!(!config.block_images);
7867 }
7868
7869 #[test]
7870 fn filter_image_blocks_replaces_images_with_deduped_placeholder_text() {
7871 let mut blocks = vec![
7872 sample_image_block(),
7873 sample_image_block(),
7874 ContentBlock::Text(TextContent::new("tail")),
7875 sample_image_block(),
7876 ];
7877
7878 let removed = filter_image_blocks(&mut blocks);
7879
7880 assert_eq!(removed, 3);
7881 assert!(
7882 !blocks
7883 .iter()
7884 .any(|block| matches!(block, ContentBlock::Image(_)))
7885 );
7886 assert!(matches!(
7887 blocks.first(),
7888 Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
7889 ));
7890 assert!(matches!(
7891 blocks.get(1),
7892 Some(ContentBlock::Text(TextContent { text, .. })) if text == "tail"
7893 ));
7894 assert!(matches!(
7895 blocks.get(2),
7896 Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
7897 ));
7898 }
7899
7900 #[test]
7901 fn filter_images_for_provider_filters_images_from_all_block_message_types() {
7902 let mut messages = vec![
7903 Message::User(UserMessage {
7904 content: UserContent::Blocks(vec![
7905 ContentBlock::Text(TextContent::new("hello")),
7906 sample_image_block(),
7907 ]),
7908 timestamp: 0,
7909 }),
7910 Message::Assistant(Arc::new(AssistantMessage {
7911 content: vec![sample_image_block()],
7912 api: "test".to_string(),
7913 provider: "test".to_string(),
7914 model: "test".to_string(),
7915 usage: Usage::default(),
7916 stop_reason: StopReason::Stop,
7917 error_message: None,
7918 timestamp: 0,
7919 })),
7920 Message::tool_result(ToolResultMessage {
7921 tool_call_id: "tc1".to_string(),
7922 tool_name: "read".to_string(),
7923 content: vec![
7924 sample_image_block(),
7925 ContentBlock::Text(TextContent::new("ok")),
7926 ],
7927 details: None,
7928 is_error: false,
7929 timestamp: 0,
7930 }),
7931 ];
7932
7933 let stats = filter_images_for_provider(&mut messages);
7934
7935 assert_eq!(stats.removed_images, 3);
7936 assert_eq!(stats.affected_messages, 3);
7937 assert_eq!(
7938 messages.iter().map(image_count_in_message).sum::<usize>(),
7939 0,
7940 "no images should remain in provider-bound context"
7941 );
7942 }
7943
7944 #[test]
7945 fn build_context_strips_images_when_block_images_enabled() {
7946 let mut agent = Agent::new(
7947 Arc::new(SilentProvider),
7948 ToolRegistry::new(&[], Path::new("."), None),
7949 AgentConfig {
7950 system_prompt: None,
7951 max_tool_iterations: 50,
7952 stream_options: StreamOptions::default(),
7953 block_images: true,
7954 fail_closed_hooks: false,
7955 },
7956 );
7957 agent.add_message(Message::User(UserMessage {
7958 content: UserContent::Blocks(vec![sample_image_block()]),
7959 timestamp: 0,
7960 }));
7961
7962 let context = agent.build_context();
7963 assert_eq!(context.messages.len(), 1);
7964 assert_eq!(image_count_in_message(&context.messages[0]), 0);
7965 assert!(matches!(
7966 &context.messages[0],
7967 Message::User(UserMessage {
7968 content: UserContent::Blocks(blocks),
7969 ..
7970 }) if blocks
7971 .iter()
7972 .any(|block| matches!(block, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER))
7973 ));
7974 }
7975
7976 #[test]
7977 fn build_context_keeps_images_when_block_images_disabled() {
7978 let mut agent = Agent::new(
7979 Arc::new(SilentProvider),
7980 ToolRegistry::new(&[], Path::new("."), None),
7981 AgentConfig {
7982 system_prompt: None,
7983 max_tool_iterations: 50,
7984 stream_options: StreamOptions::default(),
7985 block_images: false,
7986 fail_closed_hooks: false,
7987 },
7988 );
7989 agent.add_message(Message::User(UserMessage {
7990 content: UserContent::Blocks(vec![sample_image_block()]),
7991 timestamp: 0,
7992 }));
7993
7994 let context = agent.build_context();
7995 assert_eq!(context.messages.len(), 1);
7996 assert_eq!(image_count_in_message(&context.messages[0]), 1);
7997 }
7998
7999 #[test]
8000 fn auto_compaction_start_serializes_with_pi_mono_compatible_type_tag() {
8001 let event = AgentEvent::AutoCompactionStart {
8002 reason: "threshold".to_string(),
8003 };
8004 let json = serde_json::to_value(&event).unwrap();
8005 assert_eq!(json["type"], "auto_compaction_start");
8006 assert_eq!(json["reason"], "threshold");
8007 }
8008
8009 #[test]
8010 fn auto_compaction_end_serializes_with_pi_mono_compatible_fields() {
8011 let event = AgentEvent::AutoCompactionEnd {
8012 result: Some(serde_json::json!({"tokens_before": 5000, "tokens_after": 2000})),
8013 aborted: false,
8014 will_retry: false,
8015 error_message: None,
8016 };
8017 let json = serde_json::to_value(&event).unwrap();
8018 assert_eq!(json["type"], "auto_compaction_end");
8019 assert_eq!(json["aborted"], false);
8020 assert_eq!(json["willRetry"], false);
8021 assert!(json.get("errorMessage").is_none()); assert!(json["result"].is_object());
8023 }
8024
8025 #[test]
8026 fn auto_compaction_end_includes_error_message_when_present() {
8027 let event = AgentEvent::AutoCompactionEnd {
8028 result: None,
8029 aborted: true,
8030 will_retry: false,
8031 error_message: Some("Compaction failed".to_string()),
8032 };
8033 let json = serde_json::to_value(&event).unwrap();
8034 assert_eq!(json["type"], "auto_compaction_end");
8035 assert_eq!(json["aborted"], true);
8036 assert_eq!(json["errorMessage"], "Compaction failed");
8037 }
8038
8039 #[test]
8040 fn auto_retry_start_serializes_with_camel_case_fields() {
8041 let event = AgentEvent::AutoRetryStart {
8042 attempt: 1,
8043 max_attempts: 3,
8044 delay_ms: 2000,
8045 error_message: "Rate limited".to_string(),
8046 };
8047 let json = serde_json::to_value(&event).unwrap();
8048 assert_eq!(json["type"], "auto_retry_start");
8049 assert_eq!(json["attempt"], 1);
8050 assert_eq!(json["maxAttempts"], 3);
8051 assert_eq!(json["delayMs"], 2000);
8052 assert_eq!(json["errorMessage"], "Rate limited");
8053 }
8054
8055 #[test]
8056 fn auto_retry_end_serializes_success_and_omits_null_final_error() {
8057 let event = AgentEvent::AutoRetryEnd {
8058 success: true,
8059 attempt: 2,
8060 final_error: None,
8061 };
8062 let json = serde_json::to_value(&event).unwrap();
8063 assert_eq!(json["type"], "auto_retry_end");
8064 assert_eq!(json["success"], true);
8065 assert_eq!(json["attempt"], 2);
8066 assert!(json.get("finalError").is_none());
8067 }
8068
8069 #[test]
8070 fn auto_retry_end_includes_final_error_on_failure() {
8071 let event = AgentEvent::AutoRetryEnd {
8072 success: false,
8073 attempt: 3,
8074 final_error: Some("Max retries exceeded".to_string()),
8075 };
8076 let json = serde_json::to_value(&event).unwrap();
8077 assert_eq!(json["type"], "auto_retry_end");
8078 assert_eq!(json["success"], false);
8079 assert_eq!(json["attempt"], 3);
8080 assert_eq!(json["finalError"], "Max retries exceeded");
8081 }
8082
8083 #[test]
8084 fn message_queue_push_increments_seq_and_counts_both_queues() {
8085 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
8086 assert_eq!(queue.pending_count(), 0);
8087
8088 assert_eq!(queue.push_steering(user_message("s1")), 0);
8089 assert_eq!(queue.push_follow_up(user_message("f1")), 1);
8090 assert_eq!(queue.push_steering(user_message("s2")), 2);
8091
8092 assert_eq!(queue.pending_count(), 3);
8093 }
8094
8095 #[test]
8096 fn message_queue_pop_steering_one_at_a_time_preserves_order() {
8097 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
8098 queue.push_steering(user_message("s1"));
8099 queue.push_steering(user_message("s2"));
8100
8101 let first = queue.pop_steering();
8102 assert_eq!(first.len(), 1);
8103 assert_user_text(&first[0], "s1");
8104 assert_eq!(queue.pending_count(), 1);
8105
8106 let second = queue.pop_steering();
8107 assert_eq!(second.len(), 1);
8108 assert_user_text(&second[0], "s2");
8109 assert_eq!(queue.pending_count(), 0);
8110
8111 let empty = queue.pop_steering();
8112 assert!(empty.is_empty());
8113 }
8114
8115 #[test]
8116 fn message_queue_pop_respects_queue_modes_per_kind() {
8117 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
8118 queue.push_steering(user_message("s1"));
8119 queue.push_steering(user_message("s2"));
8120 queue.push_follow_up(user_message("f1"));
8121 queue.push_follow_up(user_message("f2"));
8122
8123 let steering = queue.pop_steering();
8124 assert_eq!(steering.len(), 2);
8125 assert_user_text(&steering[0], "s1");
8126 assert_user_text(&steering[1], "s2");
8127 assert_eq!(queue.pending_count(), 2);
8128
8129 let follow_up = queue.pop_follow_up();
8130 assert_eq!(follow_up.len(), 1);
8131 assert_user_text(&follow_up[0], "f1");
8132 assert_eq!(queue.pending_count(), 1);
8133
8134 let follow_up = queue.pop_follow_up();
8135 assert_eq!(follow_up.len(), 1);
8136 assert_user_text(&follow_up[0], "f2");
8137 assert_eq!(queue.pending_count(), 0);
8138 }
8139
8140 #[test]
8141 fn message_queue_set_modes_applies_to_existing_messages() {
8142 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
8143 queue.push_steering(user_message("s1"));
8144 queue.push_steering(user_message("s2"));
8145
8146 let first = queue.pop_steering();
8147 assert_eq!(first.len(), 1);
8148 assert_user_text(&first[0], "s1");
8149
8150 queue.set_modes(QueueMode::All, QueueMode::OneAtATime);
8151 let remaining = queue.pop_steering();
8152 assert_eq!(remaining.len(), 1);
8153 assert_user_text(&remaining[0], "s2");
8154 }
8155
8156 fn build_switch_test_session(auth: &AuthStorage) -> AgentSession {
8157 let registry = ModelRegistry::load(auth, None);
8158 let current_entry = registry
8159 .find("anthropic", "claude-sonnet-4-5")
8160 .expect("anthropic model in registry");
8161 let provider = crate::providers::create_provider(¤t_entry, None)
8162 .expect("create anthropic provider");
8163 let tools = ToolRegistry::new(&[], Path::new("."), None);
8164 let mut stream_options = StreamOptions {
8165 api_key: Some("stale-key".to_string()),
8166 ..Default::default()
8167 };
8168 let _ = stream_options
8169 .headers
8170 .insert("x-stale-header".to_string(), "stale-value".to_string());
8171 let agent = Agent::new(
8172 provider,
8173 tools,
8174 AgentConfig {
8175 system_prompt: None,
8176 max_tool_iterations: 50,
8177 stream_options,
8178 block_images: false,
8179 fail_closed_hooks: false,
8180 },
8181 );
8182
8183 let mut session = Session::in_memory();
8184 session.header.provider = Some("anthropic".to_string());
8185 session.header.model_id = Some("claude-sonnet-4-5".to_string());
8186
8187 let mut agent_session = AgentSession::new(
8188 agent,
8189 Arc::new(Mutex::new(session)),
8190 false,
8191 ResolvedCompactionSettings::default(),
8192 );
8193 agent_session.set_model_registry(registry);
8194 agent_session.set_auth_storage(auth.clone());
8195 agent_session
8196 }
8197
8198 #[test]
8199 fn compaction_runtime_handle_creates_fallback_runtime() {
8200 let dir = tempfile::tempdir().expect("tempdir");
8201 let auth_path = dir.path().join("auth.json");
8202 let auth = AuthStorage::load(auth_path).expect("load auth");
8203 let mut agent_session = build_switch_test_session(&auth);
8204
8205 assert!(agent_session.compaction_runtime.is_none());
8206 assert!(agent_session.runtime_handle.is_none());
8207
8208 let runtime_handle = agent_session
8209 .compaction_runtime_handle()
8210 .expect("create fallback compaction runtime");
8211 let join = runtime_handle.spawn(async { 7_u8 });
8212 assert_eq!(futures::executor::block_on(join), 7);
8213
8214 assert!(agent_session.compaction_runtime.is_some());
8215 assert!(agent_session.runtime_handle.is_some());
8216 }
8217
8218 #[test]
8219 fn apply_session_model_selection_updates_stream_credentials_and_headers() {
8220 let dir = tempfile::tempdir().expect("tempdir");
8221 let auth_path = dir.path().join("auth.json");
8222 let mut auth = AuthStorage::load(auth_path).expect("load auth");
8223 auth.set(
8224 "anthropic",
8225 AuthCredential::ApiKey {
8226 key: "anthropic-key".to_string(),
8227 },
8228 );
8229 auth.set(
8230 "openai",
8231 AuthCredential::ApiKey {
8232 key: "openai-key".to_string(),
8233 },
8234 );
8235
8236 let mut agent_session = build_switch_test_session(&auth);
8237 agent_session
8238 .apply_session_model_selection("openai", "gpt-4o")
8239 .expect("switch should update stream options");
8240
8241 assert_eq!(agent_session.agent.provider().name(), "openai");
8242 assert_eq!(agent_session.agent.provider().model_id(), "gpt-4o");
8243 assert_eq!(
8244 agent_session.agent.stream_options().api_key.as_deref(),
8245 Some("openai-key")
8246 );
8247 assert!(
8248 agent_session.agent.stream_options().headers.is_empty(),
8249 "stream headers should be refreshed from selected model entry"
8250 );
8251 }
8252
8253 #[test]
8254 fn apply_session_model_selection_clears_stale_key_for_keyless_target() {
8255 let dir = tempfile::tempdir().expect("tempdir");
8256 let auth_path = dir.path().join("auth.json");
8257 let mut auth = AuthStorage::load(auth_path).expect("load auth");
8258 auth.set(
8259 "anthropic",
8260 AuthCredential::ApiKey {
8261 key: "anthropic-key".to_string(),
8262 },
8263 );
8264
8265 let mut registry = ModelRegistry::load(&auth, None);
8266 registry.merge_entries(vec![ModelEntry {
8267 model: Model {
8268 id: "local-model".to_string(),
8269 name: "Local Model".to_string(),
8270 api: "openai-completions".to_string(),
8271 provider: "acme-local".to_string(),
8272 base_url: "https://example.invalid/v1".to_string(),
8273 reasoning: true,
8274 input: vec![InputType::Text],
8275 cost: ModelCost {
8276 input: 0.0,
8277 output: 0.0,
8278 cache_read: 0.0,
8279 cache_write: 0.0,
8280 },
8281 context_window: 128_000,
8282 max_tokens: 8_192,
8283 headers: HashMap::new(),
8284 },
8285 api_key: None,
8286 headers: HashMap::new(),
8287 auth_header: false,
8288 compat: None,
8289 oauth_config: None,
8290 }]);
8291
8292 let mut agent_session = build_switch_test_session(&auth);
8293 agent_session.set_model_registry(registry);
8294 agent_session
8295 .apply_session_model_selection("acme-local", "local-model")
8296 .expect("keyless local model should still activate");
8297
8298 assert_eq!(agent_session.agent.provider().name(), "acme-local");
8299 assert_eq!(
8300 agent_session.agent.stream_options().api_key,
8301 None,
8302 "stale key must be cleared when target model has no configured key"
8303 );
8304 }
8305
8306 #[test]
8307 fn apply_session_model_selection_treats_blank_model_key_as_missing_credential() {
8308 let dir = tempfile::tempdir().expect("tempdir");
8309 let auth_path = dir.path().join("auth.json");
8310 let auth = AuthStorage::load(auth_path).expect("load auth");
8311
8312 let mut registry = ModelRegistry::load(&auth, None);
8313 registry.merge_entries(vec![ModelEntry {
8314 model: Model {
8315 id: "blank-model".to_string(),
8316 name: "Blank Model".to_string(),
8317 api: "openai-completions".to_string(),
8318 provider: "acme".to_string(),
8319 base_url: "https://example.invalid/v1".to_string(),
8320 reasoning: true,
8321 input: vec![InputType::Text],
8322 cost: ModelCost {
8323 input: 0.0,
8324 output: 0.0,
8325 cache_read: 0.0,
8326 cache_write: 0.0,
8327 },
8328 context_window: 128_000,
8329 max_tokens: 8_192,
8330 headers: HashMap::new(),
8331 },
8332 api_key: Some(" ".to_string()),
8333 headers: HashMap::new(),
8334 auth_header: true,
8335 compat: None,
8336 oauth_config: None,
8337 }]);
8338
8339 let mut agent_session = build_switch_test_session(&auth);
8340 agent_session.set_model_registry(registry);
8341 let err = agent_session
8342 .apply_session_model_selection("acme", "blank-model")
8343 .expect_err("blank keys must not satisfy credential requirements");
8344
8345 assert!(
8346 err.to_string()
8347 .contains("Missing credentials for acme/blank-model"),
8348 "unexpected error: {err}"
8349 );
8350 assert_eq!(agent_session.agent.provider().name(), "anthropic");
8351 assert_eq!(
8352 agent_session.agent.stream_options().api_key,
8353 Some("stale-key".to_string()),
8354 "failed switches must preserve the prior runtime credentials"
8355 );
8356 }
8357
8358 #[test]
8359 fn set_provider_model_preserves_session_header_when_switch_fails() {
8360 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8361 .build()
8362 .expect("build runtime");
8363
8364 runtime.block_on(async {
8365 let dir = tempfile::tempdir().expect("tempdir");
8366 let auth_path = dir.path().join("auth.json");
8367 let auth = AuthStorage::load(auth_path).expect("load auth");
8368 let mut agent_session = build_switch_test_session(&auth);
8369
8370 {
8371 let cx = crate::agent_cx::AgentCx::for_request();
8372 let mut session = agent_session
8373 .session
8374 .lock(cx.cx())
8375 .await
8376 .expect("session lock");
8377 session.header.provider = Some("anthropic".to_string());
8378 session.header.model_id = Some("claude-sonnet-4-5".to_string());
8379 }
8380
8381 let err = agent_session
8382 .set_provider_model("missing-provider", "missing-model")
8383 .await
8384 .expect_err("missing model should not switch");
8385 assert!(
8386 err.to_string()
8387 .contains("Unable to switch provider/model to missing-provider/missing-model"),
8388 "unexpected error: {err}"
8389 );
8390 assert_eq!(agent_session.agent.provider().name(), "anthropic");
8391 assert_eq!(
8392 agent_session.agent.provider().model_id(),
8393 "claude-sonnet-4-5"
8394 );
8395
8396 let cx = crate::agent_cx::AgentCx::for_request();
8397 let session = agent_session
8398 .session
8399 .lock(cx.cx())
8400 .await
8401 .expect("session lock");
8402 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
8403 assert_eq!(
8404 session.header.model_id.as_deref(),
8405 Some("claude-sonnet-4-5")
8406 );
8407 });
8408 }
8409
8410 #[test]
8411 fn set_provider_model_rejects_missing_credentials_without_switching() {
8412 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8413 .build()
8414 .expect("build runtime");
8415
8416 runtime.block_on(async {
8417 let dir = tempfile::tempdir().expect("tempdir");
8418 let auth_path = dir.path().join("auth.json");
8419 let auth = AuthStorage::load(auth_path).expect("load auth");
8420 let mut agent_session = build_switch_test_session(&auth);
8421
8422 {
8423 let cx = crate::agent_cx::AgentCx::for_request();
8424 let mut session = agent_session
8425 .session
8426 .lock(cx.cx())
8427 .await
8428 .expect("session lock");
8429 session.header.provider = Some("anthropic".to_string());
8430 session.header.model_id = Some("claude-sonnet-4-5".to_string());
8431 }
8432
8433 let err = agent_session
8434 .set_provider_model("openai", "gpt-4o")
8435 .await
8436 .expect_err("missing credentials should abort model switch");
8437 assert!(
8438 err.to_string()
8439 .contains("Missing credentials for openai/gpt-4o"),
8440 "unexpected error: {err}"
8441 );
8442 assert_eq!(agent_session.agent.provider().name(), "anthropic");
8443 assert_eq!(
8444 agent_session.agent.provider().model_id(),
8445 "claude-sonnet-4-5"
8446 );
8447
8448 let cx = crate::agent_cx::AgentCx::for_request();
8449 let session = agent_session
8450 .session
8451 .lock(cx.cx())
8452 .await
8453 .expect("session lock");
8454 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
8455 assert_eq!(
8456 session.header.model_id.as_deref(),
8457 Some("claude-sonnet-4-5")
8458 );
8459 });
8460 }
8461
8462 #[test]
8463 fn set_provider_model_clamps_thinking_for_non_reasoning_targets() {
8464 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8465 .build()
8466 .expect("build runtime");
8467
8468 runtime.block_on(async {
8469 let dir = tempfile::tempdir().expect("tempdir");
8470 let auth_path = dir.path().join("auth.json");
8471 let auth = AuthStorage::load(auth_path).expect("load auth");
8472
8473 let mut registry = ModelRegistry::load(&auth, None);
8474 registry.merge_entries(vec![ModelEntry {
8475 model: Model {
8476 id: "plain-model".to_string(),
8477 name: "Plain Model".to_string(),
8478 api: "openai-completions".to_string(),
8479 provider: "acme".to_string(),
8480 base_url: "https://example.invalid/v1".to_string(),
8481 reasoning: false,
8482 input: vec![InputType::Text],
8483 cost: ModelCost {
8484 input: 0.0,
8485 output: 0.0,
8486 cache_read: 0.0,
8487 cache_write: 0.0,
8488 },
8489 context_window: 128_000,
8490 max_tokens: 8_192,
8491 headers: HashMap::new(),
8492 },
8493 api_key: None,
8494 headers: HashMap::new(),
8495 auth_header: false,
8496 compat: None,
8497 oauth_config: None,
8498 }]);
8499
8500 let mut agent_session = build_switch_test_session(&auth);
8501 agent_session.set_model_registry(registry);
8502 agent_session.agent.stream_options_mut().thinking_level =
8503 Some(crate::model::ThinkingLevel::High);
8504
8505 {
8506 let cx = crate::agent_cx::AgentCx::for_request();
8507 let mut session = agent_session
8508 .session
8509 .lock(cx.cx())
8510 .await
8511 .expect("session lock");
8512 session.header.thinking_level = Some("high".to_string());
8513 }
8514
8515 agent_session
8516 .set_provider_model("acme", "plain-model")
8517 .await
8518 .expect("switch should clamp unsupported thinking");
8519
8520 assert_eq!(agent_session.agent.provider().name(), "acme");
8521 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
8522 assert_eq!(
8523 agent_session.agent.stream_options().thinking_level,
8524 Some(crate::model::ThinkingLevel::Off)
8525 );
8526
8527 let cx = crate::agent_cx::AgentCx::for_request();
8528 let session = agent_session
8529 .session
8530 .lock(cx.cx())
8531 .await
8532 .expect("session lock");
8533 assert_eq!(session.header.provider.as_deref(), Some("acme"));
8534 assert_eq!(session.header.model_id.as_deref(), Some("plain-model"));
8535 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
8536 });
8537 }
8538
8539 #[test]
8540 fn set_provider_model_records_model_change_once() {
8541 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8542 .build()
8543 .expect("build runtime");
8544
8545 runtime.block_on(async {
8546 let dir = tempfile::tempdir().expect("tempdir");
8547 let auth_path = dir.path().join("auth.json");
8548 let mut auth = AuthStorage::load(auth_path).expect("load auth");
8549 auth.set(
8550 "anthropic",
8551 AuthCredential::ApiKey {
8552 key: "anthropic-key".to_string(),
8553 },
8554 );
8555 auth.set(
8556 "openai",
8557 AuthCredential::ApiKey {
8558 key: "openai-key".to_string(),
8559 },
8560 );
8561
8562 let mut agent_session = build_switch_test_session(&auth);
8563 agent_session
8564 .set_provider_model("openai", "gpt-4o")
8565 .await
8566 .expect("switch model");
8567 agent_session
8568 .set_provider_model("openai", "gpt-4o")
8569 .await
8570 .expect("repeat same model");
8571
8572 let cx = crate::agent_cx::AgentCx::for_request();
8573 let session = agent_session
8574 .session
8575 .lock(cx.cx())
8576 .await
8577 .expect("session lock");
8578 let model_changes = session
8579 .entries_for_current_path()
8580 .iter()
8581 .filter(|entry| matches!(entry, crate::session::SessionEntry::ModelChange(_)))
8582 .count();
8583 assert_eq!(model_changes, 1);
8584 });
8585 }
8586
8587 #[test]
8588 fn sync_runtime_selection_from_session_header_clamps_and_normalizes_thinking() {
8589 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8590 .build()
8591 .expect("build runtime");
8592
8593 runtime.block_on(async {
8594 let dir = tempfile::tempdir().expect("tempdir");
8595 let auth_path = dir.path().join("auth.json");
8596 let auth = AuthStorage::load(auth_path).expect("load auth");
8597
8598 let mut registry = ModelRegistry::load(&auth, None);
8599 registry.merge_entries(vec![ModelEntry {
8600 model: Model {
8601 id: "plain-model".to_string(),
8602 name: "Plain Model".to_string(),
8603 api: "openai-completions".to_string(),
8604 provider: "acme".to_string(),
8605 base_url: "https://example.invalid/v1".to_string(),
8606 reasoning: false,
8607 input: vec![InputType::Text],
8608 cost: ModelCost {
8609 input: 0.0,
8610 output: 0.0,
8611 cache_read: 0.0,
8612 cache_write: 0.0,
8613 },
8614 context_window: 128_000,
8615 max_tokens: 8_192,
8616 headers: HashMap::new(),
8617 },
8618 api_key: None,
8619 headers: HashMap::new(),
8620 auth_header: false,
8621 compat: None,
8622 oauth_config: None,
8623 }]);
8624
8625 let mut agent_session = build_switch_test_session(&auth);
8626 agent_session.set_model_registry(registry);
8627 agent_session.agent.stream_options_mut().thinking_level =
8628 Some(crate::model::ThinkingLevel::High);
8629
8630 {
8631 let cx = crate::agent_cx::AgentCx::for_request();
8632 let mut session = agent_session
8633 .session
8634 .lock(cx.cx())
8635 .await
8636 .expect("session lock");
8637 session.header.provider = Some("acme".to_string());
8638 session.header.model_id = Some("plain-model".to_string());
8639 session.header.thinking_level = Some("high".to_string());
8640 }
8641
8642 agent_session
8643 .sync_runtime_selection_from_session_header()
8644 .await
8645 .expect("sync runtime selection");
8646
8647 assert_eq!(agent_session.agent.provider().name(), "acme");
8648 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
8649 assert_eq!(
8650 agent_session.agent.stream_options().thinking_level,
8651 Some(crate::model::ThinkingLevel::Off)
8652 );
8653
8654 let cx = crate::agent_cx::AgentCx::for_request();
8655 let session = agent_session
8656 .session
8657 .lock(cx.cx())
8658 .await
8659 .expect("session lock");
8660 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
8661 let thinking_changes = session
8662 .entries_for_current_path()
8663 .iter()
8664 .filter(|entry| {
8665 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
8666 })
8667 .count();
8668 assert_eq!(thinking_changes, 1);
8669 });
8670 }
8671
8672 #[test]
8673 fn sync_runtime_selection_from_session_header_clamps_current_thinking_when_header_omits_it() {
8674 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8675 .build()
8676 .expect("build runtime");
8677
8678 runtime.block_on(async {
8679 let dir = tempfile::tempdir().expect("tempdir");
8680 let auth_path = dir.path().join("auth.json");
8681 let auth = AuthStorage::load(auth_path).expect("load auth");
8682
8683 let mut registry = ModelRegistry::load(&auth, None);
8684 registry.merge_entries(vec![ModelEntry {
8685 model: Model {
8686 id: "plain-model".to_string(),
8687 name: "Plain Model".to_string(),
8688 api: "openai-completions".to_string(),
8689 provider: "acme".to_string(),
8690 base_url: "https://example.invalid/v1".to_string(),
8691 reasoning: false,
8692 input: vec![InputType::Text],
8693 cost: ModelCost {
8694 input: 0.0,
8695 output: 0.0,
8696 cache_read: 0.0,
8697 cache_write: 0.0,
8698 },
8699 context_window: 128_000,
8700 max_tokens: 8_192,
8701 headers: HashMap::new(),
8702 },
8703 api_key: None,
8704 headers: HashMap::new(),
8705 auth_header: false,
8706 compat: None,
8707 oauth_config: None,
8708 }]);
8709
8710 let mut agent_session = build_switch_test_session(&auth);
8711 agent_session.set_model_registry(registry);
8712 agent_session.agent.stream_options_mut().thinking_level =
8713 Some(crate::model::ThinkingLevel::High);
8714
8715 {
8716 let cx = crate::agent_cx::AgentCx::for_request();
8717 let mut session = agent_session
8718 .session
8719 .lock(cx.cx())
8720 .await
8721 .expect("session lock");
8722 session.header.provider = Some("acme".to_string());
8723 session.header.model_id = Some("plain-model".to_string());
8724 session.header.thinking_level = None;
8725 }
8726
8727 agent_session
8728 .sync_runtime_selection_from_session_header()
8729 .await
8730 .expect("sync runtime selection");
8731
8732 assert_eq!(agent_session.agent.provider().name(), "acme");
8733 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
8734 assert_eq!(
8735 agent_session.agent.stream_options().thinking_level,
8736 Some(crate::model::ThinkingLevel::Off)
8737 );
8738
8739 let cx = crate::agent_cx::AgentCx::for_request();
8740 let session = agent_session
8741 .session
8742 .lock(cx.cx())
8743 .await
8744 .expect("session lock");
8745 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
8746 let thinking_changes = session
8747 .entries_for_current_path()
8748 .iter()
8749 .filter(|entry| {
8750 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
8751 })
8752 .count();
8753 assert_eq!(thinking_changes, 1);
8754 });
8755 }
8756
8757 #[test]
8758 fn sync_runtime_selection_from_session_header_rejects_missing_credentials() {
8759 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8760 .build()
8761 .expect("build runtime");
8762
8763 runtime.block_on(async {
8764 let dir = tempfile::tempdir().expect("tempdir");
8765 let auth_path = dir.path().join("auth.json");
8766 let auth = AuthStorage::load(auth_path).expect("load auth");
8767 let mut agent_session = build_switch_test_session(&auth);
8768
8769 {
8770 let cx = crate::agent_cx::AgentCx::for_request();
8771 let mut session = agent_session
8772 .session
8773 .lock(cx.cx())
8774 .await
8775 .expect("session lock");
8776 session.header.provider = Some("openai".to_string());
8777 session.header.model_id = Some("gpt-4o".to_string());
8778 }
8779
8780 let err = agent_session
8781 .sync_runtime_selection_from_session_header()
8782 .await
8783 .expect_err("sync should reject switching to a credentialed target without a key");
8784 assert!(
8785 err.to_string()
8786 .contains("Missing credentials for openai/gpt-4o"),
8787 "unexpected error: {err}"
8788 );
8789 assert_eq!(agent_session.agent.provider().name(), "anthropic");
8790 assert_eq!(
8791 agent_session.agent.provider().model_id(),
8792 "claude-sonnet-4-5"
8793 );
8794
8795 let cx = crate::agent_cx::AgentCx::for_request();
8796 let session = agent_session
8797 .session
8798 .lock(cx.cx())
8799 .await
8800 .expect("session lock");
8801 assert_eq!(session.header.provider.as_deref(), Some("openai"));
8802 assert_eq!(session.header.model_id.as_deref(), Some("gpt-4o"));
8803 });
8804 }
8805
8806 #[test]
8807 fn set_provider_model_allows_current_model_without_registry() {
8808 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
8809 .build()
8810 .expect("build runtime");
8811
8812 runtime.block_on(async {
8813 let dir = tempfile::tempdir().expect("tempdir");
8814 let auth_path = dir.path().join("auth.json");
8815 let auth = AuthStorage::load(auth_path).expect("load auth");
8816 let mut agent_session = build_switch_test_session(&auth);
8817 agent_session.model_registry = None;
8818 agent_session.agent.stream_options_mut().thinking_level =
8819 Some(crate::model::ThinkingLevel::High);
8820
8821 agent_session
8822 .set_provider_model("anthropic", "claude-sonnet-4-5")
8823 .await
8824 .expect("re-persisting the current model should succeed without a registry");
8825
8826 assert_eq!(agent_session.agent.provider().name(), "anthropic");
8827 assert_eq!(
8828 agent_session.agent.provider().model_id(),
8829 "claude-sonnet-4-5"
8830 );
8831 assert_eq!(
8832 agent_session.agent.stream_options().thinking_level,
8833 Some(crate::model::ThinkingLevel::High)
8834 );
8835
8836 let cx = crate::agent_cx::AgentCx::for_request();
8837 let session = agent_session
8838 .session
8839 .lock(cx.cx())
8840 .await
8841 .expect("session lock");
8842 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
8843 assert_eq!(
8844 session.header.model_id.as_deref(),
8845 Some("claude-sonnet-4-5")
8846 );
8847 assert_eq!(session.header.thinking_level.as_deref(), Some("high"));
8848 });
8849 }
8850
8851 #[test]
8852 fn auto_compaction_start_serializes_to_pi_mono_format() {
8853 let event = AgentEvent::AutoCompactionStart {
8854 reason: "threshold".to_string(),
8855 };
8856 let json = serde_json::to_value(&event).unwrap();
8857 assert_eq!(json["type"], "auto_compaction_start");
8858 assert_eq!(json["reason"], "threshold");
8859 }
8860
8861 #[test]
8862 fn auto_compaction_end_serializes_to_pi_mono_format() {
8863 let event = AgentEvent::AutoCompactionEnd {
8864 result: Some(serde_json::json!({
8865 "summary": "Compacted",
8866 "firstKeptEntryId": "abc123",
8867 "tokensBefore": 50000,
8868 "details": { "readFiles": [], "modifiedFiles": [] }
8869 })),
8870 aborted: false,
8871 will_retry: true,
8872 error_message: None,
8873 };
8874 let json = serde_json::to_value(&event).unwrap();
8875 assert_eq!(json["type"], "auto_compaction_end");
8876 assert!(json["result"].is_object());
8877 assert_eq!(json["aborted"], false);
8878 assert_eq!(json["willRetry"], true);
8879 assert!(json.get("errorMessage").is_none());
8880 }
8881
8882 #[test]
8883 fn auto_compaction_end_with_error_serializes_error_message() {
8884 let event = AgentEvent::AutoCompactionEnd {
8885 result: None,
8886 aborted: false,
8887 will_retry: false,
8888 error_message: Some("compaction failed".to_string()),
8889 };
8890 let json = serde_json::to_value(&event).unwrap();
8891 assert_eq!(json["type"], "auto_compaction_end");
8892 assert!(json["result"].is_null());
8893 assert_eq!(json["errorMessage"], "compaction failed");
8894 }
8895
8896 #[test]
8897 fn auto_retry_start_serializes_to_pi_mono_format() {
8898 let event = AgentEvent::AutoRetryStart {
8899 attempt: 2,
8900 max_attempts: 3,
8901 delay_ms: 4000,
8902 error_message: "rate limited".to_string(),
8903 };
8904 let json = serde_json::to_value(&event).unwrap();
8905 assert_eq!(json["type"], "auto_retry_start");
8906 assert_eq!(json["attempt"], 2);
8907 assert_eq!(json["maxAttempts"], 3);
8908 assert_eq!(json["delayMs"], 4000);
8909 assert_eq!(json["errorMessage"], "rate limited");
8910 }
8911
8912 #[test]
8913 fn auto_retry_end_success_serializes_to_pi_mono_format() {
8914 let event = AgentEvent::AutoRetryEnd {
8915 success: true,
8916 attempt: 2,
8917 final_error: None,
8918 };
8919 let json = serde_json::to_value(&event).unwrap();
8920 assert_eq!(json["type"], "auto_retry_end");
8921 assert_eq!(json["success"], true);
8922 assert_eq!(json["attempt"], 2);
8923 assert!(json.get("finalError").is_none());
8924 }
8925
8926 #[test]
8927 fn auto_retry_end_failure_serializes_final_error() {
8928 let event = AgentEvent::AutoRetryEnd {
8929 success: false,
8930 attempt: 3,
8931 final_error: Some("max retries exceeded".to_string()),
8932 };
8933 let json = serde_json::to_value(&event).unwrap();
8934 assert_eq!(json["type"], "auto_retry_end");
8935 assert_eq!(json["success"], false);
8936 assert_eq!(json["attempt"], 3);
8937 assert_eq!(json["finalError"], "max retries exceeded");
8938 }
8939}