1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use crate::extension::InferenceEngineId;
7use crate::reliability::ReliabilityRequestPolicy;
8use crate::tools::{ToolChoice, ToolSpec};
9use crate::transcript::TranscriptItem;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct ModelSelection {
13 pub provider: String,
14 pub model: String,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "snake_case")]
19pub enum ProviderAuthType {
20 None,
21 ApiKey,
22 OAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct InferenceProviderMetadata {
27 pub name: String,
28 pub description: Option<String>,
29 pub auth_type: ProviderAuthType,
30 pub auth_label: Option<String>,
31 pub auth_configured: Option<bool>,
32 pub recommended: bool,
33 pub sort_order: i32,
34}
35
36impl InferenceProviderMetadata {
37 pub fn local(name: impl Into<String>) -> Self {
38 Self {
39 name: name.into(),
40 description: None,
41 auth_type: ProviderAuthType::None,
42 auth_label: None,
43 auth_configured: Some(true),
44 recommended: false,
45 sort_order: 100,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
51#[serde(rename_all = "snake_case")]
52pub enum ToolSearchMode {
53 #[default]
54 Explicit,
55 Auto,
56 ProviderNative,
57}
58
59impl ToolSearchMode {
60 pub fn allows_provider_native(self) -> bool {
61 matches!(self, Self::Auto | Self::ProviderNative)
62 }
63}
64
65#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(rename_all = "snake_case")]
67pub enum ToolSearchProviderVariant {
68 #[default]
69 Default,
70 Regex,
71 Bm25,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75#[serde(rename_all = "camelCase")]
76pub struct ToolSearchConfig {
77 #[serde(default)]
78 pub mode: ToolSearchMode,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub max_catalog_items: Option<u32>,
81 #[serde(default)]
82 pub include_mcp: bool,
83 #[serde(default)]
84 pub include_skills: bool,
85 #[serde(default)]
86 pub fallback_to_explicit_tools: bool,
87 #[serde(default)]
88 pub provider_variant: ToolSearchProviderVariant,
89}
90
91impl Default for ToolSearchConfig {
92 fn default() -> Self {
93 Self {
94 mode: ToolSearchMode::Explicit,
95 max_catalog_items: None,
96 include_mcp: true,
97 include_skills: true,
98 fallback_to_explicit_tools: true,
99 provider_variant: ToolSearchProviderVariant::Default,
100 }
101 }
102}
103
104impl ToolSearchConfig {
105 pub fn explicit() -> Self {
106 Self {
107 mode: ToolSearchMode::Explicit,
108 ..Self::default()
109 }
110 }
111
112 pub fn provider_native() -> Self {
113 Self {
114 mode: ToolSearchMode::ProviderNative,
115 ..Self::default()
116 }
117 }
118
119 pub fn is_provider_native_requested(&self) -> bool {
120 self.mode.allows_provider_native()
121 }
122
123 pub fn resolve_effective_mode(
132 &self,
133 provider_native_supported: bool,
134 ) -> Result<EffectiveToolSearchMode, ToolSearchModeError> {
135 match self.mode {
136 ToolSearchMode::Explicit => Ok(EffectiveToolSearchMode::Explicit),
137 ToolSearchMode::Auto => {
138 if provider_native_supported {
139 Ok(EffectiveToolSearchMode::ProviderNative)
140 } else {
141 Ok(EffectiveToolSearchMode::Explicit)
142 }
143 }
144 ToolSearchMode::ProviderNative => {
145 if provider_native_supported {
146 Ok(EffectiveToolSearchMode::ProviderNative)
147 } else if self.fallback_to_explicit_tools {
148 Ok(EffectiveToolSearchMode::Explicit)
149 } else {
150 Err(ToolSearchModeError::ProviderNativeUnsupported)
151 }
152 }
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
158#[serde(rename_all = "snake_case")]
159pub enum EffectiveToolSearchMode {
160 Explicit,
161 ProviderNative,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum ToolSearchModeError {
166 ProviderNativeUnsupported,
167}
168
169impl std::fmt::Display for ToolSearchModeError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 Self::ProviderNativeUnsupported => write!(
173 f,
174 "provider-native tool search was requested but the selected provider/model does \
175 not support it and fallback_to_explicit_tools is disabled; enable fallback or \
176 pick a supported model"
177 ),
178 }
179 }
180}
181
182impl std::error::Error for ToolSearchModeError {}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
185#[serde(rename_all = "camelCase")]
186pub struct ToolSearchConfigOverlay {
187 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub mode: Option<ToolSearchMode>,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub max_catalog_items: Option<u32>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub include_mcp: Option<bool>,
193 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub include_skills: Option<bool>,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub fallback_to_explicit_tools: Option<bool>,
197 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub provider_variant: Option<ToolSearchProviderVariant>,
199}
200
201impl ToolSearchConfigOverlay {
202 pub fn overlay(&mut self, other: &Self) {
203 if other.mode.is_some() {
204 self.mode = other.mode;
205 }
206 if other.max_catalog_items.is_some() {
207 self.max_catalog_items = other.max_catalog_items;
208 }
209 if other.include_mcp.is_some() {
210 self.include_mcp = other.include_mcp;
211 }
212 if other.include_skills.is_some() {
213 self.include_skills = other.include_skills;
214 }
215 if other.fallback_to_explicit_tools.is_some() {
216 self.fallback_to_explicit_tools = other.fallback_to_explicit_tools;
217 }
218 if other.provider_variant.is_some() {
219 self.provider_variant = other.provider_variant;
220 }
221 }
222
223 pub fn apply_to(&self, config: &mut ToolSearchConfig) {
224 if let Some(mode) = self.mode {
225 config.mode = mode;
226 }
227 if let Some(max_catalog_items) = self.max_catalog_items {
228 config.max_catalog_items = Some(max_catalog_items);
229 }
230 if let Some(include_mcp) = self.include_mcp {
231 config.include_mcp = include_mcp;
232 }
233 if let Some(include_skills) = self.include_skills {
234 config.include_skills = include_skills;
235 }
236 if let Some(fallback_to_explicit_tools) = self.fallback_to_explicit_tools {
237 config.fallback_to_explicit_tools = fallback_to_explicit_tools;
238 }
239 if let Some(provider_variant) = self.provider_variant {
240 config.provider_variant = provider_variant;
241 }
242 }
243}
244
245#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
246pub struct InstructionBundle {
247 pub system: Option<String>,
248 pub developer: Option<String>,
249 pub developer_context: Option<String>,
256}
257
258#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
259#[serde(rename_all = "snake_case")]
260pub enum RuntimeProfile {
261 #[default]
262 Interactive,
263 NonInteractive,
264 Eval,
265}
266
267impl RuntimeProfile {
268 pub fn as_str(self) -> &'static str {
269 match self {
270 Self::Interactive => "interactive",
271 Self::NonInteractive => "non_interactive",
272 Self::Eval => "eval",
273 }
274 }
275
276 pub fn is_non_interactive(self) -> bool {
277 matches!(self, Self::NonInteractive | Self::Eval)
278 }
279}
280
281impl std::str::FromStr for RuntimeProfile {
282 type Err = anyhow::Error;
283
284 fn from_str(value: &str) -> Result<Self, Self::Err> {
285 match value.trim().to_ascii_lowercase().as_str() {
286 "interactive" => Ok(Self::Interactive),
287 "non_interactive" | "non-interactive" | "headless" => Ok(Self::NonInteractive),
288 "eval" => Ok(Self::Eval),
289 other => anyhow::bail!(
290 "unsupported runtime profile {other:?}; expected interactive, non_interactive, or eval"
291 ),
292 }
293 }
294}
295
296#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
297pub struct ReasoningConfig {
298 pub enabled: bool,
299 pub level: Option<String>,
300}
301
302#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
303#[serde(rename_all = "snake_case")]
304pub enum ProviderFamily {
305 #[default]
306 Mock,
307 OpenAi,
308 Anthropic,
309 Gemini,
310 Xai,
311 Opencode,
312 Poolside,
313 Cursor,
314}
315
316#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
317#[serde(rename_all = "snake_case")]
318pub enum ModelSchemaPolicy {
319 #[default]
320 StandardRequiredFirst,
321 RequiredFirstFlat,
322}
323
324#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
325#[serde(rename_all = "snake_case")]
326pub enum ModelInstructionOverlay {
327 #[default]
328 Standard,
329 LiteralToolOutputs,
330 IntuitiveContext,
331}
332
333#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
334#[serde(rename_all = "camelCase")]
335pub struct ModelProfileReasoning {
336 #[serde(default, skip_serializing_if = "Option::is_none")]
337 pub orientation: Option<String>,
338 #[serde(default, skip_serializing_if = "Option::is_none")]
339 pub execution: Option<String>,
340 #[serde(default, skip_serializing_if = "Option::is_none")]
341 pub verification: Option<String>,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub recovery: Option<String>,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
347#[serde(rename_all = "camelCase")]
348pub struct ModelHarnessProfile {
349 pub model: String,
350 pub provider: String,
351 pub provider_family: ProviderFamily,
352 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub edit_tool: Option<String>,
354 #[serde(default)]
355 pub schema_policy: ModelSchemaPolicy,
356 #[serde(default)]
357 pub instruction_overlay: ModelInstructionOverlay,
358 #[serde(default)]
359 pub reasoning: ModelProfileReasoning,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub parallel_tool_calls: Option<bool>,
362 #[serde(default, skip_serializing_if = "Option::is_none")]
363 pub auto_compact_token_limit: Option<u32>,
364}
365
366#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
367#[serde(rename_all = "snake_case")]
368pub enum SpeedPolicyPhase {
369 #[default]
370 Orientation,
371 Execution,
372 Verification,
373 Recovery,
374}
375
376impl SpeedPolicyPhase {
377 pub fn as_str(self) -> &'static str {
378 match self {
379 Self::Orientation => "orientation",
380 Self::Execution => "execution",
381 Self::Verification => "verification",
382 Self::Recovery => "recovery",
383 }
384 }
385}
386
387#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
388#[serde(rename_all = "camelCase")]
389pub struct SpeedPolicyDecision {
390 pub phase: SpeedPolicyPhase,
391 pub desired_reasoning: String,
392 pub applied_reasoning: Option<String>,
393 pub supported: bool,
394}
395
396#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
397pub struct OutputConfig {
398 pub max_tokens: Option<u32>,
399 pub temperature: Option<f32>,
400 pub top_p: Option<f32>,
401 pub response_format: Option<serde_json::Value>,
402}
403
404#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
405#[serde(rename_all = "snake_case")]
406pub enum HostedWebSearchMode {
407 #[default]
408 Disabled,
409 Cached,
410 Live,
411}
412
413#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
414pub struct HostedWebSearchConfig {
415 pub mode: HostedWebSearchMode,
416}
417
418impl HostedWebSearchConfig {
419 pub fn disabled() -> Self {
420 Self {
421 mode: HostedWebSearchMode::Disabled,
422 }
423 }
424
425 pub fn cached() -> Self {
426 Self {
427 mode: HostedWebSearchMode::Cached,
428 }
429 }
430
431 pub fn live() -> Self {
432 Self {
433 mode: HostedWebSearchMode::Live,
434 }
435 }
436
437 pub fn is_enabled(&self) -> bool {
438 self.mode != HostedWebSearchMode::Disabled
439 }
440}
441
442#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
443pub struct RuntimeHints {
444 pub trace_id: Option<String>,
445 pub prompt_cache_key: Option<String>,
446 pub auto_compact_token_limit: Option<u32>,
447 #[serde(default)]
448 pub profile: RuntimeProfile,
449 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub parallel_tool_calls: Option<bool>,
451 #[serde(default)]
452 pub hosted_web_search: HostedWebSearchConfig,
453 #[serde(default)]
454 pub tool_search: ToolSearchConfig,
455 #[serde(default, skip_serializing_if = "Option::is_none")]
456 pub speed_policy: Option<SpeedPolicyDecision>,
457 #[serde(default, skip_serializing_if = "Option::is_none")]
458 pub deadline_remaining_seconds: Option<u64>,
459 #[serde(default, skip_serializing_if = "Option::is_none")]
460 pub reliability: Option<ReliabilityRequestPolicy>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
464pub struct AgentInferenceRequest {
465 pub model: ModelSelection,
466 pub instructions: InstructionBundle,
467 pub transcript: Vec<TranscriptItem>,
468 pub tools: Vec<ToolSpec>,
469 pub tool_choice: ToolChoice,
470 pub reasoning: ReasoningConfig,
471 pub output: OutputConfig,
472 pub runtime: RuntimeHints,
473 pub metadata: serde_json::Value,
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
477pub struct MessageDelta {
478 pub text: String,
479 #[serde(default, skip_serializing_if = "Option::is_none")]
480 pub phase: Option<String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
484pub struct ReasoningDelta {
485 pub text: String,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
489pub struct ToolCallStarted {
490 pub id: String,
491 pub name: String,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
495pub struct ToolCallDelta {
496 pub id: String,
497 pub arguments_delta: String,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
501pub struct ToolCallCompleted {
502 pub id: String,
503 pub name: String,
504 pub arguments: String,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
508pub struct HostedToolCallStarted {
509 pub id: String,
510 pub name: String,
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
514pub struct HostedToolCallCompleted {
515 pub id: String,
516 pub name: String,
517 pub arguments: String,
518}
519
520#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
521pub struct TokenUsage {
522 pub prompt_tokens: u32,
523 pub completion_tokens: u32,
524 pub total_tokens: u32,
525 #[serde(default)]
526 pub cached_prompt_tokens: u32,
527 #[serde(default)]
534 pub cache_creation_prompt_tokens: u32,
535 #[serde(default, skip_serializing_if = "Option::is_none")]
536 pub cache_hit_rate: Option<f64>,
537}
538
539impl TokenUsage {
540 pub fn new(prompt_tokens: u32, completion_tokens: u32, total_tokens: u32) -> Self {
541 Self {
542 prompt_tokens,
543 completion_tokens,
544 total_tokens,
545 cached_prompt_tokens: 0,
546 cache_creation_prompt_tokens: 0,
547 cache_hit_rate: cache_hit_rate(prompt_tokens, 0),
548 }
549 }
550
551 pub fn with_cached_prompt_tokens(mut self, cached_prompt_tokens: u32) -> Self {
552 self.cached_prompt_tokens = cached_prompt_tokens.min(self.prompt_tokens);
553 self.cache_hit_rate = cache_hit_rate(self.prompt_tokens, self.cached_prompt_tokens);
554 self
555 }
556
557 pub fn with_cache_creation_prompt_tokens(mut self, cache_creation_prompt_tokens: u32) -> Self {
558 self.cache_creation_prompt_tokens = cache_creation_prompt_tokens.min(self.prompt_tokens);
559 self
560 }
561
562 pub fn add_assign(&mut self, usage: &TokenUsage) {
563 self.prompt_tokens = self.prompt_tokens.saturating_add(usage.prompt_tokens);
564 self.completion_tokens = self
565 .completion_tokens
566 .saturating_add(usage.completion_tokens);
567 self.total_tokens = self.total_tokens.saturating_add(usage.total_tokens);
568 self.cached_prompt_tokens = self
569 .cached_prompt_tokens
570 .saturating_add(usage.cached_prompt_tokens);
571 self.cache_creation_prompt_tokens = self
572 .cache_creation_prompt_tokens
573 .saturating_add(usage.cache_creation_prompt_tokens);
574 self.cache_hit_rate = cache_hit_rate(self.prompt_tokens, self.cached_prompt_tokens);
575 }
576
577 pub fn is_empty(&self) -> bool {
578 self.prompt_tokens == 0
579 && self.completion_tokens == 0
580 && self.total_tokens == 0
581 && self.cached_prompt_tokens == 0
582 && self.cache_creation_prompt_tokens == 0
583 }
584}
585
586pub fn cache_hit_rate(prompt_tokens: u32, cached_prompt_tokens: u32) -> Option<f64> {
587 if prompt_tokens == 0 {
588 None
589 } else {
590 Some(f64::from(cached_prompt_tokens.min(prompt_tokens)) / f64::from(prompt_tokens))
591 }
592}
593
594#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
595pub struct CompletionMetadata {
596 pub stop_reason: Option<String>,
597 pub provider_response_id: Option<String>,
598}
599
600pub fn finish_reason_from_stop_reason(stop_reason: &str) -> String {
608 match stop_reason {
609 "end_turn" | "stop" | "stop_sequence" => "stop",
610 "max_tokens" | "length" => "length",
611 "tool_use" | "tool_calls" => "toolUse",
612 "content_filter" => "contentFilter",
613 "refusal" => "refusal",
614 other => other,
615 }
616 .to_string()
617}
618
619#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
620pub struct InferenceFailure {
621 pub message: String,
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
625pub struct CompactionProgress {
626 pub status: String,
627 #[serde(default, skip_serializing_if = "Option::is_none")]
628 pub item_id: Option<String>,
629}
630
631#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
632pub enum InferenceEvent {
633 MessageDelta(MessageDelta),
634 ReasoningDelta(ReasoningDelta),
635 ToolCallStarted(ToolCallStarted),
636 ToolCallDelta(ToolCallDelta),
637 ToolCallCompleted(ToolCallCompleted),
638 HostedToolCallStarted(HostedToolCallStarted),
639 HostedToolCallCompleted(HostedToolCallCompleted),
640 Compaction(CompactionProgress),
641 Usage(TokenUsage),
642 Completed(CompletionMetadata),
643 Failed(InferenceFailure),
644 ProviderMetadata(serde_json::Value),
645}
646
647pub type InferenceEventStream =
648 Pin<Box<dyn Stream<Item = anyhow::Result<InferenceEvent>> + Send + 'static>>;
649
650#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
651pub struct InferenceCapabilities {
652 pub streaming: bool,
653 pub tool_calls: bool,
654 pub parallel_tool_calls: bool,
655 pub reasoning_summaries: bool,
656 pub structured_output: bool,
657 pub image_input: bool,
658 pub prompt_cache: bool,
659 pub provider_metadata: bool,
660 pub tool_search: bool,
661}
662
663impl InferenceCapabilities {
664 pub fn text_only() -> Self {
665 Self {
666 streaming: true,
667 tool_calls: false,
668 parallel_tool_calls: false,
669 reasoning_summaries: false,
670 structured_output: false,
671 image_input: false,
672 prompt_cache: false,
673 provider_metadata: false,
674 tool_search: false,
675 }
676 }
677
678 pub fn coding_agent_default() -> Self {
679 Self {
680 streaming: true,
681 tool_calls: true,
682 parallel_tool_calls: true,
683 reasoning_summaries: false,
684 structured_output: false,
685 image_input: false,
686 prompt_cache: false,
687 provider_metadata: true,
688 tool_search: false,
689 }
690 }
691}
692
693#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
694pub struct ModelDescriptor {
695 pub id: String,
696 pub name: String,
697 pub context_window: Option<u32>,
698 #[serde(default, skip_serializing_if = "Option::is_none")]
699 pub default_reasoning: Option<String>,
700 #[serde(default, skip_serializing_if = "Vec::is_empty")]
701 pub supported_reasoning: Vec<ReasoningEffortDescriptor>,
702}
703
704#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
705pub struct ReasoningEffortDescriptor {
706 pub effort: String,
707 pub description: String,
708}
709
710pub struct InferenceProviderContext<'a> {
711 pub provider_id: &'a str,
712}
713
714pub struct InferenceTurnContext<'a> {
715 pub thread_id: &'a str,
716 pub turn_id: &'a str,
717 pub tool_executor: Option<std::sync::Arc<dyn TurnToolExecutor>>,
724}
725
726#[derive(Debug, Clone)]
728pub struct TurnToolOutcome {
729 pub result: String,
730 pub is_error: bool,
731}
732
733#[async_trait::async_trait]
737pub trait TurnToolExecutor: Send + Sync {
738 async fn execute(&self, call: ToolCallCompleted) -> anyhow::Result<TurnToolOutcome>;
739}
740
741#[async_trait::async_trait]
742pub trait InferenceEngine: Send + Sync + 'static {
743 fn id(&self) -> InferenceEngineId;
744 fn capabilities(&self) -> InferenceCapabilities;
745
746 fn metadata(&self) -> InferenceProviderMetadata {
747 InferenceProviderMetadata::local(self.id())
748 }
749
750 async fn list_models(
751 &self,
752 ctx: InferenceProviderContext<'_>,
753 ) -> anyhow::Result<Vec<ModelDescriptor>>;
754
755 async fn stream_turn(
756 &self,
757 ctx: InferenceTurnContext<'_>,
758 request: AgentInferenceRequest,
759 ) -> anyhow::Result<InferenceEventStream>;
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765
766 #[test]
767 fn finish_reason_mapping_normalizes_known_stop_reasons() {
768 assert_eq!(finish_reason_from_stop_reason("end_turn"), "stop");
769 assert_eq!(finish_reason_from_stop_reason("stop"), "stop");
770 assert_eq!(finish_reason_from_stop_reason("stop_sequence"), "stop");
771 assert_eq!(finish_reason_from_stop_reason("max_tokens"), "length");
772 assert_eq!(finish_reason_from_stop_reason("length"), "length");
773 assert_eq!(finish_reason_from_stop_reason("tool_use"), "toolUse");
774 assert_eq!(finish_reason_from_stop_reason("tool_calls"), "toolUse");
775 assert_eq!(
776 finish_reason_from_stop_reason("content_filter"),
777 "contentFilter"
778 );
779 assert_eq!(finish_reason_from_stop_reason("refusal"), "refusal");
780 assert_eq!(finish_reason_from_stop_reason("pause_turn"), "pause_turn");
781 }
782
783 #[test]
784 fn token_usage_accumulates_cache_creation_prompt_tokens() {
785 let mut usage = TokenUsage::new(100, 10, 110)
786 .with_cached_prompt_tokens(80)
787 .with_cache_creation_prompt_tokens(15);
788 usage.add_assign(
789 &TokenUsage::new(50, 5, 55)
790 .with_cached_prompt_tokens(40)
791 .with_cache_creation_prompt_tokens(10),
792 );
793
794 assert_eq!(usage.prompt_tokens, 150);
795 assert_eq!(usage.cached_prompt_tokens, 120);
796 assert_eq!(usage.cache_creation_prompt_tokens, 25);
797 assert!(!usage.is_empty());
798
799 let creation_only = TokenUsage {
800 cache_creation_prompt_tokens: 1,
801 ..TokenUsage::default()
802 };
803 assert!(!creation_only.is_empty());
804 }
805
806 #[test]
807 fn inference_speed_policy_decision_serializes_runtime_metadata() {
808 let decision = SpeedPolicyDecision {
809 phase: SpeedPolicyPhase::Verification,
810 desired_reasoning: "high".to_string(),
811 applied_reasoning: Some("high".to_string()),
812 supported: true,
813 };
814 let hints = RuntimeHints {
815 speed_policy: Some(decision),
816 ..RuntimeHints::default()
817 };
818
819 let json = serde_json::to_value(hints).unwrap();
820 assert_eq!(
821 json.get("speed_policy")
822 .and_then(|value| value.get("phase"))
823 .and_then(serde_json::Value::as_str),
824 Some("verification")
825 );
826 assert_eq!(
827 json.get("speed_policy")
828 .and_then(|value| value.get("desiredReasoning"))
829 .and_then(serde_json::Value::as_str),
830 Some("high")
831 );
832 assert_eq!(
833 json.get("speed_policy")
834 .and_then(|value| value.get("appliedReasoning"))
835 .and_then(serde_json::Value::as_str),
836 Some("high")
837 );
838 }
839
840 #[test]
841 fn inference_reliability_policy_serializes_runtime_metadata() {
842 let hints = RuntimeHints {
843 reliability: Some(ReliabilityRequestPolicy::default()),
844 ..RuntimeHints::default()
845 };
846
847 let json = serde_json::to_value(hints).unwrap();
848 assert_eq!(
849 json.get("reliability")
850 .and_then(|value| value.get("providerRetryMaxAttempts"))
851 .and_then(serde_json::Value::as_u64),
852 Some(3)
853 );
854 assert_eq!(
855 json.get("reliability")
856 .and_then(|value| value.get("retryEmptyProviderBody"))
857 .and_then(serde_json::Value::as_bool),
858 Some(true)
859 );
860 }
861
862 #[test]
863 fn tool_search_config_serializes_provider_native_request() {
864 let config = ToolSearchConfig {
865 mode: ToolSearchMode::ProviderNative,
866 max_catalog_items: Some(200),
867 include_mcp: true,
868 include_skills: false,
869 fallback_to_explicit_tools: true,
870 provider_variant: ToolSearchProviderVariant::Bm25,
871 };
872
873 let value = serde_json::to_value(&config).unwrap();
874
875 assert_eq!(value["mode"], "provider_native");
876 assert_eq!(value["maxCatalogItems"], 200);
877 assert_eq!(value["includeMcp"], true);
878 assert_eq!(value["includeSkills"], false);
879 assert_eq!(value["providerVariant"], "bm25");
880 assert!(config.is_provider_native_requested());
881 }
882
883 #[test]
884 fn explicit_tool_search_config_preserves_current_default() {
885 let config = ToolSearchConfig::default();
886
887 assert_eq!(config.mode, ToolSearchMode::Explicit);
888 assert!(!config.is_provider_native_requested());
889 assert!(config.fallback_to_explicit_tools);
890 }
891
892 #[test]
893 fn tool_search_effective_mode_resolution_covers_fallback_matrix() {
894 let explicit = ToolSearchConfig::explicit();
895 assert_eq!(
896 explicit.resolve_effective_mode(true).unwrap(),
897 EffectiveToolSearchMode::Explicit
898 );
899
900 let auto = ToolSearchConfig {
901 mode: ToolSearchMode::Auto,
902 ..ToolSearchConfig::default()
903 };
904 assert_eq!(
905 auto.resolve_effective_mode(true).unwrap(),
906 EffectiveToolSearchMode::ProviderNative
907 );
908 assert_eq!(
909 auto.resolve_effective_mode(false).unwrap(),
910 EffectiveToolSearchMode::Explicit
911 );
912
913 let native = ToolSearchConfig::provider_native();
914 assert_eq!(
915 native.resolve_effective_mode(true).unwrap(),
916 EffectiveToolSearchMode::ProviderNative
917 );
918 assert_eq!(
919 native.resolve_effective_mode(false).unwrap(),
920 EffectiveToolSearchMode::Explicit
921 );
922
923 let strict = ToolSearchConfig {
924 fallback_to_explicit_tools: false,
925 ..ToolSearchConfig::provider_native()
926 };
927 let error = strict.resolve_effective_mode(false).unwrap_err();
928 assert_eq!(error, ToolSearchModeError::ProviderNativeUnsupported);
929 assert!(error.to_string().contains("fallback_to_explicit_tools"));
930 }
931}