use std::sync::Arc;
use tracing::debug;
use crate::provider::{AgentConfig, AgentProvider, InvokeFuture, LogSink};
#[derive(Debug, Clone)]
pub enum ProviderMatcher {
ModelPrefix(String),
ModelExact(String),
}
impl ProviderMatcher {
fn matches(&self, config: &AgentConfig) -> bool {
match self {
Self::ModelPrefix(prefix) => config.model.starts_with(prefix.as_str()),
Self::ModelExact(exact) => config.model == *exact,
}
}
}
pub struct ProviderRouter {
routes: Vec<(ProviderMatcher, Arc<dyn AgentProvider>)>,
fallback: Arc<dyn AgentProvider>,
}
impl ProviderRouter {
pub fn new(fallback: Arc<dyn AgentProvider>) -> Self {
Self {
routes: Vec::new(),
fallback,
}
}
pub fn route(mut self, matcher: ProviderMatcher, provider: Arc<dyn AgentProvider>) -> Self {
self.routes.push((matcher, provider));
self
}
fn resolve(&self, config: &AgentConfig) -> &Arc<dyn AgentProvider> {
for (matcher, provider) in &self.routes {
if matcher.matches(config) {
debug!(
model = %config.model,
matcher = ?matcher,
"routed to matched provider"
);
return provider;
}
}
debug!(model = %config.model, "using fallback provider");
&self.fallback
}
}
impl AgentProvider for ProviderRouter {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
let provider = self.resolve(config);
provider.invoke(config)
}
fn invoke_with_logs<'a>(
&'a self,
config: &'a AgentConfig,
log_sink: Arc<dyn LogSink>,
) -> InvokeFuture<'a> {
let provider = self.resolve(config);
provider.invoke_with_logs(config, log_sink)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::json;
use super::*;
use crate::provider::AgentOutput;
struct CountingProvider {
name: &'static str,
count: AtomicUsize,
}
impl CountingProvider {
fn new(name: &'static str) -> Arc<Self> {
Arc::new(Self {
name,
count: AtomicUsize::new(0),
})
}
fn call_count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
}
impl AgentProvider for CountingProvider {
fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
self.count.fetch_add(1, Ordering::Relaxed);
let name = self.name;
Box::pin(async move { Ok(AgentOutput::new(json!(name))) })
}
}
#[tokio::test]
async fn router_fallback_when_no_routes() {
let fallback = CountingProvider::new("fallback");
let router = ProviderRouter::new(fallback.clone());
let config = AgentConfig::new("hello");
let output = router.invoke(&config).await.expect("should succeed");
assert_eq!(output.value, json!("fallback"));
assert_eq!(fallback.call_count(), 1);
}
#[tokio::test]
async fn router_matches_model_prefix() {
let fallback = CountingProvider::new("fallback");
let nvidia = CountingProvider::new("nvidia");
let router = ProviderRouter::new(fallback.clone()).route(
ProviderMatcher::ModelPrefix("nvidia/".into()),
nvidia.clone(),
);
let config = AgentConfig::new("hello").model("nvidia/deepseek-v4-flash");
let output = router.invoke(&config).await.expect("should succeed");
assert_eq!(output.value, json!("nvidia"));
assert_eq!(nvidia.call_count(), 1);
assert_eq!(fallback.call_count(), 0);
}
#[tokio::test]
async fn router_matches_model_exact() {
let fallback = CountingProvider::new("fallback");
let special = CountingProvider::new("special");
let router = ProviderRouter::new(fallback.clone()).route(
ProviderMatcher::ModelExact("my-model".into()),
special.clone(),
);
let config = AgentConfig::new("hello").model("my-model");
let output = router.invoke(&config).await.expect("should succeed");
assert_eq!(output.value, json!("special"));
assert_eq!(special.call_count(), 1);
}
#[tokio::test]
async fn router_exact_does_not_match_prefix() {
let fallback = CountingProvider::new("fallback");
let special = CountingProvider::new("special");
let router = ProviderRouter::new(fallback.clone()).route(
ProviderMatcher::ModelExact("nvidia".into()),
special.clone(),
);
let config = AgentConfig::new("hello").model("nvidia/something");
let output = router.invoke(&config).await.expect("should succeed");
assert_eq!(output.value, json!("fallback"));
assert_eq!(special.call_count(), 0);
assert_eq!(fallback.call_count(), 1);
}
#[tokio::test]
async fn router_first_match_wins() {
let fallback = CountingProvider::new("fallback");
let first = CountingProvider::new("first");
let second = CountingProvider::new("second");
let router = ProviderRouter::new(fallback.clone())
.route(
ProviderMatcher::ModelPrefix("nvidia/".into()),
first.clone(),
)
.route(
ProviderMatcher::ModelPrefix("nvidia/".into()),
second.clone(),
);
let config = AgentConfig::new("hello").model("nvidia/test");
let output = router.invoke(&config).await.expect("should succeed");
assert_eq!(output.value, json!("first"));
assert_eq!(first.call_count(), 1);
assert_eq!(second.call_count(), 0);
}
#[tokio::test]
async fn router_multiple_routes() {
let fallback = CountingProvider::new("claude");
let nvidia = CountingProvider::new("nvidia");
let openai = CountingProvider::new("openai");
let router = ProviderRouter::new(fallback.clone())
.route(
ProviderMatcher::ModelPrefix("nvidia/".into()),
nvidia.clone(),
)
.route(ProviderMatcher::ModelPrefix("gpt-".into()), openai.clone());
let config1 = AgentConfig::new("hello").model("nvidia/nemotron");
let config2 = AgentConfig::new("hello").model("gpt-5.5");
let config3 = AgentConfig::new("hello").model("sonnet");
let out1 = router.invoke(&config1).await.expect("should succeed");
let out2 = router.invoke(&config2).await.expect("should succeed");
let out3 = router.invoke(&config3).await.expect("should succeed");
assert_eq!(out1.value, json!("nvidia"));
assert_eq!(out2.value, json!("openai"));
assert_eq!(out3.value, json!("claude"));
}
}