1use crate::auth::AuthStorage;
16use crate::compaction::{self, ResolvedCompactionSettings};
17use crate::compaction_worker::{
18 CompactionAdmissionSignals, CompactionQuota, CompactionWorkerState,
19};
20use crate::error::{Error, Result};
21use crate::extension_events::{
22 BeforeAgentStartOutcome, InputEventOutcome, SessionBeforeCompactOutcome,
23 apply_before_agent_start_response, apply_input_event_response,
24 apply_session_before_compact_response,
25};
26use crate::extension_tools::collect_extension_tool_wrappers;
27use crate::extensions::{
28 EXTENSION_EVENT_TIMEOUT_MS, ExtensionAiCompletionRequest, ExtensionDeliverAs,
29 ExtensionEventName, ExtensionHostActions, ExtensionLoadSpec, ExtensionManager, ExtensionPolicy,
30 ExtensionRegion, ExtensionRuntimeHandle, ExtensionSendMessage, ExtensionSendUserMessage,
31 JsExtensionLoadSpec, JsExtensionRuntimeHandle, NativeRustExtensionLoadSpec,
32 NativeRustExtensionRuntimeHandle, RepairPolicyMode, resolve_extension_load_spec,
33};
34#[cfg(feature = "wasm-host")]
35use crate::extensions::{WasmExtensionHost, WasmExtensionLoadSpec};
36use crate::extensions_js::{PiJsRuntimeConfig, RepairMode};
37use crate::model::{
38 AssistantMessage, AssistantMessageEvent, ContentBlock, CustomMessage, ImageContent, Message,
39 StopReason, StreamEvent, TextContent, ThinkingContent, ToolCall, ToolResultMessage, Usage,
40 UserContent, UserMessage,
41};
42use crate::models::{
43 ModelEntry, ModelRegistry, model_requires_configured_credential, normalize_api_key_opt,
44};
45use crate::provider::{Context, Provider, StreamOptions, ToolDef};
46use crate::semantic_workspace_graph::{ContextBundleItem, SemanticContextBundle};
47use crate::session::{AutosaveFlushTrigger, Session, SessionHandle};
48use crate::tools::{Tool, ToolEffects, ToolOutput, ToolRegistry, ToolUpdate};
49use asupersync::runtime::{Runtime, RuntimeBuilder, RuntimeHandle};
50use asupersync::sync::{Mutex, Notify};
51use async_trait::async_trait;
52use chrono::Utc;
53use futures::FutureExt;
54use futures::StreamExt;
55use futures::future::BoxFuture;
56use futures::stream;
57use serde::Serialize;
58use serde_json::{Value, json};
59use sha2::{Digest as _, Sha256};
60use std::borrow::Cow;
61use std::collections::VecDeque;
62use std::fmt;
63use std::sync::Arc;
64use std::sync::Mutex as StdMutex;
65use std::sync::OnceLock;
66use std::sync::atomic::{AtomicBool, Ordering};
67use std::time::{Duration, Instant};
68use tracing::warn;
69
70const MIN_COMPATIBLE_TOOL_PARALLELISM: usize = 8;
71const MAX_AUTO_COMPATIBLE_TOOL_PARALLELISM: usize = 64;
72const MAX_CONFIGURED_COMPATIBLE_TOOL_PARALLELISM: usize = 256;
73const MAX_STEERING_QUEUE_SIZE: usize = 100;
75const MAX_FOLLOW_UP_QUEUE_SIZE: usize = 100;
77const MAX_AGENT_MESSAGES: usize = 10_000;
79pub const TURN_LATENCY_BREAKDOWN_SCHEMA_V1: &str = "pi.agent.turn_latency_breakdown.v1";
81pub const TOOL_EFFECT_BATCH_PLAN_SCHEMA_V1: &str = "pi.agent.tool_effect_batch_plan.v1";
83const TOOL_CANCELLATION_SCHEMA_V1: &str = "pi.tool.cancellation.v1";
84const TOOL_APPROVAL_DENIED_SCHEMA_V1: &str = "pi.tool.approval_denied.v1";
85const TOOL_APPROVAL_STATUS_SCHEMA_V1: &str = "pi.tool.approval_status.v1";
86const SEMANTIC_CONTEXT_PROMPT_SCHEMA_V1: &str = "pi.semantic_context_prompt.v1";
87const SEMANTIC_CONTEXT_PROVENANCE_SCHEMA_V1: &str = "pi.semantic_context_provenance.v1";
88const SEMANTIC_CONTEXT_CUSTOM_TYPE: &str = "semantic_context_bundle";
89const DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_BYTES: u64 = 16 * 1024;
90const DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_ITEMS: usize = 16;
91
92fn compatible_tool_parallelism_limit() -> usize {
93 static LIMIT: OnceLock<usize> = OnceLock::new();
94 *LIMIT.get_or_init(|| {
95 let host_parallelism = std::thread::available_parallelism()
96 .map_or(MIN_COMPATIBLE_TOOL_PARALLELISM, |parallelism| {
97 parallelism.get()
98 });
99 resolve_compatible_tool_parallelism(
100 std::env::var("PI_MAX_CONCURRENT_COMPATIBLE_TOOLS")
101 .ok()
102 .as_deref(),
103 host_parallelism,
104 )
105 })
106}
107
108fn resolve_compatible_tool_parallelism(
109 raw_override: Option<&str>,
110 host_parallelism: usize,
111) -> usize {
112 let host_default = host_parallelism.clamp(
113 MIN_COMPATIBLE_TOOL_PARALLELISM,
114 MAX_AUTO_COMPATIBLE_TOOL_PARALLELISM,
115 );
116
117 let Some(raw) = raw_override.map(str::trim).filter(|raw| !raw.is_empty()) else {
118 return host_default;
119 };
120
121 match raw.parse::<usize>() {
122 Ok(0) => {
123 warn!(
124 value = raw,
125 "Ignoring PI_MAX_CONCURRENT_COMPATIBLE_TOOLS=0; using host-scaled default"
126 );
127 host_default
128 }
129 Ok(limit) => limit.clamp(1, MAX_CONFIGURED_COMPATIBLE_TOOL_PARALLELISM),
130 Err(err) => {
131 warn!(
132 value = raw,
133 error = %err,
134 "Ignoring invalid PI_MAX_CONCURRENT_COMPATIBLE_TOOLS; using host-scaled default"
135 );
136 host_default
137 }
138 }
139}
140
141fn duration_millis_saturating(duration: Duration) -> u64 {
142 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
143}
144
145fn duration_micros_saturating(duration: Duration) -> u64 {
146 u64::try_from(duration.as_micros()).unwrap_or(u64::MAX)
147}
148
149fn record_global_latency(counter: &crate::session_metrics::TimingCounter, duration: Duration) {
150 if crate::session_metrics::global().enabled() {
151 counter.record(duration_micros_saturating(duration));
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Default)]
157#[serde(rename_all = "camelCase")]
158pub struct LatencyPercentiles {
159 pub p50_ms: u64,
161 pub p95_ms: u64,
163 pub p99_ms: u64,
165 pub p999_ms: u64,
167}
168
169impl LatencyPercentiles {
170 fn from_samples(samples: &[u64]) -> Self {
171 Self {
172 p50_ms: percentile_nearest_rank(samples, 50),
173 p95_ms: percentile_nearest_rank(samples, 95),
174 p99_ms: percentile_nearest_rank(samples, 99),
175 p999_ms: percentile_nearest_rank_per_mille(samples, 999),
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Default)]
182#[serde(rename_all = "camelCase")]
183pub struct LatencyComponentBreakdown {
184 pub duration_ms: u64,
186 pub samples: usize,
188 pub tail_percentiles: LatencyPercentiles,
190}
191
192impl LatencyComponentBreakdown {
193 #[must_use]
195 pub fn from_millis_samples(samples: &[u64]) -> Self {
196 Self {
197 duration_ms: samples.iter().copied().fold(0u64, u64::saturating_add),
198 samples: samples.len(),
199 tail_percentiles: LatencyPercentiles::from_samples(samples),
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize)]
206#[serde(rename_all = "camelCase")]
207pub struct TurnLatencyBreakdown {
208 pub schema: &'static str,
210 pub total_ms: u64,
212 pub provider_streaming: LatencyComponentBreakdown,
214 pub local_tools: LatencyComponentBreakdown,
216 pub extension_hostcalls: LatencyComponentBreakdown,
218 pub persistence: LatencyComponentBreakdown,
220 pub dominant_component: String,
222}
223
224impl TurnLatencyBreakdown {
225 #[must_use]
227 pub fn from_component_samples(
228 total_ms: u64,
229 provider_streaming_ms: &[u64],
230 local_tool_ms: &[u64],
231 extension_hostcall_ms: &[u64],
232 persistence_ms: &[u64],
233 ) -> Self {
234 let provider_streaming =
235 LatencyComponentBreakdown::from_millis_samples(provider_streaming_ms);
236 let local_tools = LatencyComponentBreakdown::from_millis_samples(local_tool_ms);
237 let extension_hostcalls =
238 LatencyComponentBreakdown::from_millis_samples(extension_hostcall_ms);
239 let persistence = LatencyComponentBreakdown::from_millis_samples(persistence_ms);
240 let dominant_component = dominant_latency_component(
241 &provider_streaming,
242 &local_tools,
243 &extension_hostcalls,
244 &persistence,
245 );
246
247 Self {
248 schema: TURN_LATENCY_BREAKDOWN_SCHEMA_V1,
249 total_ms,
250 provider_streaming,
251 local_tools,
252 extension_hostcalls,
253 persistence,
254 dominant_component,
255 }
256 }
257}
258
259fn percentile_nearest_rank(samples: &[u64], percentile: usize) -> u64 {
260 if samples.is_empty() {
261 return 0;
262 }
263
264 let mut sorted = samples.to_vec();
265 sorted.sort_unstable();
266 let len = sorted.len();
267 let rank = percentile
268 .saturating_mul(len)
269 .div_ceil(100)
270 .saturating_sub(1)
271 .min(len.saturating_sub(1));
272 sorted[rank]
273}
274
275fn percentile_nearest_rank_per_mille(samples: &[u64], permille: usize) -> u64 {
276 if samples.is_empty() {
277 return 0;
278 }
279
280 let mut sorted = samples.to_vec();
281 sorted.sort_unstable();
282 let len = sorted.len();
283 let rank = permille
284 .saturating_mul(len)
285 .div_ceil(1000)
286 .saturating_sub(1)
287 .min(len.saturating_sub(1));
288 sorted[rank]
289}
290
291fn dominant_latency_component(
292 provider_streaming: &LatencyComponentBreakdown,
293 local_tools: &LatencyComponentBreakdown,
294 extension_hostcalls: &LatencyComponentBreakdown,
295 persistence: &LatencyComponentBreakdown,
296) -> String {
297 [
298 ("provider_streaming", provider_streaming.duration_ms),
299 ("local_tools", local_tools.duration_ms),
300 ("extension_hostcalls", extension_hostcalls.duration_ms),
301 ("persistence", persistence.duration_ms),
302 ]
303 .into_iter()
304 .max_by_key(|(_, duration_ms)| *duration_ms)
305 .filter(|(_, duration_ms)| *duration_ms > 0)
306 .map_or_else(|| "none".to_string(), |(name, _)| name.to_string())
307}
308
309#[derive(Debug)]
310struct TurnLatencyAccumulator {
311 started_at: Instant,
312 provider_streaming_ms: Vec<u64>,
313 local_tool_ms: Vec<u64>,
314 extension_hostcall_ms: Vec<u64>,
315 persistence_ms: Vec<u64>,
316}
317
318impl TurnLatencyAccumulator {
319 fn started() -> Self {
320 Self {
321 started_at: Instant::now(),
322 provider_streaming_ms: Vec::new(),
323 local_tool_ms: Vec::new(),
324 extension_hostcall_ms: Vec::new(),
325 persistence_ms: Vec::new(),
326 }
327 }
328
329 fn snapshot(&self) -> TurnLatencyBreakdown {
330 TurnLatencyBreakdown::from_component_samples(
331 duration_millis_saturating(self.started_at.elapsed()),
332 &self.provider_streaming_ms,
333 &self.local_tool_ms,
334 &self.extension_hostcall_ms,
335 &self.persistence_ms,
336 )
337 }
338}
339
340type SharedTurnLatencyAccumulator = Arc<StdMutex<TurnLatencyAccumulator>>;
341
342fn snapshot_turn_latency(
343 latency: &SharedTurnLatencyAccumulator,
344) -> Option<Box<TurnLatencyBreakdown>> {
345 latency.lock().ok().map(|guard| Box::new(guard.snapshot()))
346}
347
348fn record_provider_streaming_latency(latency: &SharedTurnLatencyAccumulator, duration: Duration) {
349 if let Ok(mut guard) = latency.lock() {
350 guard
351 .provider_streaming_ms
352 .push(duration_millis_saturating(duration));
353 }
354 let metrics = crate::session_metrics::global();
355 record_global_latency(&metrics.provider_streaming, duration);
356}
357
358fn record_local_tool_latency(latency: &SharedTurnLatencyAccumulator, duration: Duration) {
359 if let Ok(mut guard) = latency.lock() {
360 guard
361 .local_tool_ms
362 .push(duration_millis_saturating(duration));
363 }
364 let metrics = crate::session_metrics::global();
365 record_global_latency(&metrics.local_tools, duration);
366}
367
368fn record_extension_hostcall_latency(latency: &SharedTurnLatencyAccumulator, duration: Duration) {
369 if let Ok(mut guard) = latency.lock() {
370 guard
371 .extension_hostcall_ms
372 .push(duration_millis_saturating(duration));
373 }
374 let metrics = crate::session_metrics::global();
375 record_global_latency(&metrics.extension_hostcalls, duration);
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379struct ToolEffectBatch {
380 start: usize,
381 end: usize,
382}
383
384#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
386#[serde(rename_all = "camelCase")]
387pub struct ToolEffectBatchEvidence {
388 pub start: usize,
390 pub end: usize,
392 pub len: usize,
394 pub combined_effects: Vec<&'static str>,
396 pub parallel_safe: bool,
398 #[serde(skip_serializing_if = "Option::is_none")]
400 pub barrier_reason: Option<&'static str>,
401}
402
403#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
405#[serde(rename_all = "camelCase")]
406pub struct ToolEffectBatchPlanEvidence {
407 pub schema: &'static str,
409 pub tool_count: usize,
411 pub parallelism_cap: usize,
413 pub batches: Vec<ToolEffectBatchEvidence>,
415}
416
417fn plan_tool_effect_batches(effects: &[ToolEffects]) -> Vec<ToolEffectBatch> {
418 let Some((&first_effects, remaining_effects)) = effects.split_first() else {
419 return Vec::new();
420 };
421
422 let mut batches = Vec::new();
423 let mut start = 0;
424 let mut active_effects = first_effects;
425
426 for (offset, candidate_effects) in remaining_effects.iter().copied().enumerate() {
427 let index = offset + 1;
428 if active_effects.compatible_with(candidate_effects) {
429 active_effects = active_effects.union(candidate_effects);
430 } else {
431 batches.push(ToolEffectBatch { start, end: index });
432 start = index;
433 active_effects = candidate_effects;
434 }
435 }
436
437 batches.push(ToolEffectBatch {
438 start,
439 end: effects.len(),
440 });
441 batches
442}
443
444fn combined_tool_effects(effects: &[ToolEffects]) -> Option<ToolEffects> {
445 effects.iter().copied().reduce(ToolEffects::union)
446}
447
448const fn tool_effect_barrier_reason(effects: ToolEffects) -> Option<&'static str> {
449 if effects.parallel_safe() {
450 return None;
451 }
452 match (effects.writes(), effects.appends(), effects.processes()) {
453 (true, true, true) => Some("write_append_process_barrier"),
454 (true, true, false) => Some("write_append_barrier"),
455 (true, false, true) => Some("write_process_barrier"),
456 (false, true, true) => Some("append_process_barrier"),
457 (true, false, false) => Some("write_barrier"),
458 (false, true, false) => Some("append_barrier"),
459 (false, false, true) => Some("process_barrier"),
460 (false, false, false) => Some("undeclared_effects_barrier"),
461 }
462}
463
464#[must_use]
466pub fn tool_effect_batch_plan_evidence(
467 effects: &[ToolEffects],
468 parallelism_cap: usize,
469) -> ToolEffectBatchPlanEvidence {
470 let batches = plan_tool_effect_batches(effects)
471 .into_iter()
472 .map(|batch| {
473 let combined_effects = effects
474 .get(batch.start..batch.end)
475 .and_then(combined_tool_effects)
476 .unwrap_or_else(ToolEffects::read);
477 ToolEffectBatchEvidence {
478 start: batch.start,
479 end: batch.end,
480 len: batch.end.saturating_sub(batch.start),
481 combined_effects: combined_effects.labels(),
482 parallel_safe: combined_effects.parallel_safe(),
483 barrier_reason: tool_effect_barrier_reason(combined_effects),
484 }
485 })
486 .collect();
487
488 ToolEffectBatchPlanEvidence {
489 schema: TOOL_EFFECT_BATCH_PLAN_SCHEMA_V1,
490 tool_count: effects.len(),
491 parallelism_cap,
492 batches,
493 }
494}
495
496pub const MAX_TOOL_ITERATIONS_DEFAULT: usize = 50;
508
509pub const MAX_TOOL_ITERATIONS_CEILING: usize = 1_000;
515
516const ITERATION_WARN_NUMERATOR: usize = 4;
521const ITERATION_WARN_DENOMINATOR: usize = 5;
522
523const ITERATION_WARN_MIN_CAP: usize = 5;
527
528pub fn resolved_max_tool_iterations_default() -> usize {
534 resolve_max_tool_iterations(std::env::var("PI_MAX_TOOL_ITERATIONS").ok().as_deref())
535}
536
537pub fn resolve_max_tool_iterations(raw_override: Option<&str>) -> usize {
543 let Some(raw) = raw_override.map(str::trim).filter(|raw| !raw.is_empty()) else {
544 return MAX_TOOL_ITERATIONS_DEFAULT;
545 };
546 match raw.parse::<usize>() {
547 Ok(0) => {
548 warn!(
549 "PI_MAX_TOOL_ITERATIONS=0 is invalid; falling back to {}",
550 MAX_TOOL_ITERATIONS_DEFAULT
551 );
552 MAX_TOOL_ITERATIONS_DEFAULT
553 }
554 Ok(n) if n > MAX_TOOL_ITERATIONS_CEILING => {
555 warn!(
556 "PI_MAX_TOOL_ITERATIONS={n} exceeds ceiling {MAX_TOOL_ITERATIONS_CEILING}; clamping to {MAX_TOOL_ITERATIONS_CEILING}"
557 );
558 MAX_TOOL_ITERATIONS_CEILING
559 }
560 Ok(n) => n,
561 Err(err) => {
562 warn!(
563 "PI_MAX_TOOL_ITERATIONS={raw:?} is not a valid usize ({err}); falling back to {}",
564 MAX_TOOL_ITERATIONS_DEFAULT
565 );
566 MAX_TOOL_ITERATIONS_DEFAULT
567 }
568 }
569}
570
571pub fn clamp_max_tool_iterations(value: Option<usize>) -> usize {
578 match value {
579 None => MAX_TOOL_ITERATIONS_DEFAULT,
580 Some(0) => {
581 warn!(
582 "--max-tool-iterations=0 is invalid; falling back to {}",
583 MAX_TOOL_ITERATIONS_DEFAULT
584 );
585 MAX_TOOL_ITERATIONS_DEFAULT
586 }
587 Some(n) if n > MAX_TOOL_ITERATIONS_CEILING => {
588 warn!(
589 "--max-tool-iterations={n} exceeds ceiling {MAX_TOOL_ITERATIONS_CEILING}; clamping to {MAX_TOOL_ITERATIONS_CEILING}"
590 );
591 MAX_TOOL_ITERATIONS_CEILING
592 }
593 Some(n) => n,
594 }
595}
596
597pub const fn should_warn_at_iteration_threshold(current: usize, max: usize) -> bool {
608 max >= ITERATION_WARN_MIN_CAP
609 && current >= max.saturating_mul(ITERATION_WARN_NUMERATOR) / ITERATION_WARN_DENOMINATOR
610}
611
612pub fn iteration_handoff_steering_text(current: usize, max: usize) -> String {
616 format!(
617 "[runtime] Tool-iteration budget at >=80% (used {current} of {max}). \
618 Per the iteration-aware-handoff protocol in your spec, begin graceful \
619 handoff now: commit current work, post a one-line status note, and \
620 write an incomplete-handoff envelope with what's done / what remains \
621 / next-agent starting position. Do NOT compress remaining work into \
622 the last few iterations."
623 )
624}
625
626#[derive(Clone)]
628pub struct AgentConfig {
629 pub system_prompt: Option<String>,
631
632 pub max_tool_iterations: usize,
634
635 pub stream_options: StreamOptions,
637
638 pub block_images: bool,
640
641 pub fail_closed_hooks: bool,
643
644 pub tool_approval: Option<ToolApprovalHandler>,
646}
647
648impl fmt::Debug for AgentConfig {
649 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
650 f.debug_struct("AgentConfig")
651 .field("system_prompt", &self.system_prompt)
652 .field("max_tool_iterations", &self.max_tool_iterations)
653 .field("stream_options", &self.stream_options)
654 .field("block_images", &self.block_images)
655 .field("fail_closed_hooks", &self.fail_closed_hooks)
656 .field("tool_approval", &self.tool_approval.is_some())
657 .finish()
658 }
659}
660
661#[derive(Debug, Clone)]
663pub struct ToolApprovalRequest {
664 pub tool_call_id: String,
665 pub tool_name: String,
666 pub arguments: Value,
667}
668
669#[derive(Debug, Clone, PartialEq, Eq)]
671pub enum ToolApprovalDecision {
672 Allow,
673 Deny { reason: String },
674}
675
676impl ToolApprovalDecision {
677 #[must_use]
678 pub fn deny(reason: impl Into<String>) -> Self {
679 Self::Deny {
680 reason: reason.into(),
681 }
682 }
683}
684
685pub type ToolApprovalHandler =
686 Arc<dyn Fn(ToolApprovalRequest) -> BoxFuture<'static, ToolApprovalDecision> + Send + Sync>;
687
688impl Default for AgentConfig {
689 fn default() -> Self {
690 Self {
691 system_prompt: None,
692 max_tool_iterations: resolved_max_tool_iterations_default(),
693 stream_options: StreamOptions::default(),
694 block_images: false,
695 fail_closed_hooks: false,
696 tool_approval: None,
697 }
698 }
699}
700
701#[derive(Debug, Clone)]
703pub struct SemanticContextBundleInjection {
704 pub enabled: bool,
705 pub bundle: SemanticContextBundle,
706 pub max_prompt_items: usize,
707 pub max_prompt_bytes: u64,
708 pub include_exclusion_summary: bool,
709 pub include_validation_commands: bool,
710}
711
712impl SemanticContextBundleInjection {
713 pub fn enabled(bundle: SemanticContextBundle) -> Self {
714 let max_prompt_items = bundle
715 .budget
716 .max_items
717 .min(DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_ITEMS);
718 let max_prompt_bytes = bundle
719 .budget
720 .max_bytes
721 .min(DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_BYTES);
722 Self {
723 enabled: true,
724 bundle,
725 max_prompt_items,
726 max_prompt_bytes,
727 include_exclusion_summary: true,
728 include_validation_commands: true,
729 }
730 }
731
732 pub const fn disabled(bundle: SemanticContextBundle) -> Self {
733 Self {
734 enabled: false,
735 bundle,
736 max_prompt_items: DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_ITEMS,
737 max_prompt_bytes: DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_BYTES,
738 include_exclusion_summary: true,
739 include_validation_commands: true,
740 }
741 }
742
743 #[must_use]
744 pub const fn with_prompt_budget(mut self, max_items: usize, max_bytes: u64) -> Self {
745 self.max_prompt_items = max_items;
746 self.max_prompt_bytes = max_bytes;
747 self
748 }
749}
750
751#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
752#[serde(rename_all = "snake_case")]
753pub enum SemanticContextPromptShape {
754 CustomUserMessage,
755 SystemPromptAppend,
756}
757
758#[derive(Debug, Clone)]
759struct PreparedSemanticContextPrompt {
760 prompt: String,
761 revision: String,
762 shape: SemanticContextPromptShape,
763 details: Value,
764}
765
766#[derive(Debug, Clone, Copy)]
767struct SemanticContextPromptBudget {
768 max_items: usize,
769 max_bytes: u64,
770}
771
772#[derive(Debug, Default, Clone, Copy)]
773struct SemanticContextPromptStats {
774 selected_items_included: usize,
775 selected_items_omitted: usize,
776 validation_commands_included: usize,
777 validation_commands_omitted: usize,
778 exclusions_included: usize,
779 exclusions_omitted: usize,
780 truncated: bool,
781}
782
783pub type MessageFetcher = Arc<dyn Fn() -> BoxFuture<'static, Vec<Message>> + Send + Sync + 'static>;
785
786type AgentEventHandler = Arc<dyn Fn(AgentEvent) + Send + Sync + 'static>;
787
788#[derive(Debug, Clone, Copy, PartialEq, Eq)]
789pub enum QueueMode {
790 All,
791 OneAtATime,
792}
793
794impl QueueMode {
795 pub const fn as_str(self) -> &'static str {
796 match self {
797 Self::All => "all",
798 Self::OneAtATime => "one-at-a-time",
799 }
800 }
801}
802
803#[derive(Debug, Clone, Copy, PartialEq, Eq)]
804pub enum InputSource {
805 Interactive,
806 Rpc,
807 Extension,
808}
809
810impl InputSource {
811 pub const fn as_str(self) -> &'static str {
812 match self {
813 Self::Interactive => "interactive",
814 Self::Rpc => "rpc",
815 Self::Extension => "extension",
816 }
817 }
818}
819
820#[derive(Debug, Clone, Copy)]
821enum QueueKind {
822 Steering,
823 FollowUp,
824}
825
826#[derive(Debug, Clone)]
827struct QueuedMessage {
828 seq: u64,
829 enqueued_at: i64,
830 message: Message,
831}
832
833#[derive(Debug)]
834struct MessageQueue {
835 steering: VecDeque<QueuedMessage>,
836 follow_up: VecDeque<QueuedMessage>,
837 steering_mode: QueueMode,
838 follow_up_mode: QueueMode,
839 next_seq: u64,
840}
841
842impl MessageQueue {
843 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
844 Self {
845 steering: VecDeque::new(),
846 follow_up: VecDeque::new(),
847 steering_mode,
848 follow_up_mode,
849 next_seq: 0,
850 }
851 }
852
853 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
854 self.steering_mode = steering_mode;
855 self.follow_up_mode = follow_up_mode;
856 }
857
858 fn pending_count(&self) -> usize {
859 self.steering.len() + self.follow_up.len()
860 }
861
862 fn push(&mut self, kind: QueueKind, message: Message) -> u64 {
863 let seq = self.next_seq;
864 self.next_seq = self.next_seq.saturating_add(1);
865 let entry = QueuedMessage {
866 seq,
867 enqueued_at: Utc::now().timestamp_millis(),
868 message,
869 };
870 match kind {
871 QueueKind::Steering => {
872 if self.steering.len() >= MAX_STEERING_QUEUE_SIZE {
873 tracing::warn!(
874 "Steering queue full ({} messages), dropping oldest message",
875 MAX_STEERING_QUEUE_SIZE
876 );
877 self.steering.pop_front();
878 }
879 self.steering.push_back(entry);
880 }
881 QueueKind::FollowUp => {
882 if self.follow_up.len() >= MAX_FOLLOW_UP_QUEUE_SIZE {
883 tracing::warn!(
884 "Follow-up queue full ({} messages), dropping oldest message",
885 MAX_FOLLOW_UP_QUEUE_SIZE
886 );
887 self.follow_up.pop_front();
888 }
889 self.follow_up.push_back(entry);
890 }
891 }
892 seq
893 }
894
895 fn push_steering(&mut self, message: Message) -> u64 {
896 self.push(QueueKind::Steering, message)
897 }
898
899 fn push_follow_up(&mut self, message: Message) -> u64 {
900 self.push(QueueKind::FollowUp, message)
901 }
902
903 fn pop_steering(&mut self) -> Vec<Message> {
904 self.pop_kind(QueueKind::Steering)
905 }
906
907 fn pop_follow_up(&mut self) -> Vec<Message> {
908 self.pop_kind(QueueKind::FollowUp)
909 }
910
911 fn pop_kind(&mut self, kind: QueueKind) -> Vec<Message> {
912 let (queue, mode) = match kind {
913 QueueKind::Steering => (&mut self.steering, self.steering_mode),
914 QueueKind::FollowUp => (&mut self.follow_up, self.follow_up_mode),
915 };
916
917 match mode {
918 QueueMode::All => queue.drain(..).map(|entry| entry.message).collect(),
919 QueueMode::OneAtATime => queue
920 .pop_front()
921 .into_iter()
922 .map(|entry| entry.message)
923 .collect(),
924 }
925 }
926}
927
928#[derive(Debug, Clone, Serialize)]
934#[serde(tag = "type", rename_all = "snake_case")]
935pub enum AgentEvent {
936 AgentStart {
938 #[serde(rename = "sessionId")]
939 session_id: Arc<str>,
940 },
941 AgentEnd {
943 #[serde(rename = "sessionId")]
944 session_id: Arc<str>,
945 messages: Vec<Message>,
946 #[serde(skip_serializing_if = "Option::is_none")]
947 error: Option<String>,
948 },
949 TurnStart {
951 #[serde(rename = "sessionId")]
952 session_id: Arc<str>,
953 #[serde(rename = "turnIndex")]
954 turn_index: usize,
955 timestamp: i64,
956 },
957 TurnEnd {
959 #[serde(rename = "sessionId")]
960 session_id: Arc<str>,
961 #[serde(rename = "turnIndex")]
962 turn_index: usize,
963 message: Message,
964 #[serde(rename = "toolResults")]
965 tool_results: Vec<Message>,
966 #[serde(rename = "latencyBreakdown", skip_serializing_if = "Option::is_none")]
967 latency_breakdown: Option<Box<TurnLatencyBreakdown>>,
968 },
969 MessageStart { message: Message },
971 MessageUpdate {
973 message: Message,
974 #[serde(rename = "assistantMessageEvent")]
975 assistant_message_event: AssistantMessageEvent,
976 },
977 MessageEnd { message: Message },
979 ToolExecutionStart {
981 #[serde(rename = "toolCallId")]
982 tool_call_id: String,
983 #[serde(rename = "toolName")]
984 tool_name: String,
985 args: serde_json::Value,
986 },
987 ToolExecutionUpdate {
989 #[serde(rename = "toolCallId")]
990 tool_call_id: String,
991 #[serde(rename = "toolName")]
992 tool_name: String,
993 args: serde_json::Value,
994 #[serde(rename = "partialResult")]
995 partial_result: ToolOutput,
996 },
997 ToolExecutionEnd {
999 #[serde(rename = "toolCallId")]
1000 tool_call_id: String,
1001 #[serde(rename = "toolName")]
1002 tool_name: String,
1003 result: ToolOutput,
1004 #[serde(rename = "isError")]
1005 is_error: bool,
1006 },
1007 AutoCompactionStart { reason: String },
1009 AutoCompactionEnd {
1011 #[serde(skip_serializing_if = "Option::is_none")]
1012 result: Option<serde_json::Value>,
1013 aborted: bool,
1014 #[serde(rename = "willRetry")]
1015 will_retry: bool,
1016 #[serde(rename = "errorMessage", skip_serializing_if = "Option::is_none")]
1017 error_message: Option<String>,
1018 },
1019 AutoRetryStart {
1021 attempt: u32,
1022 #[serde(rename = "maxAttempts")]
1023 max_attempts: u32,
1024 #[serde(rename = "delayMs")]
1025 delay_ms: u64,
1026 #[serde(rename = "errorMessage")]
1027 error_message: String,
1028 },
1029 AutoRetryEnd {
1031 success: bool,
1032 attempt: u32,
1033 #[serde(rename = "finalError", skip_serializing_if = "Option::is_none")]
1034 final_error: Option<String>,
1035 },
1036 ExtensionError {
1038 #[serde(rename = "extensionId", skip_serializing_if = "Option::is_none")]
1039 extension_id: Option<String>,
1040 event: String,
1041 error: String,
1042 },
1043}
1044
1045#[derive(Debug, Clone)]
1051pub struct AbortHandle {
1052 inner: Arc<AbortSignalInner>,
1053}
1054
1055#[derive(Debug, Clone)]
1057pub struct AbortSignal {
1058 inner: Arc<AbortSignalInner>,
1059}
1060
1061#[derive(Debug)]
1062struct AbortSignalInner {
1063 aborted: AtomicBool,
1064 notify: Notify,
1065}
1066
1067impl AbortHandle {
1068 #[must_use]
1070 pub fn new() -> (Self, AbortSignal) {
1071 let inner = Arc::new(AbortSignalInner {
1072 aborted: AtomicBool::new(false),
1073 notify: Notify::new(),
1074 });
1075 (
1076 Self {
1077 inner: Arc::clone(&inner),
1078 },
1079 AbortSignal { inner },
1080 )
1081 }
1082
1083 pub fn abort(&self) {
1085 if !self.inner.aborted.swap(true, Ordering::SeqCst) {
1086 self.inner.notify.notify_waiters();
1087 }
1088 }
1089}
1090
1091impl AbortSignal {
1092 #[must_use]
1094 pub fn is_aborted(&self) -> bool {
1095 self.inner.aborted.load(Ordering::SeqCst)
1096 }
1097
1098 pub async fn wait(&self) {
1099 if self.is_aborted() {
1100 return;
1101 }
1102
1103 loop {
1104 self.inner.notify.notified().await;
1105 if self.is_aborted() {
1106 return;
1107 }
1108 }
1109 }
1110}
1111
1112pub struct Agent {
1114 provider: Arc<dyn Provider>,
1116
1117 tools: ToolRegistry,
1119
1120 config: AgentConfig,
1122
1123 extensions: Option<ExtensionManager>,
1125
1126 messages: Vec<Message>,
1128
1129 steering_fetchers: Vec<MessageFetcher>,
1131
1132 follow_up_fetchers: Vec<MessageFetcher>,
1134
1135 message_queue: MessageQueue,
1137
1138 cached_tool_defs: Option<Vec<ToolDef>>,
1140}
1141
1142impl Agent {
1143 pub fn new(provider: Arc<dyn Provider>, tools: ToolRegistry, config: AgentConfig) -> Self {
1145 Self {
1146 provider,
1147 tools,
1148 config,
1149 extensions: None,
1150 messages: Vec::new(),
1151 steering_fetchers: Vec::new(),
1152 follow_up_fetchers: Vec::new(),
1153 message_queue: MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime),
1154 cached_tool_defs: None,
1155 }
1156 }
1157
1158 #[must_use]
1160 pub fn messages(&self) -> &[Message] {
1161 &self.messages
1162 }
1163
1164 pub fn clear_messages(&mut self) {
1166 self.messages.clear();
1167 }
1168
1169 pub fn add_message(&mut self, message: Message) {
1171 if self.messages.len() >= MAX_AGENT_MESSAGES {
1172 tracing::warn!(
1173 "Agent message history full ({} messages), dropping oldest message",
1174 MAX_AGENT_MESSAGES
1175 );
1176 self.messages.remove(0);
1177 }
1178 self.messages.push(message);
1179 }
1180
1181 pub fn replace_messages(&mut self, messages: Vec<Message>) {
1183 self.messages = messages;
1184 }
1185
1186 pub fn set_provider(&mut self, provider: Arc<dyn Provider>) {
1188 self.provider = provider;
1189 }
1190
1191 pub fn register_message_fetchers(
1196 &mut self,
1197 steering: Option<MessageFetcher>,
1198 follow_up: Option<MessageFetcher>,
1199 ) {
1200 if let Some(fetcher) = steering {
1201 self.steering_fetchers.push(fetcher);
1202 }
1203 if let Some(fetcher) = follow_up {
1204 self.follow_up_fetchers.push(fetcher);
1205 }
1206 }
1207
1208 pub fn extend_tools<I>(&mut self, tools: I)
1210 where
1211 I: IntoIterator<Item = Box<dyn Tool>>,
1212 {
1213 self.tools.extend(tools);
1214 self.cached_tool_defs = None; }
1216
1217 pub fn queue_steering(&mut self, message: Message) -> u64 {
1219 self.message_queue.push_steering(message)
1220 }
1221
1222 pub fn queue_follow_up(&mut self, message: Message) -> u64 {
1224 self.message_queue.push_follow_up(message)
1225 }
1226
1227 pub const fn set_queue_modes(&mut self, steering: QueueMode, follow_up: QueueMode) {
1229 self.message_queue.set_modes(steering, follow_up);
1230 }
1231
1232 pub const fn queue_modes(&self) -> (QueueMode, QueueMode) {
1233 (
1234 self.message_queue.steering_mode,
1235 self.message_queue.follow_up_mode,
1236 )
1237 }
1238
1239 #[must_use]
1241 pub fn queued_message_count(&self) -> usize {
1242 self.message_queue.pending_count()
1243 }
1244
1245 pub fn provider(&self) -> Arc<dyn Provider> {
1246 Arc::clone(&self.provider)
1247 }
1248
1249 pub const fn stream_options(&self) -> &StreamOptions {
1250 &self.config.stream_options
1251 }
1252
1253 pub const fn stream_options_mut(&mut self) -> &mut StreamOptions {
1254 &mut self.config.stream_options
1255 }
1256
1257 pub fn system_prompt(&self) -> Option<&str> {
1258 self.config.system_prompt.as_deref()
1259 }
1260
1261 pub fn set_system_prompt(&mut self, system_prompt: Option<String>) {
1262 self.config.system_prompt = system_prompt;
1263 }
1264
1265 fn build_context(&mut self) -> Context<'_> {
1267 let messages: Cow<'_, [Message]> = if self.config.block_images {
1268 let mut msgs = self.messages.clone();
1269 msgs.retain(|m| match m {
1271 Message::Custom(c) => c.display,
1272 _ => true,
1273 });
1274 let stats = filter_images_for_provider(&mut msgs);
1275 if stats.removed_images > 0 {
1276 tracing::debug!(
1277 filtered_images = stats.removed_images,
1278 affected_messages = stats.affected_messages,
1279 "Filtered image content from outbound provider context (images.block_images=true)"
1280 );
1281 }
1282 Cow::Owned(msgs)
1283 } else {
1284 let has_hidden = self.messages.iter().any(|m| match m {
1286 Message::Custom(c) => !c.display,
1287 _ => false,
1288 });
1289
1290 if has_hidden {
1291 let mut msgs = self.messages.clone();
1292 msgs.retain(|m| match m {
1293 Message::Custom(c) => c.display,
1294 _ => true,
1295 });
1296 Cow::Owned(msgs)
1297 } else {
1298 Cow::Borrowed(self.messages.as_slice())
1299 }
1300 };
1301
1302 if self.cached_tool_defs.is_none() {
1304 let defs: Vec<ToolDef> = self
1305 .tools
1306 .tools()
1307 .iter()
1308 .map(|t| ToolDef {
1309 name: t.name().to_string(),
1310 description: t.description().to_string(),
1311 parameters: t.parameters(),
1312 })
1313 .collect();
1314 self.cached_tool_defs = Some(defs);
1315 }
1316 let tools = Cow::Borrowed(self.cached_tool_defs.as_deref().unwrap());
1317
1318 Context {
1319 system_prompt: self.config.system_prompt.as_deref().map(Cow::Borrowed),
1320 messages,
1321 tools,
1322 }
1323 }
1324
1325 pub async fn run(
1329 &mut self,
1330 user_input: impl Into<String>,
1331 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1332 ) -> Result<AssistantMessage> {
1333 self.run_with_abort(user_input, None, on_event).await
1334 }
1335
1336 pub async fn run_with_abort(
1338 &mut self,
1339 user_input: impl Into<String>,
1340 abort: Option<AbortSignal>,
1341 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1342 ) -> Result<AssistantMessage> {
1343 let user_message = Message::User(UserMessage {
1345 content: UserContent::Text(user_input.into()),
1346 timestamp: Utc::now().timestamp_millis(),
1347 });
1348
1349 self.run_loop(vec![user_message], Arc::new(on_event), abort)
1351 .await
1352 }
1353
1354 pub async fn run_with_content(
1356 &mut self,
1357 content: Vec<ContentBlock>,
1358 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1359 ) -> Result<AssistantMessage> {
1360 self.run_with_content_with_abort(content, None, on_event)
1361 .await
1362 }
1363
1364 pub async fn run_with_content_with_abort(
1366 &mut self,
1367 content: Vec<ContentBlock>,
1368 abort: Option<AbortSignal>,
1369 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1370 ) -> Result<AssistantMessage> {
1371 let user_message = Message::User(UserMessage {
1373 content: UserContent::Blocks(content),
1374 timestamp: Utc::now().timestamp_millis(),
1375 });
1376
1377 self.run_loop(vec![user_message], Arc::new(on_event), abort)
1379 .await
1380 }
1381
1382 pub async fn run_with_message_with_abort(
1384 &mut self,
1385 message: Message,
1386 abort: Option<AbortSignal>,
1387 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1388 ) -> Result<AssistantMessage> {
1389 self.run_loop(vec![message], Arc::new(on_event), abort)
1390 .await
1391 }
1392
1393 pub async fn run_with_messages_with_abort(
1395 &mut self,
1396 messages: Vec<Message>,
1397 abort: Option<AbortSignal>,
1398 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1399 ) -> Result<AssistantMessage> {
1400 self.run_loop(messages, Arc::new(on_event), abort).await
1401 }
1402
1403 pub async fn run_continue_with_abort(
1405 &mut self,
1406 abort: Option<AbortSignal>,
1407 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
1408 ) -> Result<AssistantMessage> {
1409 self.run_loop(Vec::new(), Arc::new(on_event), abort).await
1410 }
1411
1412 fn build_abort_message(&self, partial: Option<&AssistantMessage>) -> AssistantMessage {
1413 let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
1414 content: Vec::new(),
1415 api: self.provider.api().to_string(),
1416 provider: self.provider.name().to_string(),
1417 model: self.provider.model_id().to_string(),
1418 usage: Usage::default(),
1419 stop_reason: StopReason::Aborted,
1420 error_message: Some("Aborted".to_string()),
1421 timestamp: Utc::now().timestamp_millis(),
1422 });
1423 message.stop_reason = StopReason::Aborted;
1424 message.error_message = Some("Aborted".to_string());
1425 message.timestamp = Utc::now().timestamp_millis();
1426 message
1427 }
1428
1429 fn build_error_message(
1430 &self,
1431 partial: Option<&AssistantMessage>,
1432 error_message: impl Into<String>,
1433 ) -> AssistantMessage {
1434 let error_message = error_message.into();
1435 let mut message = partial.cloned().unwrap_or_else(|| AssistantMessage {
1436 content: Vec::new(),
1437 api: self.provider.api().to_string(),
1438 provider: self.provider.name().to_string(),
1439 model: self.provider.model_id().to_string(),
1440 usage: Usage::default(),
1441 stop_reason: StopReason::Error,
1442 error_message: Some(error_message.clone()),
1443 timestamp: Utc::now().timestamp_millis(),
1444 });
1445 message.stop_reason = StopReason::Error;
1446 message.error_message = Some(error_message);
1447 message.timestamp = Utc::now().timestamp_millis();
1448 message
1449 }
1450
1451 #[allow(clippy::too_many_lines)]
1453 async fn run_loop(
1454 &mut self,
1455 prompts: Vec<Message>,
1456 on_event: AgentEventHandler,
1457 abort: Option<AbortSignal>,
1458 ) -> Result<AssistantMessage> {
1459 let loop_cx = crate::agent_cx::AgentCx::for_current_or_request();
1460 let session_id: Arc<str> = self
1461 .config
1462 .stream_options
1463 .session_id
1464 .as_deref()
1465 .unwrap_or("")
1466 .into();
1467 let mut iterations = 0usize;
1468 let mut warned_at_handoff_threshold = false;
1469 let mut turn_index: usize = 0;
1470 let mut new_messages: Vec<Message> = Vec::with_capacity(prompts.len() + 8);
1471 let mut last_assistant: Option<Arc<AssistantMessage>> = None;
1472
1473 let agent_start_event = AgentEvent::AgentStart {
1474 session_id: session_id.clone(),
1475 };
1476 self.dispatch_extension_lifecycle_event(&agent_start_event)
1477 .await;
1478 on_event(agent_start_event);
1479
1480 for prompt in prompts {
1481 self.messages.push(prompt.clone());
1482 on_event(AgentEvent::MessageStart {
1483 message: prompt.clone(),
1484 });
1485 on_event(AgentEvent::MessageEnd {
1486 message: prompt.clone(),
1487 });
1488 new_messages.push(prompt);
1489 }
1490
1491 let mut pending_messages = self.drain_steering_messages().await;
1493
1494 loop {
1495 let mut has_more_tool_calls = true;
1496 let mut steering_after_tools: Option<Vec<Message>> = None;
1497
1498 while has_more_tool_calls || !pending_messages.is_empty() {
1499 let current_turn_index = turn_index;
1500 let turn_latency = Arc::new(StdMutex::new(TurnLatencyAccumulator::started()));
1501 let turn_start_event = AgentEvent::TurnStart {
1502 session_id: session_id.clone(),
1503 turn_index: current_turn_index,
1504 timestamp: Utc::now().timestamp_millis(),
1505 };
1506 self.dispatch_extension_lifecycle_event(&turn_start_event)
1507 .await;
1508 on_event(turn_start_event);
1509
1510 for message in std::mem::take(&mut pending_messages) {
1511 self.messages.push(message.clone());
1512 on_event(AgentEvent::MessageStart {
1513 message: message.clone(),
1514 });
1515 on_event(AgentEvent::MessageEnd {
1516 message: message.clone(),
1517 });
1518 new_messages.push(message);
1519 }
1520
1521 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
1522 let abort_message = self.build_abort_message(None);
1523 let message = Message::assistant(abort_message.clone());
1524
1525 self.messages.push(message.clone());
1526 new_messages.push(message.clone());
1527 on_event(AgentEvent::MessageStart {
1528 message: message.clone(),
1529 });
1530 on_event(AgentEvent::MessageEnd {
1531 message: message.clone(),
1532 });
1533
1534 let turn_end_event = AgentEvent::TurnEnd {
1535 session_id: session_id.clone(),
1536 turn_index: current_turn_index,
1537 message,
1538 tool_results: Vec::new(),
1539 latency_breakdown: snapshot_turn_latency(&turn_latency),
1540 };
1541 self.dispatch_extension_lifecycle_event(&turn_end_event)
1542 .await;
1543 on_event(turn_end_event);
1544 let agent_end_event = AgentEvent::AgentEnd {
1545 session_id: session_id.clone(),
1546 messages: std::mem::take(&mut new_messages),
1547 error: Some(
1548 abort_message
1549 .error_message
1550 .clone()
1551 .unwrap_or_else(|| "Aborted".to_string()),
1552 ),
1553 };
1554 self.dispatch_extension_lifecycle_event(&agent_end_event)
1555 .await;
1556 on_event(agent_end_event);
1557 return Ok(abort_message);
1558 }
1559
1560 let provider_streaming_started_at = Instant::now();
1561 let assistant_result = self
1562 .stream_assistant_response(Arc::clone(&on_event), abort.clone(), &loop_cx)
1563 .await;
1564 record_provider_streaming_latency(
1565 &turn_latency,
1566 provider_streaming_started_at.elapsed(),
1567 );
1568
1569 let assistant_message = match assistant_result {
1570 Ok(msg) => msg,
1571 Err(err) => {
1572 let err_string = err.to_string();
1573 let steering_to_add = self.drain_steering_messages().await;
1574 for message in steering_to_add {
1575 self.messages.push(message.clone());
1576 on_event(AgentEvent::MessageStart {
1577 message: message.clone(),
1578 });
1579 on_event(AgentEvent::MessageEnd {
1580 message: message.clone(),
1581 });
1582 new_messages.push(message);
1583 }
1584
1585 let error_message = self.build_error_message(None, err_string.clone());
1586 let assistant_event_message = Message::assistant(error_message.clone());
1587 self.messages.push(assistant_event_message.clone());
1588 new_messages.push(assistant_event_message.clone());
1589 on_event(AgentEvent::MessageStart {
1590 message: assistant_event_message.clone(),
1591 });
1592 on_event(AgentEvent::MessageEnd {
1593 message: assistant_event_message.clone(),
1594 });
1595
1596 let turn_end_event = AgentEvent::TurnEnd {
1597 session_id: session_id.clone(),
1598 turn_index: current_turn_index,
1599 message: assistant_event_message,
1600 tool_results: Vec::new(),
1601 latency_breakdown: snapshot_turn_latency(&turn_latency),
1602 };
1603 self.dispatch_extension_lifecycle_event(&turn_end_event)
1604 .await;
1605 on_event(turn_end_event);
1606
1607 let agent_end_event = AgentEvent::AgentEnd {
1608 session_id: session_id.clone(),
1609 messages: std::mem::take(&mut new_messages),
1610 error: Some(err_string),
1611 };
1612 self.dispatch_extension_lifecycle_event(&agent_end_event)
1613 .await;
1614 on_event(agent_end_event);
1615 return Err(err);
1616 }
1617 };
1618 let assistant_arc = Arc::new(assistant_message);
1621 last_assistant = Some(Arc::clone(&assistant_arc));
1622
1623 let assistant_event_message = Message::Assistant(Arc::clone(&assistant_arc));
1624 new_messages.push(assistant_event_message.clone());
1625
1626 if matches!(
1627 assistant_arc.stop_reason,
1628 StopReason::Error | StopReason::Aborted
1629 ) {
1630 let steering_to_add = self.drain_steering_messages().await;
1631 for message in steering_to_add {
1632 self.messages.push(message.clone());
1633 on_event(AgentEvent::MessageStart {
1634 message: message.clone(),
1635 });
1636 on_event(AgentEvent::MessageEnd {
1637 message: message.clone(),
1638 });
1639 new_messages.push(message);
1640 }
1641
1642 let turn_end_event = AgentEvent::TurnEnd {
1643 session_id: session_id.clone(),
1644 turn_index: current_turn_index,
1645 message: assistant_event_message.clone(),
1646 tool_results: Vec::new(),
1647 latency_breakdown: snapshot_turn_latency(&turn_latency),
1648 };
1649 self.dispatch_extension_lifecycle_event(&turn_end_event)
1650 .await;
1651 on_event(turn_end_event);
1652 let agent_end_event = AgentEvent::AgentEnd {
1653 session_id: session_id.clone(),
1654 messages: std::mem::take(&mut new_messages),
1655 error: assistant_arc.error_message.clone(),
1656 };
1657 self.dispatch_extension_lifecycle_event(&agent_end_event)
1658 .await;
1659 on_event(agent_end_event);
1660 return Ok(Arc::unwrap_or_clone(assistant_arc));
1661 }
1662
1663 let tool_calls = extract_tool_calls(&assistant_arc.content);
1664 has_more_tool_calls = !tool_calls.is_empty();
1665
1666 let mut tool_results: Vec<Arc<ToolResultMessage>> = Vec::new();
1667 if has_more_tool_calls {
1668 iterations += 1;
1669 if !warned_at_handoff_threshold
1677 && should_warn_at_iteration_threshold(
1678 iterations,
1679 self.config.max_tool_iterations,
1680 )
1681 {
1682 warned_at_handoff_threshold = true;
1683 let warning = Message::User(UserMessage {
1684 content: UserContent::Text(iteration_handoff_steering_text(
1685 iterations,
1686 self.config.max_tool_iterations,
1687 )),
1688 timestamp: Utc::now().timestamp_millis(),
1689 });
1690 self.message_queue.push_steering(warning);
1691 tracing::warn!(
1692 iterations,
1693 max = self.config.max_tool_iterations,
1694 "tool-iteration budget at >=80%; injected handoff steering message"
1695 );
1696 }
1697 if iterations > self.config.max_tool_iterations {
1698 let error_message = format!(
1699 "Maximum tool iterations ({}) exceeded",
1700 self.config.max_tool_iterations
1701 );
1702 let mut stop_message = (*assistant_arc).clone();
1703 stop_message.stop_reason = StopReason::Error;
1704 stop_message.error_message = Some(error_message.clone());
1705
1706 stop_message
1708 .content
1709 .retain(|b| !matches!(b, crate::model::ContentBlock::ToolCall(_)));
1710
1711 let stop_arc = Arc::new(stop_message.clone());
1712 let stop_event_message = Message::Assistant(Arc::clone(&stop_arc));
1713
1714 if let Some(last @ Message::Assistant(_)) = self
1717 .messages
1718 .iter_mut()
1719 .rev()
1720 .find(|m| matches!(m, Message::Assistant(_)))
1721 {
1722 *last = stop_event_message.clone();
1723 }
1724 if let Some(last @ Message::Assistant(_)) = new_messages.last_mut() {
1725 *last = stop_event_message.clone();
1726 }
1727
1728 let steering_to_add = self.drain_steering_messages().await;
1729 for message in steering_to_add {
1730 self.messages.push(message.clone());
1731 on_event(AgentEvent::MessageStart {
1732 message: message.clone(),
1733 });
1734 on_event(AgentEvent::MessageEnd {
1735 message: message.clone(),
1736 });
1737 new_messages.push(message);
1738 }
1739
1740 let turn_end_event = AgentEvent::TurnEnd {
1741 session_id: session_id.clone(),
1742 turn_index: current_turn_index,
1743 message: stop_event_message,
1744 tool_results: Vec::new(),
1745 latency_breakdown: snapshot_turn_latency(&turn_latency),
1746 };
1747 self.dispatch_extension_lifecycle_event(&turn_end_event)
1748 .await;
1749 on_event(turn_end_event);
1750
1751 let agent_end_event = AgentEvent::AgentEnd {
1752 session_id: session_id.clone(),
1753 messages: std::mem::take(&mut new_messages),
1754 error: Some(error_message),
1755 };
1756 self.dispatch_extension_lifecycle_event(&agent_end_event)
1757 .await;
1758 on_event(agent_end_event);
1759
1760 return Ok(stop_message);
1761 }
1762
1763 let outcome = match self
1764 .execute_tool_calls(
1765 &tool_calls,
1766 Arc::clone(&on_event),
1767 &mut new_messages,
1768 abort.clone(),
1769 Arc::clone(&turn_latency),
1770 )
1771 .await
1772 {
1773 Ok(outcome) => outcome,
1774 Err(err) => {
1775 let steering_to_add = self.drain_steering_messages().await;
1776 for message in steering_to_add {
1777 self.messages.push(message.clone());
1778 on_event(AgentEvent::MessageStart {
1779 message: message.clone(),
1780 });
1781 on_event(AgentEvent::MessageEnd {
1782 message: message.clone(),
1783 });
1784 new_messages.push(message);
1785 }
1786
1787 let turn_end_event = AgentEvent::TurnEnd {
1788 session_id: session_id.clone(),
1789 turn_index: current_turn_index,
1790 message: assistant_event_message.clone(),
1791 tool_results: Vec::new(),
1792 latency_breakdown: snapshot_turn_latency(&turn_latency),
1793 };
1794 self.dispatch_extension_lifecycle_event(&turn_end_event)
1795 .await;
1796 on_event(turn_end_event);
1797
1798 let agent_end_event = AgentEvent::AgentEnd {
1799 session_id: session_id.clone(),
1800 messages: std::mem::take(&mut new_messages),
1801 error: Some(err.to_string()),
1802 };
1803 self.dispatch_extension_lifecycle_event(&agent_end_event)
1804 .await;
1805 on_event(agent_end_event);
1806 return Err(err);
1807 }
1808 };
1809 tool_results = outcome.tool_results;
1810 steering_after_tools = outcome.steering_messages;
1811 }
1812
1813 let tool_messages = tool_results
1814 .iter()
1815 .map(|r| Message::ToolResult(Arc::clone(r)))
1816 .collect::<Vec<_>>();
1817
1818 let turn_end_event = AgentEvent::TurnEnd {
1819 session_id: session_id.clone(),
1820 turn_index: current_turn_index,
1821 message: assistant_event_message.clone(),
1822 tool_results: tool_messages,
1823 latency_breakdown: snapshot_turn_latency(&turn_latency),
1824 };
1825 self.dispatch_extension_lifecycle_event(&turn_end_event)
1826 .await;
1827 on_event(turn_end_event);
1828
1829 turn_index = turn_index.saturating_add(1);
1830
1831 if let Some(steering) = steering_after_tools.take() {
1832 pending_messages = steering;
1833 } else {
1834 pending_messages = self.drain_steering_messages().await;
1836 }
1837 }
1838
1839 let follow_up = self.drain_follow_up_messages().await;
1841 if follow_up.is_empty() {
1842 break;
1843 }
1844 pending_messages = follow_up;
1845 }
1846
1847 let Some(final_arc) = last_assistant else {
1848 return Err(Error::api("Agent completed without assistant message"));
1849 };
1850
1851 let agent_end_event = AgentEvent::AgentEnd {
1852 session_id: session_id.clone(),
1853 messages: new_messages,
1854 error: None,
1855 };
1856 self.dispatch_extension_lifecycle_event(&agent_end_event)
1857 .await;
1858 on_event(agent_end_event);
1859 Ok(Arc::unwrap_or_clone(final_arc))
1860 }
1861
1862 async fn fetch_messages(&self, fetcher: Option<&MessageFetcher>) -> Vec<Message> {
1863 if let Some(fetcher) = fetcher {
1864 (fetcher)().await
1865 } else {
1866 Vec::new()
1867 }
1868 }
1869
1870 async fn dispatch_extension_lifecycle_event(&self, event: &AgentEvent) {
1871 let Some(extensions) = &self.extensions else {
1872 return;
1873 };
1874
1875 let name = match event {
1876 AgentEvent::AgentStart { .. } => ExtensionEventName::AgentStart,
1877 AgentEvent::AgentEnd { .. } => ExtensionEventName::AgentEnd,
1878 AgentEvent::TurnStart { .. } => ExtensionEventName::TurnStart,
1879 AgentEvent::TurnEnd { .. } => ExtensionEventName::TurnEnd,
1880 _ => return,
1881 };
1882
1883 let payload = match serde_json::to_value(event) {
1884 Ok(payload) => payload,
1885 Err(err) => {
1886 tracing::warn!("failed to serialize agent lifecycle event (fail-open): {err}");
1887 return;
1888 }
1889 };
1890
1891 if let Err(err) = extensions.dispatch_event(name, Some(payload)).await {
1892 tracing::warn!("agent lifecycle extension hook failed (fail-open): {err}");
1893 }
1894 }
1895
1896 async fn dispatch_context_event(&self, messages: &[Message]) -> Option<Vec<Message>> {
1897 let Some(extensions) = &self.extensions else {
1898 return None;
1899 };
1900
1901 let payload = json!({ "messages": messages });
1902 let response = extensions
1903 .dispatch_event_with_response(
1904 ExtensionEventName::Context,
1905 Some(payload),
1906 EXTENSION_EVENT_TIMEOUT_MS,
1907 )
1908 .await
1909 .ok()?;
1910
1911 let value = response?;
1912
1913 if value.is_null() {
1914 return None;
1915 }
1916
1917 let messages_value = if let Some(obj) = value.as_object() {
1918 obj.get("messages").cloned()?
1919 } else if value.is_array() {
1920 value
1921 } else {
1922 return None;
1923 };
1924
1925 if messages_value.is_null() {
1926 return Some(Vec::new());
1927 }
1928
1929 match serde_json::from_value(messages_value) {
1930 Ok(messages) => Some(messages),
1931 Err(err) => {
1932 tracing::warn!("context extension hook returned invalid messages: {err}");
1933 None
1934 }
1935 }
1936 }
1937
1938 async fn drain_steering_messages(&mut self) -> Vec<Message> {
1939 for fetcher in &self.steering_fetchers {
1940 let fetched = self.fetch_messages(Some(fetcher)).await;
1941 for message in fetched {
1942 self.message_queue.push_steering(message);
1943 }
1944 }
1945 self.message_queue.pop_steering()
1946 }
1947
1948 async fn drain_follow_up_messages(&mut self) -> Vec<Message> {
1949 for fetcher in &self.follow_up_fetchers {
1950 let fetched = self.fetch_messages(Some(fetcher)).await;
1951 for message in fetched {
1952 self.message_queue.push_follow_up(message);
1953 }
1954 }
1955 self.message_queue.pop_follow_up()
1956 }
1957
1958 #[allow(clippy::too_many_lines)]
1960 async fn stream_assistant_response(
1961 &mut self,
1962 on_event: AgentEventHandler,
1963 abort: Option<AbortSignal>,
1964 checkpoint_cx: &crate::agent_cx::AgentCx,
1965 ) -> Result<AssistantMessage> {
1966 let provider = Arc::clone(&self.provider);
1968 let stream_options = self.config.stream_options.clone();
1969 let (system_prompt, tools, base_messages) = {
1970 let context = self.build_context();
1971 (
1972 context.system_prompt.as_deref().map(str::to_string),
1973 context.tools.to_vec(),
1974 context.messages.to_vec(),
1975 )
1976 };
1977 let messages = self
1978 .dispatch_context_event(&base_messages)
1979 .await
1980 .unwrap_or(base_messages);
1981 let context = Context::owned(system_prompt, messages, tools);
1982 let mut stream = provider.stream(&context, &stream_options).await?;
1983
1984 let mut added_partial = false;
1985 let mut sent_start = false;
1988
1989 'stream: loop {
1990 if checkpoint_cx.checkpoint().is_err() {
1991 let last_partial = if added_partial {
1992 match self
1993 .messages
1994 .iter()
1995 .rev()
1996 .find(|m| matches!(m, Message::Assistant(_)))
1997 {
1998 Some(Message::Assistant(a)) => Some(a.as_ref()),
1999 _ => None,
2000 }
2001 } else {
2002 None
2003 };
2004 let abort_arc = Arc::new(self.build_abort_message(last_partial));
2005 if !sent_start {
2006 on_event(AgentEvent::MessageStart {
2007 message: Message::Assistant(Arc::clone(&abort_arc)),
2008 });
2009 self.messages
2010 .push(Message::Assistant(Arc::clone(&abort_arc)));
2011 added_partial = true;
2012 }
2013 on_event(AgentEvent::MessageUpdate {
2014 message: Message::Assistant(Arc::clone(&abort_arc)),
2015 assistant_message_event: AssistantMessageEvent::Error {
2016 reason: StopReason::Aborted,
2017 error: Arc::clone(&abort_arc),
2018 },
2019 });
2020 return Ok(self.finalize_assistant_message(
2021 Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
2022 &on_event,
2023 added_partial,
2024 ));
2025 }
2026
2027 let event_result = if let Some(signal) = abort.as_ref() {
2028 let abort_fut = signal.wait().fuse();
2029 let event_fut = stream.next().fuse();
2030 futures::pin_mut!(abort_fut, event_fut);
2031
2032 match futures::future::select(abort_fut, event_fut).await {
2033 futures::future::Either::Left(((), _event_fut)) => {
2034 let last_partial = if added_partial {
2035 match self
2036 .messages
2037 .iter()
2038 .rev()
2039 .find(|m| matches!(m, Message::Assistant(_)))
2040 {
2041 Some(Message::Assistant(a)) => Some(a.as_ref()),
2042 _ => None,
2043 }
2044 } else {
2045 None
2046 };
2047 let abort_arc = Arc::new(self.build_abort_message(last_partial));
2048 if !sent_start {
2049 on_event(AgentEvent::MessageStart {
2050 message: Message::Assistant(Arc::clone(&abort_arc)),
2051 });
2052 self.messages
2053 .push(Message::Assistant(Arc::clone(&abort_arc)));
2054 added_partial = true;
2055 }
2059 on_event(AgentEvent::MessageUpdate {
2060 message: Message::Assistant(Arc::clone(&abort_arc)),
2061 assistant_message_event: AssistantMessageEvent::Error {
2062 reason: StopReason::Aborted,
2063 error: Arc::clone(&abort_arc),
2064 },
2065 });
2066 return Ok(self.finalize_assistant_message(
2067 Arc::try_unwrap(abort_arc).unwrap_or_else(|a| (*a).clone()),
2068 &on_event,
2069 added_partial,
2070 ));
2071 }
2072 futures::future::Either::Right((event, _abort_fut)) => event,
2073 }
2074 } else {
2075 let event_fut = stream.next().fuse();
2076 futures::pin_mut!(event_fut);
2077 loop {
2078 let now = checkpoint_cx
2079 .cx()
2080 .timer_driver()
2081 .map_or_else(asupersync::time::wall_now, |timer| timer.now());
2082 let tick_fut =
2083 asupersync::time::sleep(now, std::time::Duration::from_millis(25)).fuse();
2084 futures::pin_mut!(tick_fut);
2085
2086 match futures::future::select(tick_fut, &mut event_fut).await {
2087 futures::future::Either::Left(((), _event_fut)) => {
2088 if checkpoint_cx.checkpoint().is_err() {
2089 continue 'stream;
2090 }
2091 }
2092 futures::future::Either::Right((result, _tick_fut)) => break result,
2093 }
2094 }
2095 };
2096
2097 let Some(event_result) = event_result else {
2098 break;
2099 };
2100 let event = match event_result {
2101 Ok(e) => e,
2102 Err(err) => {
2103 let partial = if added_partial {
2104 match self
2105 .messages
2106 .iter()
2107 .rev()
2108 .find(|m| matches!(m, Message::Assistant(_)))
2109 {
2110 Some(Message::Assistant(a)) => Some(a.as_ref()),
2111 _ => None,
2112 }
2113 } else {
2114 None
2115 };
2116 let msg = self.build_error_message(partial, err.to_string());
2117
2118 return Ok(self.finalize_assistant_message(msg, &on_event, added_partial));
2122 }
2123 };
2124
2125 match event {
2126 StreamEvent::Start { partial } => {
2127 if added_partial {
2128 if let Some(Message::Assistant(msg_arc)) = self
2129 .messages
2130 .iter_mut()
2131 .rev()
2132 .find(|m| matches!(m, Message::Assistant(_)))
2133 {
2134 let msg = Arc::make_mut(msg_arc);
2135 if msg.content.is_empty() {
2136 *msg = partial;
2137 } else {
2138 msg.api = partial.api;
2139 msg.provider = partial.provider;
2140 msg.model = partial.model;
2141 msg.usage = partial.usage;
2142 msg.stop_reason = partial.stop_reason;
2143 msg.error_message = partial.error_message;
2144 msg.timestamp = partial.timestamp;
2145 }
2146 let shared = Arc::clone(msg_arc);
2147 if !sent_start {
2148 on_event(AgentEvent::MessageStart {
2149 message: Message::Assistant(Arc::clone(&shared)),
2150 });
2151 sent_start = true;
2152 }
2153 on_event(AgentEvent::MessageUpdate {
2154 message: Message::Assistant(Arc::clone(&shared)),
2155 assistant_message_event: AssistantMessageEvent::Start {
2156 partial: shared,
2157 },
2158 });
2159 } else {
2160 let shared = Arc::new(partial);
2161 self.update_partial_message(Arc::clone(&shared), &mut added_partial);
2162 on_event(AgentEvent::MessageStart {
2163 message: Message::Assistant(Arc::clone(&shared)),
2164 });
2165 sent_start = true;
2166 on_event(AgentEvent::MessageUpdate {
2167 message: Message::Assistant(Arc::clone(&shared)),
2168 assistant_message_event: AssistantMessageEvent::Start {
2169 partial: shared,
2170 },
2171 });
2172 }
2173 } else {
2174 let shared = Arc::new(partial);
2175 self.update_partial_message(Arc::clone(&shared), &mut added_partial);
2176 on_event(AgentEvent::MessageStart {
2177 message: Message::Assistant(Arc::clone(&shared)),
2178 });
2179 sent_start = true;
2180 on_event(AgentEvent::MessageUpdate {
2181 message: Message::Assistant(Arc::clone(&shared)),
2182 assistant_message_event: AssistantMessageEvent::Start {
2183 partial: shared,
2184 },
2185 });
2186 }
2187 }
2188 StreamEvent::TextStart { content_index, .. } => {
2189 self.seed_partial_message_if_missing(&mut added_partial);
2190 if let Some(Message::Assistant(msg_arc)) = self
2191 .messages
2192 .iter_mut()
2193 .rev()
2194 .find(|m| matches!(m, Message::Assistant(_)))
2195 {
2196 let msg = Arc::make_mut(msg_arc);
2197 if content_index == msg.content.len() {
2198 msg.content.push(ContentBlock::Text(TextContent::new("")));
2199 }
2200 let shared = Arc::clone(msg_arc);
2201 if !sent_start {
2202 on_event(AgentEvent::MessageStart {
2203 message: Message::Assistant(Arc::clone(&shared)),
2204 });
2205 sent_start = true;
2206 }
2207 on_event(AgentEvent::MessageUpdate {
2208 message: Message::Assistant(Arc::clone(&shared)),
2209 assistant_message_event: AssistantMessageEvent::TextStart {
2210 content_index,
2211 partial: shared,
2212 },
2213 });
2214 }
2215 }
2216 StreamEvent::TextDelta {
2217 content_index,
2218 delta,
2219 ..
2220 } => {
2221 self.seed_partial_message_if_missing(&mut added_partial);
2222 if let Some(Message::Assistant(msg_arc)) = self
2223 .messages
2224 .iter_mut()
2225 .rev()
2226 .find(|m| matches!(m, Message::Assistant(_)))
2227 {
2228 {
2229 let msg = Arc::make_mut(msg_arc);
2230 if msg.content.get(content_index).is_none()
2231 && content_index == msg.content.len()
2232 {
2233 msg.content.push(ContentBlock::Text(TextContent::new("")));
2234 }
2235 if let Some(ContentBlock::Text(text)) =
2236 msg.content.get_mut(content_index)
2237 {
2238 text.text.push_str(&delta);
2239 }
2240 }
2241 let shared = Arc::clone(msg_arc);
2242 if !sent_start {
2243 on_event(AgentEvent::MessageStart {
2244 message: Message::Assistant(Arc::clone(&shared)),
2245 });
2246 sent_start = true;
2247 }
2248 on_event(AgentEvent::MessageUpdate {
2249 message: Message::Assistant(Arc::clone(&shared)),
2250 assistant_message_event: AssistantMessageEvent::TextDelta {
2251 content_index,
2252 delta,
2253 partial: shared,
2254 },
2255 });
2256 }
2257 }
2258 StreamEvent::TextEnd {
2259 content_index,
2260 content,
2261 ..
2262 } => {
2263 self.seed_partial_message_if_missing(&mut added_partial);
2264 if let Some(Message::Assistant(msg_arc)) = self
2265 .messages
2266 .iter_mut()
2267 .rev()
2268 .find(|m| matches!(m, Message::Assistant(_)))
2269 {
2270 {
2271 let msg = Arc::make_mut(msg_arc);
2272 if msg.content.get(content_index).is_none()
2273 && content_index == msg.content.len()
2274 {
2275 msg.content.push(ContentBlock::Text(TextContent::new("")));
2276 }
2277 if let Some(ContentBlock::Text(text)) =
2278 msg.content.get_mut(content_index)
2279 {
2280 text.text.clone_from(&content);
2281 }
2282 }
2283 let shared = Arc::clone(msg_arc);
2284 if !sent_start {
2285 on_event(AgentEvent::MessageStart {
2286 message: Message::Assistant(Arc::clone(&shared)),
2287 });
2288 sent_start = true;
2289 }
2290 on_event(AgentEvent::MessageUpdate {
2291 message: Message::Assistant(Arc::clone(&shared)),
2292 assistant_message_event: AssistantMessageEvent::TextEnd {
2293 content_index,
2294 content,
2295 partial: shared,
2296 },
2297 });
2298 }
2299 }
2300 StreamEvent::ThinkingStart { content_index, .. } => {
2301 self.seed_partial_message_if_missing(&mut added_partial);
2302 if let Some(Message::Assistant(msg_arc)) = self
2303 .messages
2304 .iter_mut()
2305 .rev()
2306 .find(|m| matches!(m, Message::Assistant(_)))
2307 {
2308 let msg = Arc::make_mut(msg_arc);
2309 if content_index == msg.content.len() {
2310 msg.content.push(ContentBlock::Thinking(ThinkingContent {
2311 thinking: String::new(),
2312 thinking_signature: None,
2313 }));
2314 }
2315 let shared = Arc::clone(msg_arc);
2316 if !sent_start {
2317 on_event(AgentEvent::MessageStart {
2318 message: Message::Assistant(Arc::clone(&shared)),
2319 });
2320 sent_start = true;
2321 }
2322 on_event(AgentEvent::MessageUpdate {
2323 message: Message::Assistant(Arc::clone(&shared)),
2324 assistant_message_event: AssistantMessageEvent::ThinkingStart {
2325 content_index,
2326 partial: shared,
2327 },
2328 });
2329 }
2330 }
2331 StreamEvent::ThinkingDelta {
2332 content_index,
2333 delta,
2334 ..
2335 } => {
2336 self.seed_partial_message_if_missing(&mut added_partial);
2337 if let Some(Message::Assistant(msg_arc)) = self
2338 .messages
2339 .iter_mut()
2340 .rev()
2341 .find(|m| matches!(m, Message::Assistant(_)))
2342 {
2343 {
2344 let msg = Arc::make_mut(msg_arc);
2345 if msg.content.get(content_index).is_none()
2346 && content_index == msg.content.len()
2347 {
2348 msg.content.push(ContentBlock::Thinking(ThinkingContent {
2349 thinking: String::new(),
2350 thinking_signature: None,
2351 }));
2352 }
2353 if let Some(ContentBlock::Thinking(thinking)) =
2354 msg.content.get_mut(content_index)
2355 {
2356 thinking.thinking.push_str(&delta);
2357 }
2358 }
2359 let shared = Arc::clone(msg_arc);
2360 if !sent_start {
2361 on_event(AgentEvent::MessageStart {
2362 message: Message::Assistant(Arc::clone(&shared)),
2363 });
2364 sent_start = true;
2365 }
2366 on_event(AgentEvent::MessageUpdate {
2367 message: Message::Assistant(Arc::clone(&shared)),
2368 assistant_message_event: AssistantMessageEvent::ThinkingDelta {
2369 content_index,
2370 delta,
2371 partial: shared,
2372 },
2373 });
2374 }
2375 }
2376 StreamEvent::ThinkingEnd {
2377 content_index,
2378 content,
2379 ..
2380 } => {
2381 self.seed_partial_message_if_missing(&mut added_partial);
2382 if let Some(Message::Assistant(msg_arc)) = self
2383 .messages
2384 .iter_mut()
2385 .rev()
2386 .find(|m| matches!(m, Message::Assistant(_)))
2387 {
2388 {
2389 let msg = Arc::make_mut(msg_arc);
2390 if msg.content.get(content_index).is_none()
2391 && content_index == msg.content.len()
2392 {
2393 msg.content.push(ContentBlock::Thinking(ThinkingContent {
2394 thinking: String::new(),
2395 thinking_signature: None,
2396 }));
2397 }
2398 if let Some(ContentBlock::Thinking(thinking)) =
2399 msg.content.get_mut(content_index)
2400 {
2401 thinking.thinking.clone_from(&content);
2402 }
2403 }
2404 let shared = Arc::clone(msg_arc);
2405 if !sent_start {
2406 on_event(AgentEvent::MessageStart {
2407 message: Message::Assistant(Arc::clone(&shared)),
2408 });
2409 sent_start = true;
2410 }
2411 on_event(AgentEvent::MessageUpdate {
2412 message: Message::Assistant(Arc::clone(&shared)),
2413 assistant_message_event: AssistantMessageEvent::ThinkingEnd {
2414 content_index,
2415 content,
2416 partial: shared,
2417 },
2418 });
2419 }
2420 }
2421 StreamEvent::ToolCallStart { content_index, .. } => {
2422 self.seed_partial_message_if_missing(&mut added_partial);
2423 if let Some(Message::Assistant(msg_arc)) = self
2424 .messages
2425 .iter_mut()
2426 .rev()
2427 .find(|m| matches!(m, Message::Assistant(_)))
2428 {
2429 let msg = Arc::make_mut(msg_arc);
2430 if content_index == msg.content.len() {
2431 msg.content.push(ContentBlock::ToolCall(ToolCall {
2432 id: String::new(),
2433 name: String::new(),
2434 arguments: serde_json::Value::Null,
2435 thought_signature: None,
2436 }));
2437 }
2438 let shared = Arc::clone(msg_arc);
2439 if !sent_start {
2440 on_event(AgentEvent::MessageStart {
2441 message: Message::Assistant(Arc::clone(&shared)),
2442 });
2443 sent_start = true;
2444 }
2445 on_event(AgentEvent::MessageUpdate {
2446 message: Message::Assistant(Arc::clone(&shared)),
2447 assistant_message_event: AssistantMessageEvent::ToolCallStart {
2448 content_index,
2449 partial: shared,
2450 },
2451 });
2452 }
2453 }
2454 StreamEvent::ToolCallDelta {
2455 content_index,
2456 delta,
2457 ..
2458 } => {
2459 self.seed_partial_message_if_missing(&mut added_partial);
2460 if let Some(Message::Assistant(msg_arc)) = self
2461 .messages
2462 .iter_mut()
2463 .rev()
2464 .find(|m| matches!(m, Message::Assistant(_)))
2465 {
2466 if msg_arc.content.get(content_index).is_none()
2467 && content_index == msg_arc.content.len()
2468 {
2469 let msg = Arc::make_mut(msg_arc);
2470 msg.content.push(ContentBlock::ToolCall(ToolCall {
2471 id: String::new(),
2472 name: String::new(),
2473 arguments: serde_json::Value::Null,
2474 thought_signature: None,
2475 }));
2476 }
2477 let shared = Arc::clone(msg_arc);
2480 if !sent_start {
2481 on_event(AgentEvent::MessageStart {
2482 message: Message::Assistant(Arc::clone(&shared)),
2483 });
2484 sent_start = true;
2485 }
2486 on_event(AgentEvent::MessageUpdate {
2487 message: Message::Assistant(Arc::clone(&shared)),
2488 assistant_message_event: AssistantMessageEvent::ToolCallDelta {
2489 content_index,
2490 delta,
2491 partial: shared,
2492 },
2493 });
2494 }
2495 }
2496 StreamEvent::ToolCallEnd {
2497 content_index,
2498 tool_call,
2499 ..
2500 } => {
2501 self.seed_partial_message_if_missing(&mut added_partial);
2502 if let Some(Message::Assistant(msg_arc)) = self
2503 .messages
2504 .iter_mut()
2505 .rev()
2506 .find(|m| matches!(m, Message::Assistant(_)))
2507 {
2508 {
2509 let msg = Arc::make_mut(msg_arc);
2510 if msg.content.get(content_index).is_none()
2511 && content_index == msg.content.len()
2512 {
2513 msg.content.push(ContentBlock::ToolCall(ToolCall {
2514 id: String::new(),
2515 name: String::new(),
2516 arguments: serde_json::Value::Null,
2517 thought_signature: None,
2518 }));
2519 }
2520 if let Some(ContentBlock::ToolCall(tc)) =
2521 msg.content.get_mut(content_index)
2522 {
2523 *tc = tool_call.clone();
2524 }
2525 }
2526 let shared = Arc::clone(msg_arc);
2527 if !sent_start {
2528 on_event(AgentEvent::MessageStart {
2529 message: Message::Assistant(Arc::clone(&shared)),
2530 });
2531 sent_start = true;
2532 }
2533 on_event(AgentEvent::MessageUpdate {
2534 message: Message::Assistant(Arc::clone(&shared)),
2535 assistant_message_event: AssistantMessageEvent::ToolCallEnd {
2536 content_index,
2537 tool_call,
2538 partial: shared,
2539 },
2540 });
2541 }
2542 }
2543 StreamEvent::Done { message, .. } => {
2544 return Ok(self.finalize_assistant_message(message, &on_event, added_partial));
2545 }
2546 StreamEvent::Error { error, .. } => {
2547 return Ok(self.finalize_assistant_message(error, &on_event, added_partial));
2548 }
2549 }
2550 }
2551
2552 if added_partial {
2556 if let Some(Message::Assistant(last_msg)) = self
2557 .messages
2558 .iter()
2559 .rev()
2560 .find(|m| matches!(m, Message::Assistant(_)))
2561 {
2562 let mut final_msg = (**last_msg).clone();
2563 final_msg.stop_reason = StopReason::Error;
2564 final_msg.error_message = Some("Stream ended without Done event".to_string());
2565 return Ok(self.finalize_assistant_message(final_msg, &on_event, true));
2566 }
2567 }
2568 Err(Error::api("Stream ended without Done event"))
2569 }
2570
2571 fn seed_partial_message_if_missing(&mut self, added_partial: &mut bool) {
2576 if *added_partial {
2577 return;
2578 }
2579
2580 let message = AssistantMessage {
2581 content: Vec::new(),
2582 api: self.provider.api().to_string(),
2583 provider: self.provider.name().to_string(),
2584 model: self.provider.model_id().to_string(),
2585 usage: Usage::default(),
2586 stop_reason: StopReason::Stop,
2587 error_message: None,
2588 timestamp: Utc::now().timestamp_millis(),
2589 };
2590 self.messages.push(Message::Assistant(Arc::new(message)));
2591 *added_partial = true;
2592 }
2593
2594 fn update_partial_message(
2599 &mut self,
2600 partial: Arc<AssistantMessage>,
2601 added_partial: &mut bool,
2602 ) -> bool {
2603 if *added_partial {
2604 if let Some(target) = self
2605 .messages
2606 .iter_mut()
2607 .rev()
2608 .find(|m| matches!(m, Message::Assistant(_)))
2609 {
2610 *target = Message::Assistant(partial);
2611 } else {
2612 tracing::warn!("update_partial_message: expected an Assistant message in history");
2615 self.messages.push(Message::Assistant(partial));
2616 }
2617 false
2618 } else {
2619 self.messages.push(Message::Assistant(partial));
2620 *added_partial = true;
2621 true
2622 }
2623 }
2624
2625 fn finalize_assistant_message(
2626 &mut self,
2627 message: AssistantMessage,
2628 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
2629 added_partial: bool,
2630 ) -> AssistantMessage {
2631 let arc = Arc::new(message);
2632 if added_partial {
2633 if let Some(target) = self
2634 .messages
2635 .iter_mut()
2636 .rev()
2637 .find(|m| matches!(m, Message::Assistant(_)))
2638 {
2639 *target = Message::Assistant(Arc::clone(&arc));
2640 } else {
2641 tracing::warn!(
2644 "finalize_assistant_message: expected an Assistant message in history"
2645 );
2646 self.messages.push(Message::Assistant(Arc::clone(&arc)));
2647 on_event(AgentEvent::MessageStart {
2648 message: Message::Assistant(Arc::clone(&arc)),
2649 });
2650 }
2651 } else {
2652 self.messages.push(Message::Assistant(Arc::clone(&arc)));
2653 on_event(AgentEvent::MessageStart {
2654 message: Message::Assistant(Arc::clone(&arc)),
2655 });
2656 }
2657
2658 on_event(AgentEvent::MessageEnd {
2659 message: Message::Assistant(Arc::clone(&arc)),
2660 });
2661 Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone())
2662 }
2663
2664 async fn execute_tool_batch(
2665 &self,
2666 batch: Vec<(usize, ToolCall)>,
2667 on_event: AgentEventHandler,
2668 abort: Option<AbortSignal>,
2669 latency: SharedTurnLatencyAccumulator,
2670 ) -> Vec<(usize, (ToolOutput, bool))> {
2671 let parallelism = compatible_tool_parallelism_limit();
2672 let futures = batch.into_iter().map(|(idx, tc)| {
2673 let on_event = Arc::clone(&on_event);
2674 let latency = Arc::clone(&latency);
2675 async move { (idx, self.execute_tool_owned(tc, on_event, latency).await) }
2676 });
2677
2678 if let Some(signal) = abort.as_ref() {
2679 use futures::future::{Either, select};
2680 let all_fut = stream::iter(futures)
2681 .buffer_unordered(parallelism)
2682 .collect::<Vec<_>>()
2683 .fuse();
2684 let abort_fut = signal.wait().fuse();
2685 futures::pin_mut!(all_fut, abort_fut);
2686
2687 match select(all_fut, abort_fut).await {
2688 Either::Left((batch_results, _)) => batch_results,
2689 Either::Right(_) => Vec::new(), }
2691 } else {
2692 stream::iter(futures)
2693 .buffer_unordered(parallelism)
2694 .collect::<Vec<_>>()
2695 .await
2696 }
2697 }
2698
2699 #[allow(clippy::too_many_lines)]
2700 async fn execute_tool_calls(
2701 &mut self,
2702 tool_calls: &[ToolCall],
2703 on_event: AgentEventHandler,
2704 new_messages: &mut Vec<Message>,
2705 abort: Option<AbortSignal>,
2706 latency: SharedTurnLatencyAccumulator,
2707 ) -> Result<ToolExecutionOutcome> {
2708 let mut results = Vec::new();
2709 let mut steering_messages: Option<Vec<Message>> = None;
2710
2711 for tool_call in tool_calls {
2713 on_event(AgentEvent::ToolExecutionStart {
2714 tool_call_id: tool_call.id.clone(),
2715 tool_name: tool_call.name.clone(),
2716 args: tool_call.arguments.clone(),
2717 });
2718 }
2719
2720 let effect_plan = tool_calls
2722 .iter()
2723 .map(|tool_call| {
2724 self.tools
2725 .get(&tool_call.name)
2726 .map_or_else(ToolEffects::write, Tool::effects)
2727 })
2728 .collect::<Vec<_>>();
2729 let effect_batches = plan_tool_effect_batches(&effect_plan);
2730 let mut recorded_results: Vec<Option<Arc<ToolResultMessage>>> =
2731 vec![None; tool_calls.len()];
2732
2733 for effect_batch in effect_batches {
2734 if abort.as_ref().is_some_and(AbortSignal::is_aborted) {
2735 break;
2736 }
2737
2738 let steering = self.drain_steering_messages().await;
2739 if !steering.is_empty() {
2740 steering_messages = Some(steering);
2741 break;
2742 }
2743
2744 let batch_len = effect_batch.end.saturating_sub(effect_batch.start);
2745 let batch = tool_calls
2746 .iter()
2747 .cloned()
2748 .enumerate()
2749 .skip(effect_batch.start)
2750 .take(batch_len)
2751 .collect();
2752 let mut batch_results = self
2753 .execute_tool_batch(
2754 batch,
2755 Arc::clone(&on_event),
2756 abort.clone(),
2757 Arc::clone(&latency),
2758 )
2759 .await;
2760 batch_results.sort_by_key(|(idx, _)| *idx);
2761 for (idx, (output, is_error)) in batch_results {
2762 if let (Some(tool_call), Some(recorded_result)) =
2763 (tool_calls.get(idx), recorded_results.get_mut(idx))
2764 {
2765 *recorded_result = Some(self.record_tool_result(
2766 tool_call,
2767 output,
2768 is_error,
2769 &on_event,
2770 new_messages,
2771 ));
2772 }
2773 }
2774 }
2775
2776 for (index, tool_call) in tool_calls.iter().enumerate() {
2778 if steering_messages.is_none() && !abort.as_ref().is_some_and(AbortSignal::is_aborted) {
2781 let steering = self.drain_steering_messages().await;
2782 if !steering.is_empty() {
2783 steering_messages = Some(steering);
2784 }
2785 }
2786
2787 if let Some(tool_result) = recorded_results.get_mut(index).and_then(Option::take) {
2790 results.push(tool_result);
2791 } else if steering_messages.is_some() {
2792 results.push(self.skip_tool_call(tool_call, &on_event, new_messages));
2794 } else {
2795 let output = ToolOutput {
2797 content: vec![ContentBlock::Text(TextContent::new(
2798 "Tool execution aborted",
2799 ))],
2800 details: Some(Self::tool_cancellation_details(
2801 &tool_call.name,
2802 "abort_signal",
2803 )),
2804 is_error: true,
2805 };
2806
2807 on_event(AgentEvent::ToolExecutionUpdate {
2808 tool_call_id: tool_call.id.clone(),
2809 tool_name: tool_call.name.clone(),
2810 args: tool_call.arguments.clone(),
2811 partial_result: ToolOutput {
2812 content: output.content.clone(),
2813 details: output.details.clone(),
2814 is_error: true,
2815 },
2816 });
2817
2818 on_event(AgentEvent::ToolExecutionEnd {
2819 tool_call_id: tool_call.id.clone(),
2820 tool_name: tool_call.name.clone(),
2821 result: ToolOutput {
2822 content: output.content.clone(),
2823 details: output.details.clone(),
2824 is_error: true,
2825 },
2826 is_error: true,
2827 });
2828
2829 let tool_result = Arc::new(ToolResultMessage {
2830 tool_call_id: tool_call.id.clone(),
2831 tool_name: tool_call.name.clone(),
2832 content: output.content,
2833 details: output.details,
2834 is_error: true,
2835 timestamp: Utc::now().timestamp_millis(),
2836 });
2837
2838 let msg = Message::ToolResult(Arc::clone(&tool_result));
2839 self.messages.push(msg.clone());
2840 on_event(AgentEvent::MessageStart {
2841 message: msg.clone(),
2842 });
2843 let end_msg = msg.clone();
2844 new_messages.push(msg);
2845 on_event(AgentEvent::MessageEnd { message: end_msg });
2846
2847 results.push(tool_result);
2848 }
2849 }
2850
2851 Ok(ToolExecutionOutcome {
2852 tool_results: results,
2853 steering_messages,
2854 })
2855 }
2856
2857 fn record_tool_result(
2858 &mut self,
2859 tool_call: &ToolCall,
2860 output: ToolOutput,
2861 is_error: bool,
2862 on_event: &AgentEventHandler,
2863 new_messages: &mut Vec<Message>,
2864 ) -> Arc<ToolResultMessage> {
2865 on_event(AgentEvent::ToolExecutionUpdate {
2866 tool_call_id: tool_call.id.clone(),
2867 tool_name: tool_call.name.clone(),
2868 args: tool_call.arguments.clone(),
2869 partial_result: ToolOutput {
2870 content: output.content.clone(),
2871 details: output.details.clone(),
2872 is_error,
2873 },
2874 });
2875
2876 let tool_result = Arc::new(ToolResultMessage {
2877 tool_call_id: tool_call.id.clone(),
2878 tool_name: tool_call.name.clone(),
2879 content: output.content,
2880 details: output.details,
2881 is_error,
2882 timestamp: Utc::now().timestamp_millis(),
2883 });
2884
2885 on_event(AgentEvent::ToolExecutionEnd {
2886 tool_call_id: tool_result.tool_call_id.clone(),
2887 tool_name: tool_result.tool_name.clone(),
2888 result: ToolOutput {
2889 content: tool_result.content.clone(),
2890 details: tool_result.details.clone(),
2891 is_error,
2892 },
2893 is_error,
2894 });
2895
2896 let msg = Message::ToolResult(Arc::clone(&tool_result));
2897 self.messages.push(msg.clone());
2898 on_event(AgentEvent::MessageStart {
2899 message: msg.clone(),
2900 });
2901 new_messages.push(msg.clone());
2902 on_event(AgentEvent::MessageEnd { message: msg });
2903
2904 tool_result
2905 }
2906
2907 async fn execute_tool(
2908 &self,
2909 tool_call: ToolCall,
2910 on_event: AgentEventHandler,
2911 latency: SharedTurnLatencyAccumulator,
2912 ) -> (ToolOutput, bool) {
2913 let extensions = self.extensions.clone();
2914
2915 let approval_denied_output = self
2916 .request_tool_approval(&tool_call, Arc::clone(&on_event))
2917 .await;
2918
2919 let (mut output, is_error) = if let Some(output) = approval_denied_output {
2920 (output, true)
2921 } else if let Some(extensions) = &extensions {
2922 let hook_started_at = Instant::now();
2923 let hook_outcome = Self::dispatch_tool_call_hook(
2924 extensions,
2925 &tool_call,
2926 self.config.fail_closed_hooks,
2927 )
2928 .await;
2929 record_extension_hostcall_latency(&latency, hook_started_at.elapsed());
2930
2931 if let Some(blocked_output) = hook_outcome {
2932 (blocked_output, true)
2933 } else {
2934 let tool_started_at = Instant::now();
2935 let outcome = self
2936 .execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
2937 .await;
2938 record_local_tool_latency(&latency, tool_started_at.elapsed());
2939 outcome
2940 }
2941 } else {
2942 let tool_started_at = Instant::now();
2943 let outcome = self
2944 .execute_tool_without_hooks(&tool_call, Arc::clone(&on_event))
2945 .await;
2946 record_local_tool_latency(&latency, tool_started_at.elapsed());
2947 outcome
2948 };
2949
2950 if let Some(extensions) = &extensions {
2951 let hook_started_at = Instant::now();
2952 Self::apply_tool_result_hook(extensions, &tool_call, &mut output, is_error).await;
2953 record_extension_hostcall_latency(&latency, hook_started_at.elapsed());
2954 }
2955
2956 (output, is_error)
2957 }
2958
2959 async fn request_tool_approval(
2960 &self,
2961 tool_call: &ToolCall,
2962 on_event: AgentEventHandler,
2963 ) -> Option<ToolOutput> {
2964 let Some(approval) = &self.config.tool_approval else {
2965 return None;
2966 };
2967
2968 let request = ToolApprovalRequest {
2969 tool_call_id: tool_call.id.clone(),
2970 tool_name: tool_call.name.clone(),
2971 arguments: tool_call.arguments.clone(),
2972 };
2973
2974 match approval(request).await {
2975 ToolApprovalDecision::Allow => {
2976 on_event(AgentEvent::ToolExecutionUpdate {
2977 tool_call_id: tool_call.id.clone(),
2978 tool_name: tool_call.name.clone(),
2979 args: tool_call.arguments.clone(),
2980 partial_result: ToolOutput {
2981 content: Vec::new(),
2982 details: Some(json!({
2983 "schema": TOOL_APPROVAL_STATUS_SCHEMA_V1,
2984 "status": "approved",
2985 })),
2986 is_error: false,
2987 },
2988 });
2989 None
2990 }
2991 ToolApprovalDecision::Deny { reason } => {
2992 Some(Self::tool_approval_denied_output(&reason))
2993 }
2994 }
2995 }
2996
2997 async fn execute_tool_owned(
2998 &self,
2999 tool_call: ToolCall,
3000 on_event: AgentEventHandler,
3001 latency: SharedTurnLatencyAccumulator,
3002 ) -> (ToolOutput, bool) {
3003 self.execute_tool(tool_call, on_event, latency).await
3004 }
3005
3006 async fn execute_tool_without_hooks(
3007 &self,
3008 tool_call: &ToolCall,
3009 on_event: AgentEventHandler,
3010 ) -> (ToolOutput, bool) {
3011 let Some(tool) = self.tools.get(&tool_call.name) else {
3013 return (Self::tool_not_found_output(&tool_call.name), true);
3014 };
3015
3016 let tool_name = tool_call.name.clone();
3017 let tool_id = tool_call.id.clone();
3018 let tool_args = tool_call.arguments.clone();
3019 let on_event = Arc::clone(&on_event);
3020
3021 let update_callback = move |update: ToolUpdate| {
3022 on_event(AgentEvent::ToolExecutionUpdate {
3023 tool_call_id: tool_id.clone(),
3024 tool_name: tool_name.clone(),
3025 args: tool_args.clone(),
3026 partial_result: ToolOutput {
3027 content: update.content,
3028 details: update.details,
3029 is_error: false,
3030 },
3031 });
3032 };
3033
3034 let _artifact_session_guard =
3035 self.config
3036 .stream_options
3037 .session_id
3038 .as_deref()
3039 .map(|session_id| {
3040 crate::tools::register_tool_output_artifact_session(&tool_call.id, session_id)
3041 });
3042
3043 match tool
3044 .execute(
3045 &tool_call.id,
3046 tool_call.arguments.clone(),
3047 Some(Box::new(update_callback)),
3048 )
3049 .await
3050 {
3051 Ok(output) => {
3052 let is_error = output.is_error;
3053 (output, is_error)
3054 }
3055 Err(e) => (
3056 ToolOutput {
3057 content: vec![ContentBlock::Text(TextContent::new(format!("Error: {e}")))],
3058 details: None,
3059 is_error: true,
3060 },
3061 true,
3062 ),
3063 }
3064 }
3065
3066 fn tool_not_found_output(tool_name: &str) -> ToolOutput {
3067 ToolOutput {
3068 content: vec![ContentBlock::Text(TextContent::new(format!(
3069 "Error: Tool '{tool_name}' not found"
3070 )))],
3071 details: None,
3072 is_error: true,
3073 }
3074 }
3075
3076 fn tool_cancellation_details(tool_name: &str, reason: &str) -> Value {
3077 json!({
3078 "schema": TOOL_CANCELLATION_SCHEMA_V1,
3079 "status": "cancelled",
3080 "reason": reason,
3081 "toolName": tool_name,
3082 "cleanup": "tool_result_recorded_no_success",
3083 })
3084 }
3085
3086 async fn dispatch_tool_call_hook(
3087 extensions: &ExtensionManager,
3088 tool_call: &ToolCall,
3089 fail_closed_hooks: bool,
3090 ) -> Option<ToolOutput> {
3091 match extensions
3092 .dispatch_tool_call(tool_call, EXTENSION_EVENT_TIMEOUT_MS)
3093 .await
3094 {
3095 Ok(Some(result)) if result.block => {
3096 Some(Self::tool_call_blocked_output(result.reason.as_deref()))
3097 }
3098 Ok(_) => None,
3099 Err(err) => {
3100 if fail_closed_hooks {
3101 tracing::warn!(
3102 error = ?err,
3103 "tool_call extension hook failed (fail-closed)"
3104 );
3105 Some(Self::tool_call_blocked_output(Some(
3106 "extension hook failed",
3107 )))
3108 } else {
3109 tracing::warn!("tool_call extension hook failed (fail-open): {err}");
3110 None
3111 }
3112 }
3113 }
3114 }
3115
3116 fn tool_call_blocked_output(reason: Option<&str>) -> ToolOutput {
3117 let reason = reason.map(str::trim).filter(|reason| !reason.is_empty());
3118 let message = reason.map_or_else(
3119 || "Tool execution was blocked by an extension".to_string(),
3120 |reason| format!("Tool execution blocked: {reason}"),
3121 );
3122
3123 ToolOutput {
3124 content: vec![ContentBlock::Text(TextContent::new(message))],
3125 details: None,
3126 is_error: true,
3127 }
3128 }
3129
3130 fn tool_approval_denied_output(reason: &str) -> ToolOutput {
3131 let reason = reason.trim();
3132 let reason = if reason.is_empty() {
3133 "tool approval denied"
3134 } else {
3135 reason
3136 };
3137
3138 ToolOutput {
3139 content: vec![ContentBlock::Text(TextContent::new(format!(
3140 "Tool execution denied: {reason}"
3141 )))],
3142 details: Some(json!({
3143 "schema": TOOL_APPROVAL_DENIED_SCHEMA_V1,
3144 "status": "denied",
3145 "reason": reason,
3146 })),
3147 is_error: true,
3148 }
3149 }
3150
3151 async fn apply_tool_result_hook(
3152 extensions: &ExtensionManager,
3153 tool_call: &ToolCall,
3154 output: &mut ToolOutput,
3155 is_error: bool,
3156 ) {
3157 match extensions
3158 .dispatch_tool_result(tool_call, &*output, is_error, EXTENSION_EVENT_TIMEOUT_MS)
3159 .await
3160 {
3161 Ok(Some(result)) => {
3162 if let Some(content) = result.content {
3163 output.content = content;
3164 }
3165 if let Some(details) = result.details {
3166 output.details = Some(details);
3167 }
3168 }
3169 Ok(None) => {}
3170 Err(err) => tracing::warn!("tool_result extension hook failed (fail-open): {err}"),
3171 }
3172 }
3173
3174 fn skip_tool_call(
3175 &mut self,
3176 tool_call: &ToolCall,
3177 on_event: &Arc<dyn Fn(AgentEvent) + Send + Sync>,
3178 new_messages: &mut Vec<Message>,
3179 ) -> Arc<ToolResultMessage> {
3180 let output = ToolOutput {
3181 content: vec![ContentBlock::Text(TextContent::new(
3182 "Skipped due to queued user message.",
3183 ))],
3184 details: None,
3185 is_error: true,
3186 };
3187
3188 on_event(AgentEvent::ToolExecutionUpdate {
3191 tool_call_id: tool_call.id.clone(),
3192 tool_name: tool_call.name.clone(),
3193 args: tool_call.arguments.clone(),
3194 partial_result: output.clone(),
3195 });
3196 on_event(AgentEvent::ToolExecutionEnd {
3197 tool_call_id: tool_call.id.clone(),
3198 tool_name: tool_call.name.clone(),
3199 result: output.clone(),
3200 is_error: true,
3201 });
3202
3203 let tool_result = Arc::new(ToolResultMessage {
3204 tool_call_id: tool_call.id.clone(),
3205 tool_name: tool_call.name.clone(),
3206 content: output.content,
3207 details: output.details,
3208 is_error: true,
3209 timestamp: Utc::now().timestamp_millis(),
3210 });
3211
3212 let msg = Message::ToolResult(Arc::clone(&tool_result));
3213 self.messages.push(msg.clone());
3214 new_messages.push(msg.clone());
3215
3216 on_event(AgentEvent::MessageStart {
3217 message: msg.clone(),
3218 });
3219 on_event(AgentEvent::MessageEnd { message: msg });
3220
3221 tool_result
3222 }
3223}
3224
3225struct ToolExecutionOutcome {
3230 tool_results: Vec<Arc<ToolResultMessage>>,
3231 steering_messages: Option<Vec<Message>>,
3232}
3233
3234pub struct PreWarmedExtensionRuntime {
3239 pub manager: ExtensionManager,
3241 pub runtime: ExtensionRuntimeHandle,
3243 pub tools: Arc<ToolRegistry>,
3245}
3246
3247struct AtomicBoolGuard(Arc<AtomicBool>);
3250
3251impl AtomicBoolGuard {
3252 fn activate(flag: &Arc<AtomicBool>) -> Self {
3253 flag.store(true, Ordering::SeqCst);
3254 Self(Arc::clone(flag))
3255 }
3256}
3257
3258impl Drop for AtomicBoolGuard {
3259 fn drop(&mut self) {
3260 self.0.store(false, Ordering::SeqCst);
3261 }
3262}
3263
3264pub struct AgentSession {
3265 pub agent: Agent,
3266 pub session: Arc<Mutex<Session>>,
3267 save_enabled: bool,
3268 input_source: InputSource,
3269 pub extensions: Option<ExtensionRegion>,
3272 extensions_is_streaming: Arc<AtomicBool>,
3273 extensions_is_compacting: Arc<AtomicBool>,
3274 extensions_turn_active: Arc<AtomicBool>,
3275 extensions_pending_idle_actions: Arc<StdMutex<VecDeque<PendingIdleAction>>>,
3276 extension_queue_modes: Option<Arc<StdMutex<ExtensionQueueModeState>>>,
3277 extension_injected_queue: Option<Arc<StdMutex<ExtensionInjectedQueue>>>,
3278 extension_ai_completion: Arc<StdMutex<ExtensionAiCompletionHostState>>,
3279 compaction_settings: ResolvedCompactionSettings,
3280 compaction_runtime: Option<Runtime>,
3281 runtime_handle: Option<RuntimeHandle>,
3282 compaction_worker: CompactionWorkerState,
3283 model_registry: Option<ModelRegistry>,
3284 auth_storage: Option<AuthStorage>,
3285 api_key_override: Option<String>,
3286 semantic_context_bundle: Option<SemanticContextBundleInjection>,
3287}
3288
3289#[derive(Debug, Clone, Copy)]
3290struct ExtensionQueueModeState {
3291 steering_mode: QueueMode,
3292 follow_up_mode: QueueMode,
3293}
3294
3295impl ExtensionQueueModeState {
3296 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
3297 Self {
3298 steering_mode,
3299 follow_up_mode,
3300 }
3301 }
3302
3303 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
3304 self.steering_mode = steering_mode;
3305 self.follow_up_mode = follow_up_mode;
3306 }
3307}
3308
3309#[derive(Debug)]
3310struct ExtensionInjectedQueue {
3311 steering: VecDeque<Message>,
3312 follow_up: VecDeque<Message>,
3313 steering_mode: QueueMode,
3314 follow_up_mode: QueueMode,
3315}
3316
3317impl ExtensionInjectedQueue {
3318 const fn new(steering_mode: QueueMode, follow_up_mode: QueueMode) -> Self {
3319 Self {
3320 steering: VecDeque::new(),
3321 follow_up: VecDeque::new(),
3322 steering_mode,
3323 follow_up_mode,
3324 }
3325 }
3326
3327 const fn set_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
3328 self.steering_mode = steering_mode;
3329 self.follow_up_mode = follow_up_mode;
3330 }
3331
3332 fn push_steering(&mut self, message: Message) {
3333 if self.steering.len() >= MAX_STEERING_QUEUE_SIZE {
3334 tracing::warn!(
3335 "Extension steering queue full ({} messages), dropping oldest message",
3336 MAX_STEERING_QUEUE_SIZE
3337 );
3338 self.steering.pop_front();
3339 }
3340 self.steering.push_back(message);
3341 }
3342
3343 fn push_follow_up(&mut self, message: Message) {
3344 if self.follow_up.len() >= MAX_FOLLOW_UP_QUEUE_SIZE {
3345 tracing::warn!(
3346 "Extension follow-up queue full ({} messages), dropping oldest message",
3347 MAX_FOLLOW_UP_QUEUE_SIZE
3348 );
3349 self.follow_up.pop_front();
3350 }
3351 self.follow_up.push_back(message);
3352 }
3353
3354 fn pop_steering(&mut self) -> Vec<Message> {
3355 match self.steering_mode {
3356 QueueMode::All => self.steering.drain(..).collect(),
3357 QueueMode::OneAtATime => self.steering.pop_front().into_iter().collect(),
3358 }
3359 }
3360
3361 fn pop_follow_up(&mut self) -> Vec<Message> {
3362 match self.follow_up_mode {
3363 QueueMode::All => self.follow_up.drain(..).collect(),
3364 QueueMode::OneAtATime => self.follow_up.pop_front().into_iter().collect(),
3365 }
3366 }
3367}
3368
3369impl Default for ExtensionInjectedQueue {
3370 fn default() -> Self {
3371 Self::new(QueueMode::OneAtATime, QueueMode::OneAtATime)
3372 }
3373}
3374
3375#[derive(Debug)]
3376enum PendingIdleAction {
3377 CustomMessage(Message),
3378 UserText(String),
3379}
3380
3381#[derive(Clone)]
3382struct AgentSessionHostActions {
3383 session: Arc<Mutex<Session>>,
3384 injected: Arc<StdMutex<ExtensionInjectedQueue>>,
3385 is_streaming: Arc<AtomicBool>,
3386 is_turn_active: Arc<AtomicBool>,
3387 pending_idle_actions: Arc<StdMutex<VecDeque<PendingIdleAction>>>,
3388 ai_completion: Arc<StdMutex<ExtensionAiCompletionHostState>>,
3389}
3390
3391#[derive(Clone)]
3392struct ExtensionAiCompletionHostState {
3393 provider: Arc<dyn Provider>,
3394 stream_options: StreamOptions,
3395 models: Vec<Value>,
3396}
3397
3398impl AgentSessionHostActions {
3399 fn enqueue(&self, deliver_as: Option<ExtensionDeliverAs>, message: Message) {
3400 let deliver_as = deliver_as.unwrap_or(ExtensionDeliverAs::Steer);
3401 let Ok(mut queue) = self.injected.lock() else {
3402 tracing::error!("injected queue mutex poisoned; dropping extension message");
3403 return;
3404 };
3405 match deliver_as {
3406 ExtensionDeliverAs::FollowUp => {
3407 queue.push_follow_up(message);
3408 }
3409 ExtensionDeliverAs::Steer | ExtensionDeliverAs::NextTurn => {
3410 queue.push_steering(message);
3411 }
3412 }
3413 }
3414
3415 async fn append_to_session(&self, message: Message) -> Result<()> {
3416 let cx = crate::agent_cx::AgentCx::for_current_or_request();
3417 let mut session = self
3418 .session
3419 .lock(cx.cx())
3420 .await
3421 .map_err(|e| Error::session(e.to_string()))?;
3422 session.append_model_message(message);
3423 Ok(())
3424 }
3425
3426 fn queue_pending_idle_action(&self, action: PendingIdleAction) {
3427 let Ok(mut actions) = self.pending_idle_actions.lock() else {
3428 tracing::error!("pending idle actions mutex poisoned; dropping idle action");
3429 return;
3430 };
3431 actions.push_back(action);
3432 }
3433}
3434
3435#[async_trait]
3436impl ExtensionHostActions for AgentSessionHostActions {
3437 async fn send_message(&self, message: ExtensionSendMessage) -> Result<()> {
3438 let custom_message = Message::Custom(CustomMessage {
3439 content: message.content,
3440 custom_type: message.custom_type,
3441 display: message.display,
3442 details: message.details,
3443 timestamp: Utc::now().timestamp_millis(),
3444 });
3445
3446 if matches!(message.deliver_as, Some(ExtensionDeliverAs::NextTurn)) {
3447 return self.append_to_session(custom_message).await;
3448 }
3449
3450 if self.is_streaming.load(Ordering::SeqCst) {
3451 self.enqueue(message.deliver_as, custom_message);
3452 return Ok(());
3453 }
3454
3455 if self.is_turn_active.load(Ordering::SeqCst) {
3456 return self.append_to_session(custom_message).await;
3457 }
3458
3459 if message.trigger_turn {
3460 self.queue_pending_idle_action(PendingIdleAction::CustomMessage(custom_message));
3461 return Ok(());
3462 }
3463
3464 self.append_to_session(custom_message).await
3465 }
3466
3467 async fn send_user_message(&self, message: ExtensionSendUserMessage) -> Result<()> {
3468 let text = message.text;
3469 let user_message = Message::User(UserMessage {
3470 content: UserContent::Text(text.clone()),
3471 timestamp: Utc::now().timestamp_millis(),
3472 });
3473
3474 if self.is_streaming.load(Ordering::SeqCst) {
3475 self.enqueue(message.deliver_as, user_message);
3476 return Ok(());
3477 }
3478
3479 if self.is_turn_active.load(Ordering::SeqCst) {
3480 return self.append_to_session(user_message).await;
3481 }
3482
3483 self.queue_pending_idle_action(PendingIdleAction::UserText(text));
3484 Ok(())
3485 }
3486
3487 async fn complete_ai(&self, request: ExtensionAiCompletionRequest) -> Result<Value> {
3488 let (provider, mut stream_options) = {
3489 let state = self.ai_completion.lock().map_err(|_| {
3490 Error::extension("extension completion host state mutex poisoned".to_string())
3491 })?;
3492 (Arc::clone(&state.provider), state.stream_options.clone())
3493 };
3494
3495 apply_pi_ai_completion_options(&request.options, &mut stream_options)?;
3496 let context = build_pi_ai_completion_context(&request)?;
3497 let provider_name = provider.name().to_string();
3498 let mut events = provider.stream(&context, &stream_options).await?;
3499 let mut streamed_text = String::new();
3500
3501 while let Some(event) = events.next().await {
3502 match event.map_err(|err| Error::provider(provider_name.clone(), err.to_string()))? {
3503 StreamEvent::TextDelta { delta, .. } => streamed_text.push_str(&delta),
3504 StreamEvent::TextEnd { content, .. } => {
3505 streamed_text.push_str(&content);
3506 }
3507 StreamEvent::Done { message, .. } => {
3508 if message.stop_reason == StopReason::Error {
3509 return Err(Error::provider(
3510 provider_name,
3511 pi_ai_assistant_error_message(&message),
3512 ));
3513 }
3514 return pi_ai_completion_response(&message, request.simple);
3515 }
3516 StreamEvent::Error { error, .. } => {
3517 return Err(Error::provider(
3518 provider_name,
3519 pi_ai_assistant_error_message(&error),
3520 ));
3521 }
3522 StreamEvent::Start { .. }
3523 | StreamEvent::TextStart { .. }
3524 | StreamEvent::ThinkingStart { .. }
3525 | StreamEvent::ThinkingDelta { .. }
3526 | StreamEvent::ThinkingEnd { .. }
3527 | StreamEvent::ToolCallStart { .. }
3528 | StreamEvent::ToolCallDelta { .. }
3529 | StreamEvent::ToolCallEnd { .. } => {}
3530 }
3531 }
3532
3533 let suffix = if streamed_text.is_empty() {
3534 String::new()
3535 } else {
3536 format!(" after streaming {} text bytes", streamed_text.len())
3537 };
3538 Err(Error::provider(
3539 provider_name,
3540 format!("pi-ai completion stream ended without Done event{suffix}"),
3541 ))
3542 }
3543
3544 async fn list_ai_models(&self) -> Result<Value> {
3545 let state = self.ai_completion.lock().map_err(|_| {
3546 Error::extension("extension completion host state mutex poisoned".to_string())
3547 })?;
3548 if state.models.is_empty() {
3549 return Ok(json!([{
3550 "id": state.provider.model_id(),
3551 "name": state.provider.model_id(),
3552 "api": state.provider.api(),
3553 "provider": state.provider.name(),
3554 }]));
3555 }
3556 Ok(Value::Array(state.models.clone()))
3557 }
3558}
3559
3560fn pi_ai_model_entry_value(entry: &ModelEntry) -> Value {
3561 json!({
3562 "id": entry.model.id,
3563 "name": entry.model.name,
3564 "api": entry.model.api,
3565 "provider": entry.model.provider,
3566 "baseUrl": entry.model.base_url,
3567 "reasoning": entry.model.reasoning,
3568 "input": entry.model.input,
3569 "cost": entry.model.cost,
3570 "contextWindow": entry.model.context_window,
3571 "maxTokens": entry.model.max_tokens,
3572 "authHeader": entry.auth_header,
3573 "hasCredentials": entry.api_key.is_some(),
3574 })
3575}
3576
3577fn pi_ai_model_registry_values(registry: &ModelRegistry) -> Vec<Value> {
3578 registry
3579 .models()
3580 .iter()
3581 .map(pi_ai_model_entry_value)
3582 .collect()
3583}
3584
3585fn apply_pi_ai_completion_options(
3586 options: &Value,
3587 stream_options: &mut StreamOptions,
3588) -> Result<()> {
3589 if let Some(value) = options
3590 .get("temperature")
3591 .or_else(|| options.get("temp"))
3592 .filter(|value| !value.is_null())
3593 {
3594 let temperature = serde_json::from_value::<f32>(value.clone()).map_err(|err| {
3595 Error::validation(format!(
3596 "pi-ai completion temperature must be numeric: {err}"
3597 ))
3598 })?;
3599 if !(0.0..=2.0).contains(&temperature) {
3600 return Err(Error::validation(
3601 "pi-ai completion temperature must be between 0 and 2".to_string(),
3602 ));
3603 }
3604 stream_options.temperature = Some(temperature);
3605 }
3606
3607 if let Some(value) = options
3608 .get("maxTokens")
3609 .or_else(|| options.get("max_tokens"))
3610 .filter(|value| !value.is_null())
3611 {
3612 let raw = value.as_u64().ok_or_else(|| {
3613 Error::validation("pi-ai completion maxTokens must be an unsigned integer".to_string())
3614 })?;
3615 let max_tokens = u32::try_from(raw).map_err(|_| {
3616 Error::validation("pi-ai completion maxTokens exceeds u32::MAX".to_string())
3617 })?;
3618 if max_tokens == 0 {
3619 return Err(Error::validation(
3620 "pi-ai completion maxTokens must be greater than zero".to_string(),
3621 ));
3622 }
3623 stream_options.max_tokens = Some(max_tokens);
3624 }
3625
3626 Ok(())
3627}
3628
3629fn build_pi_ai_completion_context(
3630 request: &ExtensionAiCompletionRequest,
3631) -> Result<Context<'static>> {
3632 let mut system_prompts = Vec::new();
3633 let mut messages = Vec::new();
3634 collect_pi_ai_context_messages(&request.context, &mut system_prompts, &mut messages)?;
3635
3636 if messages.is_empty() {
3637 return Err(Error::validation(
3638 "@mariozechner/pi-ai completion requires at least one user or assistant message"
3639 .to_string(),
3640 ));
3641 }
3642
3643 let system_prompt = system_prompts
3644 .into_iter()
3645 .filter(|text| !text.trim().is_empty())
3646 .collect::<Vec<_>>()
3647 .join("\n\n");
3648 Ok(Context::owned(
3649 if system_prompt.is_empty() {
3650 None
3651 } else {
3652 Some(system_prompt)
3653 },
3654 messages,
3655 Vec::new(),
3656 ))
3657}
3658
3659fn collect_pi_ai_context_messages(
3660 value: &Value,
3661 system_prompts: &mut Vec<String>,
3662 messages: &mut Vec<Message>,
3663) -> Result<()> {
3664 match value {
3665 Value::Null => {}
3666 Value::String(text) => push_pi_ai_user_message(text, messages),
3667 Value::Array(items) => {
3668 for item in items {
3669 push_pi_ai_message(item, system_prompts, messages)?;
3670 }
3671 }
3672 Value::Object(map) => {
3673 if let Some(system) = map
3674 .get("systemPrompt")
3675 .or_else(|| map.get("system_prompt"))
3676 .or_else(|| map.get("system"))
3677 .and_then(pi_ai_text_from_value)
3678 {
3679 system_prompts.push(system);
3680 }
3681
3682 if let Some(items) = map.get("messages").and_then(Value::as_array) {
3683 for item in items {
3684 push_pi_ai_message(item, system_prompts, messages)?;
3685 }
3686 } else if let Some(prompt) = map
3687 .get("prompt")
3688 .or_else(|| map.get("input"))
3689 .or_else(|| map.get("message"))
3690 .and_then(pi_ai_text_from_value)
3691 {
3692 push_pi_ai_user_message(&prompt, messages);
3693 } else if map.contains_key("role") {
3694 push_pi_ai_message(value, system_prompts, messages)?;
3695 }
3696 }
3697 Value::Bool(_) | Value::Number(_) => push_pi_ai_user_message(&value.to_string(), messages),
3698 }
3699 Ok(())
3700}
3701
3702fn push_pi_ai_message(
3703 value: &Value,
3704 system_prompts: &mut Vec<String>,
3705 messages: &mut Vec<Message>,
3706) -> Result<()> {
3707 let Value::Object(map) = value else {
3708 if let Some(text) = pi_ai_text_from_value(value) {
3709 push_pi_ai_user_message(&text, messages);
3710 }
3711 return Ok(());
3712 };
3713
3714 let role = map
3715 .get("role")
3716 .and_then(Value::as_str)
3717 .unwrap_or("user")
3718 .trim()
3719 .to_ascii_lowercase();
3720 let content = map
3721 .get("content")
3722 .or_else(|| map.get("text"))
3723 .and_then(pi_ai_text_from_value)
3724 .unwrap_or_default();
3725
3726 match role.as_str() {
3727 "system" => {
3728 if !content.trim().is_empty() {
3729 system_prompts.push(content);
3730 }
3731 }
3732 "user" => push_pi_ai_user_message(&content, messages),
3733 "assistant" => push_pi_ai_assistant_message(&content, messages),
3734 other => {
3735 return Err(Error::validation(format!(
3736 "@mariozechner/pi-ai completion does not support {other:?} context messages"
3737 )));
3738 }
3739 }
3740 Ok(())
3741}
3742
3743fn push_pi_ai_user_message(text: &str, messages: &mut Vec<Message>) {
3744 messages.push(Message::User(UserMessage {
3745 content: UserContent::Text(text.to_string()),
3746 timestamp: Utc::now().timestamp_millis(),
3747 }));
3748}
3749
3750fn push_pi_ai_assistant_message(text: &str, messages: &mut Vec<Message>) {
3751 messages.push(Message::assistant(AssistantMessage {
3752 content: vec![ContentBlock::Text(TextContent::new(text.to_string()))],
3753 timestamp: Utc::now().timestamp_millis(),
3754 ..AssistantMessage::default()
3755 }));
3756}
3757
3758fn pi_ai_text_from_value(value: &Value) -> Option<String> {
3759 match value {
3760 Value::Null => None,
3761 Value::String(text) => Some(text.clone()),
3762 Value::Bool(_) | Value::Number(_) => Some(value.to_string()),
3763 Value::Array(items) => {
3764 let mut text = String::new();
3765 for item in items {
3766 if let Some(part) = pi_ai_text_from_value(item)
3767 && !part.is_empty()
3768 {
3769 text.push_str(&part);
3770 }
3771 }
3772 Some(text)
3773 }
3774 Value::Object(map) => map
3775 .get("text")
3776 .or_else(|| map.get("content"))
3777 .or_else(|| map.get("delta"))
3778 .and_then(pi_ai_text_from_value),
3779 }
3780}
3781
3782fn pi_ai_assistant_text(message: &AssistantMessage) -> String {
3783 let mut text = String::new();
3784 for block in &message.content {
3785 if let ContentBlock::Text(text_block) = block {
3786 text.push_str(&text_block.text);
3787 }
3788 }
3789 text
3790}
3791
3792fn pi_ai_assistant_error_message(message: &AssistantMessage) -> String {
3793 message
3794 .error_message
3795 .clone()
3796 .filter(|text| !text.trim().is_empty())
3797 .unwrap_or_else(|| {
3798 let text = pi_ai_assistant_text(message);
3799 if text.trim().is_empty() {
3800 "provider returned an error without a message".to_string()
3801 } else {
3802 text
3803 }
3804 })
3805}
3806
3807fn pi_ai_completion_response(message: &AssistantMessage, simple: bool) -> Result<Value> {
3808 let text = pi_ai_assistant_text(message);
3809 if simple {
3810 return Ok(Value::String(text));
3811 }
3812
3813 Ok(json!({
3814 "message": serde_json::to_value(message)?,
3815 "content": serde_json::to_value(&message.content)?,
3816 "text": text,
3817 "usage": serde_json::to_value(&message.usage)?,
3818 "model": message.model,
3819 "provider": message.provider,
3820 "api": message.api,
3821 "stopReason": message.stop_reason,
3822 }))
3823}
3824
3825#[cfg(test)]
3826mod message_queue_tests {
3827 use super::*;
3828
3829 fn user_message(text: &str) -> Message {
3830 Message::User(UserMessage {
3831 content: UserContent::Text(text.to_string()),
3832 timestamp: 0,
3833 })
3834 }
3835
3836 #[test]
3837 fn message_queue_one_at_a_time() {
3838 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
3839 queue.push_steering(user_message("a"));
3840 queue.push_steering(user_message("b"));
3841
3842 let first = queue.pop_steering();
3843 assert_eq!(first.len(), 1);
3844 assert!(matches!(
3845 first.first(),
3846 Some(Message::User(UserMessage { content, .. }))
3847 if matches!(content, UserContent::Text(text) if text == "a")
3848 ));
3849
3850 let second = queue.pop_steering();
3851 assert_eq!(second.len(), 1);
3852 assert!(matches!(
3853 second.first(),
3854 Some(Message::User(UserMessage { content, .. }))
3855 if matches!(content, UserContent::Text(text) if text == "b")
3856 ));
3857
3858 assert!(queue.pop_steering().is_empty());
3859 }
3860
3861 #[test]
3862 fn message_queue_all_mode() {
3863 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
3864 queue.push_steering(user_message("a"));
3865 queue.push_steering(user_message("b"));
3866
3867 let drained = queue.pop_steering();
3868 assert_eq!(drained.len(), 2);
3869 assert!(queue.pop_steering().is_empty());
3870 }
3871
3872 #[test]
3873 fn message_queue_separates_kinds() {
3874 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
3875 queue.push_steering(user_message("steer"));
3876 queue.push_follow_up(user_message("follow"));
3877
3878 let steering = queue.pop_steering();
3879 assert_eq!(steering.len(), 1);
3880 assert_eq!(queue.pending_count(), 1);
3881
3882 let follow = queue.pop_follow_up();
3883 assert_eq!(follow.len(), 1);
3884 assert_eq!(queue.pending_count(), 0);
3885 }
3886
3887 #[test]
3888 fn message_queue_seq_increments() {
3889 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
3890 let first = queue.push_steering(user_message("a"));
3891 let second = queue.push_follow_up(user_message("b"));
3892 assert!(second > first);
3893 }
3894
3895 #[test]
3896 fn message_queue_seq_saturates_at_u64_max() {
3897 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
3898 queue.next_seq = u64::MAX;
3899
3900 let first = queue.push_steering(user_message("a"));
3901 let second = queue.push_follow_up(user_message("b"));
3902
3903 assert_eq!(first, u64::MAX);
3904 assert_eq!(second, u64::MAX);
3905 assert_eq!(queue.pending_count(), 2);
3906 }
3907
3908 #[test]
3909 fn message_queue_follow_up_all_mode_drains_entire_queue_in_order() {
3910 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::All);
3911 queue.push_follow_up(user_message("f1"));
3912 queue.push_follow_up(user_message("f2"));
3913
3914 let follow_up = queue.pop_follow_up();
3915 assert_eq!(follow_up.len(), 2);
3916 assert!(matches!(
3917 follow_up.first(),
3918 Some(Message::User(UserMessage { content, .. }))
3919 if matches!(content, UserContent::Text(text) if text == "f1")
3920 ));
3921 assert!(matches!(
3922 follow_up.get(1),
3923 Some(Message::User(UserMessage { content, .. }))
3924 if matches!(content, UserContent::Text(text) if text == "f2")
3925 ));
3926 assert!(queue.pop_follow_up().is_empty());
3927 }
3928}
3929
3930#[cfg(test)]
3931mod compatible_tool_parallelism_tests {
3932 use super::*;
3933
3934 #[test]
3935 fn compatible_tool_parallelism_preserves_historical_floor() {
3936 assert_eq!(resolve_compatible_tool_parallelism(None, 1), 8);
3937 assert_eq!(resolve_compatible_tool_parallelism(None, 8), 8);
3938 }
3939
3940 #[test]
3941 fn compatible_tool_parallelism_scales_on_many_core_hosts() {
3942 assert_eq!(resolve_compatible_tool_parallelism(None, 32), 32);
3943 assert_eq!(resolve_compatible_tool_parallelism(None, 64), 64);
3944 assert_eq!(resolve_compatible_tool_parallelism(None, 128), 64);
3945 }
3946
3947 #[test]
3948 fn compatible_tool_parallelism_accepts_bounded_override() {
3949 assert_eq!(resolve_compatible_tool_parallelism(Some("16"), 4), 16);
3950 assert_eq!(resolve_compatible_tool_parallelism(Some("512"), 64), 256);
3951 assert_eq!(resolve_compatible_tool_parallelism(Some("1"), 64), 1);
3952 }
3953
3954 #[test]
3955 fn compatible_tool_parallelism_ignores_invalid_override() {
3956 assert_eq!(
3957 resolve_compatible_tool_parallelism(Some("not-a-number"), 24),
3958 24
3959 );
3960 assert_eq!(resolve_compatible_tool_parallelism(Some("0"), 24), 24);
3961 assert_eq!(resolve_compatible_tool_parallelism(Some(" "), 24), 24);
3962 }
3963}
3964
3965#[cfg(test)]
3966mod tool_effect_batch_planning_tests {
3967 use super::*;
3968
3969 #[derive(Debug, Clone, Copy)]
3970 enum SyntheticOutcome {
3971 Success,
3972 Error,
3973 }
3974
3975 #[derive(Debug, Clone)]
3976 struct SyntheticToolCase {
3977 id: String,
3978 name: String,
3979 registered_effects: Option<ToolEffects>,
3980 outcome: SyntheticOutcome,
3981 }
3982
3983 #[derive(Debug, Clone, Copy)]
3984 enum BatchArrivalOrder {
3985 Forward,
3986 Reverse,
3987 RotateLeft(usize),
3988 }
3989
3990 #[derive(Debug, Clone, PartialEq, Eq)]
3991 struct TranscriptEntry {
3992 tool_call_id: String,
3993 tool_name: String,
3994 text: String,
3995 details: serde_json::Value,
3996 is_error: bool,
3997 }
3998
3999 fn batch_ranges(effects: &[ToolEffects]) -> Vec<(usize, usize)> {
4000 plan_tool_effect_batches(effects)
4001 .into_iter()
4002 .map(|batch| (batch.start, batch.end))
4003 .collect()
4004 }
4005
4006 fn batch_plan_json(effects: &[ToolEffects], parallelism_cap: usize) -> serde_json::Value {
4007 serde_json::to_value(tool_effect_batch_plan_evidence(effects, parallelism_cap))
4008 .expect("tool-effect batch evidence should serialize")
4009 }
4010
4011 fn synthetic_tool_case(
4012 index: usize,
4013 name: impl Into<String>,
4014 registered_effects: Option<ToolEffects>,
4015 outcome: SyntheticOutcome,
4016 ) -> SyntheticToolCase {
4017 SyntheticToolCase {
4018 id: format!("call-{index:03}"),
4019 name: name.into(),
4020 registered_effects,
4021 outcome,
4022 }
4023 }
4024
4025 fn effect_plan(cases: &[SyntheticToolCase]) -> Vec<ToolEffects> {
4026 cases
4027 .iter()
4028 .map(|case| case.registered_effects.unwrap_or_else(ToolEffects::write))
4029 .collect()
4030 }
4031
4032 fn make_tool_result(case: &SyntheticToolCase, index: usize) -> ToolResultMessage {
4033 let (content, is_error) = match case.outcome {
4034 SyntheticOutcome::Success => (format!("ok:{}", case.name), false),
4035 SyntheticOutcome::Error => (format!("error:{}", case.name), true),
4036 };
4037 ToolResultMessage {
4038 tool_call_id: case.id.clone(),
4039 tool_name: case.name.clone(),
4040 content: vec![ContentBlock::Text(TextContent::new(content))],
4041 details: Some(serde_json::json!({
4042 "ordinal": index,
4043 "tool": case.name,
4044 "status": if is_error { "error" } else { "ok" },
4045 })),
4046 is_error,
4047 timestamp: 42,
4048 }
4049 }
4050
4051 fn transcript_entry(message: &ToolResultMessage) -> TranscriptEntry {
4052 assert_eq!(message.content.len(), 1, "synthetic result content drifted");
4053 let text = message
4054 .content
4055 .first()
4056 .and_then(|block| match block {
4057 ContentBlock::Text(text) => Some(text.text.clone()),
4058 _ => None,
4059 })
4060 .unwrap_or_else(|| "non-text synthetic result".to_string());
4061 TranscriptEntry {
4062 tool_call_id: message.tool_call_id.clone(),
4063 tool_name: message.tool_name.clone(),
4064 text,
4065 details: message.details.clone().unwrap_or(serde_json::Value::Null),
4066 is_error: message.is_error,
4067 }
4068 }
4069
4070 fn sequential_oracle(cases: &[SyntheticToolCase]) -> Vec<TranscriptEntry> {
4071 cases
4072 .iter()
4073 .enumerate()
4074 .map(|(index, case)| transcript_entry(&make_tool_result(case, index)))
4075 .collect()
4076 }
4077
4078 fn reorder_batch(indices: &mut [usize], order: BatchArrivalOrder) {
4079 match order {
4080 BatchArrivalOrder::Forward => {}
4081 BatchArrivalOrder::Reverse => indices.reverse(),
4082 BatchArrivalOrder::RotateLeft(amount) => {
4083 if !indices.is_empty() {
4084 indices.rotate_left(amount % indices.len());
4085 }
4086 }
4087 }
4088 }
4089
4090 fn scheduled_transcript(
4091 cases: &[SyntheticToolCase],
4092 order: BatchArrivalOrder,
4093 ) -> Vec<TranscriptEntry> {
4094 let effects = effect_plan(cases);
4095 let batches = plan_tool_effect_batches(&effects);
4096 let mut recorded_results: Vec<Option<ToolResultMessage>> = vec![None; cases.len()];
4097
4098 for batch in batches {
4099 let mut completion_order = (batch.start..batch.end).collect::<Vec<_>>();
4100 reorder_batch(&mut completion_order, order);
4101 let mut batch_results = completion_order
4102 .into_iter()
4103 .filter_map(|index| {
4104 cases
4105 .get(index)
4106 .map(|case| (index, make_tool_result(case, index)))
4107 })
4108 .collect::<Vec<_>>();
4109 batch_results.sort_by_key(|(index, _)| *index);
4110 for (index, result) in batch_results {
4111 if let Some(slot) = recorded_results.get_mut(index) {
4112 *slot = Some(result);
4113 }
4114 }
4115 }
4116
4117 assert!(
4118 recorded_results.iter().all(Option::is_some),
4119 "scheduled execution should record every result"
4120 );
4121 recorded_results
4122 .into_iter()
4123 .flatten()
4124 .map(|result| transcript_entry(&result))
4125 .collect()
4126 }
4127
4128 fn assert_barrier_effects_are_singleton_batches(cases: &[SyntheticToolCase]) {
4129 let effects = effect_plan(cases);
4130 for batch in plan_tool_effect_batches(&effects) {
4131 let batch_effects = effects
4132 .get(batch.start..batch.end)
4133 .unwrap_or(&[])
4134 .iter()
4135 .copied()
4136 .fold(ToolEffects::read(), ToolEffects::union);
4137 if !batch_effects.parallel_safe() {
4138 assert_eq!(
4139 batch.end - batch.start,
4140 1,
4141 "barrier batch must serialize original index {}",
4142 batch.start
4143 );
4144 }
4145 }
4146 }
4147
4148 #[test]
4149 fn read_and_network_effects_share_compatible_batch() {
4150 let ranges = batch_ranges(&[
4151 ToolEffects::read(),
4152 ToolEffects::network(),
4153 ToolEffects::read(),
4154 ]);
4155
4156 assert_eq!(ranges, vec![(0, 3)]);
4157 }
4158
4159 #[test]
4160 fn evidence_records_64_plus_compatible_batch_with_parallelism_cap() {
4161 let effects = (0..72)
4162 .map(|index| {
4163 if index % 3 == 0 {
4164 ToolEffects::network()
4165 } else {
4166 ToolEffects::read()
4167 }
4168 })
4169 .collect::<Vec<_>>();
4170
4171 assert_eq!(
4172 batch_plan_json(&effects, 64),
4173 serde_json::json!({
4174 "schema": TOOL_EFFECT_BATCH_PLAN_SCHEMA_V1,
4175 "toolCount": 72,
4176 "parallelismCap": 64,
4177 "batches": [
4178 {
4179 "start": 0,
4180 "end": 72,
4181 "len": 72,
4182 "combinedEffects": ["read", "network"],
4183 "parallelSafe": true
4184 }
4185 ]
4186 })
4187 );
4188 }
4189
4190 #[test]
4191 fn write_effect_creates_deterministic_barrier() {
4192 let ranges = batch_ranges(&[
4193 ToolEffects::read(),
4194 ToolEffects::read(),
4195 ToolEffects::write(),
4196 ToolEffects::read(),
4197 ]);
4198
4199 assert_eq!(ranges, vec![(0, 2), (2, 3), (3, 4)]);
4200 }
4201
4202 #[test]
4203 fn append_and_process_effects_remain_serialized() {
4204 let ranges = batch_ranges(&[
4205 ToolEffects::append(),
4206 ToolEffects::append(),
4207 ToolEffects::process(),
4208 ToolEffects::read(),
4209 ]);
4210
4211 assert_eq!(ranges, vec![(0, 1), (1, 2), (2, 3), (3, 4)]);
4212 }
4213
4214 #[test]
4215 fn combined_process_write_effect_is_exclusive() {
4216 let ranges = batch_ranges(&[
4217 ToolEffects::read(),
4218 ToolEffects::process().union(ToolEffects::write()),
4219 ToolEffects::network(),
4220 ]);
4221
4222 assert_eq!(ranges, vec![(0, 1), (1, 2), (2, 3)]);
4223 }
4224
4225 #[test]
4226 fn evidence_records_barrier_reasons_for_mixed_effects() {
4227 let effects = [
4228 ToolEffects::read(),
4229 ToolEffects::network(),
4230 ToolEffects::write(),
4231 ToolEffects::append(),
4232 ToolEffects::process(),
4233 ToolEffects::read(),
4234 ToolEffects::process().union(ToolEffects::write()),
4235 ];
4236
4237 assert_eq!(
4238 batch_plan_json(&effects, 32),
4239 serde_json::json!({
4240 "schema": TOOL_EFFECT_BATCH_PLAN_SCHEMA_V1,
4241 "toolCount": 7,
4242 "parallelismCap": 32,
4243 "batches": [
4244 {
4245 "start": 0,
4246 "end": 2,
4247 "len": 2,
4248 "combinedEffects": ["read", "network"],
4249 "parallelSafe": true
4250 },
4251 {
4252 "start": 2,
4253 "end": 3,
4254 "len": 1,
4255 "combinedEffects": ["write"],
4256 "parallelSafe": false,
4257 "barrierReason": "write_barrier"
4258 },
4259 {
4260 "start": 3,
4261 "end": 4,
4262 "len": 1,
4263 "combinedEffects": ["append"],
4264 "parallelSafe": false,
4265 "barrierReason": "append_barrier"
4266 },
4267 {
4268 "start": 4,
4269 "end": 5,
4270 "len": 1,
4271 "combinedEffects": ["process"],
4272 "parallelSafe": false,
4273 "barrierReason": "process_barrier"
4274 },
4275 {
4276 "start": 5,
4277 "end": 6,
4278 "len": 1,
4279 "combinedEffects": ["read"],
4280 "parallelSafe": true
4281 },
4282 {
4283 "start": 6,
4284 "end": 7,
4285 "len": 1,
4286 "combinedEffects": ["write", "process"],
4287 "parallelSafe": false,
4288 "barrierReason": "write_process_barrier"
4289 }
4290 ]
4291 })
4292 );
4293 }
4294
4295 #[test]
4296 fn metamorphic_empty_tool_batch_matches_sequential_oracle() {
4297 let cases = Vec::new();
4298
4299 assert!(plan_tool_effect_batches(&effect_plan(&cases)).is_empty());
4300 assert_eq!(
4301 scheduled_transcript(&cases, BatchArrivalOrder::Forward),
4302 sequential_oracle(&cases)
4303 );
4304 }
4305
4306 #[test]
4307 fn metamorphic_mixed_effect_batches_match_sequential_oracle() {
4308 let cases = vec![
4309 synthetic_tool_case(
4310 0,
4311 "read",
4312 Some(ToolEffects::read()),
4313 SyntheticOutcome::Success,
4314 ),
4315 synthetic_tool_case(
4316 1,
4317 "network",
4318 Some(ToolEffects::network()),
4319 SyntheticOutcome::Success,
4320 ),
4321 synthetic_tool_case(
4322 2,
4323 "write",
4324 Some(ToolEffects::write()),
4325 SyntheticOutcome::Success,
4326 ),
4327 synthetic_tool_case(
4328 3,
4329 "read",
4330 Some(ToolEffects::read()),
4331 SyntheticOutcome::Success,
4332 ),
4333 synthetic_tool_case(
4334 4,
4335 "append",
4336 Some(ToolEffects::append()),
4337 SyntheticOutcome::Error,
4338 ),
4339 synthetic_tool_case(
4340 5,
4341 "network",
4342 Some(ToolEffects::network()),
4343 SyntheticOutcome::Success,
4344 ),
4345 synthetic_tool_case(
4346 6,
4347 "process",
4348 Some(ToolEffects::process()),
4349 SyntheticOutcome::Success,
4350 ),
4351 synthetic_tool_case(
4352 7,
4353 "read",
4354 Some(ToolEffects::read()),
4355 SyntheticOutcome::Error,
4356 ),
4357 synthetic_tool_case(8, "unknown", None, SyntheticOutcome::Success),
4358 synthetic_tool_case(
4359 9,
4360 "network",
4361 Some(ToolEffects::network()),
4362 SyntheticOutcome::Success,
4363 ),
4364 ];
4365
4366 assert_eq!(
4367 batch_ranges(&effect_plan(&cases)),
4368 vec![
4369 (0, 2),
4370 (2, 3),
4371 (3, 4),
4372 (4, 5),
4373 (5, 6),
4374 (6, 7),
4375 (7, 8),
4376 (8, 9),
4377 (9, 10)
4378 ]
4379 );
4380 let evidence = tool_effect_batch_plan_evidence(&effect_plan(&cases), 16);
4381 assert_eq!(evidence.schema, TOOL_EFFECT_BATCH_PLAN_SCHEMA_V1);
4382 assert_eq!(evidence.parallelism_cap, 16);
4383 assert_eq!(evidence.batches.len(), 9);
4384 assert!(evidence.batches.iter().any(|batch| {
4385 batch.barrier_reason == Some("append_barrier") && batch.combined_effects == ["append"]
4386 }));
4387 assert!(
4388 cases
4389 .iter()
4390 .any(|case| matches!(case.outcome, SyntheticOutcome::Error)),
4391 "mixed-effect fixture must include failure cases"
4392 );
4393 assert_barrier_effects_are_singleton_batches(&cases);
4394
4395 let oracle = sequential_oracle(&cases);
4396 assert_eq!(
4397 scheduled_transcript(&cases, BatchArrivalOrder::Reverse),
4398 oracle
4399 );
4400 assert_eq!(
4401 scheduled_transcript(&cases, BatchArrivalOrder::RotateLeft(1)),
4402 oracle
4403 );
4404 }
4405
4406 #[test]
4407 fn metamorphic_high_count_batches_keep_transcript_deterministic() {
4408 let cases = (0..96)
4409 .map(|index| match index % 12 {
4410 0 => synthetic_tool_case(
4411 index,
4412 format!("process-{index}"),
4413 Some(ToolEffects::process()),
4414 SyntheticOutcome::Success,
4415 ),
4416 5 => synthetic_tool_case(
4417 index,
4418 format!("append-{index}"),
4419 Some(ToolEffects::append()),
4420 SyntheticOutcome::Success,
4421 ),
4422 9 => synthetic_tool_case(
4423 index,
4424 format!("unknown-{index}"),
4425 None,
4426 SyntheticOutcome::Error,
4427 ),
4428 3 | 7 => synthetic_tool_case(
4429 index,
4430 format!("network-{index}"),
4431 Some(ToolEffects::network()),
4432 SyntheticOutcome::Success,
4433 ),
4434 _ => synthetic_tool_case(
4435 index,
4436 format!("read-{index}"),
4437 Some(ToolEffects::read()),
4438 SyntheticOutcome::Success,
4439 ),
4440 })
4441 .collect::<Vec<_>>();
4442
4443 assert_barrier_effects_are_singleton_batches(&cases);
4444 let oracle = sequential_oracle(&cases);
4445 assert_eq!(
4446 scheduled_transcript(&cases, BatchArrivalOrder::Forward),
4447 oracle
4448 );
4449 assert_eq!(
4450 scheduled_transcript(&cases, BatchArrivalOrder::Reverse),
4451 oracle
4452 );
4453 assert_eq!(
4454 scheduled_transcript(&cases, BatchArrivalOrder::RotateLeft(3)),
4455 oracle
4456 );
4457 }
4458}
4459
4460#[cfg(test)]
4461mod extensions_integration_tests {
4462 use super::*;
4463
4464 use crate::session::Session;
4465 use asupersync::runtime::RuntimeBuilder;
4466 use async_trait::async_trait;
4467 use futures::Stream;
4468 use serde_json::json;
4469 use std::path::Path;
4470 use std::pin::Pin;
4471 use std::sync::atomic::AtomicUsize;
4472 use std::time::Duration;
4473
4474 #[derive(Debug)]
4475 struct NoopProvider;
4476
4477 #[async_trait]
4478 #[allow(clippy::unnecessary_literal_bound)]
4479 impl Provider for NoopProvider {
4480 fn name(&self) -> &str {
4481 "test-provider"
4482 }
4483
4484 fn api(&self) -> &str {
4485 "test-api"
4486 }
4487
4488 fn model_id(&self) -> &str {
4489 "test-model"
4490 }
4491
4492 async fn stream(
4493 &self,
4494 _context: &Context<'_>,
4495 _options: &StreamOptions,
4496 ) -> crate::error::Result<
4497 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4498 > {
4499 Ok(Box::pin(futures::stream::empty()))
4500 }
4501 }
4502
4503 #[derive(Debug)]
4504 struct IdleCommandProvider;
4505
4506 #[async_trait]
4507 #[allow(clippy::unnecessary_literal_bound)]
4508 impl Provider for IdleCommandProvider {
4509 fn name(&self) -> &str {
4510 "test-provider"
4511 }
4512
4513 fn api(&self) -> &str {
4514 "test-api"
4515 }
4516
4517 fn model_id(&self) -> &str {
4518 "test-model"
4519 }
4520
4521 async fn stream(
4522 &self,
4523 _context: &Context<'_>,
4524 _options: &StreamOptions,
4525 ) -> crate::error::Result<
4526 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4527 > {
4528 let partial = AssistantMessage {
4529 content: Vec::new(),
4530 api: self.api().to_string(),
4531 provider: self.name().to_string(),
4532 model: self.model_id().to_string(),
4533 usage: Usage::default(),
4534 stop_reason: StopReason::Stop,
4535 error_message: None,
4536 timestamp: 0,
4537 };
4538 let done = AssistantMessage {
4539 content: vec![ContentBlock::Text(TextContent::new(
4540 "resumed-response-0".to_string(),
4541 ))],
4542 api: self.api().to_string(),
4543 provider: self.name().to_string(),
4544 model: self.model_id().to_string(),
4545 usage: Usage::default(),
4546 stop_reason: StopReason::Stop,
4547 error_message: None,
4548 timestamp: 0,
4549 };
4550 Ok(Box::pin(futures::stream::iter(vec![
4551 Ok(StreamEvent::Start { partial }),
4552 Ok(StreamEvent::Done {
4553 reason: StopReason::Stop,
4554 message: done,
4555 }),
4556 ])))
4557 }
4558 }
4559
4560 #[derive(Debug)]
4561 struct CountingTool {
4562 calls: Arc<AtomicUsize>,
4563 }
4564
4565 #[async_trait]
4566 #[allow(clippy::unnecessary_literal_bound)]
4567 impl Tool for CountingTool {
4568 fn name(&self) -> &str {
4569 "count_tool"
4570 }
4571
4572 fn label(&self) -> &str {
4573 "count_tool"
4574 }
4575
4576 fn description(&self) -> &str {
4577 "counting tool"
4578 }
4579
4580 fn parameters(&self) -> serde_json::Value {
4581 json!({ "type": "object" })
4582 }
4583
4584 async fn execute(
4585 &self,
4586 _tool_call_id: &str,
4587 _input: serde_json::Value,
4588 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
4589 ) -> Result<ToolOutput> {
4590 self.calls.fetch_add(1, Ordering::SeqCst);
4591 Ok(ToolOutput {
4592 content: vec![ContentBlock::Text(TextContent::new("ok"))],
4593 details: None,
4594 is_error: false,
4595 })
4596 }
4597 }
4598
4599 #[derive(Debug)]
4600 struct ToolUseProvider {
4601 stream_calls: AtomicUsize,
4602 }
4603
4604 impl ToolUseProvider {
4605 const fn new() -> Self {
4606 Self {
4607 stream_calls: AtomicUsize::new(0),
4608 }
4609 }
4610
4611 fn assistant_message(
4612 &self,
4613 stop_reason: StopReason,
4614 content: Vec<ContentBlock>,
4615 ) -> AssistantMessage {
4616 AssistantMessage {
4617 content,
4618 api: self.api().to_string(),
4619 provider: self.name().to_string(),
4620 model: self.model_id().to_string(),
4621 usage: Usage::default(),
4622 stop_reason,
4623 error_message: None,
4624 timestamp: 0,
4625 }
4626 }
4627 }
4628
4629 #[async_trait]
4630 #[allow(clippy::unnecessary_literal_bound)]
4631 impl Provider for ToolUseProvider {
4632 fn name(&self) -> &str {
4633 "test-provider"
4634 }
4635
4636 fn api(&self) -> &str {
4637 "test-api"
4638 }
4639
4640 fn model_id(&self) -> &str {
4641 "test-model"
4642 }
4643
4644 async fn stream(
4645 &self,
4646 _context: &Context<'_>,
4647 _options: &StreamOptions,
4648 ) -> crate::error::Result<
4649 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
4650 > {
4651 let call_index = self.stream_calls.fetch_add(1, Ordering::SeqCst);
4652
4653 let partial = self.assistant_message(StopReason::Stop, Vec::new());
4654
4655 let (reason, message) = if call_index == 0 {
4656 let tool_calls = vec![
4657 ToolCall {
4658 id: "call-1".to_string(),
4659 name: "count_tool".to_string(),
4660 arguments: json!({}),
4661 thought_signature: None,
4662 },
4663 ToolCall {
4664 id: "call-2".to_string(),
4665 name: "count_tool".to_string(),
4666 arguments: json!({}),
4667 thought_signature: None,
4668 },
4669 ];
4670
4671 (
4672 StopReason::ToolUse,
4673 self.assistant_message(
4674 StopReason::ToolUse,
4675 tool_calls
4676 .into_iter()
4677 .map(ContentBlock::ToolCall)
4678 .collect::<Vec<_>>(),
4679 ),
4680 )
4681 } else {
4682 (
4683 StopReason::Stop,
4684 self.assistant_message(
4685 StopReason::Stop,
4686 vec![ContentBlock::Text(TextContent::new("done"))],
4687 ),
4688 )
4689 };
4690
4691 let events = vec![
4692 Ok(StreamEvent::Start { partial }),
4693 Ok(StreamEvent::Done { reason, message }),
4694 ];
4695 Ok(Box::pin(futures::stream::iter(events)))
4696 }
4697 }
4698
4699 #[test]
4700 fn agent_session_enable_extensions_registers_extension_tools() {
4701 let runtime = RuntimeBuilder::current_thread()
4702 .build()
4703 .expect("runtime build");
4704
4705 runtime.block_on(async {
4706 let temp_dir = tempfile::tempdir().expect("tempdir");
4707 let entry_path = temp_dir.path().join("ext.mjs");
4708 std::fs::write(
4709 &entry_path,
4710 r#"
4711 export default function init(pi) {
4712 pi.registerTool({
4713 name: "hello_tool",
4714 label: "hello_tool",
4715 description: "test tool",
4716 parameters: { type: "object", properties: { name: { type: "string" } } },
4717 execute: async (_callId, input, _onUpdate, _abort, ctx) => {
4718 const who = input && input.name ? String(input.name) : "world";
4719 const cwd = ctx && ctx.cwd ? String(ctx.cwd) : "";
4720 return {
4721 content: [{ type: "text", text: `hello ${who}` }],
4722 details: { from: "extension", cwd: cwd },
4723 isError: false
4724 };
4725 }
4726 });
4727 }
4728 "#,
4729 )
4730 .expect("write extension entry");
4731
4732 let provider = Arc::new(NoopProvider);
4733 let tools = ToolRegistry::new(&[], Path::new("."), None);
4734 let agent = Agent::new(provider, tools, AgentConfig::default());
4735 let session = Arc::new(Mutex::new(Session::in_memory()));
4736 let mut agent_session =
4737 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4738
4739 agent_session
4740 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4741 .await
4742 .expect("enable extensions");
4743
4744 let tool = agent_session
4745 .agent
4746 .tools
4747 .get("hello_tool")
4748 .expect("hello_tool registered");
4749
4750 let output = tool
4751 .execute("call-1", json!({ "name": "pi" }), None)
4752 .await
4753 .expect("execute tool");
4754
4755 assert!(!output.is_error);
4756 assert!(
4757 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
4758 "Expected single text content block, got {:?}",
4759 output.content
4760 );
4761 let [ContentBlock::Text(text)] = output.content.as_slice() else {
4762 return;
4763 };
4764 assert_eq!(text.text, "hello pi");
4765
4766 let details = output.details.expect("details present");
4767 assert_eq!(
4768 details.get("from").and_then(serde_json::Value::as_str),
4769 Some("extension")
4770 );
4771 });
4772 }
4773
4774 #[test]
4775 fn agent_session_enable_extensions_with_no_entries_clears_and_is_noop() {
4776 let runtime = RuntimeBuilder::current_thread()
4777 .build()
4778 .expect("runtime build");
4779
4780 runtime.block_on(async {
4781 let temp_dir = tempfile::tempdir().expect("tempdir");
4782 let provider = Arc::new(NoopProvider);
4783 let tools = ToolRegistry::new(&[], Path::new("."), None);
4784 let agent = Agent::new(provider, tools, AgentConfig::default());
4785 let session = Arc::new(Mutex::new(Session::in_memory()));
4786 let mut agent_session =
4787 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4788
4789 let dummy_manager = ExtensionManager::new();
4791 agent_session.extensions = Some(crate::extensions::ExtensionRegion::new(dummy_manager.clone()));
4792 agent_session.agent.extensions = Some(dummy_manager.clone());
4793 agent_session.extension_queue_modes = Some(Arc::new(std::sync::Mutex::new(ExtensionQueueModeState::new(
4794 QueueMode::OneAtATime,
4795 QueueMode::OneAtATime,
4796 ))));
4797 agent_session.extension_injected_queue = Some(Arc::new(std::sync::Mutex::new(ExtensionInjectedQueue::default())));
4798
4799 agent_session
4800 .enable_extensions(&[], temp_dir.path(), None, &[])
4801 .await
4802 .expect("empty extension list should be a no-op");
4803
4804 assert!(
4805 agent_session.extensions.is_none(),
4806 "no extension region should be created (and existing should be cleared) for an empty extension list"
4807 );
4808 assert!(
4809 agent_session.agent.extensions.is_none(),
4810 "agent should not report extensions active when nothing was requested"
4811 );
4812 assert!(
4813 agent_session.extension_queue_modes.is_none(),
4814 "empty extension list should clear queue mode mirrors"
4815 );
4816 assert!(
4817 agent_session.extension_injected_queue.is_none(),
4818 "empty extension list should clear injected extension queues"
4819 );
4820 });
4821 }
4822
4823 #[test]
4824 fn agent_session_enable_extensions_rejects_mixed_js_and_native_entries() {
4825 let runtime = RuntimeBuilder::current_thread()
4826 .build()
4827 .expect("runtime build");
4828
4829 runtime.block_on(async {
4830 let temp_dir = tempfile::tempdir().expect("tempdir");
4831 let js_entry = temp_dir.path().join("ext.mjs");
4832 let native_entry = temp_dir.path().join("ext.native.json");
4833 std::fs::write(
4834 &js_entry,
4835 r"
4836 export default function init(_pi) {}
4837 ",
4838 )
4839 .expect("write js extension entry");
4840 std::fs::write(&native_entry, "{}").expect("write native extension descriptor");
4841
4842 let provider = Arc::new(NoopProvider);
4843 let tools = ToolRegistry::new(&[], Path::new("."), None);
4844 let agent = Agent::new(provider, tools, AgentConfig::default());
4845 let session = Arc::new(Mutex::new(Session::in_memory()));
4846 let mut agent_session =
4847 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
4848
4849 let err = agent_session
4850 .enable_extensions(&[], temp_dir.path(), None, &[js_entry, native_entry])
4851 .await
4852 .expect_err("mixed extension runtimes should be rejected");
4853 let msg = err.to_string();
4854 assert!(
4855 msg.contains("Mixed extension runtimes are not supported"),
4856 "unexpected mixed-runtime error message: {msg}"
4857 );
4858 });
4859 }
4860
4861 #[test]
4862 fn extension_send_message_persists_custom_message_entry_when_idle() {
4863 let runtime = RuntimeBuilder::current_thread()
4864 .build()
4865 .expect("runtime build");
4866
4867 runtime.block_on(async {
4868 let temp_dir = tempfile::tempdir().expect("tempdir");
4869 let entry_path = temp_dir.path().join("ext.mjs");
4870 std::fs::write(
4871 &entry_path,
4872 r#"
4873 export default function init(pi) {
4874 pi.registerTool({
4875 name: "emit_message",
4876 label: "emit_message",
4877 description: "emit a custom message",
4878 parameters: { type: "object" },
4879 execute: async () => {
4880 pi.sendMessage({
4881 customType: "note",
4882 content: "hello",
4883 display: true,
4884 details: { from: "test" }
4885 }, {});
4886 return { content: [{ type: "text", text: "ok" }], isError: false };
4887 }
4888 });
4889 }
4890 "#,
4891 )
4892 .expect("write extension entry");
4893
4894 let provider = Arc::new(NoopProvider);
4895 let tools = ToolRegistry::new(&[], Path::new("."), None);
4896 let agent = Agent::new(provider, tools, AgentConfig::default());
4897 let session = Arc::new(Mutex::new(Session::in_memory()));
4898 let mut agent_session = AgentSession::new(
4899 agent,
4900 Arc::clone(&session),
4901 false,
4902 ResolvedCompactionSettings::default(),
4903 );
4904
4905 agent_session
4906 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4907 .await
4908 .expect("enable extensions");
4909
4910 let tool = agent_session
4911 .agent
4912 .tools
4913 .get("emit_message")
4914 .expect("emit_message registered");
4915
4916 let _ = tool
4917 .execute("call-1", json!({}), None)
4918 .await
4919 .expect("execute tool");
4920
4921 let cx = crate::agent_cx::AgentCx::for_request();
4922 let session_guard = session.lock(cx.cx()).await.expect("lock session");
4923 let messages = session_guard.to_messages_for_current_path();
4924
4925 assert!(
4926 messages.iter().any(|msg| {
4927 matches!(
4928 msg,
4929 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
4930 if custom_type == "note"
4931 && content == "hello"
4932 && *display
4933 && details
4934 .as_ref()
4935 .and_then(|v| v.get("from").and_then(Value::as_str))
4936 .is_some_and(|from| from.eq("test"))
4937 )
4938 }),
4939 "expected custom message to be persisted, got {messages:?}"
4940 );
4941 });
4942 }
4943
4944 #[test]
4945 fn extension_send_message_persists_custom_message_entry_when_idle_after_await() {
4946 let runtime = RuntimeBuilder::current_thread()
4947 .build()
4948 .expect("runtime build");
4949
4950 runtime.block_on(async {
4951 let temp_dir = tempfile::tempdir().expect("tempdir");
4952 let entry_path = temp_dir.path().join("ext.mjs");
4953 std::fs::write(
4954 &entry_path,
4955 r#"
4956 export default function init(pi) {
4957 pi.registerTool({
4958 name: "emit_message",
4959 label: "emit_message",
4960 description: "emit a custom message",
4961 parameters: { type: "object" },
4962 execute: async () => {
4963 await Promise.resolve();
4964 pi.sendMessage({
4965 customType: "note",
4966 content: "hello-after-await",
4967 display: true,
4968 details: { from: "test" }
4969 }, {});
4970 return { content: [{ type: "text", text: "ok" }], isError: false };
4971 }
4972 });
4973 }
4974 "#,
4975 )
4976 .expect("write extension entry");
4977
4978 let provider = Arc::new(NoopProvider);
4979 let tools = ToolRegistry::new(&[], Path::new("."), None);
4980 let agent = Agent::new(provider, tools, AgentConfig::default());
4981 let session = Arc::new(Mutex::new(Session::in_memory()));
4982 let mut agent_session = AgentSession::new(
4983 agent,
4984 Arc::clone(&session),
4985 false,
4986 ResolvedCompactionSettings::default(),
4987 );
4988
4989 agent_session
4990 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
4991 .await
4992 .expect("enable extensions");
4993
4994 let tool = agent_session
4995 .agent
4996 .tools
4997 .get("emit_message")
4998 .expect("emit_message registered");
4999
5000 let _ = tool
5001 .execute("call-1", json!({}), None)
5002 .await
5003 .expect("execute tool");
5004
5005 let cx = crate::agent_cx::AgentCx::for_request();
5006 let session_guard = session.lock(cx.cx()).await.expect("lock session");
5007 let messages = session_guard.to_messages_for_current_path();
5008
5009 assert!(
5010 messages.iter().any(|msg| {
5011 matches!(
5012 msg,
5013 Message::Custom(CustomMessage { custom_type, content, display, details, .. })
5014 if custom_type == "note"
5015 && content == "hello-after-await"
5016 && *display
5017 && details
5018 .as_ref()
5019 .and_then(|v| v.get("from").and_then(Value::as_str))
5020 .is_some_and(|from| from.eq("test"))
5021 )
5022 }),
5023 "expected custom message to be persisted, got {messages:?}"
5024 );
5025 });
5026 }
5027
5028 #[test]
5029 fn agent_host_actions_send_message_inherits_cancelled_context_when_locked() {
5030 let runtime = RuntimeBuilder::current_thread()
5031 .build()
5032 .expect("runtime build");
5033
5034 runtime.block_on(async {
5035 let session = Arc::new(Mutex::new(Session::in_memory()));
5036 let actions = AgentSessionHostActions {
5037 session: Arc::clone(&session),
5038 injected: Arc::new(StdMutex::new(ExtensionInjectedQueue::default())),
5039 is_streaming: Arc::new(AtomicBool::new(false)),
5040 is_turn_active: Arc::new(AtomicBool::new(false)),
5041 pending_idle_actions: Arc::new(StdMutex::new(VecDeque::new())),
5042 ai_completion: Arc::new(StdMutex::new(ExtensionAiCompletionHostState {
5043 provider: Arc::new(NoopProvider),
5044 stream_options: StreamOptions::default(),
5045 models: Vec::new(),
5046 })),
5047 };
5048
5049 let hold_cx = crate::agent_cx::AgentCx::for_request();
5050 let held_guard = session.lock(hold_cx.cx()).await.expect("lock session");
5051
5052 let ambient_cx = asupersync::Cx::for_testing();
5053 ambient_cx.set_cancel_requested(true);
5054 let _current = asupersync::Cx::set_current(Some(ambient_cx));
5055 let inner = asupersync::time::timeout(
5056 asupersync::time::wall_now(),
5057 Duration::from_millis(100),
5058 actions.send_message(ExtensionSendMessage {
5059 extension_id: Some("ext".to_string()),
5060 custom_type: "note".to_string(),
5061 content: "blocked".to_string(),
5062 display: false,
5063 details: None,
5064 deliver_as: Some(ExtensionDeliverAs::NextTurn),
5065 trigger_turn: false,
5066 }),
5067 )
5068 .await;
5069 let outcome = inner.expect("cancelled helper should finish before timeout");
5070 let err = outcome.expect_err("session append should fail under inherited cancellation");
5071 assert!(
5072 err.to_string().contains("mutex lock cancelled"),
5073 "unexpected error: {err}"
5074 );
5075
5076 drop(held_guard);
5077
5078 let cx = crate::agent_cx::AgentCx::for_request();
5079 let guard = session.lock(cx.cx()).await.expect("lock session");
5080 assert!(
5081 guard.to_messages_for_current_path().is_empty(),
5082 "cancelled send_message should not append a message"
5083 );
5084 });
5085 }
5086
5087 #[derive(Debug, Default)]
5088 struct PiAiCapturedProviderContext {
5089 system_prompt: Option<String>,
5090 messages: Vec<Message>,
5091 }
5092
5093 #[derive(Debug)]
5094 struct PiAiCaptureProvider {
5095 calls: Arc<StdMutex<Vec<PiAiCapturedProviderContext>>>,
5096 }
5097
5098 #[async_trait]
5099 impl Provider for PiAiCaptureProvider {
5100 fn name(&self) -> &'static str {
5101 "capturing-provider"
5102 }
5103
5104 fn api(&self) -> &'static str {
5105 "test-api"
5106 }
5107
5108 fn model_id(&self) -> &'static str {
5109 "capture-model"
5110 }
5111
5112 async fn stream(
5113 &self,
5114 context: &Context<'_>,
5115 _options: &StreamOptions,
5116 ) -> crate::error::Result<
5117 std::pin::Pin<
5118 Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>,
5119 >,
5120 > {
5121 self.calls
5122 .lock()
5123 .unwrap_or_else(std::sync::PoisonError::into_inner)
5124 .push(PiAiCapturedProviderContext {
5125 system_prompt: context.system_prompt.as_ref().map(ToString::to_string),
5126 messages: context.messages.iter().cloned().collect(),
5127 });
5128 let final_message = AssistantMessage {
5129 content: vec![ContentBlock::Text(TextContent::new("captured"))],
5130 api: "test-api".to_string(),
5131 provider: "capturing-provider".to_string(),
5132 model: "capture-model".to_string(),
5133 usage: Usage::default(),
5134 stop_reason: StopReason::Stop,
5135 error_message: None,
5136 timestamp: 0,
5137 };
5138 Ok(Box::pin(futures::stream::iter(vec![Ok(
5139 StreamEvent::Done {
5140 reason: StopReason::Stop,
5141 message: final_message,
5142 },
5143 )])))
5144 }
5145 }
5146
5147 #[test]
5148 fn agent_host_actions_complete_ai_streams_configured_provider() {
5149 let runtime = RuntimeBuilder::current_thread()
5150 .build()
5151 .expect("runtime build");
5152
5153 runtime.block_on(async {
5154 let session = Arc::new(Mutex::new(Session::in_memory()));
5155 let calls = Arc::new(StdMutex::new(Vec::new()));
5156 let provider = Arc::new(PiAiCaptureProvider {
5157 calls: Arc::clone(&calls),
5158 });
5159 let actions = AgentSessionHostActions {
5160 session,
5161 injected: Arc::new(StdMutex::new(ExtensionInjectedQueue::default())),
5162 is_streaming: Arc::new(AtomicBool::new(false)),
5163 is_turn_active: Arc::new(AtomicBool::new(false)),
5164 pending_idle_actions: Arc::new(StdMutex::new(VecDeque::new())),
5165 ai_completion: Arc::new(StdMutex::new(ExtensionAiCompletionHostState {
5166 provider,
5167 stream_options: StreamOptions::default(),
5168 models: vec![json!({
5169 "id": "capture-model",
5170 "provider": "capturing-provider",
5171 "api": "test-api",
5172 })],
5173 })),
5174 };
5175
5176 let result = actions
5177 .complete_ai(ExtensionAiCompletionRequest {
5178 model: json!({ "id": "capture-model" }),
5179 context: json!({
5180 "systemPrompt": "answer tersely",
5181 "messages": [
5182 { "role": "user", "content": "ping" }
5183 ]
5184 }),
5185 options: json!({ "maxTokens": 16 }),
5186 simple: false,
5187 })
5188 .await
5189 .expect("complete through provider");
5190
5191 assert_eq!(result["text"], json!("captured"));
5192 assert_eq!(result["provider"], json!("capturing-provider"));
5193 assert_eq!(result["api"], json!("test-api"));
5194
5195 let (captured_len, captured_system_prompt, captured_messages) = {
5196 let captured = match calls.lock() {
5197 Ok(guard) => guard,
5198 Err(poisoned) => poisoned.into_inner(),
5199 };
5200 (
5201 captured.len(),
5202 captured.first().and_then(|call| call.system_prompt.clone()),
5203 captured
5204 .first()
5205 .map(|call| call.messages.clone())
5206 .unwrap_or_default(),
5207 )
5208 };
5209 assert_eq!(captured_len, 1);
5210 assert_eq!(captured_system_prompt.as_deref(), Some("answer tersely"));
5211 assert_eq!(captured_messages.len(), 1);
5212 assert!(
5213 matches!(
5214 captured_messages.first(),
5215 Some(Message::User(UserMessage { content: UserContent::Text(text), .. }))
5216 if text == "ping"
5217 ),
5218 "expected user message context, got {captured_messages:?}"
5219 );
5220
5221 let models = actions.list_ai_models().await.expect("list models");
5222 assert_eq!(models[0]["id"], json!("capture-model"));
5223 });
5224 }
5225
5226 #[test]
5227 fn extension_command_send_message_trigger_turn_runs_agent_turn_when_idle() {
5228 let runtime = RuntimeBuilder::current_thread()
5229 .build()
5230 .expect("runtime build");
5231
5232 runtime.block_on(async {
5233 let temp_dir = tempfile::tempdir().expect("tempdir");
5234 let entry_path = temp_dir.path().join("ext.mjs");
5235 std::fs::write(
5236 &entry_path,
5237 r#"
5238 export default function init(pi) {
5239 pi.registerCommand("emit-now", {
5240 description: "emit a custom message and trigger a turn",
5241 handler: async () => {
5242 await pi.events("sendMessage", {
5243 message: {
5244 customType: "note",
5245 content: "turn-now",
5246 display: true
5247 },
5248 options: {
5249 deliverAs: "steer",
5250 triggerTurn: true
5251 }
5252 });
5253 return "queued";
5254 }
5255 });
5256 }
5257 "#,
5258 )
5259 .expect("write extension entry");
5260
5261 let provider = Arc::new(IdleCommandProvider);
5262 let tools = ToolRegistry::new(&[], Path::new("."), None);
5263 let agent = Agent::new(provider, tools, AgentConfig::default());
5264 let session = Arc::new(Mutex::new(Session::in_memory()));
5265 let mut agent_session = AgentSession::new(
5266 agent,
5267 Arc::clone(&session),
5268 false,
5269 ResolvedCompactionSettings::default(),
5270 );
5271
5272 agent_session
5273 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5274 .await
5275 .expect("enable extensions");
5276
5277 let value = agent_session
5278 .execute_extension_command("emit-now", "", 5_000, |_| {})
5279 .await
5280 .expect("execute extension command");
5281 assert_eq!(value.as_str(), Some("queued"));
5282
5283 let cx = crate::agent_cx::AgentCx::for_request();
5284 let session_guard = session.lock(cx.cx()).await.expect("lock session");
5285 let messages = session_guard.to_messages_for_current_path();
5286
5287 assert!(
5288 messages.iter().any(|msg| {
5289 matches!(
5290 msg,
5291 Message::Custom(CustomMessage { custom_type, content, .. })
5292 if custom_type == "note" && content == "turn-now"
5293 )
5294 }),
5295 "expected custom message prompt in session, got {messages:?}"
5296 );
5297 assert!(
5298 messages.iter().any(|msg| {
5299 matches!(
5300 msg,
5301 Message::Assistant(assistant)
5302 if assistant.content.iter().any(|block| matches!(
5303 block,
5304 ContentBlock::Text(TextContent { text, .. })
5305 if text.as_str().eq("resumed-response-0")
5306 ))
5307 )
5308 }),
5309 "expected assistant response after triggered turn, got {messages:?}"
5310 );
5311 });
5312 }
5313
5314 #[test]
5315 fn agent_extension_session_get_state_reports_agent_runtime_state() {
5316 let runtime = RuntimeBuilder::current_thread()
5317 .build()
5318 .expect("runtime build");
5319
5320 runtime.block_on(async {
5321 let mut session = Session::in_memory();
5322 session.set_model_header(
5323 Some("test-provider".to_string()),
5324 Some("test-model".to_string()),
5325 Some("high".to_string()),
5326 );
5327 session.append_message(crate::session::SessionMessage::User {
5328 content: UserContent::Text("hello".to_string()),
5329 timestamp: Some(1),
5330 });
5331 let session = Arc::new(Mutex::new(session));
5332
5333 let extension_session = AgentExtensionSession {
5334 handle: SessionHandle(Arc::clone(&session)),
5335 is_streaming: Arc::new(AtomicBool::new(true)),
5336 is_compacting: Arc::new(AtomicBool::new(true)),
5337 queue_modes: Arc::new(StdMutex::new(ExtensionQueueModeState::new(
5338 QueueMode::All,
5339 QueueMode::OneAtATime,
5340 ))),
5341 auto_compaction_enabled: true,
5342 };
5343
5344 let state = <AgentExtensionSession as crate::extensions::ExtensionSession>::get_state(
5345 &extension_session,
5346 )
5347 .await;
5348
5349 assert_eq!(state["model"]["provider"], "test-provider");
5350 assert_eq!(state["model"]["id"], "test-model");
5351 assert_eq!(state["thinkingLevel"], "high");
5352 assert_eq!(state["isStreaming"], true);
5353 assert_eq!(state["isCompacting"], true);
5354 assert_eq!(state["steeringMode"], "all");
5355 assert_eq!(state["followUpMode"], "one-at-a-time");
5356 assert_eq!(state["autoCompactionEnabled"], true);
5357 assert_eq!(state["messageCount"], 1);
5358 });
5359 }
5360
5361 #[test]
5362 fn agent_extension_session_get_state_uses_branch_local_model_and_thinking() {
5363 let runtime = RuntimeBuilder::current_thread()
5364 .build()
5365 .expect("runtime build");
5366
5367 runtime.block_on(async {
5368 let mut session = Session::in_memory();
5369 let root_id = session.append_message(crate::session::SessionMessage::User {
5370 content: UserContent::Text("root".to_string()),
5371 timestamp: Some(1),
5372 });
5373 session.append_model_change("openai".to_string(), "gpt-4o".to_string());
5374 let branch_a_thinking = session.append_thinking_level_change("low".to_string());
5375 session.set_model_header(
5376 Some("openai".to_string()),
5377 Some("gpt-4o".to_string()),
5378 Some("low".to_string()),
5379 );
5380
5381 assert!(session.create_branch_from(&root_id));
5382 session.append_model_change("anthropic".to_string(), "claude-sonnet-4-5".to_string());
5383 session.append_thinking_level_change("high".to_string());
5384 session.set_model_header(
5385 Some("anthropic".to_string()),
5386 Some("claude-sonnet-4-5".to_string()),
5387 Some("high".to_string()),
5388 );
5389
5390 assert!(session.navigate_to(&branch_a_thinking));
5391 let session = Arc::new(Mutex::new(session));
5392
5393 let extension_session = AgentExtensionSession {
5394 handle: SessionHandle(Arc::clone(&session)),
5395 is_streaming: Arc::new(AtomicBool::new(false)),
5396 is_compacting: Arc::new(AtomicBool::new(false)),
5397 queue_modes: Arc::new(StdMutex::new(ExtensionQueueModeState::new(
5398 QueueMode::OneAtATime,
5399 QueueMode::OneAtATime,
5400 ))),
5401 auto_compaction_enabled: false,
5402 };
5403
5404 let state = <AgentExtensionSession as crate::extensions::ExtensionSession>::get_state(
5405 &extension_session,
5406 )
5407 .await;
5408
5409 assert_eq!(state["model"]["provider"], "openai");
5410 assert_eq!(state["model"]["id"], "gpt-4o");
5411 assert_eq!(state["thinkingLevel"], "low");
5412 });
5413 }
5414
5415 #[test]
5416 fn agent_session_set_queue_modes_updates_extension_delivery_state() {
5417 let provider = Arc::new(NoopProvider);
5418 let tools = ToolRegistry::new(&[], Path::new("."), None);
5419 let agent = Agent::new(provider, tools, AgentConfig::default());
5420 let session = Arc::new(Mutex::new(Session::in_memory()));
5421 let mut agent_session =
5422 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5423
5424 let queue_modes = Arc::new(StdMutex::new(ExtensionQueueModeState::new(
5425 QueueMode::OneAtATime,
5426 QueueMode::OneAtATime,
5427 )));
5428 let injected_queue = Arc::new(StdMutex::new(ExtensionInjectedQueue::new(
5429 QueueMode::OneAtATime,
5430 QueueMode::OneAtATime,
5431 )));
5432 agent_session.extension_queue_modes = Some(Arc::clone(&queue_modes));
5433 agent_session.extension_injected_queue = Some(Arc::clone(&injected_queue));
5434
5435 agent_session.set_queue_modes(QueueMode::All, QueueMode::All);
5436
5437 assert_eq!(
5438 agent_session.agent.queue_modes(),
5439 (QueueMode::All, QueueMode::All)
5440 );
5441 let mirrored = queue_modes.lock().expect("lock queue mode mirror");
5442 assert_eq!(mirrored.steering_mode, QueueMode::All);
5443 assert_eq!(mirrored.follow_up_mode, QueueMode::All);
5444 drop(mirrored);
5445
5446 let queued_follow_up_len = {
5447 let mut queue = injected_queue.lock().expect("lock injected queue");
5448 queue.push_follow_up(Message::User(UserMessage {
5449 content: UserContent::Text("first".to_string()),
5450 timestamp: 0,
5451 }));
5452 queue.push_follow_up(Message::User(UserMessage {
5453 content: UserContent::Text("second".to_string()),
5454 timestamp: 0,
5455 }));
5456 queue.pop_follow_up().len()
5457 };
5458 assert_eq!(
5459 queued_follow_up_len, 2,
5460 "updated queue modes should apply to extension-injected follow-ups"
5461 );
5462 }
5463
5464 #[test]
5465 fn extension_command_send_user_message_runs_agent_turn_when_idle() {
5466 let runtime = RuntimeBuilder::current_thread()
5467 .build()
5468 .expect("runtime build");
5469
5470 runtime.block_on(async {
5471 let temp_dir = tempfile::tempdir().expect("tempdir");
5472 let entry_path = temp_dir.path().join("ext.mjs");
5473 std::fs::write(
5474 &entry_path,
5475 r#"
5476 export default function init(pi) {
5477 pi.registerCommand("inject-user", {
5478 description: "inject a user message",
5479 handler: async () => {
5480 await pi.events("sendUserMessage", {
5481 text: "Please review the changes"
5482 });
5483 return "queued";
5484 }
5485 });
5486 }
5487 "#,
5488 )
5489 .expect("write extension entry");
5490
5491 let provider = Arc::new(IdleCommandProvider);
5492 let tools = ToolRegistry::new(&[], Path::new("."), None);
5493 let agent = Agent::new(provider, tools, AgentConfig::default());
5494 let session = Arc::new(Mutex::new(Session::in_memory()));
5495 let mut agent_session = AgentSession::new(
5496 agent,
5497 Arc::clone(&session),
5498 false,
5499 ResolvedCompactionSettings::default(),
5500 );
5501
5502 agent_session
5503 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5504 .await
5505 .expect("enable extensions");
5506
5507 let value = agent_session
5508 .execute_extension_command("inject-user", "", 5_000, |_| {})
5509 .await
5510 .expect("execute extension command");
5511 assert_eq!(value.as_str(), Some("queued"));
5512
5513 let cx = crate::agent_cx::AgentCx::for_request();
5514 let session_guard = session.lock(cx.cx()).await.expect("lock session");
5515 let messages = session_guard.to_messages_for_current_path();
5516
5517 assert!(
5518 messages.iter().any(|msg| {
5519 matches!(
5520 msg,
5521 Message::User(UserMessage {
5522 content: UserContent::Text(text),
5523 ..
5524 }) if text == "Please review the changes"
5525 )
5526 }),
5527 "expected injected user message in session, got {messages:?}"
5528 );
5529 assert!(
5530 messages.iter().any(|msg| {
5531 matches!(
5532 msg,
5533 Message::Assistant(assistant)
5534 if assistant.content.iter().any(|block| matches!(
5535 block,
5536 ContentBlock::Text(TextContent { text, .. })
5537 if text.as_str().eq("resumed-response-0")
5538 ))
5539 )
5540 }),
5541 "expected assistant response after injected user turn, got {messages:?}"
5542 );
5543 });
5544 }
5545
5546 #[test]
5547 fn send_user_message_steer_skips_remaining_tools() {
5548 let runtime = RuntimeBuilder::current_thread()
5549 .build()
5550 .expect("runtime build");
5551
5552 runtime.block_on(async {
5553 let temp_dir = tempfile::tempdir().expect("tempdir");
5554 let entry_path = temp_dir.path().join("ext.mjs");
5555 std::fs::write(
5556 &entry_path,
5557 r#"
5558 export default function init(pi) {
5559 let sent = false;
5560 pi.on("tool_call", async (event) => {
5561 if (sent) return {};
5562 if (Object.is(event && event.toolName, "count_tool")) {
5563 sent = true;
5564 await pi.events("sendUserMessage", {
5565 text: "steer-now",
5566 options: { deliverAs: "steer" }
5567 });
5568 }
5569 return {};
5570 });
5571 }
5572 "#,
5573 )
5574 .expect("write extension entry");
5575
5576 let provider = Arc::new(ToolUseProvider::new());
5577 let calls = Arc::new(AtomicUsize::new(0));
5578 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5579 calls: Arc::clone(&calls),
5580 })]);
5581 let agent = Agent::new(provider, tools, AgentConfig::default());
5582 let session = Arc::new(Mutex::new(Session::in_memory()));
5583 let mut agent_session =
5584 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5585
5586 agent_session
5587 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5588 .await
5589 .expect("enable extensions");
5590
5591 let _ = agent_session
5592 .run_text("go".to_string(), |_| {})
5593 .await
5594 .expect("run_text");
5595
5596 assert_eq!(calls.load(Ordering::SeqCst), 1);
5598 });
5599 }
5600
5601 #[test]
5602 fn send_user_message_follow_up_does_not_skip_tools() {
5603 let runtime = RuntimeBuilder::current_thread()
5604 .build()
5605 .expect("runtime build");
5606
5607 runtime.block_on(async {
5608 let temp_dir = tempfile::tempdir().expect("tempdir");
5609 let entry_path = temp_dir.path().join("ext.mjs");
5610 std::fs::write(
5611 &entry_path,
5612 r#"
5613 export default function init(pi) {
5614 let sent = false;
5615 pi.on("tool_call", async (event) => {
5616 if (sent) return {};
5617 if (Object.is(event && event.toolName, "count_tool")) {
5618 sent = true;
5619 await pi.events("sendUserMessage", {
5620 text: "follow-up",
5621 options: { deliverAs: "followUp" }
5622 });
5623 }
5624 return {};
5625 });
5626 }
5627 "#,
5628 )
5629 .expect("write extension entry");
5630
5631 let provider = Arc::new(ToolUseProvider::new());
5632 let calls = Arc::new(AtomicUsize::new(0));
5633 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5634 calls: Arc::clone(&calls),
5635 })]);
5636 let agent = Agent::new(provider, tools, AgentConfig::default());
5637 let session = Arc::new(Mutex::new(Session::in_memory()));
5638 let mut agent_session =
5639 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5640
5641 agent_session
5642 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5643 .await
5644 .expect("enable extensions");
5645
5646 let _ = agent_session
5647 .run_text("go".to_string(), |_| {})
5648 .await
5649 .expect("run_text");
5650
5651 assert_eq!(calls.load(Ordering::SeqCst), 2);
5652 });
5653 }
5654
5655 fn test_turn_latency() -> SharedTurnLatencyAccumulator {
5656 Arc::new(StdMutex::new(TurnLatencyAccumulator::started()))
5657 }
5658
5659 #[test]
5660 fn latency_breakdown_reports_component_tail_percentiles() {
5661 let breakdown =
5662 TurnLatencyBreakdown::from_component_samples(250, &[10, 30, 20], &[40, 5], &[2], &[]);
5663
5664 assert_eq!(breakdown.schema, TURN_LATENCY_BREAKDOWN_SCHEMA_V1);
5665 assert_eq!(breakdown.provider_streaming.duration_ms, 60);
5666 assert_eq!(breakdown.provider_streaming.samples, 3);
5667 assert_eq!(breakdown.provider_streaming.tail_percentiles.p50_ms, 20);
5668 assert_eq!(breakdown.provider_streaming.tail_percentiles.p95_ms, 30);
5669 assert_eq!(breakdown.provider_streaming.tail_percentiles.p99_ms, 30);
5670 assert_eq!(breakdown.provider_streaming.tail_percentiles.p999_ms, 30);
5671 assert_eq!(breakdown.local_tools.duration_ms, 45);
5672 assert_eq!(breakdown.extension_hostcalls.duration_ms, 2);
5673 assert_eq!(breakdown.persistence.duration_ms, 0);
5674 assert_eq!(breakdown.dominant_component, "provider_streaming");
5675 }
5676
5677 #[test]
5678 fn latency_breakdown_serializes_without_provider_secrets() {
5679 let breakdown =
5680 TurnLatencyBreakdown::from_component_samples(125, &[100], &[20], &[5], &[0]);
5681 let serialized = serde_json::to_string(&breakdown).expect("serialize latency breakdown");
5682
5683 assert!(serialized.contains(TURN_LATENCY_BREAKDOWN_SCHEMA_V1));
5684 assert!(serialized.contains("providerStreaming"));
5685 assert!(serialized.contains("localTools"));
5686 assert!(serialized.contains("extensionHostcalls"));
5687 assert!(serialized.contains("persistence"));
5688 assert!(!serialized.contains("api_key"));
5689 assert!(!serialized.contains("authorization"));
5690 assert!(!serialized.contains("bearer"));
5691 assert!(!serialized.contains("sk-"));
5692 }
5693
5694 #[test]
5695 fn tool_call_hook_can_block_tool_execution() {
5696 let runtime = RuntimeBuilder::current_thread()
5697 .build()
5698 .expect("runtime build");
5699
5700 runtime.block_on(async {
5701 let temp_dir = tempfile::tempdir().expect("tempdir");
5702 let entry_path = temp_dir.path().join("ext.mjs");
5703 std::fs::write(
5704 &entry_path,
5705 r#"
5706 export default function init(pi) {
5707 pi.on("tool_call", async (event) => {
5708 if (Object.is(event && event.toolName, "count_tool")) {
5709 return { block: true, reason: "blocked in test" };
5710 }
5711 return {};
5712 });
5713 }
5714 "#,
5715 )
5716 .expect("write extension entry");
5717
5718 let provider = Arc::new(NoopProvider);
5719 let calls = Arc::new(AtomicUsize::new(0));
5720 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5721 calls: Arc::clone(&calls),
5722 })]);
5723 let agent = Agent::new(provider, tools, AgentConfig::default());
5724 let session = Arc::new(Mutex::new(Session::in_memory()));
5725 let mut agent_session =
5726 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5727
5728 agent_session
5729 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5730 .await
5731 .expect("enable extensions");
5732
5733 let tool_call = ToolCall {
5734 id: "call-1".to_string(),
5735 name: "count_tool".to_string(),
5736 arguments: json!({}),
5737 thought_signature: None,
5738 };
5739
5740 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
5741 let (output, is_error) = agent_session
5742 .agent
5743 .execute_tool(tool_call, on_event, test_turn_latency())
5744 .await;
5745
5746 assert!(is_error);
5747 assert!(output.is_error);
5748 assert_eq!(calls.load(Ordering::SeqCst), 0);
5749
5750 assert_eq!(output.details, None);
5751 assert!(
5752 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
5753 "Expected text output, got {:?}",
5754 output.content
5755 );
5756 if let [ContentBlock::Text(text)] = output.content.as_slice() {
5757 assert_eq!(text.text, "Tool execution blocked: blocked in test");
5758 }
5759 });
5760 }
5761
5762 #[test]
5763 fn tool_call_hook_errors_fail_open() {
5764 let runtime = RuntimeBuilder::current_thread()
5765 .build()
5766 .expect("runtime build");
5767
5768 runtime.block_on(async {
5769 let temp_dir = tempfile::tempdir().expect("tempdir");
5770 let entry_path = temp_dir.path().join("ext.mjs");
5771 std::fs::write(
5772 &entry_path,
5773 r#"
5774 export default function init(pi) {
5775 pi.on("tool_call", async (_event) => {
5776 throw new Error("boom");
5777 });
5778 }
5779 "#,
5780 )
5781 .expect("write extension entry");
5782
5783 let provider = Arc::new(NoopProvider);
5784 let calls = Arc::new(AtomicUsize::new(0));
5785 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5786 calls: Arc::clone(&calls),
5787 })]);
5788 let agent = Agent::new(provider, tools, AgentConfig::default());
5789 let session = Arc::new(Mutex::new(Session::in_memory()));
5790 let mut agent_session =
5791 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5792
5793 agent_session
5794 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5795 .await
5796 .expect("enable extensions");
5797
5798 let tool_call = ToolCall {
5799 id: "call-1".to_string(),
5800 name: "count_tool".to_string(),
5801 arguments: json!({}),
5802 thought_signature: None,
5803 };
5804
5805 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
5806 let (output, is_error) = agent_session
5807 .agent
5808 .execute_tool(tool_call, on_event, test_turn_latency())
5809 .await;
5810
5811 assert!(!is_error);
5812 assert!(!output.is_error);
5813 assert_eq!(calls.load(Ordering::SeqCst), 1);
5814 });
5815 }
5816
5817 #[test]
5818 fn tool_call_hook_errors_fail_closed_when_configured() {
5819 let runtime = RuntimeBuilder::current_thread()
5820 .build()
5821 .expect("runtime build");
5822
5823 runtime.block_on(async {
5824 let temp_dir = tempfile::tempdir().expect("tempdir");
5825 let entry_path = temp_dir.path().join("ext.mjs");
5826 std::fs::write(
5827 &entry_path,
5828 r#"
5829 export default function init(pi) {
5830 pi.on("tool_call", async (_event) => {
5831 throw new Error("boom");
5832 });
5833 }
5834 "#,
5835 )
5836 .expect("write extension entry");
5837
5838 let provider = Arc::new(NoopProvider);
5839 let calls = Arc::new(AtomicUsize::new(0));
5840 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5841 calls: Arc::clone(&calls),
5842 })]);
5843 let agent = Agent::new(
5844 provider,
5845 tools,
5846 AgentConfig {
5847 fail_closed_hooks: true,
5848 ..AgentConfig::default()
5849 },
5850 );
5851 let session = Arc::new(Mutex::new(Session::in_memory()));
5852 let mut agent_session =
5853 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5854
5855 agent_session
5856 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5857 .await
5858 .expect("enable extensions");
5859
5860 let tool_call = ToolCall {
5861 id: "call-1".to_string(),
5862 name: "count_tool".to_string(),
5863 arguments: json!({}),
5864 thought_signature: None,
5865 };
5866
5867 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
5868 let (output, is_error) = agent_session
5869 .agent
5870 .execute_tool(tool_call, on_event, test_turn_latency())
5871 .await;
5872
5873 assert!(is_error);
5874 assert!(output.is_error);
5875 assert_eq!(calls.load(Ordering::SeqCst), 0);
5876 assert!(
5877 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
5878 "Expected text output, got {:?}",
5879 output.content
5880 );
5881 let [ContentBlock::Text(text)] = output.content.as_slice() else {
5882 return;
5883 };
5884 assert_eq!(text.text, "Tool execution blocked: extension hook failed");
5885 });
5886 }
5887
5888 #[test]
5889 fn tool_call_hook_absent_allows_tool_execution() {
5890 let runtime = RuntimeBuilder::current_thread()
5891 .build()
5892 .expect("runtime build");
5893
5894 runtime.block_on(async {
5895 let temp_dir = tempfile::tempdir().expect("tempdir");
5896 let entry_path = temp_dir.path().join("ext.mjs");
5897 std::fs::write(
5898 &entry_path,
5899 r"
5900 export default function init(_pi) {}
5901 ",
5902 )
5903 .expect("write extension entry");
5904
5905 let provider = Arc::new(NoopProvider);
5906 let calls = Arc::new(AtomicUsize::new(0));
5907 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5908 calls: Arc::clone(&calls),
5909 })]);
5910 let agent = Agent::new(provider, tools, AgentConfig::default());
5911 let session = Arc::new(Mutex::new(Session::in_memory()));
5912 let mut agent_session =
5913 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5914
5915 agent_session
5916 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
5917 .await
5918 .expect("enable extensions");
5919
5920 let tool_call = ToolCall {
5921 id: "call-1".to_string(),
5922 name: "count_tool".to_string(),
5923 arguments: json!({}),
5924 thought_signature: None,
5925 };
5926
5927 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
5928 let (output, is_error) = agent_session
5929 .agent
5930 .execute_tool(tool_call, on_event, test_turn_latency())
5931 .await;
5932
5933 assert!(!is_error);
5934 assert!(!output.is_error);
5935 assert_eq!(calls.load(Ordering::SeqCst), 1);
5936 });
5937 }
5938
5939 #[test]
5940 fn tool_approval_allow_executes_tool() {
5941 let runtime = RuntimeBuilder::current_thread()
5942 .build()
5943 .expect("runtime build");
5944
5945 runtime.block_on(async {
5946 let provider = Arc::new(NoopProvider);
5947 let calls = Arc::new(AtomicUsize::new(0));
5948 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
5949 calls: Arc::clone(&calls),
5950 })]);
5951 let approval_calls = Arc::new(AtomicUsize::new(0));
5952 let approval_counter = Arc::clone(&approval_calls);
5953 let agent = Agent::new(
5954 provider,
5955 tools,
5956 AgentConfig {
5957 tool_approval: Some(Arc::new(move |request| {
5958 assert_eq!(request.tool_call_id, "call-1");
5959 assert_eq!(request.tool_name, "count_tool");
5960 approval_counter.fetch_add(1, Ordering::SeqCst);
5961 Box::pin(async { ToolApprovalDecision::Allow })
5962 })),
5963 ..AgentConfig::default()
5964 },
5965 );
5966 let session = Arc::new(Mutex::new(Session::in_memory()));
5967 let agent_session =
5968 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
5969
5970 let tool_call = ToolCall {
5971 id: "call-1".to_string(),
5972 name: "count_tool".to_string(),
5973 arguments: json!({}),
5974 thought_signature: None,
5975 };
5976
5977 let events = Arc::new(std::sync::Mutex::new(Vec::new()));
5978 let events_for_handler = Arc::clone(&events);
5979 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(move |event| {
5980 if let Ok(mut guard) = events_for_handler.lock() {
5981 guard.push(event);
5982 }
5983 });
5984 let (output, is_error) = agent_session
5985 .agent
5986 .execute_tool(tool_call, on_event, test_turn_latency())
5987 .await;
5988
5989 assert!(!is_error);
5990 assert!(!output.is_error);
5991 assert_eq!(approval_calls.load(Ordering::SeqCst), 1);
5992 assert_eq!(calls.load(Ordering::SeqCst), 1);
5993 let saw_approval_update = events.lock().is_ok_and(|guard| {
5994 guard.iter().any(|event| {
5995 matches!(
5996 event,
5997 AgentEvent::ToolExecutionUpdate {
5998 partial_result,
5999 ..
6000 } if partial_result.details.as_ref().is_some_and(|details| {
6001 details["schema"] == TOOL_APPROVAL_STATUS_SCHEMA_V1
6002 && details["status"] == "approved"
6003 })
6004 )
6005 })
6006 });
6007 assert!(saw_approval_update);
6008 });
6009 }
6010
6011 #[test]
6012 fn tool_approval_deny_blocks_tool_execution() {
6013 let runtime = RuntimeBuilder::current_thread()
6014 .build()
6015 .expect("runtime build");
6016
6017 runtime.block_on(async {
6018 let provider = Arc::new(NoopProvider);
6019 let calls = Arc::new(AtomicUsize::new(0));
6020 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
6021 calls: Arc::clone(&calls),
6022 })]);
6023 let agent = Agent::new(
6024 provider,
6025 tools,
6026 AgentConfig {
6027 tool_approval: Some(Arc::new(|request| {
6028 assert_eq!(request.tool_name, "count_tool");
6029 Box::pin(async { ToolApprovalDecision::deny("denied by approval test") })
6030 })),
6031 ..AgentConfig::default()
6032 },
6033 );
6034 let session = Arc::new(Mutex::new(Session::in_memory()));
6035 let agent_session =
6036 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6037
6038 let tool_call = ToolCall {
6039 id: "call-1".to_string(),
6040 name: "count_tool".to_string(),
6041 arguments: json!({}),
6042 thought_signature: None,
6043 };
6044
6045 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6046 let (output, is_error) = agent_session
6047 .agent
6048 .execute_tool(tool_call, on_event, test_turn_latency())
6049 .await;
6050
6051 assert!(is_error);
6052 assert!(output.is_error);
6053 assert_eq!(calls.load(Ordering::SeqCst), 0);
6054 assert_eq!(
6055 output.details.as_ref().unwrap()["schema"],
6056 TOOL_APPROVAL_DENIED_SCHEMA_V1
6057 );
6058 assert!(
6059 matches!(output.content.as_slice(), [ContentBlock::Text(text)] if text
6060 .text
6061 .contains("denied by approval test"))
6062 );
6063 });
6064 }
6065
6066 #[test]
6067 fn tool_call_hook_returns_empty_allows_tool_execution() {
6068 let runtime = RuntimeBuilder::current_thread()
6069 .build()
6070 .expect("runtime build");
6071
6072 runtime.block_on(async {
6073 let temp_dir = tempfile::tempdir().expect("tempdir");
6074 let entry_path = temp_dir.path().join("ext.mjs");
6075 std::fs::write(
6076 &entry_path,
6077 r#"
6078 export default function init(pi) {
6079 pi.on("tool_call", async (_event) => ({}));
6080 }
6081 "#,
6082 )
6083 .expect("write extension entry");
6084
6085 let provider = Arc::new(NoopProvider);
6086 let calls = Arc::new(AtomicUsize::new(0));
6087 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
6088 calls: Arc::clone(&calls),
6089 })]);
6090 let agent = Agent::new(provider, tools, AgentConfig::default());
6091 let session = Arc::new(Mutex::new(Session::in_memory()));
6092 let mut agent_session =
6093 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6094
6095 agent_session
6096 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
6097 .await
6098 .expect("enable extensions");
6099
6100 let tool_call = ToolCall {
6101 id: "call-1".to_string(),
6102 name: "count_tool".to_string(),
6103 arguments: json!({}),
6104 thought_signature: None,
6105 };
6106
6107 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6108 let (output, is_error) = agent_session
6109 .agent
6110 .execute_tool(tool_call, on_event, test_turn_latency())
6111 .await;
6112
6113 assert!(!is_error);
6114 assert!(!output.is_error);
6115 assert_eq!(calls.load(Ordering::SeqCst), 1);
6116 });
6117 }
6118
6119 #[test]
6120 fn tool_call_hook_can_block_bash_tool_execution() {
6121 let runtime = RuntimeBuilder::current_thread()
6122 .build()
6123 .expect("runtime build");
6124
6125 runtime.block_on(async {
6126 let temp_dir = tempfile::tempdir().expect("tempdir");
6127 let entry_path = temp_dir.path().join("ext.mjs");
6128 std::fs::write(
6129 &entry_path,
6130 r#"
6131 export default function init(pi) {
6132 pi.on("tool_call", async (event) => {
6133 const name = event && event.toolName ? String(event.toolName) : "";
6134 if (name === "bash") return { block: true, reason: "blocked bash in test" };
6135 return {};
6136 });
6137 }
6138 "#,
6139 )
6140 .expect("write extension entry");
6141
6142 let provider = Arc::new(NoopProvider);
6143 let tools = ToolRegistry::new(&["bash"], temp_dir.path(), None);
6144 let agent = Agent::new(provider, tools, AgentConfig::default());
6145 let session = Arc::new(Mutex::new(Session::in_memory()));
6146 let mut agent_session =
6147 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6148
6149 agent_session
6150 .enable_extensions(&["bash"], temp_dir.path(), None, &[entry_path])
6151 .await
6152 .expect("enable extensions");
6153
6154 let tool_call = ToolCall {
6155 id: "call-1".to_string(),
6156 name: "bash".to_string(),
6157 arguments: json!({ "command": "printf 'hi' > blocked.txt" }),
6158 thought_signature: None,
6159 };
6160
6161 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6162 let (output, is_error) = agent_session
6163 .agent
6164 .execute_tool(tool_call, on_event, test_turn_latency())
6165 .await;
6166
6167 assert!(is_error);
6168 assert!(output.is_error);
6169 assert_eq!(output.details, None);
6170 assert!(
6171 !temp_dir.path().join("blocked.txt").exists(),
6172 "expected bash command not to run when blocked"
6173 );
6174 assert!(
6175 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
6176 "Expected text output, got {:?}",
6177 output.content
6178 );
6179 if let [ContentBlock::Text(text)] = output.content.as_slice() {
6180 assert_eq!(text.text, "Tool execution blocked: blocked bash in test");
6181 }
6182 });
6183 }
6184
6185 #[test]
6186 fn tool_result_hook_can_modify_tool_output() {
6187 let runtime = RuntimeBuilder::current_thread()
6188 .build()
6189 .expect("runtime build");
6190
6191 runtime.block_on(async {
6192 let temp_dir = tempfile::tempdir().expect("tempdir");
6193 let entry_path = temp_dir.path().join("ext.mjs");
6194 std::fs::write(
6195 &entry_path,
6196 r#"
6197 export default function init(pi) {
6198 pi.on("tool_result", async (event) => {
6199 if (Object.is(event && event.toolName, "count_tool")) {
6200 return {
6201 content: [{ type: "text", text: "modified" }],
6202 details: { from: "tool_result" }
6203 };
6204 }
6205 return {};
6206 });
6207 }
6208 "#,
6209 )
6210 .expect("write extension entry");
6211
6212 let provider = Arc::new(NoopProvider);
6213 let calls = Arc::new(AtomicUsize::new(0));
6214 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
6215 calls: Arc::clone(&calls),
6216 })]);
6217 let agent = Agent::new(provider, tools, AgentConfig::default());
6218 let session = Arc::new(Mutex::new(Session::in_memory()));
6219 let mut agent_session =
6220 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6221
6222 agent_session
6223 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
6224 .await
6225 .expect("enable extensions");
6226
6227 let tool_call = ToolCall {
6228 id: "call-1".to_string(),
6229 name: "count_tool".to_string(),
6230 arguments: json!({}),
6231 thought_signature: None,
6232 };
6233
6234 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6235 let (output, is_error) = agent_session
6236 .agent
6237 .execute_tool(tool_call, on_event, test_turn_latency())
6238 .await;
6239
6240 assert!(!is_error);
6241 assert!(!output.is_error);
6242 assert_eq!(calls.load(Ordering::SeqCst), 1);
6243 assert_eq!(output.details, Some(json!({ "from": "tool_result" })));
6244
6245 assert!(
6246 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
6247 "Expected text output, got {:?}",
6248 output.content
6249 );
6250 if let [ContentBlock::Text(text)] = output.content.as_slice() {
6251 assert_eq!(text.text, "modified");
6252 }
6253 });
6254 }
6255
6256 #[test]
6257 fn tool_result_hook_can_modify_tool_not_found_error() {
6258 let runtime = RuntimeBuilder::current_thread()
6259 .build()
6260 .expect("runtime build");
6261
6262 runtime.block_on(async {
6263 let temp_dir = tempfile::tempdir().expect("tempdir");
6264 let entry_path = temp_dir.path().join("ext.mjs");
6265 std::fs::write(
6266 &entry_path,
6267 r#"
6268 export default function init(pi) {
6269 pi.on("tool_result", async (event) => {
6270 if (Object.is(event && event.toolName, "missing_tool") && event.isError) {
6271 return {
6272 content: [{ type: "text", text: "overridden" }],
6273 details: { handled: true }
6274 };
6275 }
6276 return {};
6277 });
6278 }
6279 "#,
6280 )
6281 .expect("write extension entry");
6282
6283 let provider = Arc::new(NoopProvider);
6284 let tools = ToolRegistry::from_tools(Vec::new());
6285 let agent = Agent::new(provider, tools, AgentConfig::default());
6286 let session = Arc::new(Mutex::new(Session::in_memory()));
6287 let mut agent_session =
6288 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6289
6290 agent_session
6291 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
6292 .await
6293 .expect("enable extensions");
6294
6295 let tool_call = ToolCall {
6296 id: "call-1".to_string(),
6297 name: "missing_tool".to_string(),
6298 arguments: json!({}),
6299 thought_signature: None,
6300 };
6301
6302 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6303 let (output, is_error) = agent_session
6304 .agent
6305 .execute_tool(tool_call, on_event, test_turn_latency())
6306 .await;
6307
6308 assert!(is_error);
6309 assert!(output.is_error);
6310 assert_eq!(output.details, Some(json!({ "handled": true })));
6311
6312 assert!(
6313 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
6314 "Expected text output, got {:?}",
6315 output.content
6316 );
6317 if let [ContentBlock::Text(text)] = output.content.as_slice() {
6318 assert_eq!(text.text, "overridden");
6319 }
6320 });
6321 }
6322
6323 #[test]
6324 fn tool_result_hook_errors_fail_open() {
6325 let runtime = RuntimeBuilder::current_thread()
6326 .build()
6327 .expect("runtime build");
6328
6329 runtime.block_on(async {
6330 let temp_dir = tempfile::tempdir().expect("tempdir");
6331 let entry_path = temp_dir.path().join("ext.mjs");
6332 std::fs::write(
6333 &entry_path,
6334 r#"
6335 export default function init(pi) {
6336 pi.on("tool_result", async (_event) => {
6337 throw new Error("boom");
6338 });
6339 }
6340 "#,
6341 )
6342 .expect("write extension entry");
6343
6344 let provider = Arc::new(NoopProvider);
6345 let calls = Arc::new(AtomicUsize::new(0));
6346 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
6347 calls: Arc::clone(&calls),
6348 })]);
6349 let agent = Agent::new(provider, tools, AgentConfig::default());
6350 let session = Arc::new(Mutex::new(Session::in_memory()));
6351 let mut agent_session =
6352 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6353
6354 agent_session
6355 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
6356 .await
6357 .expect("enable extensions");
6358
6359 let tool_call = ToolCall {
6360 id: "call-1".to_string(),
6361 name: "count_tool".to_string(),
6362 arguments: json!({}),
6363 thought_signature: None,
6364 };
6365
6366 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6367 let (output, is_error) = agent_session
6368 .agent
6369 .execute_tool(tool_call, on_event, test_turn_latency())
6370 .await;
6371
6372 assert!(!is_error);
6373 assert!(!output.is_error);
6374 assert_eq!(calls.load(Ordering::SeqCst), 1);
6375
6376 assert_eq!(output.details, None);
6377 assert!(
6378 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
6379 "Expected text output, got {:?}",
6380 output.content
6381 );
6382 if let [ContentBlock::Text(text)] = output.content.as_slice() {
6383 assert_eq!(text.text, "ok");
6384 }
6385 });
6386 }
6387
6388 #[test]
6389 fn tool_result_hook_runs_on_blocked_tool_call() {
6390 let runtime = RuntimeBuilder::current_thread()
6391 .build()
6392 .expect("runtime build");
6393
6394 runtime.block_on(async {
6395 let temp_dir = tempfile::tempdir().expect("tempdir");
6396 let entry_path = temp_dir.path().join("ext.mjs");
6397 std::fs::write(
6398 &entry_path,
6399 r#"
6400 export default function init(pi) {
6401 pi.on("tool_call", async (event) => {
6402 if (Object.is(event && event.toolName, "count_tool")) {
6403 return { block: true, reason: "blocked in test" };
6404 }
6405 return {};
6406 });
6407
6408 pi.on("tool_result", async (event) => {
6409 if (Object.is(event && event.toolName, "count_tool") && event.isError) {
6410 return { content: [{ type: "text", text: "override" }] };
6411 }
6412 return {};
6413 });
6414 }
6415 "#,
6416 )
6417 .expect("write extension entry");
6418
6419 let provider = Arc::new(NoopProvider);
6420 let calls = Arc::new(AtomicUsize::new(0));
6421 let tools = ToolRegistry::from_tools(vec![Box::new(CountingTool {
6422 calls: Arc::clone(&calls),
6423 })]);
6424 let agent = Agent::new(provider, tools, AgentConfig::default());
6425 let session = Arc::new(Mutex::new(Session::in_memory()));
6426 let mut agent_session =
6427 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6428
6429 agent_session
6430 .enable_extensions(&[], temp_dir.path(), None, &[entry_path])
6431 .await
6432 .expect("enable extensions");
6433
6434 let tool_call = ToolCall {
6435 id: "call-1".to_string(),
6436 name: "count_tool".to_string(),
6437 arguments: json!({}),
6438 thought_signature: None,
6439 };
6440
6441 let on_event: Arc<dyn Fn(AgentEvent) + Send + Sync> = Arc::new(|_| {});
6442 let (output, is_error) = agent_session
6443 .agent
6444 .execute_tool(tool_call, on_event, test_turn_latency())
6445 .await;
6446
6447 assert!(is_error);
6448 assert!(output.is_error);
6449 assert_eq!(calls.load(Ordering::SeqCst), 0);
6450
6451 assert!(
6452 matches!(output.content.as_slice(), [ContentBlock::Text(_)]),
6453 "Expected text output, got {:?}",
6454 output.content
6455 );
6456 if let [ContentBlock::Text(text)] = output.content.as_slice() {
6457 assert_eq!(text.text, "override");
6458 }
6459 });
6460 }
6461}
6462
6463#[cfg(test)]
6464mod abort_tests {
6465 use super::*;
6466 use crate::session::Session;
6467 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
6468 use asupersync::runtime::RuntimeBuilder;
6469 use async_trait::async_trait;
6470 use futures::Stream;
6471 use serde_json::json;
6472 use std::path::Path;
6473 use std::pin::Pin;
6474 use std::sync::Mutex as StdMutex;
6475 use std::sync::atomic::AtomicUsize;
6476 use std::task::{Context as TaskContext, Poll};
6477
6478 struct StartThenPending {
6479 start: Option<StreamEvent>,
6480 }
6481
6482 impl Stream for StartThenPending {
6483 type Item = crate::error::Result<StreamEvent>;
6484
6485 fn poll_next(
6486 mut self: Pin<&mut Self>,
6487 _cx: &mut TaskContext<'_>,
6488 ) -> Poll<Option<Self::Item>> {
6489 if let Some(event) = self.start.take() {
6490 return Poll::Ready(Some(Ok(event)));
6491 }
6492 Poll::Pending
6493 }
6494 }
6495
6496 #[derive(Debug)]
6497 struct HangingProvider;
6498
6499 #[async_trait]
6500 #[allow(clippy::unnecessary_literal_bound)]
6501 impl Provider for HangingProvider {
6502 fn name(&self) -> &str {
6503 "test-provider"
6504 }
6505
6506 fn api(&self) -> &str {
6507 "test-api"
6508 }
6509
6510 fn model_id(&self) -> &str {
6511 "test-model"
6512 }
6513
6514 async fn stream(
6515 &self,
6516 _context: &Context<'_>,
6517 _options: &StreamOptions,
6518 ) -> crate::error::Result<
6519 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
6520 > {
6521 let partial = AssistantMessage {
6522 content: Vec::new(),
6523 api: self.api().to_string(),
6524 provider: self.name().to_string(),
6525 model: self.model_id().to_string(),
6526 usage: Usage::default(),
6527 stop_reason: StopReason::Stop,
6528 error_message: None,
6529 timestamp: 0,
6530 };
6531
6532 Ok(Box::pin(StartThenPending {
6533 start: Some(StreamEvent::Start { partial }),
6534 }))
6535 }
6536 }
6537
6538 #[derive(Debug)]
6539 struct CountingProvider {
6540 calls: Arc<std::sync::atomic::AtomicUsize>,
6541 }
6542
6543 #[async_trait]
6544 #[allow(clippy::unnecessary_literal_bound)]
6545 impl Provider for CountingProvider {
6546 fn name(&self) -> &str {
6547 "test-provider"
6548 }
6549
6550 fn api(&self) -> &str {
6551 "test-api"
6552 }
6553
6554 fn model_id(&self) -> &str {
6555 "test-model"
6556 }
6557
6558 async fn stream(
6559 &self,
6560 _context: &Context<'_>,
6561 _options: &StreamOptions,
6562 ) -> crate::error::Result<
6563 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
6564 > {
6565 self.calls.fetch_add(1, Ordering::SeqCst);
6566 Ok(Box::pin(futures::stream::empty()))
6567 }
6568 }
6569
6570 #[derive(Debug)]
6571 struct PhasedProvider {
6572 pending_calls: usize,
6573 calls: AtomicUsize,
6574 }
6575
6576 impl PhasedProvider {
6577 const fn new(pending_calls: usize) -> Self {
6578 Self {
6579 pending_calls,
6580 calls: AtomicUsize::new(0),
6581 }
6582 }
6583
6584 fn base_message() -> AssistantMessage {
6585 AssistantMessage {
6586 content: Vec::new(),
6587 api: "test-api".to_string(),
6588 provider: "test-provider".to_string(),
6589 model: "test-model".to_string(),
6590 usage: Usage::default(),
6591 stop_reason: StopReason::Stop,
6592 error_message: None,
6593 timestamp: 0,
6594 }
6595 }
6596 }
6597
6598 #[async_trait]
6599 #[allow(clippy::unnecessary_literal_bound)]
6600 impl Provider for PhasedProvider {
6601 fn name(&self) -> &str {
6602 "test-provider"
6603 }
6604
6605 fn api(&self) -> &str {
6606 "test-api"
6607 }
6608
6609 fn model_id(&self) -> &str {
6610 "test-model"
6611 }
6612
6613 async fn stream(
6614 &self,
6615 _context: &Context<'_>,
6616 _options: &StreamOptions,
6617 ) -> crate::error::Result<
6618 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
6619 > {
6620 let call = self.calls.fetch_add(1, Ordering::SeqCst);
6621 if call < self.pending_calls {
6622 return Ok(Box::pin(StartThenPending {
6623 start: Some(StreamEvent::Start {
6624 partial: Self::base_message(),
6625 }),
6626 }));
6627 }
6628
6629 let partial = Self::base_message();
6630 let mut done = Self::base_message();
6631 done.content = vec![ContentBlock::Text(TextContent::new(format!(
6632 "resumed-response-{call}"
6633 )))];
6634
6635 Ok(Box::pin(futures::stream::iter(vec![
6636 Ok(StreamEvent::Start { partial }),
6637 Ok(StreamEvent::Done {
6638 reason: StopReason::Stop,
6639 message: done,
6640 }),
6641 ])))
6642 }
6643 }
6644
6645 #[derive(Debug)]
6646 struct ToolCallProvider;
6647
6648 #[async_trait]
6649 #[allow(clippy::unnecessary_literal_bound)]
6650 impl Provider for ToolCallProvider {
6651 fn name(&self) -> &str {
6652 "test-provider"
6653 }
6654
6655 fn api(&self) -> &str {
6656 "test-api"
6657 }
6658
6659 fn model_id(&self) -> &str {
6660 "test-model"
6661 }
6662
6663 async fn stream(
6664 &self,
6665 _context: &Context<'_>,
6666 _options: &StreamOptions,
6667 ) -> crate::error::Result<
6668 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
6669 > {
6670 let message = AssistantMessage {
6671 content: vec![ContentBlock::ToolCall(ToolCall {
6672 id: "call-1".to_string(),
6673 name: "hanging_tool".to_string(),
6674 arguments: json!({}),
6675 thought_signature: None,
6676 })],
6677 api: "test-api".to_string(),
6678 provider: "test-provider".to_string(),
6679 model: "test-model".to_string(),
6680 usage: Usage::default(),
6681 stop_reason: StopReason::ToolUse,
6682 error_message: None,
6683 timestamp: 0,
6684 };
6685
6686 Ok(Box::pin(futures::stream::iter(vec![Ok(
6687 StreamEvent::Done {
6688 reason: StopReason::ToolUse,
6689 message,
6690 },
6691 )])))
6692 }
6693 }
6694
6695 #[derive(Debug)]
6696 struct HangingTool;
6697
6698 #[async_trait]
6699 #[allow(clippy::unnecessary_literal_bound)]
6700 impl Tool for HangingTool {
6701 fn name(&self) -> &str {
6702 "hanging_tool"
6703 }
6704
6705 fn label(&self) -> &str {
6706 "Hanging Tool"
6707 }
6708
6709 fn description(&self) -> &str {
6710 "Never completes unless aborted by the host"
6711 }
6712
6713 fn parameters(&self) -> serde_json::Value {
6714 json!({
6715 "type": "object",
6716 "properties": {},
6717 "additionalProperties": false
6718 })
6719 }
6720
6721 async fn execute(
6722 &self,
6723 _tool_call_id: &str,
6724 _input: serde_json::Value,
6725 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
6726 ) -> crate::error::Result<ToolOutput> {
6727 futures::future::pending::<()>().await;
6728 unreachable!("hanging tool should be aborted by the agent")
6729 }
6730 }
6731
6732 fn event_tag(event: &AgentEvent) -> &'static str {
6733 match event {
6734 AgentEvent::AgentStart { .. } => "agent_start",
6735 AgentEvent::AgentEnd { error, .. } => {
6736 if error.as_deref() == Some("Aborted") {
6737 "agent_end_aborted"
6738 } else {
6739 "agent_end"
6740 }
6741 }
6742 AgentEvent::TurnStart { .. } => "turn_start",
6743 AgentEvent::TurnEnd { .. } => "turn_end",
6744 AgentEvent::MessageStart { .. } => "message_start",
6745 AgentEvent::MessageUpdate {
6746 assistant_message_event,
6747 ..
6748 } => match &assistant_message_event {
6749 AssistantMessageEvent::Error {
6750 reason: StopReason::Aborted,
6751 ..
6752 } => "assistant_error_aborted",
6753 AssistantMessageEvent::Done { .. } => "assistant_done",
6754 _ => "assistant_update",
6755 },
6756 AgentEvent::MessageEnd { .. } => "message_end",
6757 AgentEvent::ToolExecutionStart { .. } => "tool_start",
6758 AgentEvent::ToolExecutionUpdate { .. } => "tool_update",
6759 AgentEvent::ToolExecutionEnd { .. } => "tool_end",
6760 AgentEvent::AutoCompactionStart { .. } => "auto_compaction_start",
6761 AgentEvent::AutoCompactionEnd { .. } => "auto_compaction_end",
6762 AgentEvent::AutoRetryStart { .. } => "auto_retry_start",
6763 AgentEvent::AutoRetryEnd { .. } => "auto_retry_end",
6764 AgentEvent::ExtensionError { .. } => "extension_error",
6765 }
6766 }
6767
6768 fn assert_abort_resume_message_sequence(persisted: &[Message]) {
6769 assert_eq!(
6770 persisted.len(),
6771 6,
6772 "expected three user+assistant pairs, got: {persisted:?}"
6773 );
6774
6775 let assistant_states = persisted
6776 .iter()
6777 .filter_map(|message| match message {
6778 Message::Assistant(assistant) => Some(assistant.stop_reason),
6779 _ => None,
6780 })
6781 .collect::<Vec<_>>();
6782 assert_eq!(
6783 assistant_states,
6784 vec![StopReason::Aborted, StopReason::Aborted, StopReason::Stop]
6785 );
6786 }
6787
6788 fn assert_abort_resume_timeline_boundaries(timeline: &[String]) {
6789 assert!(
6790 timeline
6791 .iter()
6792 .any(|event| event.as_str().eq("run0:agent_end_aborted")),
6793 "missing aborted boundary for first run: {timeline:?}"
6794 );
6795 assert!(
6796 timeline
6797 .iter()
6798 .any(|event| event.as_str().eq("run1:agent_end_aborted")),
6799 "missing aborted boundary for second run: {timeline:?}"
6800 );
6801 assert!(
6802 timeline
6803 .iter()
6804 .any(|event| event.as_str().eq("run2:agent_end")),
6805 "missing successful boundary for resumed run: {timeline:?}"
6806 );
6807 }
6808
6809 #[test]
6810 fn abort_interrupts_in_flight_stream() {
6811 let runtime = RuntimeBuilder::current_thread()
6812 .build()
6813 .expect("runtime build");
6814 let handle = runtime.handle();
6815
6816 let started = Arc::new(Notify::new());
6817 let started_wait = started.notified();
6818
6819 let (abort_handle, abort_signal) = AbortHandle::new();
6820
6821 let provider = Arc::new(HangingProvider);
6822 let tools = ToolRegistry::new(&[], Path::new("."), None);
6823 let agent = Agent::new(provider, tools, AgentConfig::default());
6824 let session = Arc::new(Mutex::new(Session::in_memory()));
6825 let mut agent_session =
6826 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6827
6828 let started_tx = Arc::clone(&started);
6829 let join = handle.spawn(async move {
6830 agent_session
6831 .run_text_with_abort("hello".to_string(), Some(abort_signal), move |event| {
6832 if matches!(
6833 event,
6834 AgentEvent::MessageStart {
6835 message: Message::Assistant(_)
6836 }
6837 ) {
6838 started_tx.notify_one();
6839 }
6840 })
6841 .await
6842 });
6843
6844 runtime.block_on(async move {
6845 started_wait.await;
6846 abort_handle.abort();
6847
6848 let message = join.await.expect("run_text_with_abort");
6849 assert_eq!(message.stop_reason, StopReason::Aborted);
6850 assert_eq!(message.error_message.as_deref(), Some("Aborted"));
6851 });
6852 }
6853
6854 #[test]
6855 fn ambient_cancellation_interrupts_in_flight_stream() {
6856 let runtime = RuntimeBuilder::current_thread()
6857 .build()
6858 .expect("runtime build");
6859
6860 runtime.block_on(async move {
6861 let (started_tx, started_rx) = std::sync::mpsc::channel();
6862
6863 let provider = Arc::new(HangingProvider);
6864 let tools = ToolRegistry::new(&[], Path::new("."), None);
6865 let agent = Agent::new(provider, tools, AgentConfig::default());
6866 let session = Arc::new(Mutex::new(Session::in_memory()));
6867 let mut agent_session =
6868 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6869
6870 let ambient_cx = asupersync::Cx::for_testing();
6871 let cancel_cx = ambient_cx.clone();
6872 let _current = asupersync::Cx::set_current(Some(ambient_cx));
6873
6874 let cancel_thread = std::thread::spawn(move || {
6875 started_rx
6876 .recv_timeout(std::time::Duration::from_secs(1))
6877 .expect("stream start");
6878 cancel_cx.set_cancel_requested(true);
6879 });
6880
6881 let run = agent_session.run_text_with_abort("hello".to_string(), None, move |event| {
6882 if matches!(
6883 event,
6884 AgentEvent::MessageStart {
6885 message: Message::Assistant(_)
6886 }
6887 ) {
6888 let _ = started_tx.send(());
6889 }
6890 });
6891 futures::pin_mut!(run);
6892
6893 let message = asupersync::time::timeout(
6894 asupersync::time::wall_now(),
6895 std::time::Duration::from_secs(1),
6896 run,
6897 )
6898 .await
6899 .expect("ambient cancellation should finish before timeout")
6900 .expect("run_text_with_abort");
6901
6902 cancel_thread.join().expect("cancel thread");
6903
6904 assert_eq!(message.stop_reason, StopReason::Aborted);
6905 assert_eq!(message.error_message.as_deref(), Some("Aborted"));
6906 });
6907 }
6908
6909 #[test]
6910 fn abort_before_run_skips_provider_stream_call() {
6911 let runtime = RuntimeBuilder::current_thread()
6912 .build()
6913 .expect("runtime build");
6914
6915 let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
6916 let provider = Arc::new(CountingProvider {
6917 calls: Arc::clone(&calls),
6918 });
6919 let tools = ToolRegistry::new(&[], Path::new("."), None);
6920 let agent = Agent::new(provider, tools, AgentConfig::default());
6921 let session = Arc::new(Mutex::new(Session::in_memory()));
6922 let mut agent_session =
6923 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
6924
6925 let (abort_handle, abort_signal) = AbortHandle::new();
6926 abort_handle.abort();
6927
6928 runtime.block_on(async move {
6929 let message = agent_session
6930 .run_text_with_abort("hello".to_string(), Some(abort_signal), |_| {})
6931 .await
6932 .expect("run_text_with_abort");
6933 assert_eq!(message.stop_reason, StopReason::Aborted);
6934 assert_eq!(calls.load(Ordering::SeqCst), 0);
6935 });
6936 }
6937
6938 #[test]
6939 fn abort_then_resume_preserves_session_history() {
6940 let runtime = RuntimeBuilder::current_thread()
6941 .build()
6942 .expect("runtime build");
6943 let handle = runtime.handle();
6944
6945 runtime.block_on(async move {
6946 let provider = Arc::new(PhasedProvider::new(1));
6947 let tools = ToolRegistry::new(&[], Path::new("."), None);
6948 let agent = Agent::new(provider, tools, AgentConfig::default());
6949 let session = Arc::new(Mutex::new(Session::in_memory()));
6950 let mut agent_session = AgentSession::new(
6951 agent,
6952 Arc::clone(&session),
6953 false,
6954 ResolvedCompactionSettings::default(),
6955 );
6956
6957 let started = Arc::new(Notify::new());
6958 let (abort_handle, abort_signal) = AbortHandle::new();
6959 let started_for_abort = Arc::clone(&started);
6960 let abort_join = handle.spawn(async move {
6961 started_for_abort.notified().await;
6962 abort_handle.abort();
6963 });
6964
6965 let aborted = agent_session
6966 .run_text_with_abort("first".to_string(), Some(abort_signal), {
6967 let started = Arc::clone(&started);
6968 move |event| {
6969 if matches!(
6970 event,
6971 AgentEvent::MessageStart {
6972 message: Message::Assistant(_)
6973 }
6974 ) {
6975 started.notify_one();
6976 }
6977 }
6978 })
6979 .await
6980 .expect("first run");
6981 abort_join.await;
6982
6983 assert_eq!(aborted.stop_reason, StopReason::Aborted);
6984 assert_eq!(aborted.error_message.as_deref(), Some("Aborted"));
6985
6986 let resumed = agent_session
6987 .run_text("second".to_string(), |_| {})
6988 .await
6989 .expect("resumed run");
6990 assert_eq!(resumed.stop_reason, StopReason::Stop);
6991 assert!(resumed.error_message.is_none());
6992
6993 let cx = crate::agent_cx::AgentCx::for_request();
6994 let persisted = session
6995 .lock(cx.cx())
6996 .await
6997 .expect("lock session")
6998 .to_messages_for_current_path();
6999
7000 assert_eq!(
7001 persisted.len(),
7002 4,
7003 "unexpected message history after abort+resume: {persisted:?}"
7004 );
7005 assert!(matches!(persisted.first(), Some(Message::User(_))));
7006 assert!(matches!(
7007 persisted.get(1),
7008 Some(Message::Assistant(assistant))
7009 if matches!(assistant.stop_reason, StopReason::Aborted)
7010 ));
7011 assert!(matches!(persisted.get(2), Some(Message::User(_))));
7012 assert!(matches!(
7013 persisted.get(3),
7014 Some(Message::Assistant(assistant))
7015 if matches!(assistant.stop_reason, StopReason::Stop)
7016 && assistant.error_message.is_none()
7017 ));
7018 });
7019 }
7020
7021 #[test]
7022 fn repeated_abort_then_resume_has_consistent_timeline_and_state() {
7023 let runtime = RuntimeBuilder::current_thread()
7024 .build()
7025 .expect("runtime build");
7026 let handle = runtime.handle();
7027
7028 runtime.block_on(async move {
7029 let provider = Arc::new(PhasedProvider::new(2));
7030 let tools = ToolRegistry::new(&[], Path::new("."), None);
7031 let agent = Agent::new(provider, tools, AgentConfig::default());
7032 let session = Arc::new(Mutex::new(Session::in_memory()));
7033 let mut agent_session = AgentSession::new(
7034 agent,
7035 Arc::clone(&session),
7036 false,
7037 ResolvedCompactionSettings::default(),
7038 );
7039
7040 let timeline = Arc::new(StdMutex::new(Vec::<String>::new()));
7041
7042 for run_idx in 0..2 {
7043 let started = Arc::new(Notify::new());
7044 let (abort_handle, abort_signal) = AbortHandle::new();
7045 let started_for_abort = Arc::clone(&started);
7046 let abort_join = handle.spawn(async move {
7047 started_for_abort.notified().await;
7048 abort_handle.abort();
7049 });
7050
7051 let run_timeline = Arc::clone(&timeline);
7052 let aborted = agent_session
7053 .run_text_with_abort(format!("abort-run-{run_idx}"), Some(abort_signal), {
7054 let started = Arc::clone(&started);
7055 move |event| {
7056 if let Ok(mut events) = run_timeline.lock() {
7057 events.push(format!("run{run_idx}:{}", event_tag(&event)));
7058 }
7059 if matches!(
7060 event,
7061 AgentEvent::MessageStart {
7062 message: Message::Assistant(_)
7063 }
7064 ) {
7065 started.notify_one();
7066 }
7067 }
7068 })
7069 .await
7070 .expect("aborted run");
7071 abort_join.await;
7072
7073 assert_eq!(
7074 aborted.stop_reason,
7075 StopReason::Aborted,
7076 "run {run_idx} should abort cleanly"
7077 );
7078 }
7079
7080 let run_timeline = Arc::clone(&timeline);
7081 let resumed = agent_session
7082 .run_text("final-run".to_string(), move |event| {
7083 if let Ok(mut events) = run_timeline.lock() {
7084 events.push(format!("run2:{}", event_tag(&event)));
7085 }
7086 })
7087 .await
7088 .expect("final resumed run");
7089 assert_eq!(resumed.stop_reason, StopReason::Stop);
7090 assert!(resumed.error_message.is_none());
7091
7092 let cx = crate::agent_cx::AgentCx::for_request();
7093 let persisted = session
7094 .lock(cx.cx())
7095 .await
7096 .expect("lock session")
7097 .to_messages_for_current_path();
7098
7099 assert_abort_resume_message_sequence(&persisted);
7100
7101 let timeline = timeline
7102 .lock()
7103 .unwrap_or_else(std::sync::PoisonError::into_inner)
7104 .clone();
7105 assert_abort_resume_timeline_boundaries(&timeline);
7106 });
7107 }
7108
7109 #[test]
7110 fn abort_during_tool_execution_records_aborted_tool_result() {
7111 let runtime = RuntimeBuilder::current_thread()
7112 .build()
7113 .expect("runtime build");
7114 let handle = runtime.handle();
7115
7116 runtime.block_on(async move {
7117 let provider = Arc::new(ToolCallProvider);
7118 let tools = ToolRegistry::from_tools(vec![Box::new(HangingTool)]);
7119 let agent = Agent::new(provider, tools, AgentConfig::default());
7120 let session = Arc::new(Mutex::new(Session::in_memory()));
7121 let mut agent_session = AgentSession::new(
7122 agent,
7123 Arc::clone(&session),
7124 false,
7125 ResolvedCompactionSettings::default(),
7126 );
7127
7128 let tool_started = Arc::new(Notify::new());
7129 let (abort_handle, abort_signal) = AbortHandle::new();
7130 let tool_started_for_abort = Arc::clone(&tool_started);
7131 let abort_join = handle.spawn(async move {
7132 tool_started_for_abort.notified().await;
7133 abort_handle.abort();
7134 });
7135
7136 let result = agent_session
7137 .run_text_with_abort("trigger tool".to_string(), Some(abort_signal), {
7138 let tool_started = Arc::clone(&tool_started);
7139 move |event| {
7140 if matches!(event, AgentEvent::ToolExecutionStart { .. }) {
7141 tool_started.notify_one();
7142 }
7143 }
7144 })
7145 .await
7146 .expect("tool-abort run");
7147 abort_join.await;
7148 assert_eq!(result.stop_reason, StopReason::Aborted);
7149
7150 let cx = crate::agent_cx::AgentCx::for_request();
7151 let persisted = session
7152 .lock(cx.cx())
7153 .await
7154 .expect("lock session")
7155 .to_messages_for_current_path();
7156
7157 let tool_result = persisted
7158 .iter()
7159 .find_map(|message| match message {
7160 Message::ToolResult(result) => Some(result),
7161 _ => None,
7162 })
7163 .expect("expected tool result message");
7164 assert!(tool_result.is_error);
7165 assert!(
7166 tool_result.content.iter().any(|block| {
7167 matches!(
7168 block,
7169 ContentBlock::Text(text) if text.text.contains("Tool execution aborted")
7170 )
7171 }),
7172 "missing aborted tool marker in tool output: {:?}",
7173 tool_result.content
7174 );
7175 let details = tool_result
7176 .details
7177 .as_ref()
7178 .expect("aborted tool result should include structured details");
7179 assert_eq!(details["schema"], TOOL_CANCELLATION_SCHEMA_V1);
7180 assert_eq!(details["status"], "cancelled");
7181 assert_eq!(details["reason"], "abort_signal");
7182 assert_eq!(details["toolName"], "hanging_tool");
7183 assert_eq!(details["cleanup"], "tool_result_recorded_no_success");
7184 });
7185 }
7186}
7187
7188#[cfg(test)]
7189mod turn_event_tests {
7190 use super::*;
7191 use crate::session::Session;
7192 use crate::tools::{Tool, ToolOutput, ToolRegistry, ToolUpdate};
7193 use asupersync::runtime::RuntimeBuilder;
7194 use async_trait::async_trait;
7195 use futures::Stream;
7196 use serde_json::json;
7197 use std::path::Path;
7198 use std::pin::Pin;
7199 use std::sync::atomic::AtomicUsize;
7200 fn assistant_message(text: &str) -> AssistantMessage {
7204 AssistantMessage {
7205 content: vec![ContentBlock::Text(TextContent::new(text))],
7206 api: "test-api".to_string(),
7207 provider: "test-provider".to_string(),
7208 model: "test-model".to_string(),
7209 usage: Usage::default(),
7210 stop_reason: StopReason::Stop,
7211 error_message: None,
7212 timestamp: 0,
7213 }
7214 }
7215
7216 struct SingleShotProvider;
7217
7218 #[async_trait]
7219 #[allow(clippy::unnecessary_literal_bound)]
7220 impl Provider for SingleShotProvider {
7221 fn name(&self) -> &str {
7222 "test-provider"
7223 }
7224
7225 fn api(&self) -> &str {
7226 "test-api"
7227 }
7228
7229 fn model_id(&self) -> &str {
7230 "test-model"
7231 }
7232
7233 async fn stream(
7234 &self,
7235 _context: &Context<'_>,
7236 _options: &StreamOptions,
7237 ) -> crate::error::Result<
7238 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
7239 > {
7240 let partial = assistant_message("");
7241 let final_message = assistant_message("hello");
7242 let events = vec![
7243 Ok(StreamEvent::Start { partial }),
7244 Ok(StreamEvent::Done {
7245 reason: StopReason::Stop,
7246 message: final_message,
7247 }),
7248 ];
7249 Ok(Box::pin(futures::stream::iter(events)))
7250 }
7251 }
7252
7253 struct StreamSetupErrorProvider;
7254
7255 #[async_trait]
7256 #[allow(clippy::unnecessary_literal_bound)]
7257 impl Provider for StreamSetupErrorProvider {
7258 fn name(&self) -> &str {
7259 "test-provider"
7260 }
7261
7262 fn api(&self) -> &str {
7263 "test-api"
7264 }
7265
7266 fn model_id(&self) -> &str {
7267 "test-model"
7268 }
7269
7270 async fn stream(
7271 &self,
7272 _context: &Context<'_>,
7273 _options: &StreamOptions,
7274 ) -> crate::error::Result<
7275 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
7276 > {
7277 Err(Error::api("stream setup failed"))
7278 }
7279 }
7280
7281 #[derive(Debug)]
7282 struct EchoTool;
7283
7284 #[async_trait]
7285 #[allow(clippy::unnecessary_literal_bound)]
7286 impl Tool for EchoTool {
7287 fn name(&self) -> &str {
7288 "echo_tool"
7289 }
7290
7291 fn label(&self) -> &str {
7292 "echo_tool"
7293 }
7294
7295 fn description(&self) -> &str {
7296 "echo test tool"
7297 }
7298
7299 fn parameters(&self) -> serde_json::Value {
7300 json!({ "type": "object" })
7301 }
7302
7303 async fn execute(
7304 &self,
7305 _tool_call_id: &str,
7306 _input: serde_json::Value,
7307 _on_update: Option<Box<dyn Fn(ToolUpdate) + Send + Sync>>,
7308 ) -> Result<ToolOutput> {
7309 Ok(ToolOutput {
7310 content: vec![ContentBlock::Text(TextContent::new("tool-ok"))],
7311 details: None,
7312 is_error: false,
7313 })
7314 }
7315 }
7316
7317 #[derive(Debug)]
7318 struct ToolTurnProvider {
7319 calls: AtomicUsize,
7320 }
7321
7322 impl ToolTurnProvider {
7323 const fn new() -> Self {
7324 Self {
7325 calls: AtomicUsize::new(0),
7326 }
7327 }
7328
7329 fn assistant_message_with(
7330 &self,
7331 stop_reason: StopReason,
7332 content: Vec<ContentBlock>,
7333 ) -> AssistantMessage {
7334 AssistantMessage {
7335 content,
7336 api: self.api().to_string(),
7337 provider: self.name().to_string(),
7338 model: self.model_id().to_string(),
7339 usage: Usage::default(),
7340 stop_reason,
7341 error_message: None,
7342 timestamp: 0,
7343 }
7344 }
7345 }
7346
7347 #[async_trait]
7348 #[allow(clippy::unnecessary_literal_bound)]
7349 impl Provider for ToolTurnProvider {
7350 fn name(&self) -> &str {
7351 "test-provider"
7352 }
7353
7354 fn api(&self) -> &str {
7355 "test-api"
7356 }
7357
7358 fn model_id(&self) -> &str {
7359 "test-model"
7360 }
7361
7362 async fn stream(
7363 &self,
7364 _context: &Context<'_>,
7365 _options: &StreamOptions,
7366 ) -> crate::error::Result<
7367 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
7368 > {
7369 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
7370 let partial = self.assistant_message_with(StopReason::Stop, Vec::new());
7371 let done = if call_index == 0 {
7372 self.assistant_message_with(
7373 StopReason::ToolUse,
7374 vec![ContentBlock::ToolCall(ToolCall {
7375 id: "tool-1".to_string(),
7376 name: "echo_tool".to_string(),
7377 arguments: json!({}),
7378 thought_signature: None,
7379 })],
7380 )
7381 } else {
7382 self.assistant_message_with(
7383 StopReason::Stop,
7384 vec![ContentBlock::Text(TextContent::new("final"))],
7385 )
7386 };
7387
7388 Ok(Box::pin(futures::stream::iter(vec![
7389 Ok(StreamEvent::Start { partial }),
7390 Ok(StreamEvent::Done {
7391 reason: done.stop_reason,
7392 message: done,
7393 }),
7394 ])))
7395 }
7396 }
7397
7398 #[test]
7399 fn turn_events_wrap_assistant_response() {
7400 let runtime = RuntimeBuilder::current_thread()
7401 .build()
7402 .expect("runtime build");
7403 let handle = runtime.handle();
7404
7405 let provider = Arc::new(SingleShotProvider);
7406 let tools = ToolRegistry::new(&[], Path::new("."), None);
7407 let agent = Agent::new(provider, tools, AgentConfig::default());
7408 let session = Arc::new(Mutex::new(Session::in_memory()));
7409 let mut agent_session =
7410 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
7411
7412 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
7413 Arc::new(std::sync::Mutex::new(Vec::new()));
7414 let events_capture = Arc::clone(&events);
7415
7416 let join = handle.spawn(async move {
7417 agent_session
7418 .run_text("hello".to_string(), move |event| {
7419 events_capture
7420 .lock()
7421 .unwrap_or_else(std::sync::PoisonError::into_inner)
7422 .push(event);
7423 })
7424 .await
7425 .expect("run_text")
7426 });
7427
7428 runtime.block_on(async move {
7429 let message = join.await;
7430 assert_eq!(message.stop_reason, StopReason::Stop);
7431
7432 let events = events
7433 .lock()
7434 .unwrap_or_else(std::sync::PoisonError::into_inner);
7435 let turn_start_indices = events
7436 .iter()
7437 .enumerate()
7438 .filter_map(|(idx, event)| {
7439 matches!(event, AgentEvent::TurnStart { .. }).then_some(idx)
7440 })
7441 .collect::<Vec<_>>();
7442 let turn_end_indices = events
7443 .iter()
7444 .enumerate()
7445 .filter_map(|(idx, event)| {
7446 matches!(event, AgentEvent::TurnEnd { .. }).then_some(idx)
7447 })
7448 .collect::<Vec<_>>();
7449
7450 assert_eq!(turn_start_indices.len(), 1);
7451 assert_eq!(turn_end_indices.len(), 1);
7452 assert!(turn_start_indices[0] < turn_end_indices[0]);
7453
7454 let assistant_message_end = events
7455 .iter()
7456 .enumerate()
7457 .find_map(|(idx, event)| match event {
7458 AgentEvent::MessageEnd {
7459 message: Message::Assistant(_),
7460 } => Some(idx),
7461 _ => None,
7462 })
7463 .expect("assistant message end");
7464
7465 assert!(assistant_message_end < turn_end_indices[0]);
7466
7467 let (message_is_assistant, tool_results_empty) = {
7468 let turn_end_event = &events[turn_end_indices[0]];
7469 assert!(
7470 matches!(turn_end_event, AgentEvent::TurnEnd { .. }),
7471 "Expected TurnEnd event, got {turn_end_event:?}"
7472 );
7473 match turn_end_event {
7474 AgentEvent::TurnEnd {
7475 message,
7476 tool_results,
7477 ..
7478 } => (
7479 matches!(message, Message::Assistant(_)),
7480 tool_results.is_empty(),
7481 ),
7482 _ => (false, false),
7483 }
7484 };
7485 drop(events);
7486 assert!(message_is_assistant);
7487 assert!(tool_results_empty);
7488 });
7489 }
7490
7491 #[test]
7492 fn stream_setup_errors_still_emit_turn_end_before_agent_end() {
7493 let runtime = RuntimeBuilder::current_thread()
7494 .build()
7495 .expect("runtime build");
7496 let handle = runtime.handle();
7497
7498 let provider = Arc::new(StreamSetupErrorProvider);
7499 let tools = ToolRegistry::new(&[], Path::new("."), None);
7500 let agent = Agent::new(provider, tools, AgentConfig::default());
7501 let session = Arc::new(Mutex::new(Session::in_memory()));
7502 let mut agent_session =
7503 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
7504
7505 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
7506 Arc::new(std::sync::Mutex::new(Vec::new()));
7507 let events_capture = Arc::clone(&events);
7508
7509 let join = handle.spawn(async move {
7510 agent_session
7511 .run_text("hello".to_string(), move |event| {
7512 events_capture
7513 .lock()
7514 .unwrap_or_else(std::sync::PoisonError::into_inner)
7515 .push(event);
7516 })
7517 .await
7518 .expect_err("run_text should fail before streaming starts")
7519 });
7520
7521 runtime.block_on(async move {
7522 let err = join.await;
7523 assert!(
7524 err.to_string().contains("stream setup failed"),
7525 "unexpected error: {err}"
7526 );
7527
7528 let events = events
7529 .lock()
7530 .unwrap_or_else(std::sync::PoisonError::into_inner);
7531 let turn_start_idx = events
7532 .iter()
7533 .position(|event| matches!(event, AgentEvent::TurnStart { turn_index: 0, .. }))
7534 .expect("turn start");
7535 let turn_end_idx = events
7536 .iter()
7537 .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
7538 .expect("turn end");
7539 let agent_end_idx = events
7540 .iter()
7541 .position(|event| matches!(event, AgentEvent::AgentEnd { .. }))
7542 .expect("agent end");
7543
7544 assert!(turn_start_idx < turn_end_idx);
7545 assert!(turn_end_idx < agent_end_idx);
7546
7547 let assistant_message_end = events
7548 .iter()
7549 .position(|event| {
7550 matches!(
7551 event,
7552 AgentEvent::MessageEnd {
7553 message: Message::Assistant(_),
7554 }
7555 )
7556 })
7557 .expect("assistant message end");
7558 assert!(assistant_message_end < turn_end_idx);
7559
7560 match &events[turn_end_idx] {
7561 AgentEvent::TurnEnd {
7562 message,
7563 tool_results,
7564 ..
7565 } => {
7566 assert!(tool_results.is_empty());
7567 assert!(
7568 matches!(message, Message::Assistant(_)),
7569 "expected assistant message in TurnEnd, got {message:?}"
7570 );
7571 let Message::Assistant(message) = message else {
7572 return;
7573 };
7574 assert_eq!(message.stop_reason, StopReason::Error);
7575 assert_eq!(
7576 message.error_message.as_deref(),
7577 Some("API error: stream setup failed")
7578 );
7579 assert_eq!(message.api, "test-api");
7580 assert_eq!(message.provider, "test-provider");
7581 assert_eq!(message.model, "test-model");
7582 }
7583 other => {
7584 assert!(matches!(other, AgentEvent::TurnEnd { .. }));
7585 return;
7586 }
7587 }
7588
7589 match &events[agent_end_idx] {
7590 AgentEvent::AgentEnd { error, .. } => {
7591 assert_eq!(error.as_deref(), Some("API error: stream setup failed"));
7592 }
7593 other => {
7594 assert!(matches!(other, AgentEvent::AgentEnd { .. }));
7595 }
7596 }
7597 });
7598 }
7599
7600 #[test]
7601 fn turn_events_include_tool_execution_and_tool_result_messages() {
7602 let runtime = RuntimeBuilder::current_thread()
7603 .build()
7604 .expect("runtime build");
7605 let handle = runtime.handle();
7606
7607 let provider = Arc::new(ToolTurnProvider::new());
7608 let tools = ToolRegistry::from_tools(vec![Box::new(EchoTool)]);
7609 let agent = Agent::new(provider, tools, AgentConfig::default());
7610 let session = Arc::new(Mutex::new(Session::in_memory()));
7611 let mut agent_session =
7612 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
7613
7614 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
7615 Arc::new(std::sync::Mutex::new(Vec::new()));
7616 let events_capture = Arc::clone(&events);
7617
7618 let join = handle.spawn(async move {
7619 agent_session
7620 .run_text("hello".to_string(), move |event| {
7621 events_capture
7622 .lock()
7623 .unwrap_or_else(std::sync::PoisonError::into_inner)
7624 .push(event);
7625 })
7626 .await
7627 .expect("run_text")
7628 });
7629
7630 runtime.block_on(async move {
7631 let message = join.await;
7632 assert_eq!(message.stop_reason, StopReason::Stop);
7633
7634 let events = events
7635 .lock()
7636 .unwrap_or_else(std::sync::PoisonError::into_inner);
7637 let turn_start_count = events
7638 .iter()
7639 .filter(|event| matches!(event, AgentEvent::TurnStart { .. }))
7640 .count();
7641 let turn_end_count = events
7642 .iter()
7643 .filter(|event| matches!(event, AgentEvent::TurnEnd { .. }))
7644 .count();
7645 assert_eq!(
7646 turn_start_count, 2,
7647 "expected one tool turn and one final turn"
7648 );
7649 assert_eq!(
7650 turn_end_count, 2,
7651 "expected one tool turn and one final turn"
7652 );
7653
7654 let tool_start_idx = events
7655 .iter()
7656 .position(|event| matches!(event, AgentEvent::ToolExecutionStart { .. }))
7657 .expect("tool execution start event");
7658 let tool_end_idx = events
7659 .iter()
7660 .position(|event| matches!(event, AgentEvent::ToolExecutionEnd { .. }))
7661 .expect("tool execution end event");
7662 assert!(tool_start_idx < tool_end_idx);
7663
7664 let first_turn_end_idx = events
7665 .iter()
7666 .position(|event| matches!(event, AgentEvent::TurnEnd { turn_index: 0, .. }))
7667 .expect("first turn end");
7668 assert!(
7669 tool_end_idx < first_turn_end_idx,
7670 "tool execution should complete before first turn end"
7671 );
7672
7673 let first_turn_tool_results = events.iter().find_map(|event| match event {
7674 AgentEvent::TurnEnd {
7675 turn_index,
7676 tool_results,
7677 ..
7678 } if turn_index.eq(&0) => Some(tool_results),
7679 _ => None,
7680 });
7681
7682 let first_turn_tool_results =
7683 first_turn_tool_results.expect("expected tool results for first turn");
7684 assert_eq!(first_turn_tool_results.len(), 1);
7685 let first_result = first_turn_tool_results.first().unwrap();
7686 if let Message::ToolResult(tr) = first_result {
7687 assert_eq!(tr.tool_name, "echo_tool");
7688 assert!(!tr.is_error);
7689 } else {
7690 unreachable!("expected Message::ToolResult, got {:?}", first_result);
7691 }
7692 drop(events);
7693 });
7694 }
7695}
7696
7697#[derive(Clone)]
7698struct AgentExtensionSession {
7699 handle: SessionHandle,
7700 is_streaming: Arc<AtomicBool>,
7701 is_compacting: Arc<AtomicBool>,
7702 queue_modes: Arc<StdMutex<ExtensionQueueModeState>>,
7703 auto_compaction_enabled: bool,
7704}
7705
7706impl AgentExtensionSession {
7707 fn current_queue_modes(&self) -> (QueueMode, QueueMode) {
7708 self.queue_modes
7709 .lock()
7710 .map_or((QueueMode::OneAtATime, QueueMode::OneAtATime), |state| {
7711 (state.steering_mode, state.follow_up_mode)
7712 })
7713 }
7714
7715 fn state_fallback(&self) -> Value {
7716 let (steering_mode, follow_up_mode) = self.current_queue_modes();
7717 json!({
7718 "model": null,
7719 "thinkingLevel": "off",
7720 "durabilityMode": "balanced",
7721 "isStreaming": self.is_streaming.load(std::sync::atomic::Ordering::SeqCst),
7722 "isCompacting": self.is_compacting.load(std::sync::atomic::Ordering::SeqCst),
7723 "steeringMode": steering_mode.as_str(),
7724 "followUpMode": follow_up_mode.as_str(),
7725 "sessionFile": null,
7726 "sessionId": "",
7727 "sessionName": null,
7728 "autoCompactionEnabled": self.auto_compaction_enabled,
7729 "messageCount": 0,
7730 "pendingMessageCount": 0,
7731 })
7732 }
7733}
7734
7735#[async_trait]
7736impl crate::extensions::ExtensionSession for AgentExtensionSession {
7737 async fn get_state(&self) -> Value {
7738 let (steering_mode, follow_up_mode) = self.current_queue_modes();
7739 let mut state =
7740 <SessionHandle as crate::extensions::ExtensionSession>::get_state(&self.handle).await;
7741 let Some(object) = state.as_object_mut() else {
7742 return self.state_fallback();
7743 };
7744
7745 object.insert(
7746 "isStreaming".to_string(),
7747 Value::Bool(self.is_streaming.load(std::sync::atomic::Ordering::SeqCst)),
7748 );
7749 object.insert(
7750 "isCompacting".to_string(),
7751 Value::Bool(self.is_compacting.load(std::sync::atomic::Ordering::SeqCst)),
7752 );
7753 object.insert(
7754 "steeringMode".to_string(),
7755 Value::String(steering_mode.as_str().to_string()),
7756 );
7757 object.insert(
7758 "followUpMode".to_string(),
7759 Value::String(follow_up_mode.as_str().to_string()),
7760 );
7761 object.insert(
7762 "autoCompactionEnabled".to_string(),
7763 Value::Bool(self.auto_compaction_enabled),
7764 );
7765
7766 state
7767 }
7768
7769 async fn get_messages(&self) -> Vec<crate::session::SessionMessage> {
7770 <SessionHandle as crate::extensions::ExtensionSession>::get_messages(&self.handle).await
7771 }
7772
7773 async fn get_entries(&self) -> Vec<Value> {
7774 <SessionHandle as crate::extensions::ExtensionSession>::get_entries(&self.handle).await
7775 }
7776
7777 async fn get_branch(&self) -> Vec<Value> {
7778 <SessionHandle as crate::extensions::ExtensionSession>::get_branch(&self.handle).await
7779 }
7780
7781 async fn set_name(&self, name: String) -> crate::error::Result<()> {
7782 <SessionHandle as crate::extensions::ExtensionSession>::set_name(&self.handle, name).await
7783 }
7784
7785 async fn append_message(
7786 &self,
7787 message: crate::session::SessionMessage,
7788 ) -> crate::error::Result<()> {
7789 <SessionHandle as crate::extensions::ExtensionSession>::append_message(
7790 &self.handle,
7791 message,
7792 )
7793 .await
7794 }
7795
7796 async fn append_custom_entry(
7797 &self,
7798 custom_type: String,
7799 data: Option<Value>,
7800 ) -> crate::error::Result<()> {
7801 <SessionHandle as crate::extensions::ExtensionSession>::append_custom_entry(
7802 &self.handle,
7803 custom_type,
7804 data,
7805 )
7806 .await
7807 }
7808
7809 async fn set_model(&self, provider: String, model_id: String) -> crate::error::Result<()> {
7810 <SessionHandle as crate::extensions::ExtensionSession>::set_model(
7811 &self.handle,
7812 provider,
7813 model_id,
7814 )
7815 .await
7816 }
7817
7818 async fn get_model(&self) -> (Option<String>, Option<String>) {
7819 <SessionHandle as crate::extensions::ExtensionSession>::get_model(&self.handle).await
7820 }
7821
7822 async fn set_thinking_level(&self, level: String) -> crate::error::Result<()> {
7823 <SessionHandle as crate::extensions::ExtensionSession>::set_thinking_level(
7824 &self.handle,
7825 level,
7826 )
7827 .await
7828 }
7829
7830 async fn get_thinking_level(&self) -> Option<String> {
7831 <SessionHandle as crate::extensions::ExtensionSession>::get_thinking_level(&self.handle)
7832 .await
7833 }
7834
7835 async fn set_label(
7836 &self,
7837 target_id: String,
7838 label: Option<String>,
7839 ) -> crate::error::Result<()> {
7840 <SessionHandle as crate::extensions::ExtensionSession>::set_label(
7841 &self.handle,
7842 target_id,
7843 label,
7844 )
7845 .await
7846 }
7847}
7848
7849impl AgentSession {
7850 pub const fn runtime_repair_mode_from_policy_mode(mode: RepairPolicyMode) -> RepairMode {
7851 match mode {
7852 RepairPolicyMode::Off => RepairMode::Off,
7853 RepairPolicyMode::Suggest => RepairMode::Suggest,
7854 RepairPolicyMode::AutoSafe => RepairMode::AutoSafe,
7855 RepairPolicyMode::AutoStrict => RepairMode::AutoStrict,
7856 }
7857 }
7858
7859 #[allow(clippy::too_many_arguments)]
7860 async fn start_js_extension_runtime(
7861 stage: &'static str,
7862 cwd: &std::path::Path,
7863 tools: Arc<ToolRegistry>,
7864 manager: ExtensionManager,
7865 policy: ExtensionPolicy,
7866 repair_mode: RepairMode,
7867 memory_limit_bytes: usize,
7868 ) -> Result<ExtensionRuntimeHandle> {
7869 let mut config = PiJsRuntimeConfig {
7870 cwd: cwd.display().to_string(),
7871 repair_mode,
7872 ..PiJsRuntimeConfig::default()
7873 };
7874 config.limits.memory_limit_bytes = Some(memory_limit_bytes).filter(|bytes| *bytes > 0);
7875
7876 let runtime =
7877 JsExtensionRuntimeHandle::start_with_policy(config, tools, manager, policy).await?;
7878 tracing::info!(
7879 event = "pi.extension_runtime.engine_decision",
7880 stage,
7881 requested = "quickjs",
7882 selected = "quickjs",
7883 fallback = false,
7884 "Extension runtime engine selected (legacy JS/TS)"
7885 );
7886 Ok(ExtensionRuntimeHandle::Js(runtime))
7887 }
7888
7889 #[allow(clippy::too_many_arguments)]
7890 async fn start_native_extension_runtime(
7891 stage: &'static str,
7892 _cwd: &std::path::Path,
7893 _tools: Arc<ToolRegistry>,
7894 _manager: ExtensionManager,
7895 _policy: ExtensionPolicy,
7896 _repair_mode: RepairMode,
7897 _memory_limit_bytes: usize,
7898 ) -> Result<ExtensionRuntimeHandle> {
7899 let runtime = NativeRustExtensionRuntimeHandle::start().await?;
7900 tracing::info!(
7901 event = "pi.extension_runtime.engine_decision",
7902 stage,
7903 requested = "native-rust",
7904 selected = "native-rust",
7905 fallback = false,
7906 "Extension runtime engine selected (native-rust)"
7907 );
7908 Ok(ExtensionRuntimeHandle::NativeRust(runtime))
7909 }
7910
7911 pub fn new(
7912 agent: Agent,
7913 session: Arc<Mutex<Session>>,
7914 save_enabled: bool,
7915 compaction_settings: ResolvedCompactionSettings,
7916 ) -> Self {
7917 let extension_ai_completion = Arc::new(StdMutex::new(ExtensionAiCompletionHostState {
7918 provider: agent.provider(),
7919 stream_options: agent.stream_options().clone(),
7920 models: Vec::new(),
7921 }));
7922
7923 Self {
7924 agent,
7925 session,
7926 save_enabled,
7927 input_source: InputSource::Interactive,
7928 extensions: None,
7929 extensions_is_streaming: Arc::new(AtomicBool::new(false)),
7930 extensions_is_compacting: Arc::new(AtomicBool::new(false)),
7931 extensions_turn_active: Arc::new(AtomicBool::new(false)),
7932 extensions_pending_idle_actions: Arc::new(StdMutex::new(VecDeque::new())),
7933 extension_queue_modes: None,
7934 extension_injected_queue: None,
7935 extension_ai_completion,
7936 compaction_settings,
7937 compaction_runtime: None,
7938 runtime_handle: None,
7939 compaction_worker: CompactionWorkerState::new(CompactionQuota::default()),
7940 model_registry: None,
7941 auth_storage: None,
7942 api_key_override: None,
7943 semantic_context_bundle: None,
7944 }
7945 }
7946
7947 pub const fn set_input_source(&mut self, source: InputSource) {
7948 self.input_source = source;
7949 }
7950
7951 #[must_use]
7952 pub fn with_runtime_handle(mut self, runtime_handle: RuntimeHandle) -> Self {
7953 self.compaction_runtime = None;
7954 self.runtime_handle = Some(runtime_handle);
7955 self
7956 }
7957
7958 #[must_use]
7959 pub fn with_model_registry(mut self, registry: ModelRegistry) -> Self {
7960 self.set_model_registry(registry);
7961 self
7962 }
7963
7964 #[must_use]
7965 pub fn with_auth_storage(mut self, auth: AuthStorage) -> Self {
7966 self.auth_storage = Some(auth);
7967 self
7968 }
7969
7970 pub fn set_model_registry(&mut self, registry: ModelRegistry) {
7971 self.set_extension_ai_models(pi_ai_model_registry_values(®istry));
7972 self.model_registry = Some(registry);
7973 }
7974
7975 pub fn set_auth_storage(&mut self, auth: AuthStorage) {
7976 self.auth_storage = Some(auth);
7977 }
7978
7979 #[must_use]
7980 pub fn with_api_key_override(mut self, api_key: Option<String>) -> Self {
7981 self.set_api_key_override(api_key);
7982 self
7983 }
7984
7985 pub fn set_api_key_override(&mut self, api_key: Option<String>) {
7986 self.api_key_override = normalize_api_key_opt(api_key);
7987 }
7988
7989 pub fn refresh_extension_completion_host_state(&self) {
7990 let Ok(mut state) = self.extension_ai_completion.lock() else {
7991 tracing::error!("extension completion host state mutex poisoned; keeping stale state");
7992 return;
7993 };
7994 state.provider = self.agent.provider();
7995 state.stream_options = self.agent.stream_options().clone();
7996 }
7997
7998 fn set_extension_ai_models(&self, models: Vec<Value>) {
7999 let Ok(mut state) = self.extension_ai_completion.lock() else {
8000 tracing::error!(
8001 "extension completion host state mutex poisoned; keeping stale model catalog"
8002 );
8003 return;
8004 };
8005 state.models = models;
8006 }
8007
8008 pub fn set_semantic_context_bundle(
8009 &mut self,
8010 injection: Option<SemanticContextBundleInjection>,
8011 ) {
8012 self.semantic_context_bundle = injection;
8013 }
8014
8015 pub const fn semantic_context_bundle(&self) -> Option<&SemanticContextBundleInjection> {
8016 self.semantic_context_bundle.as_ref()
8017 }
8018
8019 pub fn set_queue_modes(&mut self, steering_mode: QueueMode, follow_up_mode: QueueMode) {
8020 self.agent.set_queue_modes(steering_mode, follow_up_mode);
8021
8022 if let Some(queue_modes) = &self.extension_queue_modes
8023 && let Ok(mut state) = queue_modes.lock()
8024 {
8025 state.set_modes(steering_mode, follow_up_mode);
8026 }
8027
8028 if let Some(injected_queue) = &self.extension_injected_queue
8029 && let Ok(mut queue) = injected_queue.lock()
8030 {
8031 queue.set_modes(steering_mode, follow_up_mode);
8032 }
8033 }
8034
8035 pub const fn set_compaction_context_window(&mut self, context_window_tokens: u32) {
8036 self.compaction_settings.context_window_tokens = context_window_tokens;
8037 }
8038
8039 pub async fn set_provider_model(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
8040 let already_active = {
8041 let provider = self.agent.provider();
8042 provider.name().eq(provider_id) && provider.model_id().eq(model_id)
8043 };
8044 let current_thinking = self
8045 .agent
8046 .stream_options()
8047 .thinking_level
8048 .unwrap_or_default();
8049
8050 let target_entry = self
8051 .model_registry
8052 .as_ref()
8053 .and_then(|registry| registry.find(provider_id, model_id));
8054 let next_thinking = if let Some(target_entry) = target_entry {
8055 let resolved_key = self.resolve_stream_api_key_for_model(&target_entry);
8056 if !already_active
8057 && model_requires_configured_credential(&target_entry)
8058 && resolved_key.is_none()
8059 {
8060 return Err(Error::auth(format!(
8061 "Missing credentials for {provider_id}/{model_id}"
8062 )));
8063 }
8064 self.clamp_thinking_level_for_model(provider_id, model_id, current_thinking)
8065 } else if already_active {
8066 current_thinking
8067 } else {
8068 return Err(Error::validation(format!(
8069 "Unable to switch provider/model to {provider_id}/{model_id}"
8070 )));
8071 };
8072
8073 if !already_active {
8074 self.apply_session_model_selection(provider_id, model_id)?;
8075 }
8076 self.agent.stream_options_mut().thinking_level = Some(next_thinking);
8077 self.refresh_extension_completion_host_state();
8078
8079 {
8080 let cx = crate::agent_cx::AgentCx::for_request();
8081 let mut session = self
8082 .session
8083 .lock(cx.cx())
8084 .await
8085 .map_err(|e| Error::session(e.to_string()))?;
8086 let previous_model = session.effective_model_for_current_path();
8087 let previous_thinking = session
8088 .effective_thinking_level_for_current_path()
8089 .as_deref()
8090 .and_then(|value| value.parse::<crate::model::ThinkingLevel>().ok());
8091 if previous_model
8092 .as_ref()
8093 .map(|(provider, model_id)| (provider.as_str(), model_id.as_str()))
8094 != Some((provider_id, model_id))
8095 {
8096 session.append_model_change(provider_id.to_string(), model_id.to_string());
8097 }
8098 session.set_model_header(
8099 Some(provider_id.to_string()),
8100 Some(model_id.to_string()),
8101 Some(next_thinking.to_string()),
8102 );
8103 if !previous_thinking.is_some_and(|previous| previous.eq(&next_thinking)) {
8104 session.append_thinking_level_change(next_thinking.to_string());
8105 }
8106 }
8107
8108 self.persist_session().await
8109 }
8110
8111 pub(crate) fn clamp_thinking_level_for_model(
8112 &self,
8113 provider_id: &str,
8114 model_id: &str,
8115 level: crate::model::ThinkingLevel,
8116 ) -> crate::model::ThinkingLevel {
8117 self.model_registry
8118 .as_ref()
8119 .and_then(|registry| registry.find(provider_id, model_id))
8120 .map_or(level, |entry| entry.clamp_thinking_level(level))
8121 }
8122
8123 fn resolve_stream_api_key_for_model(&self, entry: &ModelEntry) -> Option<String> {
8124 let normalize = |key_opt: Option<String>| {
8125 key_opt.and_then(|key| {
8126 let trimmed = key.trim();
8127 (!trimmed.is_empty()).then(|| trimmed.to_string())
8128 })
8129 };
8130
8131 normalize(self.api_key_override.clone())
8132 .or_else(|| {
8133 self.auth_storage
8134 .as_ref()
8135 .and_then(|auth| normalize(auth.resolve_api_key(&entry.model.provider, None)))
8136 })
8137 .or_else(|| normalize(entry.api_key.clone()))
8138 }
8139
8140 pub(crate) async fn sync_runtime_selection_from_session_header(&mut self) -> Result<()> {
8141 let session_state = {
8142 let cx = crate::agent_cx::AgentCx::for_request();
8143 let session = self
8144 .session
8145 .lock(cx.cx())
8146 .await
8147 .map_err(|e| Error::session(e.to_string()))?;
8148 (
8149 session.effective_model_for_current_path(),
8150 session.effective_thinking_level_for_current_path(),
8151 )
8152 };
8153
8154 let (session_model, session_thinking) = session_state;
8155 let current_thinking = self
8156 .agent
8157 .stream_options()
8158 .thinking_level
8159 .unwrap_or_default();
8160
8161 if let Some((provider_id, model_id)) = session_model.as_ref() {
8162 self.apply_session_model_selection(provider_id, model_id)?;
8163 }
8164
8165 let parsed_session_thinking = session_thinking.as_deref().and_then(|raw| {
8166 raw.parse::<crate::model::ThinkingLevel>().map_or_else(
8167 |_| {
8168 tracing::warn!("Ignoring invalid session thinking level: {raw}");
8169 None
8170 },
8171 Some,
8172 )
8173 });
8174 let requested = parsed_session_thinking.unwrap_or(current_thinking);
8175
8176 let effective = if let Some((provider_id, model_id)) = session_model.as_ref() {
8177 self.clamp_thinking_level_for_model(provider_id, model_id, requested)
8178 } else {
8179 requested
8180 };
8181
8182 self.agent.stream_options_mut().thinking_level = Some(effective);
8183 self.refresh_extension_completion_host_state();
8184
8185 let thinking_changed = !effective.eq(¤t_thinking);
8186 let persist_needed = if session_thinking.is_some() {
8187 !parsed_session_thinking.is_some_and(|parsed| parsed.eq(&effective))
8188 } else {
8189 thinking_changed
8190 };
8191 if !persist_needed {
8192 return Ok(());
8193 }
8194
8195 {
8196 let cx = crate::agent_cx::AgentCx::for_request();
8197 let mut session = self
8198 .session
8199 .lock(cx.cx())
8200 .await
8201 .map_err(|e| Error::session(e.to_string()))?;
8202 let previous_thinking = session
8203 .header
8204 .thinking_level
8205 .as_deref()
8206 .and_then(|value| value.parse::<crate::model::ThinkingLevel>().ok());
8207 session.set_model_header(None, None, Some(effective.to_string()));
8208 if thinking_changed
8209 && !previous_thinking.is_some_and(|previous| previous.eq(&effective))
8210 {
8211 session.append_thinking_level_change(effective.to_string());
8212 }
8213 }
8214
8215 self.persist_session().await
8216 }
8217
8218 fn apply_session_model_selection(&mut self, provider_id: &str, model_id: &str) -> Result<()> {
8219 if self.agent.provider().name().eq(provider_id)
8220 && self.agent.provider().model_id().eq(model_id)
8221 {
8222 return Ok(());
8223 }
8224
8225 let Some(registry) = &self.model_registry else {
8226 return Err(Error::validation(format!(
8227 "Unable to switch provider/model to {provider_id}/{model_id}"
8228 )));
8229 };
8230
8231 let Some(entry) = registry.find(provider_id, model_id) else {
8232 return Err(Error::validation(format!(
8233 "Unable to switch provider/model to {provider_id}/{model_id}"
8234 )));
8235 };
8236
8237 let resolved_key = self.resolve_stream_api_key_for_model(&entry);
8238 if model_requires_configured_credential(&entry) && resolved_key.is_none() {
8239 return Err(Error::auth(format!(
8240 "Missing credentials for {provider_id}/{model_id}"
8241 )));
8242 }
8243
8244 match crate::providers::create_provider(
8245 &entry,
8246 self.extensions.as_ref().map(ExtensionRegion::manager),
8247 ) {
8248 Ok(provider) => {
8249 tracing::info!("Updating agent provider to {provider_id}/{model_id}");
8250 self.agent.set_provider(provider);
8251
8252 let stream_options = self.agent.stream_options_mut();
8253 stream_options.api_key.clone_from(&resolved_key);
8254 stream_options.headers.clone_from(&entry.headers);
8255 self.refresh_extension_completion_host_state();
8256 Ok(())
8257 }
8258 Err(e) => Err(Error::validation(format!(
8259 "Unable to switch provider/model to {provider_id}/{model_id}: {e}"
8260 ))),
8261 }
8262 }
8263
8264 pub const fn save_enabled(&self) -> bool {
8265 self.save_enabled
8266 }
8267
8268 pub async fn compact_now(
8270 &mut self,
8271 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
8272 ) -> Result<()> {
8273 self.compact_synchronous(Arc::new(on_event)).await
8274 }
8275
8276 pub async fn execute_extension_command(
8277 &mut self,
8278 command_name: &str,
8279 args: &str,
8280 timeout_ms: u64,
8281 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
8282 ) -> Result<Value> {
8283 self.execute_extension_command_with_abort(command_name, args, timeout_ms, None, on_event)
8284 .await
8285 }
8286
8287 pub async fn execute_extension_command_with_abort(
8288 &mut self,
8289 command_name: &str,
8290 args: &str,
8291 timeout_ms: u64,
8292 abort: Option<AbortSignal>,
8293 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
8294 ) -> Result<Value> {
8295 let manager = self
8296 .extensions
8297 .as_ref()
8298 .map(ExtensionRegion::manager)
8299 .ok_or_else(|| Error::extension("Extensions are disabled"))?
8300 .clone();
8301 let on_event: AgentEventHandler = Arc::new(on_event);
8302
8303 self.run_pending_idle_actions_with_abort(abort.clone(), Arc::clone(&on_event))
8304 .await?;
8305
8306 let command_result = manager
8307 .execute_command(command_name, args, timeout_ms)
8308 .await;
8309 let replay_result = self
8310 .run_pending_idle_actions_with_abort(abort, Arc::clone(&on_event))
8311 .await;
8312
8313 match command_result {
8314 Ok(value) => {
8315 replay_result?;
8316 Ok(value)
8317 }
8318 Err(err) => {
8319 if let Err(replay_err) = replay_result {
8320 tracing::warn!(
8321 "extension command follow-up replay failed after command error: {replay_err}"
8322 );
8323 }
8324 Err(err)
8325 }
8326 }
8327 }
8328
8329 #[allow(clippy::too_many_lines)]
8335 async fn maybe_compact(&mut self, on_event: AgentEventHandler) -> Result<()> {
8336 if !self.compaction_settings.enabled {
8337 return Ok(());
8338 }
8339
8340 if let Some(outcome) = self.compaction_worker.try_recv().await {
8342 self.extensions_is_compacting
8343 .store(false, std::sync::atomic::Ordering::SeqCst);
8344 match outcome {
8345 Ok(result) => {
8346 self.apply_compaction_result(result, Arc::clone(&on_event))
8347 .await?;
8348 }
8349 Err(e) => {
8350 on_event(AgentEvent::AutoCompactionEnd {
8351 result: None,
8352 aborted: false,
8353 will_retry: false,
8354 error_message: Some(e.to_string()),
8355 });
8356 }
8357 }
8358 }
8359
8360 if !self.compaction_worker.can_start() {
8362 return Ok(());
8363 }
8364
8365 let (entries, preparation) = {
8366 let cx = crate::agent_cx::AgentCx::for_request();
8367 let mut session = self
8368 .session
8369 .lock(cx.cx())
8370 .await
8371 .map_err(|e| Error::session(e.to_string()))?;
8372 session.ensure_entry_ids();
8373 let entries = session
8374 .entries_for_current_path()
8375 .into_iter()
8376 .cloned()
8377 .collect::<Vec<_>>();
8378 let prep = compaction::prepare_compaction(&entries, self.compaction_settings.clone());
8379 (entries, prep)
8380 };
8381
8382 if let Some(prep) = preparation {
8383 let admission = self
8384 .compaction_worker
8385 .admission_decision(Some(&prep), &CompactionAdmissionSignals::default());
8386 if !admission.allowed {
8387 tracing::info!(
8388 reason = admission.reason.as_str(),
8389 tokens_before = admission.tokens_before,
8390 "Background compaction admission denied"
8391 );
8392 return Ok(());
8393 }
8394
8395 on_event(AgentEvent::AutoCompactionStart {
8396 reason: format!("threshold;admission={}", admission.reason.as_str()),
8397 });
8398
8399 let before_outcome = self.dispatch_before_compact(&prep, &entries, None).await;
8400 if before_outcome.cancel {
8401 on_event(AgentEvent::AutoCompactionEnd {
8402 result: None,
8403 aborted: true,
8404 will_retry: false,
8405 error_message: None,
8406 });
8407 return Ok(());
8408 }
8409
8410 if let Some(compaction) = before_outcome.compaction {
8411 let result_value = Some(Self::auto_compaction_result_payload(
8412 compaction.summary.clone(),
8413 compaction.first_kept_entry_id.clone(),
8414 compaction.tokens_before,
8415 compaction.details.clone(),
8416 ));
8417 self.extensions_is_compacting
8418 .store(true, std::sync::atomic::Ordering::SeqCst);
8419 let apply_result = self
8420 .apply_compaction_entry(
8421 compaction.summary,
8422 compaction.first_kept_entry_id,
8423 compaction.tokens_before,
8424 compaction.details,
8425 true,
8426 )
8427 .await;
8428 self.extensions_is_compacting
8429 .store(false, std::sync::atomic::Ordering::SeqCst);
8430 apply_result?;
8431 on_event(AgentEvent::AutoCompactionEnd {
8432 result: result_value,
8433 aborted: false,
8434 will_retry: false,
8435 error_message: None,
8436 });
8437 return Ok(());
8438 }
8439
8440 let provider = self.agent.provider();
8441 let credential = self
8442 .agent
8443 .stream_options()
8444 .api_key
8445 .clone()
8446 .unwrap_or_default();
8447
8448 let runtime_handle = match self.compaction_runtime_handle() {
8449 Ok(runtime_handle) => runtime_handle,
8450 Err(e) => {
8451 on_event(AgentEvent::AutoCompactionEnd {
8452 result: None,
8453 aborted: false,
8454 will_retry: false,
8455 error_message: Some(e.to_string()),
8456 });
8457 return Ok(());
8458 }
8459 };
8460
8461 self.compaction_worker
8462 .start(&runtime_handle, prep, provider, credential, None);
8463 self.extensions_is_compacting
8464 .store(true, std::sync::atomic::Ordering::SeqCst);
8465 }
8466
8467 Ok(())
8468 }
8469
8470 fn compaction_runtime_handle(&mut self) -> Result<RuntimeHandle> {
8471 if let Some(runtime_handle) = self.runtime_handle.clone() {
8472 return Ok(runtime_handle);
8473 }
8474
8475 let runtime = RuntimeBuilder::new().build().map_err(|e| {
8476 Error::session(format!("Background compaction runtime init failed: {e}"))
8477 })?;
8478 let runtime_handle = runtime.handle();
8479 self.compaction_runtime = Some(runtime);
8480 self.runtime_handle = Some(runtime_handle.clone());
8481 Ok(runtime_handle)
8482 }
8483
8484 fn auto_compaction_result_payload(
8485 summary: String,
8486 first_kept_entry_id: String,
8487 tokens_before: u64,
8488 details: Option<Value>,
8489 ) -> Value {
8490 let mut payload = serde_json::Map::new();
8491 payload.insert("summary".to_string(), Value::String(summary));
8492 payload.insert(
8493 "firstKeptEntryId".to_string(),
8494 Value::String(first_kept_entry_id),
8495 );
8496 payload.insert("tokensBefore".to_string(), Value::from(tokens_before));
8497 if let Some(details) = details {
8498 payload.insert("details".to_string(), details);
8499 }
8500 Value::Object(payload)
8501 }
8502
8503 async fn apply_compaction_entry(
8504 &self,
8505 summary: String,
8506 first_kept_entry_id: String,
8507 tokens_before: u64,
8508 details: Option<Value>,
8509 from_extension: bool,
8510 ) -> Result<()> {
8511 let cx = crate::agent_cx::AgentCx::for_request();
8512 let mut session = self
8513 .session
8514 .lock(cx.cx())
8515 .await
8516 .map_err(|e| Error::session(e.to_string()))?;
8517
8518 let from_hook = if from_extension { Some(true) } else { None };
8519 let entry_id = session.append_compaction(
8520 summary,
8521 first_kept_entry_id,
8522 tokens_before,
8523 details,
8524 from_hook,
8525 );
8526
8527 if self.save_enabled {
8528 session
8529 .flush_autosave(AutosaveFlushTrigger::Periodic)
8530 .await?;
8531 }
8532
8533 let compaction_entry = session.get_entry(&entry_id).and_then(|entry| {
8534 if let crate::session::SessionEntry::Compaction(compaction) = entry {
8535 Some(compaction.clone())
8536 } else {
8537 None
8538 }
8539 });
8540 drop(session);
8541
8542 if let (Some(region), Some(compaction_entry)) = (&self.extensions, compaction_entry) {
8543 let payload = json!({
8544 "compactionEntry": compaction_entry,
8545 "fromExtension": from_extension,
8546 });
8547 if let Err(err) = region
8548 .manager()
8549 .dispatch_event(ExtensionEventName::SessionCompact, Some(payload))
8550 .await
8551 {
8552 tracing::warn!("session_compact extension hook failed (fail-open): {err}");
8553 }
8554 }
8555
8556 Ok(())
8557 }
8558
8559 async fn apply_compaction_result(
8561 &self,
8562 result: compaction::CompactionResult,
8563 on_event: AgentEventHandler,
8564 ) -> Result<()> {
8565 let details = Some(compaction::compaction_details_to_value(&result.details)?);
8566 let result_value = Some(Self::auto_compaction_result_payload(
8567 result.summary.clone(),
8568 result.first_kept_entry_id.clone(),
8569 result.tokens_before,
8570 details.clone(),
8571 ));
8572
8573 self.apply_compaction_entry(
8574 result.summary,
8575 result.first_kept_entry_id,
8576 result.tokens_before,
8577 details,
8578 false,
8579 )
8580 .await?;
8581
8582 on_event(AgentEvent::AutoCompactionEnd {
8583 result: result_value,
8584 aborted: false,
8585 will_retry: false,
8586 error_message: None,
8587 });
8588
8589 Ok(())
8590 }
8591
8592 async fn compact_synchronous(&self, on_event: AgentEventHandler) -> Result<()> {
8594 if !self.compaction_settings.enabled {
8595 return Ok(());
8596 }
8597
8598 let (entries, preparation) = {
8599 let cx = crate::agent_cx::AgentCx::for_request();
8600 let mut session = self
8601 .session
8602 .lock(cx.cx())
8603 .await
8604 .map_err(|e| Error::session(e.to_string()))?;
8605 session.ensure_entry_ids();
8606 let entries = session
8607 .entries_for_current_path()
8608 .into_iter()
8609 .cloned()
8610 .collect::<Vec<_>>();
8611 let prep = compaction::prepare_compaction(&entries, self.compaction_settings.clone());
8612 (entries, prep)
8613 };
8614
8615 if let Some(prep) = preparation {
8616 on_event(AgentEvent::AutoCompactionStart {
8617 reason: "threshold".to_string(),
8618 });
8619
8620 let before_outcome = self.dispatch_before_compact(&prep, &entries, None).await;
8621 if before_outcome.cancel {
8622 on_event(AgentEvent::AutoCompactionEnd {
8623 result: None,
8624 aborted: true,
8625 will_retry: false,
8626 error_message: None,
8627 });
8628 return Err(Error::extension("Compaction cancelled".to_string()));
8629 }
8630
8631 if let Some(compaction) = before_outcome.compaction {
8632 let result_value = Some(Self::auto_compaction_result_payload(
8633 compaction.summary.clone(),
8634 compaction.first_kept_entry_id.clone(),
8635 compaction.tokens_before,
8636 compaction.details.clone(),
8637 ));
8638 self.extensions_is_compacting
8639 .store(true, std::sync::atomic::Ordering::SeqCst);
8640 let apply_result = self
8641 .apply_compaction_entry(
8642 compaction.summary,
8643 compaction.first_kept_entry_id,
8644 compaction.tokens_before,
8645 compaction.details,
8646 true,
8647 )
8648 .await;
8649 self.extensions_is_compacting
8650 .store(false, std::sync::atomic::Ordering::SeqCst);
8651 apply_result?;
8652 on_event(AgentEvent::AutoCompactionEnd {
8653 result: result_value,
8654 aborted: false,
8655 will_retry: false,
8656 error_message: None,
8657 });
8658 return Ok(());
8659 }
8660 self.extensions_is_compacting
8661 .store(true, std::sync::atomic::Ordering::SeqCst);
8662
8663 let provider = self.agent.provider();
8664 let credential = self
8665 .agent
8666 .stream_options()
8667 .api_key
8668 .clone()
8669 .unwrap_or_default();
8670
8671 let compaction_result = compaction::compact(prep, provider, &credential, None).await;
8672 self.extensions_is_compacting
8673 .store(false, std::sync::atomic::Ordering::SeqCst);
8674
8675 match compaction_result {
8676 Ok(result) => {
8677 self.apply_compaction_result(result, Arc::clone(&on_event))
8678 .await?;
8679 }
8680 Err(e) => {
8681 on_event(AgentEvent::AutoCompactionEnd {
8682 result: None,
8683 aborted: false,
8684 will_retry: false,
8685 error_message: Some(e.to_string()),
8686 });
8687 return Err(e);
8688 }
8689 }
8690 }
8691 Ok(())
8692 }
8693
8694 fn resolve_extension_policy_for_enable(
8695 config: Option<&crate::config::Config>,
8696 policy: Option<ExtensionPolicy>,
8697 ) -> ExtensionPolicy {
8698 policy.unwrap_or_else(|| {
8699 config.map_or_else(
8700 || crate::config::Config::default().resolve_extension_policy(None),
8701 |cfg| cfg.resolve_extension_policy(None),
8702 )
8703 })
8704 }
8705
8706 pub async fn enable_extensions(
8707 &mut self,
8708 enabled_tools: &[&str],
8709 cwd: &std::path::Path,
8710 config: Option<&crate::config::Config>,
8711 extension_entries: &[std::path::PathBuf],
8712 ) -> Result<()> {
8713 self.enable_extensions_with_policy(
8714 enabled_tools,
8715 cwd,
8716 config,
8717 extension_entries,
8718 None,
8719 None,
8720 None,
8721 )
8722 .await
8723 }
8724
8725 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
8726 pub async fn enable_extensions_with_policy(
8727 &mut self,
8728 enabled_tools: &[&str],
8729 cwd: &std::path::Path,
8730 config: Option<&crate::config::Config>,
8731 extension_entries: &[std::path::PathBuf],
8732 policy: Option<ExtensionPolicy>,
8733 repair_policy: Option<RepairPolicyMode>,
8734 pre_warmed: Option<PreWarmedExtensionRuntime>,
8735 ) -> Result<()> {
8736 let mut js_specs: Vec<JsExtensionLoadSpec> = Vec::new();
8737 let mut native_specs: Vec<NativeRustExtensionLoadSpec> = Vec::new();
8738 #[cfg(feature = "wasm-host")]
8739 let mut wasm_specs: Vec<WasmExtensionLoadSpec> = Vec::new();
8740
8741 for entry in extension_entries {
8742 match resolve_extension_load_spec(entry)? {
8743 ExtensionLoadSpec::Js(spec) => js_specs.push(spec),
8744 ExtensionLoadSpec::NativeRust(spec) => native_specs.push(spec),
8745 #[cfg(feature = "wasm-host")]
8746 ExtensionLoadSpec::Wasm(spec) => wasm_specs.push(spec),
8747 }
8748 }
8749
8750 if !js_specs.is_empty() && !native_specs.is_empty() {
8751 return Err(Error::validation(
8752 "Mixed extension runtimes are not supported in one session yet. Use either JS/TS extensions (QuickJS) or native-rust descriptors (*.native.json), but not both at once."
8753 .to_string(),
8754 ));
8755 }
8756
8757 #[cfg(feature = "wasm-host")]
8758 if js_specs.is_empty() && native_specs.is_empty() && wasm_specs.is_empty() {
8759 self.extensions = None;
8760 self.agent.extensions = None;
8761 self.extension_queue_modes = None;
8762 self.extension_injected_queue = None;
8763 return Ok(());
8764 }
8765
8766 #[cfg(not(feature = "wasm-host"))]
8767 if js_specs.is_empty() && native_specs.is_empty() {
8768 self.extensions = None;
8769 self.agent.extensions = None;
8770 self.extension_queue_modes = None;
8771 self.extension_injected_queue = None;
8772 return Ok(());
8773 }
8774
8775 let resolved_policy = Self::resolve_extension_policy_for_enable(config, policy);
8776 let resolved_repair_policy = repair_policy
8777 .or_else(|| config.map(|cfg| cfg.resolve_repair_policy(None)))
8778 .unwrap_or(RepairPolicyMode::AutoSafe);
8779 let runtime_repair_mode =
8780 Self::runtime_repair_mode_from_policy_mode(resolved_repair_policy);
8781 let memory_limit_bytes =
8782 (resolved_policy.max_memory_mb as usize).saturating_mul(1024 * 1024);
8783 let wants_js_runtime = !js_specs.is_empty();
8784
8785 #[allow(unused_variables)]
8788 let (manager, tools) = if let Some(pre) = pre_warmed {
8789 let manager = pre.manager;
8790 let tools = pre.tools;
8791 let runtime = match pre.runtime {
8792 ExtensionRuntimeHandle::NativeRust(runtime) => {
8793 if wants_js_runtime {
8794 tracing::warn!(
8795 event = "pi.extension_runtime.prewarm.mismatch",
8796 expected = "quickjs",
8797 got = "native-rust",
8798 "Pre-warmed runtime mismatched requested JS mode; creating quickjs runtime"
8799 );
8800 Self::start_js_extension_runtime(
8801 "agent_enable_extensions_prewarm_mismatch",
8802 cwd,
8803 Arc::clone(&tools),
8804 manager.clone(),
8805 resolved_policy.clone(),
8806 runtime_repair_mode,
8807 memory_limit_bytes,
8808 )
8809 .await?
8810 } else {
8811 tracing::info!(
8812 event = "pi.extension_runtime.engine_decision",
8813 stage = "agent_enable_extensions_prewarmed",
8814 requested = "native-rust",
8815 selected = "native-rust",
8816 fallback = false,
8817 "Using pre-warmed extension runtime"
8818 );
8819 ExtensionRuntimeHandle::NativeRust(runtime)
8820 }
8821 }
8822 ExtensionRuntimeHandle::Js(runtime) => {
8823 if wants_js_runtime {
8824 tracing::info!(
8825 event = "pi.extension_runtime.engine_decision",
8826 stage = "agent_enable_extensions_prewarmed",
8827 requested = "quickjs",
8828 selected = "quickjs",
8829 fallback = false,
8830 "Using pre-warmed extension runtime"
8831 );
8832 ExtensionRuntimeHandle::Js(runtime)
8833 } else {
8834 tracing::warn!(
8835 event = "pi.extension_runtime.prewarm.mismatch",
8836 expected = "native-rust",
8837 got = "quickjs",
8838 "Pre-warmed runtime mismatched requested native mode; creating native-rust runtime"
8839 );
8840 Self::start_native_extension_runtime(
8841 "agent_enable_extensions_prewarm_mismatch",
8842 cwd,
8843 Arc::clone(&tools),
8844 manager.clone(),
8845 resolved_policy.clone(),
8846 runtime_repair_mode,
8847 memory_limit_bytes,
8848 )
8849 .await?
8850 }
8851 }
8852 };
8853 manager.set_runtime(runtime);
8854 (manager, tools)
8855 } else {
8856 let manager = ExtensionManager::new();
8857 manager.set_cwd(cwd.display().to_string());
8858 let tools = Arc::new(ToolRegistry::new(enabled_tools, cwd, config));
8859
8860 if let Some(cfg) = config {
8861 let resolved_risk = cfg.resolve_extension_risk_with_metadata();
8862 tracing::info!(
8863 event = "pi.extension_runtime_risk.config",
8864 source = resolved_risk.source,
8865 enabled = resolved_risk.settings.enabled,
8866 alpha = resolved_risk.settings.alpha,
8867 window_size = resolved_risk.settings.window_size,
8868 ledger_limit = resolved_risk.settings.ledger_limit,
8869 fail_closed = resolved_risk.settings.fail_closed,
8870 "Resolved extension runtime risk settings"
8871 );
8872 manager.set_runtime_risk_config(resolved_risk.settings);
8873 }
8874
8875 let runtime = if wants_js_runtime {
8876 Self::start_js_extension_runtime(
8877 "agent_enable_extensions_boot",
8878 cwd,
8879 Arc::clone(&tools),
8880 manager.clone(),
8881 resolved_policy.clone(),
8882 runtime_repair_mode,
8883 memory_limit_bytes,
8884 )
8885 .await?
8886 } else {
8887 Self::start_native_extension_runtime(
8888 "agent_enable_extensions_boot",
8889 cwd,
8890 Arc::clone(&tools),
8891 manager.clone(),
8892 resolved_policy.clone(),
8893 runtime_repair_mode,
8894 memory_limit_bytes,
8895 )
8896 .await?
8897 };
8898 manager.set_runtime(runtime);
8899 (manager, tools)
8900 };
8901
8902 let (steering_mode, follow_up_mode) = self.agent.queue_modes();
8906 let queue_modes = Arc::new(StdMutex::new(ExtensionQueueModeState::new(
8907 steering_mode,
8908 follow_up_mode,
8909 )));
8910 manager.set_session(Arc::new(AgentExtensionSession {
8911 handle: SessionHandle(self.session.clone()),
8912 is_streaming: Arc::clone(&self.extensions_is_streaming),
8913 is_compacting: Arc::clone(&self.extensions_is_compacting),
8914 queue_modes: Arc::clone(&queue_modes),
8915 auto_compaction_enabled: self.compaction_settings.enabled,
8916 }));
8917
8918 let injected = Arc::new(StdMutex::new(ExtensionInjectedQueue::new(
8919 steering_mode,
8920 follow_up_mode,
8921 )));
8922 let host_actions = AgentSessionHostActions {
8923 session: Arc::clone(&self.session),
8924 injected: Arc::clone(&injected),
8925 is_streaming: Arc::clone(&self.extensions_is_streaming),
8926 is_turn_active: Arc::clone(&self.extensions_turn_active),
8927 pending_idle_actions: Arc::clone(&self.extensions_pending_idle_actions),
8928 ai_completion: Arc::clone(&self.extension_ai_completion),
8929 };
8930 self.extension_queue_modes = Some(Arc::clone(&queue_modes));
8931 self.extension_injected_queue = Some(Arc::clone(&injected));
8932 manager.set_host_actions(Arc::new(host_actions));
8933 {
8934 let steering_queue = Arc::clone(&injected);
8935 let follow_up_queue = Arc::clone(&injected);
8936 let steering_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
8937 let steering_queue = Arc::clone(&steering_queue);
8938 Box::pin(async move {
8939 let Ok(mut queue) = steering_queue.lock() else {
8940 return Vec::new();
8941 };
8942 queue.pop_steering()
8943 })
8944 };
8945 let follow_up_fetcher = move || -> BoxFuture<'static, Vec<Message>> {
8946 let follow_up_queue = Arc::clone(&follow_up_queue);
8947 Box::pin(async move {
8948 let Ok(mut queue) = follow_up_queue.lock() else {
8949 return Vec::new();
8950 };
8951 queue.pop_follow_up()
8952 })
8953 };
8954 self.agent.register_message_fetchers(
8955 Some(Arc::new(steering_fetcher)),
8956 Some(Arc::new(follow_up_fetcher)),
8957 );
8958 }
8959
8960 if !js_specs.is_empty() {
8961 manager.load_js_extensions(js_specs).await?;
8962 }
8963
8964 if !native_specs.is_empty() {
8965 manager.load_native_extensions(native_specs).await?;
8966 }
8967
8968 if let Some(rt) = manager.runtime() {
8970 let events = rt.drain_repair_events().await;
8971 if !events.is_empty() {
8972 log_repair_diagnostics(&events);
8973 }
8974 }
8975
8976 #[cfg(feature = "wasm-host")]
8977 if !wasm_specs.is_empty() {
8978 let host = WasmExtensionHost::new(cwd, resolved_policy.clone())?;
8979 manager
8980 .load_wasm_extensions(&host, wasm_specs, Arc::clone(&tools))
8981 .await?;
8982 }
8983
8984 let session_path = {
8987 let cx = crate::agent_cx::AgentCx::for_request();
8988 let session = self
8989 .session
8990 .lock(cx.cx())
8991 .await
8992 .map_err(|e| Error::extension(e.to_string()))?;
8993 session.path.as_ref().map(|p| p.display().to_string())
8994 };
8995
8996 if let Err(err) = manager
8997 .dispatch_event(
8998 ExtensionEventName::Startup,
8999 Some(serde_json::json!({
9000 "version": env!("CARGO_PKG_VERSION"),
9001 "sessionFile": session_path,
9002 })),
9003 )
9004 .await
9005 {
9006 tracing::warn!("startup extension hook failed (fail-open): {err}");
9007 }
9008
9009 if let Err(err) = manager
9010 .dispatch_event(ExtensionEventName::SessionStart, None)
9011 .await
9012 {
9013 tracing::warn!("session_start extension hook failed (fail-open): {err}");
9014 }
9015
9016 let ctx_payload = serde_json::json!({ "cwd": cwd.display().to_string() });
9017 let wrappers = collect_extension_tool_wrappers(&manager, ctx_payload).await?;
9018 self.agent.extend_tools(wrappers);
9019 self.agent.extensions = Some(manager.clone());
9020 self.extensions = Some(ExtensionRegion::new(manager));
9021 Ok(())
9022 }
9023
9024 pub async fn save_and_index(&mut self) -> Result<()> {
9025 if self.save_enabled {
9026 let cx = crate::agent_cx::AgentCx::for_request();
9027 let mut session = self
9028 .session
9029 .lock(cx.cx())
9030 .await
9031 .map_err(|e| Error::session(e.to_string()))?;
9032 session
9033 .flush_autosave(AutosaveFlushTrigger::Periodic)
9034 .await?;
9035 }
9036 Ok(())
9037 }
9038
9039 pub async fn persist_session(&mut self) -> Result<()> {
9040 if !self.save_enabled {
9041 return Ok(());
9042 }
9043 let cx = crate::agent_cx::AgentCx::for_request();
9044 let mut session = self
9045 .session
9046 .lock(cx.cx())
9047 .await
9048 .map_err(|e| Error::session(e.to_string()))?;
9049 session
9050 .flush_autosave(AutosaveFlushTrigger::Periodic)
9051 .await?;
9052 Ok(())
9053 }
9054
9055 pub async fn run_text(
9056 &mut self,
9057 input: String,
9058 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9059 ) -> Result<AssistantMessage> {
9060 self.run_text_with_abort(input, None, on_event).await
9061 }
9062
9063 pub async fn run_text_with_abort(
9064 &mut self,
9065 input: String,
9066 abort: Option<AbortSignal>,
9067 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9068 ) -> Result<AssistantMessage> {
9069 self.extensions_turn_active.store(true, Ordering::SeqCst);
9070 let result = async {
9071 let outcome = self.dispatch_input_event(input, Vec::new()).await?;
9072 let (text, images) = match outcome {
9073 InputEventOutcome::Continue { text, images } => (text, images),
9074 InputEventOutcome::Block { reason } => {
9075 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
9076 return Err(Error::extension(message));
9077 }
9078 };
9079
9080 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9081 let BeforeAgentStartOutcome {
9082 messages: custom_messages,
9083 system_prompt,
9084 } = self
9085 .dispatch_before_agent_start(
9086 &text,
9087 &images,
9088 base_system_prompt.as_deref().unwrap_or(""),
9089 )
9090 .await;
9091 if let Some(prompt) = system_prompt {
9092 self.agent.set_system_prompt(Some(prompt));
9093 } else {
9094 self.agent.set_system_prompt(base_system_prompt.clone());
9095 }
9096
9097 let result = if images.is_empty() {
9098 self.run_agent_with_text(text, abort, on_event, custom_messages)
9099 .await
9100 } else {
9101 let content = Self::build_content_blocks_for_input(&text, &images);
9102 self.run_agent_with_content(content, abort, on_event, custom_messages)
9103 .await
9104 };
9105
9106 self.agent.set_system_prompt(base_system_prompt);
9107 result
9108 }
9109 .await;
9110 self.extensions_turn_active.store(false, Ordering::SeqCst);
9111 result
9112 }
9113
9114 pub async fn run_with_content(
9115 &mut self,
9116 content: Vec<ContentBlock>,
9117 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9118 ) -> Result<AssistantMessage> {
9119 self.run_with_content_with_abort(content, None, on_event)
9120 .await
9121 }
9122
9123 pub async fn run_with_content_with_abort(
9124 &mut self,
9125 content: Vec<ContentBlock>,
9126 abort: Option<AbortSignal>,
9127 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9128 ) -> Result<AssistantMessage> {
9129 self.extensions_turn_active.store(true, Ordering::SeqCst);
9130 let result = async {
9131 let (text, images) = Self::split_content_blocks_for_input(&content);
9132 let outcome = self.dispatch_input_event(text, images).await?;
9133 let (text, images) = match outcome {
9134 InputEventOutcome::Continue { text, images } => (text, images),
9135 InputEventOutcome::Block { reason } => {
9136 let message = reason.unwrap_or_else(|| "Input blocked".to_string());
9137 return Err(Error::extension(message));
9138 }
9139 };
9140
9141 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9142 let BeforeAgentStartOutcome {
9143 messages: custom_messages,
9144 system_prompt,
9145 } = self
9146 .dispatch_before_agent_start(
9147 &text,
9148 &images,
9149 base_system_prompt.as_deref().unwrap_or(""),
9150 )
9151 .await;
9152 if let Some(prompt) = system_prompt {
9153 self.agent.set_system_prompt(Some(prompt));
9154 } else {
9155 self.agent.set_system_prompt(base_system_prompt.clone());
9156 }
9157
9158 let content_for_agent = Self::build_content_blocks_for_input(&text, &images);
9159 let result = self
9160 .run_agent_with_content(content_for_agent, abort, on_event, custom_messages)
9161 .await;
9162
9163 self.agent.set_system_prompt(base_system_prompt);
9164 result
9165 }
9166 .await;
9167 self.extensions_turn_active.store(false, Ordering::SeqCst);
9168 result
9169 }
9170
9171 pub async fn revert_last_user_message(&mut self) -> Result<bool> {
9172 let cx = crate::agent_cx::AgentCx::for_request();
9173 let mut session = self
9174 .session
9175 .lock(cx.cx())
9176 .await
9177 .map_err(|e| Error::session(e.to_string()))?;
9178
9179 let reverted = session.revert_last_user_message();
9180 if reverted {
9181 let messages = session.to_messages_for_current_path();
9182 self.agent.replace_messages(messages);
9183 }
9184 Ok(reverted)
9185 }
9186
9187 async fn dispatch_input_event(
9188 &self,
9189 text: String,
9190 images: Vec<ImageContent>,
9191 ) -> Result<InputEventOutcome> {
9192 let Some(region) = &self.extensions else {
9193 return Ok(InputEventOutcome::Continue { text, images });
9194 };
9195
9196 let images_value = serde_json::to_value(&images).unwrap_or(Value::Null);
9197 let attachments_value = images_value.clone();
9198 let text_clone = text.clone();
9199 let payload = json!({
9200 "text": text,
9201 "content": text_clone,
9202 "images": images_value,
9203 "attachments": attachments_value,
9204 "source": self.input_source.as_str(),
9205 });
9206
9207 let response = region
9208 .manager()
9209 .dispatch_event_with_response(
9210 ExtensionEventName::Input,
9211 Some(payload),
9212 EXTENSION_EVENT_TIMEOUT_MS,
9213 )
9214 .await?;
9215
9216 Ok(apply_input_event_response(response, text, images))
9217 }
9218
9219 async fn dispatch_before_agent_start(
9220 &self,
9221 prompt: &str,
9222 images: &[ImageContent],
9223 system_prompt: &str,
9224 ) -> BeforeAgentStartOutcome {
9225 let Some(region) = &self.extensions else {
9226 return BeforeAgentStartOutcome {
9227 messages: Vec::new(),
9228 system_prompt: None,
9229 };
9230 };
9231
9232 let images_value = serde_json::to_value(images).unwrap_or(Value::Null);
9233 let payload = json!({
9234 "prompt": prompt,
9235 "images": images_value,
9236 "systemPrompt": system_prompt,
9237 });
9238
9239 let response = region
9240 .manager()
9241 .dispatch_event_with_response(
9242 ExtensionEventName::BeforeAgentStart,
9243 Some(payload),
9244 EXTENSION_EVENT_TIMEOUT_MS,
9245 )
9246 .await;
9247
9248 match response {
9249 Ok(value) => apply_before_agent_start_response(value, Utc::now().timestamp_millis()),
9250 Err(err) => {
9251 tracing::warn!("before_agent_start extension hook failed (fail-open): {err}");
9252 BeforeAgentStartOutcome {
9253 messages: Vec::new(),
9254 system_prompt: None,
9255 }
9256 }
9257 }
9258 }
9259
9260 async fn dispatch_before_compact(
9261 &self,
9262 preparation: &compaction::CompactionPreparation,
9263 branch_entries: &[crate::session::SessionEntry],
9264 custom_instructions: Option<&str>,
9265 ) -> SessionBeforeCompactOutcome {
9266 let Some(region) = &self.extensions else {
9267 return SessionBeforeCompactOutcome::default();
9268 };
9269
9270 let prep_value = compaction::compaction_preparation_to_value(preparation);
9271 let branch_entries_value =
9272 serde_json::to_value(branch_entries).unwrap_or(Value::Array(Vec::new()));
9273 let mut payload = serde_json::Map::new();
9274 payload.insert("preparation".to_string(), prep_value);
9275 payload.insert("branchEntries".to_string(), branch_entries_value);
9276 if let Some(custom_instructions) = custom_instructions {
9277 payload.insert(
9278 "customInstructions".to_string(),
9279 Value::String(custom_instructions.to_string()),
9280 );
9281 }
9282
9283 let response = region
9284 .manager()
9285 .dispatch_event_with_response(
9286 ExtensionEventName::SessionBeforeCompact,
9287 Some(Value::Object(payload)),
9288 EXTENSION_EVENT_TIMEOUT_MS,
9289 )
9290 .await;
9291
9292 match response {
9293 Ok(value) => apply_session_before_compact_response(value, preparation.tokens_before),
9294 Err(err) => {
9295 tracing::warn!("session_before_compact extension hook failed (fail-open): {err}");
9296 SessionBeforeCompactOutcome::default()
9297 }
9298 }
9299 }
9300
9301 fn prepare_semantic_context_prompt(&self) -> Option<PreparedSemanticContextPrompt> {
9302 let injection = self.semantic_context_bundle.as_ref()?;
9303 if !injection.enabled {
9304 return None;
9305 }
9306
9307 let provider = self.agent.provider();
9308 let shape = semantic_context_prompt_shape_for_provider(provider.api());
9309 let budget = semantic_context_prompt_budget_for_provider(provider.api(), injection);
9310 let revision = semantic_context_bundle_revision(&injection.bundle);
9311 let (prompt, stats) =
9312 render_semantic_context_prompt(&injection.bundle, injection, budget, &revision);
9313 if prompt.trim().is_empty() {
9314 tracing::warn!(
9315 event = "pi.semantic_context.prompt.skipped",
9316 provider = provider.name(),
9317 api = provider.api(),
9318 model = provider.model_id(),
9319 revision = %revision,
9320 max_bytes = budget.max_bytes,
9321 "semantic context bundle prompt skipped because prompt budget was too small"
9322 );
9323 return None;
9324 }
9325
9326 tracing::info!(
9327 event = "pi.semantic_context.prompt.injected",
9328 provider = provider.name(),
9329 api = provider.api(),
9330 model = provider.model_id(),
9331 revision = %revision,
9332 shape = ?shape,
9333 prompt_bytes = prompt.len(),
9334 selected_items = stats.selected_items_included,
9335 selected_items_omitted = stats.selected_items_omitted,
9336 validation_commands = stats.validation_commands_included,
9337 truncated = stats.truncated,
9338 "semantic context bundle attached to agent turn"
9339 );
9340
9341 let details = json!({
9342 "schema": SEMANTIC_CONTEXT_PROVENANCE_SCHEMA_V1,
9343 "bundleSchema": injection.bundle.schema.as_str(),
9344 "bundleRevision": revision.as_str(),
9345 "provider": {
9346 "name": provider.name(),
9347 "api": provider.api(),
9348 "model": provider.model_id(),
9349 "promptShape": shape,
9350 },
9351 "budget": {
9352 "requestedMaxItems": injection.max_prompt_items,
9353 "requestedMaxBytes": injection.max_prompt_bytes,
9354 "effectiveMaxItems": budget.max_items,
9355 "effectiveMaxBytes": budget.max_bytes,
9356 },
9357 "prompt": {
9358 "bytes": prompt.len(),
9359 "selectedItemsIncluded": stats.selected_items_included,
9360 "selectedItemsOmitted": stats.selected_items_omitted,
9361 "validationCommandsIncluded": stats.validation_commands_included,
9362 "validationCommandsOmitted": stats.validation_commands_omitted,
9363 "exclusionsIncluded": stats.exclusions_included,
9364 "exclusionsOmitted": stats.exclusions_omitted,
9365 "truncated": stats.truncated,
9366 },
9367 "bundle": {
9368 "selectedItems": injection.bundle.selected_items.len(),
9369 "excludedItems": injection.bundle.excluded_items.len(),
9370 "staleEvidenceSuppressions": injection.bundle.stale_evidence_suppressions.len(),
9371 "estimatedBytes": injection.bundle.estimated_bytes,
9372 "estimatedTokens": injection.bundle.estimated_tokens,
9373 "redactionStatus": injection.bundle.redaction_summary.overall_status,
9374 "inputFingerprintSha256": injection.bundle.invalidation_policy.input_fingerprint_sha256.as_str(),
9375 "cacheable": injection.bundle.invalidation_policy.cacheable,
9376 "workspaceId": injection.bundle.invalidation_policy.workspace_id.as_str(),
9377 "branch": injection.bundle.invalidation_policy.branch.as_deref(),
9378 "sessionId": injection.bundle.invalidation_policy.session_id.as_deref(),
9379 }
9380 });
9381
9382 Some(PreparedSemanticContextPrompt {
9383 prompt,
9384 revision,
9385 shape,
9386 details,
9387 })
9388 }
9389
9390 fn semantic_context_prompt_messages(
9391 prepared: &PreparedSemanticContextPrompt,
9392 timestamp: i64,
9393 ) -> Vec<Message> {
9394 match prepared.shape {
9395 SemanticContextPromptShape::CustomUserMessage => {
9396 vec![Message::Custom(CustomMessage {
9397 content: prepared.prompt.clone(),
9398 custom_type: SEMANTIC_CONTEXT_CUSTOM_TYPE.to_string(),
9399 display: true,
9400 details: Some(prepared.details.clone()),
9401 timestamp,
9402 })]
9403 }
9404 SemanticContextPromptShape::SystemPromptAppend => {
9405 vec![Message::Custom(CustomMessage {
9406 content: format!(
9407 "Semantic context bundle revision {} attached to system prompt.",
9408 prepared.revision
9409 ),
9410 custom_type: SEMANTIC_CONTEXT_CUSTOM_TYPE.to_string(),
9411 display: false,
9412 details: Some(prepared.details.clone()),
9413 timestamp,
9414 })]
9415 }
9416 }
9417 }
9418
9419 fn semantic_context_system_prompt_for_turn(
9420 base_system_prompt: Option<String>,
9421 prepared: Option<&PreparedSemanticContextPrompt>,
9422 ) -> Option<String> {
9423 let Some(prepared) = prepared else {
9424 return base_system_prompt;
9425 };
9426 if !matches!(
9427 prepared.shape,
9428 SemanticContextPromptShape::SystemPromptAppend
9429 ) {
9430 return base_system_prompt;
9431 }
9432
9433 let mut prompt = base_system_prompt.unwrap_or_default();
9434 if !prompt.is_empty() {
9435 prompt.push_str("\n\n");
9436 }
9437 prompt.push_str(&prepared.prompt);
9438 Some(prompt)
9439 }
9440
9441 fn split_content_blocks_for_input(blocks: &[ContentBlock]) -> (String, Vec<ImageContent>) {
9442 let mut text = String::new();
9443 let mut images = Vec::new();
9444 for block in blocks {
9445 match block {
9446 ContentBlock::Text(text_block) if !text_block.text.trim().is_empty() => {
9447 if !text.is_empty() {
9448 text.push('\n');
9449 }
9450 text.push_str(&text_block.text);
9451 }
9452 ContentBlock::Image(image) => images.push(image.clone()),
9453 _ => {}
9454 }
9455 }
9456 (text, images)
9457 }
9458
9459 fn build_content_blocks_for_input(text: &str, images: &[ImageContent]) -> Vec<ContentBlock> {
9460 let mut content = Vec::new();
9461 if !text.trim().is_empty() {
9462 content.push(ContentBlock::Text(TextContent::new(text.to_string())));
9463 }
9464 for image in images {
9465 content.push(ContentBlock::Image(image.clone()));
9466 }
9467 content
9468 }
9469
9470 fn take_pending_idle_actions(&self) -> Vec<PendingIdleAction> {
9471 let Ok(mut actions) = self.extensions_pending_idle_actions.lock() else {
9472 return Vec::new();
9473 };
9474 actions.drain(..).collect()
9475 }
9476
9477 async fn run_pending_idle_actions_with_abort(
9478 &mut self,
9479 abort: Option<AbortSignal>,
9480 on_event: AgentEventHandler,
9481 ) -> Result<()> {
9482 let actions = self.take_pending_idle_actions();
9483 if actions.is_empty() {
9484 return Ok(());
9485 }
9486
9487 let previous_source = self.input_source;
9488 self.input_source = InputSource::Extension;
9489 let result = async {
9490 for action in actions {
9491 match action {
9492 PendingIdleAction::CustomMessage(message) => {
9493 let handler = Arc::clone(&on_event);
9494 self.run_custom_message_with_abort(message, abort.clone(), move |event| {
9495 handler(event);
9496 })
9497 .await?;
9498 }
9499 PendingIdleAction::UserText(text) => {
9500 let handler = Arc::clone(&on_event);
9501 self.run_text_with_abort(text, abort.clone(), move |event| {
9502 handler(event);
9503 })
9504 .await?;
9505 }
9506 }
9507 }
9508 Ok(())
9509 }
9510 .await;
9511 self.input_source = previous_source;
9512 result
9513 }
9514
9515 async fn run_custom_message_with_abort(
9516 &mut self,
9517 message: Message,
9518 abort: Option<AbortSignal>,
9519 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9520 ) -> Result<AssistantMessage> {
9521 self.extensions_turn_active.store(true, Ordering::SeqCst);
9522 let result = async {
9523 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9524 let BeforeAgentStartOutcome {
9525 messages: custom_messages,
9526 system_prompt,
9527 } = self
9528 .dispatch_before_agent_start("", &[], base_system_prompt.as_deref().unwrap_or(""))
9529 .await;
9530 if let Some(prompt) = system_prompt {
9531 self.agent.set_system_prompt(Some(prompt));
9532 } else {
9533 self.agent.set_system_prompt(base_system_prompt.clone());
9534 }
9535
9536 let result = self
9537 .run_agent_with_prompt_message(message, abort, on_event, custom_messages)
9538 .await;
9539
9540 self.agent.set_system_prompt(base_system_prompt);
9541 result
9542 }
9543 .await;
9544 self.extensions_turn_active.store(false, Ordering::SeqCst);
9545 result
9546 }
9547
9548 async fn run_agent_with_prompt_message(
9549 &mut self,
9550 prompt_message: Message,
9551 abort: Option<AbortSignal>,
9552 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9553 custom_messages: Vec<CustomMessage>,
9554 ) -> Result<AssistantMessage> {
9555 let on_event: AgentEventHandler = Arc::new(on_event);
9556 self.sync_runtime_selection_from_session_header().await?;
9557
9558 self.maybe_compact(Arc::clone(&on_event)).await?;
9559 let history = {
9560 let cx = crate::agent_cx::AgentCx::for_request();
9561 let session = self
9562 .session
9563 .lock(cx.cx())
9564 .await
9565 .map_err(|e| Error::session(e.to_string()))?;
9566 session.to_messages_for_current_path()
9567 };
9568 self.agent.replace_messages(history);
9569
9570 let start_len = self.agent.messages().len();
9571 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
9572 prompts.push(prompt_message.clone());
9573 prompts.extend(custom_messages.into_iter().map(Message::Custom));
9574
9575 {
9576 let cx = crate::agent_cx::AgentCx::for_request();
9577 let mut session = self
9578 .session
9579 .lock(cx.cx())
9580 .await
9581 .map_err(|e| Error::session(e.to_string()))?;
9582 session.append_model_message(prompt_message.clone());
9583 if self.save_enabled {
9584 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
9585 }
9586 }
9587
9588 let semantic_context = self.prepare_semantic_context_prompt();
9589 let semantic_context_messages = semantic_context
9590 .as_ref()
9591 .map(|prepared| {
9592 Self::semantic_context_prompt_messages(prepared, Utc::now().timestamp_millis())
9593 })
9594 .unwrap_or_default();
9595 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
9596 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9597 self.agent
9598 .set_system_prompt(Self::semantic_context_system_prompt_for_turn(
9599 base_system_prompt.clone(),
9600 semantic_context.as_ref(),
9601 ));
9602 let on_event_for_run = Arc::clone(&on_event);
9603 prompts.extend(semantic_context_messages);
9604 let result = self
9605 .agent
9606 .run_with_messages_with_abort(prompts, abort, move |event| {
9607 on_event_for_run(event);
9608 })
9609 .await;
9610 drop(streaming_guard);
9611 self.agent.set_system_prompt(base_system_prompt);
9612
9613 let persist_result = self
9614 .persist_new_messages(start_len + 1, result.is_err())
9615 .await;
9616
9617 let result = result?;
9618 persist_result?;
9619 Ok(result)
9620 }
9621
9622 pub(crate) async fn run_agent_with_text(
9623 &mut self,
9624 input: String,
9625 abort: Option<AbortSignal>,
9626 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9627 custom_messages: Vec<CustomMessage>,
9628 ) -> Result<AssistantMessage> {
9629 let on_event: AgentEventHandler = Arc::new(on_event);
9630 self.sync_runtime_selection_from_session_header().await?;
9631
9632 self.maybe_compact(Arc::clone(&on_event)).await?;
9633 let history = {
9634 let cx = crate::agent_cx::AgentCx::for_request();
9635 let session = self
9636 .session
9637 .lock(cx.cx())
9638 .await
9639 .map_err(|e| Error::session(e.to_string()))?;
9640 session.to_messages_for_current_path()
9641 };
9642 self.agent.replace_messages(history);
9643
9644 let start_len = self.agent.messages().len();
9645
9646 let user_message = Message::User(UserMessage {
9648 content: UserContent::Text(input),
9649 timestamp: Utc::now().timestamp_millis(),
9650 });
9651 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
9652 prompts.push(user_message.clone());
9653 let semantic_context = self.prepare_semantic_context_prompt();
9654 let semantic_context_messages = semantic_context
9655 .as_ref()
9656 .map(|prepared| {
9657 Self::semantic_context_prompt_messages(prepared, Utc::now().timestamp_millis())
9658 })
9659 .unwrap_or_default();
9660 prompts.extend(semantic_context_messages);
9661 prompts.extend(custom_messages.into_iter().map(Message::Custom));
9662
9663 {
9664 let cx = crate::agent_cx::AgentCx::for_request();
9665 let mut session = self
9666 .session
9667 .lock(cx.cx())
9668 .await
9669 .map_err(|e| Error::session(e.to_string()))?;
9670 session.append_model_message(user_message.clone());
9671 if self.save_enabled {
9672 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
9673 }
9674 }
9675
9676 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
9677 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9678 self.agent
9679 .set_system_prompt(Self::semantic_context_system_prompt_for_turn(
9680 base_system_prompt.clone(),
9681 semantic_context.as_ref(),
9682 ));
9683 let on_event_for_run = Arc::clone(&on_event);
9684 let result = self
9685 .agent
9686 .run_with_messages_with_abort(prompts, abort, move |event| {
9687 on_event_for_run(event);
9688 })
9689 .await;
9690 drop(streaming_guard);
9691 self.agent.set_system_prompt(base_system_prompt);
9692
9693 let persist_result = self
9696 .persist_new_messages(start_len + 1, result.is_err())
9697 .await;
9698
9699 let result = result?;
9700 persist_result?;
9701 Ok(result)
9702 }
9703
9704 pub(crate) async fn run_agent_with_content(
9705 &mut self,
9706 content: Vec<ContentBlock>,
9707 abort: Option<AbortSignal>,
9708 on_event: impl Fn(AgentEvent) + Send + Sync + 'static,
9709 custom_messages: Vec<CustomMessage>,
9710 ) -> Result<AssistantMessage> {
9711 let on_event: AgentEventHandler = Arc::new(on_event);
9712 self.sync_runtime_selection_from_session_header().await?;
9713
9714 self.maybe_compact(Arc::clone(&on_event)).await?;
9715 let history = {
9716 let cx = crate::agent_cx::AgentCx::for_request();
9717 let session = self
9718 .session
9719 .lock(cx.cx())
9720 .await
9721 .map_err(|e| Error::session(e.to_string()))?;
9722 session.to_messages_for_current_path()
9723 };
9724 self.agent.replace_messages(history);
9725
9726 let start_len = self.agent.messages().len();
9727
9728 let user_message = Message::User(UserMessage {
9730 content: UserContent::Blocks(content),
9731 timestamp: Utc::now().timestamp_millis(),
9732 });
9733 let mut prompts = Vec::with_capacity(1 + custom_messages.len());
9734 prompts.push(user_message.clone());
9735 let semantic_context = self.prepare_semantic_context_prompt();
9736 let semantic_context_messages = semantic_context
9737 .as_ref()
9738 .map(|prepared| {
9739 Self::semantic_context_prompt_messages(prepared, Utc::now().timestamp_millis())
9740 })
9741 .unwrap_or_default();
9742 prompts.extend(semantic_context_messages);
9743 prompts.extend(custom_messages.into_iter().map(Message::Custom));
9744
9745 {
9746 let cx = crate::agent_cx::AgentCx::for_request();
9747 let mut session = self
9748 .session
9749 .lock(cx.cx())
9750 .await
9751 .map_err(|e| Error::session(e.to_string()))?;
9752 session.append_model_message(user_message.clone());
9753 if self.save_enabled {
9754 session.flush_autosave(AutosaveFlushTrigger::Manual).await?;
9755 }
9756 }
9757
9758 let streaming_guard = AtomicBoolGuard::activate(&self.extensions_is_streaming);
9759 let base_system_prompt = self.agent.system_prompt().map(str::to_string);
9760 self.agent
9761 .set_system_prompt(Self::semantic_context_system_prompt_for_turn(
9762 base_system_prompt.clone(),
9763 semantic_context.as_ref(),
9764 ));
9765 let on_event_for_run = Arc::clone(&on_event);
9766 let result = self
9767 .agent
9768 .run_with_messages_with_abort(prompts, abort, move |event| {
9769 on_event_for_run(event);
9770 })
9771 .await;
9772 drop(streaming_guard);
9773 self.agent.set_system_prompt(base_system_prompt);
9774
9775 let persist_result = self
9778 .persist_new_messages(start_len + 1, result.is_err())
9779 .await;
9780
9781 let result = result?;
9782 persist_result?;
9783 Ok(result)
9784 }
9785
9786 async fn persist_new_messages(&self, start_len: usize, run_failed: bool) -> Result<()> {
9787 let new_messages = self.agent.messages()[start_len..].to_vec();
9788 {
9789 let cx = crate::agent_cx::AgentCx::for_request();
9790 let mut session = self
9791 .session
9792 .lock(cx.cx())
9793 .await
9794 .map_err(|e| Error::session(e.to_string()))?;
9795 for message in new_messages {
9796 if run_failed && is_synthetic_empty_error_assistant(&message) {
9797 continue;
9798 }
9799 session.append_model_message(message);
9800 }
9801 if self.save_enabled {
9802 session
9803 .flush_autosave(AutosaveFlushTrigger::Periodic)
9804 .await?;
9805 }
9806 }
9807 Ok(())
9808 }
9809}
9810
9811fn is_synthetic_empty_error_assistant(message: &Message) -> bool {
9812 matches!(
9813 message,
9814 Message::Assistant(assistant)
9815 if assistant.content.is_empty()
9816 && matches!(assistant.stop_reason, StopReason::Error)
9817 && assistant.error_message.is_some()
9818 )
9819}
9820
9821fn semantic_context_prompt_shape_for_provider(api: &str) -> SemanticContextPromptShape {
9822 match api {
9823 "bedrock-converse-stream" | "gitlab-chat" => SemanticContextPromptShape::SystemPromptAppend,
9824 _ => SemanticContextPromptShape::CustomUserMessage,
9825 }
9826}
9827
9828fn semantic_context_prompt_budget_for_provider(
9829 api: &str,
9830 injection: &SemanticContextBundleInjection,
9831) -> SemanticContextPromptBudget {
9832 let provider_max_bytes = match api {
9833 "gitlab-chat" => 8 * 1024,
9834 "bedrock-converse-stream" | "google-gemini" | "google-vertex" => 12 * 1024,
9835 "openai-responses" | "openai-completions" | "azure-openai" => 24 * 1024,
9836 "anthropic" => 32 * 1024,
9837 _ => DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_BYTES,
9838 };
9839 let provider_max_items = match api {
9840 "gitlab-chat" => 8,
9841 "bedrock-converse-stream" | "google-gemini" | "google-vertex" => 12,
9842 _ => DEFAULT_SEMANTIC_CONTEXT_PROMPT_MAX_ITEMS,
9843 };
9844
9845 SemanticContextPromptBudget {
9846 max_items: injection
9847 .max_prompt_items
9848 .min(injection.bundle.budget.max_items)
9849 .min(provider_max_items),
9850 max_bytes: injection
9851 .max_prompt_bytes
9852 .min(injection.bundle.budget.max_bytes)
9853 .min(provider_max_bytes),
9854 }
9855}
9856
9857fn semantic_context_bundle_revision(bundle: &SemanticContextBundle) -> String {
9858 let bytes = serde_json::to_vec(bundle).unwrap_or_else(|_| {
9859 format!(
9860 "{}:{}:{}:{}",
9861 bundle.schema,
9862 bundle.invalidation_policy.input_fingerprint_sha256,
9863 bundle.selected_items.len(),
9864 bundle.estimated_bytes
9865 )
9866 .into_bytes()
9867 });
9868 format!("{:x}", Sha256::digest(bytes))
9869}
9870
9871fn render_semantic_context_prompt(
9872 bundle: &SemanticContextBundle,
9873 injection: &SemanticContextBundleInjection,
9874 budget: SemanticContextPromptBudget,
9875 revision: &str,
9876) -> (String, SemanticContextPromptStats) {
9877 let mut prompt = String::new();
9878 let mut stats = SemanticContextPromptStats::default();
9879 push_semantic_context_header(&mut prompt, &mut stats, budget, bundle, revision);
9880 push_selected_semantic_context_items(&mut prompt, &mut stats, budget, bundle);
9881 if injection.include_validation_commands {
9882 push_semantic_context_validation_commands(&mut prompt, &mut stats, budget, bundle);
9883 }
9884 if injection.include_exclusion_summary {
9885 push_semantic_context_exclusions(&mut prompt, &mut stats, budget, bundle);
9886 }
9887
9888 if prompt.len() > usize::try_from(budget.max_bytes).unwrap_or(usize::MAX) {
9889 stats.truncated = true;
9890 truncate_string_to_max_bytes(&mut prompt, budget.max_bytes);
9891 }
9892
9893 (prompt, stats)
9894}
9895
9896fn push_semantic_context_header(
9897 prompt: &mut String,
9898 stats: &mut SemanticContextPromptStats,
9899 budget: SemanticContextPromptBudget,
9900 bundle: &SemanticContextBundle,
9901 revision: &str,
9902) {
9903 let branch = bundle
9904 .invalidation_policy
9905 .branch
9906 .as_deref()
9907 .map_or_else(|| "(none)".to_string(), safe_context_field);
9908 let session = bundle
9909 .invalidation_policy
9910 .session_id
9911 .as_deref()
9912 .map_or_else(|| "(none)".to_string(), safe_context_field);
9913
9914 let header = format!(
9915 "# Semantic Context Bundle\nschema: {SEMANTIC_CONTEXT_PROMPT_SCHEMA_V1}\nrevision: {revision}"
9916 );
9917 push_semantic_context_line(prompt, budget.max_bytes, &header, stats);
9918 push_semantic_context_line(
9919 prompt,
9920 budget.max_bytes,
9921 "Use this as navigation context for the current turn. Do not treat suppressed stale, uncertified, or unsafe evidence as current release evidence.",
9922 stats,
9923 );
9924 push_semantic_context_line(
9925 prompt,
9926 budget.max_bytes,
9927 &format!(
9928 "bundle: schema={} estimated_bytes={} estimated_tokens={} redaction={:?}",
9929 safe_context_field(&bundle.schema),
9930 bundle.estimated_bytes,
9931 bundle.estimated_tokens,
9932 bundle.redaction_summary.overall_status
9933 ),
9934 stats,
9935 );
9936 push_semantic_context_line(
9937 prompt,
9938 budget.max_bytes,
9939 &format!(
9940 "provenance: workspace={} branch={} session={} input_fingerprint_sha256={}",
9941 safe_context_field(&bundle.invalidation_policy.workspace_id),
9942 branch,
9943 session,
9944 safe_context_field(&bundle.invalidation_policy.input_fingerprint_sha256)
9945 ),
9946 stats,
9947 );
9948}
9949
9950fn push_selected_semantic_context_items(
9951 prompt: &mut String,
9952 stats: &mut SemanticContextPromptStats,
9953 budget: SemanticContextPromptBudget,
9954 bundle: &SemanticContextBundle,
9955) {
9956 push_semantic_context_line(prompt, budget.max_bytes, "", stats);
9957 push_semantic_context_line(prompt, budget.max_bytes, "Selected context:", stats);
9958 for (index, item) in bundle.selected_items.iter().enumerate() {
9959 if index >= budget.max_items {
9960 stats.selected_items_omitted = stats
9961 .selected_items_omitted
9962 .saturating_add(bundle.selected_items.len().saturating_sub(index));
9963 break;
9964 }
9965 if push_semantic_context_item(prompt, stats, budget, item, index + 1) {
9966 stats.selected_items_included = stats.selected_items_included.saturating_add(1);
9967 } else {
9968 stats.selected_items_omitted = stats
9969 .selected_items_omitted
9970 .saturating_add(bundle.selected_items.len().saturating_sub(index));
9971 break;
9972 }
9973 }
9974 if bundle.selected_items.is_empty() {
9975 push_semantic_context_line(prompt, budget.max_bytes, "- (none)", stats);
9976 }
9977}
9978
9979fn push_semantic_context_validation_commands(
9980 prompt: &mut String,
9981 stats: &mut SemanticContextPromptStats,
9982 budget: SemanticContextPromptBudget,
9983 bundle: &SemanticContextBundle,
9984) {
9985 push_semantic_context_line(prompt, budget.max_bytes, "", stats);
9986 push_semantic_context_line(
9987 prompt,
9988 budget.max_bytes,
9989 "Suggested validation commands:",
9990 stats,
9991 );
9992 if bundle.suggested_validation_commands.is_empty() {
9993 push_semantic_context_line(prompt, budget.max_bytes, "- (none)", stats);
9994 return;
9995 }
9996
9997 for (index, command) in bundle.suggested_validation_commands.iter().enumerate() {
9998 let line = format!("- {}", safe_context_field(command));
9999 if push_semantic_context_line(prompt, budget.max_bytes, &line, stats) {
10000 stats.validation_commands_included =
10001 stats.validation_commands_included.saturating_add(1);
10002 } else {
10003 stats.validation_commands_omitted = bundle
10004 .suggested_validation_commands
10005 .len()
10006 .saturating_sub(index);
10007 break;
10008 }
10009 }
10010}
10011
10012fn push_semantic_context_exclusions(
10013 prompt: &mut String,
10014 stats: &mut SemanticContextPromptStats,
10015 budget: SemanticContextPromptBudget,
10016 bundle: &SemanticContextBundle,
10017) {
10018 push_semantic_context_line(prompt, budget.max_bytes, "", stats);
10019 push_semantic_context_line(
10020 prompt,
10021 budget.max_bytes,
10022 "Suppressed or excluded context:",
10023 stats,
10024 );
10025 if bundle.stale_evidence_suppressions.is_empty() && bundle.excluded_items.is_empty() {
10026 push_semantic_context_line(prompt, budget.max_bytes, "- (none)", stats);
10027 return;
10028 }
10029
10030 for (index, item) in bundle
10031 .stale_evidence_suppressions
10032 .iter()
10033 .chain(bundle.excluded_items.iter())
10034 .take(8)
10035 .enumerate()
10036 {
10037 let line = format!(
10038 "- {:?} {} :: {} reason={}",
10039 item.node_type,
10040 safe_context_field(&item.source_path),
10041 safe_context_field(&item.title),
10042 safe_context_field(&item.reason)
10043 );
10044 if push_semantic_context_line(prompt, budget.max_bytes, &line, stats) {
10045 stats.exclusions_included = stats.exclusions_included.saturating_add(1);
10046 } else {
10047 stats.exclusions_omitted = bundle
10048 .stale_evidence_suppressions
10049 .len()
10050 .saturating_add(bundle.excluded_items.len())
10051 .saturating_sub(index);
10052 break;
10053 }
10054 }
10055}
10056
10057fn push_semantic_context_item(
10058 prompt: &mut String,
10059 stats: &mut SemanticContextPromptStats,
10060 budget: SemanticContextPromptBudget,
10061 item: &ContextBundleItem,
10062 ordinal: usize,
10063) -> bool {
10064 let freshness = item.freshness_status.map_or_else(
10065 || "not_applicable".to_string(),
10066 |status| format!("{status:?}"),
10067 );
10068 let line = format!(
10069 "{ordinal}. {:?} {} :: {}",
10070 item.node_type,
10071 safe_context_field(&item.source_path),
10072 safe_context_field(&item.title)
10073 );
10074 let detail = format!(
10075 " reason={} score={} tokens={} freshness={} redaction={:?}",
10076 safe_context_field(&item.reason),
10077 item.score,
10078 item.estimated_tokens,
10079 freshness,
10080 item.redaction_status
10081 );
10082 push_semantic_context_line(prompt, budget.max_bytes, &line, stats)
10083 && push_semantic_context_line(prompt, budget.max_bytes, &detail, stats)
10084}
10085
10086fn push_semantic_context_line(
10087 prompt: &mut String,
10088 max_bytes: u64,
10089 line: &str,
10090 stats: &mut SemanticContextPromptStats,
10091) -> bool {
10092 let max_bytes = usize::try_from(max_bytes).unwrap_or(usize::MAX);
10093 let required = line.len().saturating_add(1);
10094 if prompt.len().saturating_add(required) > max_bytes {
10095 stats.truncated = true;
10096 return false;
10097 }
10098 prompt.push_str(line);
10099 prompt.push('\n');
10100 true
10101}
10102
10103fn truncate_string_to_max_bytes(value: &mut String, max_bytes: u64) {
10104 let max_bytes = usize::try_from(max_bytes).unwrap_or(usize::MAX);
10105 if value.len() <= max_bytes {
10106 return;
10107 }
10108 let mut end = max_bytes;
10109 while !value.is_char_boundary(end) {
10110 end = end.saturating_sub(1);
10111 }
10112 value.truncate(end);
10113}
10114
10115fn safe_context_field(value: &str) -> String {
10116 let mut output = String::with_capacity(value.len().min(512));
10117 for ch in value.chars() {
10118 if matches!(ch, '\n' | '\r' | '\t') {
10119 output.push(' ');
10120 } else if ch.is_control() {
10121 output.push('?');
10122 } else {
10123 output.push(ch);
10124 }
10125 if output.len() >= 512 {
10126 output.push_str("...");
10127 break;
10128 }
10129 }
10130 output
10131}
10132
10133fn log_repair_diagnostics(events: &[crate::extensions_js::ExtensionRepairEvent]) {
10142 use std::collections::BTreeMap;
10143
10144 for ev in events {
10146 tracing::info!(
10147 event = "extension.auto_repair",
10148 extension_id = %ev.extension_id,
10149 pattern = %ev.pattern,
10150 success = ev.success,
10151 original_error = %ev.original_error,
10152 repair_action = %ev.repair_action,
10153 );
10154 }
10155
10156 let mut by_pattern: BTreeMap<String, Vec<&str>> = BTreeMap::new();
10158 for ev in events {
10159 by_pattern
10160 .entry(ev.pattern.to_string())
10161 .or_default()
10162 .push(&ev.extension_id);
10163 }
10164
10165 let verbose = std::env::var("PI_AUTO_REPAIR_VERBOSE")
10166 .is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
10167
10168 if verbose {
10169 warn!(
10170 "[auto-repair] {} extension{} auto-repaired:",
10171 events.len(),
10172 if events.len() == 1 { "" } else { "s" }
10173 );
10174 for ev in events {
10175 warn!(
10176 " {}: {} ({})",
10177 ev.pattern, ev.extension_id, ev.repair_action
10178 );
10179 }
10180 } else {
10181 let patterns: Vec<String> = by_pattern
10183 .iter()
10184 .map(|(pat, ids)| format!("{pat}:{}", ids.len()))
10185 .collect();
10186 tracing::info!(
10187 event = "extension.auto_repair.summary",
10188 count = events.len(),
10189 patterns = %patterns.join(", "),
10190 "auto-repaired {} extension(s)",
10191 events.len(),
10192 );
10193 }
10194}
10195
10196const BLOCK_IMAGES_PLACEHOLDER: &str = "Image reading is disabled.";
10197
10198#[derive(Debug, Default, Clone, Copy)]
10199struct ImageFilterStats {
10200 removed_images: usize,
10201 affected_messages: usize,
10202}
10203
10204fn filter_images_for_provider(messages: &mut [Message]) -> ImageFilterStats {
10205 let mut stats = ImageFilterStats::default();
10206 for message in messages {
10207 let removed = filter_images_from_message(message);
10208 if removed > 0 {
10209 stats.removed_images += removed;
10210 stats.affected_messages += 1;
10211 }
10212 }
10213 stats
10214}
10215
10216fn filter_images_from_message(message: &mut Message) -> usize {
10217 match message {
10218 Message::User(user) => match &mut user.content {
10219 UserContent::Text(_) => 0,
10220 UserContent::Blocks(blocks) => filter_image_blocks(blocks),
10221 },
10222 Message::Assistant(assistant) => {
10223 let assistant = Arc::make_mut(assistant);
10224 filter_image_blocks(&mut assistant.content)
10225 }
10226 Message::ToolResult(tool_result) => {
10227 filter_image_blocks(&mut Arc::make_mut(tool_result).content)
10228 }
10229 Message::Custom(_) => 0,
10230 }
10231}
10232
10233fn filter_image_blocks(blocks: &mut Vec<ContentBlock>) -> usize {
10234 let mut removed = 0usize;
10235 let mut filtered = Vec::with_capacity(blocks.len());
10236
10237 for block in blocks.drain(..) {
10238 match block {
10239 ContentBlock::Image(_) => {
10240 removed += 1;
10241 let previous_is_placeholder =
10242 filtered
10243 .last()
10244 .is_some_and(|prev| matches!(prev, ContentBlock::Text(TextContent { text, .. }) if text.as_str().eq(BLOCK_IMAGES_PLACEHOLDER)));
10245 if !previous_is_placeholder {
10246 filtered.push(ContentBlock::Text(TextContent::new(
10247 BLOCK_IMAGES_PLACEHOLDER,
10248 )));
10249 }
10250 }
10251 other => filtered.push(other),
10252 }
10253 }
10254
10255 *blocks = filtered;
10256 removed
10257}
10258
10259fn extract_tool_calls(content: &[ContentBlock]) -> Vec<ToolCall> {
10261 content
10262 .iter()
10263 .filter_map(|block| {
10264 if let ContentBlock::ToolCall(tc) = block {
10265 Some(tc.clone())
10266 } else {
10267 None
10268 }
10269 })
10270 .collect()
10271}
10272
10273#[cfg(test)]
10278mod tests {
10279 use super::*;
10280 use crate::auth::AuthCredential;
10281 use crate::provider::{InputType, Model, ModelCost};
10282 use asupersync::runtime::RuntimeBuilder;
10283 use async_trait::async_trait;
10284 use futures::Stream;
10285 use std::collections::BTreeSet;
10286 use std::collections::HashMap;
10287 use std::path::Path;
10288 use std::pin::Pin;
10289 use std::sync::{Arc as StdArc, Mutex as StdTestMutex};
10290
10291 fn user_message(text: &str) -> Message {
10292 Message::User(UserMessage {
10293 content: UserContent::Text(text.to_string()),
10294 timestamp: 0,
10295 })
10296 }
10297
10298 fn assert_user_text(message: &Message, expected: &str) {
10299 assert!(
10300 matches!(
10301 message,
10302 Message::User(UserMessage {
10303 content: UserContent::Text(_),
10304 ..
10305 })
10306 ),
10307 "expected user text message, got {message:?}"
10308 );
10309 if let Message::User(UserMessage {
10310 content: UserContent::Text(text),
10311 ..
10312 }) = message
10313 {
10314 assert_eq!(text, expected);
10315 }
10316 }
10317
10318 fn sample_image_block() -> ContentBlock {
10319 ContentBlock::Image(ImageContent {
10320 data: "aGVsbG8=".to_string(),
10321 mime_type: "image/png".to_string(),
10322 })
10323 }
10324
10325 fn image_count_in_message(message: &Message) -> usize {
10326 let count_images = |blocks: &[ContentBlock]| {
10327 blocks
10328 .iter()
10329 .filter(|block| matches!(block, ContentBlock::Image(_)))
10330 .count()
10331 };
10332 match message {
10333 Message::User(UserMessage {
10334 content: UserContent::Blocks(blocks),
10335 ..
10336 }) => count_images(blocks),
10337 Message::Assistant(msg) => count_images(&msg.content),
10338 Message::ToolResult(tool_result) => count_images(&tool_result.content),
10339 Message::User(UserMessage {
10340 content: UserContent::Text(_),
10341 ..
10342 })
10343 | Message::Custom(_) => 0,
10344 }
10345 }
10346
10347 fn assistant_message(text: &str) -> AssistantMessage {
10348 AssistantMessage {
10349 content: vec![ContentBlock::Text(TextContent::new(text))],
10350 api: "test-api".to_string(),
10351 provider: "test-provider".to_string(),
10352 model: "test-model".to_string(),
10353 usage: Usage::default(),
10354 stop_reason: StopReason::Stop,
10355 error_message: None,
10356 timestamp: 0,
10357 }
10358 }
10359
10360 #[derive(Debug)]
10361 struct SilentProvider;
10362
10363 #[async_trait]
10364 #[allow(clippy::unnecessary_literal_bound)]
10365 impl Provider for SilentProvider {
10366 fn name(&self) -> &str {
10367 "silent-provider"
10368 }
10369
10370 fn api(&self) -> &str {
10371 "test-api"
10372 }
10373
10374 fn model_id(&self) -> &str {
10375 "test-model"
10376 }
10377
10378 async fn stream(
10379 &self,
10380 _context: &Context<'_>,
10381 _options: &StreamOptions,
10382 ) -> crate::error::Result<
10383 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
10384 > {
10385 Ok(Box::pin(futures::stream::empty()))
10386 }
10387 }
10388
10389 #[derive(Debug)]
10390 struct DeltaOnlyProvider;
10391
10392 #[async_trait]
10393 #[allow(clippy::unnecessary_literal_bound)]
10394 impl Provider for DeltaOnlyProvider {
10395 fn name(&self) -> &str {
10396 "test-provider"
10397 }
10398
10399 fn api(&self) -> &str {
10400 "test-api"
10401 }
10402
10403 fn model_id(&self) -> &str {
10404 "test-model"
10405 }
10406
10407 async fn stream(
10408 &self,
10409 _context: &Context<'_>,
10410 _options: &StreamOptions,
10411 ) -> crate::error::Result<
10412 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
10413 > {
10414 let final_message = assistant_message("hello");
10415 let events = vec![
10416 Ok(StreamEvent::TextDelta {
10417 content_index: 0,
10418 delta: "hello".to_string(),
10419 }),
10420 Ok(StreamEvent::Done {
10421 reason: StopReason::Stop,
10422 message: final_message,
10423 }),
10424 ];
10425 Ok(Box::pin(futures::stream::iter(events)))
10426 }
10427 }
10428
10429 #[derive(Debug, Default)]
10430 struct CapturedProviderContext {
10431 system_prompt: Option<String>,
10432 messages: Vec<Message>,
10433 }
10434
10435 #[derive(Debug)]
10436 struct CapturingProvider {
10437 api: &'static str,
10438 calls: StdArc<StdTestMutex<Vec<CapturedProviderContext>>>,
10439 }
10440
10441 impl CapturingProvider {
10442 fn new(api: &'static str) -> Self {
10443 Self {
10444 api,
10445 calls: StdArc::new(StdTestMutex::new(Vec::new())),
10446 }
10447 }
10448
10449 fn calls(&self) -> StdArc<StdTestMutex<Vec<CapturedProviderContext>>> {
10450 StdArc::clone(&self.calls)
10451 }
10452 }
10453
10454 #[async_trait]
10455 #[allow(clippy::unnecessary_literal_bound)]
10456 impl Provider for CapturingProvider {
10457 fn name(&self) -> &str {
10458 "capturing-provider"
10459 }
10460
10461 fn api(&self) -> &str {
10462 self.api
10463 }
10464
10465 fn model_id(&self) -> &str {
10466 "capture-model"
10467 }
10468
10469 async fn stream(
10470 &self,
10471 context: &Context<'_>,
10472 _options: &StreamOptions,
10473 ) -> crate::error::Result<
10474 Pin<Box<dyn Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
10475 > {
10476 self.calls
10477 .lock()
10478 .expect("capture context lock")
10479 .push(CapturedProviderContext {
10480 system_prompt: context.system_prompt.as_ref().map(ToString::to_string),
10481 messages: context.messages.iter().cloned().collect(),
10482 });
10483 let final_message = assistant_message("captured");
10484 Ok(Box::pin(futures::stream::iter(vec![Ok(
10485 StreamEvent::Done {
10486 reason: StopReason::Stop,
10487 message: final_message,
10488 },
10489 )])))
10490 }
10491 }
10492
10493 fn sample_semantic_context_bundle() -> SemanticContextBundle {
10494 use crate::semantic_workspace_graph::{
10495 ContextBundleBudget, ContextBundleExclusion, ContextBundleInvalidationPolicy,
10496 ContextRedactionSummary, EvidenceFreshnessStatus, RedactionStatus, SemanticNodeType,
10497 };
10498
10499 SemanticContextBundle {
10500 schema: crate::semantic_workspace_graph::SEMANTIC_CONTEXT_BUNDLE_SCHEMA.to_string(),
10501 budget: ContextBundleBudget {
10502 max_items: 8,
10503 max_bytes: 4096,
10504 },
10505 selected_items: vec![
10506 ContextBundleItem {
10507 node_id: "node-session".to_string(),
10508 node_type: SemanticNodeType::CodeSymbol,
10509 source_path: "src/agent.rs".to_string(),
10510 title: "AgentSession::run_agent_with_text".to_string(),
10511 reason: "query_match,related_to_bead_or_changed_path".to_string(),
10512 score: 420,
10513 estimated_bytes: 700,
10514 estimated_tokens: 175,
10515 freshness_status: None,
10516 redaction_status: RedactionStatus::None,
10517 },
10518 ContextBundleItem {
10519 node_id: "node-test".to_string(),
10520 node_type: SemanticNodeType::TestCase,
10521 source_path: "tests/agent_loop_reliability.rs".to_string(),
10522 title: "semantic context session coverage".to_string(),
10523 reason: "validation_context".to_string(),
10524 score: 300,
10525 estimated_bytes: 400,
10526 estimated_tokens: 100,
10527 freshness_status: Some(EvidenceFreshnessStatus::Current),
10528 redaction_status: RedactionStatus::Redacted,
10529 },
10530 ],
10531 excluded_items: vec![ContextBundleExclusion {
10532 node_id: "stale-doc".to_string(),
10533 node_type: SemanticNodeType::DocSection,
10534 source_path: "README.md".to_string(),
10535 title: "obsolete drop-in claim".to_string(),
10536 reason: "suppressed_stale_or_unsafe_evidence".to_string(),
10537 score: 250,
10538 estimated_bytes: 300,
10539 freshness_status: Some(EvidenceFreshnessStatus::Uncertified),
10540 redaction_status: RedactionStatus::SensitiveOmitted,
10541 }],
10542 stale_evidence_suppressions: Vec::new(),
10543 redaction_summary: ContextRedactionSummary {
10544 policy_version: "test-policy".to_string(),
10545 overall_status: RedactionStatus::Redacted,
10546 selected_redacted_nodes: 1,
10547 selected_sensitive_omissions: 0,
10548 suppressed_unsafe_nodes: 0,
10549 redacted_metadata_keys: BTreeSet::from(["api_key".to_string()]),
10550 sensitive_path_kinds: BTreeSet::new(),
10551 },
10552 invalidation_policy: ContextBundleInvalidationPolicy {
10553 policy_version: "test-policy".to_string(),
10554 workspace_id: "workspace:test".to_string(),
10555 branch: Some("main".to_string()),
10556 session_id: Some("session-123".to_string()),
10557 input_fingerprint_sha256: "abc123".repeat(10),
10558 cache_ttl_seconds: 900,
10559 generated_at_utc: Some("2026-05-13T00:00:00Z".to_string()),
10560 expires_at_utc: Some("2026-05-13T00:15:00Z".to_string()),
10561 invalidates_on: vec!["input_fingerprint_change".to_string()],
10562 cacheable: true,
10563 },
10564 path_normalization: Vec::new(),
10565 suggested_validation_commands: vec![
10566 "cargo test agent_semantic_context".to_string(),
10567 "cargo check --all-targets".to_string(),
10568 ],
10569 estimated_bytes: 1100,
10570 estimated_tokens: 275,
10571 }
10572 }
10573
10574 #[test]
10575 fn delta_without_start_does_not_mutate_previous_message() {
10576 let runtime = RuntimeBuilder::current_thread()
10577 .build()
10578 .expect("runtime build");
10579
10580 runtime.block_on(async {
10581 let provider = Arc::new(DeltaOnlyProvider);
10582 let tools = ToolRegistry::from_tools(Vec::new());
10583 let mut agent = Agent::new(provider, tools, AgentConfig::default());
10584
10585 agent.add_message(Message::Assistant(Arc::new(assistant_message("prev"))));
10586
10587 agent
10588 .run_with_message_with_abort(user_message("hi"), None, |_| {})
10589 .await
10590 .expect("run");
10591
10592 let assistant_texts = agent
10593 .messages()
10594 .iter()
10595 .filter_map(|message| match message {
10596 Message::Assistant(msg)
10597 if matches!(msg.content.as_slice(), [ContentBlock::Text(_)]) =>
10598 {
10599 if let [ContentBlock::Text(text)] = msg.content.as_slice() {
10600 Some(text.text.clone())
10601 } else {
10602 None
10603 }
10604 }
10605 _ => None,
10606 })
10607 .collect::<Vec<_>>();
10608
10609 assert_eq!(
10610 assistant_texts.as_slice(),
10611 ["prev".to_string(), "hello".to_string()]
10612 );
10613 });
10614 }
10615
10616 #[test]
10617 fn semantic_context_bundle_injection_is_disabled_by_default() {
10618 let runtime = RuntimeBuilder::current_thread()
10619 .build()
10620 .expect("runtime build");
10621
10622 runtime.block_on(async {
10623 let provider = CapturingProvider::new("openai-responses");
10624 let calls = provider.calls();
10625 let agent = Agent::new(
10626 Arc::new(provider),
10627 ToolRegistry::from_tools(Vec::new()),
10628 AgentConfig::default(),
10629 );
10630 let session = Arc::new(Mutex::new(Session::in_memory()));
10631 let mut agent_session =
10632 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
10633
10634 agent_session
10635 .run_text("hello".to_string(), |_| {})
10636 .await
10637 .expect("run with default context settings");
10638
10639 let calls = match calls.lock() {
10640 Ok(calls) => calls,
10641 Err(poisoned) => poisoned.into_inner(),
10642 };
10643 assert_eq!(calls.len(), 1);
10644 assert_eq!(calls[0].messages.len(), 1);
10645 assert_user_text(&calls[0].messages[0], "hello");
10646 assert!(calls[0].system_prompt.is_none());
10647 drop(calls);
10648 });
10649 }
10650
10651 #[test]
10652 fn semantic_context_bundle_injection_adds_bounded_custom_message_and_session_provenance() {
10653 let runtime = RuntimeBuilder::current_thread()
10654 .build()
10655 .expect("runtime build");
10656
10657 runtime.block_on(async {
10658 let bundle = sample_semantic_context_bundle();
10659 let revision = semantic_context_bundle_revision(&bundle);
10660 let provider = CapturingProvider::new("openai-responses");
10661 let calls = provider.calls();
10662 let agent = Agent::new(
10663 Arc::new(provider),
10664 ToolRegistry::from_tools(Vec::new()),
10665 AgentConfig::default(),
10666 );
10667 let session = Arc::new(Mutex::new(Session::in_memory()));
10668 let mut agent_session = AgentSession::new(
10669 agent,
10670 Arc::clone(&session),
10671 false,
10672 ResolvedCompactionSettings::default(),
10673 );
10674 agent_session.set_semantic_context_bundle(Some(
10675 SemanticContextBundleInjection::enabled(bundle).with_prompt_budget(4, 2048),
10676 ));
10677
10678 agent_session
10679 .run_text("use context".to_string(), |_| {})
10680 .await
10681 .expect("run with context bundle");
10682
10683 {
10684 let calls = match calls.lock() {
10685 Ok(calls) => calls,
10686 Err(poisoned) => poisoned.into_inner(),
10687 };
10688 assert_eq!(calls.len(), 1);
10689 assert_eq!(calls[0].messages.len(), 2);
10690 assert_user_text(&calls[0].messages[0], "use context");
10691 let custom = match &calls[0].messages[1] {
10692 Message::Custom(custom) => custom,
10693 other => {
10694 assert!(
10695 matches!(other, Message::Custom(_)),
10696 "expected custom semantic context message"
10697 );
10698 return;
10699 }
10700 };
10701 assert_eq!(custom.custom_type, SEMANTIC_CONTEXT_CUSTOM_TYPE);
10702 assert!(custom.display);
10703 assert!(custom.content.len() <= 2048);
10704 assert!(custom.content.contains("Semantic Context Bundle"));
10705 assert!(custom.content.contains("src/agent.rs"));
10706 let details = custom.details.as_ref().expect("context provenance");
10707 assert_eq!(
10708 details.get("bundleRevision").and_then(Value::as_str),
10709 Some(revision.as_str())
10710 );
10711 assert_eq!(
10712 details
10713 .pointer("/provider/promptShape")
10714 .and_then(Value::as_str),
10715 Some("custom_user_message")
10716 );
10717 drop(calls);
10718 }
10719
10720 let cx = crate::agent_cx::AgentCx::for_request();
10721 let stored = session
10722 .lock(cx.cx())
10723 .await
10724 .expect("session lock")
10725 .to_messages_for_current_path();
10726 assert!(
10727 stored.iter().any(|message| matches!(
10728 message,
10729 Message::Custom(CustomMessage { custom_type, details, display: true, .. })
10730 if custom_type == SEMANTIC_CONTEXT_CUSTOM_TYPE
10731 && details
10732 .as_ref()
10733 .and_then(|value| value.get("bundleRevision"))
10734 .and_then(Value::as_str)
10735 == Some(revision.as_str())
10736 )),
10737 "semantic context provenance was not persisted in session messages: {stored:?}"
10738 );
10739 });
10740 }
10741
10742 #[test]
10743 fn semantic_context_bundle_uses_system_prompt_append_for_providers_without_custom_context() {
10744 let runtime = RuntimeBuilder::current_thread()
10745 .build()
10746 .expect("runtime build");
10747
10748 runtime.block_on(async {
10749 let bundle = sample_semantic_context_bundle();
10750 let revision = semantic_context_bundle_revision(&bundle);
10751 let provider = CapturingProvider::new("gitlab-chat");
10752 let calls = provider.calls();
10753 let agent = Agent::new(
10754 Arc::new(provider),
10755 ToolRegistry::from_tools(Vec::new()),
10756 AgentConfig {
10757 system_prompt: Some("base prompt".to_string()),
10758 ..AgentConfig::default()
10759 },
10760 );
10761 let session = Arc::new(Mutex::new(Session::in_memory()));
10762 let mut agent_session = AgentSession::new(
10763 agent,
10764 Arc::clone(&session),
10765 false,
10766 ResolvedCompactionSettings::default(),
10767 );
10768 agent_session.set_semantic_context_bundle(Some(
10769 SemanticContextBundleInjection::enabled(bundle).with_prompt_budget(4, 2048),
10770 ));
10771
10772 agent_session
10773 .run_text("gitlab turn".to_string(), |_| {})
10774 .await
10775 .expect("run with system prompt context");
10776
10777 {
10778 let calls = match calls.lock() {
10779 Ok(calls) => calls,
10780 Err(poisoned) => poisoned.into_inner(),
10781 };
10782 assert_eq!(calls.len(), 1);
10783 assert_eq!(calls[0].messages.len(), 1);
10784 assert_user_text(&calls[0].messages[0], "gitlab turn");
10785 let system_prompt = calls[0].system_prompt.as_deref().expect("system prompt");
10786 assert!(system_prompt.contains("base prompt"));
10787 assert!(system_prompt.contains("Semantic Context Bundle"));
10788 assert!(system_prompt.contains("src/agent.rs"));
10789 drop(calls);
10790 }
10791
10792 let cx = crate::agent_cx::AgentCx::for_request();
10793 let stored = session
10794 .lock(cx.cx())
10795 .await
10796 .expect("session lock")
10797 .to_messages_for_current_path();
10798 assert!(
10799 stored.iter().any(|message| matches!(
10800 message,
10801 Message::Custom(CustomMessage { custom_type, details, display: false, .. })
10802 if custom_type == SEMANTIC_CONTEXT_CUSTOM_TYPE
10803 && details
10804 .as_ref()
10805 .and_then(|value| value.get("bundleRevision"))
10806 .and_then(Value::as_str)
10807 == Some(revision.as_str())
10808 )),
10809 "hidden semantic context provenance was not persisted in session messages: {stored:?}"
10810 );
10811 assert_eq!(agent_session.agent.system_prompt(), Some("base prompt"));
10812 });
10813 }
10814
10815 #[test]
10816 fn enable_extensions_policy_resolution_defaults_to_permissive() {
10817 let policy = AgentSession::resolve_extension_policy_for_enable(None, None);
10818 assert_eq!(
10819 policy.mode,
10820 crate::extensions::ExtensionPolicyMode::Permissive
10821 );
10822 }
10823
10824 #[test]
10825 fn enable_extensions_policy_resolution_respects_config_default_toggle() {
10826 let config = crate::config::Config {
10827 extension_policy: Some(crate::config::ExtensionPolicyConfig {
10828 profile: None,
10829 default_permissive: Some(false),
10830 allow_dangerous: None,
10831 }),
10832 ..Default::default()
10833 };
10834 let policy = AgentSession::resolve_extension_policy_for_enable(Some(&config), None);
10835 assert_eq!(policy.mode, crate::extensions::ExtensionPolicyMode::Strict);
10836 }
10837
10838 #[test]
10839 fn enable_extensions_policy_resolution_prefers_explicit_policy() {
10840 let config = crate::config::Config {
10841 extension_policy: Some(crate::config::ExtensionPolicyConfig {
10842 profile: None,
10843 default_permissive: Some(false),
10844 allow_dangerous: None,
10845 }),
10846 ..Default::default()
10847 };
10848 let explicit = crate::extensions::PolicyProfile::Permissive.to_policy();
10849 let policy =
10850 AgentSession::resolve_extension_policy_for_enable(Some(&config), Some(explicit));
10851 assert_eq!(
10852 policy.mode,
10853 crate::extensions::ExtensionPolicyMode::Permissive
10854 );
10855 }
10856
10857 #[test]
10858 fn test_extract_tool_calls() {
10859 let content = vec![
10860 ContentBlock::Text(TextContent::new("Hello")),
10861 ContentBlock::ToolCall(ToolCall {
10862 id: "tc1".to_string(),
10863 name: "read".to_string(),
10864 arguments: serde_json::json!({"path": "file.txt"}),
10865 thought_signature: None,
10866 }),
10867 ContentBlock::Text(TextContent::new("World")),
10868 ContentBlock::ToolCall(ToolCall {
10869 id: "tc2".to_string(),
10870 name: "bash".to_string(),
10871 arguments: serde_json::json!({"command": "ls"}),
10872 thought_signature: None,
10873 }),
10874 ];
10875
10876 let tool_calls = extract_tool_calls(&content);
10877 assert_eq!(tool_calls.len(), 2);
10878 assert_eq!(tool_calls[0].name, "read");
10879 assert_eq!(tool_calls[1].name, "bash");
10880 }
10881
10882 #[test]
10883 fn test_agent_config_default() {
10884 let config = AgentConfig::default();
10891 let expected = resolved_max_tool_iterations_default();
10892 assert_eq!(config.max_tool_iterations, expected);
10893 assert!(config.system_prompt.is_none());
10894 assert!(!config.block_images);
10895 }
10896
10897 #[test]
10898 fn resolve_max_tool_iterations_handles_unset_empty_and_whitespace() {
10899 assert_eq!(
10900 resolve_max_tool_iterations(None),
10901 MAX_TOOL_ITERATIONS_DEFAULT
10902 );
10903 assert_eq!(
10904 resolve_max_tool_iterations(Some("")),
10905 MAX_TOOL_ITERATIONS_DEFAULT
10906 );
10907 assert_eq!(
10908 resolve_max_tool_iterations(Some(" ")),
10909 MAX_TOOL_ITERATIONS_DEFAULT
10910 );
10911 }
10912
10913 #[test]
10914 fn resolve_max_tool_iterations_rejects_zero_and_invalid() {
10915 assert_eq!(
10916 resolve_max_tool_iterations(Some("0")),
10917 MAX_TOOL_ITERATIONS_DEFAULT
10918 );
10919 assert_eq!(
10920 resolve_max_tool_iterations(Some("not-a-number")),
10921 MAX_TOOL_ITERATIONS_DEFAULT
10922 );
10923 assert_eq!(
10924 resolve_max_tool_iterations(Some("-5")),
10925 MAX_TOOL_ITERATIONS_DEFAULT
10926 );
10927 assert_eq!(
10928 resolve_max_tool_iterations(Some("3.14")),
10929 MAX_TOOL_ITERATIONS_DEFAULT
10930 );
10931 }
10932
10933 #[test]
10934 fn resolve_max_tool_iterations_accepts_valid_overrides_and_trims_whitespace() {
10935 assert_eq!(resolve_max_tool_iterations(Some("1")), 1);
10936 assert_eq!(resolve_max_tool_iterations(Some("100")), 100);
10937 assert_eq!(resolve_max_tool_iterations(Some(" 200 ")), 200);
10938 assert_eq!(resolve_max_tool_iterations(Some("999")), 999);
10939 }
10940
10941 #[test]
10942 fn resolve_max_tool_iterations_clamps_above_ceiling() {
10943 assert_eq!(
10944 resolve_max_tool_iterations(Some("99999")),
10945 MAX_TOOL_ITERATIONS_CEILING
10946 );
10947 assert_eq!(
10949 resolve_max_tool_iterations(Some("1000")),
10950 MAX_TOOL_ITERATIONS_CEILING
10951 );
10952 }
10953
10954 #[test]
10955 fn clamp_max_tool_iterations_matches_resolve_semantics() {
10956 assert_eq!(clamp_max_tool_iterations(None), MAX_TOOL_ITERATIONS_DEFAULT);
10958 assert_eq!(
10959 clamp_max_tool_iterations(Some(0)),
10960 MAX_TOOL_ITERATIONS_DEFAULT
10961 );
10962 assert_eq!(clamp_max_tool_iterations(Some(7)), 7);
10963 assert_eq!(
10964 clamp_max_tool_iterations(Some(usize::MAX)),
10965 MAX_TOOL_ITERATIONS_CEILING
10966 );
10967 }
10968
10969 #[test]
10970 fn iteration_warning_fires_at_80_percent_for_default_cap() {
10971 assert!(!should_warn_at_iteration_threshold(39, 50));
10973 assert!(should_warn_at_iteration_threshold(40, 50));
10974 assert!(should_warn_at_iteration_threshold(50, 50));
10975 assert!(!should_warn_at_iteration_threshold(0, 50));
10977 }
10978
10979 #[test]
10980 fn iteration_warning_fires_at_80_percent_for_custom_caps() {
10981 for (cap, threshold) in [(100usize, 80usize), (200, 160), (1000, 800)] {
10982 assert!(
10983 !should_warn_at_iteration_threshold(threshold - 1, cap),
10984 "expected no warning below threshold (current=cap={cap}, threshold={threshold})"
10985 );
10986 assert!(
10987 should_warn_at_iteration_threshold(threshold, cap),
10988 "expected warning at threshold (cap={cap}, threshold={threshold})"
10989 );
10990 }
10991 }
10992
10993 #[test]
10994 fn iteration_warning_skipped_for_caps_below_minimum() {
10995 for cap in 0..ITERATION_WARN_MIN_CAP {
10999 for current in 0..=cap.saturating_add(2) {
11000 assert!(
11001 !should_warn_at_iteration_threshold(current, cap),
11002 "should not warn at current={current} cap={cap}"
11003 );
11004 }
11005 }
11006 }
11007
11008 #[test]
11009 fn iteration_warning_handles_minimum_warnable_cap_boundary() {
11010 assert!(!should_warn_at_iteration_threshold(3, 5));
11012 assert!(should_warn_at_iteration_threshold(4, 5));
11013 assert!(should_warn_at_iteration_threshold(5, 5));
11014 }
11015
11016 #[test]
11017 fn iteration_warning_handles_overflow_resistant_caps() {
11018 assert!(!should_warn_at_iteration_threshold(1_000_000, usize::MAX));
11025 assert!(!should_warn_at_iteration_threshold(
11026 usize::MAX / 6,
11027 usize::MAX
11028 ));
11029 assert!(should_warn_at_iteration_threshold(
11031 usize::MAX / 5,
11032 usize::MAX
11033 ));
11034 }
11035
11036 #[test]
11037 fn iteration_handoff_steering_text_is_self_describing() {
11038 let text = iteration_handoff_steering_text(42, 50);
11043 assert!(text.contains("[runtime]"));
11044 assert!(text.contains("Tool-iteration budget at >=80%"));
11045 assert!(text.contains("used 42 of 50"));
11046 assert!(text.contains("graceful handoff"));
11047 assert!(text.contains("incomplete-handoff"));
11048 assert!(text.contains("Do NOT compress"));
11049 }
11050
11051 #[test]
11052 fn filter_image_blocks_replaces_images_with_deduped_placeholder_text() {
11053 let mut blocks = vec![
11054 sample_image_block(),
11055 sample_image_block(),
11056 ContentBlock::Text(TextContent::new("tail")),
11057 sample_image_block(),
11058 ];
11059
11060 let removed = filter_image_blocks(&mut blocks);
11061
11062 assert_eq!(removed, 3);
11063 assert!(
11064 !blocks
11065 .iter()
11066 .any(|block| matches!(block, ContentBlock::Image(_)))
11067 );
11068 assert!(matches!(
11069 blocks.first(),
11070 Some(ContentBlock::Text(TextContent { text, .. }))
11071 if text.as_str().eq(BLOCK_IMAGES_PLACEHOLDER)
11072 ));
11073 assert!(matches!(
11074 blocks.get(1),
11075 Some(ContentBlock::Text(TextContent { text, .. })) if text.as_str().eq("tail")
11076 ));
11077 assert!(matches!(
11078 blocks.get(2),
11079 Some(ContentBlock::Text(TextContent { text, .. }))
11080 if text.as_str().eq(BLOCK_IMAGES_PLACEHOLDER)
11081 ));
11082 }
11083
11084 #[test]
11085 fn filter_images_for_provider_filters_images_from_all_block_message_types() {
11086 let mut messages = vec![
11087 Message::User(UserMessage {
11088 content: UserContent::Blocks(vec![
11089 ContentBlock::Text(TextContent::new("hello")),
11090 sample_image_block(),
11091 ]),
11092 timestamp: 0,
11093 }),
11094 Message::Assistant(Arc::new(AssistantMessage {
11095 content: vec![sample_image_block()],
11096 api: "test".to_string(),
11097 provider: "test".to_string(),
11098 model: "test".to_string(),
11099 usage: Usage::default(),
11100 stop_reason: StopReason::Stop,
11101 error_message: None,
11102 timestamp: 0,
11103 })),
11104 Message::tool_result(ToolResultMessage {
11105 tool_call_id: "tc1".to_string(),
11106 tool_name: "read".to_string(),
11107 content: vec![
11108 sample_image_block(),
11109 ContentBlock::Text(TextContent::new("ok")),
11110 ],
11111 details: None,
11112 is_error: false,
11113 timestamp: 0,
11114 }),
11115 ];
11116
11117 let stats = filter_images_for_provider(&mut messages);
11118
11119 assert_eq!(stats.removed_images, 3);
11120 assert_eq!(stats.affected_messages, 3);
11121 assert_eq!(
11122 messages.iter().map(image_count_in_message).sum::<usize>(),
11123 0,
11124 "no images should remain in provider-bound context"
11125 );
11126 }
11127
11128 #[test]
11129 fn build_context_strips_images_when_block_images_enabled() {
11130 let mut agent = Agent::new(
11131 Arc::new(SilentProvider),
11132 ToolRegistry::new(&[], Path::new("."), None),
11133 AgentConfig {
11134 system_prompt: None,
11135 max_tool_iterations: 50,
11136 stream_options: StreamOptions::default(),
11137 block_images: true,
11138 fail_closed_hooks: false,
11139 tool_approval: None,
11140 },
11141 );
11142 agent.add_message(Message::User(UserMessage {
11143 content: UserContent::Blocks(vec![sample_image_block()]),
11144 timestamp: 0,
11145 }));
11146
11147 let context = agent.build_context();
11148 assert_eq!(context.messages.len(), 1);
11149 assert_eq!(image_count_in_message(&context.messages[0]), 0);
11150 assert!(matches!(
11151 &context.messages[0],
11152 Message::User(UserMessage {
11153 content: UserContent::Blocks(blocks),
11154 ..
11155 }) if blocks
11156 .iter()
11157 .any(|block| matches!(block, ContentBlock::Text(TextContent { text, .. }) if text.as_str().eq(BLOCK_IMAGES_PLACEHOLDER)))
11158 ));
11159 }
11160
11161 #[test]
11162 fn build_context_keeps_images_when_block_images_disabled() {
11163 let mut agent = Agent::new(
11164 Arc::new(SilentProvider),
11165 ToolRegistry::new(&[], Path::new("."), None),
11166 AgentConfig {
11167 system_prompt: None,
11168 max_tool_iterations: 50,
11169 stream_options: StreamOptions::default(),
11170 block_images: false,
11171 fail_closed_hooks: false,
11172 tool_approval: None,
11173 },
11174 );
11175 agent.add_message(Message::User(UserMessage {
11176 content: UserContent::Blocks(vec![sample_image_block()]),
11177 timestamp: 0,
11178 }));
11179
11180 let context = agent.build_context();
11181 assert_eq!(context.messages.len(), 1);
11182 assert_eq!(image_count_in_message(&context.messages[0]), 1);
11183 }
11184
11185 #[test]
11186 fn auto_compaction_start_serializes_with_pi_mono_compatible_type_tag() {
11187 let event = AgentEvent::AutoCompactionStart {
11188 reason: "threshold".to_string(),
11189 };
11190 let json = serde_json::to_value(&event).unwrap();
11191 assert_eq!(json["type"], "auto_compaction_start");
11192 assert_eq!(json["reason"], "threshold");
11193 }
11194
11195 #[test]
11196 fn auto_compaction_end_serializes_with_pi_mono_compatible_fields() {
11197 let event = AgentEvent::AutoCompactionEnd {
11198 result: Some(serde_json::json!({"tokens_before": 5000, "tokens_after": 2000})),
11199 aborted: false,
11200 will_retry: false,
11201 error_message: None,
11202 };
11203 let json = serde_json::to_value(&event).unwrap();
11204 assert_eq!(json["type"], "auto_compaction_end");
11205 assert_eq!(json["aborted"], false);
11206 assert_eq!(json["willRetry"], false);
11207 assert!(json.get("errorMessage").is_none()); assert!(json["result"].is_object());
11209 }
11210
11211 #[test]
11212 fn auto_compaction_end_includes_error_message_when_present() {
11213 let event = AgentEvent::AutoCompactionEnd {
11214 result: None,
11215 aborted: true,
11216 will_retry: false,
11217 error_message: Some("Compaction failed".to_string()),
11218 };
11219 let json = serde_json::to_value(&event).unwrap();
11220 assert_eq!(json["type"], "auto_compaction_end");
11221 assert_eq!(json["aborted"], true);
11222 assert_eq!(json["errorMessage"], "Compaction failed");
11223 }
11224
11225 #[test]
11226 fn auto_retry_start_serializes_with_camel_case_fields() {
11227 let event = AgentEvent::AutoRetryStart {
11228 attempt: 1,
11229 max_attempts: 3,
11230 delay_ms: 2000,
11231 error_message: "Rate limited".to_string(),
11232 };
11233 let json = serde_json::to_value(&event).unwrap();
11234 assert_eq!(json["type"], "auto_retry_start");
11235 assert_eq!(json["attempt"], 1);
11236 assert_eq!(json["maxAttempts"], 3);
11237 assert_eq!(json["delayMs"], 2000);
11238 assert_eq!(json["errorMessage"], "Rate limited");
11239 }
11240
11241 #[test]
11242 fn auto_retry_end_serializes_success_and_omits_null_final_error() {
11243 let event = AgentEvent::AutoRetryEnd {
11244 success: true,
11245 attempt: 2,
11246 final_error: None,
11247 };
11248 let json = serde_json::to_value(&event).unwrap();
11249 assert_eq!(json["type"], "auto_retry_end");
11250 assert_eq!(json["success"], true);
11251 assert_eq!(json["attempt"], 2);
11252 assert!(json.get("finalError").is_none());
11253 }
11254
11255 #[test]
11256 fn auto_retry_end_includes_final_error_on_failure() {
11257 let event = AgentEvent::AutoRetryEnd {
11258 success: false,
11259 attempt: 3,
11260 final_error: Some("Max retries exceeded".to_string()),
11261 };
11262 let json = serde_json::to_value(&event).unwrap();
11263 assert_eq!(json["type"], "auto_retry_end");
11264 assert_eq!(json["success"], false);
11265 assert_eq!(json["attempt"], 3);
11266 assert_eq!(json["finalError"], "Max retries exceeded");
11267 }
11268
11269 #[test]
11270 fn message_queue_push_increments_seq_and_counts_both_queues() {
11271 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
11272 assert_eq!(queue.pending_count(), 0);
11273
11274 assert_eq!(queue.push_steering(user_message("s1")), 0);
11275 assert_eq!(queue.push_follow_up(user_message("f1")), 1);
11276 assert_eq!(queue.push_steering(user_message("s2")), 2);
11277
11278 assert_eq!(queue.pending_count(), 3);
11279 }
11280
11281 #[test]
11282 fn message_queue_pop_steering_one_at_a_time_preserves_order() {
11283 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
11284 queue.push_steering(user_message("s1"));
11285 queue.push_steering(user_message("s2"));
11286
11287 let first = queue.pop_steering();
11288 assert_eq!(first.len(), 1);
11289 assert_user_text(&first[0], "s1");
11290 assert_eq!(queue.pending_count(), 1);
11291
11292 let second = queue.pop_steering();
11293 assert_eq!(second.len(), 1);
11294 assert_user_text(&second[0], "s2");
11295 assert_eq!(queue.pending_count(), 0);
11296
11297 let empty = queue.pop_steering();
11298 assert!(empty.is_empty());
11299 }
11300
11301 #[test]
11302 fn message_queue_pop_respects_queue_modes_per_kind() {
11303 let mut queue = MessageQueue::new(QueueMode::All, QueueMode::OneAtATime);
11304 queue.push_steering(user_message("s1"));
11305 queue.push_steering(user_message("s2"));
11306 queue.push_follow_up(user_message("f1"));
11307 queue.push_follow_up(user_message("f2"));
11308
11309 let steering = queue.pop_steering();
11310 assert_eq!(steering.len(), 2);
11311 assert_user_text(&steering[0], "s1");
11312 assert_user_text(&steering[1], "s2");
11313 assert_eq!(queue.pending_count(), 2);
11314
11315 let follow_up = queue.pop_follow_up();
11316 assert_eq!(follow_up.len(), 1);
11317 assert_user_text(&follow_up[0], "f1");
11318 assert_eq!(queue.pending_count(), 1);
11319
11320 let follow_up = queue.pop_follow_up();
11321 assert_eq!(follow_up.len(), 1);
11322 assert_user_text(&follow_up[0], "f2");
11323 assert_eq!(queue.pending_count(), 0);
11324 }
11325
11326 #[test]
11327 fn message_queue_set_modes_applies_to_existing_messages() {
11328 let mut queue = MessageQueue::new(QueueMode::OneAtATime, QueueMode::OneAtATime);
11329 queue.push_steering(user_message("s1"));
11330 queue.push_steering(user_message("s2"));
11331
11332 let first = queue.pop_steering();
11333 assert_eq!(first.len(), 1);
11334 assert_user_text(&first[0], "s1");
11335
11336 queue.set_modes(QueueMode::All, QueueMode::OneAtATime);
11337 let remaining = queue.pop_steering();
11338 assert_eq!(remaining.len(), 1);
11339 assert_user_text(&remaining[0], "s2");
11340 }
11341
11342 fn build_switch_test_session(auth: &AuthStorage) -> AgentSession {
11343 let registry = ModelRegistry::load(auth, None);
11344 let current_entry = registry
11345 .find("anthropic", "claude-sonnet-4-5")
11346 .expect("anthropic model in registry");
11347 let provider = crate::providers::create_provider(¤t_entry, None)
11348 .expect("create anthropic provider");
11349 let tools = ToolRegistry::new(&[], Path::new("."), None);
11350 let mut stream_options = StreamOptions {
11351 api_key: Some("stale-key".to_string()),
11352 ..Default::default()
11353 };
11354 let _ = stream_options
11355 .headers
11356 .insert("x-stale-header".to_string(), "stale-value".to_string());
11357 let agent = Agent::new(
11358 provider,
11359 tools,
11360 AgentConfig {
11361 system_prompt: None,
11362 max_tool_iterations: 50,
11363 stream_options,
11364 block_images: false,
11365 fail_closed_hooks: false,
11366 tool_approval: None,
11367 },
11368 );
11369
11370 let mut session = Session::in_memory();
11371 session.header.provider = Some("anthropic".to_string());
11372 session.header.model_id = Some("claude-sonnet-4-5".to_string());
11373
11374 let mut agent_session = AgentSession::new(
11375 agent,
11376 Arc::new(Mutex::new(session)),
11377 false,
11378 ResolvedCompactionSettings::default(),
11379 );
11380 agent_session.set_model_registry(registry);
11381 agent_session.set_auth_storage(auth.clone());
11382 agent_session
11383 }
11384
11385 #[test]
11386 fn compaction_runtime_handle_creates_fallback_runtime() {
11387 let dir = tempfile::tempdir().expect("tempdir");
11388 let auth_path = dir.path().join("auth.json");
11389 let auth = AuthStorage::load(auth_path).expect("load auth");
11390 let mut agent_session = build_switch_test_session(&auth);
11391
11392 assert!(agent_session.compaction_runtime.is_none());
11393 assert!(agent_session.runtime_handle.is_none());
11394
11395 let runtime_handle = agent_session
11396 .compaction_runtime_handle()
11397 .expect("create fallback compaction runtime");
11398 let join = runtime_handle.spawn(async { 7_u8 });
11399 assert_eq!(futures::executor::block_on(join), 7);
11400
11401 assert!(agent_session.compaction_runtime.is_some());
11402 assert!(agent_session.runtime_handle.is_some());
11403 }
11404
11405 #[test]
11406 fn apply_session_model_selection_updates_stream_credentials_and_headers() {
11407 let dir = tempfile::tempdir().expect("tempdir");
11408 let auth_path = dir.path().join("auth.json");
11409 let mut auth = AuthStorage::load(auth_path).expect("load auth");
11410 auth.set(
11411 "anthropic",
11412 AuthCredential::ApiKey {
11413 key: "anthropic-key".to_string(),
11414 },
11415 );
11416 auth.set(
11417 "openai",
11418 AuthCredential::ApiKey {
11419 key: "openai-key".to_string(),
11420 },
11421 );
11422
11423 let mut agent_session = build_switch_test_session(&auth);
11424 agent_session
11425 .apply_session_model_selection("openai", "gpt-4o")
11426 .expect("switch should update stream options");
11427
11428 assert_eq!(agent_session.agent.provider().name(), "openai");
11429 assert_eq!(agent_session.agent.provider().model_id(), "gpt-4o");
11430 assert_eq!(
11431 agent_session.agent.stream_options().api_key.as_deref(),
11432 Some("openai-key")
11433 );
11434 assert!(
11435 agent_session.agent.stream_options().headers.is_empty(),
11436 "stream headers should be refreshed from selected model entry"
11437 );
11438 }
11439
11440 #[test]
11441 fn apply_session_model_selection_clears_stale_key_for_keyless_target() {
11442 let dir = tempfile::tempdir().expect("tempdir");
11443 let auth_path = dir.path().join("auth.json");
11444 let mut auth = AuthStorage::load(auth_path).expect("load auth");
11445 auth.set(
11446 "anthropic",
11447 AuthCredential::ApiKey {
11448 key: "anthropic-key".to_string(),
11449 },
11450 );
11451
11452 let mut registry = ModelRegistry::load(&auth, None);
11453 registry.merge_entries(vec![ModelEntry {
11454 model: Model {
11455 id: "local-model".to_string(),
11456 name: "Local Model".to_string(),
11457 api: "openai-completions".to_string(),
11458 provider: "acme-local".to_string(),
11459 base_url: "https://example.invalid/v1".to_string(),
11460 reasoning: true,
11461 input: vec![InputType::Text],
11462 cost: ModelCost {
11463 input: 0.0,
11464 output: 0.0,
11465 cache_read: 0.0,
11466 cache_write: 0.0,
11467 },
11468 context_window: 128_000,
11469 max_tokens: 8_192,
11470 headers: HashMap::new(),
11471 },
11472 api_key: None,
11473 headers: HashMap::new(),
11474 auth_header: false,
11475 compat: None,
11476 oauth_config: None,
11477 }]);
11478
11479 let mut agent_session = build_switch_test_session(&auth);
11480 agent_session.set_model_registry(registry);
11481 agent_session
11482 .apply_session_model_selection("acme-local", "local-model")
11483 .expect("keyless local model should still activate");
11484
11485 assert_eq!(agent_session.agent.provider().name(), "acme-local");
11486 assert_eq!(
11487 agent_session.agent.stream_options().api_key,
11488 None,
11489 "stale key must be cleared when target model has no configured key"
11490 );
11491 }
11492
11493 #[test]
11494 fn apply_session_model_selection_treats_blank_model_key_as_missing_credential() {
11495 let dir = tempfile::tempdir().expect("tempdir");
11496 let auth_path = dir.path().join("auth.json");
11497 let auth = AuthStorage::load(auth_path).expect("load auth");
11498
11499 let mut registry = ModelRegistry::load(&auth, None);
11500 registry.merge_entries(vec![ModelEntry {
11501 model: Model {
11502 id: "blank-model".to_string(),
11503 name: "Blank Model".to_string(),
11504 api: "openai-completions".to_string(),
11505 provider: "acme".to_string(),
11506 base_url: "https://example.invalid/v1".to_string(),
11507 reasoning: true,
11508 input: vec![InputType::Text],
11509 cost: ModelCost {
11510 input: 0.0,
11511 output: 0.0,
11512 cache_read: 0.0,
11513 cache_write: 0.0,
11514 },
11515 context_window: 128_000,
11516 max_tokens: 8_192,
11517 headers: HashMap::new(),
11518 },
11519 api_key: Some(" ".to_string()),
11520 headers: HashMap::new(),
11521 auth_header: true,
11522 compat: None,
11523 oauth_config: None,
11524 }]);
11525
11526 let mut agent_session = build_switch_test_session(&auth);
11527 agent_session.set_model_registry(registry);
11528 let err = agent_session
11529 .apply_session_model_selection("acme", "blank-model")
11530 .expect_err("blank keys must not satisfy credential requirements");
11531
11532 assert!(
11533 err.to_string()
11534 .contains("Missing credentials for acme/blank-model"),
11535 "unexpected error: {err}"
11536 );
11537 assert_eq!(agent_session.agent.provider().name(), "anthropic");
11538 assert_eq!(
11539 agent_session.agent.stream_options().api_key,
11540 Some("stale-key".to_string()),
11541 "failed switches must preserve the prior runtime credentials"
11542 );
11543 }
11544
11545 #[test]
11546 fn set_provider_model_preserves_session_header_when_switch_fails() {
11547 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11548 .build()
11549 .expect("build runtime");
11550
11551 runtime.block_on(async {
11552 let dir = tempfile::tempdir().expect("tempdir");
11553 let auth_path = dir.path().join("auth.json");
11554 let auth = AuthStorage::load(auth_path).expect("load auth");
11555 let mut agent_session = build_switch_test_session(&auth);
11556
11557 {
11558 let cx = crate::agent_cx::AgentCx::for_request();
11559 let mut session = agent_session
11560 .session
11561 .lock(cx.cx())
11562 .await
11563 .expect("session lock");
11564 session.header.provider = Some("anthropic".to_string());
11565 session.header.model_id = Some("claude-sonnet-4-5".to_string());
11566 }
11567
11568 let err = agent_session
11569 .set_provider_model("missing-provider", "missing-model")
11570 .await
11571 .expect_err("missing model should not switch");
11572 assert!(
11573 err.to_string()
11574 .contains("Unable to switch provider/model to missing-provider/missing-model"),
11575 "unexpected error: {err}"
11576 );
11577 assert_eq!(agent_session.agent.provider().name(), "anthropic");
11578 assert_eq!(
11579 agent_session.agent.provider().model_id(),
11580 "claude-sonnet-4-5"
11581 );
11582
11583 let cx = crate::agent_cx::AgentCx::for_request();
11584 let session = agent_session
11585 .session
11586 .lock(cx.cx())
11587 .await
11588 .expect("session lock");
11589 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
11590 assert_eq!(
11591 session.header.model_id.as_deref(),
11592 Some("claude-sonnet-4-5")
11593 );
11594 });
11595 }
11596
11597 #[test]
11598 fn set_provider_model_rejects_missing_credentials_without_switching() {
11599 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11600 .build()
11601 .expect("build runtime");
11602
11603 runtime.block_on(async {
11604 let dir = tempfile::tempdir().expect("tempdir");
11605 let auth_path = dir.path().join("auth.json");
11606 let auth = AuthStorage::load(auth_path).expect("load auth");
11607 let mut agent_session = build_switch_test_session(&auth);
11608
11609 {
11610 let cx = crate::agent_cx::AgentCx::for_request();
11611 let mut session = agent_session
11612 .session
11613 .lock(cx.cx())
11614 .await
11615 .expect("session lock");
11616 session.header.provider = Some("anthropic".to_string());
11617 session.header.model_id = Some("claude-sonnet-4-5".to_string());
11618 }
11619
11620 let err = agent_session
11621 .set_provider_model("openai", "gpt-4o")
11622 .await
11623 .expect_err("missing credentials should abort model switch");
11624 assert!(
11625 err.to_string()
11626 .contains("Missing credentials for openai/gpt-4o"),
11627 "unexpected error: {err}"
11628 );
11629 assert_eq!(agent_session.agent.provider().name(), "anthropic");
11630 assert_eq!(
11631 agent_session.agent.provider().model_id(),
11632 "claude-sonnet-4-5"
11633 );
11634
11635 let cx = crate::agent_cx::AgentCx::for_request();
11636 let session = agent_session
11637 .session
11638 .lock(cx.cx())
11639 .await
11640 .expect("session lock");
11641 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
11642 assert_eq!(
11643 session.header.model_id.as_deref(),
11644 Some("claude-sonnet-4-5")
11645 );
11646 });
11647 }
11648
11649 #[test]
11650 fn set_provider_model_clamps_thinking_for_non_reasoning_targets() {
11651 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11652 .build()
11653 .expect("build runtime");
11654
11655 runtime.block_on(async {
11656 let dir = tempfile::tempdir().expect("tempdir");
11657 let auth_path = dir.path().join("auth.json");
11658 let auth = AuthStorage::load(auth_path).expect("load auth");
11659
11660 let mut registry = ModelRegistry::load(&auth, None);
11661 registry.merge_entries(vec![ModelEntry {
11662 model: Model {
11663 id: "plain-model".to_string(),
11664 name: "Plain Model".to_string(),
11665 api: "openai-completions".to_string(),
11666 provider: "acme".to_string(),
11667 base_url: "https://example.invalid/v1".to_string(),
11668 reasoning: false,
11669 input: vec![InputType::Text],
11670 cost: ModelCost {
11671 input: 0.0,
11672 output: 0.0,
11673 cache_read: 0.0,
11674 cache_write: 0.0,
11675 },
11676 context_window: 128_000,
11677 max_tokens: 8_192,
11678 headers: HashMap::new(),
11679 },
11680 api_key: None,
11681 headers: HashMap::new(),
11682 auth_header: false,
11683 compat: None,
11684 oauth_config: None,
11685 }]);
11686
11687 let mut agent_session = build_switch_test_session(&auth);
11688 agent_session.set_model_registry(registry);
11689 agent_session.agent.stream_options_mut().thinking_level =
11690 Some(crate::model::ThinkingLevel::High);
11691
11692 {
11693 let cx = crate::agent_cx::AgentCx::for_request();
11694 let mut session = agent_session
11695 .session
11696 .lock(cx.cx())
11697 .await
11698 .expect("session lock");
11699 session.header.thinking_level = Some("high".to_string());
11700 }
11701
11702 agent_session
11703 .set_provider_model("acme", "plain-model")
11704 .await
11705 .expect("switch should clamp unsupported thinking");
11706
11707 assert_eq!(agent_session.agent.provider().name(), "acme");
11708 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
11709 assert_eq!(
11710 agent_session.agent.stream_options().thinking_level,
11711 Some(crate::model::ThinkingLevel::Off)
11712 );
11713
11714 let cx = crate::agent_cx::AgentCx::for_request();
11715 let session = agent_session
11716 .session
11717 .lock(cx.cx())
11718 .await
11719 .expect("session lock");
11720 assert_eq!(session.header.provider.as_deref(), Some("acme"));
11721 assert_eq!(session.header.model_id.as_deref(), Some("plain-model"));
11722 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
11723 });
11724 }
11725
11726 #[test]
11727 fn set_provider_model_records_model_change_once() {
11728 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11729 .build()
11730 .expect("build runtime");
11731
11732 runtime.block_on(async {
11733 let dir = tempfile::tempdir().expect("tempdir");
11734 let auth_path = dir.path().join("auth.json");
11735 let mut auth = AuthStorage::load(auth_path).expect("load auth");
11736 auth.set(
11737 "anthropic",
11738 AuthCredential::ApiKey {
11739 key: "anthropic-key".to_string(),
11740 },
11741 );
11742 auth.set(
11743 "openai",
11744 AuthCredential::ApiKey {
11745 key: "openai-key".to_string(),
11746 },
11747 );
11748
11749 let mut agent_session = build_switch_test_session(&auth);
11750 agent_session
11751 .set_provider_model("openai", "gpt-4o")
11752 .await
11753 .expect("switch model");
11754 agent_session
11755 .set_provider_model("openai", "gpt-4o")
11756 .await
11757 .expect("repeat same model");
11758
11759 let cx = crate::agent_cx::AgentCx::for_request();
11760 let session = agent_session
11761 .session
11762 .lock(cx.cx())
11763 .await
11764 .expect("session lock");
11765 let model_changes = session
11766 .entries_for_current_path()
11767 .iter()
11768 .filter(|entry| matches!(entry, crate::session::SessionEntry::ModelChange(_)))
11769 .count();
11770 assert_eq!(model_changes, 1);
11771 });
11772 }
11773
11774 #[test]
11775 fn sync_runtime_selection_from_session_header_clamps_and_normalizes_thinking() {
11776 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11777 .build()
11778 .expect("build runtime");
11779
11780 runtime.block_on(async {
11781 let dir = tempfile::tempdir().expect("tempdir");
11782 let auth_path = dir.path().join("auth.json");
11783 let auth = AuthStorage::load(auth_path).expect("load auth");
11784
11785 let mut registry = ModelRegistry::load(&auth, None);
11786 registry.merge_entries(vec![ModelEntry {
11787 model: Model {
11788 id: "plain-model".to_string(),
11789 name: "Plain Model".to_string(),
11790 api: "openai-completions".to_string(),
11791 provider: "acme".to_string(),
11792 base_url: "https://example.invalid/v1".to_string(),
11793 reasoning: false,
11794 input: vec![InputType::Text],
11795 cost: ModelCost {
11796 input: 0.0,
11797 output: 0.0,
11798 cache_read: 0.0,
11799 cache_write: 0.0,
11800 },
11801 context_window: 128_000,
11802 max_tokens: 8_192,
11803 headers: HashMap::new(),
11804 },
11805 api_key: None,
11806 headers: HashMap::new(),
11807 auth_header: false,
11808 compat: None,
11809 oauth_config: None,
11810 }]);
11811
11812 let mut agent_session = build_switch_test_session(&auth);
11813 agent_session.set_model_registry(registry);
11814 agent_session.agent.stream_options_mut().thinking_level =
11815 Some(crate::model::ThinkingLevel::High);
11816
11817 {
11818 let cx = crate::agent_cx::AgentCx::for_request();
11819 let mut session = agent_session
11820 .session
11821 .lock(cx.cx())
11822 .await
11823 .expect("session lock");
11824 session.header.provider = Some("acme".to_string());
11825 session.header.model_id = Some("plain-model".to_string());
11826 session.header.thinking_level = Some("high".to_string());
11827 }
11828
11829 agent_session
11830 .sync_runtime_selection_from_session_header()
11831 .await
11832 .expect("sync runtime selection");
11833
11834 assert_eq!(agent_session.agent.provider().name(), "acme");
11835 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
11836 assert_eq!(
11837 agent_session.agent.stream_options().thinking_level,
11838 Some(crate::model::ThinkingLevel::Off)
11839 );
11840
11841 let cx = crate::agent_cx::AgentCx::for_request();
11842 let session = agent_session
11843 .session
11844 .lock(cx.cx())
11845 .await
11846 .expect("session lock");
11847 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
11848 let thinking_changes = session
11849 .entries_for_current_path()
11850 .iter()
11851 .filter(|entry| {
11852 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
11853 })
11854 .count();
11855 assert_eq!(thinking_changes, 1);
11856 });
11857 }
11858
11859 #[test]
11860 fn sync_runtime_selection_from_session_header_clamps_current_thinking_when_header_omits_it() {
11861 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11862 .build()
11863 .expect("build runtime");
11864
11865 runtime.block_on(async {
11866 let dir = tempfile::tempdir().expect("tempdir");
11867 let auth_path = dir.path().join("auth.json");
11868 let auth = AuthStorage::load(auth_path).expect("load auth");
11869
11870 let mut registry = ModelRegistry::load(&auth, None);
11871 registry.merge_entries(vec![ModelEntry {
11872 model: Model {
11873 id: "plain-model".to_string(),
11874 name: "Plain Model".to_string(),
11875 api: "openai-completions".to_string(),
11876 provider: "acme".to_string(),
11877 base_url: "https://example.invalid/v1".to_string(),
11878 reasoning: false,
11879 input: vec![InputType::Text],
11880 cost: ModelCost {
11881 input: 0.0,
11882 output: 0.0,
11883 cache_read: 0.0,
11884 cache_write: 0.0,
11885 },
11886 context_window: 128_000,
11887 max_tokens: 8_192,
11888 headers: HashMap::new(),
11889 },
11890 api_key: None,
11891 headers: HashMap::new(),
11892 auth_header: false,
11893 compat: None,
11894 oauth_config: None,
11895 }]);
11896
11897 let mut agent_session = build_switch_test_session(&auth);
11898 agent_session.set_model_registry(registry);
11899 agent_session.agent.stream_options_mut().thinking_level =
11900 Some(crate::model::ThinkingLevel::High);
11901
11902 {
11903 let cx = crate::agent_cx::AgentCx::for_request();
11904 let mut session = agent_session
11905 .session
11906 .lock(cx.cx())
11907 .await
11908 .expect("session lock");
11909 session.header.provider = Some("acme".to_string());
11910 session.header.model_id = Some("plain-model".to_string());
11911 session.header.thinking_level = None;
11912 }
11913
11914 agent_session
11915 .sync_runtime_selection_from_session_header()
11916 .await
11917 .expect("sync runtime selection");
11918
11919 assert_eq!(agent_session.agent.provider().name(), "acme");
11920 assert_eq!(agent_session.agent.provider().model_id(), "plain-model");
11921 assert_eq!(
11922 agent_session.agent.stream_options().thinking_level,
11923 Some(crate::model::ThinkingLevel::Off)
11924 );
11925
11926 let cx = crate::agent_cx::AgentCx::for_request();
11927 let session = agent_session
11928 .session
11929 .lock(cx.cx())
11930 .await
11931 .expect("session lock");
11932 assert_eq!(session.header.thinking_level.as_deref(), Some("off"));
11933 let thinking_changes = session
11934 .entries_for_current_path()
11935 .iter()
11936 .filter(|entry| {
11937 matches!(entry, crate::session::SessionEntry::ThinkingLevelChange(_))
11938 })
11939 .count();
11940 assert_eq!(thinking_changes, 1);
11941 });
11942 }
11943
11944 #[test]
11945 fn sync_runtime_selection_from_session_header_rejects_missing_credentials() {
11946 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11947 .build()
11948 .expect("build runtime");
11949
11950 runtime.block_on(async {
11951 let dir = tempfile::tempdir().expect("tempdir");
11952 let auth_path = dir.path().join("auth.json");
11953 let auth = AuthStorage::load(auth_path).expect("load auth");
11954 let mut agent_session = build_switch_test_session(&auth);
11955
11956 {
11957 let cx = crate::agent_cx::AgentCx::for_request();
11958 let mut session = agent_session
11959 .session
11960 .lock(cx.cx())
11961 .await
11962 .expect("session lock");
11963 session.header.provider = Some("openai".to_string());
11964 session.header.model_id = Some("gpt-4o".to_string());
11965 }
11966
11967 let err = agent_session
11968 .sync_runtime_selection_from_session_header()
11969 .await
11970 .expect_err("sync should reject switching to a credentialed target without a key");
11971 assert!(
11972 err.to_string()
11973 .contains("Missing credentials for openai/gpt-4o"),
11974 "unexpected error: {err}"
11975 );
11976 assert_eq!(agent_session.agent.provider().name(), "anthropic");
11977 assert_eq!(
11978 agent_session.agent.provider().model_id(),
11979 "claude-sonnet-4-5"
11980 );
11981
11982 let cx = crate::agent_cx::AgentCx::for_request();
11983 let session = agent_session
11984 .session
11985 .lock(cx.cx())
11986 .await
11987 .expect("session lock");
11988 assert_eq!(session.header.provider.as_deref(), Some("openai"));
11989 assert_eq!(session.header.model_id.as_deref(), Some("gpt-4o"));
11990 });
11991 }
11992
11993 #[test]
11994 fn set_provider_model_allows_current_model_without_registry() {
11995 let runtime = asupersync::runtime::RuntimeBuilder::current_thread()
11996 .build()
11997 .expect("build runtime");
11998
11999 runtime.block_on(async {
12000 let dir = tempfile::tempdir().expect("tempdir");
12001 let auth_path = dir.path().join("auth.json");
12002 let auth = AuthStorage::load(auth_path).expect("load auth");
12003 let mut agent_session = build_switch_test_session(&auth);
12004 agent_session.model_registry = None;
12005 agent_session.agent.stream_options_mut().thinking_level =
12006 Some(crate::model::ThinkingLevel::High);
12007
12008 agent_session
12009 .set_provider_model("anthropic", "claude-sonnet-4-5")
12010 .await
12011 .expect("re-persisting the current model should succeed without a registry");
12012
12013 assert_eq!(agent_session.agent.provider().name(), "anthropic");
12014 assert_eq!(
12015 agent_session.agent.provider().model_id(),
12016 "claude-sonnet-4-5"
12017 );
12018 assert_eq!(
12019 agent_session.agent.stream_options().thinking_level,
12020 Some(crate::model::ThinkingLevel::High)
12021 );
12022
12023 let cx = crate::agent_cx::AgentCx::for_request();
12024 let session = agent_session
12025 .session
12026 .lock(cx.cx())
12027 .await
12028 .expect("session lock");
12029 assert_eq!(session.header.provider.as_deref(), Some("anthropic"));
12030 assert_eq!(
12031 session.header.model_id.as_deref(),
12032 Some("claude-sonnet-4-5")
12033 );
12034 assert_eq!(session.header.thinking_level.as_deref(), Some("high"));
12035 });
12036 }
12037
12038 #[test]
12039 fn auto_compaction_start_serializes_to_pi_mono_format() {
12040 let event = AgentEvent::AutoCompactionStart {
12041 reason: "threshold".to_string(),
12042 };
12043 let json = serde_json::to_value(&event).unwrap();
12044 assert_eq!(json["type"], "auto_compaction_start");
12045 assert_eq!(json["reason"], "threshold");
12046 }
12047
12048 #[test]
12049 fn auto_compaction_end_serializes_to_pi_mono_format() {
12050 let event = AgentEvent::AutoCompactionEnd {
12051 result: Some(serde_json::json!({
12052 "summary": "Compacted",
12053 "firstKeptEntryId": "abc123",
12054 "tokensBefore": 50000,
12055 "details": { "readFiles": [], "modifiedFiles": [] }
12056 })),
12057 aborted: false,
12058 will_retry: true,
12059 error_message: None,
12060 };
12061 let json = serde_json::to_value(&event).unwrap();
12062 assert_eq!(json["type"], "auto_compaction_end");
12063 assert!(json["result"].is_object());
12064 assert_eq!(json["aborted"], false);
12065 assert_eq!(json["willRetry"], true);
12066 assert!(json.get("errorMessage").is_none());
12067 }
12068
12069 #[test]
12070 fn auto_compaction_end_with_error_serializes_error_message() {
12071 let event = AgentEvent::AutoCompactionEnd {
12072 result: None,
12073 aborted: false,
12074 will_retry: false,
12075 error_message: Some("compaction failed".to_string()),
12076 };
12077 let json = serde_json::to_value(&event).unwrap();
12078 assert_eq!(json["type"], "auto_compaction_end");
12079 assert!(json.get("result").is_none());
12080 assert_eq!(json["errorMessage"], "compaction failed");
12081 }
12082
12083 #[test]
12084 fn apply_compaction_result_emits_structured_result_payload() {
12085 let runtime = RuntimeBuilder::current_thread()
12086 .build()
12087 .expect("runtime build");
12088
12089 runtime.block_on(async {
12090 let provider = Arc::new(SilentProvider);
12091 let tools = ToolRegistry::new(&[], Path::new("."), None);
12092 let agent = Agent::new(provider, tools, AgentConfig::default());
12093 let session = Arc::new(Mutex::new(Session::in_memory()));
12094 let agent_session =
12095 AgentSession::new(agent, session, false, ResolvedCompactionSettings::default());
12096
12097 let events: Arc<std::sync::Mutex<Vec<AgentEvent>>> =
12098 Arc::new(std::sync::Mutex::new(Vec::new()));
12099 let sink = Arc::clone(&events);
12100 let on_event: AgentEventHandler = Arc::new(move |event| {
12101 sink.lock().expect("lock compaction events").push(event);
12102 });
12103
12104 let result = compaction::CompactionResult {
12105 summary: "Compacted 10 messages into 2".to_string(),
12106 first_kept_entry_id: "entry-5".to_string(),
12107 tokens_before: 12_000,
12108 details: compaction::CompactionDetails {
12109 read_files: vec!["src/main.rs".to_string()],
12110 modified_files: vec!["src/agent.rs".to_string()],
12111 },
12112 };
12113
12114 agent_session
12115 .apply_compaction_result(result, on_event)
12116 .await
12117 .expect("apply compaction result");
12118
12119 let payload = {
12120 let guard = events.lock().expect("lock captured events");
12121 guard
12122 .iter()
12123 .find_map(|event| match event {
12124 AgentEvent::AutoCompactionEnd {
12125 result: Some(result),
12126 ..
12127 } => Some(result.clone()),
12128 _ => None,
12129 })
12130 .expect("auto compaction end payload")
12131 };
12132
12133 assert_eq!(payload["summary"], "Compacted 10 messages into 2");
12134 assert_eq!(payload["firstKeptEntryId"], "entry-5");
12135 assert_eq!(payload["tokensBefore"], 12_000);
12136 assert_eq!(payload["details"]["readFiles"], json!(["src/main.rs"]));
12137 assert_eq!(payload["details"]["modifiedFiles"], json!(["src/agent.rs"]));
12138 });
12139 }
12140
12141 #[test]
12142 fn auto_retry_start_serializes_to_pi_mono_format() {
12143 let event = AgentEvent::AutoRetryStart {
12144 attempt: 2,
12145 max_attempts: 3,
12146 delay_ms: 4000,
12147 error_message: "rate limited".to_string(),
12148 };
12149 let json = serde_json::to_value(&event).unwrap();
12150 assert_eq!(json["type"], "auto_retry_start");
12151 assert_eq!(json["attempt"], 2);
12152 assert_eq!(json["maxAttempts"], 3);
12153 assert_eq!(json["delayMs"], 4000);
12154 assert_eq!(json["errorMessage"], "rate limited");
12155 }
12156
12157 #[test]
12158 fn auto_retry_end_success_serializes_to_pi_mono_format() {
12159 let event = AgentEvent::AutoRetryEnd {
12160 success: true,
12161 attempt: 2,
12162 final_error: None,
12163 };
12164 let json = serde_json::to_value(&event).unwrap();
12165 assert_eq!(json["type"], "auto_retry_end");
12166 assert_eq!(json["success"], true);
12167 assert_eq!(json["attempt"], 2);
12168 assert!(json.get("finalError").is_none());
12169 }
12170
12171 #[test]
12172 fn auto_retry_end_failure_serializes_final_error() {
12173 let event = AgentEvent::AutoRetryEnd {
12174 success: false,
12175 attempt: 3,
12176 final_error: Some("max retries exceeded".to_string()),
12177 };
12178 let json = serde_json::to_value(&event).unwrap();
12179 assert_eq!(json["type"], "auto_retry_end");
12180 assert_eq!(json["success"], false);
12181 assert_eq!(json["attempt"], 3);
12182 assert_eq!(json["finalError"], "max retries exceeded");
12183 }
12184}