1use crate::auth::AuthStorage;
16use crate::compaction::{self, ResolvedCompactionSettings};
17use crate::compaction_worker::{CompactionQuota, CompactionWorkerState};
18use crate::error::{Error, Result};
19use crate::extension_events::{InputEventOutcome, apply_input_event_response};
20use crate::extension_tools::collect_extension_tool_wrappers;
21use crate::extensions::{
22 EXTENSION_EVENT_TIMEOUT_MS, ExtensionDeliverAs, ExtensionEventName, ExtensionHostActions,
23 ExtensionLoadSpec, ExtensionManager, ExtensionPolicy, ExtensionRegion, ExtensionRuntimeHandle,
24 ExtensionSendMessage, ExtensionSendUserMessage, JsExtensionLoadSpec, JsExtensionRuntimeHandle,
25 NativeRustExtensionLoadSpec, NativeRustExtensionRuntimeHandle, RepairPolicyMode,
26 resolve_extension_load_spec,
27};
28#[cfg(feature = "wasm-host")]
29use crate::extensions::{WasmExtensionHost, WasmExtensionLoadSpec};
30use crate::extensions_js::{PiJsRuntimeConfig, RepairMode};
31use crate::model::{
32 AssistantMessage, AssistantMessageEvent, ContentBlock, CustomMessage, ImageContent, Message,
33 StopReason, StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage,
34 UserContent, UserMessage,
35};
36use crate::models::{ModelEntry, ModelRegistry};
37use crate::provider::{Context, Provider, StreamOptions, ToolDef};
38use crate::session::{AutosaveFlushTrigger, Session, SessionHandle};
39use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
40use asupersync::sync::{Mutex, Notify};
41use async_trait::async_trait;
42use chrono::Utc;
43use futures::FutureExt;
44use futures::StreamExt;
45use futures::future::BoxFuture;
46use futures::stream;
47use serde::Serialize;
48use serde_json::{Value, json};
49use std::borrow::Cow;
50use std::collections::VecDeque;
51use std::sync::Arc;
52use std::sync::Mutex as StdMutex;
53use std::sync::atomic::{AtomicBool, Ordering};
54
55const MAX_CONCURRENT_TOOLS: usize = 8;
56
57#[derive(Debug, Clone)]
63pub struct AgentConfig {
64 pub system_prompt: Option<String>,
66
67 pub max_tool_iterations: usize,
69
70 pub stream_options: StreamOptions,
72
73 pub block_images: bool,
75}
76
77impl Default for AgentConfig {
78 fn default() -> Self {
79 Self {
80 system_prompt: None,
81 max_tool_iterations: 50,
82 stream_options: StreamOptions::default(),
83 block_images: false,
84 }
85 }
86}
87
88pub type MessageFetcher = Arc<dyn Fn() -> BoxFuture<'static, Vec<Message>> + Send + Sync + 'static>;
90
91type AgentEventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync + 'static>;
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum QueueMode {
95 All,
96 OneAtATime,
97}
98
99impl QueueMode {
100 pub const fn as_str(self) -> &'static str {
101 match self {
102 Self::All => "all",
103 Self::OneAtATime => "one-at-a-time",
104 }
105 }
106}
107
108#[derive(Debug, Clone, Copy)]
109enum QueueKind {
110 Steering,
111 FollowUp,
112}
113
114#[derive(Debug, Clone)]
115struct QueuedMessage {
116 seq: u64,
117 enqueued_at: i64,
118 message: Message,
119}
120
121#[derive(Debug)]
122struct MessageQueue {
123 steering: VecDeque<QueuedMessage>,
124 follow_up: VecDeque<QueuedMessage>,
125 steering_mode: QueueMode,
126 follow_up_mode: QueueMode,
127 next_seq: u64,
128}
129
130impl MessageQueue {
131 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
132 Self {
133 steering: VecDeque::new(),
134 follow_up: VecDeque::new(),
135 steering_mode,
136 follow_up_mode,
137 next_seq: 0,
138 }
139 }
140
141 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
142 self.steering_mode = steering_mode;
143 self.follow_up_mode = follow_up_mode;
144 }
145
146 fn pending_count(&self) -> usize {
147 self.steering.len() + self.follow_up.len()
148 }
149
150 fn push(&mut self, kind: QueueKind, message: Message) -> u64 {
151 let seq = self.next_seq;
152 self.next_seq = self.next_seq.saturating_add(1);
153 let entry = QueuedMessage {
154 seq,
155 enqueued_at: Utc::now().timestamp_millis(),
156 message,
157 };
158 match kind {
159 QueueKind::Steering => self.steering.push_back(entry),
160 QueueKind::FollowUp => self.follow_up.push_back(entry),
161 }
162 seq
163 }
164
165 fn push_steering(&mut self, message: Message) -> u64 {
166 self.push(QueueKind::Steering, message)
167 }
168
169 fn push_follow_up(&mut self, message: Message) -> u64 {
170 self.push(QueueKind::FollowUp, message)
171 }
172
173 fn pop_steering(&mut self) -> Vec<Message> {
174 self.pop_kind(QueueKind::Steering)
175 }
176
177 fn pop_follow_up(&mut self) -> Vec<Message> {
178 self.pop_kind(QueueKind::FollowUp)
179 }
180
181 fn pop_kind(&mut self, kind: QueueKind) -> Vec<Message> {
182 let (queue, mode) = match kind {
183 QueueKind::Steering => (&mut self.steering, self.steering_mode),
184 QueueKind::FollowUp => (&mut self.follow_up, self.follow_up_mode),
185 };
186
187 match mode {
188 QueueMode::All => queue.drain(..).map(|entry| entry.message).collect(),
189 QueueMode::OneAtATime => queue
190 .pop_front()
191 .into_iter()
192 .map(|entry| entry.message)
193 .collect(),
194 }
195 }
196}
197
198#[derive(Debug, Clone, Serialize)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum AgentEvent {
206 AgentStart {
208 #[serde(rename = "sessionId")]
209 session_id: Arc<str>,
210 },
211 AgentEnd {
213 #[serde(rename = "sessionId")]
214 session_id: Arc<str>,
215 messages: Vec<Message>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 error: Option<String>,
218 },
219 TurnStart {
221 #[serde(rename = "sessionId")]
222 session_id: Arc<str>,
223 #[serde(rename = "turnIndex")]
224 turn_index: usize,
225 timestamp: i64,
226 },
227 TurnEnd {
229 #[serde(rename = "sessionId")]
230 session_id: Arc<str>,
231 #[serde(rename = "turnIndex")]
232 turn_index: usize,
233 message: Message,
234 #[serde(rename = "toolResults")]
235 tool_results: Vec<Message>,
236 },
237 MessageStart { message: Message },
239 MessageUpdate {
241 message: Message,
242 #[serde(rename = "assistantMessageEvent")]
243 assistant_message_event: AssistantMessageEvent,
244 },
245 MessageEnd { message: Message },
247 ToolExecutionStart {
249 #[serde(rename = "toolCallId")]
250 tool_call_id: String,
251 #[serde(rename = "toolName")]
252 tool_name: String,
253 args: serde_json::Value,
254 },
255 ToolExecutionUpdate {
257 #[serde(rename = "toolCallId")]
258 tool_call_id: String,
259 #[serde(rename = "toolName")]
260 tool_name: String,
261 args: serde_json::Value,
262 #[serde(rename = "partialResult")]
263 partial_result: ToolOutput,
264 },
265 ToolExecutionEnd {
267 #[serde(rename = "toolCallId")]
268 tool_call_id: String,
269 #[serde(rename = "toolName")]
270 tool_name: String,
271 result: ToolOutput,
272 #[serde(rename = "isError")]
273 is_error: bool,
274 },
275 AutoCompactionStart { reason: String },
277 AutoCompactionEnd {
279 #[serde(skip_serializing_if = "Option::is_none")]
280 result: Option<serde_json::Value>,
281 aborted: bool,
282 #[serde(rename = "willRetry")]
283 will_retry: bool,
284 #[serde(rename = "errorMessage", skip_serializing_if = "Option::is_none")]
285 error_message: Option<String>,
286 },
287 AutoRetryStart {
289 attempt: u32,
290 #[serde(rename = "maxAttempts")]
291 max_attempts: u32,
292 #[serde(rename = "delayMs")]
293 delay_ms: u64,
294 #[serde(rename = "errorMessage")]
295 error_message: String,
296 },
297 AutoRetryEnd {
299 success: bool,
300 attempt: u32,
301 #[serde(rename = "finalError", skip_serializing_if = "Option::is_none")]
302 final_error: Option<String>,
303 },
304 ExtensionError {
306 #[serde(rename = "extensionId", skip_serializing_if = "Option::is_none")]
307 extension_id: Option<String>,
308 event: String,
309 error: String,
310 },
311}
312
313#[derive(Debug, Clone)]
319pub struct AbortHandle {
320 inner: Arc<AbortSignalInner>,
321}
322
323#[derive(Debug, Clone)]
325pub struct AbortSignal {
326 inner: Arc<AbortSignalInner>,
327}
328
329#[derive(Debug)]
330struct AbortSignalInner {
331 aborted: AtomicBool,
332 notify: Notify,
333}
334
335impl AbortHandle {
336 #[must_use]
338 pub fn new() -> (Self, AbortSignal) {
339 let inner = Arc::new(AbortSignalInner {
340 aborted: AtomicBool::new(false),
341 notify: Notify::new(),
342 });
343 (
344 Self {
345 inner: Arc::clone(&inner),
346 },
347 AbortSignal { inner },
348 )
349 }
350
351 pub fn abort(&self) {
353 if !self.inner.aborted.swap(true, Ordering::SeqCst) {
354 self.inner.notify.notify_waiters();
355 }
356 }
357}
358
359impl AbortSignal {
360 #[must_use]
362 pub fn is_aborted(&self) -> bool {
363 self.inner.aborted.load(Ordering::SeqCst)
364 }
365
366 pub async fn wait(&self) {
367 if self.is_aborted() {
368 return;
369 }
370
371 loop {
372 self.inner.notify.notified().await;
373 if self.is_aborted() {
374 return;
375 }
376 }
377 }
378}
379
380pub struct Agent {
382 provider: Arc<dyn Provider>,
384
385 tools: ToolRegistry,
387
388 config: AgentConfig,
390
391 extensions: Option<ExtensionManager>,
393
394 messages: Vec<Message>,
396
397 steering_fetchers: Vec<MessageFetcher>,
399
400 follow_up_fetchers: Vec<MessageFetcher>,
402
403 message_queue: MessageQueue,
405
406 cached_tool_defs: Option<Vec<ToolDef>>,
408}
409
410impl Agent {
411 pub fn new(provider: Arc<dyn Provider>, tools: ToolRegistry, config: AgentConfig) -> Self {
413 Self {
414 provider,
415 tools,
416 config,
417 extensions: None,
418 messages: Vec::new(),
419 steering_fetchers: Vec::new(),
420 follow_up_fetchers: Vec::new(),
421 message_queue: MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime),
422 cached_tool_defs: None,
423 }
424 }
425
426 #[must_use]
428 pub fn messages(&self) -> &[Message] {
429 &self.messages
430 }
431
432 pub fn clear_messages(&mut self) {
434 self.messages.clear();
435 }
436
437 pub fn add_message(&mut self, message: Message) {
439 self.messages.push(message);
440 }
441
442 pub fn replace_messages(&mut self, messages: Vec<Message>) {
444 self.messages = messages;
445 }
446
447 pub fn set_provider(&mut self, provider: Arc<dyn Provider>) {
449 self.provider = provider;
450 }
451
452 pub fn register_message_fetchers(
457 &mut self,
458 steering: Option<MessageFetcher>,
459 follow_up: Option<MessageFetcher>,
460 ) {
461 if let Some(fetcher) = steering {
462 self.steering_fetchers.push(fetcher);
463 }
464 if let Some(fetcher) = follow_up {
465 self.follow_up_fetchers.push(fetcher);
466 }
467 }
468
469 pub fn extend_tools<I>(&mut self, tools: I)
471 where
472 I: IntoIterator<Item = Box<dyn Tool>>,
473 {
474 self.tools.extend(tools);
475 self.cached_tool_defs = None; }
477
478 pub fn queue_steering(&mut self, message: Message) -> u64 {
480 self.message_queue.push_steering(message)
481 }
482
483 pub fn queue_follow_up(&mut self, message: Message) -> u64 {
485 self.message_queue.push_follow_up(message)
486 }
487
488 pub const fn set_queue_modes(&mut self, steering: QueueMode, follow_up: QueueMode) {
490 self.message_queue.set_modes(steering, follow_up);
491 }
492
493 #[must_use]
495 pub fn queued_message_count(&self) -> usize {
496 self.message_queue.pending_count()
497 }
498
499 pub fn provider(&self) -> Arc<dyn Provider> {
500 Arc::clone(&self.provider)
501 }
502
503 pub const fn stream_options(&self) -> &StreamOptions {
504 &self.config.stream_options
505 }
506
507 pub const fn stream_options_mut(&mut self) -> &mut StreamOptions {
508 &mut self.config.stream_options
509 }
510
511 fn build_context(&mut self) -> Context<'_> {
513 let messages: Cow<'_, [Message]> = if self.config.block_images {
514 let mut msgs = self.messages.clone();
515 msgs.retain(|m| match m {
517 Message::Custom(c) => c.display,
518 _ => true,
519 });
520 let stats = filter_images_for_provider(&mut msgs);
521 if stats.removed_images > 0 {
522 tracing::debug!(
523 filtered_images = stats.removed_images,
524 affected_messages = stats.affected_messages,
525 "Filtered image content from outbound provider context (images.block_images=true)"
526 );
527 }
528 Cow::Owned(msgs)
529 } else {
530 let has_hidden = self.messages.iter().any(|m| match m {
532 Message::Custom(c) => !c.display,
533 _ => false,
534 });
535
536 if has_hidden {
537 let mut msgs = self.messages.clone();
538 msgs.retain(|m| match m {
539 Message::Custom(c) => c.display,
540 _ => true,
541 });
542 Cow::Owned(msgs)
543 } else {
544 Cow::Borrowed(self.messages.as_slice())
545 }
546 };
547
548 if self.cached_tool_defs.is_none() {
550 let defs: Vec<ToolDef> = self
551 .tools
552 .tools()
553 .iter()
554 .map(|t| ToolDef {
555 name: t.name().to_string(),
556 description: t.description().to_string(),
557 parameters: t.parameters(),
558 })
559 .collect();
560 self.cached_tool_defs = Some(defs);
561 }
562 let tools = Cow::Borrowed(self.cached_tool_defs.as_deref().unwrap());
563
564 Context {
565 system_prompt: self.config.system_prompt.as_deref().map(Cow::Borrowed),
566 messages,
567 tools,
568 }
569 }
570
571 pub async fn run(
575 &mut self,
576 user_input: impl Into<String>,
577 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
578 ) -> Result<AssistantMessage> {
579 self.run_with_abort(user_input, None, on_event).await
580 }
581
582 pub async fn run_with_abort(
584 &mut self,
585 user_input: impl Into<String>,
586 abort: Option<AbortSignal>,
587 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
588 ) -> Result<AssistantMessage> {
589 let user_message = Message::User(UserMessage {
591 content: UserContent::Text(user_input.into()),
592 timestamp: Utc::now().timestamp_millis(),
593 });
594
595 self.run_loop(vec![user_message], Arc::new(on_event), abort)
597 .await
598 }
599
600 pub async fn run_with_content(
602 &mut self,
603 content: Vec<ContentBlock>,
604 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
605 ) -> Result<AssistantMessage> {
606 self.run_with_content_with_abort(content, None, on_event)
607 .await
608 }
609
610 pub async fn run_with_content_with_abort(
612 &mut self,
613 content: Vec<ContentBlock>,
614 abort: Option<AbortSignal>,
615 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
616 ) -> Result<AssistantMessage> {
617 let user_message = Message::User(UserMessage {
619 content: UserContent::Blocks(content),
620 timestamp: Utc::now().timestamp_millis(),
621 });
622
623 self.run_loop(vec![user_message], Arc::new(on_event), abort)
625 .await
626 }
627
628 pub async fn run_with_message_with_abort(
630 &mut self,
631 message: Message,
632 abort: Option<AbortSignal>,
633 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
634 ) -> Result<AssistantMessage> {
635 self.run_loop(vec![message], Arc::new(on_event), abort)
636 .await
637 }
638
639 pub async fn run_continue_with_abort(
641 &mut self,
642 abort: Option<AbortSignal>,
643 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
644 ) -> Result<AssistantMessage> {
645 self.run_loop(Vec::new(), Arc::new(on_event), abort).await
646 }
647
648 fn build_abort_message(&self, partial: Option<&AssistantMessage>) -> AssistantMessage {
649 let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
650 content: Vec::new(),
651 api: self.provider.api().to_string(),
652 provider: self.provider.name().to_string(),
653 model: self.provider.model_id().to_string(),
654 usage: Usage::default(),
655 stop_reason: StopReason::Aborted,
656 error_message: Some("Aborted".to_string()),
657 timestamp: Utc::now().timestamp_millis(),
658 });
659 message.stop_reason = StopReason::Aborted;
660 message.error_message = Some("Aborted".to_string());
661 message.timestamp = Utc::now().timestamp_millis();
662 message
663 }
664
665 #[allow(clippy::too_many_lines)]
667 async fn run_loop(
668 &mut self,
669 prompts: Vec<Message>,
670 on_event: AgentEventHandler,
671 abort: Option<AbortSignal>,
672 ) -> Result<AssistantMessage> {
673 let session_id: Arc<str> = self
674 .config
675 .stream_options
676 .session_id
677 .as_deref()
678 .unwrap_or("")
679 .into();
680 let mut iterations = 0usize;
681 let mut turn_index: usize = 0;
682 let mut new_messages: Vec<Message> = Vec::with_capacity(prompts.len() + 8);
683 let mut last_assistant: Option<Arc<AssistantMessage>> = None;
684
685 let agent_start_event = AgentEvent::AgentStart {
686 session_id: session_id.clone(),
687 };
688 self.dispatch_extension_lifecycle_event(&agent_start_event)
689 .await;
690 on_event(agent_start_event);
691
692 for prompt in prompts {
693 self.messages.push(prompt.clone());
694 on_event(AgentEvent::MessageStart {
695 message: prompt.clone(),
696 });
697 on_event(AgentEvent::MessageEnd {
698 message: prompt.clone(),
699 });
700 new_messages.push(prompt);
701 }
702
703 let mut pending_messages = self.drain_steering_messages().await;
705
706 loop {
707 let mut has_more_tool_calls = true;
708 let mut steering_after_tools: Option<Vec<Message>> = None;
709
710 while has_more_tool_calls || !pending_messages.is_empty() {
711 let current_turn_index = turn_index;
712 let turn_start_event = AgentEvent::TurnStart {
713 session_id: session_id.clone(),
714 turn_index: current_turn_index,
715 timestamp: Utc::now().timestamp_millis(),
716 };
717 self.dispatch_extension_lifecycle_event(&turn_start_event)
718 .await;
719 on_event(turn_start_event);
720
721 for message in std::mem::take(&mut pending_messages) {
722 self.messages.push(message.clone());
723 on_event(AgentEvent::MessageStart {
724 message: message.clone(),
725 });
726 on_event(AgentEvent::MessageEnd {
727 message: message.clone(),
728 });
729 new_messages.push(message);
730 }
731
732 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
733 let abort_message = self.build_abort_message(None);
734 let message = Message::assistant(abort_message.clone());
735
736 self.messages.push(message.clone());
737 new_messages.push(message.clone());
738 on_event(AgentEvent::MessageStart {
739 message: message.clone(),
740 });
741 on_event(AgentEvent::MessageEnd {
742 message: message.clone(),
743 });
744
745 let turn_end_event = AgentEvent::TurnEnd {
746 session_id: session_id.clone(),
747 turn_index: current_turn_index,
748 message,
749 tool_results: Vec::new(),
750 };
751 self.dispatch_extension_lifecycle_event(&turn_end_event)
752 .await;
753 on_event(turn_end_event);
754 let agent_end_event = AgentEvent::AgentEnd {
755 session_id: session_id.clone(),
756 messages: std::mem::take(&mut new_messages),
757 error: Some(
758 abort_message
759 .error_message
760 .clone()
761 .unwrap_or_else(|| "Aborted".to_string()),
762 ),
763 };
764 self.dispatch_extension_lifecycle_event(&agent_end_event)
765 .await;
766 on_event(agent_end_event);
767 return Ok(abort_message);
768 }
769
770 let assistant_message = match self
771 .stream_assistant_response(Arc::clone(&on_event), abort.clone())
772 .await
773 {
774 Ok(msg) => msg,
775 Err(err) => {
776 let agent_end_event = AgentEvent::AgentEnd {
777 session_id: session_id.clone(),
778 messages: std::mem::take(&mut new_messages),
779 error: Some(err.to_string()),
780 };
781 self.dispatch_extension_lifecycle_event(&agent_end_event)
782 .await;
783 on_event(agent_end_event);
784 return Err(err);
785 }
786 };
787 let assistant_arc = Arc::new(assistant_message);
790 last_assistant = Some(Arc::clone(&assistant_arc));
791
792 let assistant_event_message = Message::Assistant(Arc::clone(&assistant_arc));
793 new_messages.push(assistant_event_message.clone());
794
795 if matches!(
796 assistant_arc.stop_reason,
797 StopReason::Error | StopReason::Aborted
798 ) {
799 let turn_end_event = AgentEvent::TurnEnd {
800 session_id: session_id.clone(),
801 turn_index: current_turn_index,
802 message: assistant_event_message.clone(),
803 tool_results: Vec::new(),
804 };
805 self.dispatch_extension_lifecycle_event(&turn_end_event)
806 .await;
807 on_event(turn_end_event);
808 let agent_end_event = AgentEvent::AgentEnd {
809 session_id: session_id.clone(),
810 messages: std::mem::take(&mut new_messages),
811 error: assistant_arc.error_message.clone(),
812 };
813 self.dispatch_extension_lifecycle_event(&agent_end_event)
814 .await;
815 on_event(agent_end_event);
816 return Ok(Arc::unwrap_or_clone(assistant_arc));
817 }
818
819 let tool_calls = extract_tool_calls(&assistant_arc.content);
820 has_more_tool_calls = !tool_calls.is_empty();
821
822 let mut tool_results: Vec<Arc<ToolResultMessage>> = Vec::new();
823 if has_more_tool_calls {
824 iterations += 1;
825 if iterations > self.config.max_tool_iterations {
826 let error_message = format!(
827 "Maximum tool iterations ({}) exceeded",
828 self.config.max_tool_iterations
829 );
830 let mut stop_message = (*assistant_arc).clone();
831 stop_message.stop_reason = StopReason::Error;
832 stop_message.error_message = Some(error_message.clone());
833 let stop_arc = Arc::new(stop_message.clone());
834 let stop_event_message = Message::Assistant(Arc::clone(&stop_arc));
835
836 if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
839 *last = stop_event_message.clone();
840 }
841 if let Some(last @ Message::Assistant(_)) = new_messages.last_mut() {
842 *last = stop_event_message.clone();
843 }
844
845 let turn_end_event = AgentEvent::TurnEnd {
846 session_id: session_id.clone(),
847 turn_index: current_turn_index,
848 message: stop_event_message,
849 tool_results: Vec::new(),
850 };
851 self.dispatch_extension_lifecycle_event(&turn_end_event)
852 .await;
853 on_event(turn_end_event);
854
855 let agent_end_event = AgentEvent::AgentEnd {
856 session_id: session_id.clone(),
857 messages: std::mem::take(&mut new_messages),
858 error: Some(error_message),
859 };
860 self.dispatch_extension_lifecycle_event(&agent_end_event)
861 .await;
862 on_event(agent_end_event);
863
864 return Ok(stop_message);
865 }
866
867 let outcome = match self
868 .execute_tool_calls(
869 &tool_calls,
870 Arc::clone(&on_event),
871 &mut new_messages,
872 abort.clone(),
873 )
874 .await
875 {
876 Ok(outcome) => outcome,
877 Err(err) => {
878 let agent_end_event = AgentEvent::AgentEnd {
879 session_id: session_id.clone(),
880 messages: std::mem::take(&mut new_messages),
881 error: Some(err.to_string()),
882 };
883 self.dispatch_extension_lifecycle_event(&agent_end_event)
884 .await;
885 on_event(agent_end_event);
886 return Err(err);
887 }
888 };
889 tool_results = outcome.tool_results;
890 steering_after_tools = outcome.steering_messages;
891 }
892
893 let tool_messages = tool_results
894 .iter()
895 .map(|r| Message::ToolResult(Arc::clone(r)))
896 .collect::<Vec<_>>();
897
898 let turn_end_event = AgentEvent::TurnEnd {
899 session_id: session_id.clone(),
900 turn_index: current_turn_index,
901 message: assistant_event_message.clone(),
902 tool_results: tool_messages,
903 };
904 self.dispatch_extension_lifecycle_event(&turn_end_event)
905 .await;
906 on_event(turn_end_event);
907
908 turn_index = turn_index.saturating_add(1);
909
910 if let Some(steering) = steering_after_tools.take() {
911 pending_messages = steering;
912 } else {
913 pending_messages = self.drain_steering_messages().await;
915 }
916 }
917
918 let follow_up = self.drain_follow_up_messages().await;
920 if follow_up.is_empty() {
921 break;
922 }
923 pending_messages = follow_up;
924 }
925
926 let Some(final_arc) = last_assistant else {
927 return Err(Error::api("Agent completed without assistant message"));
928 };
929
930 let agent_end_event = AgentEvent::AgentEnd {
931 session_id: session_id.clone(),
932 messages: new_messages,
933 error: None,
934 };
935 self.dispatch_extension_lifecycle_event(&agent_end_event)
936 .await;
937 on_event(agent_end_event);
938 Ok(Arc::unwrap_or_clone(final_arc))
939 }
940
941 async fn fetch_messages(&self, fetcher: Option<&MessageFetcher>) -> Vec<Message> {
942 if let Some(fetcher) = fetcher {
943 (fetcher)().await
944 } else {
945 Vec::new()
946 }
947 }
948
949 async fn dispatch_extension_lifecycle_event(&self, event: &AgentEvent) {
950 let Some(extensions) = &self.extensions else {
951 return;
952 };
953
954 let name = match event {
955 AgentEvent::AgentStart { .. } => ExtensionEventName::AgentStart,
956 AgentEvent::AgentEnd { .. } => ExtensionEventName::AgentEnd,
957 AgentEvent::TurnStart { .. } => ExtensionEventName::TurnStart,
958 AgentEvent::TurnEnd { .. } => ExtensionEventName::TurnEnd,
959 _ => return,
960 };
961
962 let payload = match serde_json::to_value(event) {
963 Ok(payload) => payload,
964 Err(err) => {
965 tracing::warn!("failed to serialize agent lifecycle event (fail-open): {err}");
966 return;
967 }
968 };
969
970 if let Err(err) = extensions.dispatch_event(name, Some(payload)).await {
971 tracing::warn!("agent lifecycle extension hook failed (fail-open): {err}");
972 }
973 }
974
975 async fn drain_steering_messages(&mut self) -> Vec<Message> {
976 for fetcher in &self.steering_fetchers {
977 let fetched = self.fetch_messages(Some(fetcher)).await;
978 for message in fetched {
979 self.message_queue.push_steering(message);
980 }
981 }
982 self.message_queue.pop_steering()
983 }
984
985 async fn drain_follow_up_messages(&mut self) -> Vec<Message> {
986 for fetcher in &self.follow_up_fetchers {
987 let fetched = self.fetch_messages(Some(fetcher)).await;
988 for message in fetched {
989 self.message_queue.push_follow_up(message);
990 }
991 }
992 self.message_queue.pop_follow_up()
993 }
994
995 #[allow(clippy::too_many_lines)]
997 async fn stream_assistant_response(
998 &mut self,
999 on_event: AgentEventHandler,
1000 abort: Option<AbortSignal>,
1001 ) -> Result<AssistantMessage> {
1002 let provider = Arc::clone(&self.provider);
1004 let stream_options = self.config.stream_options.clone();
1005 let context = self.build_context();
1006 let mut stream = provider.stream(&context, &stream_options).await?;
1007
1008 let mut added_partial = false;
1009 let mut sent_start = false;
1012
1013 loop {
1014 let event_result = if let Some(signal) = abort.as_ref() {
1015 let abort_fut = signal.wait().fuse();
1016 let event_fut = stream.next().fuse();
1017 futures::pin_mut!(abort_fut, event_fut);
1018
1019 match futures::future::select(abort_fut, event_fut).await {
1020 futures::future::Either::Left(((), _event_fut)) => {
1021 let last_partial = if added_partial {
1022 match self.messages.last() {
1023 Some(Message::Assistant(a)) => Some(a.as_ref()),
1024 _ => None,
1025 }
1026 } else {
1027 None
1028 };
1029 let abort_arc = Arc::new(self.build_abort_message(last_partial));
1030 if !sent_start {
1031 on_event(AgentEvent::MessageStart {
1032 message: Message::Assistant(Arc::clone(&abort_arc)),
1033 });
1034 self.messages
1035 .push(Message::Assistant(Arc::clone(&abort_arc)));
1036 added_partial = true;
1037 }
1041 on_event(AgentEvent::MessageUpdate {
1042 message: Message::Assistant(Arc::clone(&abort_arc)),
1043 assistant_message_event: AssistantMessageEvent::Error {
1044 reason: StopReason::Aborted,
1045 error: Arc::clone(&abort_arc),
1046 },
1047 });
1048 return Ok(self.finalize_assistant_message(
1049 Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
1050 &on_event,
1051 added_partial,
1052 ));
1053 }
1054 futures::future::Either::Right((event, _abort_fut)) => event,
1055 }
1056 } else {
1057 stream.next().await
1058 };
1059
1060 let Some(event_result) = event_result else {
1061 break;
1062 };
1063 let event = event_result?;
1064
1065 match event {
1066 StreamEvent::Start { partial } => {
1067 let shared = Arc::new(partial);
1068 self.update_partial_message(Arc::clone(&shared), &mut added_partial);
1069 on_event(AgentEvent::MessageStart {
1070 message: Message::Assistant(Arc::clone(&shared)),
1071 });
1072 sent_start = true;
1073 on_event(AgentEvent::MessageUpdate {
1074 message: Message::Assistant(Arc::clone(&shared)),
1075 assistant_message_event: AssistantMessageEvent::Start { partial: shared },
1076 });
1077 }
1078 StreamEvent::TextStart { content_index, .. } => {
1079 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1080 let msg = Arc::make_mut(msg_arc);
1081 if content_index == msg.content.len() {
1082 msg.content.push(ContentBlock::Text(TextContent::new("")));
1083 }
1084 let shared = Arc::clone(msg_arc);
1085 if !sent_start {
1086 on_event(AgentEvent::MessageStart {
1087 message: Message::Assistant(Arc::clone(&shared)),
1088 });
1089 sent_start = true;
1090 }
1091 on_event(AgentEvent::MessageUpdate {
1092 message: Message::Assistant(Arc::clone(&shared)),
1093 assistant_message_event: AssistantMessageEvent::TextStart {
1094 content_index,
1095 partial: shared,
1096 },
1097 });
1098 }
1099 }
1100 StreamEvent::TextDelta {
1101 content_index,
1102 delta,
1103 ..
1104 } => {
1105 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1106 {
1107 let msg = Arc::make_mut(msg_arc);
1108 if let Some(ContentBlock::Text(text)) =
1109 msg.content.get_mut(content_index)
1110 {
1111 text.text.push_str(&delta);
1112 }
1113 }
1114 let shared = Arc::clone(msg_arc);
1115 if !sent_start {
1116 on_event(AgentEvent::MessageStart {
1117 message: Message::Assistant(Arc::clone(&shared)),
1118 });
1119 sent_start = true;
1120 }
1121 on_event(AgentEvent::MessageUpdate {
1122 message: Message::Assistant(Arc::clone(&shared)),
1123 assistant_message_event: AssistantMessageEvent::TextDelta {
1124 content_index,
1125 delta,
1126 partial: shared,
1127 },
1128 });
1129 }
1130 }
1131 StreamEvent::TextEnd {
1132 content_index,
1133 content,
1134 ..
1135 } => {
1136 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1137 {
1138 let msg = Arc::make_mut(msg_arc);
1139 if let Some(ContentBlock::Text(text)) =
1140 msg.content.get_mut(content_index)
1141 {
1142 text.text.clone_from(&content);
1143 }
1144 }
1145 let shared = Arc::clone(msg_arc);
1146 if !sent_start {
1147 on_event(AgentEvent::MessageStart {
1148 message: Message::Assistant(Arc::clone(&shared)),
1149 });
1150 sent_start = true;
1151 }
1152 on_event(AgentEvent::MessageUpdate {
1153 message: Message::Assistant(Arc::clone(&shared)),
1154 assistant_message_event: AssistantMessageEvent::TextEnd {
1155 content_index,
1156 content,
1157 partial: shared,
1158 },
1159 });
1160 }
1161 }
1162 StreamEvent::ThinkingStart { content_index, .. } => {
1163 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1164 let msg = Arc::make_mut(msg_arc);
1165 if content_index == msg.content.len() {
1166 msg.content.push(ContentBlock::Thinking(ThinkingContent {
1167 thinking: String::new(),
1168 thinking_signature: None,
1169 }));
1170 }
1171 let shared = Arc::clone(msg_arc);
1172 if !sent_start {
1173 on_event(AgentEvent::MessageStart {
1174 message: Message::Assistant(Arc::clone(&shared)),
1175 });
1176 sent_start = true;
1177 }
1178 on_event(AgentEvent::MessageUpdate {
1179 message: Message::Assistant(Arc::clone(&shared)),
1180 assistant_message_event: AssistantMessageEvent::ThinkingStart {
1181 content_index,
1182 partial: shared,
1183 },
1184 });
1185 }
1186 }
1187 StreamEvent::ThinkingDelta {
1188 content_index,
1189 delta,
1190 ..
1191 } => {
1192 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1193 {
1194 let msg = Arc::make_mut(msg_arc);
1195 if let Some(ContentBlock::Thinking(thinking)) =
1196 msg.content.get_mut(content_index)
1197 {
1198 thinking.thinking.push_str(&delta);
1199 }
1200 }
1201 let shared = Arc::clone(msg_arc);
1202 if !sent_start {
1203 on_event(AgentEvent::MessageStart {
1204 message: Message::Assistant(Arc::clone(&shared)),
1205 });
1206 sent_start = true;
1207 }
1208 on_event(AgentEvent::MessageUpdate {
1209 message: Message::Assistant(Arc::clone(&shared)),
1210 assistant_message_event: AssistantMessageEvent::ThinkingDelta {
1211 content_index,
1212 delta,
1213 partial: shared,
1214 },
1215 });
1216 }
1217 }
1218 StreamEvent::ThinkingEnd {
1219 content_index,
1220 content,
1221 ..
1222 } => {
1223 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1224 {
1225 let msg = Arc::make_mut(msg_arc);
1226 if let Some(ContentBlock::Thinking(thinking)) =
1227 msg.content.get_mut(content_index)
1228 {
1229 thinking.thinking.clone_from(&content);
1230 }
1231 }
1232 let shared = Arc::clone(msg_arc);
1233 if !sent_start {
1234 on_event(AgentEvent::MessageStart {
1235 message: Message::Assistant(Arc::clone(&shared)),
1236 });
1237 sent_start = true;
1238 }
1239 on_event(AgentEvent::MessageUpdate {
1240 message: Message::Assistant(Arc::clone(&shared)),
1241 assistant_message_event: AssistantMessageEvent::ThinkingEnd {
1242 content_index,
1243 content,
1244 partial: shared,
1245 },
1246 });
1247 }
1248 }
1249 StreamEvent::ToolCallStart { content_index, .. } => {
1250 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1251 let msg = Arc::make_mut(msg_arc);
1252 if content_index == msg.content.len() {
1253 msg.content.push(ContentBlock::ToolCall(ToolCall {
1254 id: String::new(),
1255 name: String::new(),
1256 arguments: serde_json::Value::Null,
1257 thought_signature: None,
1258 }));
1259 }
1260 let shared = Arc::clone(msg_arc);
1261 if !sent_start {
1262 on_event(AgentEvent::MessageStart {
1263 message: Message::Assistant(Arc::clone(&shared)),
1264 });
1265 sent_start = true;
1266 }
1267 on_event(AgentEvent::MessageUpdate {
1268 message: Message::Assistant(Arc::clone(&shared)),
1269 assistant_message_event: AssistantMessageEvent::ToolCallStart {
1270 content_index,
1271 partial: shared,
1272 },
1273 });
1274 }
1275 }
1276 StreamEvent::ToolCallDelta {
1277 content_index,
1278 delta,
1279 ..
1280 } => {
1281 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1282 let shared = Arc::clone(msg_arc);
1285 if !sent_start {
1286 on_event(AgentEvent::MessageStart {
1287 message: Message::Assistant(Arc::clone(&shared)),
1288 });
1289 sent_start = true;
1290 }
1291 on_event(AgentEvent::MessageUpdate {
1292 message: Message::Assistant(Arc::clone(&shared)),
1293 assistant_message_event: AssistantMessageEvent::ToolCallDelta {
1294 content_index,
1295 delta,
1296 partial: shared,
1297 },
1298 });
1299 }
1300 }
1301 StreamEvent::ToolCallEnd {
1302 content_index,
1303 tool_call,
1304 ..
1305 } => {
1306 if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1307 {
1308 let msg = Arc::make_mut(msg_arc);
1309 if let Some(ContentBlock::ToolCall(tc)) =
1310 msg.content.get_mut(content_index)
1311 {
1312 *tc = tool_call.clone();
1313 }
1314 }
1315 let shared = Arc::clone(msg_arc);
1316 if !sent_start {
1317 on_event(AgentEvent::MessageStart {
1318 message: Message::Assistant(Arc::clone(&shared)),
1319 });
1320 sent_start = true;
1321 }
1322 on_event(AgentEvent::MessageUpdate {
1323 message: Message::Assistant(Arc::clone(&shared)),
1324 assistant_message_event: AssistantMessageEvent::ToolCallEnd {
1325 content_index,
1326 tool_call,
1327 partial: shared,
1328 },
1329 });
1330 }
1331 }
1332 StreamEvent::Done { message, .. } => {
1333 return Ok(self.finalize_assistant_message(message, &on_event, added_partial));
1334 }
1335 StreamEvent::Error { error, .. } => {
1336 return Ok(self.finalize_assistant_message(error, &on_event, added_partial));
1337 }
1338 }
1339 }
1340
1341 if added_partial {
1345 if let Some(Message::Assistant(last_msg)) = self.messages.last() {
1346 let mut final_msg = (**last_msg).clone();
1347 final_msg.stop_reason = StopReason::Error;
1348 final_msg.error_message = Some("Stream ended without Done event".to_string());
1349 return Ok(self.finalize_assistant_message(final_msg, &on_event, true));
1350 }
1351 }
1352 Err(Error::api("Stream ended without Done event"))
1353 }
1354
1355 fn update_partial_message(
1360 &mut self,
1361 partial: Arc<AssistantMessage>,
1362 added_partial: &mut bool,
1363 ) -> bool {
1364 if *added_partial {
1365 if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
1366 *last = Message::Assistant(partial);
1367 } else {
1368 tracing::warn!("update_partial_message: expected last message to be Assistant");
1371 self.messages.push(Message::Assistant(partial));
1372 }
1373 false
1374 } else {
1375 self.messages.push(Message::Assistant(partial));
1376 *added_partial = true;
1377 true
1378 }
1379 }
1380
1381 fn finalize_assistant_message(
1382 &mut self,
1383 message: AssistantMessage,
1384 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
1385 added_partial: bool,
1386 ) -> AssistantMessage {
1387 let arc = Arc::new(message);
1388 if added_partial {
1389 if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
1390 *last = Message::Assistant(Arc::clone(&arc));
1391 } else {
1392 tracing::warn!("finalize_assistant_message: expected last message to be Assistant");
1395 self.messages.push(Message::Assistant(Arc::clone(&arc)));
1396 on_event(AgentEvent::MessageStart {
1397 message: Message::Assistant(Arc::clone(&arc)),
1398 });
1399 }
1400 } else {
1401 self.messages.push(Message::Assistant(Arc::clone(&arc)));
1402 on_event(AgentEvent::MessageStart {
1403 message: Message::Assistant(Arc::clone(&arc)),
1404 });
1405 }
1406
1407 on_event(AgentEvent::MessageEnd {
1408 message: Message::Assistant(Arc::clone(&arc)),
1409 });
1410 Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone())
1411 }
1412
1413 async fn execute_parallel_batch(
1414 &self,
1415 batch: Vec<(usize, ToolCall)>,
1416 on_event: AgentEventHandler,
1417 abort: Option<AbortSignal>,
1418 ) -> Vec<(usize, (ToolOutput, bool))> {
1419 let futures = batch.into_iter().map(|(idx, tc)| {
1420 let on_event = Arc::clone(&on_event);
1421 async move { (idx, self.execute_tool_owned(tc, on_event).await) }
1422 });
1423
1424 if let Some(signal) = abort.as_ref() {
1425 use futures::future::{Either, select};
1426 let all_fut = stream::iter(futures)
1427 .buffer_unordered(MAX_CONCURRENT_TOOLS)
1428 .collect::<Vec<_>>()
1429 .fuse();
1430 let abort_fut = signal.wait().fuse();
1431 futures::pin_mut!(all_fut, abort_fut);
1432
1433 match select(all_fut, abort_fut).await {
1434 Either::Left((batch_results, _)) => batch_results,
1435 Either::Right(_) => Vec::new(), }
1437 } else {
1438 stream::iter(futures)
1439 .buffer_unordered(MAX_CONCURRENT_TOOLS)
1440 .collect::<Vec<_>>()
1441 .await
1442 }
1443 }
1444
1445 #[allow(clippy::too_many_lines)]
1446 async fn execute_tool_calls(
1447 &mut self,
1448 tool_calls: &[ToolCall],
1449 on_event: AgentEventHandler,
1450 new_messages: &mut Vec<Message>,
1451 abort: Option<AbortSignal>,
1452 ) -> Result<ToolExecutionOutcome> {
1453 let mut results = Vec::new();
1454 let mut steering_messages: Option<Vec<Message>> = None;
1455
1456 for tool_call in tool_calls {
1458 on_event(AgentEvent::ToolExecutionStart {
1459 tool_call_id: tool_call.id.clone(),
1460 tool_name: tool_call.name.clone(),
1461 args: tool_call.arguments.clone(),
1462 });
1463 }
1464
1465 let mut pending_parallel: Vec<(usize, ToolCall)> = Vec::new();
1467 let mut tool_outputs: Vec<Option<(ToolOutput, bool)>> = vec![None; tool_calls.len()];
1468
1469 for (index, tool_call) in tool_calls.iter().enumerate() {
1471 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1472 break;
1473 }
1474
1475 let is_read_only =
1476 matches!(self.tools.get(&tool_call.name), Some(tool) if tool.is_read_only());
1477
1478 if is_read_only {
1479 pending_parallel.push((index, tool_call.clone()));
1480 } else {
1481 let steering = self.drain_steering_messages().await;
1483 if !steering.is_empty() {
1484 steering_messages = Some(steering);
1485 break;
1486 }
1487
1488 if !pending_parallel.is_empty() {
1490 let batch = std::mem::take(&mut pending_parallel);
1491 let results = self
1492 .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
1493 .await;
1494 for (idx, result) in results {
1495 tool_outputs[idx] = Some(result);
1496 }
1497 }
1498
1499 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1500 break;
1501 }
1502
1503 let steering = self.drain_steering_messages().await;
1506 if !steering.is_empty() {
1507 steering_messages = Some(steering);
1508 break;
1509 }
1510
1511 let result = self
1512 .execute_tool(tool_call.clone(), Arc::clone(&on_event))
1513 .await;
1514 tool_outputs[index] = Some(result);
1515 }
1516 }
1517
1518 if !pending_parallel.is_empty()
1520 && !abort.as_ref().is_some_and(AbortSignal::is_aborted)
1521 && steering_messages.is_none()
1522 {
1523 let batch = std::mem::take(&mut pending_parallel);
1524 let steering = self.drain_steering_messages().await;
1526 if steering.is_empty() {
1527 let results = self
1528 .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
1529 .await;
1530 for (idx, result) in results {
1531 tool_outputs[idx] = Some(result);
1532 }
1533 } else {
1534 steering_messages = Some(steering);
1535 }
1536 }
1537
1538 for (index, tool_call) in tool_calls.iter().enumerate() {
1540 if steering_messages.is_none() && !abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1543 let steering = self.drain_steering_messages().await;
1544 if !steering.is_empty() {
1545 steering_messages = Some(steering);
1546 }
1547 }
1548
1549 if let Some((output, is_error)) = tool_outputs[index].take() {
1553 let tool_result = Arc::new(ToolResultMessage {
1558 tool_call_id: tool_call.id.clone(),
1559 tool_name: tool_call.name.clone(),
1560 content: output.content,
1561 details: output.details,
1562 is_error,
1563 timestamp: Utc::now().timestamp_millis(),
1564 });
1565
1566 on_event(AgentEvent::ToolExecutionEnd {
1569 tool_call_id: tool_result.tool_call_id.clone(),
1570 tool_name: tool_result.tool_name.clone(),
1571 result: ToolOutput {
1572 content: tool_result.content.clone(),
1573 details: tool_result.details.clone(),
1574 is_error,
1575 },
1576 is_error,
1577 });
1578
1579 let msg = Message::ToolResult(Arc::clone(&tool_result));
1580 self.messages.push(msg.clone());
1581 on_event(AgentEvent::MessageStart {
1582 message: msg.clone(),
1583 });
1584 new_messages.push(msg.clone());
1585 on_event(AgentEvent::MessageEnd { message: msg });
1586
1587 results.push(tool_result);
1588 } else if steering_messages.is_some() {
1589 results.push(self.skip_tool_call(tool_call, &on_event, new_messages));
1591 } else {
1592 let output = ToolOutput {
1594 content: vec![ContentBlock::Text(TextContent::new(
1595 "Tool execution aborted",
1596 ))],
1597 details: None,
1598 is_error: true,
1599 };
1600
1601 on_event(AgentEvent::ToolExecutionUpdate {
1602 tool_call_id: tool_call.id.clone(),
1603 tool_name: tool_call.name.clone(),
1604 args: tool_call.arguments.clone(),
1605 partial_result: ToolOutput {
1606 content: output.content.clone(),
1607 details: output.details.clone(),
1608 is_error: true,
1609 },
1610 });
1611
1612 on_event(AgentEvent::ToolExecutionEnd {
1613 tool_call_id: tool_call.id.clone(),
1614 tool_name: tool_call.name.clone(),
1615 result: ToolOutput {
1616 content: output.content.clone(),
1617 details: output.details.clone(),
1618 is_error: true,
1619 },
1620 is_error: true,
1621 });
1622
1623 let tool_result = Arc::new(ToolResultMessage {
1624 tool_call_id: tool_call.id.clone(),
1625 tool_name: tool_call.name.clone(),
1626 content: output.content,
1627 details: output.details,
1628 is_error: true,
1629 timestamp: Utc::now().timestamp_millis(),
1630 });
1631
1632 let msg = Message::ToolResult(Arc::clone(&tool_result));
1633 self.messages.push(msg.clone());
1634 on_event(AgentEvent::MessageStart {
1635 message: msg.clone(),
1636 });
1637 let end_msg = msg.clone();
1638 new_messages.push(msg);
1639 on_event(AgentEvent::MessageEnd { message: end_msg });
1640
1641 results.push(tool_result);
1642 }
1643 }
1644
1645 Ok(ToolExecutionOutcome {
1646 tool_results: results,
1647 steering_messages,
1648 })
1649 }
1650
1651 async fn execute_tool(
1652 &self,
1653 tool_call: ToolCall,
1654 on_event: AgentEventHandler,
1655 ) -> (ToolOutput, bool) {
1656 let extensions = self.extensions.clone();
1657
1658 let (mut output, is_error) = if let Some(extensions) = &extensions {
1659 match Self::dispatch_tool_call_hook(extensions, &tool_call).await {
1660 Some(blocked_output) => (blocked_output, true),
1661 None => {
1662 self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
1663 .await
1664 }
1665 }
1666 } else {
1667 self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
1668 .await
1669 };
1670
1671 if let Some(extensions) = &extensions {
1672 Self::apply_tool_result_hook(extensions, &tool_call, &mut output, is_error).await;
1673 }
1674
1675 (output, is_error)
1676 }
1677
1678 async fn execute_tool_owned(
1679 &self,
1680 tool_call: ToolCall,
1681 on_event: AgentEventHandler,
1682 ) -> (ToolOutput, bool) {
1683 self.execute_tool(tool_call, on_event).await
1684 }
1685
1686 async fn execute_tool_without_hooks(
1687 &self,
1688 tool_call: &ToolCall,
1689 on_event: AgentEventHandler,
1690 ) -> (ToolOutput, bool) {
1691 let Some(tool) = self.tools.get(&tool_call.name) else {
1693 return (Self::tool_not_found_output(&tool_call.name), true);
1694 };
1695
1696 let tool_name = tool_call.name.clone();
1697 let tool_id = tool_call.id.clone();
1698 let tool_args = tool_call.arguments.clone();
1699 let on_event = Arc::clone(&on_event);
1700
1701 let update_callback = move |update: ToolUpdate| {
1702 on_event(AgentEvent::ToolExecutionUpdate {
1703 tool_call_id: tool_id.clone(),
1704 tool_name: tool_name.clone(),
1705 args: tool_args.clone(),
1706 partial_result: ToolOutput {
1707 content: update.content,
1708 details: update.details,
1709 is_error: false,
1710 },
1711 });
1712 };
1713
1714 match tool
1715 .execute(
1716 &tool_call.id,
1717 tool_call.arguments.clone(),
1718 Some(Box::new(update_callback)),
1719 )
1720 .await
1721 {
1722 Ok(output) => {
1723 let is_error = output.is_error;
1724 (output, is_error)
1725 }
1726 Err(e) => (
1727 ToolOutput {
1728 content: vec![ContentBlock::Text(TextContent::new(format!("Error: {e}")))],
1729 details: None,
1730 is_error: true,
1731 },
1732 true,
1733 ),
1734 }
1735 }
1736
1737 fn tool_not_found_output(tool_name: &str) -> ToolOutput {
1738 ToolOutput {
1739 content: vec![ContentBlock::Text(TextContent::new(format!(
1740 "Error: Tool '{tool_name}' not found"
1741 )))],
1742 details: None,
1743 is_error: true,
1744 }
1745 }
1746
1747 async fn dispatch_tool_call_hook(
1748 extensions: &ExtensionManager,
1749 tool_call: &ToolCall,
1750 ) -> Option<ToolOutput> {
1751 match extensions
1752 .dispatch_tool_call(tool_call, EXTENSION_EVENT_TIMEOUT_MS)
1753 .await
1754 {
1755 Ok(Some(result)) if result.block => {
1756 Some(Self::tool_call_blocked_output(result.reason.as_deref()))
1757 }
1758 Ok(_) => None,
1759 Err(err) => {
1760 tracing::warn!("tool_call extension hook failed (fail-open): {err}");
1761 None
1762 }
1763 }
1764 }
1765
1766 fn tool_call_blocked_output(reason: Option<&str>) -> ToolOutput {
1767 let reason = reason.map(str::trim).filter(|reason| !reason.is_empty());
1768 let message = reason.map_or_else(
1769 || "Tool execution was blocked by an extension".to_string(),
1770 |reason| format!("Tool execution blocked: {reason}"),
1771 );
1772
1773 ToolOutput {
1774 content: vec![ContentBlock::Text(TextContent::new(message))],
1775 details: None,
1776 is_error: true,
1777 }
1778 }
1779
1780 async fn apply_tool_result_hook(
1781 extensions: &ExtensionManager,
1782 tool_call: &ToolCall,
1783 output: &mut ToolOutput,
1784 is_error: bool,
1785 ) {
1786 match extensions
1787 .dispatch_tool_result(tool_call, &*output, is_error, EXTENSION_EVENT_TIMEOUT_MS)
1788 .await
1789 {
1790 Ok(Some(result)) => {
1791 if let Some(content) = result.content {
1792 output.content = content;
1793 }
1794 if let Some(details) = result.details {
1795 output.details = Some(details);
1796 }
1797 }
1798 Ok(None) => {}
1799 Err(err) => tracing::warn!("tool_result extension hook failed (fail-open): {err}"),
1800 }
1801 }
1802
1803 fn skip_tool_call(
1804 &mut self,
1805 tool_call: &ToolCall,
1806 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
1807 new_messages: &mut Vec<Message>,
1808 ) -> Arc<ToolResultMessage> {
1809 let output = ToolOutput {
1810 content: vec![ContentBlock::Text(TextContent::new(
1811 "Skipped due to queued user message.",
1812 ))],
1813 details: None,
1814 is_error: true,
1815 };
1816
1817 on_event(AgentEvent::ToolExecutionUpdate {
1820 tool_call_id: tool_call.id.clone(),
1821 tool_name: tool_call.name.clone(),
1822 args: tool_call.arguments.clone(),
1823 partial_result: output.clone(),
1824 });
1825 on_event(AgentEvent::ToolExecutionEnd {
1826 tool_call_id: tool_call.id.clone(),
1827 tool_name: tool_call.name.clone(),
1828 result: output.clone(),
1829 is_error: true,
1830 });
1831
1832 let tool_result = Arc::new(ToolResultMessage {
1833 tool_call_id: tool_call.id.clone(),
1834 tool_name: tool_call.name.clone(),
1835 content: output.content,
1836 details: output.details,
1837 is_error: true,
1838 timestamp: Utc::now().timestamp_millis(),
1839 });
1840
1841 let msg = Message::ToolResult(Arc::clone(&tool_result));
1842 self.messages.push(msg.clone());
1843 new_messages.push(msg.clone());
1844
1845 on_event(AgentEvent::MessageStart {
1846 message: msg.clone(),
1847 });
1848 on_event(AgentEvent::MessageEnd { message: msg });
1849
1850 tool_result
1851 }
1852}
1853
1854struct ToolExecutionOutcome {
1859 tool_results: Vec<Arc<ToolResultMessage>>,
1860 steering_messages: Option<Vec<Message>>,
1861}
1862
1863pub struct PreWarmedExtensionRuntime {
1868 pub manager: ExtensionManager,
1870 pub runtime: ExtensionRuntimeHandle,
1872 pub tools: Arc<ToolRegistry>,
1874}
1875
1876pub struct AgentSession {
1877 pub agent: Agent,
1878 pub session: Arc<Mutex<Session>>,
1879 save_enabled: bool,
1880 pub extensions: Option<ExtensionRegion>,
1883 extensions_is_streaming: Arc<AtomicBool>,
1884 compaction_settings: ResolvedCompactionSettings,
1885 compaction_worker: CompactionWorkerState,
1886 model_registry: Option<ModelRegistry>,
1887 auth_storage: Option<AuthStorage>,
1888}
1889
1890#[derive(Debug, Default)]
1891struct ExtensionInjectedQueue {
1892 steering: VecDeque<Message>,
1893 follow_up: VecDeque<Message>,
1894}
1895
1896impl ExtensionInjectedQueue {
1897 fn push_steering(&mut self, message: Message) {
1898 self.steering.push_back(message);
1899 }
1900
1901 fn push_follow_up(&mut self, message: Message) {
1902 self.follow_up.push_back(message);
1903 }
1904
1905 fn pop_steering(&mut self) -> Vec<Message> {
1906 self.steering.drain(..).collect()
1907 }
1908
1909 fn pop_follow_up(&mut self) -> Vec<Message> {
1910 self.follow_up.drain(..).collect()
1911 }
1912}
1913
1914#[derive(Clone)]
1915struct AgentSessionHostActions {
1916 session: Arc<Mutex<Session>>,
1917 injected: Arc<StdMutex<ExtensionInjectedQueue>>,
1918 is_streaming: Arc<AtomicBool>,
1919}
1920
1921impl AgentSessionHostActions {
1922 fn enqueue(&self, deliver_as: Option<ExtensionDeliverAs>, message: Message) {
1923 let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
1924 let Ok(mut queue) = self.injected.lock() else {
1925 return;
1926 };
1927 match deliver_as {
1928 ExtensionDeliverAs::FollowUp => {
1929 queue.push_follow_up(message);
1930 }
1931 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
1932 queue.push_steering(message);
1933 }
1934 }
1935 }
1936
1937 async fn append_to_session(&self, message: Message) -> Result<()> {
1938 let cx = crate::agent_cx::AgentCx::for_request();
1939 let mut session = self
1940 .session
1941 .lock(cx.cx())
1942 .await
1943 .map_err(|e| Error::session(e.to_string()))?;
1944 session.append_model_message(message);
1945 Ok(())
1946 }
1947}
1948
1949#[async_trait]
1950impl ExtensionHostActions for AgentSessionHostActions {
1951 async fn send_message(&self, message: ExtensionSendMessage) -> Result<()> {
1952 let custom_message = Message::Custom(CustomMessage {
1953 content: message.content,
1954 custom_type: message.custom_type,
1955 display: message.display,
1956 details: message.details,
1957 timestamp: Utc::now().timestamp_millis(),
1958 });
1959
1960 if matches!(message.deliver_as, Some(ExtensionDeliverAs::NextTurn)) {
1961 return self.append_to_session(custom_message).await;
1962 }
1963
1964 if self.is_streaming.load(Ordering::SeqCst) {
1965 self.enqueue(message.deliver_as, custom_message);
1966 return Ok(());
1967 }
1968
1969 let _ = message.trigger_turn;
1972 self.append_to_session(custom_message).await
1973 }
1974
1975 async fn send_user_message(&self, message: ExtensionSendUserMessage) -> Result<()> {
1976 let user_message = Message::User(UserMessage {
1977 content: UserContent::Text(message.text),
1978 timestamp: Utc::now().timestamp_millis(),
1979 });
1980
1981 if self.is_streaming.load(Ordering::SeqCst) {
1982 self.enqueue(message.deliver_as, user_message);
1983 return Ok(());
1984 }
1985
1986 self.append_to_session(user_message).await
1988 }
1989}
1990
1991#[cfg(test)]
1992mod message_queue_tests {
1993 use super::*;
1994
1995 fn user_message(text: &str) -> Message {
1996 Message::User(UserMessage {
1997 content: UserContent::Text(text.to_string()),
1998 timestamp: 0,
1999 })
2000 }
2001
2002 #[test]
2003 fn message_queue_one_at_a_time() {
2004 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2005 queue.push_steering(user_message("a"));
2006 queue.push_steering(user_message("b"));
2007
2008 let first = queue.pop_steering();
2009 assert_eq!(first.len(), 1);
2010 assert!(matches!(
2011 first.first(),
2012 Some(Message::User(UserMessage { content, .. }))
2013 if matches!(content, UserContent::Text(text) if text == "a")
2014 ));
2015
2016 let second = queue.pop_steering();
2017 assert_eq!(second.len(), 1);
2018 assert!(matches!(
2019 second.first(),
2020 Some(Message::User(UserMessage { content, .. }))
2021 if matches!(content, UserContent::Text(text) if text == "b")
2022 ));
2023
2024 assert!(queue.pop_steering().is_empty());
2025 }
2026
2027 #[test]
2028 fn message_queue_all_mode() {
2029 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
2030 queue.push_steering(user_message("a"));
2031 queue.push_steering(user_message("b"));
2032
2033 let drained = queue.pop_steering();
2034 assert_eq!(drained.len(), 2);
2035 assert!(queue.pop_steering().is_empty());
2036 }
2037
2038 #[test]
2039 fn message_queue_separates_kinds() {
2040 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2041 queue.push_steering(user_message("steer"));
2042 queue.push_follow_up(user_message("follow"));
2043
2044 let steering = queue.pop_steering();
2045 assert_eq!(steering.len(), 1);
2046 assert_eq!(queue.pending_count(), 1);
2047
2048 let follow = queue.pop_follow_up();
2049 assert_eq!(follow.len(), 1);
2050 assert_eq!(queue.pending_count(), 0);
2051 }
2052
2053 #[test]
2054 fn message_queue_seq_increments() {
2055 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2056 let first = queue.push_steering(user_message("a"));
2057 let second = queue.push_follow_up(user_message("b"));
2058 assert!(second > first);
2059 }
2060
2061 #[test]
2062 fn message_queue_seq_saturates_at_u64_max() {
2063 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2064 queue.next_seq = u64::MAX;
2065
2066 let first = queue.push_steering(user_message("a"));
2067 let second = queue.push_follow_up(user_message("b"));
2068
2069 assert_eq!(first, u64::MAX);
2070 assert_eq!(second, u64::MAX);
2071 assert_eq!(queue.pending_count(), 2);
2072 }
2073
2074 #[test]
2075 fn message_queue_follow_up_all_mode_drains_entire_queue_in_order() {
2076 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::All);
2077 queue.push_follow_up(user_message("f1"));
2078 queue.push_follow_up(user_message("f2"));
2079
2080 let follow_up = queue.pop_follow_up();
2081 assert_eq!(follow_up.len(), 2);
2082 assert!(matches!(
2083 follow_up.first(),
2084 Some(Message::User(UserMessage { content, .. }))
2085 if matches!(content, UserContent::Text(text) if text == "f1")
2086 ));
2087 assert!(matches!(
2088 follow_up.get(1),
2089 Some(Message::User(UserMessage { content, .. }))
2090 if matches!(content, UserContent::Text(text) if text == "f2")
2091 ));
2092 assert!(queue.pop_follow_up().is_empty());
2093 }
2094}
2095
2096#[cfg(test)]
2097mod extensions_integration_tests {
2098 use super::*;
2099
2100 use crate::session::Session;
2101 use asupersync::runtime::RuntimeBuilder;
2102 use async_trait::async_trait;
2103 use futures::Stream;
2104 use serde_json::json;
2105 use std::path::Path;
2106 use std::pin::Pin;
2107 use std::sync::atomic::AtomicUsize;
2108
2109 #[derive(Debug)]
2110 struct NoopProvider;
2111
2112 #[async_trait]
2113 #[allow(clippy::unnecessary_literal_bound)]
2114 impl Provider for NoopProvider {
2115 fn name(&self) -> &str {
2116 "test-provider"
2117 }
2118
2119 fn api(&self) -> &str {
2120 "test-api"
2121 }
2122
2123 fn model_id(&self) -> &str {
2124 "test-model"
2125 }
2126
2127 async fn stream(
2128 &self,
2129 _context: &Context<'_>,
2130 _options: &StreamOptions,
2131 ) -> crate::error::Result<
2132 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2133 > {
2134 Ok(Box::pin(futures::stream::empty()))
2135 }
2136 }
2137
2138 #[derive(Debug)]
2139 struct CountingTool {
2140 calls: Arc<AtomicUsize>,
2141 }
2142
2143 #[async_trait]
2144 #[allow(clippy::unnecessary_literal_bound)]
2145 impl Tool for CountingTool {
2146 fn name(&self) -> &str {
2147 "count_tool"
2148 }
2149
2150 fn label(&self) -> &str {
2151 "count_tool"
2152 }
2153
2154 fn description(&self) -> &str {
2155 "counting tool"
2156 }
2157
2158 fn parameters(&self) -> serde_json::Value {
2159 json!({ "type": "object" })
2160 }
2161
2162 async fn execute(
2163 &self,
2164 _tool_call_id: &str,
2165 _input: serde_json::Value,
2166 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
2167 ) -> Result<ToolOutput> {
2168 self.calls.fetch_add(1, Ordering::SeqCst);
2169 Ok(ToolOutput {
2170 content: vec![ContentBlock::Text(TextContent::new("ok"))],
2171 details: None,
2172 is_error: false,
2173 })
2174 }
2175 }
2176
2177 #[derive(Debug)]
2178 struct ToolUseProvider {
2179 stream_calls: AtomicUsize,
2180 }
2181
2182 impl ToolUseProvider {
2183 const fn new() -> Self {
2184 Self {
2185 stream_calls: AtomicUsize::new(0),
2186 }
2187 }
2188
2189 fn assistant_message(
2190 &self,
2191 stop_reason: StopReason,
2192 content: Vec<ContentBlock>,
2193 ) -> AssistantMessage {
2194 AssistantMessage {
2195 content,
2196 api: self.api().to_string(),
2197 provider: self.name().to_string(),
2198 model: self.model_id().to_string(),
2199 usage: Usage::default(),
2200 stop_reason,
2201 error_message: None,
2202 timestamp: 0,
2203 }
2204 }
2205 }
2206
2207 #[async_trait]
2208 #[allow(clippy::unnecessary_literal_bound)]
2209 impl Provider for ToolUseProvider {
2210 fn name(&self) -> &str {
2211 "test-provider"
2212 }
2213
2214 fn api(&self) -> &str {
2215 "test-api"
2216 }
2217
2218 fn model_id(&self) -> &str {
2219 "test-model"
2220 }
2221
2222 async fn stream(
2223 &self,
2224 _context: &Context<'_>,
2225 _options: &StreamOptions,
2226 ) -> crate::error::Result<
2227 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2228 > {
2229 let call_index = self.stream_calls.fetch_add(1, Ordering::SeqCst);
2230
2231 let partial = self.assistant_message(StopReason::Stop, Vec::new());
2232
2233 let (reason, message) = if call_index == 0 {
2234 let tool_calls = vec![
2235 ToolCall {
2236 id: "call-1".to_string(),
2237 name: "count_tool".to_string(),
2238 arguments: json!({}),
2239 thought_signature: None,
2240 },
2241 ToolCall {
2242 id: "call-2".to_string(),
2243 name: "count_tool".to_string(),
2244 arguments: json!({}),
2245 thought_signature: None,
2246 },
2247 ];
2248
2249 (
2250 StopReason::ToolUse,
2251 self.assistant_message(
2252 StopReason::ToolUse,
2253 tool_calls
2254 .into_iter()
2255 .map(ContentBlock::ToolCall)
2256 .collect::<Vec<_>>(),
2257 ),
2258 )
2259 } else {
2260 (
2261 StopReason::Stop,
2262 self.assistant_message(
2263 StopReason::Stop,
2264 vec![ContentBlock::Text(TextContent::new("done"))],
2265 ),
2266 )
2267 };
2268
2269 let events = vec![
2270 Ok(StreamEvent::Start { partial }),
2271 Ok(StreamEvent::Done { reason, message }),
2272 ];
2273 Ok(Box::pin(futures::stream::iter(events)))
2274 }
2275 }
2276
2277 #[test]
2278 fn agent_session_enable_extensions_registers_extension_tools() {
2279 let runtime = RuntimeBuilder::current_thread()
2280 .build()
2281 .expect("runtime build");
2282
2283 runtime.block_on(async {
2284 let temp_dir = tempfile::tempdir().expect("tempdir");
2285 let entry_path = temp_dir.path().join("ext.mjs");
2286 std::fs::write(
2287 &entry_path,
2288 r#"
2289 export default function init(pi) {
2290 pi.registerTool({
2291 name: "hello_tool",
2292 label: "hello_tool",
2293 description: "test tool",
2294 parameters: { type: "object", properties: { name: { type: "string" } } },
2295 execute: async (_callId, input, _onUpdate, _abort, ctx) => {
2296 const who = input && input.name ? String(input.name) : "world";
2297 const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
2298 return {
2299 content: [{ type: "text", text: `hello ${who}` }],
2300 details: { from: "extension", cwd: cwd },
2301 isError: false
2302 };
2303 }
2304 });
2305 }
2306 "#,
2307 )
2308 .expect("write extension entry");
2309
2310 let provider = Arc::new(NoopProvider);
2311 let tools = ToolRegistry::new(&[], Path::new("."), None);
2312 let agent = Agent::new(provider, tools, AgentConfig::default());
2313 let session = Arc::new(Mutex::new(Session::in_memory()));
2314 let mut agent_session =
2315 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2316
2317 agent_session
2318 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2319 .await
2320 .expect("enable extensions");
2321
2322 let tool = agent_session
2323 .agent
2324 .tools
2325 .get("hello_tool")
2326 .expect("hello_tool registered");
2327
2328 let output = tool
2329 .execute("call-1", json!({ "name": "pi" }), None)
2330 .await
2331 .expect("execute tool");
2332
2333 assert!(!output.is_error);
2334 assert!(
2335 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2336 "Expected single text content block, got {:?}",
2337 output.content
2338 );
2339 let [ContentBlock::Text(text)] = output.content.as_slice() else {
2340 return;
2341 };
2342 assert_eq!(text.text, "hello pi");
2343
2344 let details = output.details.expect("details present");
2345 assert_eq!(
2346 details.get("from").and_then(serde_json::Value::as_str),
2347 Some("extension")
2348 );
2349 });
2350 }
2351
2352 #[test]
2353 fn agent_session_enable_extensions_rejects_mixed_js_and_native_entries() {
2354 let runtime = RuntimeBuilder::current_thread()
2355 .build()
2356 .expect("runtime build");
2357
2358 runtime.block_on(async {
2359 let temp_dir = tempfile::tempdir().expect("tempdir");
2360 let js_entry = temp_dir.path().join("ext.mjs");
2361 let native_entry = temp_dir.path().join("ext.native.json");
2362 std::fs::write(
2363 &js_entry,
2364 r"
2365 export default function init(_pi) {}
2366 ",
2367 )
2368 .expect("write js extension entry");
2369 std::fs::write(&native_entry, "{}").expect("write native extension descriptor");
2370
2371 let provider = Arc::new(NoopProvider);
2372 let tools = ToolRegistry::new(&[], Path::new("."), None);
2373 let agent = Agent::new(provider, tools, AgentConfig::default());
2374 let session = Arc::new(Mutex::new(Session::in_memory()));
2375 let mut agent_session =
2376 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2377
2378 let err = agent_session
2379 .enable_extensions(&[], temp_dir.path(), None, &[js_entry, native_entry])
2380 .await
2381 .expect_err("mixed extension runtimes should be rejected");
2382 let msg = err.to_string();
2383 assert!(
2384 msg.contains("Mixed extension runtimes are not supported"),
2385 "unexpected mixed-runtime error message: {msg}"
2386 );
2387 });
2388 }
2389
2390 #[test]
2391 fn extension_send_message_persists_custom_message_entry_when_idle() {
2392 let runtime = RuntimeBuilder::current_thread()
2393 .build()
2394 .expect("runtime build");
2395
2396 runtime.block_on(async {
2397 let temp_dir = tempfile::tempdir().expect("tempdir");
2398 let entry_path = temp_dir.path().join("ext.mjs");
2399 std::fs::write(
2400 &entry_path,
2401 r#"
2402 export default function init(pi) {
2403 pi.registerTool({
2404 name: "emit_message",
2405 label: "emit_message",
2406 description: "emit a custom message",
2407 parameters: { type: "object" },
2408 execute: async () => {
2409 pi.sendMessage({
2410 customType: "note",
2411 content: "hello",
2412 display: true,
2413 details: { from: "test" }
2414 }, {});
2415 return { content: [{ type: "text", text: "ok" }], isError: false };
2416 }
2417 });
2418 }
2419 "#,
2420 )
2421 .expect("write extension entry");
2422
2423 let provider = Arc::new(NoopProvider);
2424 let tools = ToolRegistry::new(&[], Path::new("."), None);
2425 let agent = Agent::new(provider, tools, AgentConfig::default());
2426 let session = Arc::new(Mutex::new(Session::in_memory()));
2427 let mut agent_session = AgentSession::new(
2428 agent,
2429 Arc::clone(&session),
2430 false,
2431 ResolvedCompactionSettings::default(),
2432 );
2433
2434 agent_session
2435 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2436 .await
2437 .expect("enable extensions");
2438
2439 let tool = agent_session
2440 .agent
2441 .tools
2442 .get("emit_message")
2443 .expect("emit_message registered");
2444
2445 let _ = tool
2446 .execute("call-1", json!({}), None)
2447 .await
2448 .expect("execute tool");
2449
2450 let cx = crate::agent_cx::AgentCx::for_request();
2451 let session_guard = session
2452 .lock(cx.cx())
2453 .await
2454 .expect("lock session");
2455 let messages = session_guard.to_messages_for_current_path();
2456
2457 assert!(
2458 messages.iter().any(|msg| {
2459 matches!(
2460 msg,
2461 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
2462 if custom_type == "note"
2463 && content == "hello"
2464 && *display
2465 && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
2466 )
2467 }),
2468 "expected custom message to be persisted, got {messages:?}"
2469 );
2470 });
2471 }
2472
2473 #[test]
2474 fn extension_send_message_persists_custom_message_entry_when_idle_after_await() {
2475 let runtime = RuntimeBuilder::current_thread()
2476 .build()
2477 .expect("runtime build");
2478
2479 runtime.block_on(async {
2480 let temp_dir = tempfile::tempdir().expect("tempdir");
2481 let entry_path = temp_dir.path().join("ext.mjs");
2482 std::fs::write(
2483 &entry_path,
2484 r#"
2485 export default function init(pi) {
2486 pi.registerTool({
2487 name: "emit_message",
2488 label: "emit_message",
2489 description: "emit a custom message",
2490 parameters: { type: "object" },
2491 execute: async () => {
2492 await Promise.resolve();
2493 pi.sendMessage({
2494 customType: "note",
2495 content: "hello-after-await",
2496 display: true,
2497 details: { from: "test" }
2498 }, {});
2499 return { content: [{ type: "text", text: "ok" }], isError: false };
2500 }
2501 });
2502 }
2503 "#,
2504 )
2505 .expect("write extension entry");
2506
2507 let provider = Arc::new(NoopProvider);
2508 let tools = ToolRegistry::new(&[], Path::new("."), None);
2509 let agent = Agent::new(provider, tools, AgentConfig::default());
2510 let session = Arc::new(Mutex::new(Session::in_memory()));
2511 let mut agent_session = AgentSession::new(
2512 agent,
2513 Arc::clone(&session),
2514 false,
2515 ResolvedCompactionSettings::default(),
2516 );
2517
2518 agent_session
2519 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2520 .await
2521 .expect("enable extensions");
2522
2523 let tool = agent_session
2524 .agent
2525 .tools
2526 .get("emit_message")
2527 .expect("emit_message registered");
2528
2529 let _ = tool
2530 .execute("call-1", json!({}), None)
2531 .await
2532 .expect("execute tool");
2533
2534 let cx = crate::agent_cx::AgentCx::for_request();
2535 let session_guard = session
2536 .lock(cx.cx())
2537 .await
2538 .expect("lock session");
2539 let messages = session_guard.to_messages_for_current_path();
2540
2541 assert!(
2542 messages.iter().any(|msg| {
2543 matches!(
2544 msg,
2545 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
2546 if custom_type == "note"
2547 && content == "hello-after-await"
2548 && *display
2549 && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
2550 )
2551 }),
2552 "expected custom message to be persisted, got {messages:?}"
2553 );
2554 });
2555 }
2556
2557 #[test]
2558 fn send_user_message_steer_skips_remaining_tools() {
2559 let runtime = RuntimeBuilder::current_thread()
2560 .build()
2561 .expect("runtime build");
2562
2563 runtime.block_on(async {
2564 let temp_dir = tempfile::tempdir().expect("tempdir");
2565 let entry_path = temp_dir.path().join("ext.mjs");
2566 std::fs::write(
2567 &entry_path,
2568 r#"
2569 export default function init(pi) {
2570 let sent = false;
2571 pi.on("tool_call", async (event) => {
2572 if (sent) return {};
2573 if (event && event.toolName === "count_tool") {
2574 sent = true;
2575 await pi.events("sendUserMessage", {
2576 text: "steer-now",
2577 options: { deliverAs: "steer" }
2578 });
2579 }
2580 return {};
2581 });
2582 }
2583 "#,
2584 )
2585 .expect("write extension entry");
2586
2587 let provider = Arc::new(ToolUseProvider::new());
2588 let calls = Arc::new(AtomicUsize::new(0));
2589 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2590 calls: Arc::clone(&calls),
2591 })]);
2592 let agent = Agent::new(provider, tools, AgentConfig::default());
2593 let session = Arc::new(Mutex::new(Session::in_memory()));
2594 let mut agent_session =
2595 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2596
2597 agent_session
2598 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2599 .await
2600 .expect("enable extensions");
2601
2602 let _ = agent_session
2603 .run_text("go".to_string(), |_| {})
2604 .await
2605 .expect("run_text");
2606
2607 assert_eq!(calls.load(Ordering::SeqCst), 1);
2609 });
2610 }
2611
2612 #[test]
2613 fn send_user_message_follow_up_does_not_skip_tools() {
2614 let runtime = RuntimeBuilder::current_thread()
2615 .build()
2616 .expect("runtime build");
2617
2618 runtime.block_on(async {
2619 let temp_dir = tempfile::tempdir().expect("tempdir");
2620 let entry_path = temp_dir.path().join("ext.mjs");
2621 std::fs::write(
2622 &entry_path,
2623 r#"
2624 export default function init(pi) {
2625 let sent = false;
2626 pi.on("tool_call", async (event) => {
2627 if (sent) return {};
2628 if (event && event.toolName === "count_tool") {
2629 sent = true;
2630 await pi.events("sendUserMessage", {
2631 text: "follow-up",
2632 options: { deliverAs: "followUp" }
2633 });
2634 }
2635 return {};
2636 });
2637 }
2638 "#,
2639 )
2640 .expect("write extension entry");
2641
2642 let provider = Arc::new(ToolUseProvider::new());
2643 let calls = Arc::new(AtomicUsize::new(0));
2644 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2645 calls: Arc::clone(&calls),
2646 })]);
2647 let agent = Agent::new(provider, tools, AgentConfig::default());
2648 let session = Arc::new(Mutex::new(Session::in_memory()));
2649 let mut agent_session =
2650 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2651
2652 agent_session
2653 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2654 .await
2655 .expect("enable extensions");
2656
2657 let _ = agent_session
2658 .run_text("go".to_string(), |_| {})
2659 .await
2660 .expect("run_text");
2661
2662 assert_eq!(calls.load(Ordering::SeqCst), 2);
2663 });
2664 }
2665
2666 #[test]
2667 fn tool_call_hook_can_block_tool_execution() {
2668 let runtime = RuntimeBuilder::current_thread()
2669 .build()
2670 .expect("runtime build");
2671
2672 runtime.block_on(async {
2673 let temp_dir = tempfile::tempdir().expect("tempdir");
2674 let entry_path = temp_dir.path().join("ext.mjs");
2675 std::fs::write(
2676 &entry_path,
2677 r#"
2678 export default function init(pi) {
2679 pi.on("tool_call", async (event) => {
2680 if (event && event.toolName === "count_tool") {
2681 return { block: true, reason: "blocked in test" };
2682 }
2683 return {};
2684 });
2685 }
2686 "#,
2687 )
2688 .expect("write extension entry");
2689
2690 let provider = Arc::new(NoopProvider);
2691 let calls = Arc::new(AtomicUsize::new(0));
2692 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2693 calls: Arc::clone(&calls),
2694 })]);
2695 let agent = Agent::new(provider, tools, AgentConfig::default());
2696 let session = Arc::new(Mutex::new(Session::in_memory()));
2697 let mut agent_session =
2698 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2699
2700 agent_session
2701 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2702 .await
2703 .expect("enable extensions");
2704
2705 let tool_call = ToolCall {
2706 id: "call-1".to_string(),
2707 name: "count_tool".to_string(),
2708 arguments: json!({}),
2709 thought_signature: None,
2710 };
2711
2712 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2713 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2714
2715 assert!(is_error);
2716 assert!(output.is_error);
2717 assert_eq!(calls.load(Ordering::SeqCst), 0);
2718
2719 assert_eq!(output.details, None);
2720 assert!(
2721 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2722 "Expected text output, got {:?}",
2723 output.content
2724 );
2725 if let [ContentBlock::Text(text)] = output.content.as_slice() {
2726 assert_eq!(text.text, "Tool execution blocked: blocked in test");
2727 }
2728 });
2729 }
2730
2731 #[test]
2732 fn tool_call_hook_errors_fail_open() {
2733 let runtime = RuntimeBuilder::current_thread()
2734 .build()
2735 .expect("runtime build");
2736
2737 runtime.block_on(async {
2738 let temp_dir = tempfile::tempdir().expect("tempdir");
2739 let entry_path = temp_dir.path().join("ext.mjs");
2740 std::fs::write(
2741 &entry_path,
2742 r#"
2743 export default function init(pi) {
2744 pi.on("tool_call", async (_event) => {
2745 throw new Error("boom");
2746 });
2747 }
2748 "#,
2749 )
2750 .expect("write extension entry");
2751
2752 let provider = Arc::new(NoopProvider);
2753 let calls = Arc::new(AtomicUsize::new(0));
2754 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2755 calls: Arc::clone(&calls),
2756 })]);
2757 let agent = Agent::new(provider, tools, AgentConfig::default());
2758 let session = Arc::new(Mutex::new(Session::in_memory()));
2759 let mut agent_session =
2760 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2761
2762 agent_session
2763 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2764 .await
2765 .expect("enable extensions");
2766
2767 let tool_call = ToolCall {
2768 id: "call-1".to_string(),
2769 name: "count_tool".to_string(),
2770 arguments: json!({}),
2771 thought_signature: None,
2772 };
2773
2774 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2775 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2776
2777 assert!(!is_error);
2778 assert!(!output.is_error);
2779 assert_eq!(calls.load(Ordering::SeqCst), 1);
2780 });
2781 }
2782
2783 #[test]
2784 fn tool_call_hook_absent_allows_tool_execution() {
2785 let runtime = RuntimeBuilder::current_thread()
2786 .build()
2787 .expect("runtime build");
2788
2789 runtime.block_on(async {
2790 let temp_dir = tempfile::tempdir().expect("tempdir");
2791 let entry_path = temp_dir.path().join("ext.mjs");
2792 std::fs::write(
2793 &entry_path,
2794 r"
2795 export default function init(_pi) {}
2796 ",
2797 )
2798 .expect("write extension entry");
2799
2800 let provider = Arc::new(NoopProvider);
2801 let calls = Arc::new(AtomicUsize::new(0));
2802 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2803 calls: Arc::clone(&calls),
2804 })]);
2805 let agent = Agent::new(provider, tools, AgentConfig::default());
2806 let session = Arc::new(Mutex::new(Session::in_memory()));
2807 let mut agent_session =
2808 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2809
2810 agent_session
2811 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2812 .await
2813 .expect("enable extensions");
2814
2815 let tool_call = ToolCall {
2816 id: "call-1".to_string(),
2817 name: "count_tool".to_string(),
2818 arguments: json!({}),
2819 thought_signature: None,
2820 };
2821
2822 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2823 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2824
2825 assert!(!is_error);
2826 assert!(!output.is_error);
2827 assert_eq!(calls.load(Ordering::SeqCst), 1);
2828 });
2829 }
2830
2831 #[test]
2832 fn tool_call_hook_returns_empty_allows_tool_execution() {
2833 let runtime = RuntimeBuilder::current_thread()
2834 .build()
2835 .expect("runtime build");
2836
2837 runtime.block_on(async {
2838 let temp_dir = tempfile::tempdir().expect("tempdir");
2839 let entry_path = temp_dir.path().join("ext.mjs");
2840 std::fs::write(
2841 &entry_path,
2842 r#"
2843 export default function init(pi) {
2844 pi.on("tool_call", async (_event) => ({}));
2845 }
2846 "#,
2847 )
2848 .expect("write extension entry");
2849
2850 let provider = Arc::new(NoopProvider);
2851 let calls = Arc::new(AtomicUsize::new(0));
2852 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2853 calls: Arc::clone(&calls),
2854 })]);
2855 let agent = Agent::new(provider, tools, AgentConfig::default());
2856 let session = Arc::new(Mutex::new(Session::in_memory()));
2857 let mut agent_session =
2858 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2859
2860 agent_session
2861 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2862 .await
2863 .expect("enable extensions");
2864
2865 let tool_call = ToolCall {
2866 id: "call-1".to_string(),
2867 name: "count_tool".to_string(),
2868 arguments: json!({}),
2869 thought_signature: None,
2870 };
2871
2872 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2873 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2874
2875 assert!(!is_error);
2876 assert!(!output.is_error);
2877 assert_eq!(calls.load(Ordering::SeqCst), 1);
2878 });
2879 }
2880
2881 #[test]
2882 fn tool_call_hook_can_block_bash_tool_execution() {
2883 let runtime = RuntimeBuilder::current_thread()
2884 .build()
2885 .expect("runtime build");
2886
2887 runtime.block_on(async {
2888 let temp_dir = tempfile::tempdir().expect("tempdir");
2889 let entry_path = temp_dir.path().join("ext.mjs");
2890 std::fs::write(
2891 &entry_path,
2892 r#"
2893 export default function init(pi) {
2894 pi.on("tool_call", async (event) => {
2895 const name = event && event.toolName ? String(event.toolName) : "";
2896 if (name === "bash") return { block: true, reason: "blocked bash in test" };
2897 return {};
2898 });
2899 }
2900 "#,
2901 )
2902 .expect("write extension entry");
2903
2904 let provider = Arc::new(NoopProvider);
2905 let tools = ToolRegistry::new(&["bash"], temp_dir.path(), None);
2906 let agent = Agent::new(provider, tools, AgentConfig::default());
2907 let session = Arc::new(Mutex::new(Session::in_memory()));
2908 let mut agent_session =
2909 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2910
2911 agent_session
2912 .enable_extensions(&["bash"], temp_dir.path(), None, &[entry_path])
2913 .await
2914 .expect("enable extensions");
2915
2916 let tool_call = ToolCall {
2917 id: "call-1".to_string(),
2918 name: "bash".to_string(),
2919 arguments: json!({ "command": "printf 'hi' > blocked.txt" }),
2920 thought_signature: None,
2921 };
2922
2923 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2924 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2925
2926 assert!(is_error);
2927 assert!(output.is_error);
2928 assert_eq!(output.details, None);
2929 assert!(
2930 !temp_dir.path().join("blocked.txt").exists(),
2931 "expected bash command not to run when blocked"
2932 );
2933 assert!(
2934 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2935 "Expected text output, got {:?}",
2936 output.content
2937 );
2938 if let [ContentBlock::Text(text)] = output.content.as_slice() {
2939 assert_eq!(text.text, "Tool execution blocked: blocked bash in test");
2940 }
2941 });
2942 }
2943
2944 #[test]
2945 fn tool_result_hook_can_modify_tool_output() {
2946 let runtime = RuntimeBuilder::current_thread()
2947 .build()
2948 .expect("runtime build");
2949
2950 runtime.block_on(async {
2951 let temp_dir = tempfile::tempdir().expect("tempdir");
2952 let entry_path = temp_dir.path().join("ext.mjs");
2953 std::fs::write(
2954 &entry_path,
2955 r#"
2956 export default function init(pi) {
2957 pi.on("tool_result", async (event) => {
2958 if (event && event.toolName === "count_tool") {
2959 return {
2960 content: [{ type: "text", text: "modified" }],
2961 details: { from: "tool_result" }
2962 };
2963 }
2964 return {};
2965 });
2966 }
2967 "#,
2968 )
2969 .expect("write extension entry");
2970
2971 let provider = Arc::new(NoopProvider);
2972 let calls = Arc::new(AtomicUsize::new(0));
2973 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2974 calls: Arc::clone(&calls),
2975 })]);
2976 let agent = Agent::new(provider, tools, AgentConfig::default());
2977 let session = Arc::new(Mutex::new(Session::in_memory()));
2978 let mut agent_session =
2979 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2980
2981 agent_session
2982 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2983 .await
2984 .expect("enable extensions");
2985
2986 let tool_call = ToolCall {
2987 id: "call-1".to_string(),
2988 name: "count_tool".to_string(),
2989 arguments: json!({}),
2990 thought_signature: None,
2991 };
2992
2993 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2994 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2995
2996 assert!(!is_error);
2997 assert!(!output.is_error);
2998 assert_eq!(calls.load(Ordering::SeqCst), 1);
2999 assert_eq!(output.details, Some(json!({ "from": "tool_result" })));
3000
3001 assert!(
3002 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3003 "Expected text output, got {:?}",
3004 output.content
3005 );
3006 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3007 assert_eq!(text.text, "modified");
3008 }
3009 });
3010 }
3011
3012 #[test]
3013 fn tool_result_hook_can_modify_tool_not_found_error() {
3014 let runtime = RuntimeBuilder::current_thread()
3015 .build()
3016 .expect("runtime build");
3017
3018 runtime.block_on(async {
3019 let temp_dir = tempfile::tempdir().expect("tempdir");
3020 let entry_path = temp_dir.path().join("ext.mjs");
3021 std::fs::write(
3022 &entry_path,
3023 r#"
3024 export default function init(pi) {
3025 pi.on("tool_result", async (event) => {
3026 if (event && event.toolName === "missing_tool" && event.isError) {
3027 return {
3028 content: [{ type: "text", text: "overridden" }],
3029 details: { handled: true }
3030 };
3031 }
3032 return {};
3033 });
3034 }
3035 "#,
3036 )
3037 .expect("write extension entry");
3038
3039 let provider = Arc::new(NoopProvider);
3040 let tools = ToolRegistry::from_tools(Vec::new());
3041 let agent = Agent::new(provider, tools, AgentConfig::default());
3042 let session = Arc::new(Mutex::new(Session::in_memory()));
3043 let mut agent_session =
3044 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3045
3046 agent_session
3047 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3048 .await
3049 .expect("enable extensions");
3050
3051 let tool_call = ToolCall {
3052 id: "call-1".to_string(),
3053 name: "missing_tool".to_string(),
3054 arguments: json!({}),
3055 thought_signature: None,
3056 };
3057
3058 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3059 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3060
3061 assert!(is_error);
3062 assert!(output.is_error);
3063 assert_eq!(output.details, Some(json!({ "handled": true })));
3064
3065 assert!(
3066 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3067 "Expected text output, got {:?}",
3068 output.content
3069 );
3070 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3071 assert_eq!(text.text, "overridden");
3072 }
3073 });
3074 }
3075
3076 #[test]
3077 fn tool_result_hook_errors_fail_open() {
3078 let runtime = RuntimeBuilder::current_thread()
3079 .build()
3080 .expect("runtime build");
3081
3082 runtime.block_on(async {
3083 let temp_dir = tempfile::tempdir().expect("tempdir");
3084 let entry_path = temp_dir.path().join("ext.mjs");
3085 std::fs::write(
3086 &entry_path,
3087 r#"
3088 export default function init(pi) {
3089 pi.on("tool_result", async (_event) => {
3090 throw new Error("boom");
3091 });
3092 }
3093 "#,
3094 )
3095 .expect("write extension entry");
3096
3097 let provider = Arc::new(NoopProvider);
3098 let calls = Arc::new(AtomicUsize::new(0));
3099 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3100 calls: Arc::clone(&calls),
3101 })]);
3102 let agent = Agent::new(provider, tools, AgentConfig::default());
3103 let session = Arc::new(Mutex::new(Session::in_memory()));
3104 let mut agent_session =
3105 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3106
3107 agent_session
3108 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3109 .await
3110 .expect("enable extensions");
3111
3112 let tool_call = ToolCall {
3113 id: "call-1".to_string(),
3114 name: "count_tool".to_string(),
3115 arguments: json!({}),
3116 thought_signature: None,
3117 };
3118
3119 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3120 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3121
3122 assert!(!is_error);
3123 assert!(!output.is_error);
3124 assert_eq!(calls.load(Ordering::SeqCst), 1);
3125
3126 assert_eq!(output.details, None);
3127 assert!(
3128 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3129 "Expected text output, got {:?}",
3130 output.content
3131 );
3132 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3133 assert_eq!(text.text, "ok");
3134 }
3135 });
3136 }
3137
3138 #[test]
3139 fn tool_result_hook_runs_on_blocked_tool_call() {
3140 let runtime = RuntimeBuilder::current_thread()
3141 .build()
3142 .expect("runtime build");
3143
3144 runtime.block_on(async {
3145 let temp_dir = tempfile::tempdir().expect("tempdir");
3146 let entry_path = temp_dir.path().join("ext.mjs");
3147 std::fs::write(
3148 &entry_path,
3149 r#"
3150 export default function init(pi) {
3151 pi.on("tool_call", async (event) => {
3152 if (event && event.toolName === "count_tool") {
3153 return { block: true, reason: "blocked in test" };
3154 }
3155 return {};
3156 });
3157
3158 pi.on("tool_result", async (event) => {
3159 if (event && event.toolName === "count_tool" && event.isError) {
3160 return { content: [{ type: "text", text: "override" }] };
3161 }
3162 return {};
3163 });
3164 }
3165 "#,
3166 )
3167 .expect("write extension entry");
3168
3169 let provider = Arc::new(NoopProvider);
3170 let calls = Arc::new(AtomicUsize::new(0));
3171 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3172 calls: Arc::clone(&calls),
3173 })]);
3174 let agent = Agent::new(provider, tools, AgentConfig::default());
3175 let session = Arc::new(Mutex::new(Session::in_memory()));
3176 let mut agent_session =
3177 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3178
3179 agent_session
3180 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3181 .await
3182 .expect("enable extensions");
3183
3184 let tool_call = ToolCall {
3185 id: "call-1".to_string(),
3186 name: "count_tool".to_string(),
3187 arguments: json!({}),
3188 thought_signature: None,
3189 };
3190
3191 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3192 let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3193
3194 assert!(is_error);
3195 assert!(output.is_error);
3196 assert_eq!(calls.load(Ordering::SeqCst), 0);
3197
3198 assert!(
3199 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3200 "Expected text output, got {:?}",
3201 output.content
3202 );
3203 if let [ContentBlock::Text(text)] = output.content.as_slice() {
3204 assert_eq!(text.text, "override");
3205 }
3206 });
3207 }
3208}
3209
3210#[cfg(test)]
3211mod abort_tests {
3212 use super::*;
3213 use crate::session::Session;
3214 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
3215 use asupersync::runtime::RuntimeBuilder;
3216 use async_trait::async_trait;
3217 use futures::Stream;
3218 use serde_json::json;
3219 use std::path::Path;
3220 use std::pin::Pin;
3221 use std::sync::Mutex as StdMutex;
3222 use std::sync::atomic::AtomicUsize;
3223 use std::task::{Context as TaskContext, Poll};
3224
3225 struct StartThenPending {
3226 start: Option<StreamEvent>,
3227 }
3228
3229 impl Stream for StartThenPending {
3230 type Item = crate::error::Result<StreamEvent>;
3231
3232 fn poll_next(
3233 mut self: Pin<&mut Self>,
3234 _cx: &mut TaskContext<'_>,
3235 ) -> Poll<Option<Self::Item>> {
3236 if let Some(event) = self.start.take() {
3237 return Poll::Ready(Some(Ok(event)));
3238 }
3239 Poll::Pending
3240 }
3241 }
3242
3243 #[derive(Debug)]
3244 struct HangingProvider;
3245
3246 #[async_trait]
3247 #[allow(clippy::unnecessary_literal_bound)]
3248 impl Provider for HangingProvider {
3249 fn name(&self) -> &str {
3250 "test-provider"
3251 }
3252
3253 fn api(&self) -> &str {
3254 "test-api"
3255 }
3256
3257 fn model_id(&self) -> &str {
3258 "test-model"
3259 }
3260
3261 async fn stream(
3262 &self,
3263 _context: &Context<'_>,
3264 _options: &StreamOptions,
3265 ) -> crate::error::Result<
3266 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3267 > {
3268 let partial = AssistantMessage {
3269 content: Vec::new(),
3270 api: self.api().to_string(),
3271 provider: self.name().to_string(),
3272 model: self.model_id().to_string(),
3273 usage: Usage::default(),
3274 stop_reason: StopReason::Stop,
3275 error_message: None,
3276 timestamp: 0,
3277 };
3278
3279 Ok(Box::pin(StartThenPending {
3280 start: Some(StreamEvent::Start { partial }),
3281 }))
3282 }
3283 }
3284
3285 #[derive(Debug)]
3286 struct CountingProvider {
3287 calls: Arc<std::sync::atomic::AtomicUsize>,
3288 }
3289
3290 #[async_trait]
3291 #[allow(clippy::unnecessary_literal_bound)]
3292 impl Provider for CountingProvider {
3293 fn name(&self) -> &str {
3294 "test-provider"
3295 }
3296
3297 fn api(&self) -> &str {
3298 "test-api"
3299 }
3300
3301 fn model_id(&self) -> &str {
3302 "test-model"
3303 }
3304
3305 async fn stream(
3306 &self,
3307 _context: &Context<'_>,
3308 _options: &StreamOptions,
3309 ) -> crate::error::Result<
3310 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3311 > {
3312 self.calls.fetch_add(1, Ordering::SeqCst);
3313 Ok(Box::pin(futures::stream::empty()))
3314 }
3315 }
3316
3317 #[derive(Debug)]
3318 struct PhasedProvider {
3319 pending_calls: usize,
3320 calls: AtomicUsize,
3321 }
3322
3323 impl PhasedProvider {
3324 const fn new(pending_calls: usize) -> Self {
3325 Self {
3326 pending_calls,
3327 calls: AtomicUsize::new(0),
3328 }
3329 }
3330
3331 fn base_message() -> AssistantMessage {
3332 AssistantMessage {
3333 content: Vec::new(),
3334 api: "test-api".to_string(),
3335 provider: "test-provider".to_string(),
3336 model: "test-model".to_string(),
3337 usage: Usage::default(),
3338 stop_reason: StopReason::Stop,
3339 error_message: None,
3340 timestamp: 0,
3341 }
3342 }
3343 }
3344
3345 #[async_trait]
3346 #[allow(clippy::unnecessary_literal_bound)]
3347 impl Provider for PhasedProvider {
3348 fn name(&self) -> &str {
3349 "test-provider"
3350 }
3351
3352 fn api(&self) -> &str {
3353 "test-api"
3354 }
3355
3356 fn model_id(&self) -> &str {
3357 "test-model"
3358 }
3359
3360 async fn stream(
3361 &self,
3362 _context: &Context<'_>,
3363 _options: &StreamOptions,
3364 ) -> crate::error::Result<
3365 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3366 > {
3367 let call = self.calls.fetch_add(1, Ordering::SeqCst);
3368 if call < self.pending_calls {
3369 return Ok(Box::pin(StartThenPending {
3370 start: Some(StreamEvent::Start {
3371 partial: Self::base_message(),
3372 }),
3373 }));
3374 }
3375
3376 let partial = Self::base_message();
3377 let mut done = Self::base_message();
3378 done.content = vec![ContentBlock::Text(TextContent::new(format!(
3379 "resumed-response-{call}"
3380 )))];
3381
3382 Ok(Box::pin(futures::stream::iter(vec![
3383 Ok(StreamEvent::Start { partial }),
3384 Ok(StreamEvent::Done {
3385 reason: StopReason::Stop,
3386 message: done,
3387 }),
3388 ])))
3389 }
3390 }
3391
3392 #[derive(Debug)]
3393 struct ToolCallProvider;
3394
3395 #[async_trait]
3396 #[allow(clippy::unnecessary_literal_bound)]
3397 impl Provider for ToolCallProvider {
3398 fn name(&self) -> &str {
3399 "test-provider"
3400 }
3401
3402 fn api(&self) -> &str {
3403 "test-api"
3404 }
3405
3406 fn model_id(&self) -> &str {
3407 "test-model"
3408 }
3409
3410 async fn stream(
3411 &self,
3412 _context: &Context<'_>,
3413 _options: &StreamOptions,
3414 ) -> crate::error::Result<
3415 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3416 > {
3417 let message = AssistantMessage {
3418 content: vec![ContentBlock::ToolCall(ToolCall {
3419 id: "call-1".to_string(),
3420 name: "hanging_tool".to_string(),
3421 arguments: json!({}),
3422 thought_signature: None,
3423 })],
3424 api: "test-api".to_string(),
3425 provider: "test-provider".to_string(),
3426 model: "test-model".to_string(),
3427 usage: Usage::default(),
3428 stop_reason: StopReason::ToolUse,
3429 error_message: None,
3430 timestamp: 0,
3431 };
3432
3433 Ok(Box::pin(futures::stream::iter(vec![Ok(
3434 StreamEvent::Done {
3435 reason: StopReason::ToolUse,
3436 message,
3437 },
3438 )])))
3439 }
3440 }
3441
3442 #[derive(Debug)]
3443 struct HangingTool;
3444
3445 #[async_trait]
3446 #[allow(clippy::unnecessary_literal_bound)]
3447 impl Tool for HangingTool {
3448 fn name(&self) -> &str {
3449 "hanging_tool"
3450 }
3451
3452 fn label(&self) -> &str {
3453 "Hanging Tool"
3454 }
3455
3456 fn description(&self) -> &str {
3457 "Never completes unless aborted by the host"
3458 }
3459
3460 fn parameters(&self) -> serde_json::Value {
3461 json!({
3462 "type": "object",
3463 "properties": {},
3464 "additionalProperties": false
3465 })
3466 }
3467
3468 async fn execute(
3469 &self,
3470 _tool_call_id: &str,
3471 _input: serde_json::Value,
3472 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
3473 ) -> crate::error::Result<ToolOutput> {
3474 futures::future::pending::<()>().await;
3475 unreachable!("hanging tool should be aborted by the agent")
3476 }
3477 }
3478
3479 fn event_tag(event: &AgentEvent) -> &'static str {
3480 match event {
3481 AgentEvent::AgentStart { .. } => "agent_start",
3482 AgentEvent::AgentEnd { error, .. } => {
3483 if error.as_deref() == Some("Aborted") {
3484 "agent_end_aborted"
3485 } else {
3486 "agent_end"
3487 }
3488 }
3489 AgentEvent::TurnStart { .. } => "turn_start",
3490 AgentEvent::TurnEnd { .. } => "turn_end",
3491 AgentEvent::MessageStart { .. } => "message_start",
3492 AgentEvent::MessageUpdate {
3493 assistant_message_event,
3494 ..
3495 } => match &assistant_message_event {
3496 AssistantMessageEvent::Error {
3497 reason: StopReason::Aborted,
3498 ..
3499 } => "assistant_error_aborted",
3500 AssistantMessageEvent::Done { .. } => "assistant_done",
3501 _ => "assistant_update",
3502 },
3503 AgentEvent::MessageEnd { .. } => "message_end",
3504 AgentEvent::ToolExecutionStart { .. } => "tool_start",
3505 AgentEvent::ToolExecutionUpdate { .. } => "tool_update",
3506 AgentEvent::ToolExecutionEnd { .. } => "tool_end",
3507 AgentEvent::AutoCompactionStart { .. } => "auto_compaction_start",
3508 AgentEvent::AutoCompactionEnd { .. } => "auto_compaction_end",
3509 AgentEvent::AutoRetryStart { .. } => "auto_retry_start",
3510 AgentEvent::AutoRetryEnd { .. } => "auto_retry_end",
3511 AgentEvent::ExtensionError { .. } => "extension_error",
3512 }
3513 }
3514
3515 fn assert_abort_resume_message_sequence(persisted: &[Message]) {
3516 assert_eq!(
3517 persisted.len(),
3518 6,
3519 "expected three user+assistant pairs, got: {persisted:?}"
3520 );
3521
3522 let assistant_states = persisted
3523 .iter()
3524 .filter_map(|message| match message {
3525 Message::Assistant(assistant) => Some(assistant.stop_reason),
3526 _ => None,
3527 })
3528 .collect::<Vec<_>>();
3529 assert_eq!(
3530 assistant_states,
3531 vec![StopReason::Aborted, StopReason::Aborted, StopReason::Stop]
3532 );
3533 }
3534
3535 fn assert_abort_resume_timeline_boundaries(timeline: &[String]) {
3536 assert!(
3537 timeline
3538 .iter()
3539 .any(|event| event == "run0:agent_end_aborted"),
3540 "missing aborted boundary for first run: {timeline:?}"
3541 );
3542 assert!(
3543 timeline
3544 .iter()
3545 .any(|event| event == "run1:agent_end_aborted"),
3546 "missing aborted boundary for second run: {timeline:?}"
3547 );
3548 assert!(
3549 timeline.iter().any(|event| event == "run2:agent_end"),
3550 "missing successful boundary for resumed run: {timeline:?}"
3551 );
3552 }
3553
3554 #[test]
3555 fn abort_interrupts_in_flight_stream() {
3556 let runtime = RuntimeBuilder::current_thread()
3557 .build()
3558 .expect("runtime build");
3559 let handle = runtime.handle();
3560
3561 let started = Arc::new(Notify::new());
3562 let started_wait = started.notified();
3563
3564 let (abort_handle, abort_signal) = AbortHandle::new();
3565
3566 let provider = Arc::new(HangingProvider);
3567 let tools = ToolRegistry::new(&[], Path::new("."), None);
3568 let agent = Agent::new(provider, tools, AgentConfig::default());
3569 let session = Arc::new(Mutex::new(Session::in_memory()));
3570 let mut agent_session =
3571 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3572
3573 let started_tx = Arc::clone(&started);
3574 let join = handle.spawn(async move {
3575 agent_session
3576 .run_text_with_abort("hello".to_string(), Some(abort_signal), move |event| {
3577 if matches!(
3578 event,
3579 AgentEvent::MessageStart {
3580 message: Message::Assistant(_)
3581 }
3582 ) {
3583 started_tx.notify_one();
3584 }
3585 })
3586 .await
3587 });
3588
3589 runtime.block_on(async move {
3590 started_wait.await;
3591 abort_handle.abort();
3592
3593 let message = join.await.expect("run_text_with_abort");
3594 assert_eq!(message.stop_reason, StopReason::Aborted);
3595 assert_eq!(message.error_message.as_deref(), Some("Aborted"));
3596 });
3597 }
3598
3599 #[test]
3600 fn abort_before_run_skips_provider_stream_call() {
3601 let runtime = RuntimeBuilder::current_thread()
3602 .build()
3603 .expect("runtime build");
3604
3605 let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
3606 let provider = Arc::new(CountingProvider {
3607 calls: Arc::clone(&calls),
3608 });
3609 let tools = ToolRegistry::new(&[], Path::new("."), None);
3610 let agent = Agent::new(provider, tools, AgentConfig::default());
3611 let session = Arc::new(Mutex::new(Session::in_memory()));
3612 let mut agent_session =
3613 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3614
3615 let (abort_handle, abort_signal) = AbortHandle::new();
3616 abort_handle.abort();
3617
3618 runtime.block_on(async move {
3619 let message = agent_session
3620 .run_text_with_abort("hello".to_string(), Some(abort_signal), |_| {})
3621 .await
3622 .expect("run_text_with_abort");
3623 assert_eq!(message.stop_reason, StopReason::Aborted);
3624 assert_eq!(calls.load(Ordering::SeqCst), 0);
3625 });
3626 }
3627
3628 #[test]
3629 fn abort_then_resume_preserves_session_history() {
3630 let runtime = RuntimeBuilder::current_thread()
3631 .build()
3632 .expect("runtime build");
3633 let handle = runtime.handle();
3634
3635 runtime.block_on(async move {
3636 let provider = Arc::new(PhasedProvider::new(1));
3637 let tools = ToolRegistry::new(&[], Path::new("."), None);
3638 let agent = Agent::new(provider, tools, AgentConfig::default());
3639 let session = Arc::new(Mutex::new(Session::in_memory()));
3640 let mut agent_session = AgentSession::new(
3641 agent,
3642 Arc::clone(&session),
3643 false,
3644 ResolvedCompactionSettings::default(),
3645 );
3646
3647 let started = Arc::new(Notify::new());
3648 let (abort_handle, abort_signal) = AbortHandle::new();
3649 let started_for_abort = Arc::clone(&started);
3650 let abort_join = handle.spawn(async move {
3651 started_for_abort.notified().await;
3652 abort_handle.abort();
3653 });
3654
3655 let aborted = agent_session
3656 .run_text_with_abort("first".to_string(), Some(abort_signal), {
3657 let started = Arc::clone(&started);
3658 move |event| {
3659 if matches!(
3660 event,
3661 AgentEvent::MessageStart {
3662 message: Message::Assistant(_)
3663 }
3664 ) {
3665 started.notify_one();
3666 }
3667 }
3668 })
3669 .await
3670 .expect("first run");
3671 abort_join.await;
3672
3673 assert_eq!(aborted.stop_reason, StopReason::Aborted);
3674 assert_eq!(aborted.error_message.as_deref(), Some("Aborted"));
3675
3676 let resumed = agent_session
3677 .run_text("second".to_string(), |_| {})
3678 .await
3679 .expect("resumed run");
3680 assert_eq!(resumed.stop_reason, StopReason::Stop);
3681 assert!(resumed.error_message.is_none());
3682
3683 let cx = crate::agent_cx::AgentCx::for_request();
3684 let persisted = session
3685 .lock(cx.cx())
3686 .await
3687 .expect("lock session")
3688 .to_messages_for_current_path();
3689
3690 assert_eq!(
3691 persisted.len(),
3692 4,
3693 "unexpected message history after abort+resume: {persisted:?}"
3694 );
3695 assert!(matches!(persisted.first(), Some(Message::User(_))));
3696 assert!(matches!(
3697 persisted.get(1),
3698 Some(Message::Assistant(assistant)) if assistant.stop_reason == StopReason::Aborted
3699 ));
3700 assert!(matches!(persisted.get(2), Some(Message::User(_))));
3701 assert!(matches!(
3702 persisted.get(3),
3703 Some(Message::Assistant(assistant))
3704 if assistant.stop_reason == StopReason::Stop && assistant.error_message.is_none()
3705 ));
3706 });
3707 }
3708
3709 #[test]
3710 fn repeated_abort_then_resume_has_consistent_timeline_and_state() {
3711 let runtime = RuntimeBuilder::current_thread()
3712 .build()
3713 .expect("runtime build");
3714 let handle = runtime.handle();
3715
3716 runtime.block_on(async move {
3717 let provider = Arc::new(PhasedProvider::new(2));
3718 let tools = ToolRegistry::new(&[], Path::new("."), None);
3719 let agent = Agent::new(provider, tools, AgentConfig::default());
3720 let session = Arc::new(Mutex::new(Session::in_memory()));
3721 let mut agent_session = AgentSession::new(
3722 agent,
3723 Arc::clone(&session),
3724 false,
3725 ResolvedCompactionSettings::default(),
3726 );
3727
3728 let timeline = Arc::new(StdMutex::new(Vec::<String>::new()));
3729
3730 for run_idx in 0..2 {
3731 let started = Arc::new(Notify::new());
3732 let (abort_handle, abort_signal) = AbortHandle::new();
3733 let started_for_abort = Arc::clone(&started);
3734 let abort_join = handle.spawn(async move {
3735 started_for_abort.notified().await;
3736 abort_handle.abort();
3737 });
3738
3739 let run_timeline = Arc::clone(&timeline);
3740 let aborted = agent_session
3741 .run_text_with_abort(format!("abort-run-{run_idx}"), Some(abort_signal), {
3742 let started = Arc::clone(&started);
3743 move |event| {
3744 if let Ok(mut events) = run_timeline.lock() {
3745 events.push(format!("run{run_idx}:{}", event_tag(&event)));
3746 }
3747 if matches!(
3748 event,
3749 AgentEvent::MessageStart {
3750 message: Message::Assistant(_)
3751 }
3752 ) {
3753 started.notify_one();
3754 }
3755 }
3756 })
3757 .await
3758 .expect("aborted run");
3759 abort_join.await;
3760
3761 assert_eq!(
3762 aborted.stop_reason,
3763 StopReason::Aborted,
3764 "run {run_idx} should abort cleanly"
3765 );
3766 }
3767
3768 let run_timeline = Arc::clone(&timeline);
3769 let resumed = agent_session
3770 .run_text("final-run".to_string(), move |event| {
3771 if let Ok(mut events) = run_timeline.lock() {
3772 events.push(format!("run2:{}", event_tag(&event)));
3773 }
3774 })
3775 .await
3776 .expect("final resumed run");
3777 assert_eq!(resumed.stop_reason, StopReason::Stop);
3778 assert!(resumed.error_message.is_none());
3779
3780 let cx = crate::agent_cx::AgentCx::for_request();
3781 let persisted = session
3782 .lock(cx.cx())
3783 .await
3784 .expect("lock session")
3785 .to_messages_for_current_path();
3786
3787 assert_abort_resume_message_sequence(&persisted);
3788
3789 let timeline = timeline.lock().expect("timeline lock").clone();
3790 assert_abort_resume_timeline_boundaries(&timeline);
3791 });
3792 }
3793
3794 #[test]
3795 fn abort_during_tool_execution_records_aborted_tool_result() {
3796 let runtime = RuntimeBuilder::current_thread()
3797 .build()
3798 .expect("runtime build");
3799 let handle = runtime.handle();
3800
3801 runtime.block_on(async move {
3802 let provider = Arc::new(ToolCallProvider);
3803 let tools = ToolRegistry::from_tools(vec![Box::new(HangingTool)]);
3804 let agent = Agent::new(provider, tools, AgentConfig::default());
3805 let session = Arc::new(Mutex::new(Session::in_memory()));
3806 let mut agent_session = AgentSession::new(
3807 agent,
3808 Arc::clone(&session),
3809 false,
3810 ResolvedCompactionSettings::default(),
3811 );
3812
3813 let tool_started = Arc::new(Notify::new());
3814 let (abort_handle, abort_signal) = AbortHandle::new();
3815 let tool_started_for_abort = Arc::clone(&tool_started);
3816 let abort_join = handle.spawn(async move {
3817 tool_started_for_abort.notified().await;
3818 abort_handle.abort();
3819 });
3820
3821 let result = agent_session
3822 .run_text_with_abort("trigger tool".to_string(), Some(abort_signal), {
3823 let tool_started = Arc::clone(&tool_started);
3824 move |event| {
3825 if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
3826 tool_started.notify_one();
3827 }
3828 }
3829 })
3830 .await
3831 .expect("tool-abort run");
3832 abort_join.await;
3833 assert_eq!(result.stop_reason, StopReason::Aborted);
3834
3835 let cx = crate::agent_cx::AgentCx::for_request();
3836 let persisted = session
3837 .lock(cx.cx())
3838 .await
3839 .expect("lock session")
3840 .to_messages_for_current_path();
3841
3842 let tool_result = persisted
3843 .iter()
3844 .find_map(|message| match message {
3845 Message::ToolResult(result) => Some(result),
3846 _ => None,
3847 })
3848 .expect("expected tool result message");
3849 assert!(tool_result.is_error);
3850 assert!(
3851 tool_result.content.iter().any(|block| {
3852 matches!(
3853 block,
3854 ContentBlock::Text(text) if text.text.contains("Tool execution aborted")
3855 )
3856 }),
3857 "missing aborted tool marker in tool output: {:?}",
3858 tool_result.content
3859 );
3860 });
3861 }
3862}
3863
3864#[cfg(test)]
3865mod turn_event_tests {
3866 use super::*;
3867 use crate::session::Session;
3868 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
3869 use asupersync::runtime::RuntimeBuilder;
3870 use async_trait::async_trait;
3871 use futures::Stream;
3872 use serde_json::json;
3873 use std::path::Path;
3874 use std::pin::Pin;
3875 use std::sync::atomic::AtomicUsize;
3876 fn assistant_message(text: &str) -> AssistantMessage {
3880 AssistantMessage {
3881 content: vec![ContentBlock::Text(TextContent::new(text))],
3882 api: "test-api".to_string(),
3883 provider: "test-provider".to_string(),
3884 model: "test-model".to_string(),
3885 usage: Usage::default(),
3886 stop_reason: StopReason::Stop,
3887 error_message: None,
3888 timestamp: 0,
3889 }
3890 }
3891
3892 struct SingleShotProvider;
3893
3894 #[async_trait]
3895 #[allow(clippy::unnecessary_literal_bound)]
3896 impl Provider for SingleShotProvider {
3897 fn name(&self) -> &str {
3898 "test-provider"
3899 }
3900
3901 fn api(&self) -> &str {
3902 "test-api"
3903 }
3904
3905 fn model_id(&self) -> &str {
3906 "test-model"
3907 }
3908
3909 async fn stream(
3910 &self,
3911 _context: &Context<'_>,
3912 _options: &StreamOptions,
3913 ) -> crate::error::Result<
3914 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3915 > {
3916 let partial = assistant_message("");
3917 let final_message = assistant_message("hello");
3918 let events = vec![
3919 Ok(StreamEvent::Start { partial }),
3920 Ok(StreamEvent::Done {
3921 reason: StopReason::Stop,
3922 message: final_message,
3923 }),
3924 ];
3925 Ok(Box::pin(futures::stream::iter(events)))
3926 }
3927 }
3928
3929 #[derive(Debug)]
3930 struct EchoTool;
3931
3932 #[async_trait]
3933 #[allow(clippy::unnecessary_literal_bound)]
3934 impl Tool for EchoTool {
3935 fn name(&self) -> &str {
3936 "echo_tool"
3937 }
3938
3939 fn label(&self) -> &str {
3940 "echo_tool"
3941 }
3942
3943 fn description(&self) -> &str {
3944 "echo test tool"
3945 }
3946
3947 fn parameters(&self) -> serde_json::Value {
3948 json!({ "type": "object" })
3949 }
3950
3951 async fn execute(
3952 &self,
3953 _tool_call_id: &str,
3954 _input: serde_json::Value,
3955 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
3956 ) -> Result<ToolOutput> {
3957 Ok(ToolOutput {
3958 content: vec![ContentBlock::Text(TextContent::new("tool-ok"))],
3959 details: None,
3960 is_error: false,
3961 })
3962 }
3963 }
3964
3965 #[derive(Debug)]
3966 struct ToolTurnProvider {
3967 calls: AtomicUsize,
3968 }
3969
3970 impl ToolTurnProvider {
3971 const fn new() -> Self {
3972 Self {
3973 calls: AtomicUsize::new(0),
3974 }
3975 }
3976
3977 fn assistant_message_with(
3978 &self,
3979 stop_reason: StopReason,
3980 content: Vec<ContentBlock>,
3981 ) -> AssistantMessage {
3982 AssistantMessage {
3983 content,
3984 api: self.api().to_string(),
3985 provider: self.name().to_string(),
3986 model: self.model_id().to_string(),
3987 usage: Usage::default(),
3988 stop_reason,
3989 error_message: None,
3990 timestamp: 0,
3991 }
3992 }
3993 }
3994
3995 #[async_trait]
3996 #[allow(clippy::unnecessary_literal_bound)]
3997 impl Provider for ToolTurnProvider {
3998 fn name(&self) -> &str {
3999 "test-provider"
4000 }
4001
4002 fn api(&self) -> &str {
4003 "test-api"
4004 }
4005
4006 fn model_id(&self) -> &str {
4007 "test-model"
4008 }
4009
4010 async fn stream(
4011 &self,
4012 _context: &Context<'_>,
4013 _options: &StreamOptions,
4014 ) -> crate::error::Result<
4015 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4016 > {
4017 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
4018 let partial = self.assistant_message_with(StopReason::Stop, Vec::new());
4019 let done = if call_index == 0 {
4020 self.assistant_message_with(
4021 StopReason::ToolUse,
4022 vec![ContentBlock::ToolCall(ToolCall {
4023 id: "tool-1".to_string(),
4024 name: "echo_tool".to_string(),
4025 arguments: json!({}),
4026 thought_signature: None,
4027 })],
4028 )
4029 } else {
4030 self.assistant_message_with(
4031 StopReason::Stop,
4032 vec![ContentBlock::Text(TextContent::new("final"))],
4033 )
4034 };
4035
4036 Ok(Box::pin(futures::stream::iter(vec![
4037 Ok(StreamEvent::Start { partial }),
4038 Ok(StreamEvent::Done {
4039 reason: done.stop_reason,
4040 message: done,
4041 }),
4042 ])))
4043 }
4044 }
4045
4046 #[test]
4047 fn turn_events_wrap_assistant_response() {
4048 let runtime = RuntimeBuilder::current_thread()
4049 .build()
4050 .expect("runtime build");
4051 let handle = runtime.handle();
4052
4053 let provider = Arc::new(SingleShotProvider);
4054 let tools = ToolRegistry::new(&[], Path::new("."), None);
4055 let agent = Agent::new(provider, tools, AgentConfig::default());
4056 let session = Arc::new(Mutex::new(Session::in_memory()));
4057 let mut agent_session =
4058 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4059
4060 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4061 Arc::new(std::sync::Mutex::new(Vec::new()));
4062 let events_capture = Arc::clone(&events);
4063
4064 let join = handle.spawn(async move {
4065 agent_session
4066 .run_text("hello".to_string(), move |event| {
4067 events_capture.lock().unwrap().push(event);
4068 })
4069 .await
4070 .expect("run_text")
4071 });
4072
4073 runtime.block_on(async move {
4074 let message = join.await;
4075 assert_eq!(message.stop_reason, StopReason::Stop);
4076
4077 let events = events.lock().unwrap();
4078 let turn_start_indices = events
4079 .iter()
4080 .enumerate()
4081 .filter_map(|(idx, event)| {
4082 matches!(event, AgentEvent::TurnStart { .. }).then_some(idx)
4083 })
4084 .collect::<Vec<_>>();
4085 let turn_end_indices = events
4086 .iter()
4087 .enumerate()
4088 .filter_map(|(idx, event)| {
4089 matches!(event, AgentEvent::TurnEnd { .. }).then_some(idx)
4090 })
4091 .collect::<Vec<_>>();
4092
4093 assert_eq!(turn_start_indices.len(), 1);
4094 assert_eq!(turn_end_indices.len(), 1);
4095 assert!(turn_start_indices[0] < turn_end_indices[0]);
4096
4097 let assistant_message_end = events
4098 .iter()
4099 .enumerate()
4100 .find_map(|(idx, event)| match event {
4101 AgentEvent::MessageEnd {
4102 message: Message::Assistant(_),
4103 } => Some(idx),
4104 _ => None,
4105 })
4106 .expect("assistant message end");
4107
4108 assert!(assistant_message_end < turn_end_indices[0]);
4109
4110 let (message_is_assistant, tool_results_empty) = {
4111 let turn_end_event = &events[turn_end_indices[0]];
4112 assert!(
4113 matches!(turn_end_event, AgentEvent::TurnEnd { .. }),
4114 "Expected TurnEnd event, got {turn_end_event:?}"
4115 );
4116 match turn_end_event {
4117 AgentEvent::TurnEnd {
4118 message,
4119 tool_results,
4120 ..
4121 } => (
4122 matches!(message, Message::Assistant(_)),
4123 tool_results.is_empty(),
4124 ),
4125 _ => (false, false),
4126 }
4127 };
4128 drop(events);
4129 assert!(message_is_assistant);
4130 assert!(tool_results_empty);
4131 });
4132 }
4133
4134 #[test]
4135 fn turn_events_include_tool_execution_and_tool_result_messages() {
4136 let runtime = RuntimeBuilder::current_thread()
4137 .build()
4138 .expect("runtime build");
4139 let handle = runtime.handle();
4140
4141 let provider = Arc::new(ToolTurnProvider::new());
4142 let tools = ToolRegistry::from_tools(vec![Box::new(EchoTool)]);
4143 let agent = Agent::new(provider, tools, AgentConfig::default());
4144 let session = Arc::new(Mutex::new(Session::in_memory()));
4145 let mut agent_session =
4146 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4147
4148 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4149 Arc::new(std::sync::Mutex::new(Vec::new()));
4150 let events_capture = Arc::clone(&events);
4151
4152 let join = handle.spawn(async move {
4153 agent_session
4154 .run_text("hello".to_string(), move |event| {
4155 events_capture.lock().expect("events lock").push(event);
4156 })
4157 .await
4158 .expect("run_text")
4159 });
4160
4161 runtime.block_on(async move {
4162 let message = join.await;
4163 assert_eq!(message.stop_reason, StopReason::Stop);
4164
4165 let events = events.lock().expect("events lock");
4166 let turn_start_count = events
4167 .iter()
4168 .filter(|event| matches!(event, AgentEvent::TurnStart { .. }))
4169 .count();
4170 let turn_end_count = events
4171 .iter()
4172 .filter(|event| matches!(event, AgentEvent::TurnEnd { .. }))
4173 .count();
4174 assert_eq!(
4175 turn_start_count, 2,
4176 "expected one tool turn and one final turn"
4177 );
4178 assert_eq!(
4179 turn_end_count, 2,
4180 "expected one tool turn and one final turn"
4181 );
4182
4183 let tool_start_idx = events
4184 .iter()
4185 .position(|event| matches!(event, AgentEvent::ToolExecutionStart { .. }))
4186 .expect("tool execution start event");
4187 let tool_end_idx = events
4188 .iter()
4189 .position(|event| matches!(event, AgentEvent::ToolExecutionEnd { .. }))
4190 .expect("tool execution end event");
4191 assert!(tool_start_idx < tool_end_idx);
4192
4193 let first_turn_end_idx = events
4194 .iter()
4195 .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
4196 .expect("first turn end");
4197 assert!(
4198 tool_end_idx < first_turn_end_idx,
4199 "tool execution should complete before first turn end"
4200 );
4201
4202 let first_turn_tool_results = events.iter().find_map(|event| match event {
4203 AgentEvent::TurnEnd {
4204 turn_index,
4205 tool_results,
4206 ..
4207 } if *turn_index == 0 => Some(tool_results),
4208 _ => None,
4209 });
4210
4211 let Some(first_turn_tool_results) = first_turn_tool_results else {
4212 panic!("missing first turn tool results");
4213 };
4214 assert_eq!(first_turn_tool_results.len(), 1);
4215 let first_result = first_turn_tool_results.first().unwrap();
4216 if let Message::ToolResult(tr) = first_result {
4217 assert_eq!(tr.tool_name, "echo_tool");
4218 assert!(!tr.is_error);
4219 } else {
4220 panic!("expected ToolResult message");
4221 }
4222 drop(events);
4223 });
4224 }
4225}
4226
4227impl AgentSession {
4228 pub const fn runtime_repair_mode_from_policy_mode(mode: RepairPolicyMode) -> RepairMode {
4229 match mode {
4230 RepairPolicyMode::Off => RepairMode::Off,
4231 RepairPolicyMode::Suggest => RepairMode::Suggest,
4232 RepairPolicyMode::AutoSafe => RepairMode::AutoSafe,
4233 RepairPolicyMode::AutoStrict => RepairMode::AutoStrict,
4234 }
4235 }
4236
4237 #[allow(clippy::too_many_arguments)]
4238 async fn start_js_extension_runtime(
4239 stage: &'static str,
4240 cwd: &std::path::Path,
4241 tools: Arc<ToolRegistry>,
4242 manager: ExtensionManager,
4243 policy: ExtensionPolicy,
4244 repair_mode: RepairMode,
4245 memory_limit_bytes: usize,
4246 ) -> Result<ExtensionRuntimeHandle> {
4247 let mut config = PiJsRuntimeConfig {
4248 cwd: cwd.display().to_string(),
4249 repair_mode,
4250 ..PiJsRuntimeConfig::default()
4251 };
4252 config.limits.memory_limit_bytes = Some(memory_limit_bytes).filter(|bytes| *bytes > 0);
4253
4254 let runtime =
4255 JsExtensionRuntimeHandle::start_with_policy(config, tools, manager, policy).await?;
4256 tracing::info!(
4257 event = "pi.extension_runtime.engine_decision",
4258 stage,
4259 requested = "quickjs",
4260 selected = "quickjs",
4261 fallback = false,
4262 "Extension runtime engine selected (legacy JS/TS)"
4263 );
4264 Ok(ExtensionRuntimeHandle::Js(runtime))
4265 }
4266
4267 #[allow(clippy::too_many_arguments)]
4268 async fn start_native_extension_runtime(
4269 stage: &'static str,
4270 _cwd: &std::path::Path,
4271 _tools: Arc<ToolRegistry>,
4272 _manager: ExtensionManager,
4273 _policy: ExtensionPolicy,
4274 _repair_mode: RepairMode,
4275 _memory_limit_bytes: usize,
4276 ) -> Result<ExtensionRuntimeHandle> {
4277 let runtime = NativeRustExtensionRuntimeHandle::start().await?;
4278 tracing::info!(
4279 event = "pi.extension_runtime.engine_decision",
4280 stage,
4281 requested = "native-rust",
4282 selected = "native-rust",
4283 fallback = false,
4284 "Extension runtime engine selected (native-rust)"
4285 );
4286 Ok(ExtensionRuntimeHandle::NativeRust(runtime))
4287 }
4288
4289 pub fn new(
4290 agent: Agent,
4291 session: Arc<Mutex<Session>>,
4292 save_enabled: bool,
4293 compaction_settings: ResolvedCompactionSettings,
4294 ) -> Self {
4295 Self {
4296 agent,
4297 session,
4298 save_enabled,
4299 extensions: None,
4300 extensions_is_streaming: Arc::new(AtomicBool::new(false)),
4301 compaction_settings,
4302 compaction_worker: CompactionWorkerState::new(CompactionQuota::default()),
4303 model_registry: None,
4304 auth_storage: None,
4305 }
4306 }
4307
4308 #[must_use]
4309 pub fn with_model_registry(mut self, registry: ModelRegistry) -> Self {
4310 self.model_registry = Some(registry);
4311 self
4312 }
4313
4314 #[must_use]
4315 pub fn with_auth_storage(mut self, auth: AuthStorage) -> Self {
4316 self.auth_storage = Some(auth);
4317 self
4318 }
4319
4320 pub fn set_model_registry(&mut self, registry: ModelRegistry) {
4321 self.model_registry = Some(registry);
4322 }
4323
4324 pub fn set_auth_storage(&mut self, auth: AuthStorage) {
4325 self.auth_storage = Some(auth);
4326 }
4327
4328 pub async fn set_provider_model(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
4329 {
4330 let cx = crate::agent_cx::AgentCx::for_request();
4331 let mut session = self
4332 .session
4333 .lock(cx.cx())
4334 .await
4335 .map_err(|e| Error::session(e.to_string()))?;
4336 session.set_model_header(
4337 Some(provider_id.to_string()),
4338 Some(model_id.to_string()),
4339 None,
4340 );
4341 }
4342
4343 self.apply_session_model_selection(provider_id, model_id);
4344 let provider = self.agent.provider();
4345 if provider.name() != provider_id || provider.model_id() != model_id {
4346 return Err(Error::validation(format!(
4347 "Unable to switch provider/model to {provider_id}/{model_id}"
4348 )));
4349 }
4350
4351 self.persist_session().await
4352 }
4353
4354 fn resolve_stream_api_key_for_model(&self, entry: &ModelEntry) -> Option<String> {
4355 let normalize = |key_opt: Option<String>| {
4356 key_opt.and_then(|key| {
4357 let trimmed = key.trim();
4358 (!trimmed.is_empty()).then(|| trimmed.to_string())
4359 })
4360 };
4361
4362 self.auth_storage
4363 .as_ref()
4364 .and_then(|auth| normalize(auth.resolve_api_key(&entry.model.provider, None)))
4365 .or_else(|| normalize(entry.api_key.clone()))
4366 }
4367
4368 fn apply_session_model_selection(&mut self, provider_id: &str, model_id: &str) {
4369 if self.agent.provider().name() == provider_id
4370 && self.agent.provider().model_id() == model_id
4371 {
4372 return;
4373 }
4374
4375 let Some(registry) = &self.model_registry else {
4376 return;
4377 };
4378
4379 let Some(entry) = registry.find(provider_id, model_id) else {
4380 tracing::warn!("Session model {provider_id}/{model_id} not found in model registry");
4381 return;
4382 };
4383
4384 match crate::providers::create_provider(
4385 &entry,
4386 self.extensions.as_ref().map(ExtensionRegion::manager),
4387 ) {
4388 Ok(provider) => {
4389 tracing::info!("Updating agent provider to {provider_id}/{model_id}");
4390 self.agent.set_provider(provider);
4391
4392 let resolved_key = self.resolve_stream_api_key_for_model(&entry);
4393 if resolved_key.is_none() {
4394 tracing::warn!(
4395 "No API key resolved for session model {provider_id}/{model_id}; clearing stream API key"
4396 );
4397 }
4398
4399 let stream_options = self.agent.stream_options_mut();
4400 stream_options.api_key = resolved_key;
4401 stream_options.headers.clone_from(&entry.headers);
4402 }
4403 Err(e) => {
4404 tracing::warn!("Failed to create provider for session model: {e}");
4405 }
4406 }
4407 }
4408
4409 pub const fn save_enabled(&self) -> bool {
4410 self.save_enabled
4411 }
4412
4413 pub async fn compact_now(
4415 &mut self,
4416 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4417 ) -> Result<()> {
4418 self.compact_synchronous(Arc::new(on_event)).await
4419 }
4420
4421 async fn maybe_compact(&mut self, on_event: AgentEventHandler) -> Result<()> {
4427 if !self.compaction_settings.enabled {
4428 return Ok(());
4429 }
4430
4431 if let Some(outcome) = self.compaction_worker.try_recv() {
4433 match outcome {
4434 Ok(result) => {
4435 self.apply_compaction_result(result, Arc::clone(&on_event))
4436 .await?;
4437 }
4438 Err(e) => {
4439 on_event(AgentEvent::AutoCompactionEnd {
4440 result: None,
4441 aborted: false,
4442 will_retry: false,
4443 error_message: Some(e.to_string()),
4444 });
4445 }
4446 }
4447 }
4448
4449 if !self.compaction_worker.can_start() {
4451 return Ok(());
4452 }
4453
4454 let preparation = {
4455 let cx = crate::agent_cx::AgentCx::for_request();
4456 let session = self
4457 .session
4458 .lock(cx.cx())
4459 .await
4460 .map_err(|e| Error::session(e.to_string()))?;
4461 let entries = session
4462 .entries_for_current_path()
4463 .into_iter()
4464 .cloned()
4465 .collect::<Vec<_>>();
4466 compaction::prepare_compaction(&entries, self.compaction_settings.clone())
4467 };
4468
4469 if let Some(prep) = preparation {
4470 on_event(AgentEvent::AutoCompactionStart {
4471 reason: "threshold".to_string(),
4472 });
4473
4474 let provider = self.agent.provider();
4475 let api_key = self
4476 .agent
4477 .stream_options()
4478 .api_key
4479 .clone()
4480 .unwrap_or_default();
4481
4482 self.compaction_worker.start(prep, provider, api_key, None);
4483 }
4484
4485 Ok(())
4486 }
4487
4488 async fn apply_compaction_result(
4490 &self,
4491 result: compaction::CompactionResult,
4492 on_event: AgentEventHandler,
4493 ) -> Result<()> {
4494 let cx = crate::agent_cx::AgentCx::for_request();
4495 let mut session = self
4496 .session
4497 .lock(cx.cx())
4498 .await
4499 .map_err(|e| Error::session(e.to_string()))?;
4500
4501 let details = compaction::compaction_details_to_value(&result.details).ok();
4502 let result_value = details.clone();
4503
4504 session.append_compaction(
4505 result.summary,
4506 result.first_kept_entry_id,
4507 result.tokens_before,
4508 details,
4509 None, );
4511
4512 if self.save_enabled {
4513 session
4514 .flush_autosave(AutosaveFlushTrigger::Periodic)
4515 .await?;
4516 }
4517
4518 on_event(AgentEvent::AutoCompactionEnd {
4519 result: result_value,
4520 aborted: false,
4521 will_retry: false,
4522 error_message: None,
4523 });
4524
4525 Ok(())
4526 }
4527
4528 async fn compact_synchronous(&self, on_event: AgentEventHandler) -> Result<()> {
4530 if !self.compaction_settings.enabled {
4531 return Ok(());
4532 }
4533
4534 let preparation = {
4535 let cx = crate::agent_cx::AgentCx::for_request();
4536 let session = self
4537 .session
4538 .lock(cx.cx())
4539 .await
4540 .map_err(|e| Error::session(e.to_string()))?;
4541 let entries = session
4542 .entries_for_current_path()
4543 .into_iter()
4544 .cloned()
4545 .collect::<Vec<_>>();
4546 compaction::prepare_compaction(&entries, self.compaction_settings.clone())
4547 };
4548
4549 if let Some(prep) = preparation {
4550 on_event(AgentEvent::AutoCompactionStart {
4551 reason: "threshold".to_string(),
4552 });
4553
4554 let provider = self.agent.provider();
4555 let api_key = self
4556 .agent
4557 .stream_options()
4558 .api_key
4559 .clone()
4560 .unwrap_or_default();
4561
4562 match compaction::compact(prep, provider, &api_key, None).await {
4563 Ok(result) => {
4564 self.apply_compaction_result(result, Arc::clone(&on_event))
4565 .await?;
4566 }
4567 Err(e) => {
4568 on_event(AgentEvent::AutoCompactionEnd {
4569 result: None,
4570 aborted: false,
4571 will_retry: false,
4572 error_message: Some(e.to_string()),
4573 });
4574 return Err(e);
4575 }
4576 }
4577 }
4578 Ok(())
4579 }
4580
4581 #[allow(clippy::too_many_arguments)]
4582 pub async fn enable_extensions(
4583 &mut self,
4584 enabled_tools: &[&str],
4585 cwd: &std::path::Path,
4586 config: Option<&crate::config::Config>,
4587 extension_entries: &[std::path::PathBuf],
4588 ) -> Result<()> {
4589 self.enable_extensions_with_policy(
4590 enabled_tools,
4591 cwd,
4592 config,
4593 extension_entries,
4594 None,
4595 None,
4596 None,
4597 )
4598 .await
4599 }
4600
4601 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
4602 pub async fn enable_extensions_with_policy(
4603 &mut self,
4604 enabled_tools: &[&str],
4605 cwd: &std::path::Path,
4606 config: Option<&crate::config::Config>,
4607 extension_entries: &[std::path::PathBuf],
4608 policy: Option<ExtensionPolicy>,
4609 repair_policy: Option<RepairPolicyMode>,
4610 pre_warmed: Option<PreWarmedExtensionRuntime>,
4611 ) -> Result<()> {
4612 let mut js_specs: Vec<JsExtensionLoadSpec> = Vec::new();
4613 let mut native_specs: Vec<NativeRustExtensionLoadSpec> = Vec::new();
4614 #[cfg(feature = "wasm-host")]
4615 let mut wasm_specs: Vec<WasmExtensionLoadSpec> = Vec::new();
4616
4617 for entry in extension_entries {
4618 match resolve_extension_load_spec(entry)? {
4619 ExtensionLoadSpec::Js(spec) => js_specs.push(spec),
4620 ExtensionLoadSpec::NativeRust(spec) => native_specs.push(spec),
4621 #[cfg(feature = "wasm-host")]
4622 ExtensionLoadSpec::Wasm(spec) => wasm_specs.push(spec),
4623 }
4624 }
4625
4626 if !js_specs.is_empty() && !native_specs.is_empty() {
4627 return Err(Error::validation(
4628 "Mixed extension runtimes are not supported in one session yet. Use either JS/TS extensions (QuickJS) or native-rust descriptors (*.native.json), but not both at once."
4629 .to_string(),
4630 ));
4631 }
4632
4633 let resolved_policy = policy.clone().unwrap_or_default();
4634 let resolved_repair_policy = repair_policy
4635 .or_else(|| config.map(|cfg| cfg.resolve_repair_policy(None)))
4636 .unwrap_or(RepairPolicyMode::AutoSafe);
4637 let runtime_repair_mode =
4638 Self::runtime_repair_mode_from_policy_mode(resolved_repair_policy);
4639 let memory_limit_bytes =
4640 (resolved_policy.max_memory_mb as usize).saturating_mul(1024 * 1024);
4641 let wants_js_runtime = !js_specs.is_empty();
4642
4643 #[allow(unused_variables)]
4646 let (manager, tools) = if let Some(pre) = pre_warmed {
4647 let manager = pre.manager;
4648 let tools = pre.tools;
4649 let runtime = match pre.runtime {
4650 ExtensionRuntimeHandle::NativeRust(runtime) => {
4651 if wants_js_runtime {
4652 tracing::warn!(
4653 event = "pi.extension_runtime.prewarm.mismatch",
4654 expected = "quickjs",
4655 got = "native-rust",
4656 "Pre-warmed runtime mismatched requested JS mode; creating quickjs runtime"
4657 );
4658 Self::start_js_extension_runtime(
4659 "agent_enable_extensions_prewarm_mismatch",
4660 cwd,
4661 Arc::clone(&tools),
4662 manager.clone(),
4663 resolved_policy.clone(),
4664 runtime_repair_mode,
4665 memory_limit_bytes,
4666 )
4667 .await?
4668 } else {
4669 tracing::info!(
4670 event = "pi.extension_runtime.engine_decision",
4671 stage = "agent_enable_extensions_prewarmed",
4672 requested = "native-rust",
4673 selected = "native-rust",
4674 fallback = false,
4675 "Using pre-warmed extension runtime"
4676 );
4677 ExtensionRuntimeHandle::NativeRust(runtime)
4678 }
4679 }
4680 ExtensionRuntimeHandle::Js(runtime) => {
4681 if wants_js_runtime {
4682 tracing::info!(
4683 event = "pi.extension_runtime.engine_decision",
4684 stage = "agent_enable_extensions_prewarmed",
4685 requested = "quickjs",
4686 selected = "quickjs",
4687 fallback = false,
4688 "Using pre-warmed extension runtime"
4689 );
4690 ExtensionRuntimeHandle::Js(runtime)
4691 } else {
4692 tracing::warn!(
4693 event = "pi.extension_runtime.prewarm.mismatch",
4694 expected = "native-rust",
4695 got = "quickjs",
4696 "Pre-warmed runtime mismatched requested native mode; creating native-rust runtime"
4697 );
4698 Self::start_native_extension_runtime(
4699 "agent_enable_extensions_prewarm_mismatch",
4700 cwd,
4701 Arc::clone(&tools),
4702 manager.clone(),
4703 resolved_policy.clone(),
4704 runtime_repair_mode,
4705 memory_limit_bytes,
4706 )
4707 .await?
4708 }
4709 }
4710 };
4711 manager.set_runtime(runtime);
4712 (manager, tools)
4713 } else {
4714 let manager = ExtensionManager::new();
4715 manager.set_cwd(cwd.display().to_string());
4716 let tools = Arc::new(ToolRegistry::new(enabled_tools, cwd, config));
4717
4718 if let Some(cfg) = config {
4719 let resolved_risk = cfg.resolve_extension_risk_with_metadata();
4720 tracing::info!(
4721 event = "pi.extension_runtime_risk.config",
4722 source = resolved_risk.source,
4723 enabled = resolved_risk.settings.enabled,
4724 alpha = resolved_risk.settings.alpha,
4725 window_size = resolved_risk.settings.window_size,
4726 ledger_limit = resolved_risk.settings.ledger_limit,
4727 fail_closed = resolved_risk.settings.fail_closed,
4728 "Resolved extension runtime risk settings"
4729 );
4730 manager.set_runtime_risk_config(resolved_risk.settings);
4731 }
4732
4733 let runtime = if wants_js_runtime {
4734 Self::start_js_extension_runtime(
4735 "agent_enable_extensions_boot",
4736 cwd,
4737 Arc::clone(&tools),
4738 manager.clone(),
4739 resolved_policy,
4740 runtime_repair_mode,
4741 memory_limit_bytes,
4742 )
4743 .await?
4744 } else {
4745 Self::start_native_extension_runtime(
4746 "agent_enable_extensions_boot",
4747 cwd,
4748 Arc::clone(&tools),
4749 manager.clone(),
4750 resolved_policy,
4751 runtime_repair_mode,
4752 memory_limit_bytes,
4753 )
4754 .await?
4755 };
4756 manager.set_runtime(runtime);
4757 (manager, tools)
4758 };
4759
4760 manager.set_session(Arc::new(SessionHandle(self.session.clone())));
4764
4765 let injected = Arc::new(StdMutex::new(ExtensionInjectedQueue::default()));
4766 let host_actions = AgentSessionHostActions {
4767 session: Arc::clone(&self.session),
4768 injected: Arc::clone(&injected),
4769 is_streaming: Arc::clone(&self.extensions_is_streaming),
4770 };
4771 manager.set_host_actions(Arc::new(host_actions));
4772 {
4773 let steering_queue = Arc::clone(&injected);
4774 let follow_up_queue = Arc::clone(&injected);
4775 let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
4776 let steering_queue = Arc::clone(&steering_queue);
4777 Box::pin(async move {
4778 let Ok(mut queue) = steering_queue.lock() else {
4779 return Vec::new();
4780 };
4781 queue.pop_steering()
4782 })
4783 };
4784 let follow_up_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
4785 let follow_up_queue = Arc::clone(&follow_up_queue);
4786 Box::pin(async move {
4787 let Ok(mut queue) = follow_up_queue.lock() else {
4788 return Vec::new();
4789 };
4790 queue.pop_follow_up()
4791 })
4792 };
4793 self.agent.register_message_fetchers(
4794 Some(Arc::new(steering_fetcher)),
4795 Some(Arc::new(follow_up_fetcher)),
4796 );
4797 }
4798
4799 if !js_specs.is_empty() {
4800 manager.load_js_extensions(js_specs).await?;
4801 }
4802
4803 if !native_specs.is_empty() {
4804 manager.load_native_extensions(native_specs).await?;
4805 }
4806
4807 if let Some(rt) = manager.runtime() {
4809 let events = rt.drain_repair_events().await;
4810 if !events.is_empty() {
4811 log_repair_diagnostics(&events);
4812 }
4813 }
4814
4815 #[cfg(feature = "wasm-host")]
4816 if !wasm_specs.is_empty() {
4817 let host = WasmExtensionHost::new(cwd, policy.unwrap_or_default())?;
4818 manager
4819 .load_wasm_extensions(&host, wasm_specs, Arc::clone(&tools))
4820 .await?;
4821 }
4822
4823 let session_path = {
4826 let cx = crate::agent_cx::AgentCx::for_request();
4827 let session = self
4828 .session
4829 .lock(cx.cx())
4830 .await
4831 .map_err(|e| Error::extension(e.to_string()))?;
4832 session.path.as_ref().map(|p| p.display().to_string())
4833 };
4834
4835 if let Err(err) = manager
4836 .dispatch_event(
4837 ExtensionEventName::Startup,
4838 Some(serde_json::json!({
4839 "version": env!("CARGO_PKG_VERSION"),
4840 "sessionFile": session_path,
4841 })),
4842 )
4843 .await
4844 {
4845 tracing::warn!("startup extension hook failed (fail-open): {err}");
4846 }
4847
4848 let ctx_payload = serde_json::json!({ "cwd": cwd.display().to_string() });
4849 let wrappers = collect_extension_tool_wrappers(&manager, ctx_payload).await?;
4850 self.agent.extend_tools(wrappers);
4851 self.agent.extensions = Some(manager.clone());
4852 self.extensions = Some(ExtensionRegion::new(manager));
4853 Ok(())
4854 }
4855
4856 pub async fn save_and_index(&mut self) -> Result<()> {
4857 if self.save_enabled {
4858 let cx = crate::agent_cx::AgentCx::for_request();
4859 let mut session = self
4860 .session
4861 .lock(cx.cx())
4862 .await
4863 .map_err(|e| Error::session(e.to_string()))?;
4864 session
4865 .flush_autosave(AutosaveFlushTrigger::Periodic)
4866 .await?;
4867 }
4868 Ok(())
4869 }
4870
4871 pub async fn persist_session(&mut self) -> Result<()> {
4872 if !self.save_enabled {
4873 return Ok(());
4874 }
4875 let cx = crate::agent_cx::AgentCx::for_request();
4876 let mut session = self
4877 .session
4878 .lock(cx.cx())
4879 .await
4880 .map_err(|e| Error::session(e.to_string()))?;
4881 session
4882 .flush_autosave(AutosaveFlushTrigger::Periodic)
4883 .await?;
4884 Ok(())
4885 }
4886
4887 pub async fn run_text(
4888 &mut self,
4889 input: String,
4890 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4891 ) -> Result<AssistantMessage> {
4892 self.run_text_with_abort(input, None, on_event).await
4893 }
4894
4895 pub async fn run_text_with_abort(
4896 &mut self,
4897 input: String,
4898 abort: Option<AbortSignal>,
4899 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4900 ) -> Result<AssistantMessage> {
4901 let outcome = self.dispatch_input_event(input, Vec::new()).await?;
4902 let (text, images) = match outcome {
4903 InputEventOutcome::Continue { text, images } => (text, images),
4904 InputEventOutcome::Block { reason } => {
4905 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
4906 return Err(Error::extension(message));
4907 }
4908 };
4909
4910 self.dispatch_before_agent_start().await;
4911
4912 if images.is_empty() {
4913 self.run_agent_with_text(text, abort, on_event).await
4914 } else {
4915 let content = Self::build_content_blocks_for_input(&text, &images);
4916 self.run_agent_with_content(content, abort, on_event).await
4917 }
4918 }
4919
4920 pub async fn run_with_content(
4921 &mut self,
4922 content: Vec<ContentBlock>,
4923 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4924 ) -> Result<AssistantMessage> {
4925 self.run_with_content_with_abort(content, None, on_event)
4926 .await
4927 }
4928
4929 pub async fn run_with_content_with_abort(
4930 &mut self,
4931 content: Vec<ContentBlock>,
4932 abort: Option<AbortSignal>,
4933 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4934 ) -> Result<AssistantMessage> {
4935 let (text, images) = Self::split_content_blocks_for_input(&content);
4936 let outcome = self.dispatch_input_event(text, images).await?;
4937 let (text, images) = match outcome {
4938 InputEventOutcome::Continue { text, images } => (text, images),
4939 InputEventOutcome::Block { reason } => {
4940 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
4941 return Err(Error::extension(message));
4942 }
4943 };
4944
4945 self.dispatch_before_agent_start().await;
4946
4947 let content_for_agent = Self::build_content_blocks_for_input(&text, &images);
4948 self.run_agent_with_content(content_for_agent, abort, on_event)
4949 .await
4950 }
4951
4952 async fn dispatch_input_event(
4953 &self,
4954 text: String,
4955 images: Vec<ImageContent>,
4956 ) -> Result<InputEventOutcome> {
4957 let Some(region) = &self.extensions else {
4958 return Ok(InputEventOutcome::Continue { text, images });
4959 };
4960
4961 let images_value = serde_json::to_value(&images).unwrap_or(Value::Null);
4962 let payload = json!({
4963 "text": text,
4964 "images": images_value,
4965 "source": "user",
4966 });
4967
4968 let response = region
4969 .manager()
4970 .dispatch_event_with_response(
4971 ExtensionEventName::Input,
4972 Some(payload),
4973 EXTENSION_EVENT_TIMEOUT_MS,
4974 )
4975 .await?;
4976
4977 Ok(apply_input_event_response(response, text, images))
4978 }
4979
4980 async fn dispatch_before_agent_start(&self) {
4981 if let Some(region) = &self.extensions {
4982 if let Err(err) = region
4983 .manager()
4984 .dispatch_event(ExtensionEventName::BeforeAgentStart, None)
4985 .await
4986 {
4987 tracing::warn!("before_agent_start extension hook failed (fail-open): {err}");
4988 }
4989 }
4990 }
4991
4992 fn split_content_blocks_for_input(blocks: &[ContentBlock]) -> (String, Vec<ImageContent>) {
4993 let mut text = String::new();
4994 let mut images = Vec::new();
4995 for block in blocks {
4996 match block {
4997 ContentBlock::Text(text_block) => {
4998 if !text_block.text.trim().is_empty() {
4999 if !text.is_empty() {
5000 text.push('\n');
5001 }
5002 text.push_str(&text_block.text);
5003 }
5004 }
5005 ContentBlock::Image(image) => images.push(image.clone()),
5006 _ => {}
5007 }
5008 }
5009 (text, images)
5010 }
5011
5012 fn build_content_blocks_for_input(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
5013 let mut content = Vec::new();
5014 if !text.trim().is_empty() {
5015 content.push(ContentBlock::Text(TextContent::new(text.to_string())));
5016 }
5017 for image in images {
5018 content.push(ContentBlock::Image(image.clone()));
5019 }
5020 content
5021 }
5022
5023 pub(crate) async fn run_agent_with_text(
5024 &mut self,
5025 input: String,
5026 abort: Option<AbortSignal>,
5027 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
5028 ) -> Result<AssistantMessage> {
5029 let on_event: AgentEventHandler = Arc::new(on_event);
5030 let session_model = {
5031 let cx = crate::agent_cx::AgentCx::for_request();
5032 let session = self
5033 .session
5034 .lock(cx.cx())
5035 .await
5036 .map_err(|e| Error::session(e.to_string()))?;
5037 (
5038 session.header.provider.clone(),
5039 session.header.model_id.clone(),
5040 )
5041 };
5042
5043 if let (Some(provider_id), Some(model_id)) = session_model {
5044 self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
5045 }
5046
5047 self.maybe_compact(Arc::clone(&on_event)).await?;
5048 let history = {
5049 let cx = crate::agent_cx::AgentCx::for_request();
5050 let session = self
5051 .session
5052 .lock(cx.cx())
5053 .await
5054 .map_err(|e| Error::session(e.to_string()))?;
5055 session.to_messages_for_current_path()
5056 };
5057 self.agent.replace_messages(history);
5058
5059 let start_len = self.agent.messages().len();
5060
5061 let user_message = Message::User(UserMessage {
5063 content: UserContent::Text(input),
5064 timestamp: Utc::now().timestamp_millis(),
5065 });
5066
5067 {
5068 let cx = crate::agent_cx::AgentCx::for_request();
5069 let mut session = self
5070 .session
5071 .lock(cx.cx())
5072 .await
5073 .map_err(|e| Error::session(e.to_string()))?;
5074 session.append_model_message(user_message.clone());
5075 if self.save_enabled {
5076 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
5077 }
5078 }
5079
5080 self.extensions_is_streaming.store(true, Ordering::SeqCst);
5081 let on_event_for_run = Arc::clone(&on_event);
5082 let result = self
5083 .agent
5084 .run_with_message_with_abort(user_message, abort, move |event| {
5085 on_event_for_run(event);
5086 })
5087 .await;
5088 self.extensions_is_streaming.store(false, Ordering::SeqCst);
5089 let result = result?;
5090 self.persist_new_messages(start_len + 1).await?;
5092 Ok(result)
5093 }
5094
5095 pub(crate) async fn run_agent_with_content(
5096 &mut self,
5097 content: Vec<ContentBlock>,
5098 abort: Option<AbortSignal>,
5099 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
5100 ) -> Result<AssistantMessage> {
5101 let on_event: AgentEventHandler = Arc::new(on_event);
5102 let session_model = {
5103 let cx = crate::agent_cx::AgentCx::for_request();
5104 let session = self
5105 .session
5106 .lock(cx.cx())
5107 .await
5108 .map_err(|e| Error::session(e.to_string()))?;
5109 (
5110 session.header.provider.clone(),
5111 session.header.model_id.clone(),
5112 )
5113 };
5114
5115 if let (Some(provider_id), Some(model_id)) = session_model {
5116 self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
5117 }
5118
5119 self.maybe_compact(Arc::clone(&on_event)).await?;
5120 let history = {
5121 let cx = crate::agent_cx::AgentCx::for_request();
5122 let session = self
5123 .session
5124 .lock(cx.cx())
5125 .await
5126 .map_err(|e| Error::session(e.to_string()))?;
5127 session.to_messages_for_current_path()
5128 };
5129 self.agent.replace_messages(history);
5130
5131 let start_len = self.agent.messages().len();
5132
5133 let user_message = Message::User(UserMessage {
5135 content: UserContent::Blocks(content),
5136 timestamp: Utc::now().timestamp_millis(),
5137 });
5138
5139 {
5140 let cx = crate::agent_cx::AgentCx::for_request();
5141 let mut session = self
5142 .session
5143 .lock(cx.cx())
5144 .await
5145 .map_err(|e| Error::session(e.to_string()))?;
5146 session.append_model_message(user_message.clone());
5147 if self.save_enabled {
5148 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
5149 }
5150 }
5151
5152 self.extensions_is_streaming.store(true, Ordering::SeqCst);
5153 let on_event_for_run = Arc::clone(&on_event);
5154 let result = self
5155 .agent
5156 .run_with_message_with_abort(user_message, abort, move |event| {
5157 on_event_for_run(event);
5158 })
5159 .await;
5160 self.extensions_is_streaming.store(false, Ordering::SeqCst);
5161 let result = result?;
5162 self.persist_new_messages(start_len + 1).await?;
5164 Ok(result)
5165 }
5166
5167 async fn persist_new_messages(&self, start_len: usize) -> Result<()> {
5168 let new_messages = self.agent.messages()[start_len..].to_vec();
5169 {
5170 let cx = crate::agent_cx::AgentCx::for_request();
5171 let mut session = self
5172 .session
5173 .lock(cx.cx())
5174 .await
5175 .map_err(|e| Error::session(e.to_string()))?;
5176 for message in new_messages {
5177 session.append_model_message(message);
5178 }
5179 if self.save_enabled {
5180 session
5181 .flush_autosave(AutosaveFlushTrigger::Periodic)
5182 .await?;
5183 }
5184 }
5185 Ok(())
5186 }
5187}
5188
5189fn log_repair_diagnostics(events: &[crate::extensions_js::ExtensionRepairEvent]) {
5198 use std::collections::BTreeMap;
5199
5200 for ev in events {
5202 tracing::info!(
5203 event = "extension.auto_repair",
5204 extension_id = %ev.extension_id,
5205 pattern = %ev.pattern,
5206 success = ev.success,
5207 original_error = %ev.original_error,
5208 repair_action = %ev.repair_action,
5209 );
5210 }
5211
5212 let mut by_pattern: BTreeMap<String, Vec<&str>> = BTreeMap::new();
5214 for ev in events {
5215 by_pattern
5216 .entry(ev.pattern.to_string())
5217 .or_default()
5218 .push(&ev.extension_id);
5219 }
5220
5221 let verbose = std::env::var("PI_AUTO_REPAIR_VERBOSE")
5222 .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
5223
5224 if verbose {
5225 eprintln!(
5226 "[auto-repair] {} extension{} auto-repaired:",
5227 events.len(),
5228 if events.len() == 1 { "" } else { "s" }
5229 );
5230 for ev in events {
5231 eprintln!(
5232 " {}: {} ({})",
5233 ev.pattern, ev.extension_id, ev.repair_action
5234 );
5235 }
5236 } else {
5237 let patterns: Vec<String> = by_pattern
5239 .iter()
5240 .map(|(pat, ids)| format!("{pat}:{}", ids.len()))
5241 .collect();
5242 tracing::info!(
5243 event = "extension.auto_repair.summary",
5244 count = events.len(),
5245 patterns = %patterns.join(", "),
5246 "auto-repaired {} extension(s)",
5247 events.len(),
5248 );
5249 }
5250}
5251
5252const BLOCK_IMAGES_PLACEHOLDER: &str = "Image reading is disabled.";
5253
5254#[derive(Debug, Default, Clone, Copy)]
5255struct ImageFilterStats {
5256 removed_images: usize,
5257 affected_messages: usize,
5258}
5259
5260fn filter_images_for_provider(messages: &mut [Message]) -> ImageFilterStats {
5261 let mut stats = ImageFilterStats::default();
5262 for message in messages {
5263 let removed = filter_images_from_message(message);
5264 if removed > 0 {
5265 stats.removed_images += removed;
5266 stats.affected_messages += 1;
5267 }
5268 }
5269 stats
5270}
5271
5272fn filter_images_from_message(message: &mut Message) -> usize {
5273 match message {
5274 Message::User(user) => match &mut user.content {
5275 UserContent::Text(_) => 0,
5276 UserContent::Blocks(blocks) => filter_image_blocks(blocks),
5277 },
5278 Message::Assistant(assistant) => {
5279 let assistant = Arc::make_mut(assistant);
5280 filter_image_blocks(&mut assistant.content)
5281 }
5282 Message::ToolResult(tool_result) => {
5283 filter_image_blocks(&mut Arc::make_mut(tool_result).content)
5284 }
5285 Message::Custom(_) => 0,
5286 }
5287}
5288
5289fn filter_image_blocks(blocks: &mut Vec<ContentBlock>) -> usize {
5290 let mut removed = 0usize;
5291 let mut filtered = Vec::with_capacity(blocks.len());
5292
5293 for block in blocks.drain(..) {
5294 match block {
5295 ContentBlock::Image(_) => {
5296 removed += 1;
5297 let previous_is_placeholder =
5298 filtered
5299 .last()
5300 .is_some_and(|prev| matches!(prev, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER));
5301 if !previous_is_placeholder {
5302 filtered.push(ContentBlock::Text(TextContent::new(
5303 BLOCK_IMAGES_PLACEHOLDER,
5304 )));
5305 }
5306 }
5307 other => filtered.push(other),
5308 }
5309 }
5310
5311 *blocks = filtered;
5312 removed
5313}
5314
5315fn extract_tool_calls(content: &[ContentBlock]) -> Vec<ToolCall> {
5317 content
5318 .iter()
5319 .filter_map(|block| {
5320 if let ContentBlock::ToolCall(tc) = block {
5321 Some(tc.clone())
5322 } else {
5323 None
5324 }
5325 })
5326 .collect()
5327}
5328
5329#[cfg(test)]
5334mod tests {
5335 use super::*;
5336 use crate::auth::AuthCredential;
5337 use crate::provider::{InputType, Model, ModelCost};
5338 use async_trait::async_trait;
5339 use futures::Stream;
5340 use std::collections::HashMap;
5341 use std::path::Path;
5342 use std::pin::Pin;
5343
5344 fn user_message(text: &str) -> Message {
5345 Message::User(UserMessage {
5346 content: UserContent::Text(text.to_string()),
5347 timestamp: 0,
5348 })
5349 }
5350
5351 fn assert_user_text(message: &Message, expected: &str) {
5352 assert!(
5353 matches!(
5354 message,
5355 Message::User(UserMessage {
5356 content: UserContent::Text(_),
5357 ..
5358 })
5359 ),
5360 "expected user text message, got {message:?}"
5361 );
5362 if let Message::User(UserMessage {
5363 content: UserContent::Text(text),
5364 ..
5365 }) = message
5366 {
5367 assert_eq!(text, expected);
5368 }
5369 }
5370
5371 fn sample_image_block() -> ContentBlock {
5372 ContentBlock::Image(ImageContent {
5373 data: "aGVsbG8=".to_string(),
5374 mime_type: "image/png".to_string(),
5375 })
5376 }
5377
5378 fn image_count_in_message(message: &Message) -> usize {
5379 let count_images = |blocks: &[ContentBlock]| {
5380 blocks
5381 .iter()
5382 .filter(|block| matches!(block, ContentBlock::Image(_)))
5383 .count()
5384 };
5385 match message {
5386 Message::User(UserMessage {
5387 content: UserContent::Blocks(blocks),
5388 ..
5389 }) => count_images(blocks),
5390 Message::Assistant(msg) => count_images(&msg.content),
5391 Message::ToolResult(tool_result) => count_images(&tool_result.content),
5392 Message::User(UserMessage {
5393 content: UserContent::Text(_),
5394 ..
5395 })
5396 | Message::Custom(_) => 0,
5397 }
5398 }
5399
5400 #[derive(Debug)]
5401 struct SilentProvider;
5402
5403 #[async_trait]
5404 #[allow(clippy::unnecessary_literal_bound)]
5405 impl Provider for SilentProvider {
5406 fn name(&self) -> &str {
5407 "silent-provider"
5408 }
5409
5410 fn api(&self) -> &str {
5411 "test-api"
5412 }
5413
5414 fn model_id(&self) -> &str {
5415 "test-model"
5416 }
5417
5418 async fn stream(
5419 &self,
5420 _context: &Context<'_>,
5421 _options: &StreamOptions,
5422 ) -> crate::error::Result<
5423 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
5424 > {
5425 Ok(Box::pin(futures::stream::empty()))
5426 }
5427 }
5428
5429 #[test]
5430 fn test_extract_tool_calls() {
5431 let content = vec![
5432 ContentBlock::Text(TextContent::new("Hello")),
5433 ContentBlock::ToolCall(ToolCall {
5434 id: "tc1".to_string(),
5435 name: "read".to_string(),
5436 arguments: serde_json::json!({"path": "file.txt"}),
5437 thought_signature: None,
5438 }),
5439 ContentBlock::Text(TextContent::new("World")),
5440 ContentBlock::ToolCall(ToolCall {
5441 id: "tc2".to_string(),
5442 name: "bash".to_string(),
5443 arguments: serde_json::json!({"command": "ls"}),
5444 thought_signature: None,
5445 }),
5446 ];
5447
5448 let tool_calls = extract_tool_calls(&content);
5449 assert_eq!(tool_calls.len(), 2);
5450 assert_eq!(tool_calls[0].name, "read");
5451 assert_eq!(tool_calls[1].name, "bash");
5452 }
5453
5454 #[test]
5455 fn test_agent_config_default() {
5456 let config = AgentConfig::default();
5457 assert_eq!(config.max_tool_iterations, 50);
5458 assert!(config.system_prompt.is_none());
5459 assert!(!config.block_images);
5460 }
5461
5462 #[test]
5463 fn filter_image_blocks_replaces_images_with_deduped_placeholder_text() {
5464 let mut blocks = vec![
5465 sample_image_block(),
5466 sample_image_block(),
5467 ContentBlock::Text(TextContent::new("tail")),
5468 sample_image_block(),
5469 ];
5470
5471 let removed = filter_image_blocks(&mut blocks);
5472
5473 assert_eq!(removed, 3);
5474 assert!(
5475 !blocks
5476 .iter()
5477 .any(|block| matches!(block, ContentBlock::Image(_)))
5478 );
5479 assert!(matches!(
5480 blocks.first(),
5481 Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
5482 ));
5483 assert!(matches!(
5484 blocks.get(1),
5485 Some(ContentBlock::Text(TextContent { text, .. })) if text == "tail"
5486 ));
5487 assert!(matches!(
5488 blocks.get(2),
5489 Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
5490 ));
5491 }
5492
5493 #[test]
5494 fn filter_images_for_provider_filters_images_from_all_block_message_types() {
5495 let mut messages = vec![
5496 Message::User(UserMessage {
5497 content: UserContent::Blocks(vec![
5498 ContentBlock::Text(TextContent::new("hello")),
5499 sample_image_block(),
5500 ]),
5501 timestamp: 0,
5502 }),
5503 Message::Assistant(Arc::new(AssistantMessage {
5504 content: vec![sample_image_block()],
5505 api: "test".to_string(),
5506 provider: "test".to_string(),
5507 model: "test".to_string(),
5508 usage: Usage::default(),
5509 stop_reason: StopReason::Stop,
5510 error_message: None,
5511 timestamp: 0,
5512 })),
5513 Message::tool_result(ToolResultMessage {
5514 tool_call_id: "tc1".to_string(),
5515 tool_name: "read".to_string(),
5516 content: vec![
5517 sample_image_block(),
5518 ContentBlock::Text(TextContent::new("ok")),
5519 ],
5520 details: None,
5521 is_error: false,
5522 timestamp: 0,
5523 }),
5524 ];
5525
5526 let stats = filter_images_for_provider(&mut messages);
5527
5528 assert_eq!(stats.removed_images, 3);
5529 assert_eq!(stats.affected_messages, 3);
5530 assert_eq!(
5531 messages.iter().map(image_count_in_message).sum::<usize>(),
5532 0,
5533 "no images should remain in provider-bound context"
5534 );
5535 }
5536
5537 #[test]
5538 fn build_context_strips_images_when_block_images_enabled() {
5539 let mut agent = Agent::new(
5540 Arc::new(SilentProvider),
5541 ToolRegistry::new(&[], Path::new("."), None),
5542 AgentConfig {
5543 system_prompt: None,
5544 max_tool_iterations: 50,
5545 stream_options: StreamOptions::default(),
5546 block_images: true,
5547 },
5548 );
5549 agent.add_message(Message::User(UserMessage {
5550 content: UserContent::Blocks(vec![sample_image_block()]),
5551 timestamp: 0,
5552 }));
5553
5554 let context = agent.build_context();
5555 assert_eq!(context.messages.len(), 1);
5556 assert_eq!(image_count_in_message(&context.messages[0]), 0);
5557 assert!(matches!(
5558 &context.messages[0],
5559 Message::User(UserMessage {
5560 content: UserContent::Blocks(blocks),
5561 ..
5562 }) if blocks
5563 .iter()
5564 .any(|block| matches!(block, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER))
5565 ));
5566 }
5567
5568 #[test]
5569 fn build_context_keeps_images_when_block_images_disabled() {
5570 let mut agent = Agent::new(
5571 Arc::new(SilentProvider),
5572 ToolRegistry::new(&[], Path::new("."), None),
5573 AgentConfig {
5574 system_prompt: None,
5575 max_tool_iterations: 50,
5576 stream_options: StreamOptions::default(),
5577 block_images: false,
5578 },
5579 );
5580 agent.add_message(Message::User(UserMessage {
5581 content: UserContent::Blocks(vec![sample_image_block()]),
5582 timestamp: 0,
5583 }));
5584
5585 let context = agent.build_context();
5586 assert_eq!(context.messages.len(), 1);
5587 assert_eq!(image_count_in_message(&context.messages[0]), 1);
5588 }
5589
5590 #[test]
5591 fn auto_compaction_start_serializes_with_pi_mono_compatible_type_tag() {
5592 let event = AgentEvent::AutoCompactionStart {
5593 reason: "threshold".to_string(),
5594 };
5595 let json = serde_json::to_value(&event).unwrap();
5596 assert_eq!(json["type"], "auto_compaction_start");
5597 assert_eq!(json["reason"], "threshold");
5598 }
5599
5600 #[test]
5601 fn auto_compaction_end_serializes_with_pi_mono_compatible_fields() {
5602 let event = AgentEvent::AutoCompactionEnd {
5603 result: Some(serde_json::json!({"tokens_before": 5000, "tokens_after": 2000})),
5604 aborted: false,
5605 will_retry: false,
5606 error_message: None,
5607 };
5608 let json = serde_json::to_value(&event).unwrap();
5609 assert_eq!(json["type"], "auto_compaction_end");
5610 assert_eq!(json["aborted"], false);
5611 assert_eq!(json["willRetry"], false);
5612 assert!(json.get("errorMessage").is_none()); assert!(json["result"].is_object());
5614 }
5615
5616 #[test]
5617 fn auto_compaction_end_includes_error_message_when_present() {
5618 let event = AgentEvent::AutoCompactionEnd {
5619 result: None,
5620 aborted: true,
5621 will_retry: false,
5622 error_message: Some("Compaction failed".to_string()),
5623 };
5624 let json = serde_json::to_value(&event).unwrap();
5625 assert_eq!(json["type"], "auto_compaction_end");
5626 assert_eq!(json["aborted"], true);
5627 assert_eq!(json["errorMessage"], "Compaction failed");
5628 }
5629
5630 #[test]
5631 fn auto_retry_start_serializes_with_camel_case_fields() {
5632 let event = AgentEvent::AutoRetryStart {
5633 attempt: 1,
5634 max_attempts: 3,
5635 delay_ms: 2000,
5636 error_message: "Rate limited".to_string(),
5637 };
5638 let json = serde_json::to_value(&event).unwrap();
5639 assert_eq!(json["type"], "auto_retry_start");
5640 assert_eq!(json["attempt"], 1);
5641 assert_eq!(json["maxAttempts"], 3);
5642 assert_eq!(json["delayMs"], 2000);
5643 assert_eq!(json["errorMessage"], "Rate limited");
5644 }
5645
5646 #[test]
5647 fn auto_retry_end_serializes_success_and_omits_null_final_error() {
5648 let event = AgentEvent::AutoRetryEnd {
5649 success: true,
5650 attempt: 2,
5651 final_error: None,
5652 };
5653 let json = serde_json::to_value(&event).unwrap();
5654 assert_eq!(json["type"], "auto_retry_end");
5655 assert_eq!(json["success"], true);
5656 assert_eq!(json["attempt"], 2);
5657 assert!(json.get("finalError").is_none());
5658 }
5659
5660 #[test]
5661 fn auto_retry_end_includes_final_error_on_failure() {
5662 let event = AgentEvent::AutoRetryEnd {
5663 success: false,
5664 attempt: 3,
5665 final_error: Some("Max retries exceeded".to_string()),
5666 };
5667 let json = serde_json::to_value(&event).unwrap();
5668 assert_eq!(json["type"], "auto_retry_end");
5669 assert_eq!(json["success"], false);
5670 assert_eq!(json["attempt"], 3);
5671 assert_eq!(json["finalError"], "Max retries exceeded");
5672 }
5673
5674 #[test]
5675 fn message_queue_push_increments_seq_and_counts_both_queues() {
5676 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5677 assert_eq!(queue.pending_count(), 0);
5678
5679 assert_eq!(queue.push_steering(user_message("s1")), 0);
5680 assert_eq!(queue.push_follow_up(user_message("f1")), 1);
5681 assert_eq!(queue.push_steering(user_message("s2")), 2);
5682
5683 assert_eq!(queue.pending_count(), 3);
5684 }
5685
5686 #[test]
5687 fn message_queue_pop_steering_one_at_a_time_preserves_order() {
5688 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5689 queue.push_steering(user_message("s1"));
5690 queue.push_steering(user_message("s2"));
5691
5692 let first = queue.pop_steering();
5693 assert_eq!(first.len(), 1);
5694 assert_user_text(&first[0], "s1");
5695 assert_eq!(queue.pending_count(), 1);
5696
5697 let second = queue.pop_steering();
5698 assert_eq!(second.len(), 1);
5699 assert_user_text(&second[0], "s2");
5700 assert_eq!(queue.pending_count(), 0);
5701
5702 let empty = queue.pop_steering();
5703 assert!(empty.is_empty());
5704 }
5705
5706 #[test]
5707 fn message_queue_pop_respects_queue_modes_per_kind() {
5708 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
5709 queue.push_steering(user_message("s1"));
5710 queue.push_steering(user_message("s2"));
5711 queue.push_follow_up(user_message("f1"));
5712 queue.push_follow_up(user_message("f2"));
5713
5714 let steering = queue.pop_steering();
5715 assert_eq!(steering.len(), 2);
5716 assert_user_text(&steering[0], "s1");
5717 assert_user_text(&steering[1], "s2");
5718 assert_eq!(queue.pending_count(), 2);
5719
5720 let follow_up = queue.pop_follow_up();
5721 assert_eq!(follow_up.len(), 1);
5722 assert_user_text(&follow_up[0], "f1");
5723 assert_eq!(queue.pending_count(), 1);
5724
5725 let follow_up = queue.pop_follow_up();
5726 assert_eq!(follow_up.len(), 1);
5727 assert_user_text(&follow_up[0], "f2");
5728 assert_eq!(queue.pending_count(), 0);
5729 }
5730
5731 #[test]
5732 fn message_queue_set_modes_applies_to_existing_messages() {
5733 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5734 queue.push_steering(user_message("s1"));
5735 queue.push_steering(user_message("s2"));
5736
5737 let first = queue.pop_steering();
5738 assert_eq!(first.len(), 1);
5739 assert_user_text(&first[0], "s1");
5740
5741 queue.set_modes(QueueMode::All, QueueMode::OneAtATime);
5742 let remaining = queue.pop_steering();
5743 assert_eq!(remaining.len(), 1);
5744 assert_user_text(&remaining[0], "s2");
5745 }
5746
5747 fn build_switch_test_session(auth: &AuthStorage) -> AgentSession {
5748 let registry = ModelRegistry::load(auth, None);
5749 let current_entry = registry
5750 .find("anthropic", "claude-sonnet-4-5")
5751 .expect("anthropic model in registry");
5752 let provider = crate::providers::create_provider(¤t_entry, None)
5753 .expect("create anthropic provider");
5754 let tools = ToolRegistry::new(&[], Path::new("."), None);
5755 let mut stream_options = StreamOptions {
5756 api_key: Some("stale-key".to_string()),
5757 ..Default::default()
5758 };
5759 let _ = stream_options
5760 .headers
5761 .insert("x-stale-header".to_string(), "stale-value".to_string());
5762 let agent = Agent::new(
5763 provider,
5764 tools,
5765 AgentConfig {
5766 system_prompt: None,
5767 max_tool_iterations: 50,
5768 stream_options,
5769 block_images: false,
5770 },
5771 );
5772
5773 let mut session = Session::in_memory();
5774 session.header.provider = Some("openai".to_string());
5775 session.header.model_id = Some("gpt-4o".to_string());
5776
5777 let mut agent_session = AgentSession::new(
5778 agent,
5779 Arc::new(Mutex::new(session)),
5780 false,
5781 ResolvedCompactionSettings::default(),
5782 );
5783 agent_session.set_model_registry(registry);
5784 agent_session.set_auth_storage(auth.clone());
5785 agent_session
5786 }
5787
5788 #[test]
5789 fn apply_session_model_selection_updates_stream_credentials_and_headers() {
5790 let dir = tempfile::tempdir().expect("tempdir");
5791 let auth_path = dir.path().join("auth.json");
5792 let mut auth = AuthStorage::load(auth_path).expect("load auth");
5793 auth.set(
5794 "anthropic",
5795 AuthCredential::ApiKey {
5796 key: "anthropic-key".to_string(),
5797 },
5798 );
5799 auth.set(
5800 "openai",
5801 AuthCredential::ApiKey {
5802 key: "openai-key".to_string(),
5803 },
5804 );
5805
5806 let mut agent_session = build_switch_test_session(&auth);
5807 agent_session.apply_session_model_selection("openai", "gpt-4o");
5808
5809 assert_eq!(agent_session.agent.provider().name(), "openai");
5810 assert_eq!(agent_session.agent.provider().model_id(), "gpt-4o");
5811 assert_eq!(
5812 agent_session.agent.stream_options().api_key.as_deref(),
5813 Some("openai-key")
5814 );
5815 assert!(
5816 agent_session.agent.stream_options().headers.is_empty(),
5817 "stream headers should be refreshed from selected model entry"
5818 );
5819 }
5820
5821 #[test]
5822 fn apply_session_model_selection_clears_stale_key_when_target_has_no_key() {
5823 let dir = tempfile::tempdir().expect("tempdir");
5824 let auth_path = dir.path().join("auth.json");
5825 let mut auth = AuthStorage::load(auth_path).expect("load auth");
5826 auth.set(
5827 "anthropic",
5828 AuthCredential::ApiKey {
5829 key: "anthropic-key".to_string(),
5830 },
5831 );
5832
5833 let mut agent_session = build_switch_test_session(&auth);
5834 agent_session.apply_session_model_selection("openai", "gpt-4o");
5835
5836 assert_eq!(agent_session.agent.provider().name(), "openai");
5837 assert_eq!(
5838 agent_session.agent.stream_options().api_key,
5839 None,
5840 "stale key must be cleared when target model has no configured key"
5841 );
5842 }
5843
5844 #[test]
5845 fn apply_session_model_selection_treats_blank_model_key_as_missing() {
5846 let dir = tempfile::tempdir().expect("tempdir");
5847 let auth_path = dir.path().join("auth.json");
5848 let auth = AuthStorage::load(auth_path).expect("load auth");
5849
5850 let mut registry = ModelRegistry::load(&auth, None);
5851 registry.merge_entries(vec![ModelEntry {
5852 model: Model {
5853 id: "blank-model".to_string(),
5854 name: "Blank Model".to_string(),
5855 api: "openai-completions".to_string(),
5856 provider: "acme".to_string(),
5857 base_url: "https://example.invalid/v1".to_string(),
5858 reasoning: true,
5859 input: vec![InputType::Text],
5860 cost: ModelCost {
5861 input: 0.0,
5862 output: 0.0,
5863 cache_read: 0.0,
5864 cache_write: 0.0,
5865 },
5866 context_window: 128_000,
5867 max_tokens: 8_192,
5868 headers: HashMap::new(),
5869 },
5870 api_key: Some(" ".to_string()),
5871 headers: HashMap::new(),
5872 auth_header: true,
5873 compat: None,
5874 oauth_config: None,
5875 }]);
5876
5877 let mut agent_session = build_switch_test_session(&auth);
5878 agent_session.set_model_registry(registry);
5879 agent_session.apply_session_model_selection("acme", "blank-model");
5880
5881 assert_eq!(agent_session.agent.provider().name(), "acme");
5882 assert_eq!(
5883 agent_session.agent.stream_options().api_key,
5884 None,
5885 "blank model keys must not be treated as valid credentials"
5886 );
5887 }
5888
5889 #[test]
5890 fn auto_compaction_start_serializes_to_pi_mono_format() {
5891 let event = AgentEvent::AutoCompactionStart {
5892 reason: "threshold".to_string(),
5893 };
5894 let json = serde_json::to_value(&event).unwrap();
5895 assert_eq!(json["type"], "auto_compaction_start");
5896 assert_eq!(json["reason"], "threshold");
5897 }
5898
5899 #[test]
5900 fn auto_compaction_end_serializes_to_pi_mono_format() {
5901 let event = AgentEvent::AutoCompactionEnd {
5902 result: Some(serde_json::json!({
5903 "summary": "Compacted",
5904 "firstKeptEntryId": "abc123",
5905 "tokensBefore": 50000,
5906 "details": { "readFiles": [], "modifiedFiles": [] }
5907 })),
5908 aborted: false,
5909 will_retry: true,
5910 error_message: None,
5911 };
5912 let json = serde_json::to_value(&event).unwrap();
5913 assert_eq!(json["type"], "auto_compaction_end");
5914 assert!(json["result"].is_object());
5915 assert_eq!(json["aborted"], false);
5916 assert_eq!(json["willRetry"], true);
5917 assert!(json.get("errorMessage").is_none());
5918 }
5919
5920 #[test]
5921 fn auto_compaction_end_with_error_serializes_error_message() {
5922 let event = AgentEvent::AutoCompactionEnd {
5923 result: None,
5924 aborted: false,
5925 will_retry: false,
5926 error_message: Some("compaction failed".to_string()),
5927 };
5928 let json = serde_json::to_value(&event).unwrap();
5929 assert_eq!(json["type"], "auto_compaction_end");
5930 assert!(json["result"].is_null());
5931 assert_eq!(json["errorMessage"], "compaction failed");
5932 }
5933
5934 #[test]
5935 fn auto_retry_start_serializes_to_pi_mono_format() {
5936 let event = AgentEvent::AutoRetryStart {
5937 attempt: 2,
5938 max_attempts: 3,
5939 delay_ms: 4000,
5940 error_message: "rate limited".to_string(),
5941 };
5942 let json = serde_json::to_value(&event).unwrap();
5943 assert_eq!(json["type"], "auto_retry_start");
5944 assert_eq!(json["attempt"], 2);
5945 assert_eq!(json["maxAttempts"], 3);
5946 assert_eq!(json["delayMs"], 4000);
5947 assert_eq!(json["errorMessage"], "rate limited");
5948 }
5949
5950 #[test]
5951 fn auto_retry_end_success_serializes_to_pi_mono_format() {
5952 let event = AgentEvent::AutoRetryEnd {
5953 success: true,
5954 attempt: 2,
5955 final_error: None,
5956 };
5957 let json = serde_json::to_value(&event).unwrap();
5958 assert_eq!(json["type"], "auto_retry_end");
5959 assert_eq!(json["success"], true);
5960 assert_eq!(json["attempt"], 2);
5961 assert!(json.get("finalError").is_none());
5962 }
5963
5964 #[test]
5965 fn auto_retry_end_failure_serializes_final_error() {
5966 let event = AgentEvent::AutoRetryEnd {
5967 success: false,
5968 attempt: 3,
5969 final_error: Some("max retries exceeded".to_string()),
5970 };
5971 let json = serde_json::to_value(&event).unwrap();
5972 assert_eq!(json["type"], "auto_retry_end");
5973 assert_eq!(json["success"], false);
5974 assert_eq!(json["attempt"], 3);
5975 assert_eq!(json["finalError"], "max retries exceeded");
5976 }
5977}