1use std::{
2 collections::HashMap,
3 fmt,
4 net::{IpAddr, Ipv4Addr, SocketAddr},
5 path::{Path, PathBuf},
6};
7
8use serde::{Deserialize, Serialize};
9
10use bitrouter_core::routers::routing_table::ApiProtocol;
11
12use crate::env::{load_env, substitute_in_value};
13use crate::registry::{
14 builtin_agent_defs, builtin_providers, builtin_tool_provider_defs, merge_provider,
15 resolve_providers,
16};
17
18fn default_true() -> bool {
19 true
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
28pub struct BitrouterConfig {
29 #[serde(default)]
30 pub server: ServerConfig,
31
32 #[serde(default)]
34 pub database: DatabaseConfig,
35
36 #[serde(default)]
38 pub guardrails: bitrouter_guardrails::GuardrailConfig,
39
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub solana_rpc_url: Option<String>,
43
44 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub mpp: Option<MppConfig>,
47
48 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub wallet: Option<WalletConfig>,
54
55 #[serde(default = "default_true")]
59 pub inherit_defaults: bool,
60
61 #[serde(default)]
63 pub providers: HashMap<String, ProviderConfig>,
64
65 #[serde(default)]
67 pub models: HashMap<String, ModelConfig>,
68
69 #[serde(default)]
71 pub tools: HashMap<String, ToolConfig>,
72
73 #[serde(default)]
75 pub agents: HashMap<String, AgentConfig>,
76
77 #[serde(default)]
84 pub routing: HashMap<String, RoutingRuleConfig>,
85}
86
87impl BitrouterConfig {
88 pub fn has_configured_providers(&self) -> bool {
90 self.providers.values().any(|p| p.api_key.is_some())
91 }
92
93 pub fn configured_provider_names(&self) -> Vec<String> {
95 let mut names: Vec<String> = self
96 .providers
97 .iter()
98 .filter(|(_, p)| p.api_key.is_some())
99 .map(|(name, _)| name.clone())
100 .collect();
101 names.sort();
102 names
103 }
104
105 pub fn load_from_file(path: &Path, env_file: Option<&Path>) -> crate::error::Result<Self> {
114 let raw =
115 std::fs::read_to_string(path).map_err(|e| crate::error::ConfigError::ConfigRead {
116 path: path.to_path_buf(),
117 source: e,
118 })?;
119 Self::load_from_str(&raw, env_file)
120 }
121
122 pub fn load_from_str(raw: &str, env_file: Option<&Path>) -> crate::error::Result<Self> {
126 let env = load_env(env_file);
128
129 let yaml_value: serde_json::Value = serde_saphyr::from_str(raw)
133 .map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
134 let substituted = substitute_in_value(yaml_value, &env);
135 let substituted = if substituted.is_null() {
136 serde_json::Value::Object(serde_json::Map::new())
137 } else {
138 substituted
139 };
140 let mut config: BitrouterConfig = serde_json::from_value(substituted)
141 .map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
142
143 let mut providers = if config.inherit_defaults {
145 let mut base = builtin_providers();
146 for (name, user_provider) in config.providers.drain() {
147 if let Some(existing) = base.get_mut(&name) {
148 merge_provider(existing, user_provider);
149 } else {
150 base.insert(name, user_provider);
151 }
152 }
153 base
154 } else {
155 std::mem::take(&mut config.providers)
156 };
157
158 if config.inherit_defaults {
163 for (name, builtin) in builtin_tool_provider_defs() {
164 if let Some(existing) = providers.get_mut(&name) {
165 let mut base = builtin.config;
167 merge_provider(&mut base, std::mem::take(existing));
168 *existing = base;
169 } else {
170 providers.insert(name, builtin.config);
171 }
172 for (tool_name, tool_config) in builtin.tool_configs {
173 config.tools.entry(tool_name).or_insert(tool_config);
174 }
175 }
176 }
177
178 if config.inherit_defaults {
181 for (name, builtin) in builtin_agent_defs() {
182 config.agents.entry(name).or_insert(builtin);
183 }
184 }
185
186 config.providers = resolve_providers(providers, &env);
188
189 Ok(config)
190 }
191}
192
193#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "lowercase")]
198pub enum AgentProtocol {
199 #[default]
201 Acp,
202}
203
204impl fmt::Display for AgentProtocol {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 match self {
207 Self::Acp => write!(f, "acp"),
208 }
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct BinaryArchive {
215 pub archive: String,
217 pub cmd: String,
219 #[serde(default)]
221 pub args: Vec<String>,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226#[serde(rename_all = "lowercase")]
227pub enum Distribution {
228 Npx {
230 package: String,
231 #[serde(default)]
232 args: Vec<String>,
233 },
234 Uvx {
236 package: String,
237 #[serde(default)]
238 args: Vec<String>,
239 },
240 Binary {
242 platforms: HashMap<String, BinaryArchive>,
244 },
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct AgentSessionConfig {
254 #[serde(default = "default_idle_timeout_secs")]
257 pub idle_timeout_secs: u64,
258
259 #[serde(default = "default_max_concurrent")]
262 pub max_concurrent: usize,
263}
264
265fn default_idle_timeout_secs() -> u64 {
266 600
267}
268
269fn default_max_concurrent() -> usize {
270 1
271}
272
273impl Default for AgentSessionConfig {
274 fn default() -> Self {
275 Self {
276 idle_timeout_secs: default_idle_timeout_secs(),
277 max_concurrent: default_max_concurrent(),
278 }
279 }
280}
281
282#[derive(Debug, Clone, Default, Serialize, Deserialize)]
287pub struct AgentA2aConfig {
288 #[serde(default)]
290 pub enabled: bool,
291
292 #[serde(default, skip_serializing_if = "Vec::is_empty")]
294 pub skills: Vec<String>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct AgentConfig {
300 #[serde(default)]
302 pub protocol: AgentProtocol,
303
304 pub binary: String,
306
307 #[serde(default)]
309 pub args: Vec<String>,
310
311 #[serde(default = "default_true")]
313 pub enabled: bool,
314
315 #[serde(default, skip_serializing_if = "Vec::is_empty")]
317 pub distribution: Vec<Distribution>,
318
319 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub session: Option<AgentSessionConfig>,
324
325 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub a2a: Option<AgentA2aConfig>,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize, Default)]
336pub struct DatabaseConfig {
337 #[serde(default, skip_serializing_if = "Option::is_none")]
342 pub url: Option<String>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ServerConfig {
349 #[serde(default = "default_listen")]
350 pub listen: SocketAddr,
351
352 #[serde(default)]
353 pub control: ControlEndpoint,
354
355 #[serde(default = "default_log_level")]
356 pub log_level: String,
357}
358
359impl Default for ServerConfig {
360 fn default() -> Self {
361 Self {
362 listen: default_listen(),
363 control: ControlEndpoint::default(),
364 log_level: default_log_level(),
365 }
366 }
367}
368
369fn default_listen() -> SocketAddr {
370 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8787)
371}
372
373fn default_log_level() -> String {
374 "info".into()
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct ControlEndpoint {
379 #[serde(default = "default_socket_path")]
380 pub socket: PathBuf,
381}
382
383impl Default for ControlEndpoint {
384 fn default() -> Self {
385 Self {
386 socket: default_socket_path(),
387 }
388 }
389}
390
391fn default_socket_path() -> PathBuf {
392 PathBuf::from("bitrouter.sock")
393}
394
395#[derive(Debug, Clone, Default, Serialize, Deserialize)]
402pub struct ProviderConfig {
403 #[serde(default, skip_serializing_if = "Option::is_none")]
405 pub derives: Option<String>,
406
407 #[serde(default, skip_serializing_if = "Option::is_none")]
409 pub api_protocol: Option<ApiProtocol>,
410
411 #[serde(default, skip_serializing_if = "Option::is_none")]
413 pub api_base: Option<String>,
414
415 #[serde(default, skip_serializing_if = "Option::is_none")]
417 pub api_key: Option<String>,
418
419 #[serde(default, skip_serializing_if = "Option::is_none")]
421 pub auth: Option<AuthConfig>,
422
423 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub env_prefix: Option<String>,
427
428 #[serde(default, skip_serializing_if = "Option::is_none")]
430 pub default_headers: Option<HashMap<String, String>>,
431
432 #[serde(default, skip_serializing_if = "Option::is_none")]
438 pub models: Option<HashMap<String, ModelInfo>>,
439
440 #[serde(default, skip_serializing_if = "Option::is_none")]
444 pub bridge: Option<bool>,
445}
446
447#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
451#[serde(rename_all = "snake_case")]
452pub enum Modality {
453 Text,
454 Image,
455 Audio,
456 Video,
457 File,
458}
459
460impl fmt::Display for Modality {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 f.write_str(match self {
463 Self::Text => "text",
464 Self::Image => "image",
465 Self::Audio => "audio",
466 Self::Video => "video",
467 Self::File => "file",
468 })
469 }
470}
471
472#[derive(Debug, Clone, Default, Serialize, Deserialize)]
474pub struct ModelInfo {
475 #[serde(default, skip_serializing_if = "Option::is_none")]
477 pub name: Option<String>,
478
479 #[serde(default, skip_serializing_if = "Option::is_none")]
481 pub description: Option<String>,
482
483 #[serde(
488 default,
489 skip_serializing_if = "Option::is_none",
490 alias = "context_length"
491 )]
492 pub max_input_tokens: Option<u64>,
493
494 #[serde(default, skip_serializing_if = "Option::is_none")]
496 pub max_output_tokens: Option<u64>,
497
498 #[serde(default, skip_serializing_if = "Vec::is_empty")]
500 pub input_modalities: Vec<Modality>,
501
502 #[serde(default, skip_serializing_if = "Vec::is_empty")]
504 pub output_modalities: Vec<Modality>,
505
506 #[serde(default)]
508 pub pricing: ModelPricing,
509}
510
511pub use bitrouter_core::routers::routing_table::{
514 InputTokenPricing, ModelPricing, OutputTokenPricing,
515};
516
517#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct MppConfig {
522 #[serde(default)]
524 pub enabled: bool,
525
526 #[serde(default, skip_serializing_if = "Option::is_none")]
530 pub realm: Option<String>,
531
532 #[serde(default, skip_serializing_if = "Option::is_none")]
536 pub secret_key: Option<String>,
537
538 #[serde(default)]
543 pub networks: MppNetworksConfig,
544}
545
546#[derive(Debug, Clone, Default, Serialize, Deserialize)]
548pub struct MppNetworksConfig {
549 #[serde(default, skip_serializing_if = "Option::is_none")]
551 pub tempo: Option<TempoMppConfig>,
552
553 #[cfg(feature = "mpp-solana")]
555 #[serde(default, skip_serializing_if = "Option::is_none")]
556 pub solana: Option<SolanaMppConfig>,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct TempoMppConfig {
562 pub recipient: String,
564
565 pub escrow_contract: String,
567
568 #[serde(default, skip_serializing_if = "Option::is_none")]
570 pub rpc_url: Option<String>,
571
572 #[serde(default, skip_serializing_if = "Option::is_none")]
574 pub currency: Option<String>,
575
576 #[serde(default)]
578 pub fee_payer: bool,
579
580 #[serde(default, skip_serializing_if = "Option::is_none")]
584 pub close_signer: Option<String>,
585
586 #[serde(default, skip_serializing_if = "Option::is_none")]
589 pub default_deposit: Option<String>,
590}
591
592#[cfg(feature = "mpp-solana")]
594#[derive(Debug, Clone, Serialize, Deserialize)]
595pub struct SolanaMppConfig {
596 pub recipient: String,
598
599 pub channel_program: String,
601
602 #[serde(default = "default_solana_network")]
604 pub network: String,
605
606 #[serde(default)]
608 pub asset: SolanaAssetConfig,
609
610 #[serde(default, skip_serializing_if = "Option::is_none")]
614 pub suggested_deposit: Option<String>,
615}
616
617#[cfg(feature = "mpp-solana")]
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct SolanaAssetConfig {
621 #[serde(default = "default_solana_asset_kind")]
623 pub kind: String,
624
625 #[serde(default = "default_solana_asset_decimals")]
627 pub decimals: u8,
628
629 #[serde(default, skip_serializing_if = "Option::is_none")]
631 pub mint: Option<String>,
632
633 #[serde(default, skip_serializing_if = "Option::is_none")]
635 pub symbol: Option<String>,
636}
637
638#[cfg(feature = "mpp-solana")]
639impl Default for SolanaAssetConfig {
640 fn default() -> Self {
641 Self {
642 kind: default_solana_asset_kind(),
643 decimals: default_solana_asset_decimals(),
644 mint: None,
645 symbol: None,
646 }
647 }
648}
649
650#[cfg(feature = "mpp-solana")]
651fn default_solana_asset_kind() -> String {
652 "sol".into()
653}
654
655#[cfg(feature = "mpp-solana")]
656fn default_solana_asset_decimals() -> u8 {
657 9
658}
659
660#[cfg(feature = "mpp-solana")]
661fn default_solana_network() -> String {
662 "mainnet-beta".into()
663}
664
665#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct WalletConfig {
686 pub name: String,
688
689 #[serde(default, skip_serializing_if = "Option::is_none")]
691 pub vault_path: Option<String>,
692
693 #[serde(default, skip_serializing_if = "Option::is_none")]
699 pub payment: Option<PaymentClientConfig>,
700}
701
702#[derive(Debug, Clone, Serialize, Deserialize)]
707pub struct PaymentClientConfig {
708 #[serde(default, skip_serializing_if = "Option::is_none")]
713 pub tempo_rpc_url: Option<String>,
714
715 #[serde(default, skip_serializing_if = "Option::is_none")]
720 pub solana_rpc_url: Option<String>,
721
722 #[serde(default, skip_serializing_if = "Option::is_none")]
726 pub session_max_deposit: Option<u128>,
727
728 #[serde(default, skip_serializing_if = "Option::is_none")]
732 pub session_default_deposit: Option<u128>,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
737#[serde(tag = "type", rename_all = "snake_case")]
738pub enum AuthConfig {
739 Bearer { api_key: String },
741 Header {
743 header_name: String,
744 api_key: String,
745 },
746 X402,
748 Mpp,
750 Wallet,
752 #[serde(rename = "oauth")]
757 OAuth {
758 grant: OAuthGrant,
760 client_id: String,
762 #[serde(default, skip_serializing_if = "Option::is_none")]
764 scope: Option<String>,
765 #[serde(default, skip_serializing_if = "Option::is_none")]
767 device_auth_url: Option<String>,
768 #[serde(default, skip_serializing_if = "Option::is_none")]
770 token_url: Option<String>,
771 #[serde(default, skip_serializing_if = "Option::is_none")]
778 domain: Option<String>,
779 },
780 Custom {
782 method: String,
783 #[serde(default)]
784 params: serde_json::Value,
785 },
786}
787
788#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
790#[serde(rename_all = "snake_case")]
791pub enum OAuthGrant {
792 DeviceCode,
794}
795
796#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
800#[serde(rename_all = "snake_case")]
801pub enum RoutingStrategy {
802 #[default]
804 Priority,
805 LoadBalance,
807}
808
809#[derive(Debug, Clone, Serialize, Deserialize)]
811pub struct Endpoint {
812 pub provider: String,
814
815 #[serde(alias = "model_id", alias = "tool_id")]
817 pub service_id: String,
818
819 #[serde(default, skip_serializing_if = "Option::is_none")]
824 pub api_protocol: Option<ApiProtocol>,
825
826 #[serde(default, skip_serializing_if = "Option::is_none")]
828 pub api_key: Option<String>,
829
830 #[serde(default, skip_serializing_if = "Option::is_none")]
832 pub api_base: Option<String>,
833}
834
835#[derive(Debug, Clone, Default, Serialize, Deserialize)]
837pub struct ModelConfig {
838 #[serde(default)]
839 pub strategy: RoutingStrategy,
840
841 pub endpoints: Vec<Endpoint>,
842
843 #[serde(default, skip_serializing_if = "Option::is_none")]
845 pub name: Option<String>,
846
847 #[serde(default, skip_serializing_if = "Option::is_none")]
849 pub max_input_tokens: Option<u64>,
850
851 #[serde(default, skip_serializing_if = "Option::is_none")]
853 pub max_output_tokens: Option<u64>,
854
855 #[serde(default, skip_serializing_if = "Vec::is_empty")]
857 pub input_modalities: Vec<Modality>,
858
859 #[serde(default, skip_serializing_if = "Vec::is_empty")]
861 pub output_modalities: Vec<Modality>,
862
863 #[serde(default)]
865 pub pricing: ModelPricing,
866}
867
868#[derive(Debug, Clone, Default, Serialize, Deserialize)]
872pub struct ToolConfig {
873 #[serde(default)]
875 pub strategy: RoutingStrategy,
876
877 pub endpoints: Vec<Endpoint>,
879
880 #[serde(default, skip_serializing_if = "Option::is_none")]
882 pub pricing: Option<bitrouter_core::pricing::FlatPricing>,
883
884 #[serde(default, skip_serializing_if = "Option::is_none")]
886 pub description: Option<String>,
887
888 #[serde(default, skip_serializing_if = "Option::is_none")]
890 pub input_schema: Option<serde_json::Value>,
891
892 #[serde(default, skip_serializing_if = "Option::is_none")]
898 pub skill: Option<String>,
899}
900
901#[derive(Debug, Clone, Default, Serialize, Deserialize)]
909pub struct RoutingRuleConfig {
910 #[serde(default = "default_true")]
914 pub inherit_defaults: bool,
915
916 #[serde(default)]
918 pub signals: HashMap<String, SignalConfig>,
919
920 #[serde(default)]
923 pub complexity: ComplexityConfig,
924
925 #[serde(default)]
930 pub models: HashMap<String, String>,
931}
932
933#[derive(Debug, Clone, Default, Serialize, Deserialize)]
935pub struct SignalConfig {
936 #[serde(default)]
938 pub keywords: Vec<String>,
939}
940
941#[derive(Debug, Clone, Default, Serialize, Deserialize)]
942pub struct ComplexityConfig {
943 #[serde(default)]
945 pub high_keywords: Vec<String>,
946
947 #[serde(default)]
950 pub message_length_threshold: Option<usize>,
951
952 #[serde(default)]
955 pub turn_count_threshold: Option<usize>,
956
957 #[serde(default)]
959 pub code_blocks_increase_complexity: bool,
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965
966 #[test]
967 fn default_config_round_trips_through_yaml() {
968 let config = BitrouterConfig::default();
969 let yaml = serde_saphyr::to_string(&config).unwrap();
970 let parsed: BitrouterConfig = serde_saphyr::from_str(&yaml).unwrap();
971 assert_eq!(parsed.server.listen, config.server.listen);
972 }
973
974 #[test]
975 fn load_minimal_yaml() {
976 let yaml = r#"
977server:
978 listen: "127.0.0.1:9090"
979providers:
980 openai:
981 api_key: "sk-test"
982"#;
983 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
984 assert_eq!(config.server.listen, "127.0.0.1:9090".parse().unwrap());
985 assert!(config.providers.contains_key("openai"));
987 assert!(config.providers.contains_key("anthropic"));
988 assert_eq!(
989 config.providers["openai"].api_key.as_deref(),
990 Some("sk-test")
991 );
992 }
993
994 #[test]
995 fn load_with_custom_derived_provider() {
996 let yaml = r#"
997providers:
998 my-company:
999 derives: openai
1000 api_base: "https://api.mycompany.com/v1"
1001 api_key: "sk-custom"
1002"#;
1003 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1004 let p = &config.providers["my-company"];
1005 assert_eq!(p.api_protocol, Some(ApiProtocol::Openai)); assert_eq!(p.api_base.as_deref(), Some("https://api.mycompany.com/v1")); assert_eq!(p.api_key.as_deref(), Some("sk-custom"));
1008 assert!(p.derives.is_none()); }
1010
1011 #[test]
1012 fn load_with_model_routing() {
1013 let yaml = r#"
1014providers:
1015 openai:
1016 api_key: "sk-test"
1017models:
1018 my-gpt4:
1019 strategy: load_balance
1020 endpoints:
1021 - provider: openai
1022 model_id: gpt-4o
1023 api_key: "sk-key-a"
1024 - provider: openai
1025 model_id: gpt-4o
1026 api_key: "sk-key-b"
1027"#;
1028 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1029 let model = &config.models["my-gpt4"];
1030 assert_eq!(model.strategy, RoutingStrategy::LoadBalance);
1031 assert_eq!(model.endpoints.len(), 2);
1032 assert_eq!(model.endpoints[0].api_key.as_deref(), Some("sk-key-a"));
1033 }
1034
1035 #[test]
1036 fn load_with_custom_auth() {
1037 let yaml = r#"
1038providers:
1039 aimo:
1040 derives: openai
1041 api_base: "https://api.aimo.network/v1"
1042 auth:
1043 type: custom
1044 method: siwx
1045 params:
1046 chain_id: 1
1047"#;
1048 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1049 let p = &config.providers["aimo"];
1050 assert!(matches!(p.auth, Some(AuthConfig::Custom { .. })));
1051 if let Some(AuthConfig::Custom { method, .. }) = &p.auth {
1052 assert_eq!(method, "siwx");
1053 }
1054 }
1055
1056 #[test]
1057 fn empty_yaml_gets_full_builtins() {
1058 let config = BitrouterConfig::load_from_str("{}", None).unwrap();
1059 assert!(config.providers.contains_key("openai"));
1060 assert!(config.providers.contains_key("anthropic"));
1061 assert!(config.providers.contains_key("google"));
1062 }
1063
1064 #[test]
1065 fn load_with_provider_model_metadata() {
1066 let yaml = r#"
1067providers:
1068 openai:
1069 api_key: "sk-test"
1070 models:
1071 gpt-4o:
1072 name: "GPT-4o"
1073 max_input_tokens: 128000
1074 max_output_tokens: 16384
1075 input_modalities: [text, image]
1076 output_modalities: [text]
1077 pricing:
1078 input_tokens:
1079 no_cache: 2.50
1080 output_tokens:
1081 text: 10.00
1082 gpt-4o-mini:
1083 name: "GPT-4o Mini"
1084"#;
1085 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1086 let openai = &config.providers["openai"];
1087 let models = openai.models.as_ref().unwrap();
1088
1089 let gpt4o = &models["gpt-4o"];
1090 assert_eq!(gpt4o.name.as_deref(), Some("GPT-4o"));
1091 assert_eq!(gpt4o.max_input_tokens, Some(128000));
1092 assert_eq!(gpt4o.max_output_tokens, Some(16384));
1093 assert_eq!(
1094 gpt4o.input_modalities,
1095 vec![Modality::Text, Modality::Image]
1096 );
1097 assert_eq!(gpt4o.pricing.input_tokens.no_cache, Some(2.50));
1098 assert_eq!(gpt4o.pricing.output_tokens.text, Some(10.00));
1099
1100 let mini = &models["gpt-4o-mini"];
1101 assert_eq!(mini.name.as_deref(), Some("GPT-4o Mini"));
1102 assert_eq!(mini.pricing.input_tokens.no_cache, None); }
1104
1105 #[test]
1106 fn derives_inherits_model_catalog() {
1107 let yaml = r#"
1108providers:
1109 my-openai:
1110 derives: openai
1111 api_key: "sk-custom"
1112"#;
1113 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1114 let my_openai = &config.providers["my-openai"];
1115 let models = my_openai.models.as_ref().unwrap();
1117 assert!(models.contains_key("gpt-4o"));
1118 }
1119
1120 #[test]
1121 fn inherit_defaults_true_by_default() {
1122 let config = BitrouterConfig::load_from_str("{}", None).unwrap();
1123 assert!(config.inherit_defaults);
1124 assert!(config.providers.contains_key("openai"));
1125 assert!(config.providers.contains_key("bitrouter"));
1126 }
1127
1128 #[test]
1129 fn inherit_defaults_false_excludes_builtins() {
1130 let yaml = r#"
1131inherit_defaults: false
1132providers:
1133 custom:
1134 api_protocol: openai
1135 api_base: "https://custom.example.com/v1"
1136 api_key: "sk-custom"
1137"#;
1138 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1139 assert!(!config.inherit_defaults);
1140 assert!(config.providers.contains_key("custom"));
1141 assert!(!config.providers.contains_key("openai"));
1142 assert!(!config.providers.contains_key("bitrouter"));
1143 assert_eq!(config.providers.len(), 1);
1144 }
1145
1146 #[test]
1147 fn load_with_tool_routing() {
1148 let yaml = r#"
1149providers:
1150 github-mcp:
1151 api_protocol: mcp
1152 api_base: "https://api.githubcopilot.com/mcp"
1153 api_key: "ghp-test"
1154tools:
1155 create_issue:
1156 strategy: priority
1157 endpoints:
1158 - provider: github-mcp
1159 tool_id: create_issue
1160 search_code:
1161 endpoints:
1162 - provider: github-mcp
1163 tool_id: search_code
1164 api_protocol: mcp
1165"#;
1166 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1167 assert!(config.tools.len() >= 2);
1169 assert!(config.tools.contains_key("create_issue"));
1170 assert!(config.tools.contains_key("search_code"));
1171
1172 let tool = &config.tools["create_issue"];
1173 assert_eq!(tool.strategy, RoutingStrategy::Priority);
1174 assert_eq!(tool.endpoints.len(), 1);
1175 assert_eq!(tool.endpoints[0].provider, "github-mcp");
1176 assert_eq!(tool.endpoints[0].service_id, "create_issue");
1177 assert!(tool.endpoints[0].api_protocol.is_none());
1178
1179 let search = &config.tools["search_code"];
1180 assert_eq!(search.endpoints[0].api_protocol, Some(ApiProtocol::Mcp));
1181 }
1182
1183 #[test]
1184 fn full_template_deserializes() {
1185 let yaml = include_str!("../templates/full.yaml");
1186 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1187
1188 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1190 assert_eq!(config.server.log_level, "info");
1191
1192 assert!(config.database.url.is_some());
1194
1195 assert!(config.providers.contains_key("openai"));
1197 assert!(config.providers.contains_key("anthropic"));
1198 assert!(config.providers.contains_key("google"));
1199 assert!(config.providers.contains_key("my-proxy"));
1200 assert!(config.providers.contains_key("custom-llm"));
1201 assert!(config.providers.contains_key("github-mcp"));
1202 assert!(config.providers.contains_key("header-auth-provider"));
1203 assert!(config.providers.contains_key("paid-provider"));
1204
1205 let my_proxy = &config.providers["my-proxy"];
1207 assert_eq!(my_proxy.api_protocol, Some(ApiProtocol::Openai));
1208 assert!(my_proxy.derives.is_none()); let custom = &config.providers["custom-llm"];
1212 assert_eq!(custom.api_protocol, Some(ApiProtocol::Openai));
1213 let models = custom.models.as_ref().unwrap();
1214 assert!(models.contains_key("my-model-7b"));
1215
1216 assert_eq!(config.models.len(), 3);
1218 assert!(config.models.contains_key("smart"));
1219 assert!(config.models.contains_key("fast"));
1220 assert!(config.models.contains_key("coding"));
1221 assert_eq!(config.models["smart"].strategy, RoutingStrategy::Priority);
1222 assert_eq!(config.models["fast"].strategy, RoutingStrategy::LoadBalance);
1223
1224 assert!(config.tools.contains_key("create_issue"));
1226 assert!(config.tools.contains_key("web_search"));
1227
1228 assert!(config.guardrails.enabled);
1230 assert!(!config.guardrails.disabled_patterns.is_empty());
1231 assert!(!config.guardrails.custom_patterns.is_empty());
1232 assert!(!config.guardrails.upgoing.is_empty());
1233 assert!(!config.guardrails.downgoing.is_empty());
1234
1235 let wallet = config.wallet.as_ref().unwrap();
1237 assert_eq!(wallet.name, "my-wallet");
1238 assert!(wallet.payment.is_some());
1239
1240 let mpp = config.mpp.as_ref().unwrap();
1242 assert!(mpp.enabled);
1243 }
1244
1245 #[test]
1246 fn minimal_template_deserializes() {
1247 let yaml = "";
1248 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1249
1250 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1252 assert_eq!(config.server.log_level, "info");
1253
1254 assert!(config.inherit_defaults);
1256 assert!(config.providers.contains_key("openai"));
1257 assert!(config.providers.contains_key("anthropic"));
1258 assert!(config.providers.contains_key("google"));
1259 assert!(config.providers.contains_key("bitrouter"));
1260
1261 assert!(config.models.is_empty());
1263
1264 assert!(config.wallet.is_none());
1266 assert!(config.mpp.is_none());
1267
1268 assert!(config.guardrails.enabled);
1270 }
1271
1272 #[test]
1273 fn empty_string_deserializes() {
1274 let config = BitrouterConfig::load_from_str("", None).unwrap();
1275
1276 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1278 assert!(config.inherit_defaults);
1279 assert!(config.providers.contains_key("openai"));
1280 assert!(config.providers.contains_key("anthropic"));
1281 assert!(config.providers.contains_key("google"));
1282 assert!(config.models.is_empty());
1283 assert!(config.guardrails.enabled);
1284 }
1285
1286 #[test]
1287 fn load_with_oauth_auth() {
1288 let yaml = r#"
1289providers:
1290 github-copilot:
1291 api_protocol: openai
1292 api_base: "https://api.githubcopilot.com"
1293 auth:
1294 type: oauth
1295 grant: device_code
1296 client_id: "Iv23limb4eFHH5zfOCr2"
1297 scope: "read:user"
1298 device_auth_url: "https://github.com/login/device/code"
1299 token_url: "https://github.com/login/oauth/access_token"
1300"#;
1301 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1302 let p = &config.providers["github-copilot"];
1303 assert!(matches!(p.auth, Some(AuthConfig::OAuth { .. })));
1304 if let Some(AuthConfig::OAuth {
1305 grant,
1306 client_id,
1307 scope,
1308 device_auth_url,
1309 token_url,
1310 ..
1311 }) = &p.auth
1312 {
1313 assert_eq!(*grant, OAuthGrant::DeviceCode);
1314 assert_eq!(client_id, "Iv23limb4eFHH5zfOCr2");
1315 assert_eq!(scope.as_deref(), Some("read:user"));
1316 assert_eq!(
1317 device_auth_url.as_deref(),
1318 Some("https://github.com/login/device/code")
1319 );
1320 assert_eq!(
1321 token_url.as_deref(),
1322 Some("https://github.com/login/oauth/access_token")
1323 );
1324 }
1325 }
1326
1327 #[test]
1328 fn load_oauth_with_defaults() {
1329 let yaml = r#"
1330providers:
1331 test-oauth:
1332 api_protocol: openai
1333 api_base: "https://api.example.com"
1334 auth:
1335 type: oauth
1336 grant: device_code
1337 client_id: "test-client-id"
1338"#;
1339 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1340 let p = &config.providers["test-oauth"];
1341 if let Some(AuthConfig::OAuth {
1342 scope,
1343 device_auth_url,
1344 token_url,
1345 ..
1346 }) = &p.auth
1347 {
1348 assert!(scope.is_none());
1349 assert!(device_auth_url.is_none());
1350 assert!(token_url.is_none());
1351 } else {
1352 panic!("expected OAuth auth config");
1353 }
1354 }
1355}