Skip to main content

pi/
agent.rs

1//! Agent runtime - the core orchestration loop.
2//!
3//! The agent coordinates between:
4//! - Provider: Makes LLM API calls
5//! - Tools: Executes tool calls from the assistant
6//! - Session: Persists conversation history
7//!
8//! The main loop:
9//! 1. Receive user input
10//! 2. Build context (system prompt + history + tools)
11//! 3. Stream completion from provider
12//! 4. If tool calls: execute tools, append results, goto 3
13//! 5. If done: return final message
14
15use crate::auth::AuthStorage;
16use crate::compaction::{self, ResolvedCompactionSettings};
17use crate::compaction_worker::{CompactionQuota, CompactionWorkerState};
18use crate::error::{Error, Result};
19use crate::extension_events::{InputEventOutcome, apply_input_event_response};
20use crate::extension_tools::collect_extension_tool_wrappers;
21use crate::extensions::{
22    EXTENSION_EVENT_TIMEOUT_MS, ExtensionDeliverAs, ExtensionEventName, ExtensionHostActions,
23    ExtensionLoadSpec, ExtensionManager, ExtensionPolicy, ExtensionRegion, ExtensionRuntimeHandle,
24    ExtensionSendMessage, ExtensionSendUserMessage, JsExtensionLoadSpec, JsExtensionRuntimeHandle,
25    NativeRustExtensionLoadSpec, NativeRustExtensionRuntimeHandle, RepairPolicyMode,
26    resolve_extension_load_spec,
27};
28#[cfg(feature = "wasm-host")]
29use crate::extensions::{WasmExtensionHost, WasmExtensionLoadSpec};
30use crate::extensions_js::{PiJsRuntimeConfig, RepairMode};
31use crate::model::{
32    AssistantMessage, AssistantMessageEvent, ContentBlock, CustomMessage, ImageContent, Message,
33    StopReason, StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage,
34    UserContent, UserMessage,
35};
36use crate::models::{ModelEntry, ModelRegistry};
37use crate::provider::{Context, Provider, StreamOptions, ToolDef};
38use crate::session::{AutosaveFlushTrigger, Session, SessionHandle};
39use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
40use asupersync::sync::{Mutex, Notify};
41use async_trait::async_trait;
42use chrono::Utc;
43use futures::FutureExt;
44use futures::StreamExt;
45use futures::future::BoxFuture;
46use futures::stream;
47use serde::Serialize;
48use serde_json::{Value, json};
49use std::borrow::Cow;
50use std::collections::VecDeque;
51use std::sync::Arc;
52use std::sync::Mutex as StdMutex;
53use std::sync::atomic::{AtomicBool, Ordering};
54
55const MAX_CONCURRENT_TOOLS: usize = 8;
56
57// ============================================================================
58// Agent Configuration
59// ============================================================================
60
61/// Configuration for the agent.
62#[derive(Debug, Clone)]
63pub struct AgentConfig {
64    /// System prompt to use for all requests.
65    pub system_prompt: Option<String>,
66
67    /// Maximum tool call iterations before stopping.
68    pub max_tool_iterations: usize,
69
70    /// Default stream options.
71    pub stream_options: StreamOptions,
72
73    /// Strip image blocks before sending context to providers.
74    pub block_images: bool,
75}
76
77impl Default for AgentConfig {
78    fn default() -> Self {
79        Self {
80            system_prompt: None,
81            max_tool_iterations: 50,
82            stream_options: StreamOptions::default(),
83            block_images: false,
84        }
85    }
86}
87
88/// Async fetcher for queued messages (steering or follow-up).
89pub type MessageFetcher = Arc<dyn Fn() -> BoxFuture<'static, Vec<Message>> + Send + Sync + 'static>;
90
91type AgentEventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync + 'static>;
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum QueueMode {
95    All,
96    OneAtATime,
97}
98
99impl QueueMode {
100    pub const fn as_str(self) -> &'static str {
101        match self {
102            Self::All => "all",
103            Self::OneAtATime => "one-at-a-time",
104        }
105    }
106}
107
108#[derive(Debug, Clone, Copy)]
109enum QueueKind {
110    Steering,
111    FollowUp,
112}
113
114#[derive(Debug, Clone)]
115struct QueuedMessage {
116    seq: u64,
117    enqueued_at: i64,
118    message: Message,
119}
120
121#[derive(Debug)]
122struct MessageQueue {
123    steering: VecDeque<QueuedMessage>,
124    follow_up: VecDeque<QueuedMessage>,
125    steering_mode: QueueMode,
126    follow_up_mode: QueueMode,
127    next_seq: u64,
128}
129
130impl MessageQueue {
131    const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
132        Self {
133            steering: VecDeque::new(),
134            follow_up: VecDeque::new(),
135            steering_mode,
136            follow_up_mode,
137            next_seq: 0,
138        }
139    }
140
141    const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
142        self.steering_mode = steering_mode;
143        self.follow_up_mode = follow_up_mode;
144    }
145
146    fn pending_count(&self) -> usize {
147        self.steering.len() + self.follow_up.len()
148    }
149
150    fn push(&mut self, kind: QueueKind, message: Message) -> u64 {
151        let seq = self.next_seq;
152        self.next_seq = self.next_seq.saturating_add(1);
153        let entry = QueuedMessage {
154            seq,
155            enqueued_at: Utc::now().timestamp_millis(),
156            message,
157        };
158        match kind {
159            QueueKind::Steering => self.steering.push_back(entry),
160            QueueKind::FollowUp => self.follow_up.push_back(entry),
161        }
162        seq
163    }
164
165    fn push_steering(&mut self, message: Message) -> u64 {
166        self.push(QueueKind::Steering, message)
167    }
168
169    fn push_follow_up(&mut self, message: Message) -> u64 {
170        self.push(QueueKind::FollowUp, message)
171    }
172
173    fn pop_steering(&mut self) -> Vec<Message> {
174        self.pop_kind(QueueKind::Steering)
175    }
176
177    fn pop_follow_up(&mut self) -> Vec<Message> {
178        self.pop_kind(QueueKind::FollowUp)
179    }
180
181    fn pop_kind(&mut self, kind: QueueKind) -> Vec<Message> {
182        let (queue, mode) = match kind {
183            QueueKind::Steering => (&mut self.steering, self.steering_mode),
184            QueueKind::FollowUp => (&mut self.follow_up, self.follow_up_mode),
185        };
186
187        match mode {
188            QueueMode::All => queue.drain(..).map(|entry| entry.message).collect(),
189            QueueMode::OneAtATime => queue
190                .pop_front()
191                .into_iter()
192                .map(|entry| entry.message)
193                .collect(),
194        }
195    }
196}
197
198// ============================================================================
199// Agent Event
200// ============================================================================
201
202/// Events emitted by the agent during execution.
203#[derive(Debug, Clone, Serialize)]
204#[serde(tag = "type", rename_all = "snake_case")]
205pub enum AgentEvent {
206    /// Agent lifecycle start.
207    AgentStart {
208        #[serde(rename = "sessionId")]
209        session_id: Arc<str>,
210    },
211    /// Agent lifecycle end with all new messages.
212    AgentEnd {
213        #[serde(rename = "sessionId")]
214        session_id: Arc<str>,
215        messages: Vec<Message>,
216        #[serde(skip_serializing_if = "Option::is_none")]
217        error: Option<String>,
218    },
219    /// Turn lifecycle start (assistant response + tool calls).
220    TurnStart {
221        #[serde(rename = "sessionId")]
222        session_id: Arc<str>,
223        #[serde(rename = "turnIndex")]
224        turn_index: usize,
225        timestamp: i64,
226    },
227    /// Turn lifecycle end with tool results.
228    TurnEnd {
229        #[serde(rename = "sessionId")]
230        session_id: Arc<str>,
231        #[serde(rename = "turnIndex")]
232        turn_index: usize,
233        message: Message,
234        #[serde(rename = "toolResults")]
235        tool_results: Vec<Message>,
236    },
237    /// Message lifecycle start (user, assistant, or tool result).
238    MessageStart { message: Message },
239    /// Message update (assistant streaming).
240    MessageUpdate {
241        message: Message,
242        #[serde(rename = "assistantMessageEvent")]
243        assistant_message_event: AssistantMessageEvent,
244    },
245    /// Message lifecycle end.
246    MessageEnd { message: Message },
247    /// Tool execution start.
248    ToolExecutionStart {
249        #[serde(rename = "toolCallId")]
250        tool_call_id: String,
251        #[serde(rename = "toolName")]
252        tool_name: String,
253        args: serde_json::Value,
254    },
255    /// Tool execution update.
256    ToolExecutionUpdate {
257        #[serde(rename = "toolCallId")]
258        tool_call_id: String,
259        #[serde(rename = "toolName")]
260        tool_name: String,
261        args: serde_json::Value,
262        #[serde(rename = "partialResult")]
263        partial_result: ToolOutput,
264    },
265    /// Tool execution end.
266    ToolExecutionEnd {
267        #[serde(rename = "toolCallId")]
268        tool_call_id: String,
269        #[serde(rename = "toolName")]
270        tool_name: String,
271        result: ToolOutput,
272        #[serde(rename = "isError")]
273        is_error: bool,
274    },
275    /// Auto-compaction lifecycle start.
276    AutoCompactionStart { reason: String },
277    /// Auto-compaction lifecycle end.
278    AutoCompactionEnd {
279        #[serde(skip_serializing_if = "Option::is_none")]
280        result: Option<serde_json::Value>,
281        aborted: bool,
282        #[serde(rename = "willRetry")]
283        will_retry: bool,
284        #[serde(rename = "errorMessage", skip_serializing_if = "Option::is_none")]
285        error_message: Option<String>,
286    },
287    /// Auto-retry lifecycle start.
288    AutoRetryStart {
289        attempt: u32,
290        #[serde(rename = "maxAttempts")]
291        max_attempts: u32,
292        #[serde(rename = "delayMs")]
293        delay_ms: u64,
294        #[serde(rename = "errorMessage")]
295        error_message: String,
296    },
297    /// Auto-retry lifecycle end.
298    AutoRetryEnd {
299        success: bool,
300        attempt: u32,
301        #[serde(rename = "finalError", skip_serializing_if = "Option::is_none")]
302        final_error: Option<String>,
303    },
304    /// Extension error during event dispatch or execution.
305    ExtensionError {
306        #[serde(rename = "extensionId", skip_serializing_if = "Option::is_none")]
307        extension_id: Option<String>,
308        event: String,
309        error: String,
310    },
311}
312
313// ============================================================================
314// Agent
315// ============================================================================
316
317/// Handle to request an abort of an in-flight agent run.
318#[derive(Debug, Clone)]
319pub struct AbortHandle {
320    inner: Arc<AbortSignalInner>,
321}
322
323/// Signal for observing abort requests.
324#[derive(Debug, Clone)]
325pub struct AbortSignal {
326    inner: Arc<AbortSignalInner>,
327}
328
329#[derive(Debug)]
330struct AbortSignalInner {
331    aborted: AtomicBool,
332    notify: Notify,
333}
334
335impl AbortHandle {
336    /// Create a new abort handle + signal pair.
337    #[must_use]
338    pub fn new() -> (Self, AbortSignal) {
339        let inner = Arc::new(AbortSignalInner {
340            aborted: AtomicBool::new(false),
341            notify: Notify::new(),
342        });
343        (
344            Self {
345                inner: Arc::clone(&inner),
346            },
347            AbortSignal { inner },
348        )
349    }
350
351    /// Trigger an abort.
352    pub fn abort(&self) {
353        if !self.inner.aborted.swap(true, Ordering::SeqCst) {
354            self.inner.notify.notify_waiters();
355        }
356    }
357}
358
359impl AbortSignal {
360    /// Check if an abort has already been requested.
361    #[must_use]
362    pub fn is_aborted(&self) -> bool {
363        self.inner.aborted.load(Ordering::SeqCst)
364    }
365
366    pub async fn wait(&self) {
367        if self.is_aborted() {
368            return;
369        }
370
371        loop {
372            self.inner.notify.notified().await;
373            if self.is_aborted() {
374                return;
375            }
376        }
377    }
378}
379
380/// The agent runtime that orchestrates LLM calls and tool execution.
381pub struct Agent {
382    /// The LLM provider.
383    provider: Arc<dyn Provider>,
384
385    /// Tool registry.
386    tools: ToolRegistry,
387
388    /// Agent configuration.
389    config: AgentConfig,
390
391    /// Optional extension manager for tool/event hooks.
392    extensions: Option<ExtensionManager>,
393
394    /// Message history.
395    messages: Vec<Message>,
396
397    /// Fetchers for queued steering messages (interrupts).
398    steering_fetchers: Vec<MessageFetcher>,
399
400    /// Fetchers for queued follow-up messages (idle).
401    follow_up_fetchers: Vec<MessageFetcher>,
402
403    /// Internal queue for steering/follow-up messages.
404    message_queue: MessageQueue,
405
406    /// Cached tool definitions. Invalidated when tools change via `extend_tools`.
407    cached_tool_defs: Option<Vec<ToolDef>>,
408}
409
410impl Agent {
411    /// Create a new agent with the given provider and tools.
412    pub fn new(provider: Arc<dyn Provider>, tools: ToolRegistry, config: AgentConfig) -> Self {
413        Self {
414            provider,
415            tools,
416            config,
417            extensions: None,
418            messages: Vec::new(),
419            steering_fetchers: Vec::new(),
420            follow_up_fetchers: Vec::new(),
421            message_queue: MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime),
422            cached_tool_defs: None,
423        }
424    }
425
426    /// Get the current message history.
427    #[must_use]
428    pub fn messages(&self) -> &[Message] {
429        &self.messages
430    }
431
432    /// Clear the message history.
433    pub fn clear_messages(&mut self) {
434        self.messages.clear();
435    }
436
437    /// Add a message to the history.
438    pub fn add_message(&mut self, message: Message) {
439        self.messages.push(message);
440    }
441
442    /// Replace the message history.
443    pub fn replace_messages(&mut self, messages: Vec<Message>) {
444        self.messages = messages;
445    }
446
447    /// Replace the provider implementation (used for model/provider switching).
448    pub fn set_provider(&mut self, provider: Arc<dyn Provider>) {
449        self.provider = provider;
450    }
451
452    /// Register async fetchers for queued steering/follow-up messages.
453    ///
454    /// This is additive: multiple sources (e.g. RPC, extensions) can register
455    /// fetchers, and the agent will poll all of them.
456    pub fn register_message_fetchers(
457        &mut self,
458        steering: Option<MessageFetcher>,
459        follow_up: Option<MessageFetcher>,
460    ) {
461        if let Some(fetcher) = steering {
462            self.steering_fetchers.push(fetcher);
463        }
464        if let Some(fetcher) = follow_up {
465            self.follow_up_fetchers.push(fetcher);
466        }
467    }
468
469    /// Extend the tool registry with additional tools (e.g. extension-registered tools).
470    pub fn extend_tools<I>(&mut self, tools: I)
471    where
472        I: IntoIterator<Item = Box<dyn Tool>>,
473    {
474        self.tools.extend(tools);
475        self.cached_tool_defs = None; // Invalidate cache when tools change
476    }
477
478    /// Queue a steering message (delivered after tool completion).
479    pub fn queue_steering(&mut self, message: Message) -> u64 {
480        self.message_queue.push_steering(message)
481    }
482
483    /// Queue a follow-up message (delivered when agent becomes idle).
484    pub fn queue_follow_up(&mut self, message: Message) -> u64 {
485        self.message_queue.push_follow_up(message)
486    }
487
488    /// Configure queue delivery modes.
489    pub const fn set_queue_modes(&mut self, steering: QueueMode, follow_up: QueueMode) {
490        self.message_queue.set_modes(steering, follow_up);
491    }
492
493    /// Count queued messages (steering + follow-up).
494    #[must_use]
495    pub fn queued_message_count(&self) -> usize {
496        self.message_queue.pending_count()
497    }
498
499    pub fn provider(&self) -> Arc<dyn Provider> {
500        Arc::clone(&self.provider)
501    }
502
503    pub const fn stream_options(&self) -> &StreamOptions {
504        &self.config.stream_options
505    }
506
507    pub const fn stream_options_mut(&mut self) -> &mut StreamOptions {
508        &mut self.config.stream_options
509    }
510
511    /// Build context for a completion request.
512    fn build_context(&mut self) -> Context<'_> {
513        let messages: Cow<'_, [Message]> = if self.config.block_images {
514            let mut msgs = self.messages.clone();
515            // Filter out hidden custom messages.
516            msgs.retain(|m| match m {
517                Message::Custom(c) => c.display,
518                _ => true,
519            });
520            let stats = filter_images_for_provider(&mut msgs);
521            if stats.removed_images > 0 {
522                tracing::debug!(
523                    filtered_images = stats.removed_images,
524                    affected_messages = stats.affected_messages,
525                    "Filtered image content from outbound provider context (images.block_images=true)"
526                );
527            }
528            Cow::Owned(msgs)
529        } else {
530            // Check if we need to filter hidden custom messages to avoid cloning if not needed.
531            let has_hidden = self.messages.iter().any(|m| match m {
532                Message::Custom(c) => !c.display,
533                _ => false,
534            });
535
536            if has_hidden {
537                let mut msgs = self.messages.clone();
538                msgs.retain(|m| match m {
539                    Message::Custom(c) => c.display,
540                    _ => true,
541                });
542                Cow::Owned(msgs)
543            } else {
544                Cow::Borrowed(self.messages.as_slice())
545            }
546        };
547
548        // Borrow cached tool defs if available; otherwise build + cache + borrow.
549        if self.cached_tool_defs.is_none() {
550            let defs: Vec<ToolDef> = self
551                .tools
552                .tools()
553                .iter()
554                .map(|t| ToolDef {
555                    name: t.name().to_string(),
556                    description: t.description().to_string(),
557                    parameters: t.parameters(),
558                })
559                .collect();
560            self.cached_tool_defs = Some(defs);
561        }
562        let tools = Cow::Borrowed(self.cached_tool_defs.as_deref().unwrap());
563
564        Context {
565            system_prompt: self.config.system_prompt.as_deref().map(Cow::Borrowed),
566            messages,
567            tools,
568        }
569    }
570
571    /// Run the agent with a user message.
572    ///
573    /// Returns a stream of events and the final assistant message.
574    pub async fn run(
575        &mut self,
576        user_input: impl Into<String>,
577        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
578    ) -> Result<AssistantMessage> {
579        self.run_with_abort(user_input, None, on_event).await
580    }
581
582    /// Run the agent with a user message and abort support.
583    pub async fn run_with_abort(
584        &mut self,
585        user_input: impl Into<String>,
586        abort: Option<AbortSignal>,
587        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
588    ) -> Result<AssistantMessage> {
589        // Add user message
590        let user_message = Message::User(UserMessage {
591            content: UserContent::Text(user_input.into()),
592            timestamp: Utc::now().timestamp_millis(),
593        });
594
595        // Run the agent loop
596        self.run_loop(vec![user_message], Arc::new(on_event), abort)
597            .await
598    }
599
600    /// Run the agent with structured content (text + images).
601    pub async fn run_with_content(
602        &mut self,
603        content: Vec<ContentBlock>,
604        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
605    ) -> Result<AssistantMessage> {
606        self.run_with_content_with_abort(content, None, on_event)
607            .await
608    }
609
610    /// Run the agent with structured content (text + images) and abort support.
611    pub async fn run_with_content_with_abort(
612        &mut self,
613        content: Vec<ContentBlock>,
614        abort: Option<AbortSignal>,
615        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
616    ) -> Result<AssistantMessage> {
617        // Add user message
618        let user_message = Message::User(UserMessage {
619            content: UserContent::Blocks(content),
620            timestamp: Utc::now().timestamp_millis(),
621        });
622
623        // Run the agent loop
624        self.run_loop(vec![user_message], Arc::new(on_event), abort)
625            .await
626    }
627
628    /// Run the agent with a pre-constructed user message and abort support.
629    pub async fn run_with_message_with_abort(
630        &mut self,
631        message: Message,
632        abort: Option<AbortSignal>,
633        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
634    ) -> Result<AssistantMessage> {
635        self.run_loop(vec![message], Arc::new(on_event), abort)
636            .await
637    }
638
639    /// Continue the agent loop without adding a new prompt message (used for retries).
640    pub async fn run_continue_with_abort(
641        &mut self,
642        abort: Option<AbortSignal>,
643        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
644    ) -> Result<AssistantMessage> {
645        self.run_loop(Vec::new(), Arc::new(on_event), abort).await
646    }
647
648    fn build_abort_message(&self, partial: Option<&AssistantMessage>) -> AssistantMessage {
649        let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
650            content: Vec::new(),
651            api: self.provider.api().to_string(),
652            provider: self.provider.name().to_string(),
653            model: self.provider.model_id().to_string(),
654            usage: Usage::default(),
655            stop_reason: StopReason::Aborted,
656            error_message: Some("Aborted".to_string()),
657            timestamp: Utc::now().timestamp_millis(),
658        });
659        message.stop_reason = StopReason::Aborted;
660        message.error_message = Some("Aborted".to_string());
661        message.timestamp = Utc::now().timestamp_millis();
662        message
663    }
664
665    /// The main agent loop.
666    #[allow(clippy::too_many_lines)]
667    async fn run_loop(
668        &mut self,
669        prompts: Vec<Message>,
670        on_event: AgentEventHandler,
671        abort: Option<AbortSignal>,
672    ) -> Result<AssistantMessage> {
673        let session_id: Arc<str> = self
674            .config
675            .stream_options
676            .session_id
677            .as_deref()
678            .unwrap_or("")
679            .into();
680        let mut iterations = 0usize;
681        let mut turn_index: usize = 0;
682        let mut new_messages: Vec<Message> = Vec::with_capacity(prompts.len() + 8);
683        let mut last_assistant: Option<Arc<AssistantMessage>> = None;
684
685        let agent_start_event = AgentEvent::AgentStart {
686            session_id: session_id.clone(),
687        };
688        self.dispatch_extension_lifecycle_event(&agent_start_event)
689            .await;
690        on_event(agent_start_event);
691
692        for prompt in prompts {
693            self.messages.push(prompt.clone());
694            on_event(AgentEvent::MessageStart {
695                message: prompt.clone(),
696            });
697            on_event(AgentEvent::MessageEnd {
698                message: prompt.clone(),
699            });
700            new_messages.push(prompt);
701        }
702
703        // Delivery boundary: start of turn (steering messages queued while idle).
704        let mut pending_messages = self.drain_steering_messages().await;
705
706        loop {
707            let mut has_more_tool_calls = true;
708            let mut steering_after_tools: Option<Vec<Message>> = None;
709
710            while has_more_tool_calls || !pending_messages.is_empty() {
711                let current_turn_index = turn_index;
712                let turn_start_event = AgentEvent::TurnStart {
713                    session_id: session_id.clone(),
714                    turn_index: current_turn_index,
715                    timestamp: Utc::now().timestamp_millis(),
716                };
717                self.dispatch_extension_lifecycle_event(&turn_start_event)
718                    .await;
719                on_event(turn_start_event);
720
721                for message in std::mem::take(&mut pending_messages) {
722                    self.messages.push(message.clone());
723                    on_event(AgentEvent::MessageStart {
724                        message: message.clone(),
725                    });
726                    on_event(AgentEvent::MessageEnd {
727                        message: message.clone(),
728                    });
729                    new_messages.push(message);
730                }
731
732                if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
733                    let abort_message = self.build_abort_message(None);
734                    let message = Message::assistant(abort_message.clone());
735
736                    self.messages.push(message.clone());
737                    new_messages.push(message.clone());
738                    on_event(AgentEvent::MessageStart {
739                        message: message.clone(),
740                    });
741                    on_event(AgentEvent::MessageEnd {
742                        message: message.clone(),
743                    });
744
745                    let turn_end_event = AgentEvent::TurnEnd {
746                        session_id: session_id.clone(),
747                        turn_index: current_turn_index,
748                        message,
749                        tool_results: Vec::new(),
750                    };
751                    self.dispatch_extension_lifecycle_event(&turn_end_event)
752                        .await;
753                    on_event(turn_end_event);
754                    let agent_end_event = AgentEvent::AgentEnd {
755                        session_id: session_id.clone(),
756                        messages: std::mem::take(&mut new_messages),
757                        error: Some(
758                            abort_message
759                                .error_message
760                                .clone()
761                                .unwrap_or_else(|| "Aborted".to_string()),
762                        ),
763                    };
764                    self.dispatch_extension_lifecycle_event(&agent_end_event)
765                        .await;
766                    on_event(agent_end_event);
767                    return Ok(abort_message);
768                }
769
770                let assistant_message = match self
771                    .stream_assistant_response(Arc::clone(&on_event), abort.clone())
772                    .await
773                {
774                    Ok(msg) => msg,
775                    Err(err) => {
776                        let agent_end_event = AgentEvent::AgentEnd {
777                            session_id: session_id.clone(),
778                            messages: std::mem::take(&mut new_messages),
779                            error: Some(err.to_string()),
780                        };
781                        self.dispatch_extension_lifecycle_event(&agent_end_event)
782                            .await;
783                        on_event(agent_end_event);
784                        return Err(err);
785                    }
786                };
787                // Wrap in Arc once; share via Arc::clone (O(1)) instead of deep
788                // cloning the full AssistantMessage for every consumer.
789                let assistant_arc = Arc::new(assistant_message);
790                last_assistant = Some(Arc::clone(&assistant_arc));
791
792                let assistant_event_message = Message::Assistant(Arc::clone(&assistant_arc));
793                new_messages.push(assistant_event_message.clone());
794
795                if matches!(
796                    assistant_arc.stop_reason,
797                    StopReason::Error | StopReason::Aborted
798                ) {
799                    let turn_end_event = AgentEvent::TurnEnd {
800                        session_id: session_id.clone(),
801                        turn_index: current_turn_index,
802                        message: assistant_event_message.clone(),
803                        tool_results: Vec::new(),
804                    };
805                    self.dispatch_extension_lifecycle_event(&turn_end_event)
806                        .await;
807                    on_event(turn_end_event);
808                    let agent_end_event = AgentEvent::AgentEnd {
809                        session_id: session_id.clone(),
810                        messages: std::mem::take(&mut new_messages),
811                        error: assistant_arc.error_message.clone(),
812                    };
813                    self.dispatch_extension_lifecycle_event(&agent_end_event)
814                        .await;
815                    on_event(agent_end_event);
816                    return Ok(Arc::unwrap_or_clone(assistant_arc));
817                }
818
819                let tool_calls = extract_tool_calls(&assistant_arc.content);
820                has_more_tool_calls = !tool_calls.is_empty();
821
822                let mut tool_results: Vec<Arc<ToolResultMessage>> = Vec::new();
823                if has_more_tool_calls {
824                    iterations += 1;
825                    if iterations > self.config.max_tool_iterations {
826                        let error_message = format!(
827                            "Maximum tool iterations ({}) exceeded",
828                            self.config.max_tool_iterations
829                        );
830                        let mut stop_message = (*assistant_arc).clone();
831                        stop_message.stop_reason = StopReason::Error;
832                        stop_message.error_message = Some(error_message.clone());
833                        let stop_arc = Arc::new(stop_message.clone());
834                        let stop_event_message = Message::Assistant(Arc::clone(&stop_arc));
835
836                        // Keep in-memory transcript and event payloads aligned with the
837                        // error stop result returned to callers.
838                        if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
839                            *last = stop_event_message.clone();
840                        }
841                        if let Some(last @ Message::Assistant(_)) = new_messages.last_mut() {
842                            *last = stop_event_message.clone();
843                        }
844
845                        let turn_end_event = AgentEvent::TurnEnd {
846                            session_id: session_id.clone(),
847                            turn_index: current_turn_index,
848                            message: stop_event_message,
849                            tool_results: Vec::new(),
850                        };
851                        self.dispatch_extension_lifecycle_event(&turn_end_event)
852                            .await;
853                        on_event(turn_end_event);
854
855                        let agent_end_event = AgentEvent::AgentEnd {
856                            session_id: session_id.clone(),
857                            messages: std::mem::take(&mut new_messages),
858                            error: Some(error_message),
859                        };
860                        self.dispatch_extension_lifecycle_event(&agent_end_event)
861                            .await;
862                        on_event(agent_end_event);
863
864                        return Ok(stop_message);
865                    }
866
867                    let outcome = match self
868                        .execute_tool_calls(
869                            &tool_calls,
870                            Arc::clone(&on_event),
871                            &mut new_messages,
872                            abort.clone(),
873                        )
874                        .await
875                    {
876                        Ok(outcome) => outcome,
877                        Err(err) => {
878                            let agent_end_event = AgentEvent::AgentEnd {
879                                session_id: session_id.clone(),
880                                messages: std::mem::take(&mut new_messages),
881                                error: Some(err.to_string()),
882                            };
883                            self.dispatch_extension_lifecycle_event(&agent_end_event)
884                                .await;
885                            on_event(agent_end_event);
886                            return Err(err);
887                        }
888                    };
889                    tool_results = outcome.tool_results;
890                    steering_after_tools = outcome.steering_messages;
891                }
892
893                let tool_messages = tool_results
894                    .iter()
895                    .map(|r| Message::ToolResult(Arc::clone(r)))
896                    .collect::<Vec<_>>();
897
898                let turn_end_event = AgentEvent::TurnEnd {
899                    session_id: session_id.clone(),
900                    turn_index: current_turn_index,
901                    message: assistant_event_message.clone(),
902                    tool_results: tool_messages,
903                };
904                self.dispatch_extension_lifecycle_event(&turn_end_event)
905                    .await;
906                on_event(turn_end_event);
907
908                turn_index = turn_index.saturating_add(1);
909
910                if let Some(steering) = steering_after_tools.take() {
911                    pending_messages = steering;
912                } else {
913                    // Delivery boundary: after assistant completion (no tool calls).
914                    pending_messages = self.drain_steering_messages().await;
915                }
916            }
917
918            // Delivery boundary: agent idle (after all tool calls + steering).
919            let follow_up = self.drain_follow_up_messages().await;
920            if follow_up.is_empty() {
921                break;
922            }
923            pending_messages = follow_up;
924        }
925
926        let Some(final_arc) = last_assistant else {
927            return Err(Error::api("Agent completed without assistant message"));
928        };
929
930        let agent_end_event = AgentEvent::AgentEnd {
931            session_id: session_id.clone(),
932            messages: new_messages,
933            error: None,
934        };
935        self.dispatch_extension_lifecycle_event(&agent_end_event)
936            .await;
937        on_event(agent_end_event);
938        Ok(Arc::unwrap_or_clone(final_arc))
939    }
940
941    async fn fetch_messages(&self, fetcher: Option<&MessageFetcher>) -> Vec<Message> {
942        if let Some(fetcher) = fetcher {
943            (fetcher)().await
944        } else {
945            Vec::new()
946        }
947    }
948
949    async fn dispatch_extension_lifecycle_event(&self, event: &AgentEvent) {
950        let Some(extensions) = &self.extensions else {
951            return;
952        };
953
954        let name = match event {
955            AgentEvent::AgentStart { .. } => ExtensionEventName::AgentStart,
956            AgentEvent::AgentEnd { .. } => ExtensionEventName::AgentEnd,
957            AgentEvent::TurnStart { .. } => ExtensionEventName::TurnStart,
958            AgentEvent::TurnEnd { .. } => ExtensionEventName::TurnEnd,
959            _ => return,
960        };
961
962        let payload = match serde_json::to_value(event) {
963            Ok(payload) => payload,
964            Err(err) => {
965                tracing::warn!("failed to serialize agent lifecycle event (fail-open): {err}");
966                return;
967            }
968        };
969
970        if let Err(err) = extensions.dispatch_event(name, Some(payload)).await {
971            tracing::warn!("agent lifecycle extension hook failed (fail-open): {err}");
972        }
973    }
974
975    async fn drain_steering_messages(&mut self) -> Vec<Message> {
976        for fetcher in &self.steering_fetchers {
977            let fetched = self.fetch_messages(Some(fetcher)).await;
978            for message in fetched {
979                self.message_queue.push_steering(message);
980            }
981        }
982        self.message_queue.pop_steering()
983    }
984
985    async fn drain_follow_up_messages(&mut self) -> Vec<Message> {
986        for fetcher in &self.follow_up_fetchers {
987            let fetched = self.fetch_messages(Some(fetcher)).await;
988            for message in fetched {
989                self.message_queue.push_follow_up(message);
990            }
991        }
992        self.message_queue.pop_follow_up()
993    }
994
995    /// Stream an assistant response and emit message events.
996    #[allow(clippy::too_many_lines)]
997    async fn stream_assistant_response(
998        &mut self,
999        on_event: AgentEventHandler,
1000        abort: Option<AbortSignal>,
1001    ) -> Result<AssistantMessage> {
1002        // Build context and stream completion
1003        let provider = Arc::clone(&self.provider);
1004        let stream_options = self.config.stream_options.clone();
1005        let context = self.build_context();
1006        let mut stream = provider.stream(&context, &stream_options).await?;
1007
1008        let mut added_partial = false;
1009        // Track whether we've already emitted `MessageStart` for this streaming response.
1010        // Avoids cloning the full message on every event just to re-emit a redundant start.
1011        let mut sent_start = false;
1012
1013        loop {
1014            let event_result = if let Some(signal) = abort.as_ref() {
1015                let abort_fut = signal.wait().fuse();
1016                let event_fut = stream.next().fuse();
1017                futures::pin_mut!(abort_fut, event_fut);
1018
1019                match futures::future::select(abort_fut, event_fut).await {
1020                    futures::future::Either::Left(((), _event_fut)) => {
1021                        let last_partial = if added_partial {
1022                            match self.messages.last() {
1023                                Some(Message::Assistant(a)) => Some(a.as_ref()),
1024                                _ => None,
1025                            }
1026                        } else {
1027                            None
1028                        };
1029                        let abort_arc = Arc::new(self.build_abort_message(last_partial));
1030                        if !sent_start {
1031                            on_event(AgentEvent::MessageStart {
1032                                message: Message::Assistant(Arc::clone(&abort_arc)),
1033                            });
1034                            self.messages
1035                                .push(Message::Assistant(Arc::clone(&abort_arc)));
1036                            added_partial = true;
1037                            // We do NOT set sent_start = true here because we are returning immediately,
1038                            // but setting added_partial = true prevents finalize_assistant_message from
1039                            // emitting a second MessageStart.
1040                        }
1041                        on_event(AgentEvent::MessageUpdate {
1042                            message: Message::Assistant(Arc::clone(&abort_arc)),
1043                            assistant_message_event: AssistantMessageEvent::Error {
1044                                reason: StopReason::Aborted,
1045                                error: Arc::clone(&abort_arc),
1046                            },
1047                        });
1048                        return Ok(self.finalize_assistant_message(
1049                            Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
1050                            &on_event,
1051                            added_partial,
1052                        ));
1053                    }
1054                    futures::future::Either::Right((event, _abort_fut)) => event,
1055                }
1056            } else {
1057                stream.next().await
1058            };
1059
1060            let Some(event_result) = event_result else {
1061                break;
1062            };
1063            let event = event_result?;
1064
1065            match event {
1066                StreamEvent::Start { partial } => {
1067                    let shared = Arc::new(partial);
1068                    self.update_partial_message(Arc::clone(&shared), &mut added_partial);
1069                    on_event(AgentEvent::MessageStart {
1070                        message: Message::Assistant(Arc::clone(&shared)),
1071                    });
1072                    sent_start = true;
1073                    on_event(AgentEvent::MessageUpdate {
1074                        message: Message::Assistant(Arc::clone(&shared)),
1075                        assistant_message_event: AssistantMessageEvent::Start { partial: shared },
1076                    });
1077                }
1078                StreamEvent::TextStart { content_index, .. } => {
1079                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1080                        let msg = Arc::make_mut(msg_arc);
1081                        if content_index == msg.content.len() {
1082                            msg.content.push(ContentBlock::Text(TextContent::new("")));
1083                        }
1084                        let shared = Arc::clone(msg_arc);
1085                        if !sent_start {
1086                            on_event(AgentEvent::MessageStart {
1087                                message: Message::Assistant(Arc::clone(&shared)),
1088                            });
1089                            sent_start = true;
1090                        }
1091                        on_event(AgentEvent::MessageUpdate {
1092                            message: Message::Assistant(Arc::clone(&shared)),
1093                            assistant_message_event: AssistantMessageEvent::TextStart {
1094                                content_index,
1095                                partial: shared,
1096                            },
1097                        });
1098                    }
1099                }
1100                StreamEvent::TextDelta {
1101                    content_index,
1102                    delta,
1103                    ..
1104                } => {
1105                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1106                        {
1107                            let msg = Arc::make_mut(msg_arc);
1108                            if let Some(ContentBlock::Text(text)) =
1109                                msg.content.get_mut(content_index)
1110                            {
1111                                text.text.push_str(&delta);
1112                            }
1113                        }
1114                        let shared = Arc::clone(msg_arc);
1115                        if !sent_start {
1116                            on_event(AgentEvent::MessageStart {
1117                                message: Message::Assistant(Arc::clone(&shared)),
1118                            });
1119                            sent_start = true;
1120                        }
1121                        on_event(AgentEvent::MessageUpdate {
1122                            message: Message::Assistant(Arc::clone(&shared)),
1123                            assistant_message_event: AssistantMessageEvent::TextDelta {
1124                                content_index,
1125                                delta,
1126                                partial: shared,
1127                            },
1128                        });
1129                    }
1130                }
1131                StreamEvent::TextEnd {
1132                    content_index,
1133                    content,
1134                    ..
1135                } => {
1136                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1137                        {
1138                            let msg = Arc::make_mut(msg_arc);
1139                            if let Some(ContentBlock::Text(text)) =
1140                                msg.content.get_mut(content_index)
1141                            {
1142                                text.text.clone_from(&content);
1143                            }
1144                        }
1145                        let shared = Arc::clone(msg_arc);
1146                        if !sent_start {
1147                            on_event(AgentEvent::MessageStart {
1148                                message: Message::Assistant(Arc::clone(&shared)),
1149                            });
1150                            sent_start = true;
1151                        }
1152                        on_event(AgentEvent::MessageUpdate {
1153                            message: Message::Assistant(Arc::clone(&shared)),
1154                            assistant_message_event: AssistantMessageEvent::TextEnd {
1155                                content_index,
1156                                content,
1157                                partial: shared,
1158                            },
1159                        });
1160                    }
1161                }
1162                StreamEvent::ThinkingStart { content_index, .. } => {
1163                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1164                        let msg = Arc::make_mut(msg_arc);
1165                        if content_index == msg.content.len() {
1166                            msg.content.push(ContentBlock::Thinking(ThinkingContent {
1167                                thinking: String::new(),
1168                                thinking_signature: None,
1169                            }));
1170                        }
1171                        let shared = Arc::clone(msg_arc);
1172                        if !sent_start {
1173                            on_event(AgentEvent::MessageStart {
1174                                message: Message::Assistant(Arc::clone(&shared)),
1175                            });
1176                            sent_start = true;
1177                        }
1178                        on_event(AgentEvent::MessageUpdate {
1179                            message: Message::Assistant(Arc::clone(&shared)),
1180                            assistant_message_event: AssistantMessageEvent::ThinkingStart {
1181                                content_index,
1182                                partial: shared,
1183                            },
1184                        });
1185                    }
1186                }
1187                StreamEvent::ThinkingDelta {
1188                    content_index,
1189                    delta,
1190                    ..
1191                } => {
1192                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1193                        {
1194                            let msg = Arc::make_mut(msg_arc);
1195                            if let Some(ContentBlock::Thinking(thinking)) =
1196                                msg.content.get_mut(content_index)
1197                            {
1198                                thinking.thinking.push_str(&delta);
1199                            }
1200                        }
1201                        let shared = Arc::clone(msg_arc);
1202                        if !sent_start {
1203                            on_event(AgentEvent::MessageStart {
1204                                message: Message::Assistant(Arc::clone(&shared)),
1205                            });
1206                            sent_start = true;
1207                        }
1208                        on_event(AgentEvent::MessageUpdate {
1209                            message: Message::Assistant(Arc::clone(&shared)),
1210                            assistant_message_event: AssistantMessageEvent::ThinkingDelta {
1211                                content_index,
1212                                delta,
1213                                partial: shared,
1214                            },
1215                        });
1216                    }
1217                }
1218                StreamEvent::ThinkingEnd {
1219                    content_index,
1220                    content,
1221                    ..
1222                } => {
1223                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1224                        {
1225                            let msg = Arc::make_mut(msg_arc);
1226                            if let Some(ContentBlock::Thinking(thinking)) =
1227                                msg.content.get_mut(content_index)
1228                            {
1229                                thinking.thinking.clone_from(&content);
1230                            }
1231                        }
1232                        let shared = Arc::clone(msg_arc);
1233                        if !sent_start {
1234                            on_event(AgentEvent::MessageStart {
1235                                message: Message::Assistant(Arc::clone(&shared)),
1236                            });
1237                            sent_start = true;
1238                        }
1239                        on_event(AgentEvent::MessageUpdate {
1240                            message: Message::Assistant(Arc::clone(&shared)),
1241                            assistant_message_event: AssistantMessageEvent::ThinkingEnd {
1242                                content_index,
1243                                content,
1244                                partial: shared,
1245                            },
1246                        });
1247                    }
1248                }
1249                StreamEvent::ToolCallStart { content_index, .. } => {
1250                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1251                        let msg = Arc::make_mut(msg_arc);
1252                        if content_index == msg.content.len() {
1253                            msg.content.push(ContentBlock::ToolCall(ToolCall {
1254                                id: String::new(),
1255                                name: String::new(),
1256                                arguments: serde_json::Value::Null,
1257                                thought_signature: None,
1258                            }));
1259                        }
1260                        let shared = Arc::clone(msg_arc);
1261                        if !sent_start {
1262                            on_event(AgentEvent::MessageStart {
1263                                message: Message::Assistant(Arc::clone(&shared)),
1264                            });
1265                            sent_start = true;
1266                        }
1267                        on_event(AgentEvent::MessageUpdate {
1268                            message: Message::Assistant(Arc::clone(&shared)),
1269                            assistant_message_event: AssistantMessageEvent::ToolCallStart {
1270                                content_index,
1271                                partial: shared,
1272                            },
1273                        });
1274                    }
1275                }
1276                StreamEvent::ToolCallDelta {
1277                    content_index,
1278                    delta,
1279                    ..
1280                } => {
1281                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1282                        // No mutation needed for ToolCallDelta – args stay Null until ToolCallEnd.
1283                        // Just share the current Arc (O(1) refcount bump, zero deep copies).
1284                        let shared = Arc::clone(msg_arc);
1285                        if !sent_start {
1286                            on_event(AgentEvent::MessageStart {
1287                                message: Message::Assistant(Arc::clone(&shared)),
1288                            });
1289                            sent_start = true;
1290                        }
1291                        on_event(AgentEvent::MessageUpdate {
1292                            message: Message::Assistant(Arc::clone(&shared)),
1293                            assistant_message_event: AssistantMessageEvent::ToolCallDelta {
1294                                content_index,
1295                                delta,
1296                                partial: shared,
1297                            },
1298                        });
1299                    }
1300                }
1301                StreamEvent::ToolCallEnd {
1302                    content_index,
1303                    tool_call,
1304                    ..
1305                } => {
1306                    if let Some(Message::Assistant(msg_arc)) = self.messages.last_mut() {
1307                        {
1308                            let msg = Arc::make_mut(msg_arc);
1309                            if let Some(ContentBlock::ToolCall(tc)) =
1310                                msg.content.get_mut(content_index)
1311                            {
1312                                *tc = tool_call.clone();
1313                            }
1314                        }
1315                        let shared = Arc::clone(msg_arc);
1316                        if !sent_start {
1317                            on_event(AgentEvent::MessageStart {
1318                                message: Message::Assistant(Arc::clone(&shared)),
1319                            });
1320                            sent_start = true;
1321                        }
1322                        on_event(AgentEvent::MessageUpdate {
1323                            message: Message::Assistant(Arc::clone(&shared)),
1324                            assistant_message_event: AssistantMessageEvent::ToolCallEnd {
1325                                content_index,
1326                                tool_call,
1327                                partial: shared,
1328                            },
1329                        });
1330                    }
1331                }
1332                StreamEvent::Done { message, .. } => {
1333                    return Ok(self.finalize_assistant_message(message, &on_event, added_partial));
1334                }
1335                StreamEvent::Error { error, .. } => {
1336                    return Ok(self.finalize_assistant_message(error, &on_event, added_partial));
1337                }
1338            }
1339        }
1340
1341        // If the stream ends without a Done/Error event, we may have a partial message.
1342        // Instead of discarding it, we finalize it with an error state so the user/session
1343        // retains the partial content.
1344        if added_partial {
1345            if let Some(Message::Assistant(last_msg)) = self.messages.last() {
1346                let mut final_msg = (**last_msg).clone();
1347                final_msg.stop_reason = StopReason::Error;
1348                final_msg.error_message = Some("Stream ended without Done event".to_string());
1349                return Ok(self.finalize_assistant_message(final_msg, &on_event, true));
1350            }
1351        }
1352        Err(Error::api("Stream ended without Done event"))
1353    }
1354
1355    /// Update the partial assistant message in `self.messages`.
1356    ///
1357    /// Takes an `Arc<AssistantMessage>` and moves it into the message list
1358    /// (one Arc move, zero deep-copies).
1359    fn update_partial_message(
1360        &mut self,
1361        partial: Arc<AssistantMessage>,
1362        added_partial: &mut bool,
1363    ) -> bool {
1364        if *added_partial {
1365            if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
1366                *last = Message::Assistant(partial);
1367            } else {
1368                // Defensive: added_partial is true but last message isn't Assistant.
1369                // Push as new message rather than silently dropping the update.
1370                tracing::warn!("update_partial_message: expected last message to be Assistant");
1371                self.messages.push(Message::Assistant(partial));
1372            }
1373            false
1374        } else {
1375            self.messages.push(Message::Assistant(partial));
1376            *added_partial = true;
1377            true
1378        }
1379    }
1380
1381    fn finalize_assistant_message(
1382        &mut self,
1383        message: AssistantMessage,
1384        on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
1385        added_partial: bool,
1386    ) -> AssistantMessage {
1387        let arc = Arc::new(message);
1388        if added_partial {
1389            if let Some(last @ Message::Assistant(_)) = self.messages.last_mut() {
1390                *last = Message::Assistant(Arc::clone(&arc));
1391            } else {
1392                // Defensive: added_partial is true but last message isn't Assistant.
1393                // Push as new message rather than overwriting an unrelated message.
1394                tracing::warn!("finalize_assistant_message: expected last message to be Assistant");
1395                self.messages.push(Message::Assistant(Arc::clone(&arc)));
1396                on_event(AgentEvent::MessageStart {
1397                    message: Message::Assistant(Arc::clone(&arc)),
1398                });
1399            }
1400        } else {
1401            self.messages.push(Message::Assistant(Arc::clone(&arc)));
1402            on_event(AgentEvent::MessageStart {
1403                message: Message::Assistant(Arc::clone(&arc)),
1404            });
1405        }
1406
1407        on_event(AgentEvent::MessageEnd {
1408            message: Message::Assistant(Arc::clone(&arc)),
1409        });
1410        Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone())
1411    }
1412
1413    async fn execute_parallel_batch(
1414        &self,
1415        batch: Vec<(usize, ToolCall)>,
1416        on_event: AgentEventHandler,
1417        abort: Option<AbortSignal>,
1418    ) -> Vec<(usize, (ToolOutput, bool))> {
1419        let futures = batch.into_iter().map(|(idx, tc)| {
1420            let on_event = Arc::clone(&on_event);
1421            async move { (idx, self.execute_tool_owned(tc, on_event).await) }
1422        });
1423
1424        if let Some(signal) = abort.as_ref() {
1425            use futures::future::{Either, select};
1426            let all_fut = stream::iter(futures)
1427                .buffer_unordered(MAX_CONCURRENT_TOOLS)
1428                .collect::<Vec<_>>()
1429                .fuse();
1430            let abort_fut = signal.wait().fuse();
1431            futures::pin_mut!(all_fut, abort_fut);
1432
1433            match select(all_fut, abort_fut).await {
1434                Either::Left((batch_results, _)) => batch_results,
1435                Either::Right(_) => Vec::new(), // Aborted
1436            }
1437        } else {
1438            stream::iter(futures)
1439                .buffer_unordered(MAX_CONCURRENT_TOOLS)
1440                .collect::<Vec<_>>()
1441                .await
1442        }
1443    }
1444
1445    #[allow(clippy::too_many_lines)]
1446    async fn execute_tool_calls(
1447        &mut self,
1448        tool_calls: &[ToolCall],
1449        on_event: AgentEventHandler,
1450        new_messages: &mut Vec<Message>,
1451        abort: Option<AbortSignal>,
1452    ) -> Result<ToolExecutionOutcome> {
1453        let mut results = Vec::new();
1454        let mut steering_messages: Option<Vec<Message>> = None;
1455
1456        // Phase 1: Emit start events for ALL tools up front.
1457        for tool_call in tool_calls {
1458            on_event(AgentEvent::ToolExecutionStart {
1459                tool_call_id: tool_call.id.clone(),
1460                tool_name: tool_call.name.clone(),
1461                args: tool_call.arguments.clone(),
1462            });
1463        }
1464
1465        // Phase 2: Execute tools with safety barriers.
1466        let mut pending_parallel: Vec<(usize, ToolCall)> = Vec::new();
1467        let mut tool_outputs: Vec<Option<(ToolOutput, bool)>> = vec![None; tool_calls.len()];
1468
1469        // Iterate through tools. If read-only, buffer. If unsafe, flush buffer then run unsafe.
1470        for (index, tool_call) in tool_calls.iter().enumerate() {
1471            if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1472                break;
1473            }
1474
1475            let is_read_only =
1476                matches!(self.tools.get(&tool_call.name), Some(tool) if tool.is_read_only());
1477
1478            if is_read_only {
1479                pending_parallel.push((index, tool_call.clone()));
1480            } else {
1481                // Check steering BEFORE flushing parallel or running unsafe.
1482                let steering = self.drain_steering_messages().await;
1483                if !steering.is_empty() {
1484                    steering_messages = Some(steering);
1485                    break;
1486                }
1487
1488                // Barrier: flush parallel buffer first
1489                if !pending_parallel.is_empty() {
1490                    let batch = std::mem::take(&mut pending_parallel);
1491                    let results = self
1492                        .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
1493                        .await;
1494                    for (idx, result) in results {
1495                        tool_outputs[idx] = Some(result);
1496                    }
1497                }
1498
1499                if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1500                    break;
1501                }
1502
1503                // Execute unsafe tool sequentially
1504                // Check steering AGAIN before the potentially expensive unsafe tool
1505                let steering = self.drain_steering_messages().await;
1506                if !steering.is_empty() {
1507                    steering_messages = Some(steering);
1508                    break;
1509                }
1510
1511                let result = self
1512                    .execute_tool(tool_call.clone(), Arc::clone(&on_event))
1513                    .await;
1514                tool_outputs[index] = Some(result);
1515            }
1516        }
1517
1518        // Flush remaining parallel tools
1519        if !pending_parallel.is_empty()
1520            && !abort.as_ref().is_some_and(AbortSignal::is_aborted)
1521            && steering_messages.is_none()
1522        {
1523            let batch = std::mem::take(&mut pending_parallel);
1524            // Check steering one last time before final flush
1525            let steering = self.drain_steering_messages().await;
1526            if steering.is_empty() {
1527                let results = self
1528                    .execute_parallel_batch(batch, Arc::clone(&on_event), abort.clone())
1529                    .await;
1530                for (idx, result) in results {
1531                    tool_outputs[idx] = Some(result);
1532                }
1533            } else {
1534                steering_messages = Some(steering);
1535            }
1536        }
1537
1538        // Phase 3: Process results sequentially and handle skips.
1539        for (index, tool_call) in tool_calls.iter().enumerate() {
1540            // Check for new steering if we haven't already found some.
1541            // This catches steering messages that arrived during the *last* tool's execution.
1542            if steering_messages.is_none() && !abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1543                let steering = self.drain_steering_messages().await;
1544                if !steering.is_empty() {
1545                    steering_messages = Some(steering);
1546                }
1547            }
1548
1549            // Extract the result, tracking whether the tool actually executed.
1550            // If `tool_outputs[index]` is `Some`, `execute_tool` ran.
1551            // If `None`, the tool was skipped/aborted.
1552            if let Some((output, is_error)) = tool_outputs[index].take() {
1553                // Tool executed normally.
1554                // Build ToolResultMessage first and wrap in Arc; the message
1555                // clone below is O(1) Arc refcount bump since ToolResult is
1556                // already Arc-wrapped in the Message enum.
1557                let tool_result = Arc::new(ToolResultMessage {
1558                    tool_call_id: tool_call.id.clone(),
1559                    tool_name: tool_call.name.clone(),
1560                    content: output.content,
1561                    details: output.details,
1562                    is_error,
1563                    timestamp: Utc::now().timestamp_millis(),
1564                });
1565
1566                // Emit ToolExecutionEnd. We clone content/details from the
1567                // Arc'd result — same data, no extra source clone.
1568                on_event(AgentEvent::ToolExecutionEnd {
1569                    tool_call_id: tool_result.tool_call_id.clone(),
1570                    tool_name: tool_result.tool_name.clone(),
1571                    result: ToolOutput {
1572                        content: tool_result.content.clone(),
1573                        details: tool_result.details.clone(),
1574                        is_error,
1575                    },
1576                    is_error,
1577                });
1578
1579                let msg = Message::ToolResult(Arc::clone(&tool_result));
1580                self.messages.push(msg.clone());
1581                on_event(AgentEvent::MessageStart {
1582                    message: msg.clone(),
1583                });
1584                new_messages.push(msg.clone());
1585                on_event(AgentEvent::MessageEnd { message: msg });
1586
1587                results.push(tool_result);
1588            } else if steering_messages.is_some() {
1589                // Skipped due to steering.
1590                results.push(self.skip_tool_call(tool_call, &on_event, new_messages));
1591            } else {
1592                // Aborted or otherwise failed to run (e.g. abort signal).
1593                let output = ToolOutput {
1594                    content: vec![ContentBlock::Text(TextContent::new(
1595                        "Tool execution aborted",
1596                    ))],
1597                    details: None,
1598                    is_error: true,
1599                };
1600
1601                on_event(AgentEvent::ToolExecutionUpdate {
1602                    tool_call_id: tool_call.id.clone(),
1603                    tool_name: tool_call.name.clone(),
1604                    args: tool_call.arguments.clone(),
1605                    partial_result: ToolOutput {
1606                        content: output.content.clone(),
1607                        details: output.details.clone(),
1608                        is_error: true,
1609                    },
1610                });
1611
1612                on_event(AgentEvent::ToolExecutionEnd {
1613                    tool_call_id: tool_call.id.clone(),
1614                    tool_name: tool_call.name.clone(),
1615                    result: ToolOutput {
1616                        content: output.content.clone(),
1617                        details: output.details.clone(),
1618                        is_error: true,
1619                    },
1620                    is_error: true,
1621                });
1622
1623                let tool_result = Arc::new(ToolResultMessage {
1624                    tool_call_id: tool_call.id.clone(),
1625                    tool_name: tool_call.name.clone(),
1626                    content: output.content,
1627                    details: output.details,
1628                    is_error: true,
1629                    timestamp: Utc::now().timestamp_millis(),
1630                });
1631
1632                let msg = Message::ToolResult(Arc::clone(&tool_result));
1633                self.messages.push(msg.clone());
1634                on_event(AgentEvent::MessageStart {
1635                    message: msg.clone(),
1636                });
1637                let end_msg = msg.clone();
1638                new_messages.push(msg);
1639                on_event(AgentEvent::MessageEnd { message: end_msg });
1640
1641                results.push(tool_result);
1642            }
1643        }
1644
1645        Ok(ToolExecutionOutcome {
1646            tool_results: results,
1647            steering_messages,
1648        })
1649    }
1650
1651    async fn execute_tool(
1652        &self,
1653        tool_call: ToolCall,
1654        on_event: AgentEventHandler,
1655    ) -> (ToolOutput, bool) {
1656        let extensions = self.extensions.clone();
1657
1658        let (mut output, is_error) = if let Some(extensions) = &extensions {
1659            match Self::dispatch_tool_call_hook(extensions, &tool_call).await {
1660                Some(blocked_output) => (blocked_output, true),
1661                None => {
1662                    self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
1663                        .await
1664                }
1665            }
1666        } else {
1667            self.execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
1668                .await
1669        };
1670
1671        if let Some(extensions) = &extensions {
1672            Self::apply_tool_result_hook(extensions, &tool_call, &mut output, is_error).await;
1673        }
1674
1675        (output, is_error)
1676    }
1677
1678    async fn execute_tool_owned(
1679        &self,
1680        tool_call: ToolCall,
1681        on_event: AgentEventHandler,
1682    ) -> (ToolOutput, bool) {
1683        self.execute_tool(tool_call, on_event).await
1684    }
1685
1686    async fn execute_tool_without_hooks(
1687        &self,
1688        tool_call: &ToolCall,
1689        on_event: AgentEventHandler,
1690    ) -> (ToolOutput, bool) {
1691        // Find the tool
1692        let Some(tool) = self.tools.get(&tool_call.name) else {
1693            return (Self::tool_not_found_output(&tool_call.name), true);
1694        };
1695
1696        let tool_name = tool_call.name.clone();
1697        let tool_id = tool_call.id.clone();
1698        let tool_args = tool_call.arguments.clone();
1699        let on_event = Arc::clone(&on_event);
1700
1701        let update_callback = move |update: ToolUpdate| {
1702            on_event(AgentEvent::ToolExecutionUpdate {
1703                tool_call_id: tool_id.clone(),
1704                tool_name: tool_name.clone(),
1705                args: tool_args.clone(),
1706                partial_result: ToolOutput {
1707                    content: update.content,
1708                    details: update.details,
1709                    is_error: false,
1710                },
1711            });
1712        };
1713
1714        match tool
1715            .execute(
1716                &tool_call.id,
1717                tool_call.arguments.clone(),
1718                Some(Box::new(update_callback)),
1719            )
1720            .await
1721        {
1722            Ok(output) => {
1723                let is_error = output.is_error;
1724                (output, is_error)
1725            }
1726            Err(e) => (
1727                ToolOutput {
1728                    content: vec![ContentBlock::Text(TextContent::new(format!("Error: {e}")))],
1729                    details: None,
1730                    is_error: true,
1731                },
1732                true,
1733            ),
1734        }
1735    }
1736
1737    fn tool_not_found_output(tool_name: &str) -> ToolOutput {
1738        ToolOutput {
1739            content: vec![ContentBlock::Text(TextContent::new(format!(
1740                "Error: Tool '{tool_name}' not found"
1741            )))],
1742            details: None,
1743            is_error: true,
1744        }
1745    }
1746
1747    async fn dispatch_tool_call_hook(
1748        extensions: &ExtensionManager,
1749        tool_call: &ToolCall,
1750    ) -> Option<ToolOutput> {
1751        match extensions
1752            .dispatch_tool_call(tool_call, EXTENSION_EVENT_TIMEOUT_MS)
1753            .await
1754        {
1755            Ok(Some(result)) if result.block => {
1756                Some(Self::tool_call_blocked_output(result.reason.as_deref()))
1757            }
1758            Ok(_) => None,
1759            Err(err) => {
1760                tracing::warn!("tool_call extension hook failed (fail-open): {err}");
1761                None
1762            }
1763        }
1764    }
1765
1766    fn tool_call_blocked_output(reason: Option<&str>) -> ToolOutput {
1767        let reason = reason.map(str::trim).filter(|reason| !reason.is_empty());
1768        let message = reason.map_or_else(
1769            || "Tool execution was blocked by an extension".to_string(),
1770            |reason| format!("Tool execution blocked: {reason}"),
1771        );
1772
1773        ToolOutput {
1774            content: vec![ContentBlock::Text(TextContent::new(message))],
1775            details: None,
1776            is_error: true,
1777        }
1778    }
1779
1780    async fn apply_tool_result_hook(
1781        extensions: &ExtensionManager,
1782        tool_call: &ToolCall,
1783        output: &mut ToolOutput,
1784        is_error: bool,
1785    ) {
1786        match extensions
1787            .dispatch_tool_result(tool_call, &*output, is_error, EXTENSION_EVENT_TIMEOUT_MS)
1788            .await
1789        {
1790            Ok(Some(result)) => {
1791                if let Some(content) = result.content {
1792                    output.content = content;
1793                }
1794                if let Some(details) = result.details {
1795                    output.details = Some(details);
1796                }
1797            }
1798            Ok(None) => {}
1799            Err(err) => tracing::warn!("tool_result extension hook failed (fail-open): {err}"),
1800        }
1801    }
1802
1803    fn skip_tool_call(
1804        &mut self,
1805        tool_call: &ToolCall,
1806        on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
1807        new_messages: &mut Vec<Message>,
1808    ) -> Arc<ToolResultMessage> {
1809        let output = ToolOutput {
1810            content: vec![ContentBlock::Text(TextContent::new(
1811                "Skipped due to queued user message.",
1812            ))],
1813            details: None,
1814            is_error: true,
1815        };
1816
1817        // Note: Phase 1 already emitted ToolExecutionStart for all tools,
1818        // so we only emit Update and End here.
1819        on_event(AgentEvent::ToolExecutionUpdate {
1820            tool_call_id: tool_call.id.clone(),
1821            tool_name: tool_call.name.clone(),
1822            args: tool_call.arguments.clone(),
1823            partial_result: output.clone(),
1824        });
1825        on_event(AgentEvent::ToolExecutionEnd {
1826            tool_call_id: tool_call.id.clone(),
1827            tool_name: tool_call.name.clone(),
1828            result: output.clone(),
1829            is_error: true,
1830        });
1831
1832        let tool_result = Arc::new(ToolResultMessage {
1833            tool_call_id: tool_call.id.clone(),
1834            tool_name: tool_call.name.clone(),
1835            content: output.content,
1836            details: output.details,
1837            is_error: true,
1838            timestamp: Utc::now().timestamp_millis(),
1839        });
1840
1841        let msg = Message::ToolResult(Arc::clone(&tool_result));
1842        self.messages.push(msg.clone());
1843        new_messages.push(msg.clone());
1844
1845        on_event(AgentEvent::MessageStart {
1846            message: msg.clone(),
1847        });
1848        on_event(AgentEvent::MessageEnd { message: msg });
1849
1850        tool_result
1851    }
1852}
1853
1854// ============================================================================
1855// Agent Session (Agent + Session persistence)
1856// ============================================================================
1857
1858struct ToolExecutionOutcome {
1859    tool_results: Vec<Arc<ToolResultMessage>>,
1860    steering_messages: Option<Vec<Message>>,
1861}
1862
1863/// Pre-created extension runtime state for overlapping startup I/O.
1864///
1865/// By spawning runtime boot as a background task *before* session creation and
1866/// model selection, expensive runtime startup can overlap with other work.
1867pub struct PreWarmedExtensionRuntime {
1868    /// The extension manager (already has `cwd` and risk config set).
1869    pub manager: ExtensionManager,
1870    /// The booted runtime handle.
1871    pub runtime: ExtensionRuntimeHandle,
1872    /// The tool registry passed to the runtime during boot.
1873    pub tools: Arc<ToolRegistry>,
1874}
1875
1876pub struct AgentSession {
1877    pub agent: Agent,
1878    pub session: Arc<Mutex<Session>>,
1879    save_enabled: bool,
1880    /// Extension lifecycle region — ensures the JS runtime thread is shut
1881    /// down when the session ends.
1882    pub extensions: Option<ExtensionRegion>,
1883    extensions_is_streaming: Arc<AtomicBool>,
1884    compaction_settings: ResolvedCompactionSettings,
1885    compaction_worker: CompactionWorkerState,
1886    model_registry: Option<ModelRegistry>,
1887    auth_storage: Option<AuthStorage>,
1888}
1889
1890#[derive(Debug, Default)]
1891struct ExtensionInjectedQueue {
1892    steering: VecDeque<Message>,
1893    follow_up: VecDeque<Message>,
1894}
1895
1896impl ExtensionInjectedQueue {
1897    fn push_steering(&mut self, message: Message) {
1898        self.steering.push_back(message);
1899    }
1900
1901    fn push_follow_up(&mut self, message: Message) {
1902        self.follow_up.push_back(message);
1903    }
1904
1905    fn pop_steering(&mut self) -> Vec<Message> {
1906        self.steering.drain(..).collect()
1907    }
1908
1909    fn pop_follow_up(&mut self) -> Vec<Message> {
1910        self.follow_up.drain(..).collect()
1911    }
1912}
1913
1914#[derive(Clone)]
1915struct AgentSessionHostActions {
1916    session: Arc<Mutex<Session>>,
1917    injected: Arc<StdMutex<ExtensionInjectedQueue>>,
1918    is_streaming: Arc<AtomicBool>,
1919}
1920
1921impl AgentSessionHostActions {
1922    fn enqueue(&self, deliver_as: Option<ExtensionDeliverAs>, message: Message) {
1923        let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
1924        let Ok(mut queue) = self.injected.lock() else {
1925            return;
1926        };
1927        match deliver_as {
1928            ExtensionDeliverAs::FollowUp => {
1929                queue.push_follow_up(message);
1930            }
1931            ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
1932                queue.push_steering(message);
1933            }
1934        }
1935    }
1936
1937    async fn append_to_session(&self, message: Message) -> Result<()> {
1938        let cx = crate::agent_cx::AgentCx::for_request();
1939        let mut session = self
1940            .session
1941            .lock(cx.cx())
1942            .await
1943            .map_err(|e| Error::session(e.to_string()))?;
1944        session.append_model_message(message);
1945        Ok(())
1946    }
1947}
1948
1949#[async_trait]
1950impl ExtensionHostActions for AgentSessionHostActions {
1951    async fn send_message(&self, message: ExtensionSendMessage) -> Result<()> {
1952        let custom_message = Message::Custom(CustomMessage {
1953            content: message.content,
1954            custom_type: message.custom_type,
1955            display: message.display,
1956            details: message.details,
1957            timestamp: Utc::now().timestamp_millis(),
1958        });
1959
1960        if matches!(message.deliver_as, Some(ExtensionDeliverAs::NextTurn)) {
1961            return self.append_to_session(custom_message).await;
1962        }
1963
1964        if self.is_streaming.load(Ordering::SeqCst) {
1965            self.enqueue(message.deliver_as, custom_message);
1966            return Ok(());
1967        }
1968
1969        // Non-streaming, best-effort: persist to session. Triggering a new turn is handled by the
1970        // interactive layer; non-interactive modes will pick this up on the next prompt.
1971        let _ = message.trigger_turn;
1972        self.append_to_session(custom_message).await
1973    }
1974
1975    async fn send_user_message(&self, message: ExtensionSendUserMessage) -> Result<()> {
1976        let user_message = Message::User(UserMessage {
1977            content: UserContent::Text(message.text),
1978            timestamp: Utc::now().timestamp_millis(),
1979        });
1980
1981        if self.is_streaming.load(Ordering::SeqCst) {
1982            self.enqueue(message.deliver_as, user_message);
1983            return Ok(());
1984        }
1985
1986        // Non-streaming, best-effort: persist to session. Interactive mode triggers turns via UI.
1987        self.append_to_session(user_message).await
1988    }
1989}
1990
1991#[cfg(test)]
1992mod message_queue_tests {
1993    use super::*;
1994
1995    fn user_message(text: &str) -> Message {
1996        Message::User(UserMessage {
1997            content: UserContent::Text(text.to_string()),
1998            timestamp: 0,
1999        })
2000    }
2001
2002    #[test]
2003    fn message_queue_one_at_a_time() {
2004        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2005        queue.push_steering(user_message("a"));
2006        queue.push_steering(user_message("b"));
2007
2008        let first = queue.pop_steering();
2009        assert_eq!(first.len(), 1);
2010        assert!(matches!(
2011            first.first(),
2012            Some(Message::User(UserMessage { content, .. }))
2013                if matches!(content, UserContent::Text(text) if text == "a")
2014        ));
2015
2016        let second = queue.pop_steering();
2017        assert_eq!(second.len(), 1);
2018        assert!(matches!(
2019            second.first(),
2020            Some(Message::User(UserMessage { content, .. }))
2021                if matches!(content, UserContent::Text(text) if text == "b")
2022        ));
2023
2024        assert!(queue.pop_steering().is_empty());
2025    }
2026
2027    #[test]
2028    fn message_queue_all_mode() {
2029        let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
2030        queue.push_steering(user_message("a"));
2031        queue.push_steering(user_message("b"));
2032
2033        let drained = queue.pop_steering();
2034        assert_eq!(drained.len(), 2);
2035        assert!(queue.pop_steering().is_empty());
2036    }
2037
2038    #[test]
2039    fn message_queue_separates_kinds() {
2040        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2041        queue.push_steering(user_message("steer"));
2042        queue.push_follow_up(user_message("follow"));
2043
2044        let steering = queue.pop_steering();
2045        assert_eq!(steering.len(), 1);
2046        assert_eq!(queue.pending_count(), 1);
2047
2048        let follow = queue.pop_follow_up();
2049        assert_eq!(follow.len(), 1);
2050        assert_eq!(queue.pending_count(), 0);
2051    }
2052
2053    #[test]
2054    fn message_queue_seq_increments() {
2055        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2056        let first = queue.push_steering(user_message("a"));
2057        let second = queue.push_follow_up(user_message("b"));
2058        assert!(second > first);
2059    }
2060
2061    #[test]
2062    fn message_queue_seq_saturates_at_u64_max() {
2063        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
2064        queue.next_seq = u64::MAX;
2065
2066        let first = queue.push_steering(user_message("a"));
2067        let second = queue.push_follow_up(user_message("b"));
2068
2069        assert_eq!(first, u64::MAX);
2070        assert_eq!(second, u64::MAX);
2071        assert_eq!(queue.pending_count(), 2);
2072    }
2073
2074    #[test]
2075    fn message_queue_follow_up_all_mode_drains_entire_queue_in_order() {
2076        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::All);
2077        queue.push_follow_up(user_message("f1"));
2078        queue.push_follow_up(user_message("f2"));
2079
2080        let follow_up = queue.pop_follow_up();
2081        assert_eq!(follow_up.len(), 2);
2082        assert!(matches!(
2083            follow_up.first(),
2084            Some(Message::User(UserMessage { content, .. }))
2085                if matches!(content, UserContent::Text(text) if text == "f1")
2086        ));
2087        assert!(matches!(
2088            follow_up.get(1),
2089            Some(Message::User(UserMessage { content, .. }))
2090                if matches!(content, UserContent::Text(text) if text == "f2")
2091        ));
2092        assert!(queue.pop_follow_up().is_empty());
2093    }
2094}
2095
2096#[cfg(test)]
2097mod extensions_integration_tests {
2098    use super::*;
2099
2100    use crate::session::Session;
2101    use asupersync::runtime::RuntimeBuilder;
2102    use async_trait::async_trait;
2103    use futures::Stream;
2104    use serde_json::json;
2105    use std::path::Path;
2106    use std::pin::Pin;
2107    use std::sync::atomic::AtomicUsize;
2108
2109    #[derive(Debug)]
2110    struct NoopProvider;
2111
2112    #[async_trait]
2113    #[allow(clippy::unnecessary_literal_bound)]
2114    impl Provider for NoopProvider {
2115        fn name(&self) -> &str {
2116            "test-provider"
2117        }
2118
2119        fn api(&self) -> &str {
2120            "test-api"
2121        }
2122
2123        fn model_id(&self) -> &str {
2124            "test-model"
2125        }
2126
2127        async fn stream(
2128            &self,
2129            _context: &Context<'_>,
2130            _options: &StreamOptions,
2131        ) -> crate::error::Result<
2132            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2133        > {
2134            Ok(Box::pin(futures::stream::empty()))
2135        }
2136    }
2137
2138    #[derive(Debug)]
2139    struct CountingTool {
2140        calls: Arc<AtomicUsize>,
2141    }
2142
2143    #[async_trait]
2144    #[allow(clippy::unnecessary_literal_bound)]
2145    impl Tool for CountingTool {
2146        fn name(&self) -> &str {
2147            "count_tool"
2148        }
2149
2150        fn label(&self) -> &str {
2151            "count_tool"
2152        }
2153
2154        fn description(&self) -> &str {
2155            "counting tool"
2156        }
2157
2158        fn parameters(&self) -> serde_json::Value {
2159            json!({ "type": "object" })
2160        }
2161
2162        async fn execute(
2163            &self,
2164            _tool_call_id: &str,
2165            _input: serde_json::Value,
2166            _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
2167        ) -> Result<ToolOutput> {
2168            self.calls.fetch_add(1, Ordering::SeqCst);
2169            Ok(ToolOutput {
2170                content: vec![ContentBlock::Text(TextContent::new("ok"))],
2171                details: None,
2172                is_error: false,
2173            })
2174        }
2175    }
2176
2177    #[derive(Debug)]
2178    struct ToolUseProvider {
2179        stream_calls: AtomicUsize,
2180    }
2181
2182    impl ToolUseProvider {
2183        const fn new() -> Self {
2184            Self {
2185                stream_calls: AtomicUsize::new(0),
2186            }
2187        }
2188
2189        fn assistant_message(
2190            &self,
2191            stop_reason: StopReason,
2192            content: Vec<ContentBlock>,
2193        ) -> AssistantMessage {
2194            AssistantMessage {
2195                content,
2196                api: self.api().to_string(),
2197                provider: self.name().to_string(),
2198                model: self.model_id().to_string(),
2199                usage: Usage::default(),
2200                stop_reason,
2201                error_message: None,
2202                timestamp: 0,
2203            }
2204        }
2205    }
2206
2207    #[async_trait]
2208    #[allow(clippy::unnecessary_literal_bound)]
2209    impl Provider for ToolUseProvider {
2210        fn name(&self) -> &str {
2211            "test-provider"
2212        }
2213
2214        fn api(&self) -> &str {
2215            "test-api"
2216        }
2217
2218        fn model_id(&self) -> &str {
2219            "test-model"
2220        }
2221
2222        async fn stream(
2223            &self,
2224            _context: &Context<'_>,
2225            _options: &StreamOptions,
2226        ) -> crate::error::Result<
2227            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
2228        > {
2229            let call_index = self.stream_calls.fetch_add(1, Ordering::SeqCst);
2230
2231            let partial = self.assistant_message(StopReason::Stop, Vec::new());
2232
2233            let (reason, message) = if call_index == 0 {
2234                let tool_calls = vec![
2235                    ToolCall {
2236                        id: "call-1".to_string(),
2237                        name: "count_tool".to_string(),
2238                        arguments: json!({}),
2239                        thought_signature: None,
2240                    },
2241                    ToolCall {
2242                        id: "call-2".to_string(),
2243                        name: "count_tool".to_string(),
2244                        arguments: json!({}),
2245                        thought_signature: None,
2246                    },
2247                ];
2248
2249                (
2250                    StopReason::ToolUse,
2251                    self.assistant_message(
2252                        StopReason::ToolUse,
2253                        tool_calls
2254                            .into_iter()
2255                            .map(ContentBlock::ToolCall)
2256                            .collect::<Vec<_>>(),
2257                    ),
2258                )
2259            } else {
2260                (
2261                    StopReason::Stop,
2262                    self.assistant_message(
2263                        StopReason::Stop,
2264                        vec![ContentBlock::Text(TextContent::new("done"))],
2265                    ),
2266                )
2267            };
2268
2269            let events = vec![
2270                Ok(StreamEvent::Start { partial }),
2271                Ok(StreamEvent::Done { reason, message }),
2272            ];
2273            Ok(Box::pin(futures::stream::iter(events)))
2274        }
2275    }
2276
2277    #[test]
2278    fn agent_session_enable_extensions_registers_extension_tools() {
2279        let runtime = RuntimeBuilder::current_thread()
2280            .build()
2281            .expect("runtime build");
2282
2283        runtime.block_on(async {
2284            let temp_dir = tempfile::tempdir().expect("tempdir");
2285            let entry_path = temp_dir.path().join("ext.mjs");
2286            std::fs::write(
2287                &entry_path,
2288                r#"
2289                export default function init(pi) {
2290                  pi.registerTool({
2291                    name: "hello_tool",
2292                    label: "hello_tool",
2293                    description: "test tool",
2294                    parameters: { type: "object", properties: { name: { type: "string" } } },
2295                    execute: async (_callId, input, _onUpdate, _abort, ctx) => {
2296                      const who = input && input.name ? String(input.name) : "world";
2297                      const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
2298                      return {
2299                        content: [{ type: "text", text: `hello ${who}` }],
2300                        details: { from: "extension", cwd: cwd },
2301                        isError: false
2302                      };
2303                    }
2304                  });
2305                }
2306                "#,
2307            )
2308            .expect("write extension entry");
2309
2310            let provider = Arc::new(NoopProvider);
2311            let tools = ToolRegistry::new(&[], Path::new("."), None);
2312            let agent = Agent::new(provider, tools, AgentConfig::default());
2313            let session = Arc::new(Mutex::new(Session::in_memory()));
2314            let mut agent_session =
2315                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2316
2317            agent_session
2318                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2319                .await
2320                .expect("enable extensions");
2321
2322            let tool = agent_session
2323                .agent
2324                .tools
2325                .get("hello_tool")
2326                .expect("hello_tool registered");
2327
2328            let output = tool
2329                .execute("call-1", json!({ "name": "pi" }), None)
2330                .await
2331                .expect("execute tool");
2332
2333            assert!(!output.is_error);
2334            assert!(
2335                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2336                "Expected single text content block, got {:?}",
2337                output.content
2338            );
2339            let [ContentBlock::Text(text)] = output.content.as_slice() else {
2340                return;
2341            };
2342            assert_eq!(text.text, "hello pi");
2343
2344            let details = output.details.expect("details present");
2345            assert_eq!(
2346                details.get("from").and_then(serde_json::Value::as_str),
2347                Some("extension")
2348            );
2349        });
2350    }
2351
2352    #[test]
2353    fn agent_session_enable_extensions_rejects_mixed_js_and_native_entries() {
2354        let runtime = RuntimeBuilder::current_thread()
2355            .build()
2356            .expect("runtime build");
2357
2358        runtime.block_on(async {
2359            let temp_dir = tempfile::tempdir().expect("tempdir");
2360            let js_entry = temp_dir.path().join("ext.mjs");
2361            let native_entry = temp_dir.path().join("ext.native.json");
2362            std::fs::write(
2363                &js_entry,
2364                r"
2365                export default function init(_pi) {}
2366                ",
2367            )
2368            .expect("write js extension entry");
2369            std::fs::write(&native_entry, "{}").expect("write native extension descriptor");
2370
2371            let provider = Arc::new(NoopProvider);
2372            let tools = ToolRegistry::new(&[], Path::new("."), None);
2373            let agent = Agent::new(provider, tools, AgentConfig::default());
2374            let session = Arc::new(Mutex::new(Session::in_memory()));
2375            let mut agent_session =
2376                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2377
2378            let err = agent_session
2379                .enable_extensions(&[], temp_dir.path(), None, &[js_entry, native_entry])
2380                .await
2381                .expect_err("mixed extension runtimes should be rejected");
2382            let msg = err.to_string();
2383            assert!(
2384                msg.contains("Mixed extension runtimes are not supported"),
2385                "unexpected mixed-runtime error message: {msg}"
2386            );
2387        });
2388    }
2389
2390    #[test]
2391    fn extension_send_message_persists_custom_message_entry_when_idle() {
2392        let runtime = RuntimeBuilder::current_thread()
2393            .build()
2394            .expect("runtime build");
2395
2396        runtime.block_on(async {
2397            let temp_dir = tempfile::tempdir().expect("tempdir");
2398            let entry_path = temp_dir.path().join("ext.mjs");
2399            std::fs::write(
2400                &entry_path,
2401                r#"
2402                export default function init(pi) {
2403                  pi.registerTool({
2404                    name: "emit_message",
2405                    label: "emit_message",
2406                    description: "emit a custom message",
2407                    parameters: { type: "object" },
2408                    execute: async () => {
2409                      pi.sendMessage({
2410                        customType: "note",
2411                        content: "hello",
2412                        display: true,
2413                        details: { from: "test" }
2414                      }, {});
2415                      return { content: [{ type: "text", text: "ok" }], isError: false };
2416                    }
2417                  });
2418                }
2419                "#,
2420            )
2421            .expect("write extension entry");
2422
2423            let provider = Arc::new(NoopProvider);
2424            let tools = ToolRegistry::new(&[], Path::new("."), None);
2425            let agent = Agent::new(provider, tools, AgentConfig::default());
2426            let session = Arc::new(Mutex::new(Session::in_memory()));
2427            let mut agent_session = AgentSession::new(
2428                agent,
2429                Arc::clone(&session),
2430                false,
2431                ResolvedCompactionSettings::default(),
2432            );
2433
2434            agent_session
2435                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2436                .await
2437                .expect("enable extensions");
2438
2439            let tool = agent_session
2440                .agent
2441                .tools
2442                .get("emit_message")
2443                .expect("emit_message registered");
2444
2445            let _ = tool
2446                .execute("call-1", json!({}), None)
2447                .await
2448                .expect("execute tool");
2449
2450            let cx = crate::agent_cx::AgentCx::for_request();
2451            let session_guard = session
2452                .lock(cx.cx())
2453                .await
2454                .expect("lock session");
2455            let messages = session_guard.to_messages_for_current_path();
2456
2457            assert!(
2458                messages.iter().any(|msg| {
2459                    matches!(
2460                        msg,
2461                        Message::Custom(CustomMessage { custom_type, content, display, details, .. })
2462                            if custom_type == "note"
2463                                && content == "hello"
2464                                && *display
2465                                && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
2466                    )
2467                }),
2468                "expected custom message to be persisted, got {messages:?}"
2469            );
2470        });
2471    }
2472
2473    #[test]
2474    fn extension_send_message_persists_custom_message_entry_when_idle_after_await() {
2475        let runtime = RuntimeBuilder::current_thread()
2476            .build()
2477            .expect("runtime build");
2478
2479        runtime.block_on(async {
2480            let temp_dir = tempfile::tempdir().expect("tempdir");
2481            let entry_path = temp_dir.path().join("ext.mjs");
2482            std::fs::write(
2483                &entry_path,
2484                r#"
2485                export default function init(pi) {
2486                  pi.registerTool({
2487                    name: "emit_message",
2488                    label: "emit_message",
2489                    description: "emit a custom message",
2490                    parameters: { type: "object" },
2491                    execute: async () => {
2492                      await Promise.resolve();
2493                      pi.sendMessage({
2494                        customType: "note",
2495                        content: "hello-after-await",
2496                        display: true,
2497                        details: { from: "test" }
2498                      }, {});
2499                      return { content: [{ type: "text", text: "ok" }], isError: false };
2500                    }
2501                  });
2502                }
2503                "#,
2504            )
2505            .expect("write extension entry");
2506
2507            let provider = Arc::new(NoopProvider);
2508            let tools = ToolRegistry::new(&[], Path::new("."), None);
2509            let agent = Agent::new(provider, tools, AgentConfig::default());
2510            let session = Arc::new(Mutex::new(Session::in_memory()));
2511            let mut agent_session = AgentSession::new(
2512                agent,
2513                Arc::clone(&session),
2514                false,
2515                ResolvedCompactionSettings::default(),
2516            );
2517
2518            agent_session
2519                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2520                .await
2521                .expect("enable extensions");
2522
2523            let tool = agent_session
2524                .agent
2525                .tools
2526                .get("emit_message")
2527                .expect("emit_message registered");
2528
2529            let _ = tool
2530                .execute("call-1", json!({}), None)
2531                .await
2532                .expect("execute tool");
2533
2534            let cx = crate::agent_cx::AgentCx::for_request();
2535            let session_guard = session
2536                .lock(cx.cx())
2537                .await
2538                .expect("lock session");
2539            let messages = session_guard.to_messages_for_current_path();
2540
2541            assert!(
2542                messages.iter().any(|msg| {
2543                    matches!(
2544                        msg,
2545                        Message::Custom(CustomMessage { custom_type, content, display, details, .. })
2546                            if custom_type == "note"
2547                                && content == "hello-after-await"
2548                                && *display
2549                                && details.as_ref().and_then(|v| v.get("from").and_then(Value::as_str)) == Some("test")
2550                    )
2551                }),
2552                "expected custom message to be persisted, got {messages:?}"
2553            );
2554        });
2555    }
2556
2557    #[test]
2558    fn send_user_message_steer_skips_remaining_tools() {
2559        let runtime = RuntimeBuilder::current_thread()
2560            .build()
2561            .expect("runtime build");
2562
2563        runtime.block_on(async {
2564            let temp_dir = tempfile::tempdir().expect("tempdir");
2565            let entry_path = temp_dir.path().join("ext.mjs");
2566            std::fs::write(
2567                &entry_path,
2568                r#"
2569                export default function init(pi) {
2570                  let sent = false;
2571                  pi.on("tool_call", async (event) => {
2572                    if (sent) return {};
2573                    if (event && event.toolName === "count_tool") {
2574                      sent = true;
2575                      await pi.events("sendUserMessage", {
2576                        text: "steer-now",
2577                        options: { deliverAs: "steer" }
2578                      });
2579                    }
2580                    return {};
2581                  });
2582                }
2583                "#,
2584            )
2585            .expect("write extension entry");
2586
2587            let provider = Arc::new(ToolUseProvider::new());
2588            let calls = Arc::new(AtomicUsize::new(0));
2589            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2590                calls: Arc::clone(&calls),
2591            })]);
2592            let agent = Agent::new(provider, tools, AgentConfig::default());
2593            let session = Arc::new(Mutex::new(Session::in_memory()));
2594            let mut agent_session =
2595                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2596
2597            agent_session
2598                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2599                .await
2600                .expect("enable extensions");
2601
2602            let _ = agent_session
2603                .run_text("go".to_string(), |_| {})
2604                .await
2605                .expect("run_text");
2606
2607            // A steer message should short-circuit remaining tool dispatch.
2608            assert_eq!(calls.load(Ordering::SeqCst), 1);
2609        });
2610    }
2611
2612    #[test]
2613    fn send_user_message_follow_up_does_not_skip_tools() {
2614        let runtime = RuntimeBuilder::current_thread()
2615            .build()
2616            .expect("runtime build");
2617
2618        runtime.block_on(async {
2619            let temp_dir = tempfile::tempdir().expect("tempdir");
2620            let entry_path = temp_dir.path().join("ext.mjs");
2621            std::fs::write(
2622                &entry_path,
2623                r#"
2624                export default function init(pi) {
2625                  let sent = false;
2626                  pi.on("tool_call", async (event) => {
2627                    if (sent) return {};
2628                    if (event && event.toolName === "count_tool") {
2629                      sent = true;
2630                      await pi.events("sendUserMessage", {
2631                        text: "follow-up",
2632                        options: { deliverAs: "followUp" }
2633                      });
2634                    }
2635                    return {};
2636                  });
2637                }
2638                "#,
2639            )
2640            .expect("write extension entry");
2641
2642            let provider = Arc::new(ToolUseProvider::new());
2643            let calls = Arc::new(AtomicUsize::new(0));
2644            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2645                calls: Arc::clone(&calls),
2646            })]);
2647            let agent = Agent::new(provider, tools, AgentConfig::default());
2648            let session = Arc::new(Mutex::new(Session::in_memory()));
2649            let mut agent_session =
2650                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2651
2652            agent_session
2653                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2654                .await
2655                .expect("enable extensions");
2656
2657            let _ = agent_session
2658                .run_text("go".to_string(), |_| {})
2659                .await
2660                .expect("run_text");
2661
2662            assert_eq!(calls.load(Ordering::SeqCst), 2);
2663        });
2664    }
2665
2666    #[test]
2667    fn tool_call_hook_can_block_tool_execution() {
2668        let runtime = RuntimeBuilder::current_thread()
2669            .build()
2670            .expect("runtime build");
2671
2672        runtime.block_on(async {
2673            let temp_dir = tempfile::tempdir().expect("tempdir");
2674            let entry_path = temp_dir.path().join("ext.mjs");
2675            std::fs::write(
2676                &entry_path,
2677                r#"
2678                export default function init(pi) {
2679                  pi.on("tool_call", async (event) => {
2680                    if (event && event.toolName === "count_tool") {
2681                      return { block: true, reason: "blocked in test" };
2682                    }
2683                    return {};
2684                  });
2685                }
2686                "#,
2687            )
2688            .expect("write extension entry");
2689
2690            let provider = Arc::new(NoopProvider);
2691            let calls = Arc::new(AtomicUsize::new(0));
2692            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2693                calls: Arc::clone(&calls),
2694            })]);
2695            let agent = Agent::new(provider, tools, AgentConfig::default());
2696            let session = Arc::new(Mutex::new(Session::in_memory()));
2697            let mut agent_session =
2698                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2699
2700            agent_session
2701                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2702                .await
2703                .expect("enable extensions");
2704
2705            let tool_call = ToolCall {
2706                id: "call-1".to_string(),
2707                name: "count_tool".to_string(),
2708                arguments: json!({}),
2709                thought_signature: None,
2710            };
2711
2712            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2713            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2714
2715            assert!(is_error);
2716            assert!(output.is_error);
2717            assert_eq!(calls.load(Ordering::SeqCst), 0);
2718
2719            assert_eq!(output.details, None);
2720            assert!(
2721                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2722                "Expected text output, got {:?}",
2723                output.content
2724            );
2725            if let [ContentBlock::Text(text)] = output.content.as_slice() {
2726                assert_eq!(text.text, "Tool execution blocked: blocked in test");
2727            }
2728        });
2729    }
2730
2731    #[test]
2732    fn tool_call_hook_errors_fail_open() {
2733        let runtime = RuntimeBuilder::current_thread()
2734            .build()
2735            .expect("runtime build");
2736
2737        runtime.block_on(async {
2738            let temp_dir = tempfile::tempdir().expect("tempdir");
2739            let entry_path = temp_dir.path().join("ext.mjs");
2740            std::fs::write(
2741                &entry_path,
2742                r#"
2743                export default function init(pi) {
2744                  pi.on("tool_call", async (_event) => {
2745                    throw new Error("boom");
2746                  });
2747                }
2748                "#,
2749            )
2750            .expect("write extension entry");
2751
2752            let provider = Arc::new(NoopProvider);
2753            let calls = Arc::new(AtomicUsize::new(0));
2754            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2755                calls: Arc::clone(&calls),
2756            })]);
2757            let agent = Agent::new(provider, tools, AgentConfig::default());
2758            let session = Arc::new(Mutex::new(Session::in_memory()));
2759            let mut agent_session =
2760                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2761
2762            agent_session
2763                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2764                .await
2765                .expect("enable extensions");
2766
2767            let tool_call = ToolCall {
2768                id: "call-1".to_string(),
2769                name: "count_tool".to_string(),
2770                arguments: json!({}),
2771                thought_signature: None,
2772            };
2773
2774            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2775            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2776
2777            assert!(!is_error);
2778            assert!(!output.is_error);
2779            assert_eq!(calls.load(Ordering::SeqCst), 1);
2780        });
2781    }
2782
2783    #[test]
2784    fn tool_call_hook_absent_allows_tool_execution() {
2785        let runtime = RuntimeBuilder::current_thread()
2786            .build()
2787            .expect("runtime build");
2788
2789        runtime.block_on(async {
2790            let temp_dir = tempfile::tempdir().expect("tempdir");
2791            let entry_path = temp_dir.path().join("ext.mjs");
2792            std::fs::write(
2793                &entry_path,
2794                r"
2795                export default function init(_pi) {}
2796                ",
2797            )
2798            .expect("write extension entry");
2799
2800            let provider = Arc::new(NoopProvider);
2801            let calls = Arc::new(AtomicUsize::new(0));
2802            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2803                calls: Arc::clone(&calls),
2804            })]);
2805            let agent = Agent::new(provider, tools, AgentConfig::default());
2806            let session = Arc::new(Mutex::new(Session::in_memory()));
2807            let mut agent_session =
2808                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2809
2810            agent_session
2811                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2812                .await
2813                .expect("enable extensions");
2814
2815            let tool_call = ToolCall {
2816                id: "call-1".to_string(),
2817                name: "count_tool".to_string(),
2818                arguments: json!({}),
2819                thought_signature: None,
2820            };
2821
2822            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2823            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2824
2825            assert!(!is_error);
2826            assert!(!output.is_error);
2827            assert_eq!(calls.load(Ordering::SeqCst), 1);
2828        });
2829    }
2830
2831    #[test]
2832    fn tool_call_hook_returns_empty_allows_tool_execution() {
2833        let runtime = RuntimeBuilder::current_thread()
2834            .build()
2835            .expect("runtime build");
2836
2837        runtime.block_on(async {
2838            let temp_dir = tempfile::tempdir().expect("tempdir");
2839            let entry_path = temp_dir.path().join("ext.mjs");
2840            std::fs::write(
2841                &entry_path,
2842                r#"
2843                export default function init(pi) {
2844                  pi.on("tool_call", async (_event) => ({}));
2845                }
2846                "#,
2847            )
2848            .expect("write extension entry");
2849
2850            let provider = Arc::new(NoopProvider);
2851            let calls = Arc::new(AtomicUsize::new(0));
2852            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2853                calls: Arc::clone(&calls),
2854            })]);
2855            let agent = Agent::new(provider, tools, AgentConfig::default());
2856            let session = Arc::new(Mutex::new(Session::in_memory()));
2857            let mut agent_session =
2858                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2859
2860            agent_session
2861                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2862                .await
2863                .expect("enable extensions");
2864
2865            let tool_call = ToolCall {
2866                id: "call-1".to_string(),
2867                name: "count_tool".to_string(),
2868                arguments: json!({}),
2869                thought_signature: None,
2870            };
2871
2872            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2873            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2874
2875            assert!(!is_error);
2876            assert!(!output.is_error);
2877            assert_eq!(calls.load(Ordering::SeqCst), 1);
2878        });
2879    }
2880
2881    #[test]
2882    fn tool_call_hook_can_block_bash_tool_execution() {
2883        let runtime = RuntimeBuilder::current_thread()
2884            .build()
2885            .expect("runtime build");
2886
2887        runtime.block_on(async {
2888            let temp_dir = tempfile::tempdir().expect("tempdir");
2889            let entry_path = temp_dir.path().join("ext.mjs");
2890            std::fs::write(
2891                &entry_path,
2892                r#"
2893                export default function init(pi) {
2894                  pi.on("tool_call", async (event) => {
2895                    const name = event && event.toolName ? String(event.toolName) : "";
2896                    if (name === "bash") return { block: true, reason: "blocked bash in test" };
2897                    return {};
2898                  });
2899                }
2900                "#,
2901            )
2902            .expect("write extension entry");
2903
2904            let provider = Arc::new(NoopProvider);
2905            let tools = ToolRegistry::new(&["bash"], temp_dir.path(), None);
2906            let agent = Agent::new(provider, tools, AgentConfig::default());
2907            let session = Arc::new(Mutex::new(Session::in_memory()));
2908            let mut agent_session =
2909                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2910
2911            agent_session
2912                .enable_extensions(&["bash"], temp_dir.path(), None, &[entry_path])
2913                .await
2914                .expect("enable extensions");
2915
2916            let tool_call = ToolCall {
2917                id: "call-1".to_string(),
2918                name: "bash".to_string(),
2919                arguments: json!({ "command": "printf 'hi' > blocked.txt" }),
2920                thought_signature: None,
2921            };
2922
2923            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2924            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2925
2926            assert!(is_error);
2927            assert!(output.is_error);
2928            assert_eq!(output.details, None);
2929            assert!(
2930                !temp_dir.path().join("blocked.txt").exists(),
2931                "expected bash command not to run when blocked"
2932            );
2933            assert!(
2934                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
2935                "Expected text output, got {:?}",
2936                output.content
2937            );
2938            if let [ContentBlock::Text(text)] = output.content.as_slice() {
2939                assert_eq!(text.text, "Tool execution blocked: blocked bash in test");
2940            }
2941        });
2942    }
2943
2944    #[test]
2945    fn tool_result_hook_can_modify_tool_output() {
2946        let runtime = RuntimeBuilder::current_thread()
2947            .build()
2948            .expect("runtime build");
2949
2950        runtime.block_on(async {
2951            let temp_dir = tempfile::tempdir().expect("tempdir");
2952            let entry_path = temp_dir.path().join("ext.mjs");
2953            std::fs::write(
2954                &entry_path,
2955                r#"
2956                export default function init(pi) {
2957                  pi.on("tool_result", async (event) => {
2958                    if (event && event.toolName === "count_tool") {
2959                      return {
2960                        content: [{ type: "text", text: "modified" }],
2961                        details: { from: "tool_result" }
2962                      };
2963                    }
2964                    return {};
2965                  });
2966                }
2967                "#,
2968            )
2969            .expect("write extension entry");
2970
2971            let provider = Arc::new(NoopProvider);
2972            let calls = Arc::new(AtomicUsize::new(0));
2973            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
2974                calls: Arc::clone(&calls),
2975            })]);
2976            let agent = Agent::new(provider, tools, AgentConfig::default());
2977            let session = Arc::new(Mutex::new(Session::in_memory()));
2978            let mut agent_session =
2979                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
2980
2981            agent_session
2982                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
2983                .await
2984                .expect("enable extensions");
2985
2986            let tool_call = ToolCall {
2987                id: "call-1".to_string(),
2988                name: "count_tool".to_string(),
2989                arguments: json!({}),
2990                thought_signature: None,
2991            };
2992
2993            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
2994            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
2995
2996            assert!(!is_error);
2997            assert!(!output.is_error);
2998            assert_eq!(calls.load(Ordering::SeqCst), 1);
2999            assert_eq!(output.details, Some(json!({ "from": "tool_result" })));
3000
3001            assert!(
3002                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3003                "Expected text output, got {:?}",
3004                output.content
3005            );
3006            if let [ContentBlock::Text(text)] = output.content.as_slice() {
3007                assert_eq!(text.text, "modified");
3008            }
3009        });
3010    }
3011
3012    #[test]
3013    fn tool_result_hook_can_modify_tool_not_found_error() {
3014        let runtime = RuntimeBuilder::current_thread()
3015            .build()
3016            .expect("runtime build");
3017
3018        runtime.block_on(async {
3019            let temp_dir = tempfile::tempdir().expect("tempdir");
3020            let entry_path = temp_dir.path().join("ext.mjs");
3021            std::fs::write(
3022                &entry_path,
3023                r#"
3024                export default function init(pi) {
3025                  pi.on("tool_result", async (event) => {
3026                    if (event && event.toolName === "missing_tool" && event.isError) {
3027                      return {
3028                        content: [{ type: "text", text: "overridden" }],
3029                        details: { handled: true }
3030                      };
3031                    }
3032                    return {};
3033                  });
3034                }
3035                "#,
3036            )
3037            .expect("write extension entry");
3038
3039            let provider = Arc::new(NoopProvider);
3040            let tools = ToolRegistry::from_tools(Vec::new());
3041            let agent = Agent::new(provider, tools, AgentConfig::default());
3042            let session = Arc::new(Mutex::new(Session::in_memory()));
3043            let mut agent_session =
3044                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3045
3046            agent_session
3047                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3048                .await
3049                .expect("enable extensions");
3050
3051            let tool_call = ToolCall {
3052                id: "call-1".to_string(),
3053                name: "missing_tool".to_string(),
3054                arguments: json!({}),
3055                thought_signature: None,
3056            };
3057
3058            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3059            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3060
3061            assert!(is_error);
3062            assert!(output.is_error);
3063            assert_eq!(output.details, Some(json!({ "handled": true })));
3064
3065            assert!(
3066                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3067                "Expected text output, got {:?}",
3068                output.content
3069            );
3070            if let [ContentBlock::Text(text)] = output.content.as_slice() {
3071                assert_eq!(text.text, "overridden");
3072            }
3073        });
3074    }
3075
3076    #[test]
3077    fn tool_result_hook_errors_fail_open() {
3078        let runtime = RuntimeBuilder::current_thread()
3079            .build()
3080            .expect("runtime build");
3081
3082        runtime.block_on(async {
3083            let temp_dir = tempfile::tempdir().expect("tempdir");
3084            let entry_path = temp_dir.path().join("ext.mjs");
3085            std::fs::write(
3086                &entry_path,
3087                r#"
3088                export default function init(pi) {
3089                  pi.on("tool_result", async (_event) => {
3090                    throw new Error("boom");
3091                  });
3092                }
3093                "#,
3094            )
3095            .expect("write extension entry");
3096
3097            let provider = Arc::new(NoopProvider);
3098            let calls = Arc::new(AtomicUsize::new(0));
3099            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3100                calls: Arc::clone(&calls),
3101            })]);
3102            let agent = Agent::new(provider, tools, AgentConfig::default());
3103            let session = Arc::new(Mutex::new(Session::in_memory()));
3104            let mut agent_session =
3105                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3106
3107            agent_session
3108                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3109                .await
3110                .expect("enable extensions");
3111
3112            let tool_call = ToolCall {
3113                id: "call-1".to_string(),
3114                name: "count_tool".to_string(),
3115                arguments: json!({}),
3116                thought_signature: None,
3117            };
3118
3119            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3120            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3121
3122            assert!(!is_error);
3123            assert!(!output.is_error);
3124            assert_eq!(calls.load(Ordering::SeqCst), 1);
3125
3126            assert_eq!(output.details, None);
3127            assert!(
3128                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3129                "Expected text output, got {:?}",
3130                output.content
3131            );
3132            if let [ContentBlock::Text(text)] = output.content.as_slice() {
3133                assert_eq!(text.text, "ok");
3134            }
3135        });
3136    }
3137
3138    #[test]
3139    fn tool_result_hook_runs_on_blocked_tool_call() {
3140        let runtime = RuntimeBuilder::current_thread()
3141            .build()
3142            .expect("runtime build");
3143
3144        runtime.block_on(async {
3145            let temp_dir = tempfile::tempdir().expect("tempdir");
3146            let entry_path = temp_dir.path().join("ext.mjs");
3147            std::fs::write(
3148                &entry_path,
3149                r#"
3150                export default function init(pi) {
3151                  pi.on("tool_call", async (event) => {
3152                    if (event && event.toolName === "count_tool") {
3153                      return { block: true, reason: "blocked in test" };
3154                    }
3155                    return {};
3156                  });
3157
3158                  pi.on("tool_result", async (event) => {
3159                    if (event && event.toolName === "count_tool" && event.isError) {
3160                      return { content: [{ type: "text", text: "override" }] };
3161                    }
3162                    return {};
3163                  });
3164                }
3165                "#,
3166            )
3167            .expect("write extension entry");
3168
3169            let provider = Arc::new(NoopProvider);
3170            let calls = Arc::new(AtomicUsize::new(0));
3171            let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
3172                calls: Arc::clone(&calls),
3173            })]);
3174            let agent = Agent::new(provider, tools, AgentConfig::default());
3175            let session = Arc::new(Mutex::new(Session::in_memory()));
3176            let mut agent_session =
3177                AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3178
3179            agent_session
3180                .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
3181                .await
3182                .expect("enable extensions");
3183
3184            let tool_call = ToolCall {
3185                id: "call-1".to_string(),
3186                name: "count_tool".to_string(),
3187                arguments: json!({}),
3188                thought_signature: None,
3189            };
3190
3191            let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
3192            let (output, is_error) = agent_session.agent.execute_tool(tool_call, on_event).await;
3193
3194            assert!(is_error);
3195            assert!(output.is_error);
3196            assert_eq!(calls.load(Ordering::SeqCst), 0);
3197
3198            assert!(
3199                matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
3200                "Expected text output, got {:?}",
3201                output.content
3202            );
3203            if let [ContentBlock::Text(text)] = output.content.as_slice() {
3204                assert_eq!(text.text, "override");
3205            }
3206        });
3207    }
3208}
3209
3210#[cfg(test)]
3211mod abort_tests {
3212    use super::*;
3213    use crate::session::Session;
3214    use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
3215    use asupersync::runtime::RuntimeBuilder;
3216    use async_trait::async_trait;
3217    use futures::Stream;
3218    use serde_json::json;
3219    use std::path::Path;
3220    use std::pin::Pin;
3221    use std::sync::Mutex as StdMutex;
3222    use std::sync::atomic::AtomicUsize;
3223    use std::task::{Context as TaskContext, Poll};
3224
3225    struct StartThenPending {
3226        start: Option<StreamEvent>,
3227    }
3228
3229    impl Stream for StartThenPending {
3230        type Item = crate::error::Result<StreamEvent>;
3231
3232        fn poll_next(
3233            mut self: Pin<&mut Self>,
3234            _cx: &mut TaskContext<'_>,
3235        ) -> Poll<Option<Self::Item>> {
3236            if let Some(event) = self.start.take() {
3237                return Poll::Ready(Some(Ok(event)));
3238            }
3239            Poll::Pending
3240        }
3241    }
3242
3243    #[derive(Debug)]
3244    struct HangingProvider;
3245
3246    #[async_trait]
3247    #[allow(clippy::unnecessary_literal_bound)]
3248    impl Provider for HangingProvider {
3249        fn name(&self) -> &str {
3250            "test-provider"
3251        }
3252
3253        fn api(&self) -> &str {
3254            "test-api"
3255        }
3256
3257        fn model_id(&self) -> &str {
3258            "test-model"
3259        }
3260
3261        async fn stream(
3262            &self,
3263            _context: &Context<'_>,
3264            _options: &StreamOptions,
3265        ) -> crate::error::Result<
3266            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3267        > {
3268            let partial = AssistantMessage {
3269                content: Vec::new(),
3270                api: self.api().to_string(),
3271                provider: self.name().to_string(),
3272                model: self.model_id().to_string(),
3273                usage: Usage::default(),
3274                stop_reason: StopReason::Stop,
3275                error_message: None,
3276                timestamp: 0,
3277            };
3278
3279            Ok(Box::pin(StartThenPending {
3280                start: Some(StreamEvent::Start { partial }),
3281            }))
3282        }
3283    }
3284
3285    #[derive(Debug)]
3286    struct CountingProvider {
3287        calls: Arc<std::sync::atomic::AtomicUsize>,
3288    }
3289
3290    #[async_trait]
3291    #[allow(clippy::unnecessary_literal_bound)]
3292    impl Provider for CountingProvider {
3293        fn name(&self) -> &str {
3294            "test-provider"
3295        }
3296
3297        fn api(&self) -> &str {
3298            "test-api"
3299        }
3300
3301        fn model_id(&self) -> &str {
3302            "test-model"
3303        }
3304
3305        async fn stream(
3306            &self,
3307            _context: &Context<'_>,
3308            _options: &StreamOptions,
3309        ) -> crate::error::Result<
3310            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3311        > {
3312            self.calls.fetch_add(1, Ordering::SeqCst);
3313            Ok(Box::pin(futures::stream::empty()))
3314        }
3315    }
3316
3317    #[derive(Debug)]
3318    struct PhasedProvider {
3319        pending_calls: usize,
3320        calls: AtomicUsize,
3321    }
3322
3323    impl PhasedProvider {
3324        const fn new(pending_calls: usize) -> Self {
3325            Self {
3326                pending_calls,
3327                calls: AtomicUsize::new(0),
3328            }
3329        }
3330
3331        fn base_message() -> AssistantMessage {
3332            AssistantMessage {
3333                content: Vec::new(),
3334                api: "test-api".to_string(),
3335                provider: "test-provider".to_string(),
3336                model: "test-model".to_string(),
3337                usage: Usage::default(),
3338                stop_reason: StopReason::Stop,
3339                error_message: None,
3340                timestamp: 0,
3341            }
3342        }
3343    }
3344
3345    #[async_trait]
3346    #[allow(clippy::unnecessary_literal_bound)]
3347    impl Provider for PhasedProvider {
3348        fn name(&self) -> &str {
3349            "test-provider"
3350        }
3351
3352        fn api(&self) -> &str {
3353            "test-api"
3354        }
3355
3356        fn model_id(&self) -> &str {
3357            "test-model"
3358        }
3359
3360        async fn stream(
3361            &self,
3362            _context: &Context<'_>,
3363            _options: &StreamOptions,
3364        ) -> crate::error::Result<
3365            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3366        > {
3367            let call = self.calls.fetch_add(1, Ordering::SeqCst);
3368            if call < self.pending_calls {
3369                return Ok(Box::pin(StartThenPending {
3370                    start: Some(StreamEvent::Start {
3371                        partial: Self::base_message(),
3372                    }),
3373                }));
3374            }
3375
3376            let partial = Self::base_message();
3377            let mut done = Self::base_message();
3378            done.content = vec![ContentBlock::Text(TextContent::new(format!(
3379                "resumed-response-{call}"
3380            )))];
3381
3382            Ok(Box::pin(futures::stream::iter(vec![
3383                Ok(StreamEvent::Start { partial }),
3384                Ok(StreamEvent::Done {
3385                    reason: StopReason::Stop,
3386                    message: done,
3387                }),
3388            ])))
3389        }
3390    }
3391
3392    #[derive(Debug)]
3393    struct ToolCallProvider;
3394
3395    #[async_trait]
3396    #[allow(clippy::unnecessary_literal_bound)]
3397    impl Provider for ToolCallProvider {
3398        fn name(&self) -> &str {
3399            "test-provider"
3400        }
3401
3402        fn api(&self) -> &str {
3403            "test-api"
3404        }
3405
3406        fn model_id(&self) -> &str {
3407            "test-model"
3408        }
3409
3410        async fn stream(
3411            &self,
3412            _context: &Context<'_>,
3413            _options: &StreamOptions,
3414        ) -> crate::error::Result<
3415            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3416        > {
3417            let message = AssistantMessage {
3418                content: vec![ContentBlock::ToolCall(ToolCall {
3419                    id: "call-1".to_string(),
3420                    name: "hanging_tool".to_string(),
3421                    arguments: json!({}),
3422                    thought_signature: None,
3423                })],
3424                api: "test-api".to_string(),
3425                provider: "test-provider".to_string(),
3426                model: "test-model".to_string(),
3427                usage: Usage::default(),
3428                stop_reason: StopReason::ToolUse,
3429                error_message: None,
3430                timestamp: 0,
3431            };
3432
3433            Ok(Box::pin(futures::stream::iter(vec![Ok(
3434                StreamEvent::Done {
3435                    reason: StopReason::ToolUse,
3436                    message,
3437                },
3438            )])))
3439        }
3440    }
3441
3442    #[derive(Debug)]
3443    struct HangingTool;
3444
3445    #[async_trait]
3446    #[allow(clippy::unnecessary_literal_bound)]
3447    impl Tool for HangingTool {
3448        fn name(&self) -> &str {
3449            "hanging_tool"
3450        }
3451
3452        fn label(&self) -> &str {
3453            "Hanging Tool"
3454        }
3455
3456        fn description(&self) -> &str {
3457            "Never completes unless aborted by the host"
3458        }
3459
3460        fn parameters(&self) -> serde_json::Value {
3461            json!({
3462                "type": "object",
3463                "properties": {},
3464                "additionalProperties": false
3465            })
3466        }
3467
3468        async fn execute(
3469            &self,
3470            _tool_call_id: &str,
3471            _input: serde_json::Value,
3472            _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
3473        ) -> crate::error::Result<ToolOutput> {
3474            futures::future::pending::<()>().await;
3475            unreachable!("hanging tool should be aborted by the agent")
3476        }
3477    }
3478
3479    fn event_tag(event: &AgentEvent) -> &'static str {
3480        match event {
3481            AgentEvent::AgentStart { .. } => "agent_start",
3482            AgentEvent::AgentEnd { error, .. } => {
3483                if error.as_deref() == Some("Aborted") {
3484                    "agent_end_aborted"
3485                } else {
3486                    "agent_end"
3487                }
3488            }
3489            AgentEvent::TurnStart { .. } => "turn_start",
3490            AgentEvent::TurnEnd { .. } => "turn_end",
3491            AgentEvent::MessageStart { .. } => "message_start",
3492            AgentEvent::MessageUpdate {
3493                assistant_message_event,
3494                ..
3495            } => match &assistant_message_event {
3496                AssistantMessageEvent::Error {
3497                    reason: StopReason::Aborted,
3498                    ..
3499                } => "assistant_error_aborted",
3500                AssistantMessageEvent::Done { .. } => "assistant_done",
3501                _ => "assistant_update",
3502            },
3503            AgentEvent::MessageEnd { .. } => "message_end",
3504            AgentEvent::ToolExecutionStart { .. } => "tool_start",
3505            AgentEvent::ToolExecutionUpdate { .. } => "tool_update",
3506            AgentEvent::ToolExecutionEnd { .. } => "tool_end",
3507            AgentEvent::AutoCompactionStart { .. } => "auto_compaction_start",
3508            AgentEvent::AutoCompactionEnd { .. } => "auto_compaction_end",
3509            AgentEvent::AutoRetryStart { .. } => "auto_retry_start",
3510            AgentEvent::AutoRetryEnd { .. } => "auto_retry_end",
3511            AgentEvent::ExtensionError { .. } => "extension_error",
3512        }
3513    }
3514
3515    fn assert_abort_resume_message_sequence(persisted: &[Message]) {
3516        assert_eq!(
3517            persisted.len(),
3518            6,
3519            "expected three user+assistant pairs, got: {persisted:?}"
3520        );
3521
3522        let assistant_states = persisted
3523            .iter()
3524            .filter_map(|message| match message {
3525                Message::Assistant(assistant) => Some(assistant.stop_reason),
3526                _ => None,
3527            })
3528            .collect::<Vec<_>>();
3529        assert_eq!(
3530            assistant_states,
3531            vec![StopReason::Aborted, StopReason::Aborted, StopReason::Stop]
3532        );
3533    }
3534
3535    fn assert_abort_resume_timeline_boundaries(timeline: &[String]) {
3536        assert!(
3537            timeline
3538                .iter()
3539                .any(|event| event == "run0:agent_end_aborted"),
3540            "missing aborted boundary for first run: {timeline:?}"
3541        );
3542        assert!(
3543            timeline
3544                .iter()
3545                .any(|event| event == "run1:agent_end_aborted"),
3546            "missing aborted boundary for second run: {timeline:?}"
3547        );
3548        assert!(
3549            timeline.iter().any(|event| event == "run2:agent_end"),
3550            "missing successful boundary for resumed run: {timeline:?}"
3551        );
3552    }
3553
3554    #[test]
3555    fn abort_interrupts_in_flight_stream() {
3556        let runtime = RuntimeBuilder::current_thread()
3557            .build()
3558            .expect("runtime build");
3559        let handle = runtime.handle();
3560
3561        let started = Arc::new(Notify::new());
3562        let started_wait = started.notified();
3563
3564        let (abort_handle, abort_signal) = AbortHandle::new();
3565
3566        let provider = Arc::new(HangingProvider);
3567        let tools = ToolRegistry::new(&[], Path::new("."), None);
3568        let agent = Agent::new(provider, tools, AgentConfig::default());
3569        let session = Arc::new(Mutex::new(Session::in_memory()));
3570        let mut agent_session =
3571            AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3572
3573        let started_tx = Arc::clone(&started);
3574        let join = handle.spawn(async move {
3575            agent_session
3576                .run_text_with_abort("hello".to_string(), Some(abort_signal), move |event| {
3577                    if matches!(
3578                        event,
3579                        AgentEvent::MessageStart {
3580                            message: Message::Assistant(_)
3581                        }
3582                    ) {
3583                        started_tx.notify_one();
3584                    }
3585                })
3586                .await
3587        });
3588
3589        runtime.block_on(async move {
3590            started_wait.await;
3591            abort_handle.abort();
3592
3593            let message = join.await.expect("run_text_with_abort");
3594            assert_eq!(message.stop_reason, StopReason::Aborted);
3595            assert_eq!(message.error_message.as_deref(), Some("Aborted"));
3596        });
3597    }
3598
3599    #[test]
3600    fn abort_before_run_skips_provider_stream_call() {
3601        let runtime = RuntimeBuilder::current_thread()
3602            .build()
3603            .expect("runtime build");
3604
3605        let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
3606        let provider = Arc::new(CountingProvider {
3607            calls: Arc::clone(&calls),
3608        });
3609        let tools = ToolRegistry::new(&[], Path::new("."), None);
3610        let agent = Agent::new(provider, tools, AgentConfig::default());
3611        let session = Arc::new(Mutex::new(Session::in_memory()));
3612        let mut agent_session =
3613            AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
3614
3615        let (abort_handle, abort_signal) = AbortHandle::new();
3616        abort_handle.abort();
3617
3618        runtime.block_on(async move {
3619            let message = agent_session
3620                .run_text_with_abort("hello".to_string(), Some(abort_signal), |_| {})
3621                .await
3622                .expect("run_text_with_abort");
3623            assert_eq!(message.stop_reason, StopReason::Aborted);
3624            assert_eq!(calls.load(Ordering::SeqCst), 0);
3625        });
3626    }
3627
3628    #[test]
3629    fn abort_then_resume_preserves_session_history() {
3630        let runtime = RuntimeBuilder::current_thread()
3631            .build()
3632            .expect("runtime build");
3633        let handle = runtime.handle();
3634
3635        runtime.block_on(async move {
3636            let provider = Arc::new(PhasedProvider::new(1));
3637            let tools = ToolRegistry::new(&[], Path::new("."), None);
3638            let agent = Agent::new(provider, tools, AgentConfig::default());
3639            let session = Arc::new(Mutex::new(Session::in_memory()));
3640            let mut agent_session = AgentSession::new(
3641                agent,
3642                Arc::clone(&session),
3643                false,
3644                ResolvedCompactionSettings::default(),
3645            );
3646
3647            let started = Arc::new(Notify::new());
3648            let (abort_handle, abort_signal) = AbortHandle::new();
3649            let started_for_abort = Arc::clone(&started);
3650            let abort_join = handle.spawn(async move {
3651                started_for_abort.notified().await;
3652                abort_handle.abort();
3653            });
3654
3655            let aborted = agent_session
3656                .run_text_with_abort("first".to_string(), Some(abort_signal), {
3657                    let started = Arc::clone(&started);
3658                    move |event| {
3659                        if matches!(
3660                            event,
3661                            AgentEvent::MessageStart {
3662                                message: Message::Assistant(_)
3663                            }
3664                        ) {
3665                            started.notify_one();
3666                        }
3667                    }
3668                })
3669                .await
3670                .expect("first run");
3671            abort_join.await;
3672
3673            assert_eq!(aborted.stop_reason, StopReason::Aborted);
3674            assert_eq!(aborted.error_message.as_deref(), Some("Aborted"));
3675
3676            let resumed = agent_session
3677                .run_text("second".to_string(), |_| {})
3678                .await
3679                .expect("resumed run");
3680            assert_eq!(resumed.stop_reason, StopReason::Stop);
3681            assert!(resumed.error_message.is_none());
3682
3683            let cx = crate::agent_cx::AgentCx::for_request();
3684            let persisted = session
3685                .lock(cx.cx())
3686                .await
3687                .expect("lock session")
3688                .to_messages_for_current_path();
3689
3690            assert_eq!(
3691                persisted.len(),
3692                4,
3693                "unexpected message history after abort+resume: {persisted:?}"
3694            );
3695            assert!(matches!(persisted.first(), Some(Message::User(_))));
3696            assert!(matches!(
3697                persisted.get(1),
3698                Some(Message::Assistant(assistant)) if assistant.stop_reason == StopReason::Aborted
3699            ));
3700            assert!(matches!(persisted.get(2), Some(Message::User(_))));
3701            assert!(matches!(
3702                persisted.get(3),
3703                Some(Message::Assistant(assistant))
3704                    if assistant.stop_reason == StopReason::Stop && assistant.error_message.is_none()
3705            ));
3706        });
3707    }
3708
3709    #[test]
3710    fn repeated_abort_then_resume_has_consistent_timeline_and_state() {
3711        let runtime = RuntimeBuilder::current_thread()
3712            .build()
3713            .expect("runtime build");
3714        let handle = runtime.handle();
3715
3716        runtime.block_on(async move {
3717            let provider = Arc::new(PhasedProvider::new(2));
3718            let tools = ToolRegistry::new(&[], Path::new("."), None);
3719            let agent = Agent::new(provider, tools, AgentConfig::default());
3720            let session = Arc::new(Mutex::new(Session::in_memory()));
3721            let mut agent_session = AgentSession::new(
3722                agent,
3723                Arc::clone(&session),
3724                false,
3725                ResolvedCompactionSettings::default(),
3726            );
3727
3728            let timeline = Arc::new(StdMutex::new(Vec::<String>::new()));
3729
3730            for run_idx in 0..2 {
3731                let started = Arc::new(Notify::new());
3732                let (abort_handle, abort_signal) = AbortHandle::new();
3733                let started_for_abort = Arc::clone(&started);
3734                let abort_join = handle.spawn(async move {
3735                    started_for_abort.notified().await;
3736                    abort_handle.abort();
3737                });
3738
3739                let run_timeline = Arc::clone(&timeline);
3740                let aborted = agent_session
3741                    .run_text_with_abort(format!("abort-run-{run_idx}"), Some(abort_signal), {
3742                        let started = Arc::clone(&started);
3743                        move |event| {
3744                            if let Ok(mut events) = run_timeline.lock() {
3745                                events.push(format!("run{run_idx}:{}", event_tag(&event)));
3746                            }
3747                            if matches!(
3748                                event,
3749                                AgentEvent::MessageStart {
3750                                    message: Message::Assistant(_)
3751                                }
3752                            ) {
3753                                started.notify_one();
3754                            }
3755                        }
3756                    })
3757                    .await
3758                    .expect("aborted run");
3759                abort_join.await;
3760
3761                assert_eq!(
3762                    aborted.stop_reason,
3763                    StopReason::Aborted,
3764                    "run {run_idx} should abort cleanly"
3765                );
3766            }
3767
3768            let run_timeline = Arc::clone(&timeline);
3769            let resumed = agent_session
3770                .run_text("final-run".to_string(), move |event| {
3771                    if let Ok(mut events) = run_timeline.lock() {
3772                        events.push(format!("run2:{}", event_tag(&event)));
3773                    }
3774                })
3775                .await
3776                .expect("final resumed run");
3777            assert_eq!(resumed.stop_reason, StopReason::Stop);
3778            assert!(resumed.error_message.is_none());
3779
3780            let cx = crate::agent_cx::AgentCx::for_request();
3781            let persisted = session
3782                .lock(cx.cx())
3783                .await
3784                .expect("lock session")
3785                .to_messages_for_current_path();
3786
3787            assert_abort_resume_message_sequence(&persisted);
3788
3789            let timeline = timeline.lock().expect("timeline lock").clone();
3790            assert_abort_resume_timeline_boundaries(&timeline);
3791        });
3792    }
3793
3794    #[test]
3795    fn abort_during_tool_execution_records_aborted_tool_result() {
3796        let runtime = RuntimeBuilder::current_thread()
3797            .build()
3798            .expect("runtime build");
3799        let handle = runtime.handle();
3800
3801        runtime.block_on(async move {
3802            let provider = Arc::new(ToolCallProvider);
3803            let tools = ToolRegistry::from_tools(vec![Box::new(HangingTool)]);
3804            let agent = Agent::new(provider, tools, AgentConfig::default());
3805            let session = Arc::new(Mutex::new(Session::in_memory()));
3806            let mut agent_session = AgentSession::new(
3807                agent,
3808                Arc::clone(&session),
3809                false,
3810                ResolvedCompactionSettings::default(),
3811            );
3812
3813            let tool_started = Arc::new(Notify::new());
3814            let (abort_handle, abort_signal) = AbortHandle::new();
3815            let tool_started_for_abort = Arc::clone(&tool_started);
3816            let abort_join = handle.spawn(async move {
3817                tool_started_for_abort.notified().await;
3818                abort_handle.abort();
3819            });
3820
3821            let result = agent_session
3822                .run_text_with_abort("trigger tool".to_string(), Some(abort_signal), {
3823                    let tool_started = Arc::clone(&tool_started);
3824                    move |event| {
3825                        if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
3826                            tool_started.notify_one();
3827                        }
3828                    }
3829                })
3830                .await
3831                .expect("tool-abort run");
3832            abort_join.await;
3833            assert_eq!(result.stop_reason, StopReason::Aborted);
3834
3835            let cx = crate::agent_cx::AgentCx::for_request();
3836            let persisted = session
3837                .lock(cx.cx())
3838                .await
3839                .expect("lock session")
3840                .to_messages_for_current_path();
3841
3842            let tool_result = persisted
3843                .iter()
3844                .find_map(|message| match message {
3845                    Message::ToolResult(result) => Some(result),
3846                    _ => None,
3847                })
3848                .expect("expected tool result message");
3849            assert!(tool_result.is_error);
3850            assert!(
3851                tool_result.content.iter().any(|block| {
3852                    matches!(
3853                        block,
3854                        ContentBlock::Text(text) if text.text.contains("Tool execution aborted")
3855                    )
3856                }),
3857                "missing aborted tool marker in tool output: {:?}",
3858                tool_result.content
3859            );
3860        });
3861    }
3862}
3863
3864#[cfg(test)]
3865mod turn_event_tests {
3866    use super::*;
3867    use crate::session::Session;
3868    use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
3869    use asupersync::runtime::RuntimeBuilder;
3870    use async_trait::async_trait;
3871    use futures::Stream;
3872    use serde_json::json;
3873    use std::path::Path;
3874    use std::pin::Pin;
3875    use std::sync::atomic::AtomicUsize;
3876    // Note: Mutex from super::* is asupersync::sync::Mutex (for Session)
3877    // Use std::sync::Mutex directly for synchronous event capture
3878
3879    fn assistant_message(text: &str) -> AssistantMessage {
3880        AssistantMessage {
3881            content: vec![ContentBlock::Text(TextContent::new(text))],
3882            api: "test-api".to_string(),
3883            provider: "test-provider".to_string(),
3884            model: "test-model".to_string(),
3885            usage: Usage::default(),
3886            stop_reason: StopReason::Stop,
3887            error_message: None,
3888            timestamp: 0,
3889        }
3890    }
3891
3892    struct SingleShotProvider;
3893
3894    #[async_trait]
3895    #[allow(clippy::unnecessary_literal_bound)]
3896    impl Provider for SingleShotProvider {
3897        fn name(&self) -> &str {
3898            "test-provider"
3899        }
3900
3901        fn api(&self) -> &str {
3902            "test-api"
3903        }
3904
3905        fn model_id(&self) -> &str {
3906            "test-model"
3907        }
3908
3909        async fn stream(
3910            &self,
3911            _context: &Context<'_>,
3912            _options: &StreamOptions,
3913        ) -> crate::error::Result<
3914            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
3915        > {
3916            let partial = assistant_message("");
3917            let final_message = assistant_message("hello");
3918            let events = vec![
3919                Ok(StreamEvent::Start { partial }),
3920                Ok(StreamEvent::Done {
3921                    reason: StopReason::Stop,
3922                    message: final_message,
3923                }),
3924            ];
3925            Ok(Box::pin(futures::stream::iter(events)))
3926        }
3927    }
3928
3929    #[derive(Debug)]
3930    struct EchoTool;
3931
3932    #[async_trait]
3933    #[allow(clippy::unnecessary_literal_bound)]
3934    impl Tool for EchoTool {
3935        fn name(&self) -> &str {
3936            "echo_tool"
3937        }
3938
3939        fn label(&self) -> &str {
3940            "echo_tool"
3941        }
3942
3943        fn description(&self) -> &str {
3944            "echo test tool"
3945        }
3946
3947        fn parameters(&self) -> serde_json::Value {
3948            json!({ "type": "object" })
3949        }
3950
3951        async fn execute(
3952            &self,
3953            _tool_call_id: &str,
3954            _input: serde_json::Value,
3955            _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
3956        ) -> Result<ToolOutput> {
3957            Ok(ToolOutput {
3958                content: vec![ContentBlock::Text(TextContent::new("tool-ok"))],
3959                details: None,
3960                is_error: false,
3961            })
3962        }
3963    }
3964
3965    #[derive(Debug)]
3966    struct ToolTurnProvider {
3967        calls: AtomicUsize,
3968    }
3969
3970    impl ToolTurnProvider {
3971        const fn new() -> Self {
3972            Self {
3973                calls: AtomicUsize::new(0),
3974            }
3975        }
3976
3977        fn assistant_message_with(
3978            &self,
3979            stop_reason: StopReason,
3980            content: Vec<ContentBlock>,
3981        ) -> AssistantMessage {
3982            AssistantMessage {
3983                content,
3984                api: self.api().to_string(),
3985                provider: self.name().to_string(),
3986                model: self.model_id().to_string(),
3987                usage: Usage::default(),
3988                stop_reason,
3989                error_message: None,
3990                timestamp: 0,
3991            }
3992        }
3993    }
3994
3995    #[async_trait]
3996    #[allow(clippy::unnecessary_literal_bound)]
3997    impl Provider for ToolTurnProvider {
3998        fn name(&self) -> &str {
3999            "test-provider"
4000        }
4001
4002        fn api(&self) -> &str {
4003            "test-api"
4004        }
4005
4006        fn model_id(&self) -> &str {
4007            "test-model"
4008        }
4009
4010        async fn stream(
4011            &self,
4012            _context: &Context<'_>,
4013            _options: &StreamOptions,
4014        ) -> crate::error::Result<
4015            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4016        > {
4017            let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
4018            let partial = self.assistant_message_with(StopReason::Stop, Vec::new());
4019            let done = if call_index == 0 {
4020                self.assistant_message_with(
4021                    StopReason::ToolUse,
4022                    vec![ContentBlock::ToolCall(ToolCall {
4023                        id: "tool-1".to_string(),
4024                        name: "echo_tool".to_string(),
4025                        arguments: json!({}),
4026                        thought_signature: None,
4027                    })],
4028                )
4029            } else {
4030                self.assistant_message_with(
4031                    StopReason::Stop,
4032                    vec![ContentBlock::Text(TextContent::new("final"))],
4033                )
4034            };
4035
4036            Ok(Box::pin(futures::stream::iter(vec![
4037                Ok(StreamEvent::Start { partial }),
4038                Ok(StreamEvent::Done {
4039                    reason: done.stop_reason,
4040                    message: done,
4041                }),
4042            ])))
4043        }
4044    }
4045
4046    #[test]
4047    fn turn_events_wrap_assistant_response() {
4048        let runtime = RuntimeBuilder::current_thread()
4049            .build()
4050            .expect("runtime build");
4051        let handle = runtime.handle();
4052
4053        let provider = Arc::new(SingleShotProvider);
4054        let tools = ToolRegistry::new(&[], Path::new("."), None);
4055        let agent = Agent::new(provider, tools, AgentConfig::default());
4056        let session = Arc::new(Mutex::new(Session::in_memory()));
4057        let mut agent_session =
4058            AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4059
4060        let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4061            Arc::new(std::sync::Mutex::new(Vec::new()));
4062        let events_capture = Arc::clone(&events);
4063
4064        let join = handle.spawn(async move {
4065            agent_session
4066                .run_text("hello".to_string(), move |event| {
4067                    events_capture.lock().unwrap().push(event);
4068                })
4069                .await
4070                .expect("run_text")
4071        });
4072
4073        runtime.block_on(async move {
4074            let message = join.await;
4075            assert_eq!(message.stop_reason, StopReason::Stop);
4076
4077            let events = events.lock().unwrap();
4078            let turn_start_indices = events
4079                .iter()
4080                .enumerate()
4081                .filter_map(|(idx, event)| {
4082                    matches!(event, AgentEvent::TurnStart { .. }).then_some(idx)
4083                })
4084                .collect::<Vec<_>>();
4085            let turn_end_indices = events
4086                .iter()
4087                .enumerate()
4088                .filter_map(|(idx, event)| {
4089                    matches!(event, AgentEvent::TurnEnd { .. }).then_some(idx)
4090                })
4091                .collect::<Vec<_>>();
4092
4093            assert_eq!(turn_start_indices.len(), 1);
4094            assert_eq!(turn_end_indices.len(), 1);
4095            assert!(turn_start_indices[0] < turn_end_indices[0]);
4096
4097            let assistant_message_end = events
4098                .iter()
4099                .enumerate()
4100                .find_map(|(idx, event)| match event {
4101                    AgentEvent::MessageEnd {
4102                        message: Message::Assistant(_),
4103                    } => Some(idx),
4104                    _ => None,
4105                })
4106                .expect("assistant message end");
4107
4108            assert!(assistant_message_end < turn_end_indices[0]);
4109
4110            let (message_is_assistant, tool_results_empty) = {
4111                let turn_end_event = &events[turn_end_indices[0]];
4112                assert!(
4113                    matches!(turn_end_event, AgentEvent::TurnEnd { .. }),
4114                    "Expected TurnEnd event, got {turn_end_event:?}"
4115                );
4116                match turn_end_event {
4117                    AgentEvent::TurnEnd {
4118                        message,
4119                        tool_results,
4120                        ..
4121                    } => (
4122                        matches!(message, Message::Assistant(_)),
4123                        tool_results.is_empty(),
4124                    ),
4125                    _ => (false, false),
4126                }
4127            };
4128            drop(events);
4129            assert!(message_is_assistant);
4130            assert!(tool_results_empty);
4131        });
4132    }
4133
4134    #[test]
4135    fn turn_events_include_tool_execution_and_tool_result_messages() {
4136        let runtime = RuntimeBuilder::current_thread()
4137            .build()
4138            .expect("runtime build");
4139        let handle = runtime.handle();
4140
4141        let provider = Arc::new(ToolTurnProvider::new());
4142        let tools = ToolRegistry::from_tools(vec![Box::new(EchoTool)]);
4143        let agent = Agent::new(provider, tools, AgentConfig::default());
4144        let session = Arc::new(Mutex::new(Session::in_memory()));
4145        let mut agent_session =
4146            AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4147
4148        let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
4149            Arc::new(std::sync::Mutex::new(Vec::new()));
4150        let events_capture = Arc::clone(&events);
4151
4152        let join = handle.spawn(async move {
4153            agent_session
4154                .run_text("hello".to_string(), move |event| {
4155                    events_capture.lock().expect("events lock").push(event);
4156                })
4157                .await
4158                .expect("run_text")
4159        });
4160
4161        runtime.block_on(async move {
4162            let message = join.await;
4163            assert_eq!(message.stop_reason, StopReason::Stop);
4164
4165            let events = events.lock().expect("events lock");
4166            let turn_start_count = events
4167                .iter()
4168                .filter(|event| matches!(event, AgentEvent::TurnStart { .. }))
4169                .count();
4170            let turn_end_count = events
4171                .iter()
4172                .filter(|event| matches!(event, AgentEvent::TurnEnd { .. }))
4173                .count();
4174            assert_eq!(
4175                turn_start_count, 2,
4176                "expected one tool turn and one final turn"
4177            );
4178            assert_eq!(
4179                turn_end_count, 2,
4180                "expected one tool turn and one final turn"
4181            );
4182
4183            let tool_start_idx = events
4184                .iter()
4185                .position(|event| matches!(event, AgentEvent::ToolExecutionStart { .. }))
4186                .expect("tool execution start event");
4187            let tool_end_idx = events
4188                .iter()
4189                .position(|event| matches!(event, AgentEvent::ToolExecutionEnd { .. }))
4190                .expect("tool execution end event");
4191            assert!(tool_start_idx < tool_end_idx);
4192
4193            let first_turn_end_idx = events
4194                .iter()
4195                .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
4196                .expect("first turn end");
4197            assert!(
4198                tool_end_idx < first_turn_end_idx,
4199                "tool execution should complete before first turn end"
4200            );
4201
4202            let first_turn_tool_results = events.iter().find_map(|event| match event {
4203                AgentEvent::TurnEnd {
4204                    turn_index,
4205                    tool_results,
4206                    ..
4207                } if *turn_index == 0 => Some(tool_results),
4208                _ => None,
4209            });
4210
4211            let Some(first_turn_tool_results) = first_turn_tool_results else {
4212                panic!("missing first turn tool results");
4213            };
4214            assert_eq!(first_turn_tool_results.len(), 1);
4215            let first_result = first_turn_tool_results.first().unwrap();
4216            if let Message::ToolResult(tr) = first_result {
4217                assert_eq!(tr.tool_name, "echo_tool");
4218                assert!(!tr.is_error);
4219            } else {
4220                panic!("expected ToolResult message");
4221            }
4222            drop(events);
4223        });
4224    }
4225}
4226
4227impl AgentSession {
4228    pub const fn runtime_repair_mode_from_policy_mode(mode: RepairPolicyMode) -> RepairMode {
4229        match mode {
4230            RepairPolicyMode::Off => RepairMode::Off,
4231            RepairPolicyMode::Suggest => RepairMode::Suggest,
4232            RepairPolicyMode::AutoSafe => RepairMode::AutoSafe,
4233            RepairPolicyMode::AutoStrict => RepairMode::AutoStrict,
4234        }
4235    }
4236
4237    #[allow(clippy::too_many_arguments)]
4238    async fn start_js_extension_runtime(
4239        stage: &'static str,
4240        cwd: &std::path::Path,
4241        tools: Arc<ToolRegistry>,
4242        manager: ExtensionManager,
4243        policy: ExtensionPolicy,
4244        repair_mode: RepairMode,
4245        memory_limit_bytes: usize,
4246    ) -> Result<ExtensionRuntimeHandle> {
4247        let mut config = PiJsRuntimeConfig {
4248            cwd: cwd.display().to_string(),
4249            repair_mode,
4250            ..PiJsRuntimeConfig::default()
4251        };
4252        config.limits.memory_limit_bytes = Some(memory_limit_bytes).filter(|bytes| *bytes > 0);
4253
4254        let runtime =
4255            JsExtensionRuntimeHandle::start_with_policy(config, tools, manager, policy).await?;
4256        tracing::info!(
4257            event = "pi.extension_runtime.engine_decision",
4258            stage,
4259            requested = "quickjs",
4260            selected = "quickjs",
4261            fallback = false,
4262            "Extension runtime engine selected (legacy JS/TS)"
4263        );
4264        Ok(ExtensionRuntimeHandle::Js(runtime))
4265    }
4266
4267    #[allow(clippy::too_many_arguments)]
4268    async fn start_native_extension_runtime(
4269        stage: &'static str,
4270        _cwd: &std::path::Path,
4271        _tools: Arc<ToolRegistry>,
4272        _manager: ExtensionManager,
4273        _policy: ExtensionPolicy,
4274        _repair_mode: RepairMode,
4275        _memory_limit_bytes: usize,
4276    ) -> Result<ExtensionRuntimeHandle> {
4277        let runtime = NativeRustExtensionRuntimeHandle::start().await?;
4278        tracing::info!(
4279            event = "pi.extension_runtime.engine_decision",
4280            stage,
4281            requested = "native-rust",
4282            selected = "native-rust",
4283            fallback = false,
4284            "Extension runtime engine selected (native-rust)"
4285        );
4286        Ok(ExtensionRuntimeHandle::NativeRust(runtime))
4287    }
4288
4289    pub fn new(
4290        agent: Agent,
4291        session: Arc<Mutex<Session>>,
4292        save_enabled: bool,
4293        compaction_settings: ResolvedCompactionSettings,
4294    ) -> Self {
4295        Self {
4296            agent,
4297            session,
4298            save_enabled,
4299            extensions: None,
4300            extensions_is_streaming: Arc::new(AtomicBool::new(false)),
4301            compaction_settings,
4302            compaction_worker: CompactionWorkerState::new(CompactionQuota::default()),
4303            model_registry: None,
4304            auth_storage: None,
4305        }
4306    }
4307
4308    #[must_use]
4309    pub fn with_model_registry(mut self, registry: ModelRegistry) -> Self {
4310        self.model_registry = Some(registry);
4311        self
4312    }
4313
4314    #[must_use]
4315    pub fn with_auth_storage(mut self, auth: AuthStorage) -> Self {
4316        self.auth_storage = Some(auth);
4317        self
4318    }
4319
4320    pub fn set_model_registry(&mut self, registry: ModelRegistry) {
4321        self.model_registry = Some(registry);
4322    }
4323
4324    pub fn set_auth_storage(&mut self, auth: AuthStorage) {
4325        self.auth_storage = Some(auth);
4326    }
4327
4328    pub async fn set_provider_model(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
4329        {
4330            let cx = crate::agent_cx::AgentCx::for_request();
4331            let mut session = self
4332                .session
4333                .lock(cx.cx())
4334                .await
4335                .map_err(|e| Error::session(e.to_string()))?;
4336            session.set_model_header(
4337                Some(provider_id.to_string()),
4338                Some(model_id.to_string()),
4339                None,
4340            );
4341        }
4342
4343        self.apply_session_model_selection(provider_id, model_id);
4344        let provider = self.agent.provider();
4345        if provider.name() != provider_id || provider.model_id() != model_id {
4346            return Err(Error::validation(format!(
4347                "Unable to switch provider/model to {provider_id}/{model_id}"
4348            )));
4349        }
4350
4351        self.persist_session().await
4352    }
4353
4354    fn resolve_stream_api_key_for_model(&self, entry: &ModelEntry) -> Option<String> {
4355        let normalize = |key_opt: Option<String>| {
4356            key_opt.and_then(|key| {
4357                let trimmed = key.trim();
4358                (!trimmed.is_empty()).then(|| trimmed.to_string())
4359            })
4360        };
4361
4362        self.auth_storage
4363            .as_ref()
4364            .and_then(|auth| normalize(auth.resolve_api_key(&entry.model.provider, None)))
4365            .or_else(|| normalize(entry.api_key.clone()))
4366    }
4367
4368    fn apply_session_model_selection(&mut self, provider_id: &str, model_id: &str) {
4369        if self.agent.provider().name() == provider_id
4370            && self.agent.provider().model_id() == model_id
4371        {
4372            return;
4373        }
4374
4375        let Some(registry) = &self.model_registry else {
4376            return;
4377        };
4378
4379        let Some(entry) = registry.find(provider_id, model_id) else {
4380            tracing::warn!("Session model {provider_id}/{model_id} not found in model registry");
4381            return;
4382        };
4383
4384        match crate::providers::create_provider(
4385            &entry,
4386            self.extensions.as_ref().map(ExtensionRegion::manager),
4387        ) {
4388            Ok(provider) => {
4389                tracing::info!("Updating agent provider to {provider_id}/{model_id}");
4390                self.agent.set_provider(provider);
4391
4392                let resolved_key = self.resolve_stream_api_key_for_model(&entry);
4393                if resolved_key.is_none() {
4394                    tracing::warn!(
4395                        "No API key resolved for session model {provider_id}/{model_id}; clearing stream API key"
4396                    );
4397                }
4398
4399                let stream_options = self.agent.stream_options_mut();
4400                stream_options.api_key = resolved_key;
4401                stream_options.headers.clone_from(&entry.headers);
4402            }
4403            Err(e) => {
4404                tracing::warn!("Failed to create provider for session model: {e}");
4405            }
4406        }
4407    }
4408
4409    pub const fn save_enabled(&self) -> bool {
4410        self.save_enabled
4411    }
4412
4413    /// Force-run compaction synchronously (used by `/compact` slash command).
4414    pub async fn compact_now(
4415        &mut self,
4416        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4417    ) -> Result<()> {
4418        self.compact_synchronous(Arc::new(on_event)).await
4419    }
4420
4421    /// Two-phase non-blocking compaction.
4422    ///
4423    /// **Phase 1** — apply a completed background compaction result (if any).
4424    /// **Phase 2** — if quotas allow and the session needs compaction, start a
4425    /// new background compaction thread.
4426    async fn maybe_compact(&mut self, on_event: AgentEventHandler) -> Result<()> {
4427        if !self.compaction_settings.enabled {
4428            return Ok(());
4429        }
4430
4431        // Phase 1: apply completed background result.
4432        if let Some(outcome) = self.compaction_worker.try_recv() {
4433            match outcome {
4434                Ok(result) => {
4435                    self.apply_compaction_result(result, Arc::clone(&on_event))
4436                        .await?;
4437                }
4438                Err(e) => {
4439                    on_event(AgentEvent::AutoCompactionEnd {
4440                        result: None,
4441                        aborted: false,
4442                        will_retry: false,
4443                        error_message: Some(e.to_string()),
4444                    });
4445                }
4446            }
4447        }
4448
4449        // Phase 2: start new background compaction if quotas allow.
4450        if !self.compaction_worker.can_start() {
4451            return Ok(());
4452        }
4453
4454        let preparation = {
4455            let cx = crate::agent_cx::AgentCx::for_request();
4456            let session = self
4457                .session
4458                .lock(cx.cx())
4459                .await
4460                .map_err(|e| Error::session(e.to_string()))?;
4461            let entries = session
4462                .entries_for_current_path()
4463                .into_iter()
4464                .cloned()
4465                .collect::<Vec<_>>();
4466            compaction::prepare_compaction(&entries, self.compaction_settings.clone())
4467        };
4468
4469        if let Some(prep) = preparation {
4470            on_event(AgentEvent::AutoCompactionStart {
4471                reason: "threshold".to_string(),
4472            });
4473
4474            let provider = self.agent.provider();
4475            let api_key = self
4476                .agent
4477                .stream_options()
4478                .api_key
4479                .clone()
4480                .unwrap_or_default();
4481
4482            self.compaction_worker.start(prep, provider, api_key, None);
4483        }
4484
4485        Ok(())
4486    }
4487
4488    /// Apply a completed compaction result to the session.
4489    async fn apply_compaction_result(
4490        &self,
4491        result: compaction::CompactionResult,
4492        on_event: AgentEventHandler,
4493    ) -> Result<()> {
4494        let cx = crate::agent_cx::AgentCx::for_request();
4495        let mut session = self
4496            .session
4497            .lock(cx.cx())
4498            .await
4499            .map_err(|e| Error::session(e.to_string()))?;
4500
4501        let details = compaction::compaction_details_to_value(&result.details).ok();
4502        let result_value = details.clone();
4503
4504        session.append_compaction(
4505            result.summary,
4506            result.first_kept_entry_id,
4507            result.tokens_before,
4508            details,
4509            None, // from_hook
4510        );
4511
4512        if self.save_enabled {
4513            session
4514                .flush_autosave(AutosaveFlushTrigger::Periodic)
4515                .await?;
4516        }
4517
4518        on_event(AgentEvent::AutoCompactionEnd {
4519            result: result_value,
4520            aborted: false,
4521            will_retry: false,
4522            error_message: None,
4523        });
4524
4525        Ok(())
4526    }
4527
4528    /// Run compaction synchronously (inline), blocking until completion.
4529    async fn compact_synchronous(&self, on_event: AgentEventHandler) -> Result<()> {
4530        if !self.compaction_settings.enabled {
4531            return Ok(());
4532        }
4533
4534        let preparation = {
4535            let cx = crate::agent_cx::AgentCx::for_request();
4536            let session = self
4537                .session
4538                .lock(cx.cx())
4539                .await
4540                .map_err(|e| Error::session(e.to_string()))?;
4541            let entries = session
4542                .entries_for_current_path()
4543                .into_iter()
4544                .cloned()
4545                .collect::<Vec<_>>();
4546            compaction::prepare_compaction(&entries, self.compaction_settings.clone())
4547        };
4548
4549        if let Some(prep) = preparation {
4550            on_event(AgentEvent::AutoCompactionStart {
4551                reason: "threshold".to_string(),
4552            });
4553
4554            let provider = self.agent.provider();
4555            let api_key = self
4556                .agent
4557                .stream_options()
4558                .api_key
4559                .clone()
4560                .unwrap_or_default();
4561
4562            match compaction::compact(prep, provider, &api_key, None).await {
4563                Ok(result) => {
4564                    self.apply_compaction_result(result, Arc::clone(&on_event))
4565                        .await?;
4566                }
4567                Err(e) => {
4568                    on_event(AgentEvent::AutoCompactionEnd {
4569                        result: None,
4570                        aborted: false,
4571                        will_retry: false,
4572                        error_message: Some(e.to_string()),
4573                    });
4574                    return Err(e);
4575                }
4576            }
4577        }
4578        Ok(())
4579    }
4580
4581    #[allow(clippy::too_many_arguments)]
4582    pub async fn enable_extensions(
4583        &mut self,
4584        enabled_tools: &[&str],
4585        cwd: &std::path::Path,
4586        config: Option<&crate::config::Config>,
4587        extension_entries: &[std::path::PathBuf],
4588    ) -> Result<()> {
4589        self.enable_extensions_with_policy(
4590            enabled_tools,
4591            cwd,
4592            config,
4593            extension_entries,
4594            None,
4595            None,
4596            None,
4597        )
4598        .await
4599    }
4600
4601    #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
4602    pub async fn enable_extensions_with_policy(
4603        &mut self,
4604        enabled_tools: &[&str],
4605        cwd: &std::path::Path,
4606        config: Option<&crate::config::Config>,
4607        extension_entries: &[std::path::PathBuf],
4608        policy: Option<ExtensionPolicy>,
4609        repair_policy: Option<RepairPolicyMode>,
4610        pre_warmed: Option<PreWarmedExtensionRuntime>,
4611    ) -> Result<()> {
4612        let mut js_specs: Vec<JsExtensionLoadSpec> = Vec::new();
4613        let mut native_specs: Vec<NativeRustExtensionLoadSpec> = Vec::new();
4614        #[cfg(feature = "wasm-host")]
4615        let mut wasm_specs: Vec<WasmExtensionLoadSpec> = Vec::new();
4616
4617        for entry in extension_entries {
4618            match resolve_extension_load_spec(entry)? {
4619                ExtensionLoadSpec::Js(spec) => js_specs.push(spec),
4620                ExtensionLoadSpec::NativeRust(spec) => native_specs.push(spec),
4621                #[cfg(feature = "wasm-host")]
4622                ExtensionLoadSpec::Wasm(spec) => wasm_specs.push(spec),
4623            }
4624        }
4625
4626        if !js_specs.is_empty() && !native_specs.is_empty() {
4627            return Err(Error::validation(
4628                "Mixed extension runtimes are not supported in one session yet. Use either JS/TS extensions (QuickJS) or native-rust descriptors (*.native.json), but not both at once."
4629                    .to_string(),
4630            ));
4631        }
4632
4633        let resolved_policy = policy.clone().unwrap_or_default();
4634        let resolved_repair_policy = repair_policy
4635            .or_else(|| config.map(|cfg| cfg.resolve_repair_policy(None)))
4636            .unwrap_or(RepairPolicyMode::AutoSafe);
4637        let runtime_repair_mode =
4638            Self::runtime_repair_mode_from_policy_mode(resolved_repair_policy);
4639        let memory_limit_bytes =
4640            (resolved_policy.max_memory_mb as usize).saturating_mul(1024 * 1024);
4641        let wants_js_runtime = !js_specs.is_empty();
4642
4643        // Either use the pre-warmed extension runtime (booted concurrently with startup)
4644        // or create a fresh runtime inline.
4645        #[allow(unused_variables)]
4646        let (manager, tools) = if let Some(pre) = pre_warmed {
4647            let manager = pre.manager;
4648            let tools = pre.tools;
4649            let runtime = match pre.runtime {
4650                ExtensionRuntimeHandle::NativeRust(runtime) => {
4651                    if wants_js_runtime {
4652                        tracing::warn!(
4653                            event = "pi.extension_runtime.prewarm.mismatch",
4654                            expected = "quickjs",
4655                            got = "native-rust",
4656                            "Pre-warmed runtime mismatched requested JS mode; creating quickjs runtime"
4657                        );
4658                        Self::start_js_extension_runtime(
4659                            "agent_enable_extensions_prewarm_mismatch",
4660                            cwd,
4661                            Arc::clone(&tools),
4662                            manager.clone(),
4663                            resolved_policy.clone(),
4664                            runtime_repair_mode,
4665                            memory_limit_bytes,
4666                        )
4667                        .await?
4668                    } else {
4669                        tracing::info!(
4670                            event = "pi.extension_runtime.engine_decision",
4671                            stage = "agent_enable_extensions_prewarmed",
4672                            requested = "native-rust",
4673                            selected = "native-rust",
4674                            fallback = false,
4675                            "Using pre-warmed extension runtime"
4676                        );
4677                        ExtensionRuntimeHandle::NativeRust(runtime)
4678                    }
4679                }
4680                ExtensionRuntimeHandle::Js(runtime) => {
4681                    if wants_js_runtime {
4682                        tracing::info!(
4683                            event = "pi.extension_runtime.engine_decision",
4684                            stage = "agent_enable_extensions_prewarmed",
4685                            requested = "quickjs",
4686                            selected = "quickjs",
4687                            fallback = false,
4688                            "Using pre-warmed extension runtime"
4689                        );
4690                        ExtensionRuntimeHandle::Js(runtime)
4691                    } else {
4692                        tracing::warn!(
4693                            event = "pi.extension_runtime.prewarm.mismatch",
4694                            expected = "native-rust",
4695                            got = "quickjs",
4696                            "Pre-warmed runtime mismatched requested native mode; creating native-rust runtime"
4697                        );
4698                        Self::start_native_extension_runtime(
4699                            "agent_enable_extensions_prewarm_mismatch",
4700                            cwd,
4701                            Arc::clone(&tools),
4702                            manager.clone(),
4703                            resolved_policy.clone(),
4704                            runtime_repair_mode,
4705                            memory_limit_bytes,
4706                        )
4707                        .await?
4708                    }
4709                }
4710            };
4711            manager.set_runtime(runtime);
4712            (manager, tools)
4713        } else {
4714            let manager = ExtensionManager::new();
4715            manager.set_cwd(cwd.display().to_string());
4716            let tools = Arc::new(ToolRegistry::new(enabled_tools, cwd, config));
4717
4718            if let Some(cfg) = config {
4719                let resolved_risk = cfg.resolve_extension_risk_with_metadata();
4720                tracing::info!(
4721                    event = "pi.extension_runtime_risk.config",
4722                    source = resolved_risk.source,
4723                    enabled = resolved_risk.settings.enabled,
4724                    alpha = resolved_risk.settings.alpha,
4725                    window_size = resolved_risk.settings.window_size,
4726                    ledger_limit = resolved_risk.settings.ledger_limit,
4727                    fail_closed = resolved_risk.settings.fail_closed,
4728                    "Resolved extension runtime risk settings"
4729                );
4730                manager.set_runtime_risk_config(resolved_risk.settings);
4731            }
4732
4733            let runtime = if wants_js_runtime {
4734                Self::start_js_extension_runtime(
4735                    "agent_enable_extensions_boot",
4736                    cwd,
4737                    Arc::clone(&tools),
4738                    manager.clone(),
4739                    resolved_policy,
4740                    runtime_repair_mode,
4741                    memory_limit_bytes,
4742                )
4743                .await?
4744            } else {
4745                Self::start_native_extension_runtime(
4746                    "agent_enable_extensions_boot",
4747                    cwd,
4748                    Arc::clone(&tools),
4749                    manager.clone(),
4750                    resolved_policy,
4751                    runtime_repair_mode,
4752                    memory_limit_bytes,
4753                )
4754                .await?
4755            };
4756            manager.set_runtime(runtime);
4757            (manager, tools)
4758        };
4759
4760        // Session, host actions, and message fetchers are always set here
4761        // (after runtime boot) — the JS runtime only needs these when
4762        // dispatching hostcalls, which happens during extension loading.
4763        manager.set_session(Arc::new(SessionHandle(self.session.clone())));
4764
4765        let injected = Arc::new(StdMutex::new(ExtensionInjectedQueue::default()));
4766        let host_actions = AgentSessionHostActions {
4767            session: Arc::clone(&self.session),
4768            injected: Arc::clone(&injected),
4769            is_streaming: Arc::clone(&self.extensions_is_streaming),
4770        };
4771        manager.set_host_actions(Arc::new(host_actions));
4772        {
4773            let steering_queue = Arc::clone(&injected);
4774            let follow_up_queue = Arc::clone(&injected);
4775            let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
4776                let steering_queue = Arc::clone(&steering_queue);
4777                Box::pin(async move {
4778                    let Ok(mut queue) = steering_queue.lock() else {
4779                        return Vec::new();
4780                    };
4781                    queue.pop_steering()
4782                })
4783            };
4784            let follow_up_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
4785                let follow_up_queue = Arc::clone(&follow_up_queue);
4786                Box::pin(async move {
4787                    let Ok(mut queue) = follow_up_queue.lock() else {
4788                        return Vec::new();
4789                    };
4790                    queue.pop_follow_up()
4791                })
4792            };
4793            self.agent.register_message_fetchers(
4794                Some(Arc::new(steering_fetcher)),
4795                Some(Arc::new(follow_up_fetcher)),
4796            );
4797        }
4798
4799        if !js_specs.is_empty() {
4800            manager.load_js_extensions(js_specs).await?;
4801        }
4802
4803        if !native_specs.is_empty() {
4804            manager.load_native_extensions(native_specs).await?;
4805        }
4806
4807        // Drain and log auto-repair diagnostics (bd-k5q5.8.11).
4808        if let Some(rt) = manager.runtime() {
4809            let events = rt.drain_repair_events().await;
4810            if !events.is_empty() {
4811                log_repair_diagnostics(&events);
4812            }
4813        }
4814
4815        #[cfg(feature = "wasm-host")]
4816        if !wasm_specs.is_empty() {
4817            let host = WasmExtensionHost::new(cwd, policy.unwrap_or_default())?;
4818            manager
4819                .load_wasm_extensions(&host, wasm_specs, Arc::clone(&tools))
4820                .await?;
4821        }
4822
4823        // Fire the `startup` lifecycle hook once extensions are loaded.
4824        // Fail-open: extension errors must not prevent the agent from running.
4825        let session_path = {
4826            let cx = crate::agent_cx::AgentCx::for_request();
4827            let session = self
4828                .session
4829                .lock(cx.cx())
4830                .await
4831                .map_err(|e| Error::extension(e.to_string()))?;
4832            session.path.as_ref().map(|p| p.display().to_string())
4833        };
4834
4835        if let Err(err) = manager
4836            .dispatch_event(
4837                ExtensionEventName::Startup,
4838                Some(serde_json::json!({
4839                    "version": env!("CARGO_PKG_VERSION"),
4840                    "sessionFile": session_path,
4841                })),
4842            )
4843            .await
4844        {
4845            tracing::warn!("startup extension hook failed (fail-open): {err}");
4846        }
4847
4848        let ctx_payload = serde_json::json!({ "cwd": cwd.display().to_string() });
4849        let wrappers = collect_extension_tool_wrappers(&manager, ctx_payload).await?;
4850        self.agent.extend_tools(wrappers);
4851        self.agent.extensions = Some(manager.clone());
4852        self.extensions = Some(ExtensionRegion::new(manager));
4853        Ok(())
4854    }
4855
4856    pub async fn save_and_index(&mut self) -> Result<()> {
4857        if self.save_enabled {
4858            let cx = crate::agent_cx::AgentCx::for_request();
4859            let mut session = self
4860                .session
4861                .lock(cx.cx())
4862                .await
4863                .map_err(|e| Error::session(e.to_string()))?;
4864            session
4865                .flush_autosave(AutosaveFlushTrigger::Periodic)
4866                .await?;
4867        }
4868        Ok(())
4869    }
4870
4871    pub async fn persist_session(&mut self) -> Result<()> {
4872        if !self.save_enabled {
4873            return Ok(());
4874        }
4875        let cx = crate::agent_cx::AgentCx::for_request();
4876        let mut session = self
4877            .session
4878            .lock(cx.cx())
4879            .await
4880            .map_err(|e| Error::session(e.to_string()))?;
4881        session
4882            .flush_autosave(AutosaveFlushTrigger::Periodic)
4883            .await?;
4884        Ok(())
4885    }
4886
4887    pub async fn run_text(
4888        &mut self,
4889        input: String,
4890        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4891    ) -> Result<AssistantMessage> {
4892        self.run_text_with_abort(input, None, on_event).await
4893    }
4894
4895    pub async fn run_text_with_abort(
4896        &mut self,
4897        input: String,
4898        abort: Option<AbortSignal>,
4899        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4900    ) -> Result<AssistantMessage> {
4901        let outcome = self.dispatch_input_event(input, Vec::new()).await?;
4902        let (text, images) = match outcome {
4903            InputEventOutcome::Continue { text, images } => (text, images),
4904            InputEventOutcome::Block { reason } => {
4905                let message = reason.unwrap_or_else(|| "Input blocked".to_string());
4906                return Err(Error::extension(message));
4907            }
4908        };
4909
4910        self.dispatch_before_agent_start().await;
4911
4912        if images.is_empty() {
4913            self.run_agent_with_text(text, abort, on_event).await
4914        } else {
4915            let content = Self::build_content_blocks_for_input(&text, &images);
4916            self.run_agent_with_content(content, abort, on_event).await
4917        }
4918    }
4919
4920    pub async fn run_with_content(
4921        &mut self,
4922        content: Vec<ContentBlock>,
4923        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4924    ) -> Result<AssistantMessage> {
4925        self.run_with_content_with_abort(content, None, on_event)
4926            .await
4927    }
4928
4929    pub async fn run_with_content_with_abort(
4930        &mut self,
4931        content: Vec<ContentBlock>,
4932        abort: Option<AbortSignal>,
4933        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
4934    ) -> Result<AssistantMessage> {
4935        let (text, images) = Self::split_content_blocks_for_input(&content);
4936        let outcome = self.dispatch_input_event(text, images).await?;
4937        let (text, images) = match outcome {
4938            InputEventOutcome::Continue { text, images } => (text, images),
4939            InputEventOutcome::Block { reason } => {
4940                let message = reason.unwrap_or_else(|| "Input blocked".to_string());
4941                return Err(Error::extension(message));
4942            }
4943        };
4944
4945        self.dispatch_before_agent_start().await;
4946
4947        let content_for_agent = Self::build_content_blocks_for_input(&text, &images);
4948        self.run_agent_with_content(content_for_agent, abort, on_event)
4949            .await
4950    }
4951
4952    async fn dispatch_input_event(
4953        &self,
4954        text: String,
4955        images: Vec<ImageContent>,
4956    ) -> Result<InputEventOutcome> {
4957        let Some(region) = &self.extensions else {
4958            return Ok(InputEventOutcome::Continue { text, images });
4959        };
4960
4961        let images_value = serde_json::to_value(&images).unwrap_or(Value::Null);
4962        let payload = json!({
4963            "text": text,
4964            "images": images_value,
4965            "source": "user",
4966        });
4967
4968        let response = region
4969            .manager()
4970            .dispatch_event_with_response(
4971                ExtensionEventName::Input,
4972                Some(payload),
4973                EXTENSION_EVENT_TIMEOUT_MS,
4974            )
4975            .await?;
4976
4977        Ok(apply_input_event_response(response, text, images))
4978    }
4979
4980    async fn dispatch_before_agent_start(&self) {
4981        if let Some(region) = &self.extensions {
4982            if let Err(err) = region
4983                .manager()
4984                .dispatch_event(ExtensionEventName::BeforeAgentStart, None)
4985                .await
4986            {
4987                tracing::warn!("before_agent_start extension hook failed (fail-open): {err}");
4988            }
4989        }
4990    }
4991
4992    fn split_content_blocks_for_input(blocks: &[ContentBlock]) -> (String, Vec<ImageContent>) {
4993        let mut text = String::new();
4994        let mut images = Vec::new();
4995        for block in blocks {
4996            match block {
4997                ContentBlock::Text(text_block) => {
4998                    if !text_block.text.trim().is_empty() {
4999                        if !text.is_empty() {
5000                            text.push('\n');
5001                        }
5002                        text.push_str(&text_block.text);
5003                    }
5004                }
5005                ContentBlock::Image(image) => images.push(image.clone()),
5006                _ => {}
5007            }
5008        }
5009        (text, images)
5010    }
5011
5012    fn build_content_blocks_for_input(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
5013        let mut content = Vec::new();
5014        if !text.trim().is_empty() {
5015            content.push(ContentBlock::Text(TextContent::new(text.to_string())));
5016        }
5017        for image in images {
5018            content.push(ContentBlock::Image(image.clone()));
5019        }
5020        content
5021    }
5022
5023    pub(crate) async fn run_agent_with_text(
5024        &mut self,
5025        input: String,
5026        abort: Option<AbortSignal>,
5027        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
5028    ) -> Result<AssistantMessage> {
5029        let on_event: AgentEventHandler = Arc::new(on_event);
5030        let session_model = {
5031            let cx = crate::agent_cx::AgentCx::for_request();
5032            let session = self
5033                .session
5034                .lock(cx.cx())
5035                .await
5036                .map_err(|e| Error::session(e.to_string()))?;
5037            (
5038                session.header.provider.clone(),
5039                session.header.model_id.clone(),
5040            )
5041        };
5042
5043        if let (Some(provider_id), Some(model_id)) = session_model {
5044            self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
5045        }
5046
5047        self.maybe_compact(Arc::clone(&on_event)).await?;
5048        let history = {
5049            let cx = crate::agent_cx::AgentCx::for_request();
5050            let session = self
5051                .session
5052                .lock(cx.cx())
5053                .await
5054                .map_err(|e| Error::session(e.to_string()))?;
5055            session.to_messages_for_current_path()
5056        };
5057        self.agent.replace_messages(history);
5058
5059        let start_len = self.agent.messages().len();
5060
5061        // Create and persist user message immediately to avoid data loss on API errors
5062        let user_message = Message::User(UserMessage {
5063            content: UserContent::Text(input),
5064            timestamp: Utc::now().timestamp_millis(),
5065        });
5066
5067        {
5068            let cx = crate::agent_cx::AgentCx::for_request();
5069            let mut session = self
5070                .session
5071                .lock(cx.cx())
5072                .await
5073                .map_err(|e| Error::session(e.to_string()))?;
5074            session.append_model_message(user_message.clone());
5075            if self.save_enabled {
5076                session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
5077            }
5078        }
5079
5080        self.extensions_is_streaming.store(true, Ordering::SeqCst);
5081        let on_event_for_run = Arc::clone(&on_event);
5082        let result = self
5083            .agent
5084            .run_with_message_with_abort(user_message, abort, move |event| {
5085                on_event_for_run(event);
5086            })
5087            .await;
5088        self.extensions_is_streaming.store(false, Ordering::SeqCst);
5089        let result = result?;
5090        // Persist only NEW messages (assistant/tools), skipping the user message we already saved.
5091        self.persist_new_messages(start_len + 1).await?;
5092        Ok(result)
5093    }
5094
5095    pub(crate) async fn run_agent_with_content(
5096        &mut self,
5097        content: Vec<ContentBlock>,
5098        abort: Option<AbortSignal>,
5099        on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
5100    ) -> Result<AssistantMessage> {
5101        let on_event: AgentEventHandler = Arc::new(on_event);
5102        let session_model = {
5103            let cx = crate::agent_cx::AgentCx::for_request();
5104            let session = self
5105                .session
5106                .lock(cx.cx())
5107                .await
5108                .map_err(|e| Error::session(e.to_string()))?;
5109            (
5110                session.header.provider.clone(),
5111                session.header.model_id.clone(),
5112            )
5113        };
5114
5115        if let (Some(provider_id), Some(model_id)) = session_model {
5116            self.apply_session_model_selection(provider_id.as_str(), model_id.as_str());
5117        }
5118
5119        self.maybe_compact(Arc::clone(&on_event)).await?;
5120        let history = {
5121            let cx = crate::agent_cx::AgentCx::for_request();
5122            let session = self
5123                .session
5124                .lock(cx.cx())
5125                .await
5126                .map_err(|e| Error::session(e.to_string()))?;
5127            session.to_messages_for_current_path()
5128        };
5129        self.agent.replace_messages(history);
5130
5131        let start_len = self.agent.messages().len();
5132
5133        // Create and persist user message immediately to avoid data loss on API errors
5134        let user_message = Message::User(UserMessage {
5135            content: UserContent::Blocks(content),
5136            timestamp: Utc::now().timestamp_millis(),
5137        });
5138
5139        {
5140            let cx = crate::agent_cx::AgentCx::for_request();
5141            let mut session = self
5142                .session
5143                .lock(cx.cx())
5144                .await
5145                .map_err(|e| Error::session(e.to_string()))?;
5146            session.append_model_message(user_message.clone());
5147            if self.save_enabled {
5148                session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
5149            }
5150        }
5151
5152        self.extensions_is_streaming.store(true, Ordering::SeqCst);
5153        let on_event_for_run = Arc::clone(&on_event);
5154        let result = self
5155            .agent
5156            .run_with_message_with_abort(user_message, abort, move |event| {
5157                on_event_for_run(event);
5158            })
5159            .await;
5160        self.extensions_is_streaming.store(false, Ordering::SeqCst);
5161        let result = result?;
5162        // Persist only NEW messages (assistant/tools), skipping the user message we already saved.
5163        self.persist_new_messages(start_len + 1).await?;
5164        Ok(result)
5165    }
5166
5167    async fn persist_new_messages(&self, start_len: usize) -> Result<()> {
5168        let new_messages = self.agent.messages()[start_len..].to_vec();
5169        {
5170            let cx = crate::agent_cx::AgentCx::for_request();
5171            let mut session = self
5172                .session
5173                .lock(cx.cx())
5174                .await
5175                .map_err(|e| Error::session(e.to_string()))?;
5176            for message in new_messages {
5177                session.append_model_message(message);
5178            }
5179            if self.save_enabled {
5180                session
5181                    .flush_autosave(AutosaveFlushTrigger::Periodic)
5182                    .await?;
5183            }
5184        }
5185        Ok(())
5186    }
5187}
5188
5189// ============================================================================
5190// Helper Functions
5191// ============================================================================
5192
5193/// Log a summary of auto-repair events that fired during extension loading.
5194///
5195/// Default: one-line summary.  Set `PI_AUTO_REPAIR_VERBOSE=1` for per-extension
5196/// detail.  Structured tracing events are always emitted regardless of verbosity.
5197fn log_repair_diagnostics(events: &[crate::extensions_js::ExtensionRepairEvent]) {
5198    use std::collections::BTreeMap;
5199
5200    // Always emit structured tracing events for each repair.
5201    for ev in events {
5202        tracing::info!(
5203            event = "extension.auto_repair",
5204            extension_id = %ev.extension_id,
5205            pattern = %ev.pattern,
5206            success = ev.success,
5207            original_error = %ev.original_error,
5208            repair_action = %ev.repair_action,
5209        );
5210    }
5211
5212    // Group by pattern for the summary line.
5213    let mut by_pattern: BTreeMap<String, Vec<&str>> = BTreeMap::new();
5214    for ev in events {
5215        by_pattern
5216            .entry(ev.pattern.to_string())
5217            .or_default()
5218            .push(&ev.extension_id);
5219    }
5220
5221    let verbose = std::env::var("PI_AUTO_REPAIR_VERBOSE")
5222        .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
5223
5224    if verbose {
5225        eprintln!(
5226            "[auto-repair] {} extension{} auto-repaired:",
5227            events.len(),
5228            if events.len() == 1 { "" } else { "s" }
5229        );
5230        for ev in events {
5231            eprintln!(
5232                "  {}: {} ({})",
5233                ev.pattern, ev.extension_id, ev.repair_action
5234            );
5235        }
5236    } else {
5237        // Compact one-line summary.
5238        let patterns: Vec<String> = by_pattern
5239            .iter()
5240            .map(|(pat, ids)| format!("{pat}:{}", ids.len()))
5241            .collect();
5242        tracing::info!(
5243            event = "extension.auto_repair.summary",
5244            count = events.len(),
5245            patterns = %patterns.join(", "),
5246            "auto-repaired {} extension(s)",
5247            events.len(),
5248        );
5249    }
5250}
5251
5252const BLOCK_IMAGES_PLACEHOLDER: &str = "Image reading is disabled.";
5253
5254#[derive(Debug, Default, Clone, Copy)]
5255struct ImageFilterStats {
5256    removed_images: usize,
5257    affected_messages: usize,
5258}
5259
5260fn filter_images_for_provider(messages: &mut [Message]) -> ImageFilterStats {
5261    let mut stats = ImageFilterStats::default();
5262    for message in messages {
5263        let removed = filter_images_from_message(message);
5264        if removed > 0 {
5265            stats.removed_images += removed;
5266            stats.affected_messages += 1;
5267        }
5268    }
5269    stats
5270}
5271
5272fn filter_images_from_message(message: &mut Message) -> usize {
5273    match message {
5274        Message::User(user) => match &mut user.content {
5275            UserContent::Text(_) => 0,
5276            UserContent::Blocks(blocks) => filter_image_blocks(blocks),
5277        },
5278        Message::Assistant(assistant) => {
5279            let assistant = Arc::make_mut(assistant);
5280            filter_image_blocks(&mut assistant.content)
5281        }
5282        Message::ToolResult(tool_result) => {
5283            filter_image_blocks(&mut Arc::make_mut(tool_result).content)
5284        }
5285        Message::Custom(_) => 0,
5286    }
5287}
5288
5289fn filter_image_blocks(blocks: &mut Vec<ContentBlock>) -> usize {
5290    let mut removed = 0usize;
5291    let mut filtered = Vec::with_capacity(blocks.len());
5292
5293    for block in blocks.drain(..) {
5294        match block {
5295            ContentBlock::Image(_) => {
5296                removed += 1;
5297                let previous_is_placeholder =
5298                    filtered
5299                        .last()
5300                        .is_some_and(|prev| matches!(prev, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER));
5301                if !previous_is_placeholder {
5302                    filtered.push(ContentBlock::Text(TextContent::new(
5303                        BLOCK_IMAGES_PLACEHOLDER,
5304                    )));
5305                }
5306            }
5307            other => filtered.push(other),
5308        }
5309    }
5310
5311    *blocks = filtered;
5312    removed
5313}
5314
5315/// Extract tool calls from content blocks.
5316fn extract_tool_calls(content: &[ContentBlock]) -> Vec<ToolCall> {
5317    content
5318        .iter()
5319        .filter_map(|block| {
5320            if let ContentBlock::ToolCall(tc) = block {
5321                Some(tc.clone())
5322            } else {
5323                None
5324            }
5325        })
5326        .collect()
5327}
5328
5329// ============================================================================
5330// Tests
5331// ============================================================================
5332
5333#[cfg(test)]
5334mod tests {
5335    use super::*;
5336    use crate::auth::AuthCredential;
5337    use crate::provider::{InputType, Model, ModelCost};
5338    use async_trait::async_trait;
5339    use futures::Stream;
5340    use std::collections::HashMap;
5341    use std::path::Path;
5342    use std::pin::Pin;
5343
5344    fn user_message(text: &str) -> Message {
5345        Message::User(UserMessage {
5346            content: UserContent::Text(text.to_string()),
5347            timestamp: 0,
5348        })
5349    }
5350
5351    fn assert_user_text(message: &Message, expected: &str) {
5352        assert!(
5353            matches!(
5354                message,
5355                Message::User(UserMessage {
5356                    content: UserContent::Text(_),
5357                    ..
5358                })
5359            ),
5360            "expected user text message, got {message:?}"
5361        );
5362        if let Message::User(UserMessage {
5363            content: UserContent::Text(text),
5364            ..
5365        }) = message
5366        {
5367            assert_eq!(text, expected);
5368        }
5369    }
5370
5371    fn sample_image_block() -> ContentBlock {
5372        ContentBlock::Image(ImageContent {
5373            data: "aGVsbG8=".to_string(),
5374            mime_type: "image/png".to_string(),
5375        })
5376    }
5377
5378    fn image_count_in_message(message: &Message) -> usize {
5379        let count_images = |blocks: &[ContentBlock]| {
5380            blocks
5381                .iter()
5382                .filter(|block| matches!(block, ContentBlock::Image(_)))
5383                .count()
5384        };
5385        match message {
5386            Message::User(UserMessage {
5387                content: UserContent::Blocks(blocks),
5388                ..
5389            }) => count_images(blocks),
5390            Message::Assistant(msg) => count_images(&msg.content),
5391            Message::ToolResult(tool_result) => count_images(&tool_result.content),
5392            Message::User(UserMessage {
5393                content: UserContent::Text(_),
5394                ..
5395            })
5396            | Message::Custom(_) => 0,
5397        }
5398    }
5399
5400    #[derive(Debug)]
5401    struct SilentProvider;
5402
5403    #[async_trait]
5404    #[allow(clippy::unnecessary_literal_bound)]
5405    impl Provider for SilentProvider {
5406        fn name(&self) -> &str {
5407            "silent-provider"
5408        }
5409
5410        fn api(&self) -> &str {
5411            "test-api"
5412        }
5413
5414        fn model_id(&self) -> &str {
5415            "test-model"
5416        }
5417
5418        async fn stream(
5419            &self,
5420            _context: &Context<'_>,
5421            _options: &StreamOptions,
5422        ) -> crate::error::Result<
5423            Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
5424        > {
5425            Ok(Box::pin(futures::stream::empty()))
5426        }
5427    }
5428
5429    #[test]
5430    fn test_extract_tool_calls() {
5431        let content = vec![
5432            ContentBlock::Text(TextContent::new("Hello")),
5433            ContentBlock::ToolCall(ToolCall {
5434                id: "tc1".to_string(),
5435                name: "read".to_string(),
5436                arguments: serde_json::json!({"path": "file.txt"}),
5437                thought_signature: None,
5438            }),
5439            ContentBlock::Text(TextContent::new("World")),
5440            ContentBlock::ToolCall(ToolCall {
5441                id: "tc2".to_string(),
5442                name: "bash".to_string(),
5443                arguments: serde_json::json!({"command": "ls"}),
5444                thought_signature: None,
5445            }),
5446        ];
5447
5448        let tool_calls = extract_tool_calls(&content);
5449        assert_eq!(tool_calls.len(), 2);
5450        assert_eq!(tool_calls[0].name, "read");
5451        assert_eq!(tool_calls[1].name, "bash");
5452    }
5453
5454    #[test]
5455    fn test_agent_config_default() {
5456        let config = AgentConfig::default();
5457        assert_eq!(config.max_tool_iterations, 50);
5458        assert!(config.system_prompt.is_none());
5459        assert!(!config.block_images);
5460    }
5461
5462    #[test]
5463    fn filter_image_blocks_replaces_images_with_deduped_placeholder_text() {
5464        let mut blocks = vec![
5465            sample_image_block(),
5466            sample_image_block(),
5467            ContentBlock::Text(TextContent::new("tail")),
5468            sample_image_block(),
5469        ];
5470
5471        let removed = filter_image_blocks(&mut blocks);
5472
5473        assert_eq!(removed, 3);
5474        assert!(
5475            !blocks
5476                .iter()
5477                .any(|block| matches!(block, ContentBlock::Image(_)))
5478        );
5479        assert!(matches!(
5480            blocks.first(),
5481            Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
5482        ));
5483        assert!(matches!(
5484            blocks.get(1),
5485            Some(ContentBlock::Text(TextContent { text, .. })) if text == "tail"
5486        ));
5487        assert!(matches!(
5488            blocks.get(2),
5489            Some(ContentBlock::Text(TextContent { text, .. })) if text == BLOCK_IMAGES_PLACEHOLDER
5490        ));
5491    }
5492
5493    #[test]
5494    fn filter_images_for_provider_filters_images_from_all_block_message_types() {
5495        let mut messages = vec![
5496            Message::User(UserMessage {
5497                content: UserContent::Blocks(vec![
5498                    ContentBlock::Text(TextContent::new("hello")),
5499                    sample_image_block(),
5500                ]),
5501                timestamp: 0,
5502            }),
5503            Message::Assistant(Arc::new(AssistantMessage {
5504                content: vec![sample_image_block()],
5505                api: "test".to_string(),
5506                provider: "test".to_string(),
5507                model: "test".to_string(),
5508                usage: Usage::default(),
5509                stop_reason: StopReason::Stop,
5510                error_message: None,
5511                timestamp: 0,
5512            })),
5513            Message::tool_result(ToolResultMessage {
5514                tool_call_id: "tc1".to_string(),
5515                tool_name: "read".to_string(),
5516                content: vec![
5517                    sample_image_block(),
5518                    ContentBlock::Text(TextContent::new("ok")),
5519                ],
5520                details: None,
5521                is_error: false,
5522                timestamp: 0,
5523            }),
5524        ];
5525
5526        let stats = filter_images_for_provider(&mut messages);
5527
5528        assert_eq!(stats.removed_images, 3);
5529        assert_eq!(stats.affected_messages, 3);
5530        assert_eq!(
5531            messages.iter().map(image_count_in_message).sum::<usize>(),
5532            0,
5533            "no images should remain in provider-bound context"
5534        );
5535    }
5536
5537    #[test]
5538    fn build_context_strips_images_when_block_images_enabled() {
5539        let mut agent = Agent::new(
5540            Arc::new(SilentProvider),
5541            ToolRegistry::new(&[], Path::new("."), None),
5542            AgentConfig {
5543                system_prompt: None,
5544                max_tool_iterations: 50,
5545                stream_options: StreamOptions::default(),
5546                block_images: true,
5547            },
5548        );
5549        agent.add_message(Message::User(UserMessage {
5550            content: UserContent::Blocks(vec![sample_image_block()]),
5551            timestamp: 0,
5552        }));
5553
5554        let context = agent.build_context();
5555        assert_eq!(context.messages.len(), 1);
5556        assert_eq!(image_count_in_message(&context.messages[0]), 0);
5557        assert!(matches!(
5558            &context.messages[0],
5559            Message::User(UserMessage {
5560                content: UserContent::Blocks(blocks),
5561                ..
5562            }) if blocks
5563                .iter()
5564                .any(|block| matches!(block, ContentBlock::Text(TextContent { text, .. }) if text == BLOCK_IMAGES_PLACEHOLDER))
5565        ));
5566    }
5567
5568    #[test]
5569    fn build_context_keeps_images_when_block_images_disabled() {
5570        let mut agent = Agent::new(
5571            Arc::new(SilentProvider),
5572            ToolRegistry::new(&[], Path::new("."), None),
5573            AgentConfig {
5574                system_prompt: None,
5575                max_tool_iterations: 50,
5576                stream_options: StreamOptions::default(),
5577                block_images: false,
5578            },
5579        );
5580        agent.add_message(Message::User(UserMessage {
5581            content: UserContent::Blocks(vec![sample_image_block()]),
5582            timestamp: 0,
5583        }));
5584
5585        let context = agent.build_context();
5586        assert_eq!(context.messages.len(), 1);
5587        assert_eq!(image_count_in_message(&context.messages[0]), 1);
5588    }
5589
5590    #[test]
5591    fn auto_compaction_start_serializes_with_pi_mono_compatible_type_tag() {
5592        let event = AgentEvent::AutoCompactionStart {
5593            reason: "threshold".to_string(),
5594        };
5595        let json = serde_json::to_value(&event).unwrap();
5596        assert_eq!(json["type"], "auto_compaction_start");
5597        assert_eq!(json["reason"], "threshold");
5598    }
5599
5600    #[test]
5601    fn auto_compaction_end_serializes_with_pi_mono_compatible_fields() {
5602        let event = AgentEvent::AutoCompactionEnd {
5603            result: Some(serde_json::json!({"tokens_before": 5000, "tokens_after": 2000})),
5604            aborted: false,
5605            will_retry: false,
5606            error_message: None,
5607        };
5608        let json = serde_json::to_value(&event).unwrap();
5609        assert_eq!(json["type"], "auto_compaction_end");
5610        assert_eq!(json["aborted"], false);
5611        assert_eq!(json["willRetry"], false);
5612        assert!(json.get("errorMessage").is_none()); // skipped when None
5613        assert!(json["result"].is_object());
5614    }
5615
5616    #[test]
5617    fn auto_compaction_end_includes_error_message_when_present() {
5618        let event = AgentEvent::AutoCompactionEnd {
5619            result: None,
5620            aborted: true,
5621            will_retry: false,
5622            error_message: Some("Compaction failed".to_string()),
5623        };
5624        let json = serde_json::to_value(&event).unwrap();
5625        assert_eq!(json["type"], "auto_compaction_end");
5626        assert_eq!(json["aborted"], true);
5627        assert_eq!(json["errorMessage"], "Compaction failed");
5628    }
5629
5630    #[test]
5631    fn auto_retry_start_serializes_with_camel_case_fields() {
5632        let event = AgentEvent::AutoRetryStart {
5633            attempt: 1,
5634            max_attempts: 3,
5635            delay_ms: 2000,
5636            error_message: "Rate limited".to_string(),
5637        };
5638        let json = serde_json::to_value(&event).unwrap();
5639        assert_eq!(json["type"], "auto_retry_start");
5640        assert_eq!(json["attempt"], 1);
5641        assert_eq!(json["maxAttempts"], 3);
5642        assert_eq!(json["delayMs"], 2000);
5643        assert_eq!(json["errorMessage"], "Rate limited");
5644    }
5645
5646    #[test]
5647    fn auto_retry_end_serializes_success_and_omits_null_final_error() {
5648        let event = AgentEvent::AutoRetryEnd {
5649            success: true,
5650            attempt: 2,
5651            final_error: None,
5652        };
5653        let json = serde_json::to_value(&event).unwrap();
5654        assert_eq!(json["type"], "auto_retry_end");
5655        assert_eq!(json["success"], true);
5656        assert_eq!(json["attempt"], 2);
5657        assert!(json.get("finalError").is_none());
5658    }
5659
5660    #[test]
5661    fn auto_retry_end_includes_final_error_on_failure() {
5662        let event = AgentEvent::AutoRetryEnd {
5663            success: false,
5664            attempt: 3,
5665            final_error: Some("Max retries exceeded".to_string()),
5666        };
5667        let json = serde_json::to_value(&event).unwrap();
5668        assert_eq!(json["type"], "auto_retry_end");
5669        assert_eq!(json["success"], false);
5670        assert_eq!(json["attempt"], 3);
5671        assert_eq!(json["finalError"], "Max retries exceeded");
5672    }
5673
5674    #[test]
5675    fn message_queue_push_increments_seq_and_counts_both_queues() {
5676        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5677        assert_eq!(queue.pending_count(), 0);
5678
5679        assert_eq!(queue.push_steering(user_message("s1")), 0);
5680        assert_eq!(queue.push_follow_up(user_message("f1")), 1);
5681        assert_eq!(queue.push_steering(user_message("s2")), 2);
5682
5683        assert_eq!(queue.pending_count(), 3);
5684    }
5685
5686    #[test]
5687    fn message_queue_pop_steering_one_at_a_time_preserves_order() {
5688        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5689        queue.push_steering(user_message("s1"));
5690        queue.push_steering(user_message("s2"));
5691
5692        let first = queue.pop_steering();
5693        assert_eq!(first.len(), 1);
5694        assert_user_text(&first[0], "s1");
5695        assert_eq!(queue.pending_count(), 1);
5696
5697        let second = queue.pop_steering();
5698        assert_eq!(second.len(), 1);
5699        assert_user_text(&second[0], "s2");
5700        assert_eq!(queue.pending_count(), 0);
5701
5702        let empty = queue.pop_steering();
5703        assert!(empty.is_empty());
5704    }
5705
5706    #[test]
5707    fn message_queue_pop_respects_queue_modes_per_kind() {
5708        let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
5709        queue.push_steering(user_message("s1"));
5710        queue.push_steering(user_message("s2"));
5711        queue.push_follow_up(user_message("f1"));
5712        queue.push_follow_up(user_message("f2"));
5713
5714        let steering = queue.pop_steering();
5715        assert_eq!(steering.len(), 2);
5716        assert_user_text(&steering[0], "s1");
5717        assert_user_text(&steering[1], "s2");
5718        assert_eq!(queue.pending_count(), 2);
5719
5720        let follow_up = queue.pop_follow_up();
5721        assert_eq!(follow_up.len(), 1);
5722        assert_user_text(&follow_up[0], "f1");
5723        assert_eq!(queue.pending_count(), 1);
5724
5725        let follow_up = queue.pop_follow_up();
5726        assert_eq!(follow_up.len(), 1);
5727        assert_user_text(&follow_up[0], "f2");
5728        assert_eq!(queue.pending_count(), 0);
5729    }
5730
5731    #[test]
5732    fn message_queue_set_modes_applies_to_existing_messages() {
5733        let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
5734        queue.push_steering(user_message("s1"));
5735        queue.push_steering(user_message("s2"));
5736
5737        let first = queue.pop_steering();
5738        assert_eq!(first.len(), 1);
5739        assert_user_text(&first[0], "s1");
5740
5741        queue.set_modes(QueueMode::All, QueueMode::OneAtATime);
5742        let remaining = queue.pop_steering();
5743        assert_eq!(remaining.len(), 1);
5744        assert_user_text(&remaining[0], "s2");
5745    }
5746
5747    fn build_switch_test_session(auth: &AuthStorage) -> AgentSession {
5748        let registry = ModelRegistry::load(auth, None);
5749        let current_entry = registry
5750            .find("anthropic", "claude-sonnet-4-5")
5751            .expect("anthropic model in registry");
5752        let provider = crate::providers::create_provider(&current_entry, None)
5753            .expect("create anthropic provider");
5754        let tools = ToolRegistry::new(&[], Path::new("."), None);
5755        let mut stream_options = StreamOptions {
5756            api_key: Some("stale-key".to_string()),
5757            ..Default::default()
5758        };
5759        let _ = stream_options
5760            .headers
5761            .insert("x-stale-header".to_string(), "stale-value".to_string());
5762        let agent = Agent::new(
5763            provider,
5764            tools,
5765            AgentConfig {
5766                system_prompt: None,
5767                max_tool_iterations: 50,
5768                stream_options,
5769                block_images: false,
5770            },
5771        );
5772
5773        let mut session = Session::in_memory();
5774        session.header.provider = Some("openai".to_string());
5775        session.header.model_id = Some("gpt-4o".to_string());
5776
5777        let mut agent_session = AgentSession::new(
5778            agent,
5779            Arc::new(Mutex::new(session)),
5780            false,
5781            ResolvedCompactionSettings::default(),
5782        );
5783        agent_session.set_model_registry(registry);
5784        agent_session.set_auth_storage(auth.clone());
5785        agent_session
5786    }
5787
5788    #[test]
5789    fn apply_session_model_selection_updates_stream_credentials_and_headers() {
5790        let dir = tempfile::tempdir().expect("tempdir");
5791        let auth_path = dir.path().join("auth.json");
5792        let mut auth = AuthStorage::load(auth_path).expect("load auth");
5793        auth.set(
5794            "anthropic",
5795            AuthCredential::ApiKey {
5796                key: "anthropic-key".to_string(),
5797            },
5798        );
5799        auth.set(
5800            "openai",
5801            AuthCredential::ApiKey {
5802                key: "openai-key".to_string(),
5803            },
5804        );
5805
5806        let mut agent_session = build_switch_test_session(&auth);
5807        agent_session.apply_session_model_selection("openai", "gpt-4o");
5808
5809        assert_eq!(agent_session.agent.provider().name(), "openai");
5810        assert_eq!(agent_session.agent.provider().model_id(), "gpt-4o");
5811        assert_eq!(
5812            agent_session.agent.stream_options().api_key.as_deref(),
5813            Some("openai-key")
5814        );
5815        assert!(
5816            agent_session.agent.stream_options().headers.is_empty(),
5817            "stream headers should be refreshed from selected model entry"
5818        );
5819    }
5820
5821    #[test]
5822    fn apply_session_model_selection_clears_stale_key_when_target_has_no_key() {
5823        let dir = tempfile::tempdir().expect("tempdir");
5824        let auth_path = dir.path().join("auth.json");
5825        let mut auth = AuthStorage::load(auth_path).expect("load auth");
5826        auth.set(
5827            "anthropic",
5828            AuthCredential::ApiKey {
5829                key: "anthropic-key".to_string(),
5830            },
5831        );
5832
5833        let mut agent_session = build_switch_test_session(&auth);
5834        agent_session.apply_session_model_selection("openai", "gpt-4o");
5835
5836        assert_eq!(agent_session.agent.provider().name(), "openai");
5837        assert_eq!(
5838            agent_session.agent.stream_options().api_key,
5839            None,
5840            "stale key must be cleared when target model has no configured key"
5841        );
5842    }
5843
5844    #[test]
5845    fn apply_session_model_selection_treats_blank_model_key_as_missing() {
5846        let dir = tempfile::tempdir().expect("tempdir");
5847        let auth_path = dir.path().join("auth.json");
5848        let auth = AuthStorage::load(auth_path).expect("load auth");
5849
5850        let mut registry = ModelRegistry::load(&auth, None);
5851        registry.merge_entries(vec![ModelEntry {
5852            model: Model {
5853                id: "blank-model".to_string(),
5854                name: "Blank Model".to_string(),
5855                api: "openai-completions".to_string(),
5856                provider: "acme".to_string(),
5857                base_url: "https://example.invalid/v1".to_string(),
5858                reasoning: true,
5859                input: vec![InputType::Text],
5860                cost: ModelCost {
5861                    input: 0.0,
5862                    output: 0.0,
5863                    cache_read: 0.0,
5864                    cache_write: 0.0,
5865                },
5866                context_window: 128_000,
5867                max_tokens: 8_192,
5868                headers: HashMap::new(),
5869            },
5870            api_key: Some("   ".to_string()),
5871            headers: HashMap::new(),
5872            auth_header: true,
5873            compat: None,
5874            oauth_config: None,
5875        }]);
5876
5877        let mut agent_session = build_switch_test_session(&auth);
5878        agent_session.set_model_registry(registry);
5879        agent_session.apply_session_model_selection("acme", "blank-model");
5880
5881        assert_eq!(agent_session.agent.provider().name(), "acme");
5882        assert_eq!(
5883            agent_session.agent.stream_options().api_key,
5884            None,
5885            "blank model keys must not be treated as valid credentials"
5886        );
5887    }
5888
5889    #[test]
5890    fn auto_compaction_start_serializes_to_pi_mono_format() {
5891        let event = AgentEvent::AutoCompactionStart {
5892            reason: "threshold".to_string(),
5893        };
5894        let json = serde_json::to_value(&event).unwrap();
5895        assert_eq!(json["type"], "auto_compaction_start");
5896        assert_eq!(json["reason"], "threshold");
5897    }
5898
5899    #[test]
5900    fn auto_compaction_end_serializes_to_pi_mono_format() {
5901        let event = AgentEvent::AutoCompactionEnd {
5902            result: Some(serde_json::json!({
5903                "summary": "Compacted",
5904                "firstKeptEntryId": "abc123",
5905                "tokensBefore": 50000,
5906                "details": { "readFiles": [], "modifiedFiles": [] }
5907            })),
5908            aborted: false,
5909            will_retry: true,
5910            error_message: None,
5911        };
5912        let json = serde_json::to_value(&event).unwrap();
5913        assert_eq!(json["type"], "auto_compaction_end");
5914        assert!(json["result"].is_object());
5915        assert_eq!(json["aborted"], false);
5916        assert_eq!(json["willRetry"], true);
5917        assert!(json.get("errorMessage").is_none());
5918    }
5919
5920    #[test]
5921    fn auto_compaction_end_with_error_serializes_error_message() {
5922        let event = AgentEvent::AutoCompactionEnd {
5923            result: None,
5924            aborted: false,
5925            will_retry: false,
5926            error_message: Some("compaction failed".to_string()),
5927        };
5928        let json = serde_json::to_value(&event).unwrap();
5929        assert_eq!(json["type"], "auto_compaction_end");
5930        assert!(json["result"].is_null());
5931        assert_eq!(json["errorMessage"], "compaction failed");
5932    }
5933
5934    #[test]
5935    fn auto_retry_start_serializes_to_pi_mono_format() {
5936        let event = AgentEvent::AutoRetryStart {
5937            attempt: 2,
5938            max_attempts: 3,
5939            delay_ms: 4000,
5940            error_message: "rate limited".to_string(),
5941        };
5942        let json = serde_json::to_value(&event).unwrap();
5943        assert_eq!(json["type"], "auto_retry_start");
5944        assert_eq!(json["attempt"], 2);
5945        assert_eq!(json["maxAttempts"], 3);
5946        assert_eq!(json["delayMs"], 4000);
5947        assert_eq!(json["errorMessage"], "rate limited");
5948    }
5949
5950    #[test]
5951    fn auto_retry_end_success_serializes_to_pi_mono_format() {
5952        let event = AgentEvent::AutoRetryEnd {
5953            success: true,
5954            attempt: 2,
5955            final_error: None,
5956        };
5957        let json = serde_json::to_value(&event).unwrap();
5958        assert_eq!(json["type"], "auto_retry_end");
5959        assert_eq!(json["success"], true);
5960        assert_eq!(json["attempt"], 2);
5961        assert!(json.get("finalError").is_none());
5962    }
5963
5964    #[test]
5965    fn auto_retry_end_failure_serializes_final_error() {
5966        let event = AgentEvent::AutoRetryEnd {
5967            success: false,
5968            attempt: 3,
5969            final_error: Some("max retries exceeded".to_string()),
5970        };
5971        let json = serde_json::to_value(&event).unwrap();
5972        assert_eq!(json["type"], "auto_retry_end");
5973        assert_eq!(json["success"], false);
5974        assert_eq!(json["attempt"], 3);
5975        assert_eq!(json["finalError"], "max retries exceeded");
5976    }
5977}