Skip to main content

roder_api/
inference.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use crate::extension::InferenceEngineId;
7use crate::reliability::ReliabilityRequestPolicy;
8use crate::tools::{ToolChoice, ToolSpec};
9use crate::transcript::TranscriptItem;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct ModelSelection {
13    pub provider: String,
14    pub model: String,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18#[serde(rename_all = "snake_case")]
19pub enum ProviderAuthType {
20    None,
21    ApiKey,
22    OAuth,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct InferenceProviderMetadata {
27    pub name: String,
28    pub description: Option<String>,
29    pub auth_type: ProviderAuthType,
30    pub auth_label: Option<String>,
31    pub auth_configured: Option<bool>,
32    pub recommended: bool,
33    pub sort_order: i32,
34}
35
36impl InferenceProviderMetadata {
37    pub fn local(name: impl Into<String>) -> Self {
38        Self {
39            name: name.into(),
40            description: None,
41            auth_type: ProviderAuthType::None,
42            auth_label: None,
43            auth_configured: Some(true),
44            recommended: false,
45            sort_order: 100,
46        }
47    }
48}
49
50#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
51#[serde(rename_all = "snake_case")]
52pub enum ToolSearchMode {
53    #[default]
54    Explicit,
55    Auto,
56    ProviderNative,
57}
58
59impl ToolSearchMode {
60    pub fn allows_provider_native(self) -> bool {
61        matches!(self, Self::Auto | Self::ProviderNative)
62    }
63}
64
65#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(rename_all = "snake_case")]
67pub enum ToolSearchProviderVariant {
68    #[default]
69    Default,
70    Regex,
71    Bm25,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75#[serde(rename_all = "camelCase")]
76pub struct ToolSearchConfig {
77    #[serde(default)]
78    pub mode: ToolSearchMode,
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub max_catalog_items: Option<u32>,
81    #[serde(default)]
82    pub include_mcp: bool,
83    #[serde(default)]
84    pub include_skills: bool,
85    #[serde(default)]
86    pub fallback_to_explicit_tools: bool,
87    #[serde(default)]
88    pub provider_variant: ToolSearchProviderVariant,
89}
90
91impl Default for ToolSearchConfig {
92    fn default() -> Self {
93        Self {
94            mode: ToolSearchMode::Explicit,
95            max_catalog_items: None,
96            include_mcp: true,
97            include_skills: true,
98            fallback_to_explicit_tools: true,
99            provider_variant: ToolSearchProviderVariant::Default,
100        }
101    }
102}
103
104impl ToolSearchConfig {
105    pub fn explicit() -> Self {
106        Self {
107            mode: ToolSearchMode::Explicit,
108            ..Self::default()
109        }
110    }
111
112    pub fn provider_native() -> Self {
113        Self {
114            mode: ToolSearchMode::ProviderNative,
115            ..Self::default()
116        }
117    }
118
119    pub fn is_provider_native_requested(&self) -> bool {
120        self.mode.allows_provider_native()
121    }
122
123    /**
124     * Resolve the effective tool-search mode for one provider/model turn.
125     *
126     * `Auto` silently falls back to explicit tools when the provider/model
127     * does not support native tool search. An explicit `ProviderNative`
128     * request only falls back when `fallback_to_explicit_tools` allows it;
129     * otherwise the turn must fail closed with the returned diagnostic.
130     */
131    pub fn resolve_effective_mode(
132        &self,
133        provider_native_supported: bool,
134    ) -> Result<EffectiveToolSearchMode, ToolSearchModeError> {
135        match self.mode {
136            ToolSearchMode::Explicit => Ok(EffectiveToolSearchMode::Explicit),
137            ToolSearchMode::Auto => {
138                if provider_native_supported {
139                    Ok(EffectiveToolSearchMode::ProviderNative)
140                } else {
141                    Ok(EffectiveToolSearchMode::Explicit)
142                }
143            }
144            ToolSearchMode::ProviderNative => {
145                if provider_native_supported {
146                    Ok(EffectiveToolSearchMode::ProviderNative)
147                } else if self.fallback_to_explicit_tools {
148                    Ok(EffectiveToolSearchMode::Explicit)
149                } else {
150                    Err(ToolSearchModeError::ProviderNativeUnsupported)
151                }
152            }
153        }
154    }
155}
156
157#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
158#[serde(rename_all = "snake_case")]
159pub enum EffectiveToolSearchMode {
160    Explicit,
161    ProviderNative,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum ToolSearchModeError {
166    ProviderNativeUnsupported,
167}
168
169impl std::fmt::Display for ToolSearchModeError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            Self::ProviderNativeUnsupported => write!(
173                f,
174                "provider-native tool search was requested but the selected provider/model does \
175                 not support it and fallback_to_explicit_tools is disabled; enable fallback or \
176                 pick a supported model"
177            ),
178        }
179    }
180}
181
182impl std::error::Error for ToolSearchModeError {}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
185#[serde(rename_all = "camelCase")]
186pub struct ToolSearchConfigOverlay {
187    #[serde(default, skip_serializing_if = "Option::is_none")]
188    pub mode: Option<ToolSearchMode>,
189    #[serde(default, skip_serializing_if = "Option::is_none")]
190    pub max_catalog_items: Option<u32>,
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub include_mcp: Option<bool>,
193    #[serde(default, skip_serializing_if = "Option::is_none")]
194    pub include_skills: Option<bool>,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub fallback_to_explicit_tools: Option<bool>,
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub provider_variant: Option<ToolSearchProviderVariant>,
199}
200
201impl ToolSearchConfigOverlay {
202    pub fn overlay(&mut self, other: &Self) {
203        if other.mode.is_some() {
204            self.mode = other.mode;
205        }
206        if other.max_catalog_items.is_some() {
207            self.max_catalog_items = other.max_catalog_items;
208        }
209        if other.include_mcp.is_some() {
210            self.include_mcp = other.include_mcp;
211        }
212        if other.include_skills.is_some() {
213            self.include_skills = other.include_skills;
214        }
215        if other.fallback_to_explicit_tools.is_some() {
216            self.fallback_to_explicit_tools = other.fallback_to_explicit_tools;
217        }
218        if other.provider_variant.is_some() {
219            self.provider_variant = other.provider_variant;
220        }
221    }
222
223    pub fn apply_to(&self, config: &mut ToolSearchConfig) {
224        if let Some(mode) = self.mode {
225            config.mode = mode;
226        }
227        if let Some(max_catalog_items) = self.max_catalog_items {
228            config.max_catalog_items = Some(max_catalog_items);
229        }
230        if let Some(include_mcp) = self.include_mcp {
231            config.include_mcp = include_mcp;
232        }
233        if let Some(include_skills) = self.include_skills {
234            config.include_skills = include_skills;
235        }
236        if let Some(fallback_to_explicit_tools) = self.fallback_to_explicit_tools {
237            config.fallback_to_explicit_tools = fallback_to_explicit_tools;
238        }
239        if let Some(provider_variant) = self.provider_variant {
240            config.provider_variant = provider_variant;
241        }
242    }
243}
244
245#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
246pub struct InstructionBundle {
247    pub system: Option<String>,
248    pub developer: Option<String>,
249    /**
250     * Per-turn developer-authority context supplied on turn/start. Volatile:
251     * providers must render it after `system` and `developer` so prompt-cache
252     * breakpoints on the stable prefix survive per-turn changes. Never
253     * persisted to thread state.
254     */
255    pub developer_context: Option<String>,
256}
257
258#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
259#[serde(rename_all = "snake_case")]
260pub enum RuntimeProfile {
261    #[default]
262    Interactive,
263    NonInteractive,
264    Eval,
265}
266
267impl RuntimeProfile {
268    pub fn as_str(self) -> &'static str {
269        match self {
270            Self::Interactive => "interactive",
271            Self::NonInteractive => "non_interactive",
272            Self::Eval => "eval",
273        }
274    }
275
276    pub fn is_non_interactive(self) -> bool {
277        matches!(self, Self::NonInteractive | Self::Eval)
278    }
279}
280
281impl std::str::FromStr for RuntimeProfile {
282    type Err = anyhow::Error;
283
284    fn from_str(value: &str) -> Result<Self, Self::Err> {
285        match value.trim().to_ascii_lowercase().as_str() {
286            "interactive" => Ok(Self::Interactive),
287            "non_interactive" | "non-interactive" | "headless" => Ok(Self::NonInteractive),
288            "eval" => Ok(Self::Eval),
289            other => anyhow::bail!(
290                "unsupported runtime profile {other:?}; expected interactive, non_interactive, or eval"
291            ),
292        }
293    }
294}
295
296#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
297pub struct ReasoningConfig {
298    pub enabled: bool,
299    pub level: Option<String>,
300}
301
302#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
303#[serde(rename_all = "snake_case")]
304pub enum ProviderFamily {
305    #[default]
306    Mock,
307    OpenAi,
308    Anthropic,
309    Gemini,
310    Xai,
311    Opencode,
312    Poolside,
313    Cursor,
314}
315
316#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
317#[serde(rename_all = "snake_case")]
318pub enum ModelSchemaPolicy {
319    #[default]
320    StandardRequiredFirst,
321    RequiredFirstFlat,
322}
323
324#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
325#[serde(rename_all = "snake_case")]
326pub enum ModelInstructionOverlay {
327    #[default]
328    Standard,
329    LiteralToolOutputs,
330    IntuitiveContext,
331}
332
333#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
334#[serde(rename_all = "camelCase")]
335pub struct ModelProfileReasoning {
336    #[serde(default, skip_serializing_if = "Option::is_none")]
337    pub orientation: Option<String>,
338    #[serde(default, skip_serializing_if = "Option::is_none")]
339    pub execution: Option<String>,
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub verification: Option<String>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub recovery: Option<String>,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
347#[serde(rename_all = "camelCase")]
348pub struct ModelHarnessProfile {
349    pub model: String,
350    pub provider: String,
351    pub provider_family: ProviderFamily,
352    #[serde(default, skip_serializing_if = "Option::is_none")]
353    pub edit_tool: Option<String>,
354    #[serde(default)]
355    pub schema_policy: ModelSchemaPolicy,
356    #[serde(default)]
357    pub instruction_overlay: ModelInstructionOverlay,
358    #[serde(default)]
359    pub reasoning: ModelProfileReasoning,
360    #[serde(default, skip_serializing_if = "Option::is_none")]
361    pub parallel_tool_calls: Option<bool>,
362    #[serde(default, skip_serializing_if = "Option::is_none")]
363    pub auto_compact_token_limit: Option<u32>,
364}
365
366#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
367#[serde(rename_all = "snake_case")]
368pub enum SpeedPolicyPhase {
369    #[default]
370    Orientation,
371    Execution,
372    Verification,
373    Recovery,
374}
375
376impl SpeedPolicyPhase {
377    pub fn as_str(self) -> &'static str {
378        match self {
379            Self::Orientation => "orientation",
380            Self::Execution => "execution",
381            Self::Verification => "verification",
382            Self::Recovery => "recovery",
383        }
384    }
385}
386
387#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
388#[serde(rename_all = "camelCase")]
389pub struct SpeedPolicyDecision {
390    pub phase: SpeedPolicyPhase,
391    pub desired_reasoning: String,
392    pub applied_reasoning: Option<String>,
393    pub supported: bool,
394}
395
396#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
397pub struct OutputConfig {
398    pub max_tokens: Option<u32>,
399    pub temperature: Option<f32>,
400    pub top_p: Option<f32>,
401    pub response_format: Option<serde_json::Value>,
402}
403
404#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
405#[serde(rename_all = "snake_case")]
406pub enum HostedWebSearchMode {
407    #[default]
408    Disabled,
409    Cached,
410    Live,
411}
412
413#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
414pub struct HostedWebSearchConfig {
415    pub mode: HostedWebSearchMode,
416}
417
418impl HostedWebSearchConfig {
419    pub fn disabled() -> Self {
420        Self {
421            mode: HostedWebSearchMode::Disabled,
422        }
423    }
424
425    pub fn cached() -> Self {
426        Self {
427            mode: HostedWebSearchMode::Cached,
428        }
429    }
430
431    pub fn live() -> Self {
432        Self {
433            mode: HostedWebSearchMode::Live,
434        }
435    }
436
437    pub fn is_enabled(&self) -> bool {
438        self.mode != HostedWebSearchMode::Disabled
439    }
440}
441
442#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
443pub struct RuntimeHints {
444    pub trace_id: Option<String>,
445    pub prompt_cache_key: Option<String>,
446    pub auto_compact_token_limit: Option<u32>,
447    #[serde(default)]
448    pub profile: RuntimeProfile,
449    #[serde(default, skip_serializing_if = "Option::is_none")]
450    pub parallel_tool_calls: Option<bool>,
451    #[serde(default)]
452    pub hosted_web_search: HostedWebSearchConfig,
453    #[serde(default)]
454    pub tool_search: ToolSearchConfig,
455    #[serde(default, skip_serializing_if = "Option::is_none")]
456    pub speed_policy: Option<SpeedPolicyDecision>,
457    #[serde(default, skip_serializing_if = "Option::is_none")]
458    pub deadline_remaining_seconds: Option<u64>,
459    #[serde(default, skip_serializing_if = "Option::is_none")]
460    pub reliability: Option<ReliabilityRequestPolicy>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
464pub struct AgentInferenceRequest {
465    pub model: ModelSelection,
466    pub instructions: InstructionBundle,
467    pub transcript: Vec<TranscriptItem>,
468    pub tools: Vec<ToolSpec>,
469    pub tool_choice: ToolChoice,
470    pub reasoning: ReasoningConfig,
471    pub output: OutputConfig,
472    pub runtime: RuntimeHints,
473    pub metadata: serde_json::Value,
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
477pub struct MessageDelta {
478    pub text: String,
479    #[serde(default, skip_serializing_if = "Option::is_none")]
480    pub phase: Option<String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
484pub struct ReasoningDelta {
485    pub text: String,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
489pub struct ToolCallStarted {
490    pub id: String,
491    pub name: String,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
495pub struct ToolCallDelta {
496    pub id: String,
497    pub arguments_delta: String,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
501pub struct ToolCallCompleted {
502    pub id: String,
503    pub name: String,
504    pub arguments: String,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
508pub struct HostedToolCallStarted {
509    pub id: String,
510    pub name: String,
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
514pub struct HostedToolCallCompleted {
515    pub id: String,
516    pub name: String,
517    pub arguments: String,
518}
519
520#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
521pub struct TokenUsage {
522    pub prompt_tokens: u32,
523    pub completion_tokens: u32,
524    pub total_tokens: u32,
525    #[serde(default)]
526    pub cached_prompt_tokens: u32,
527    /**
528     * Prompt tokens written to the provider's prompt cache this step. Like
529     * `cached_prompt_tokens`, this is a subset of `prompt_tokens`, not an
530     * additional count; hosts use it to bill cache writes at the provider's
531     * cache-write rate.
532     */
533    #[serde(default)]
534    pub cache_creation_prompt_tokens: u32,
535    #[serde(default, skip_serializing_if = "Option::is_none")]
536    pub cache_hit_rate: Option<f64>,
537}
538
539impl TokenUsage {
540    pub fn new(prompt_tokens: u32, completion_tokens: u32, total_tokens: u32) -> Self {
541        Self {
542            prompt_tokens,
543            completion_tokens,
544            total_tokens,
545            cached_prompt_tokens: 0,
546            cache_creation_prompt_tokens: 0,
547            cache_hit_rate: cache_hit_rate(prompt_tokens, 0),
548        }
549    }
550
551    pub fn with_cached_prompt_tokens(mut self, cached_prompt_tokens: u32) -> Self {
552        self.cached_prompt_tokens = cached_prompt_tokens.min(self.prompt_tokens);
553        self.cache_hit_rate = cache_hit_rate(self.prompt_tokens, self.cached_prompt_tokens);
554        self
555    }
556
557    pub fn with_cache_creation_prompt_tokens(mut self, cache_creation_prompt_tokens: u32) -> Self {
558        self.cache_creation_prompt_tokens = cache_creation_prompt_tokens.min(self.prompt_tokens);
559        self
560    }
561
562    pub fn add_assign(&mut self, usage: &TokenUsage) {
563        self.prompt_tokens = self.prompt_tokens.saturating_add(usage.prompt_tokens);
564        self.completion_tokens = self
565            .completion_tokens
566            .saturating_add(usage.completion_tokens);
567        self.total_tokens = self.total_tokens.saturating_add(usage.total_tokens);
568        self.cached_prompt_tokens = self
569            .cached_prompt_tokens
570            .saturating_add(usage.cached_prompt_tokens);
571        self.cache_creation_prompt_tokens = self
572            .cache_creation_prompt_tokens
573            .saturating_add(usage.cache_creation_prompt_tokens);
574        self.cache_hit_rate = cache_hit_rate(self.prompt_tokens, self.cached_prompt_tokens);
575    }
576
577    pub fn is_empty(&self) -> bool {
578        self.prompt_tokens == 0
579            && self.completion_tokens == 0
580            && self.total_tokens == 0
581            && self.cached_prompt_tokens == 0
582            && self.cache_creation_prompt_tokens == 0
583    }
584}
585
586pub fn cache_hit_rate(prompt_tokens: u32, cached_prompt_tokens: u32) -> Option<f64> {
587    if prompt_tokens == 0 {
588        None
589    } else {
590        Some(f64::from(cached_prompt_tokens.min(prompt_tokens)) / f64::from(prompt_tokens))
591    }
592}
593
594#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
595pub struct CompletionMetadata {
596    pub stop_reason: Option<String>,
597    pub provider_response_id: Option<String>,
598}
599
600/**
601 * Canonical mapping from provider-native stop reasons to the finish reason
602 * surfaced as `finishReason` on `turn/completed`. Only the terminal inference
603 * step's stop reason reaches the turn surface, so `toolUse` appears only when
604 * a turn genuinely ends on a tool-use step (e.g. tool rounds exhausted).
605 * Unknown stop reasons pass through unchanged.
606 */
607pub fn finish_reason_from_stop_reason(stop_reason: &str) -> String {
608    match stop_reason {
609        "end_turn" | "stop" | "stop_sequence" => "stop",
610        "max_tokens" | "length" => "length",
611        "tool_use" | "tool_calls" => "toolUse",
612        "content_filter" => "contentFilter",
613        "refusal" => "refusal",
614        other => other,
615    }
616    .to_string()
617}
618
619#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
620pub struct InferenceFailure {
621    pub message: String,
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
625pub struct CompactionProgress {
626    pub status: String,
627    #[serde(default, skip_serializing_if = "Option::is_none")]
628    pub item_id: Option<String>,
629}
630
631#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
632pub enum InferenceEvent {
633    MessageDelta(MessageDelta),
634    ReasoningDelta(ReasoningDelta),
635    ToolCallStarted(ToolCallStarted),
636    ToolCallDelta(ToolCallDelta),
637    ToolCallCompleted(ToolCallCompleted),
638    HostedToolCallStarted(HostedToolCallStarted),
639    HostedToolCallCompleted(HostedToolCallCompleted),
640    Compaction(CompactionProgress),
641    Usage(TokenUsage),
642    Completed(CompletionMetadata),
643    Failed(InferenceFailure),
644    ProviderMetadata(serde_json::Value),
645}
646
647pub type InferenceEventStream =
648    Pin<Box<dyn Stream<Item = anyhow::Result<InferenceEvent>> + Send + 'static>>;
649
650#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
651pub struct InferenceCapabilities {
652    pub streaming: bool,
653    pub tool_calls: bool,
654    pub parallel_tool_calls: bool,
655    pub reasoning_summaries: bool,
656    pub structured_output: bool,
657    pub image_input: bool,
658    pub prompt_cache: bool,
659    pub provider_metadata: bool,
660    pub tool_search: bool,
661}
662
663impl InferenceCapabilities {
664    pub fn text_only() -> Self {
665        Self {
666            streaming: true,
667            tool_calls: false,
668            parallel_tool_calls: false,
669            reasoning_summaries: false,
670            structured_output: false,
671            image_input: false,
672            prompt_cache: false,
673            provider_metadata: false,
674            tool_search: false,
675        }
676    }
677
678    pub fn coding_agent_default() -> Self {
679        Self {
680            streaming: true,
681            tool_calls: true,
682            parallel_tool_calls: true,
683            reasoning_summaries: false,
684            structured_output: false,
685            image_input: false,
686            prompt_cache: false,
687            provider_metadata: true,
688            tool_search: false,
689        }
690    }
691}
692
693#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
694pub struct ModelDescriptor {
695    pub id: String,
696    pub name: String,
697    pub context_window: Option<u32>,
698    #[serde(default, skip_serializing_if = "Option::is_none")]
699    pub default_reasoning: Option<String>,
700    #[serde(default, skip_serializing_if = "Vec::is_empty")]
701    pub supported_reasoning: Vec<ReasoningEffortDescriptor>,
702}
703
704#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
705pub struct ReasoningEffortDescriptor {
706    pub effort: String,
707    pub description: String,
708}
709
710pub struct InferenceProviderContext<'a> {
711    pub provider_id: &'a str,
712}
713
714pub struct InferenceTurnContext<'a> {
715    pub thread_id: &'a str,
716    pub turn_id: &'a str,
717    /// Optional callback that executes a single tool call through Roder's tool
718    /// registry and policy, returning its result. Provided by the runtime for
719    /// providers that drive their own in-stream agent loop (e.g. the Cursor
720    /// bidi agent-runtime client, which must execute read/write/shell exec
721    /// requests mid-stream rather than ending the turn). Most providers ignore
722    /// it and surface tool calls as `ToolCallCompleted` events instead.
723    pub tool_executor: Option<std::sync::Arc<dyn TurnToolExecutor>>,
724}
725
726/// Result of executing one tool call via [`TurnToolExecutor`].
727#[derive(Debug, Clone)]
728pub struct TurnToolOutcome {
729    pub result: String,
730    pub is_error: bool,
731}
732
733/// Executes a single tool call through the runtime's registry + policy.
734/// Implemented by the runtime; used by providers that run their own in-stream
735/// agent loop.
736#[async_trait::async_trait]
737pub trait TurnToolExecutor: Send + Sync {
738    async fn execute(&self, call: ToolCallCompleted) -> anyhow::Result<TurnToolOutcome>;
739}
740
741#[async_trait::async_trait]
742pub trait InferenceEngine: Send + Sync + 'static {
743    fn id(&self) -> InferenceEngineId;
744    fn capabilities(&self) -> InferenceCapabilities;
745
746    fn metadata(&self) -> InferenceProviderMetadata {
747        InferenceProviderMetadata::local(self.id())
748    }
749
750    async fn list_models(
751        &self,
752        ctx: InferenceProviderContext<'_>,
753    ) -> anyhow::Result<Vec<ModelDescriptor>>;
754
755    async fn stream_turn(
756        &self,
757        ctx: InferenceTurnContext<'_>,
758        request: AgentInferenceRequest,
759    ) -> anyhow::Result<InferenceEventStream>;
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    #[test]
767    fn finish_reason_mapping_normalizes_known_stop_reasons() {
768        assert_eq!(finish_reason_from_stop_reason("end_turn"), "stop");
769        assert_eq!(finish_reason_from_stop_reason("stop"), "stop");
770        assert_eq!(finish_reason_from_stop_reason("stop_sequence"), "stop");
771        assert_eq!(finish_reason_from_stop_reason("max_tokens"), "length");
772        assert_eq!(finish_reason_from_stop_reason("length"), "length");
773        assert_eq!(finish_reason_from_stop_reason("tool_use"), "toolUse");
774        assert_eq!(finish_reason_from_stop_reason("tool_calls"), "toolUse");
775        assert_eq!(
776            finish_reason_from_stop_reason("content_filter"),
777            "contentFilter"
778        );
779        assert_eq!(finish_reason_from_stop_reason("refusal"), "refusal");
780        assert_eq!(finish_reason_from_stop_reason("pause_turn"), "pause_turn");
781    }
782
783    #[test]
784    fn token_usage_accumulates_cache_creation_prompt_tokens() {
785        let mut usage = TokenUsage::new(100, 10, 110)
786            .with_cached_prompt_tokens(80)
787            .with_cache_creation_prompt_tokens(15);
788        usage.add_assign(
789            &TokenUsage::new(50, 5, 55)
790                .with_cached_prompt_tokens(40)
791                .with_cache_creation_prompt_tokens(10),
792        );
793
794        assert_eq!(usage.prompt_tokens, 150);
795        assert_eq!(usage.cached_prompt_tokens, 120);
796        assert_eq!(usage.cache_creation_prompt_tokens, 25);
797        assert!(!usage.is_empty());
798
799        let creation_only = TokenUsage {
800            cache_creation_prompt_tokens: 1,
801            ..TokenUsage::default()
802        };
803        assert!(!creation_only.is_empty());
804    }
805
806    #[test]
807    fn inference_speed_policy_decision_serializes_runtime_metadata() {
808        let decision = SpeedPolicyDecision {
809            phase: SpeedPolicyPhase::Verification,
810            desired_reasoning: "high".to_string(),
811            applied_reasoning: Some("high".to_string()),
812            supported: true,
813        };
814        let hints = RuntimeHints {
815            speed_policy: Some(decision),
816            ..RuntimeHints::default()
817        };
818
819        let json = serde_json::to_value(hints).unwrap();
820        assert_eq!(
821            json.get("speed_policy")
822                .and_then(|value| value.get("phase"))
823                .and_then(serde_json::Value::as_str),
824            Some("verification")
825        );
826        assert_eq!(
827            json.get("speed_policy")
828                .and_then(|value| value.get("desiredReasoning"))
829                .and_then(serde_json::Value::as_str),
830            Some("high")
831        );
832        assert_eq!(
833            json.get("speed_policy")
834                .and_then(|value| value.get("appliedReasoning"))
835                .and_then(serde_json::Value::as_str),
836            Some("high")
837        );
838    }
839
840    #[test]
841    fn inference_reliability_policy_serializes_runtime_metadata() {
842        let hints = RuntimeHints {
843            reliability: Some(ReliabilityRequestPolicy::default()),
844            ..RuntimeHints::default()
845        };
846
847        let json = serde_json::to_value(hints).unwrap();
848        assert_eq!(
849            json.get("reliability")
850                .and_then(|value| value.get("providerRetryMaxAttempts"))
851                .and_then(serde_json::Value::as_u64),
852            Some(3)
853        );
854        assert_eq!(
855            json.get("reliability")
856                .and_then(|value| value.get("retryEmptyProviderBody"))
857                .and_then(serde_json::Value::as_bool),
858            Some(true)
859        );
860    }
861
862    #[test]
863    fn tool_search_config_serializes_provider_native_request() {
864        let config = ToolSearchConfig {
865            mode: ToolSearchMode::ProviderNative,
866            max_catalog_items: Some(200),
867            include_mcp: true,
868            include_skills: false,
869            fallback_to_explicit_tools: true,
870            provider_variant: ToolSearchProviderVariant::Bm25,
871        };
872
873        let value = serde_json::to_value(&config).unwrap();
874
875        assert_eq!(value["mode"], "provider_native");
876        assert_eq!(value["maxCatalogItems"], 200);
877        assert_eq!(value["includeMcp"], true);
878        assert_eq!(value["includeSkills"], false);
879        assert_eq!(value["providerVariant"], "bm25");
880        assert!(config.is_provider_native_requested());
881    }
882
883    #[test]
884    fn explicit_tool_search_config_preserves_current_default() {
885        let config = ToolSearchConfig::default();
886
887        assert_eq!(config.mode, ToolSearchMode::Explicit);
888        assert!(!config.is_provider_native_requested());
889        assert!(config.fallback_to_explicit_tools);
890    }
891
892    #[test]
893    fn tool_search_effective_mode_resolution_covers_fallback_matrix() {
894        let explicit = ToolSearchConfig::explicit();
895        assert_eq!(
896            explicit.resolve_effective_mode(true).unwrap(),
897            EffectiveToolSearchMode::Explicit
898        );
899
900        let auto = ToolSearchConfig {
901            mode: ToolSearchMode::Auto,
902            ..ToolSearchConfig::default()
903        };
904        assert_eq!(
905            auto.resolve_effective_mode(true).unwrap(),
906            EffectiveToolSearchMode::ProviderNative
907        );
908        assert_eq!(
909            auto.resolve_effective_mode(false).unwrap(),
910            EffectiveToolSearchMode::Explicit
911        );
912
913        let native = ToolSearchConfig::provider_native();
914        assert_eq!(
915            native.resolve_effective_mode(true).unwrap(),
916            EffectiveToolSearchMode::ProviderNative
917        );
918        assert_eq!(
919            native.resolve_effective_mode(false).unwrap(),
920            EffectiveToolSearchMode::Explicit
921        );
922
923        let strict = ToolSearchConfig {
924            fallback_to_explicit_tools: false,
925            ..ToolSearchConfig::provider_native()
926        };
927        let error = strict.resolve_effective_mode(false).unwrap_err();
928        assert_eq!(error, ToolSearchModeError::ProviderNativeUnsupported);
929        assert!(error.to_string().contains("fallback_to_explicit_tools"));
930    }
931}