1use serde::{Deserialize, Serialize};
2
3use crate::extension::InferenceRouterId;
4use crate::inference::{
5 InferenceCapabilities, InferenceProviderMetadata, ModelDescriptor, ModelSelection,
6 ReasoningConfig, RuntimeProfile, SpeedPolicyPhase,
7};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(tag = "type", rename_all = "camelCase")]
11pub enum ModelSelectionMode {
12 Manual {
13 provider: String,
14 model: String,
15 #[serde(default, skip_serializing_if = "Option::is_none")]
16 reasoning: Option<String>,
17 },
18 Auto {
19 option_id: String,
20 router_id: InferenceRouterId,
21 label: String,
22 baseline: ModelSelection,
23 #[serde(default, skip_serializing_if = "Option::is_none")]
24 profile: Option<String>,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
26 reasoning: Option<String>,
27 },
28}
29
30impl ModelSelectionMode {
31 pub fn manual(
32 provider: impl Into<String>,
33 model: impl Into<String>,
34 reasoning: Option<String>,
35 ) -> Self {
36 Self::Manual {
37 provider: provider.into(),
38 model: model.into(),
39 reasoning,
40 }
41 }
42
43 pub fn auto(
44 option_id: impl Into<String>,
45 router_id: impl Into<String>,
46 label: impl Into<String>,
47 baseline: ModelSelection,
48 profile: Option<String>,
49 reasoning: Option<String>,
50 ) -> Self {
51 Self::Auto {
52 option_id: option_id.into(),
53 router_id: router_id.into(),
54 label: label.into(),
55 baseline,
56 profile,
57 reasoning,
58 }
59 }
60
61 pub fn concrete_selection(&self) -> ModelSelection {
62 match self {
63 Self::Manual {
64 provider, model, ..
65 } => ModelSelection {
66 provider: provider.clone(),
67 model: model.clone(),
68 },
69 Self::Auto { baseline, .. } => baseline.clone(),
70 }
71 }
72
73 pub fn reasoning(&self) -> Option<&str> {
74 match self {
75 Self::Manual { reasoning, .. } | Self::Auto { reasoning, .. } => reasoning.as_deref(),
76 }
77 }
78
79 pub fn is_auto(&self) -> bool {
80 matches!(self, Self::Auto { .. })
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(rename_all = "camelCase")]
86pub struct InferenceRoutingOptionDescriptor {
87 pub id: String,
88 pub label: String,
89 pub router_id: InferenceRouterId,
90 pub baseline: ModelSelection,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub profile: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub objective: Option<String>,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub reasoning: Option<String>,
97 #[serde(default = "default_true")]
98 pub available: bool,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub unavailable_reason: Option<String>,
101 #[serde(default)]
102 pub metadata: serde_json::Value,
103}
104
105fn default_true() -> bool {
106 true
107}
108
109impl InferenceRoutingOptionDescriptor {
110 pub fn selectable(
111 id: impl Into<String>,
112 label: impl Into<String>,
113 router_id: impl Into<String>,
114 baseline: ModelSelection,
115 ) -> Self {
116 Self {
117 id: id.into(),
118 label: label.into(),
119 router_id: router_id.into(),
120 baseline,
121 profile: None,
122 objective: None,
123 reasoning: None,
124 available: true,
125 unavailable_reason: None,
126 metadata: serde_json::Value::Null,
127 }
128 }
129
130 pub fn unavailable(mut self, reason: impl Into<String>) -> Self {
131 self.available = false;
132 self.unavailable_reason = Some(reason.into());
133 self
134 }
135
136 pub fn selection_mode(&self) -> ModelSelectionMode {
137 ModelSelectionMode::auto(
138 self.id.clone(),
139 self.router_id.clone(),
140 self.label.clone(),
141 self.baseline.clone(),
142 self.profile.clone(),
143 self.reasoning.clone(),
144 )
145 }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
149#[serde(rename_all = "camelCase")]
150pub struct InferenceRoutingContext {
151 pub thread_id: String,
152 pub turn_id: String,
153 #[serde(default)]
154 pub round_index: u32,
155 #[serde(default)]
156 pub runtime_profile: RuntimeProfile,
157 pub default_selection: ModelSelection,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub requested_selection: Option<ModelSelection>,
160 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub phase: Option<SpeedPolicyPhase>,
162 #[serde(default)]
163 pub transcript: InferenceRoutingTranscriptSummary,
164 #[serde(default)]
165 pub tools: InferenceRoutingToolSummary,
166 #[serde(default)]
167 pub candidates: Vec<InferenceRoutingCandidate>,
168 #[serde(default)]
169 pub signals: Vec<InferenceRoutingSignal>,
170 #[serde(default)]
171 pub prior_failures: u32,
172 #[serde(default)]
173 pub prior_escalations: u32,
174 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub estimated_input_tokens: Option<u32>,
176}
177
178#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
179#[serde(rename_all = "camelCase")]
180pub struct InferenceRoutingTranscriptSummary {
181 #[serde(default)]
182 pub item_count: u32,
183 #[serde(default)]
184 pub user_message_count: u32,
185 #[serde(default)]
186 pub assistant_message_count: u32,
187 #[serde(default)]
188 pub tool_result_count: u32,
189 #[serde(default)]
190 pub has_image_input: bool,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub latest_user_message_preview: Option<String>,
193 #[serde(default, skip_serializing_if = "Vec::is_empty")]
194 pub recent_tool_names: Vec<String>,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub approximate_tokens: Option<u32>,
197}
198
199#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
200#[serde(rename_all = "camelCase")]
201pub struct InferenceRoutingToolSummary {
202 #[serde(default)]
203 pub available_count: u32,
204 #[serde(default)]
205 pub has_file_tools: bool,
206 #[serde(default)]
207 pub has_shell_tools: bool,
208 #[serde(default)]
209 pub has_network_tools: bool,
210 #[serde(default)]
211 pub requires_tool_calls: bool,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
215#[serde(rename_all = "camelCase")]
216pub struct InferenceRoutingCandidate {
217 pub selection: ModelSelection,
218 pub provider: InferenceProviderMetadata,
219 pub model: ModelDescriptor,
220 pub capabilities: InferenceCapabilities,
221 #[serde(default)]
222 pub available: bool,
223 #[serde(default, skip_serializing_if = "Option::is_none")]
224 pub unavailable_reason: Option<String>,
225}
226
227impl InferenceRoutingCandidate {
228 pub fn available(
229 selection: ModelSelection,
230 provider: InferenceProviderMetadata,
231 model: ModelDescriptor,
232 capabilities: InferenceCapabilities,
233 ) -> Self {
234 Self {
235 selection,
236 provider,
237 model,
238 capabilities,
239 available: true,
240 unavailable_reason: None,
241 }
242 }
243
244 pub fn unavailable(
245 selection: ModelSelection,
246 provider: InferenceProviderMetadata,
247 model: ModelDescriptor,
248 capabilities: InferenceCapabilities,
249 reason: impl Into<String>,
250 ) -> Self {
251 Self {
252 selection,
253 provider,
254 model,
255 capabilities,
256 available: false,
257 unavailable_reason: Some(reason.into()),
258 }
259 }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
263#[serde(rename_all = "camelCase")]
264pub struct InferenceRoutingSignal {
265 pub key: String,
266 pub value: String,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub source: Option<String>,
269 #[serde(default, skip_serializing_if = "Option::is_none")]
270 pub weight: Option<f64>,
271}
272
273impl InferenceRoutingSignal {
274 pub fn new(key: impl Into<String>, value: impl Into<String>) -> Self {
275 Self {
276 key: key.into(),
277 value: value.into(),
278 source: None,
279 weight: None,
280 }
281 }
282}
283
284#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
285#[serde(rename_all = "snake_case")]
286pub enum InferenceRoutingOutcome {
287 Selected,
288 Escalated,
289 Fallback,
290 #[default]
291 Abstained,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
295#[serde(rename_all = "camelCase")]
296pub struct InferenceRoutingDecision {
297 pub router_id: InferenceRouterId,
298 pub outcome: InferenceRoutingOutcome,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub selected: Option<ModelSelection>,
301 #[serde(default, skip_serializing_if = "Option::is_none")]
302 pub reasoning: Option<ReasoningConfig>,
303 pub reason: String,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub confidence: Option<f64>,
306 #[serde(default)]
307 pub matched_signals: Vec<InferenceRoutingSignal>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
309 pub baseline: Option<ModelSelection>,
310 #[serde(default, skip_serializing_if = "Option::is_none")]
311 pub cost_delta: Option<InferenceRoutingCostDelta>,
312 #[serde(default)]
313 pub metadata: serde_json::Value,
314}
315
316impl InferenceRoutingDecision {
317 pub fn selected(
318 router_id: impl Into<String>,
319 selection: ModelSelection,
320 reason: impl Into<String>,
321 ) -> Self {
322 Self {
323 router_id: router_id.into(),
324 outcome: InferenceRoutingOutcome::Selected,
325 selected: Some(selection),
326 reasoning: None,
327 reason: reason.into(),
328 confidence: None,
329 matched_signals: Vec::new(),
330 baseline: None,
331 cost_delta: None,
332 metadata: serde_json::Value::Null,
333 }
334 }
335
336 pub fn abstain(router_id: impl Into<String>, reason: impl Into<String>) -> Self {
337 Self {
338 router_id: router_id.into(),
339 outcome: InferenceRoutingOutcome::Abstained,
340 selected: None,
341 reasoning: None,
342 reason: reason.into(),
343 confidence: None,
344 matched_signals: Vec::new(),
345 baseline: None,
346 cost_delta: None,
347 metadata: serde_json::Value::Null,
348 }
349 }
350
351 pub fn fallback(router_id: impl Into<String>, reason: impl Into<String>) -> Self {
352 Self {
353 router_id: router_id.into(),
354 outcome: InferenceRoutingOutcome::Fallback,
355 selected: None,
356 reasoning: None,
357 reason: reason.into(),
358 confidence: None,
359 matched_signals: Vec::new(),
360 baseline: None,
361 cost_delta: None,
362 metadata: serde_json::Value::Null,
363 }
364 }
365
366 pub fn escalated(
367 router_id: impl Into<String>,
368 selection: ModelSelection,
369 reason: impl Into<String>,
370 ) -> Self {
371 Self {
372 router_id: router_id.into(),
373 outcome: InferenceRoutingOutcome::Escalated,
374 selected: Some(selection),
375 reasoning: None,
376 reason: reason.into(),
377 confidence: None,
378 matched_signals: Vec::new(),
379 baseline: None,
380 cost_delta: None,
381 metadata: serde_json::Value::Null,
382 }
383 }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
387#[serde(rename_all = "camelCase")]
388pub struct InferenceRoutingCostDelta {
389 pub selected_estimate: InferenceRoutingCostEstimate,
390 pub baseline_estimate: InferenceRoutingCostEstimate,
391 pub estimated_savings_usd: f64,
392 #[serde(default, skip_serializing_if = "Option::is_none")]
393 pub classifier_overhead_usd: Option<f64>,
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
397#[serde(rename_all = "camelCase")]
398pub struct InferenceRoutingCostEstimate {
399 pub selection: ModelSelection,
400 pub prompt_cost_usd: f64,
401 pub completion_cost_usd: f64,
402 pub total_cost_usd: f64,
403 pub price_source: String,
404 pub usage_source: String,
405 #[serde(default)]
406 pub incomplete: bool,
407}
408
409#[async_trait::async_trait]
410pub trait InferenceRouter: Send + Sync + 'static {
411 fn id(&self) -> InferenceRouterId;
412
413 fn routing_options(&self) -> Vec<InferenceRoutingOptionDescriptor> {
414 Vec::new()
415 }
416
417 async fn route(
418 &self,
419 context: InferenceRoutingContext,
420 ) -> anyhow::Result<InferenceRoutingDecision>;
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::inference::{ProviderAuthType, ReasoningEffortDescriptor};
427
428 #[test]
429 fn routing_context_serializes_camel_case_fields() {
430 let context = InferenceRoutingContext {
431 thread_id: "thread-1".to_string(),
432 turn_id: "turn-1".to_string(),
433 round_index: 2,
434 runtime_profile: RuntimeProfile::Interactive,
435 default_selection: ModelSelection {
436 provider: "openai".to_string(),
437 model: "gpt-5.4".to_string(),
438 },
439 requested_selection: None,
440 phase: Some(SpeedPolicyPhase::Verification),
441 transcript: InferenceRoutingTranscriptSummary {
442 item_count: 3,
443 has_image_input: true,
444 latest_user_message_preview: Some("review auth changes".to_string()),
445 ..InferenceRoutingTranscriptSummary::default()
446 },
447 tools: InferenceRoutingToolSummary {
448 available_count: 8,
449 requires_tool_calls: true,
450 ..InferenceRoutingToolSummary::default()
451 },
452 candidates: vec![candidate("openai", "gpt-5.4")],
453 signals: vec![InferenceRoutingSignal::new("phase", "verification")],
454 prior_failures: 1,
455 prior_escalations: 0,
456 estimated_input_tokens: Some(4096),
457 };
458
459 let value = serde_json::to_value(&context).expect("serialize context");
460
461 assert_eq!(value["threadId"], "thread-1");
462 assert_eq!(value["turnId"], "turn-1");
463 assert_eq!(value["roundIndex"], 2);
464 assert_eq!(value["defaultSelection"]["provider"], "openai");
465 assert_eq!(value["phase"], "verification");
466 assert_eq!(value["transcript"]["hasImageInput"], true);
467 assert_eq!(
468 value["transcript"]["latestUserMessagePreview"],
469 "review auth changes"
470 );
471 assert_eq!(value["tools"]["requiresToolCalls"], true);
472 assert_eq!(value["estimatedInputTokens"], 4096);
473 assert_eq!(value["candidates"][0]["selection"]["model"], "gpt-5.4");
474 }
475
476 #[test]
477 fn routing_decision_serializes_selected_abstain_and_fallback() {
478 let selected = InferenceRoutingDecision {
479 reasoning: Some(ReasoningConfig {
480 enabled: true,
481 level: Some("low".to_string()),
482 }),
483 confidence: Some(0.82),
484 matched_signals: vec![InferenceRoutingSignal::new("intent", "file_lookup")],
485 baseline: Some(ModelSelection {
486 provider: "openai".to_string(),
487 model: "gpt-5.4".to_string(),
488 }),
489 ..InferenceRoutingDecision::selected(
490 "local-router",
491 ModelSelection {
492 provider: "openai".to_string(),
493 model: "gpt-5.4-mini".to_string(),
494 },
495 "routine lookup",
496 )
497 };
498 let selected_value = serde_json::to_value(selected).expect("serialize selected decision");
499
500 assert_eq!(selected_value["routerId"], "local-router");
501 assert_eq!(selected_value["outcome"], "selected");
502 assert_eq!(selected_value["selected"]["model"], "gpt-5.4-mini");
503 assert_eq!(selected_value["reasoning"]["level"], "low");
504 assert_eq!(selected_value["matchedSignals"][0]["key"], "intent");
505
506 let abstain = serde_json::to_value(InferenceRoutingDecision::abstain(
507 "local-router",
508 "no safe candidate",
509 ))
510 .expect("serialize abstain decision");
511 assert_eq!(abstain["outcome"], "abstained");
512 assert_eq!(abstain["reason"], "no safe candidate");
513 assert!(abstain.get("selected").is_none());
514
515 let fallback = serde_json::to_value(InferenceRoutingDecision::fallback(
516 "local-router",
517 "invalid router decision",
518 ))
519 .expect("serialize fallback decision");
520 assert_eq!(fallback["outcome"], "fallback");
521 assert_eq!(fallback["reason"], "invalid router decision");
522 }
523
524 #[test]
525 fn routing_option_descriptor_round_trips_with_selection_mode() {
526 let option = InferenceRoutingOptionDescriptor {
527 profile: Some("coding".to_string()),
528 objective: Some("minimize latency without losing code quality".to_string()),
529 reasoning: Some("low".to_string()),
530 metadata: serde_json::json!({ "source": "test" }),
531 ..InferenceRoutingOptionDescriptor::selectable(
532 "local-router:coding",
533 "Auto: Coding",
534 "local-router",
535 ModelSelection {
536 provider: "codex".to_string(),
537 model: "gpt-5.5".to_string(),
538 },
539 )
540 };
541
542 let value = serde_json::to_value(&option).expect("serialize routing option");
543
544 assert_eq!(value["id"], "local-router:coding");
545 assert_eq!(value["label"], "Auto: Coding");
546 assert_eq!(value["routerId"], "local-router");
547 assert_eq!(value["baseline"]["provider"], "codex");
548 assert_eq!(value["baseline"]["model"], "gpt-5.5");
549 assert_eq!(value["profile"], "coding");
550 assert_eq!(
551 value["objective"],
552 "minimize latency without losing code quality"
553 );
554 assert_eq!(value["reasoning"], "low");
555 assert_eq!(value["available"], true);
556
557 let round_trip: InferenceRoutingOptionDescriptor =
558 serde_json::from_value(value).expect("deserialize routing option");
559 assert_eq!(round_trip, option);
560
561 assert_eq!(
562 round_trip.selection_mode(),
563 ModelSelectionMode::Auto {
564 option_id: "local-router:coding".to_string(),
565 router_id: "local-router".to_string(),
566 label: "Auto: Coding".to_string(),
567 baseline: ModelSelection {
568 provider: "codex".to_string(),
569 model: "gpt-5.5".to_string(),
570 },
571 profile: Some("coding".to_string()),
572 reasoning: Some("low".to_string()),
573 }
574 );
575 }
576
577 fn candidate(provider: &str, model: &str) -> InferenceRoutingCandidate {
578 InferenceRoutingCandidate::available(
579 ModelSelection {
580 provider: provider.to_string(),
581 model: model.to_string(),
582 },
583 InferenceProviderMetadata {
584 name: provider.to_string(),
585 description: None,
586 auth_type: ProviderAuthType::ApiKey,
587 auth_label: Some("API key".to_string()),
588 auth_configured: Some(true),
589 recommended: true,
590 sort_order: 10,
591 },
592 ModelDescriptor {
593 id: model.to_string(),
594 name: model.to_string(),
595 context_window: Some(128_000),
596 default_reasoning: Some("medium".to_string()),
597 supported_reasoning: vec![ReasoningEffortDescriptor {
598 effort: "low".to_string(),
599 description: "Low".to_string(),
600 }],
601 },
602 InferenceCapabilities::coding_agent_default(),
603 )
604 }
605}