use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use bitrouter_core::{
errors::{BitrouterError, Result},
routers::routing_table::{RouteEntry, RoutingTable, RoutingTarget},
};
use crate::config::{ApiProtocol, ModelConfig, ProviderConfig, RoutingStrategy};
#[derive(Debug, Clone)]
pub struct ResolvedTarget {
pub provider_name: String,
pub model_id: String,
pub api_key_override: Option<String>,
pub api_base_override: Option<String>,
}
pub struct ConfigRoutingTable {
providers: HashMap<String, ProviderConfig>,
models: HashMap<String, ModelConfig>,
counters: HashMap<String, AtomicUsize>,
}
impl ConfigRoutingTable {
pub fn new(
providers: HashMap<String, ProviderConfig>,
models: HashMap<String, ModelConfig>,
) -> Self {
let counters = models
.keys()
.map(|k| (k.clone(), AtomicUsize::new(0)))
.collect();
Self {
providers,
models,
counters,
}
}
pub fn providers(&self) -> &HashMap<String, ProviderConfig> {
&self.providers
}
pub fn resolve(&self, incoming: &str) -> Result<ResolvedTarget> {
if let Some((prefix, suffix)) = incoming.split_once(':')
&& self.providers.contains_key(prefix)
{
return Ok(ResolvedTarget {
provider_name: prefix.to_owned(),
model_id: suffix.to_owned(),
api_key_override: None,
api_base_override: None,
});
}
if let Some(model_config) = self.models.get(incoming) {
return self.select_endpoint(incoming, model_config);
}
Err(BitrouterError::invalid_request(
None,
format!("no route found for model: {incoming}"),
None,
))
}
fn select_endpoint(&self, model_name: &str, config: &ModelConfig) -> Result<ResolvedTarget> {
if config.endpoints.is_empty() {
return Err(BitrouterError::invalid_request(
None,
format!("model '{model_name}' has no configured endpoints"),
None,
));
}
let endpoint = match config.strategy {
RoutingStrategy::Priority => &config.endpoints[0],
RoutingStrategy::LoadBalance => {
let counter = self
.counters
.get(model_name)
.expect("counter must exist for every model");
let idx = counter.fetch_add(1, Ordering::Relaxed) % config.endpoints.len();
&config.endpoints[idx]
}
};
Ok(ResolvedTarget {
provider_name: endpoint.provider.clone(),
model_id: endpoint.model_id.clone(),
api_key_override: endpoint.api_key.clone(),
api_base_override: endpoint.api_base.clone(),
})
}
}
impl RoutingTable for ConfigRoutingTable {
async fn route(&self, incoming_model_name: &str) -> Result<RoutingTarget> {
let resolved = self.resolve(incoming_model_name)?;
Ok(RoutingTarget {
provider_name: resolved.provider_name,
model_id: resolved.model_id,
})
}
fn list_routes(&self) -> Vec<RouteEntry> {
let mut entries = Vec::new();
for (model_name, model_config) in &self.models {
if let Some(endpoint) = model_config.endpoints.first() {
let protocol = self
.providers
.get(&endpoint.provider)
.and_then(|p| p.api_protocol.as_ref())
.map(|p| match p {
ApiProtocol::Openai => "openai",
ApiProtocol::Anthropic => "anthropic",
ApiProtocol::Google => "google",
})
.unwrap_or("openai")
.to_owned();
entries.push(RouteEntry {
model: model_name.clone(),
provider: endpoint.provider.clone(),
protocol,
});
}
}
entries.sort_by(|a, b| a.model.cmp(&b.model));
entries
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{ApiProtocol, ModelEndpoint};
fn test_providers() -> HashMap<String, ProviderConfig> {
let mut p = HashMap::new();
p.insert(
"openai".into(),
ProviderConfig {
api_protocol: Some(ApiProtocol::Openai),
api_base: Some("https://api.openai.com/v1".into()),
..Default::default()
},
);
p.insert(
"anthropic".into(),
ProviderConfig {
api_protocol: Some(ApiProtocol::Anthropic),
api_base: Some("https://api.anthropic.com".into()),
..Default::default()
},
);
p
}
#[test]
fn direct_provider_routing() {
let table = ConfigRoutingTable::new(test_providers(), HashMap::new());
let target = table.resolve("openai:gpt-4o").unwrap();
assert_eq!(target.provider_name, "openai");
assert_eq!(target.model_id, "gpt-4o");
}
#[test]
fn direct_provider_routing_with_slash_in_model() {
let table = ConfigRoutingTable::new(test_providers(), HashMap::new());
let target = table.resolve("openai:deepseek/deepseek-v3").unwrap();
assert_eq!(target.provider_name, "openai");
assert_eq!(target.model_id, "deepseek/deepseek-v3");
}
#[test]
fn anthropic_direct_routing() {
let table = ConfigRoutingTable::new(test_providers(), HashMap::new());
let target = table.resolve("anthropic:claude-opus-4-6").unwrap();
assert_eq!(target.provider_name, "anthropic");
assert_eq!(target.model_id, "claude-opus-4-6");
}
#[test]
fn unknown_provider_prefix_falls_through_to_models() {
let mut models = HashMap::new();
models.insert(
"unknown:custom-model".into(),
ModelConfig {
strategy: RoutingStrategy::Priority,
endpoints: vec![ModelEndpoint {
provider: "openai".into(),
model_id: "custom-model".into(),
api_key: None,
api_base: None,
}],
},
);
let table = ConfigRoutingTable::new(test_providers(), models);
let target = table.resolve("unknown:custom-model").unwrap();
assert_eq!(target.provider_name, "openai");
assert_eq!(target.model_id, "custom-model");
}
#[test]
fn model_lookup_without_colon() {
let mut models = HashMap::new();
models.insert(
"my-gpt4".into(),
ModelConfig {
strategy: RoutingStrategy::Priority,
endpoints: vec![ModelEndpoint {
provider: "openai".into(),
model_id: "gpt-4o".into(),
api_key: Some("sk-override".into()),
api_base: None,
}],
},
);
let table = ConfigRoutingTable::new(test_providers(), models);
let target = table.resolve("my-gpt4").unwrap();
assert_eq!(target.provider_name, "openai");
assert_eq!(target.model_id, "gpt-4o");
assert_eq!(target.api_key_override.as_deref(), Some("sk-override"));
}
#[test]
fn slash_separator_does_not_match_provider() {
let table = ConfigRoutingTable::new(test_providers(), HashMap::new());
let result = table.resolve("openai/gpt-4o");
assert!(result.is_err());
}
#[test]
fn load_balance_round_robin() {
let mut models = HashMap::new();
models.insert(
"balanced".into(),
ModelConfig {
strategy: RoutingStrategy::LoadBalance,
endpoints: vec![
ModelEndpoint {
provider: "openai".into(),
model_id: "gpt-4o".into(),
api_key: Some("key-a".into()),
api_base: None,
},
ModelEndpoint {
provider: "openai".into(),
model_id: "gpt-4o".into(),
api_key: Some("key-b".into()),
api_base: None,
},
],
},
);
let table = ConfigRoutingTable::new(test_providers(), models);
let t1 = table.resolve("balanced").unwrap();
let t2 = table.resolve("balanced").unwrap();
let t3 = table.resolve("balanced").unwrap();
assert_eq!(t1.api_key_override.as_deref(), Some("key-a"));
assert_eq!(t2.api_key_override.as_deref(), Some("key-b"));
assert_eq!(t3.api_key_override.as_deref(), Some("key-a")); }
#[test]
fn priority_always_picks_first() {
let mut models = HashMap::new();
models.insert(
"primary".into(),
ModelConfig {
strategy: RoutingStrategy::Priority,
endpoints: vec![
ModelEndpoint {
provider: "openai".into(),
model_id: "gpt-4o".into(),
api_key: Some("primary-key".into()),
api_base: None,
},
ModelEndpoint {
provider: "openai".into(),
model_id: "gpt-4o".into(),
api_key: Some("fallback-key".into()),
api_base: None,
},
],
},
);
let table = ConfigRoutingTable::new(test_providers(), models);
for _ in 0..5 {
let t = table.resolve("primary").unwrap();
assert_eq!(t.api_key_override.as_deref(), Some("primary-key"));
}
}
#[test]
fn no_route_found() {
let table = ConfigRoutingTable::new(test_providers(), HashMap::new());
let result = table.resolve("nonexistent-model");
assert!(result.is_err());
}
#[test]
fn empty_endpoints_is_error() {
let mut models = HashMap::new();
models.insert(
"empty".into(),
ModelConfig {
strategy: RoutingStrategy::Priority,
endpoints: vec![],
},
);
let table = ConfigRoutingTable::new(test_providers(), models);
let result = table.resolve("empty");
assert!(result.is_err());
}
}