use crate::providers::ModelProvider;
use serde::{Deserialize, Serialize};
pub const DEFAULT_MODEL_ROUTER_ID: &str = "enact/model-router";
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ModelSelectionSource {
Step,
Workflow,
Agent,
DefaultRouter,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RoutingProfile {
Eco,
#[default]
Balanced,
Quality,
Deterministic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPolicy {
pub profile: RoutingProfile,
pub default_confidence: f32,
}
impl Default for RoutingPolicy {
fn default() -> Self {
Self {
profile: RoutingProfile::Balanced,
default_confidence: 0.70,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingDecision {
pub requested_model: String,
pub logical_model: String,
pub concrete_model: String,
pub profile: RoutingProfile,
pub confidence: f32,
pub used_default_router: bool,
pub rationale: String,
pub source: ModelSelectionSource,
}
pub fn resolve_model_precedence(
step_model: Option<&str>,
workflow_model: Option<&str>,
agent_model: Option<&str>,
) -> (String, ModelSelectionSource) {
if let Some(model) = step_model {
return (model.to_string(), ModelSelectionSource::Step);
}
if let Some(model) = workflow_model {
return (model.to_string(), ModelSelectionSource::Workflow);
}
if let Some(model) = agent_model {
return (model.to_string(), ModelSelectionSource::Agent);
}
(
DEFAULT_MODEL_ROUTER_ID.to_string(),
ModelSelectionSource::DefaultRouter,
)
}
pub struct ModelRouter;
impl ModelRouter {
pub fn resolve(
requested_model: Option<&str>,
provider: &dyn ModelProvider,
policy: &RoutingPolicy,
) -> RoutingDecision {
let (requested, source) = resolve_model_precedence(requested_model, None, None);
let used_default_router = requested == DEFAULT_MODEL_ROUTER_ID;
let concrete_model = if used_default_router {
provider.model().to_string()
} else {
requested.clone()
};
let rationale = if used_default_router {
format!(
"No explicit model pin provided; resolved '{}' to provider default '{}'",
DEFAULT_MODEL_ROUTER_ID, concrete_model
)
} else {
format!("Explicit model pin '{}' selected", concrete_model)
};
RoutingDecision {
requested_model: requested.clone(),
logical_model: requested,
concrete_model,
profile: policy.profile,
confidence: policy.default_confidence,
used_default_router,
rationale,
source,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{ChatRequest, ChatResponse};
use async_trait::async_trait;
struct MockProvider;
#[async_trait]
impl ModelProvider for MockProvider {
fn name(&self) -> &str {
"mock"
}
fn model(&self) -> &str {
"gpt-4o-mini"
}
async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
anyhow::bail!("not used")
}
}
#[test]
fn resolves_to_default_router_when_unspecified() {
let provider = MockProvider;
let policy = RoutingPolicy::default();
let decision = ModelRouter::resolve(None, &provider, &policy);
assert_eq!(decision.logical_model, DEFAULT_MODEL_ROUTER_ID);
assert_eq!(decision.concrete_model, "gpt-4o-mini");
assert!(decision.used_default_router);
}
#[test]
fn preserves_explicit_model_pin() {
let provider = MockProvider;
let policy = RoutingPolicy::default();
let decision = ModelRouter::resolve(Some("anthropic/claude-sonnet-4"), &provider, &policy);
assert_eq!(decision.logical_model, "anthropic/claude-sonnet-4");
assert_eq!(decision.concrete_model, "anthropic/claude-sonnet-4");
assert!(!decision.used_default_router);
assert_eq!(decision.source, ModelSelectionSource::Step);
}
#[test]
fn model_precedence_order_is_step_then_workflow_then_agent_then_default() {
let (model, source) = resolve_model_precedence(
Some("step/model"),
Some("workflow/model"),
Some("agent/model"),
);
assert_eq!(model, "step/model");
assert_eq!(source, ModelSelectionSource::Step);
let (model, source) =
resolve_model_precedence(None, Some("workflow/model"), Some("agent/model"));
assert_eq!(model, "workflow/model");
assert_eq!(source, ModelSelectionSource::Workflow);
let (model, source) = resolve_model_precedence(None, None, Some("agent/model"));
assert_eq!(model, "agent/model");
assert_eq!(source, ModelSelectionSource::Agent);
let (model, source) = resolve_model_precedence(None, None, None);
assert_eq!(model, DEFAULT_MODEL_ROUTER_ID);
assert_eq!(source, ModelSelectionSource::DefaultRouter);
}
}