use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::{Path, PathBuf},
};
use serde::{Deserialize, Serialize};
use crate::env::{load_env, substitute_in_value};
use crate::registry::{builtin_providers, merge_provider, resolve_providers};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BitrouterConfig {
#[serde(default)]
pub server: ServerConfig,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub master_key: Option<String>,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub models: HashMap<String, ModelConfig>,
}
impl BitrouterConfig {
pub fn has_configured_providers(&self) -> bool {
self.providers.values().any(|p| p.api_key.is_some())
}
pub fn configured_provider_names(&self) -> Vec<String> {
let mut names: Vec<String> = self
.providers
.iter()
.filter(|(_, p)| p.api_key.is_some())
.map(|(name, _)| name.clone())
.collect();
names.sort();
names
}
pub fn load_from_file(path: &Path, env_file: Option<&Path>) -> crate::error::Result<Self> {
let raw =
std::fs::read_to_string(path).map_err(|e| crate::error::ConfigError::ConfigRead {
path: path.to_path_buf(),
source: e,
})?;
Self::load_from_str(&raw, env_file)
}
pub fn load_from_str(raw: &str, env_file: Option<&Path>) -> crate::error::Result<Self> {
let env = load_env(env_file);
let yaml_value: serde_yaml::Value = serde_yaml::from_str(raw)
.map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
let substituted = substitute_in_value(yaml_value, &env);
let mut config: BitrouterConfig = serde_yaml::from_value(substituted)
.map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
let mut providers = builtin_providers();
for (name, user_provider) in config.providers.drain() {
if let Some(existing) = providers.get_mut(&name) {
merge_provider(existing, user_provider);
} else {
providers.insert(name, user_provider);
}
}
config.providers = resolve_providers(providers, &env);
Ok(config)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DatabaseConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_listen")]
pub listen: SocketAddr,
#[serde(default)]
pub control: ControlEndpoint,
#[serde(default = "default_log_level")]
pub log_level: String,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
listen: default_listen(),
control: ControlEndpoint::default(),
log_level: default_log_level(),
}
}
}
fn default_listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8787)
}
fn default_log_level() -> String {
"info".into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControlEndpoint {
#[serde(default = "default_socket_path")]
pub socket: PathBuf,
}
impl Default for ControlEndpoint {
fn default() -> Self {
Self {
socket: default_socket_path(),
}
}
}
fn default_socket_path() -> PathBuf {
PathBuf::from("bitrouter.sock")
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApiProtocol {
Openai,
Anthropic,
Google,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ProviderConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub derives: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_protocol: Option<ApiProtocol>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_base: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth: Option<AuthConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub env_prefix: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub default_headers: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthConfig {
Bearer { api_key: String },
Header {
header_name: String,
api_key: String,
},
Custom {
method: String,
#[serde(default)]
params: serde_json::Value,
},
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RoutingStrategy {
#[default]
Priority,
LoadBalance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelEndpoint {
pub provider: String,
pub model_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_base: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
#[serde(default)]
pub strategy: RoutingStrategy,
pub endpoints: Vec<ModelEndpoint>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_round_trips_through_yaml() {
let config = BitrouterConfig::default();
let yaml = serde_yaml::to_string(&config).unwrap();
let parsed: BitrouterConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(parsed.server.listen, config.server.listen);
}
#[test]
fn load_minimal_yaml() {
let yaml = r#"
server:
listen: "127.0.0.1:9090"
providers:
openai:
api_key: "sk-test"
"#;
let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
assert_eq!(config.server.listen, "127.0.0.1:9090".parse().unwrap());
assert!(config.providers.contains_key("openai"));
assert!(config.providers.contains_key("anthropic"));
assert_eq!(
config.providers["openai"].api_key.as_deref(),
Some("sk-test")
);
}
#[test]
fn load_with_custom_derived_provider() {
let yaml = r#"
providers:
my-company:
derives: openai
api_base: "https://api.mycompany.com/v1"
api_key: "sk-custom"
"#;
let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
let p = &config.providers["my-company"];
assert_eq!(p.api_protocol, Some(ApiProtocol::Openai)); assert_eq!(p.api_base.as_deref(), Some("https://api.mycompany.com/v1")); assert_eq!(p.api_key.as_deref(), Some("sk-custom"));
assert!(p.derives.is_none()); }
#[test]
fn load_with_model_routing() {
let yaml = r#"
providers:
openai:
api_key: "sk-test"
models:
my-gpt4:
strategy: load_balance
endpoints:
- provider: openai
model_id: gpt-4o
api_key: "sk-key-a"
- provider: openai
model_id: gpt-4o
api_key: "sk-key-b"
"#;
let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
let model = &config.models["my-gpt4"];
assert_eq!(model.strategy, RoutingStrategy::LoadBalance);
assert_eq!(model.endpoints.len(), 2);
assert_eq!(model.endpoints[0].api_key.as_deref(), Some("sk-key-a"));
}
#[test]
fn load_with_custom_auth() {
let yaml = r#"
providers:
aimo:
derives: openai
api_base: "https://api.aimo.network/v1"
auth:
type: custom
method: siwx
params:
chain_id: 1
"#;
let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
let p = &config.providers["aimo"];
assert!(matches!(p.auth, Some(AuthConfig::Custom { .. })));
if let Some(AuthConfig::Custom { method, .. }) = &p.auth {
assert_eq!(method, "siwx");
}
}
#[test]
fn empty_yaml_gets_full_builtins() {
let config = BitrouterConfig::load_from_str("{}", None).unwrap();
assert!(config.providers.contains_key("openai"));
assert!(config.providers.contains_key("anthropic"));
assert!(config.providers.contains_key("google"));
}
}