Skip to main content

roder_api/
inference_routing.rs

1use serde::{Deserialize, Serialize};
2
3use crate::extension::InferenceRouterId;
4use crate::inference::{
5    InferenceCapabilities, InferenceProviderMetadata, ModelDescriptor, ModelSelection,
6    ReasoningConfig, RuntimeProfile, SpeedPolicyPhase,
7};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(tag = "type", rename_all = "camelCase")]
11pub enum ModelSelectionMode {
12    Manual {
13        provider: String,
14        model: String,
15        #[serde(default, skip_serializing_if = "Option::is_none")]
16        reasoning: Option<String>,
17    },
18    Auto {
19        option_id: String,
20        router_id: InferenceRouterId,
21        label: String,
22        baseline: ModelSelection,
23        #[serde(default, skip_serializing_if = "Option::is_none")]
24        profile: Option<String>,
25        #[serde(default, skip_serializing_if = "Option::is_none")]
26        reasoning: Option<String>,
27    },
28}
29
30impl ModelSelectionMode {
31    pub fn manual(
32        provider: impl Into<String>,
33        model: impl Into<String>,
34        reasoning: Option<String>,
35    ) -> Self {
36        Self::Manual {
37            provider: provider.into(),
38            model: model.into(),
39            reasoning,
40        }
41    }
42
43    pub fn auto(
44        option_id: impl Into<String>,
45        router_id: impl Into<String>,
46        label: impl Into<String>,
47        baseline: ModelSelection,
48        profile: Option<String>,
49        reasoning: Option<String>,
50    ) -> Self {
51        Self::Auto {
52            option_id: option_id.into(),
53            router_id: router_id.into(),
54            label: label.into(),
55            baseline,
56            profile,
57            reasoning,
58        }
59    }
60
61    pub fn concrete_selection(&self) -> ModelSelection {
62        match self {
63            Self::Manual {
64                provider, model, ..
65            } => ModelSelection {
66                provider: provider.clone(),
67                model: model.clone(),
68            },
69            Self::Auto { baseline, .. } => baseline.clone(),
70        }
71    }
72
73    pub fn reasoning(&self) -> Option<&str> {
74        match self {
75            Self::Manual { reasoning, .. } | Self::Auto { reasoning, .. } => reasoning.as_deref(),
76        }
77    }
78
79    pub fn is_auto(&self) -> bool {
80        matches!(self, Self::Auto { .. })
81    }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(rename_all = "camelCase")]
86pub struct InferenceRoutingOptionDescriptor {
87    pub id: String,
88    pub label: String,
89    pub router_id: InferenceRouterId,
90    pub baseline: ModelSelection,
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub profile: Option<String>,
93    #[serde(default, skip_serializing_if = "Option::is_none")]
94    pub objective: Option<String>,
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub reasoning: Option<String>,
97    #[serde(default = "default_true")]
98    pub available: bool,
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub unavailable_reason: Option<String>,
101    #[serde(default)]
102    pub metadata: serde_json::Value,
103}
104
105fn default_true() -> bool {
106    true
107}
108
109impl InferenceRoutingOptionDescriptor {
110    pub fn selectable(
111        id: impl Into<String>,
112        label: impl Into<String>,
113        router_id: impl Into<String>,
114        baseline: ModelSelection,
115    ) -> Self {
116        Self {
117            id: id.into(),
118            label: label.into(),
119            router_id: router_id.into(),
120            baseline,
121            profile: None,
122            objective: None,
123            reasoning: None,
124            available: true,
125            unavailable_reason: None,
126            metadata: serde_json::Value::Null,
127        }
128    }
129
130    pub fn unavailable(mut self, reason: impl Into<String>) -> Self {
131        self.available = false;
132        self.unavailable_reason = Some(reason.into());
133        self
134    }
135
136    pub fn selection_mode(&self) -> ModelSelectionMode {
137        ModelSelectionMode::auto(
138            self.id.clone(),
139            self.router_id.clone(),
140            self.label.clone(),
141            self.baseline.clone(),
142            self.profile.clone(),
143            self.reasoning.clone(),
144        )
145    }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
149#[serde(rename_all = "camelCase")]
150pub struct InferenceRoutingContext {
151    pub thread_id: String,
152    pub turn_id: String,
153    #[serde(default)]
154    pub round_index: u32,
155    #[serde(default)]
156    pub runtime_profile: RuntimeProfile,
157    pub default_selection: ModelSelection,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub requested_selection: Option<ModelSelection>,
160    #[serde(default, skip_serializing_if = "Option::is_none")]
161    pub phase: Option<SpeedPolicyPhase>,
162    #[serde(default)]
163    pub transcript: InferenceRoutingTranscriptSummary,
164    #[serde(default)]
165    pub tools: InferenceRoutingToolSummary,
166    #[serde(default)]
167    pub candidates: Vec<InferenceRoutingCandidate>,
168    #[serde(default)]
169    pub signals: Vec<InferenceRoutingSignal>,
170    #[serde(default)]
171    pub prior_failures: u32,
172    #[serde(default)]
173    pub prior_escalations: u32,
174    #[serde(default, skip_serializing_if = "Option::is_none")]
175    pub estimated_input_tokens: Option<u32>,
176}
177
178#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
179#[serde(rename_all = "camelCase")]
180pub struct InferenceRoutingTranscriptSummary {
181    #[serde(default)]
182    pub item_count: u32,
183    #[serde(default)]
184    pub user_message_count: u32,
185    #[serde(default)]
186    pub assistant_message_count: u32,
187    #[serde(default)]
188    pub tool_result_count: u32,
189    #[serde(default)]
190    pub has_image_input: bool,
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub latest_user_message_preview: Option<String>,
193    #[serde(default, skip_serializing_if = "Vec::is_empty")]
194    pub recent_tool_names: Vec<String>,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    pub approximate_tokens: Option<u32>,
197}
198
199#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
200#[serde(rename_all = "camelCase")]
201pub struct InferenceRoutingToolSummary {
202    #[serde(default)]
203    pub available_count: u32,
204    #[serde(default)]
205    pub has_file_tools: bool,
206    #[serde(default)]
207    pub has_shell_tools: bool,
208    #[serde(default)]
209    pub has_network_tools: bool,
210    #[serde(default)]
211    pub requires_tool_calls: bool,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
215#[serde(rename_all = "camelCase")]
216pub struct InferenceRoutingCandidate {
217    pub selection: ModelSelection,
218    pub provider: InferenceProviderMetadata,
219    pub model: ModelDescriptor,
220    pub capabilities: InferenceCapabilities,
221    #[serde(default)]
222    pub available: bool,
223    #[serde(default, skip_serializing_if = "Option::is_none")]
224    pub unavailable_reason: Option<String>,
225}
226
227impl InferenceRoutingCandidate {
228    pub fn available(
229        selection: ModelSelection,
230        provider: InferenceProviderMetadata,
231        model: ModelDescriptor,
232        capabilities: InferenceCapabilities,
233    ) -> Self {
234        Self {
235            selection,
236            provider,
237            model,
238            capabilities,
239            available: true,
240            unavailable_reason: None,
241        }
242    }
243
244    pub fn unavailable(
245        selection: ModelSelection,
246        provider: InferenceProviderMetadata,
247        model: ModelDescriptor,
248        capabilities: InferenceCapabilities,
249        reason: impl Into<String>,
250    ) -> Self {
251        Self {
252            selection,
253            provider,
254            model,
255            capabilities,
256            available: false,
257            unavailable_reason: Some(reason.into()),
258        }
259    }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
263#[serde(rename_all = "camelCase")]
264pub struct InferenceRoutingSignal {
265    pub key: String,
266    pub value: String,
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub source: Option<String>,
269    #[serde(default, skip_serializing_if = "Option::is_none")]
270    pub weight: Option<f64>,
271}
272
273impl InferenceRoutingSignal {
274    pub fn new(key: impl Into<String>, value: impl Into<String>) -> Self {
275        Self {
276            key: key.into(),
277            value: value.into(),
278            source: None,
279            weight: None,
280        }
281    }
282}
283
284#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
285#[serde(rename_all = "snake_case")]
286pub enum InferenceRoutingOutcome {
287    Selected,
288    Escalated,
289    Fallback,
290    #[default]
291    Abstained,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
295#[serde(rename_all = "camelCase")]
296pub struct InferenceRoutingDecision {
297    pub router_id: InferenceRouterId,
298    pub outcome: InferenceRoutingOutcome,
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub selected: Option<ModelSelection>,
301    #[serde(default, skip_serializing_if = "Option::is_none")]
302    pub reasoning: Option<ReasoningConfig>,
303    pub reason: String,
304    #[serde(default, skip_serializing_if = "Option::is_none")]
305    pub confidence: Option<f64>,
306    #[serde(default)]
307    pub matched_signals: Vec<InferenceRoutingSignal>,
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub baseline: Option<ModelSelection>,
310    #[serde(default, skip_serializing_if = "Option::is_none")]
311    pub cost_delta: Option<InferenceRoutingCostDelta>,
312    #[serde(default)]
313    pub metadata: serde_json::Value,
314}
315
316impl InferenceRoutingDecision {
317    pub fn selected(
318        router_id: impl Into<String>,
319        selection: ModelSelection,
320        reason: impl Into<String>,
321    ) -> Self {
322        Self {
323            router_id: router_id.into(),
324            outcome: InferenceRoutingOutcome::Selected,
325            selected: Some(selection),
326            reasoning: None,
327            reason: reason.into(),
328            confidence: None,
329            matched_signals: Vec::new(),
330            baseline: None,
331            cost_delta: None,
332            metadata: serde_json::Value::Null,
333        }
334    }
335
336    pub fn abstain(router_id: impl Into<String>, reason: impl Into<String>) -> Self {
337        Self {
338            router_id: router_id.into(),
339            outcome: InferenceRoutingOutcome::Abstained,
340            selected: None,
341            reasoning: None,
342            reason: reason.into(),
343            confidence: None,
344            matched_signals: Vec::new(),
345            baseline: None,
346            cost_delta: None,
347            metadata: serde_json::Value::Null,
348        }
349    }
350
351    pub fn fallback(router_id: impl Into<String>, reason: impl Into<String>) -> Self {
352        Self {
353            router_id: router_id.into(),
354            outcome: InferenceRoutingOutcome::Fallback,
355            selected: None,
356            reasoning: None,
357            reason: reason.into(),
358            confidence: None,
359            matched_signals: Vec::new(),
360            baseline: None,
361            cost_delta: None,
362            metadata: serde_json::Value::Null,
363        }
364    }
365
366    pub fn escalated(
367        router_id: impl Into<String>,
368        selection: ModelSelection,
369        reason: impl Into<String>,
370    ) -> Self {
371        Self {
372            router_id: router_id.into(),
373            outcome: InferenceRoutingOutcome::Escalated,
374            selected: Some(selection),
375            reasoning: None,
376            reason: reason.into(),
377            confidence: None,
378            matched_signals: Vec::new(),
379            baseline: None,
380            cost_delta: None,
381            metadata: serde_json::Value::Null,
382        }
383    }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
387#[serde(rename_all = "camelCase")]
388pub struct InferenceRoutingCostDelta {
389    pub selected_estimate: InferenceRoutingCostEstimate,
390    pub baseline_estimate: InferenceRoutingCostEstimate,
391    pub estimated_savings_usd: f64,
392    #[serde(default, skip_serializing_if = "Option::is_none")]
393    pub classifier_overhead_usd: Option<f64>,
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
397#[serde(rename_all = "camelCase")]
398pub struct InferenceRoutingCostEstimate {
399    pub selection: ModelSelection,
400    pub prompt_cost_usd: f64,
401    pub completion_cost_usd: f64,
402    pub total_cost_usd: f64,
403    pub price_source: String,
404    pub usage_source: String,
405    #[serde(default)]
406    pub incomplete: bool,
407}
408
409#[async_trait::async_trait]
410pub trait InferenceRouter: Send + Sync + 'static {
411    fn id(&self) -> InferenceRouterId;
412
413    fn routing_options(&self) -> Vec<InferenceRoutingOptionDescriptor> {
414        Vec::new()
415    }
416
417    async fn route(
418        &self,
419        context: InferenceRoutingContext,
420    ) -> anyhow::Result<InferenceRoutingDecision>;
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::inference::{ProviderAuthType, ReasoningEffortDescriptor};
427
428    #[test]
429    fn routing_context_serializes_camel_case_fields() {
430        let context = InferenceRoutingContext {
431            thread_id: "thread-1".to_string(),
432            turn_id: "turn-1".to_string(),
433            round_index: 2,
434            runtime_profile: RuntimeProfile::Interactive,
435            default_selection: ModelSelection {
436                provider: "openai".to_string(),
437                model: "gpt-5.4".to_string(),
438            },
439            requested_selection: None,
440            phase: Some(SpeedPolicyPhase::Verification),
441            transcript: InferenceRoutingTranscriptSummary {
442                item_count: 3,
443                has_image_input: true,
444                latest_user_message_preview: Some("review auth changes".to_string()),
445                ..InferenceRoutingTranscriptSummary::default()
446            },
447            tools: InferenceRoutingToolSummary {
448                available_count: 8,
449                requires_tool_calls: true,
450                ..InferenceRoutingToolSummary::default()
451            },
452            candidates: vec![candidate("openai", "gpt-5.4")],
453            signals: vec![InferenceRoutingSignal::new("phase", "verification")],
454            prior_failures: 1,
455            prior_escalations: 0,
456            estimated_input_tokens: Some(4096),
457        };
458
459        let value = serde_json::to_value(&context).expect("serialize context");
460
461        assert_eq!(value["threadId"], "thread-1");
462        assert_eq!(value["turnId"], "turn-1");
463        assert_eq!(value["roundIndex"], 2);
464        assert_eq!(value["defaultSelection"]["provider"], "openai");
465        assert_eq!(value["phase"], "verification");
466        assert_eq!(value["transcript"]["hasImageInput"], true);
467        assert_eq!(
468            value["transcript"]["latestUserMessagePreview"],
469            "review auth changes"
470        );
471        assert_eq!(value["tools"]["requiresToolCalls"], true);
472        assert_eq!(value["estimatedInputTokens"], 4096);
473        assert_eq!(value["candidates"][0]["selection"]["model"], "gpt-5.4");
474    }
475
476    #[test]
477    fn routing_decision_serializes_selected_abstain_and_fallback() {
478        let selected = InferenceRoutingDecision {
479            reasoning: Some(ReasoningConfig {
480                enabled: true,
481                level: Some("low".to_string()),
482            }),
483            confidence: Some(0.82),
484            matched_signals: vec![InferenceRoutingSignal::new("intent", "file_lookup")],
485            baseline: Some(ModelSelection {
486                provider: "openai".to_string(),
487                model: "gpt-5.4".to_string(),
488            }),
489            ..InferenceRoutingDecision::selected(
490                "local-router",
491                ModelSelection {
492                    provider: "openai".to_string(),
493                    model: "gpt-5.4-mini".to_string(),
494                },
495                "routine lookup",
496            )
497        };
498        let selected_value = serde_json::to_value(selected).expect("serialize selected decision");
499
500        assert_eq!(selected_value["routerId"], "local-router");
501        assert_eq!(selected_value["outcome"], "selected");
502        assert_eq!(selected_value["selected"]["model"], "gpt-5.4-mini");
503        assert_eq!(selected_value["reasoning"]["level"], "low");
504        assert_eq!(selected_value["matchedSignals"][0]["key"], "intent");
505
506        let abstain = serde_json::to_value(InferenceRoutingDecision::abstain(
507            "local-router",
508            "no safe candidate",
509        ))
510        .expect("serialize abstain decision");
511        assert_eq!(abstain["outcome"], "abstained");
512        assert_eq!(abstain["reason"], "no safe candidate");
513        assert!(abstain.get("selected").is_none());
514
515        let fallback = serde_json::to_value(InferenceRoutingDecision::fallback(
516            "local-router",
517            "invalid router decision",
518        ))
519        .expect("serialize fallback decision");
520        assert_eq!(fallback["outcome"], "fallback");
521        assert_eq!(fallback["reason"], "invalid router decision");
522    }
523
524    #[test]
525    fn routing_option_descriptor_round_trips_with_selection_mode() {
526        let option = InferenceRoutingOptionDescriptor {
527            profile: Some("coding".to_string()),
528            objective: Some("minimize latency without losing code quality".to_string()),
529            reasoning: Some("low".to_string()),
530            metadata: serde_json::json!({ "source": "test" }),
531            ..InferenceRoutingOptionDescriptor::selectable(
532                "local-router:coding",
533                "Auto: Coding",
534                "local-router",
535                ModelSelection {
536                    provider: "codex".to_string(),
537                    model: "gpt-5.5".to_string(),
538                },
539            )
540        };
541
542        let value = serde_json::to_value(&option).expect("serialize routing option");
543
544        assert_eq!(value["id"], "local-router:coding");
545        assert_eq!(value["label"], "Auto: Coding");
546        assert_eq!(value["routerId"], "local-router");
547        assert_eq!(value["baseline"]["provider"], "codex");
548        assert_eq!(value["baseline"]["model"], "gpt-5.5");
549        assert_eq!(value["profile"], "coding");
550        assert_eq!(
551            value["objective"],
552            "minimize latency without losing code quality"
553        );
554        assert_eq!(value["reasoning"], "low");
555        assert_eq!(value["available"], true);
556
557        let round_trip: InferenceRoutingOptionDescriptor =
558            serde_json::from_value(value).expect("deserialize routing option");
559        assert_eq!(round_trip, option);
560
561        assert_eq!(
562            round_trip.selection_mode(),
563            ModelSelectionMode::Auto {
564                option_id: "local-router:coding".to_string(),
565                router_id: "local-router".to_string(),
566                label: "Auto: Coding".to_string(),
567                baseline: ModelSelection {
568                    provider: "codex".to_string(),
569                    model: "gpt-5.5".to_string(),
570                },
571                profile: Some("coding".to_string()),
572                reasoning: Some("low".to_string()),
573            }
574        );
575    }
576
577    fn candidate(provider: &str, model: &str) -> InferenceRoutingCandidate {
578        InferenceRoutingCandidate::available(
579            ModelSelection {
580                provider: provider.to_string(),
581                model: model.to_string(),
582            },
583            InferenceProviderMetadata {
584                name: provider.to_string(),
585                description: None,
586                auth_type: ProviderAuthType::ApiKey,
587                auth_label: Some("API key".to_string()),
588                auth_configured: Some(true),
589                recommended: true,
590                sort_order: 10,
591            },
592            ModelDescriptor {
593                id: model.to_string(),
594                name: model.to_string(),
595                context_window: Some(128_000),
596                default_reasoning: Some("medium".to_string()),
597                supported_reasoning: vec![ReasoningEffortDescriptor {
598                    effort: "low".to_string(),
599                    description: "Low".to_string(),
600                }],
601            },
602            InferenceCapabilities::coding_agent_default(),
603        )
604    }
605}