1use std::{
2 collections::HashMap,
3 fmt,
4 net::{IpAddr, Ipv4Addr, SocketAddr},
5 path::{Path, PathBuf},
6};
7
8use serde::{Deserialize, Serialize};
9
10use bitrouter_core::routers::routing_table::ApiProtocol;
11
12use crate::env::{load_env, substitute_in_value};
13use crate::registry::{
14 builtin_agent_defs, builtin_providers, builtin_tool_provider_defs, merge_provider,
15 resolve_providers,
16};
17
18fn default_true() -> bool {
19 true
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default)]
28pub struct BitrouterConfig {
29 #[serde(default)]
30 pub server: ServerConfig,
31
32 #[serde(default)]
34 pub database: DatabaseConfig,
35
36 #[serde(default)]
38 pub guardrails: bitrouter_guardrails::GuardrailConfig,
39
40 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub solana_rpc_url: Option<String>,
43
44 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub mpp: Option<MppConfig>,
47
48 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub wallet: Option<WalletConfig>,
54
55 #[serde(default = "default_true")]
59 pub inherit_defaults: bool,
60
61 #[serde(default)]
63 pub providers: HashMap<String, ProviderConfig>,
64
65 #[serde(default)]
67 pub models: HashMap<String, ModelConfig>,
68
69 #[serde(default)]
71 pub tools: HashMap<String, ToolConfig>,
72
73 #[serde(default)]
75 pub agents: HashMap<String, AgentConfig>,
76
77 #[serde(default)]
84 pub routing: HashMap<String, RoutingRuleConfig>,
85}
86
87impl BitrouterConfig {
88 pub fn has_configured_providers(&self) -> bool {
90 self.providers.values().any(|p| p.api_key.is_some())
91 }
92
93 pub fn configured_provider_names(&self) -> Vec<String> {
95 let mut names: Vec<String> = self
96 .providers
97 .iter()
98 .filter(|(_, p)| p.api_key.is_some())
99 .map(|(name, _)| name.clone())
100 .collect();
101 names.sort();
102 names
103 }
104
105 pub fn load_from_file(path: &Path, env_file: Option<&Path>) -> crate::error::Result<Self> {
114 let raw =
115 std::fs::read_to_string(path).map_err(|e| crate::error::ConfigError::ConfigRead {
116 path: path.to_path_buf(),
117 source: e,
118 })?;
119 Self::load_from_str(&raw, env_file)
120 }
121
122 pub fn load_from_str(raw: &str, env_file: Option<&Path>) -> crate::error::Result<Self> {
126 let env = load_env(env_file);
128
129 let yaml_value: serde_json::Value = serde_saphyr::from_str(raw)
133 .map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
134 let substituted = substitute_in_value(yaml_value, &env);
135 let substituted = if substituted.is_null() {
136 serde_json::Value::Object(serde_json::Map::new())
137 } else {
138 substituted
139 };
140 let mut config: BitrouterConfig = serde_json::from_value(substituted)
141 .map_err(|e| crate::error::ConfigError::ConfigParse(e.to_string()))?;
142
143 let mut providers = if config.inherit_defaults {
145 let mut base = builtin_providers();
146 for (name, user_provider) in config.providers.drain() {
147 if let Some(existing) = base.get_mut(&name) {
148 merge_provider(existing, user_provider);
149 } else {
150 base.insert(name, user_provider);
151 }
152 }
153 base
154 } else {
155 std::mem::take(&mut config.providers)
156 };
157
158 if config.inherit_defaults {
163 for (name, builtin) in builtin_tool_provider_defs() {
164 if let Some(existing) = providers.get_mut(&name) {
165 let mut base = builtin.config;
167 merge_provider(&mut base, std::mem::take(existing));
168 *existing = base;
169 } else {
170 providers.insert(name, builtin.config);
171 }
172 for (tool_name, tool_config) in builtin.tool_configs {
173 config.tools.entry(tool_name).or_insert(tool_config);
174 }
175 }
176 }
177
178 if config.inherit_defaults {
181 for (name, builtin) in builtin_agent_defs() {
182 config.agents.entry(name).or_insert(builtin);
183 }
184 }
185
186 config.providers = resolve_providers(providers, &env);
188
189 Ok(config)
190 }
191}
192
193#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "lowercase")]
198pub enum AgentProtocol {
199 #[default]
201 Acp,
202}
203
204impl fmt::Display for AgentProtocol {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 match self {
207 Self::Acp => write!(f, "acp"),
208 }
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct BinaryArchive {
215 pub archive: String,
217 pub cmd: String,
219 #[serde(default)]
221 pub args: Vec<String>,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226#[serde(rename_all = "lowercase")]
227pub enum Distribution {
228 Npx {
230 package: String,
231 #[serde(default)]
232 args: Vec<String>,
233 },
234 Uvx {
236 package: String,
237 #[serde(default)]
238 args: Vec<String>,
239 },
240 Binary {
242 platforms: HashMap<String, BinaryArchive>,
244 },
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct AgentSessionConfig {
254 #[serde(default = "default_idle_timeout_secs")]
257 pub idle_timeout_secs: u64,
258
259 #[serde(default = "default_max_concurrent")]
262 pub max_concurrent: usize,
263}
264
265fn default_idle_timeout_secs() -> u64 {
266 600
267}
268
269fn default_max_concurrent() -> usize {
270 1
271}
272
273impl Default for AgentSessionConfig {
274 fn default() -> Self {
275 Self {
276 idle_timeout_secs: default_idle_timeout_secs(),
277 max_concurrent: default_max_concurrent(),
278 }
279 }
280}
281
282#[derive(Debug, Clone, Default, Serialize, Deserialize)]
287pub struct AgentA2aConfig {
288 #[serde(default)]
290 pub enabled: bool,
291
292 #[serde(default, skip_serializing_if = "Vec::is_empty")]
294 pub skills: Vec<String>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct AgentConfig {
300 #[serde(default)]
302 pub protocol: AgentProtocol,
303
304 pub binary: String,
306
307 #[serde(default)]
309 pub args: Vec<String>,
310
311 #[serde(default = "default_true")]
313 pub enabled: bool,
314
315 #[serde(default, skip_serializing_if = "Vec::is_empty")]
317 pub distribution: Vec<Distribution>,
318
319 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub session: Option<AgentSessionConfig>,
324
325 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub a2a: Option<AgentA2aConfig>,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize, Default)]
336pub struct DatabaseConfig {
337 #[serde(default, skip_serializing_if = "Option::is_none")]
342 pub url: Option<String>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ServerConfig {
349 #[serde(default = "default_listen")]
350 pub listen: SocketAddr,
351
352 #[serde(default)]
353 pub control: ControlEndpoint,
354
355 #[serde(default = "default_log_level")]
356 pub log_level: String,
357}
358
359impl Default for ServerConfig {
360 fn default() -> Self {
361 Self {
362 listen: default_listen(),
363 control: ControlEndpoint::default(),
364 log_level: default_log_level(),
365 }
366 }
367}
368
369fn default_listen() -> SocketAddr {
370 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8787)
371}
372
373fn default_log_level() -> String {
374 "info".into()
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct ControlEndpoint {
379 #[serde(default = "default_socket_path")]
380 pub socket: PathBuf,
381}
382
383impl Default for ControlEndpoint {
384 fn default() -> Self {
385 Self {
386 socket: default_socket_path(),
387 }
388 }
389}
390
391fn default_socket_path() -> PathBuf {
392 PathBuf::from("bitrouter.sock")
393}
394
395#[derive(Debug, Clone, Default, Serialize, Deserialize)]
402pub struct ProviderConfig {
403 #[serde(default, skip_serializing_if = "Option::is_none")]
405 pub derives: Option<String>,
406
407 #[serde(default, skip_serializing_if = "Option::is_none")]
409 pub api_protocol: Option<ApiProtocol>,
410
411 #[serde(default, skip_serializing_if = "Option::is_none")]
413 pub api_base: Option<String>,
414
415 #[serde(default, skip_serializing_if = "Option::is_none")]
417 pub api_key: Option<String>,
418
419 #[serde(default, skip_serializing_if = "Option::is_none")]
421 pub auth: Option<AuthConfig>,
422
423 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub env_prefix: Option<String>,
427
428 #[serde(default, skip_serializing_if = "Option::is_none")]
430 pub default_headers: Option<HashMap<String, String>>,
431
432 #[serde(default, skip_serializing_if = "Option::is_none")]
438 pub models: Option<HashMap<String, ModelInfo>>,
439
440 #[serde(default, skip_serializing_if = "Option::is_none")]
444 pub bridge: Option<bool>,
445}
446
447#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
451#[serde(rename_all = "snake_case")]
452pub enum Modality {
453 Text,
454 Image,
455 Audio,
456 Video,
457 File,
458}
459
460impl fmt::Display for Modality {
461 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462 f.write_str(match self {
463 Self::Text => "text",
464 Self::Image => "image",
465 Self::Audio => "audio",
466 Self::Video => "video",
467 Self::File => "file",
468 })
469 }
470}
471
472#[derive(Debug, Clone, Default, Serialize, Deserialize)]
474pub struct ModelInfo {
475 #[serde(default, skip_serializing_if = "Option::is_none")]
477 pub name: Option<String>,
478
479 #[serde(default, skip_serializing_if = "Option::is_none")]
481 pub description: Option<String>,
482
483 #[serde(
488 default,
489 skip_serializing_if = "Option::is_none",
490 alias = "context_length"
491 )]
492 pub max_input_tokens: Option<u64>,
493
494 #[serde(default, skip_serializing_if = "Option::is_none")]
496 pub max_output_tokens: Option<u64>,
497
498 #[serde(default, skip_serializing_if = "Vec::is_empty")]
500 pub input_modalities: Vec<Modality>,
501
502 #[serde(default, skip_serializing_if = "Vec::is_empty")]
504 pub output_modalities: Vec<Modality>,
505
506 #[serde(default)]
508 pub pricing: ModelPricing,
509}
510
511pub use bitrouter_core::routers::routing_table::{
514 InputTokenPricing, ModelPricing, OutputTokenPricing,
515};
516
517#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct MppConfig {
522 #[serde(default)]
524 pub enabled: bool,
525
526 #[serde(default, skip_serializing_if = "Option::is_none")]
530 pub realm: Option<String>,
531
532 #[serde(default, skip_serializing_if = "Option::is_none")]
536 pub secret_key: Option<String>,
537
538 #[serde(default)]
543 pub networks: MppNetworksConfig,
544}
545
546#[derive(Debug, Clone, Default, Serialize, Deserialize)]
548pub struct MppNetworksConfig {
549 #[serde(default, skip_serializing_if = "Option::is_none")]
551 pub tempo: Option<TempoMppConfig>,
552
553 #[cfg(feature = "mpp-solana")]
555 #[serde(default, skip_serializing_if = "Option::is_none")]
556 pub solana: Option<SolanaMppConfig>,
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct TempoMppConfig {
562 pub recipient: String,
564
565 pub escrow_contract: String,
567
568 #[serde(default, skip_serializing_if = "Option::is_none")]
570 pub rpc_url: Option<String>,
571
572 #[serde(default, skip_serializing_if = "Option::is_none")]
574 pub currency: Option<String>,
575
576 #[serde(default)]
578 pub fee_payer: bool,
579
580 #[serde(default, skip_serializing_if = "Option::is_none")]
584 pub close_signer: Option<String>,
585
586 #[serde(default, skip_serializing_if = "Option::is_none")]
589 pub default_deposit: Option<String>,
590}
591
592#[cfg(feature = "mpp-solana")]
594#[derive(Debug, Clone, Serialize, Deserialize)]
595pub struct SolanaMppConfig {
596 pub recipient: String,
598
599 pub channel_program: String,
601
602 #[serde(default = "default_solana_network")]
604 pub network: String,
605
606 #[serde(default)]
608 pub asset: SolanaAssetConfig,
609
610 #[serde(default, skip_serializing_if = "Option::is_none")]
614 pub suggested_deposit: Option<String>,
615}
616
617#[cfg(feature = "mpp-solana")]
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct SolanaAssetConfig {
621 #[serde(default = "default_solana_asset_kind")]
623 pub kind: String,
624
625 #[serde(default = "default_solana_asset_decimals")]
627 pub decimals: u8,
628
629 #[serde(default, skip_serializing_if = "Option::is_none")]
631 pub mint: Option<String>,
632
633 #[serde(default, skip_serializing_if = "Option::is_none")]
635 pub symbol: Option<String>,
636}
637
638#[cfg(feature = "mpp-solana")]
639impl Default for SolanaAssetConfig {
640 fn default() -> Self {
641 Self {
642 kind: default_solana_asset_kind(),
643 decimals: default_solana_asset_decimals(),
644 mint: None,
645 symbol: None,
646 }
647 }
648}
649
650#[cfg(feature = "mpp-solana")]
651fn default_solana_asset_kind() -> String {
652 "sol".into()
653}
654
655#[cfg(feature = "mpp-solana")]
656fn default_solana_asset_decimals() -> u8 {
657 9
658}
659
660#[cfg(feature = "mpp-solana")]
661fn default_solana_network() -> String {
662 "mainnet-beta".into()
663}
664
665#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct WalletConfig {
686 pub name: String,
688
689 #[serde(default, skip_serializing_if = "Option::is_none")]
691 pub vault_path: Option<String>,
692
693 #[serde(default, skip_serializing_if = "Option::is_none")]
699 pub payment: Option<PaymentClientConfig>,
700}
701
702#[derive(Debug, Clone, Serialize, Deserialize)]
707pub struct PaymentClientConfig {
708 #[serde(default, skip_serializing_if = "Option::is_none")]
713 pub tempo_rpc_url: Option<String>,
714
715 #[serde(default, skip_serializing_if = "Option::is_none")]
720 pub solana_rpc_url: Option<String>,
721
722 #[serde(default, skip_serializing_if = "Option::is_none")]
726 pub session_max_deposit: Option<u128>,
727
728 #[serde(default, skip_serializing_if = "Option::is_none")]
732 pub session_default_deposit: Option<u128>,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
737#[serde(tag = "type", rename_all = "snake_case")]
738pub enum AuthConfig {
739 Bearer { api_key: String },
741 Header {
743 header_name: String,
744 api_key: String,
745 },
746 X402,
748 Mpp,
750 Wallet,
752 #[serde(rename = "oauth")]
757 OAuth {
758 grant: OAuthGrant,
760 client_id: String,
762 #[serde(default, skip_serializing_if = "Option::is_none")]
764 scope: Option<String>,
765 #[serde(default, skip_serializing_if = "Option::is_none")]
767 device_auth_url: Option<String>,
768 #[serde(default, skip_serializing_if = "Option::is_none")]
770 token_url: Option<String>,
771 },
772 Custom {
774 method: String,
775 #[serde(default)]
776 params: serde_json::Value,
777 },
778}
779
780#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
782#[serde(rename_all = "snake_case")]
783pub enum OAuthGrant {
784 DeviceCode,
786}
787
788#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
792#[serde(rename_all = "snake_case")]
793pub enum RoutingStrategy {
794 #[default]
796 Priority,
797 LoadBalance,
799}
800
801#[derive(Debug, Clone, Serialize, Deserialize)]
803pub struct Endpoint {
804 pub provider: String,
806
807 #[serde(alias = "model_id", alias = "tool_id")]
809 pub service_id: String,
810
811 #[serde(default, skip_serializing_if = "Option::is_none")]
816 pub api_protocol: Option<ApiProtocol>,
817
818 #[serde(default, skip_serializing_if = "Option::is_none")]
820 pub api_key: Option<String>,
821
822 #[serde(default, skip_serializing_if = "Option::is_none")]
824 pub api_base: Option<String>,
825}
826
827#[derive(Debug, Clone, Default, Serialize, Deserialize)]
829pub struct ModelConfig {
830 #[serde(default)]
831 pub strategy: RoutingStrategy,
832
833 pub endpoints: Vec<Endpoint>,
834
835 #[serde(default, skip_serializing_if = "Option::is_none")]
837 pub name: Option<String>,
838
839 #[serde(default, skip_serializing_if = "Option::is_none")]
841 pub max_input_tokens: Option<u64>,
842
843 #[serde(default, skip_serializing_if = "Option::is_none")]
845 pub max_output_tokens: Option<u64>,
846
847 #[serde(default, skip_serializing_if = "Vec::is_empty")]
849 pub input_modalities: Vec<Modality>,
850
851 #[serde(default, skip_serializing_if = "Vec::is_empty")]
853 pub output_modalities: Vec<Modality>,
854
855 #[serde(default)]
857 pub pricing: ModelPricing,
858}
859
860#[derive(Debug, Clone, Default, Serialize, Deserialize)]
864pub struct ToolConfig {
865 #[serde(default)]
867 pub strategy: RoutingStrategy,
868
869 pub endpoints: Vec<Endpoint>,
871
872 #[serde(default, skip_serializing_if = "Option::is_none")]
874 pub pricing: Option<bitrouter_core::pricing::FlatPricing>,
875
876 #[serde(default, skip_serializing_if = "Option::is_none")]
878 pub description: Option<String>,
879
880 #[serde(default, skip_serializing_if = "Option::is_none")]
882 pub input_schema: Option<serde_json::Value>,
883
884 #[serde(default, skip_serializing_if = "Option::is_none")]
890 pub skill: Option<String>,
891}
892
893#[derive(Debug, Clone, Default, Serialize, Deserialize)]
901pub struct RoutingRuleConfig {
902 #[serde(default = "default_true")]
906 pub inherit_defaults: bool,
907
908 #[serde(default)]
910 pub signals: HashMap<String, SignalConfig>,
911
912 #[serde(default)]
915 pub complexity: ComplexityConfig,
916
917 #[serde(default)]
922 pub models: HashMap<String, String>,
923}
924
925#[derive(Debug, Clone, Default, Serialize, Deserialize)]
927pub struct SignalConfig {
928 #[serde(default)]
930 pub keywords: Vec<String>,
931}
932
933#[derive(Debug, Clone, Default, Serialize, Deserialize)]
934pub struct ComplexityConfig {
935 #[serde(default)]
937 pub high_keywords: Vec<String>,
938
939 #[serde(default)]
942 pub message_length_threshold: Option<usize>,
943
944 #[serde(default)]
947 pub turn_count_threshold: Option<usize>,
948
949 #[serde(default)]
951 pub code_blocks_increase_complexity: bool,
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957
958 #[test]
959 fn default_config_round_trips_through_yaml() {
960 let config = BitrouterConfig::default();
961 let yaml = serde_saphyr::to_string(&config).unwrap();
962 let parsed: BitrouterConfig = serde_saphyr::from_str(&yaml).unwrap();
963 assert_eq!(parsed.server.listen, config.server.listen);
964 }
965
966 #[test]
967 fn load_minimal_yaml() {
968 let yaml = r#"
969server:
970 listen: "127.0.0.1:9090"
971providers:
972 openai:
973 api_key: "sk-test"
974"#;
975 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
976 assert_eq!(config.server.listen, "127.0.0.1:9090".parse().unwrap());
977 assert!(config.providers.contains_key("openai"));
979 assert!(config.providers.contains_key("anthropic"));
980 assert_eq!(
981 config.providers["openai"].api_key.as_deref(),
982 Some("sk-test")
983 );
984 }
985
986 #[test]
987 fn load_with_custom_derived_provider() {
988 let yaml = r#"
989providers:
990 my-company:
991 derives: openai
992 api_base: "https://api.mycompany.com/v1"
993 api_key: "sk-custom"
994"#;
995 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
996 let p = &config.providers["my-company"];
997 assert_eq!(p.api_protocol, Some(ApiProtocol::Openai)); assert_eq!(p.api_base.as_deref(), Some("https://api.mycompany.com/v1")); assert_eq!(p.api_key.as_deref(), Some("sk-custom"));
1000 assert!(p.derives.is_none()); }
1002
1003 #[test]
1004 fn load_with_model_routing() {
1005 let yaml = r#"
1006providers:
1007 openai:
1008 api_key: "sk-test"
1009models:
1010 my-gpt4:
1011 strategy: load_balance
1012 endpoints:
1013 - provider: openai
1014 model_id: gpt-4o
1015 api_key: "sk-key-a"
1016 - provider: openai
1017 model_id: gpt-4o
1018 api_key: "sk-key-b"
1019"#;
1020 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1021 let model = &config.models["my-gpt4"];
1022 assert_eq!(model.strategy, RoutingStrategy::LoadBalance);
1023 assert_eq!(model.endpoints.len(), 2);
1024 assert_eq!(model.endpoints[0].api_key.as_deref(), Some("sk-key-a"));
1025 }
1026
1027 #[test]
1028 fn load_with_custom_auth() {
1029 let yaml = r#"
1030providers:
1031 aimo:
1032 derives: openai
1033 api_base: "https://api.aimo.network/v1"
1034 auth:
1035 type: custom
1036 method: siwx
1037 params:
1038 chain_id: 1
1039"#;
1040 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1041 let p = &config.providers["aimo"];
1042 assert!(matches!(p.auth, Some(AuthConfig::Custom { .. })));
1043 if let Some(AuthConfig::Custom { method, .. }) = &p.auth {
1044 assert_eq!(method, "siwx");
1045 }
1046 }
1047
1048 #[test]
1049 fn empty_yaml_gets_full_builtins() {
1050 let config = BitrouterConfig::load_from_str("{}", None).unwrap();
1051 assert!(config.providers.contains_key("openai"));
1052 assert!(config.providers.contains_key("anthropic"));
1053 assert!(config.providers.contains_key("google"));
1054 }
1055
1056 #[test]
1057 fn load_with_provider_model_metadata() {
1058 let yaml = r#"
1059providers:
1060 openai:
1061 api_key: "sk-test"
1062 models:
1063 gpt-4o:
1064 name: "GPT-4o"
1065 max_input_tokens: 128000
1066 max_output_tokens: 16384
1067 input_modalities: [text, image]
1068 output_modalities: [text]
1069 pricing:
1070 input_tokens:
1071 no_cache: 2.50
1072 output_tokens:
1073 text: 10.00
1074 gpt-4o-mini:
1075 name: "GPT-4o Mini"
1076"#;
1077 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1078 let openai = &config.providers["openai"];
1079 let models = openai.models.as_ref().unwrap();
1080
1081 let gpt4o = &models["gpt-4o"];
1082 assert_eq!(gpt4o.name.as_deref(), Some("GPT-4o"));
1083 assert_eq!(gpt4o.max_input_tokens, Some(128000));
1084 assert_eq!(gpt4o.max_output_tokens, Some(16384));
1085 assert_eq!(
1086 gpt4o.input_modalities,
1087 vec![Modality::Text, Modality::Image]
1088 );
1089 assert_eq!(gpt4o.pricing.input_tokens.no_cache, Some(2.50));
1090 assert_eq!(gpt4o.pricing.output_tokens.text, Some(10.00));
1091
1092 let mini = &models["gpt-4o-mini"];
1093 assert_eq!(mini.name.as_deref(), Some("GPT-4o Mini"));
1094 assert_eq!(mini.pricing.input_tokens.no_cache, None); }
1096
1097 #[test]
1098 fn derives_inherits_model_catalog() {
1099 let yaml = r#"
1100providers:
1101 my-openai:
1102 derives: openai
1103 api_key: "sk-custom"
1104"#;
1105 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1106 let my_openai = &config.providers["my-openai"];
1107 let models = my_openai.models.as_ref().unwrap();
1109 assert!(models.contains_key("gpt-4o"));
1110 }
1111
1112 #[test]
1113 fn inherit_defaults_true_by_default() {
1114 let config = BitrouterConfig::load_from_str("{}", None).unwrap();
1115 assert!(config.inherit_defaults);
1116 assert!(config.providers.contains_key("openai"));
1117 assert!(config.providers.contains_key("bitrouter"));
1118 }
1119
1120 #[test]
1121 fn inherit_defaults_false_excludes_builtins() {
1122 let yaml = r#"
1123inherit_defaults: false
1124providers:
1125 custom:
1126 api_protocol: openai
1127 api_base: "https://custom.example.com/v1"
1128 api_key: "sk-custom"
1129"#;
1130 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1131 assert!(!config.inherit_defaults);
1132 assert!(config.providers.contains_key("custom"));
1133 assert!(!config.providers.contains_key("openai"));
1134 assert!(!config.providers.contains_key("bitrouter"));
1135 assert_eq!(config.providers.len(), 1);
1136 }
1137
1138 #[test]
1139 fn load_with_tool_routing() {
1140 let yaml = r#"
1141providers:
1142 github-mcp:
1143 api_protocol: mcp
1144 api_base: "https://api.githubcopilot.com/mcp"
1145 api_key: "ghp-test"
1146tools:
1147 create_issue:
1148 strategy: priority
1149 endpoints:
1150 - provider: github-mcp
1151 tool_id: create_issue
1152 search_code:
1153 endpoints:
1154 - provider: github-mcp
1155 tool_id: search_code
1156 api_protocol: mcp
1157"#;
1158 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1159 assert!(config.tools.len() >= 2);
1161 assert!(config.tools.contains_key("create_issue"));
1162 assert!(config.tools.contains_key("search_code"));
1163
1164 let tool = &config.tools["create_issue"];
1165 assert_eq!(tool.strategy, RoutingStrategy::Priority);
1166 assert_eq!(tool.endpoints.len(), 1);
1167 assert_eq!(tool.endpoints[0].provider, "github-mcp");
1168 assert_eq!(tool.endpoints[0].service_id, "create_issue");
1169 assert!(tool.endpoints[0].api_protocol.is_none());
1170
1171 let search = &config.tools["search_code"];
1172 assert_eq!(search.endpoints[0].api_protocol, Some(ApiProtocol::Mcp));
1173 }
1174
1175 #[test]
1176 fn full_template_deserializes() {
1177 let yaml = include_str!("../templates/full.yaml");
1178 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1179
1180 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1182 assert_eq!(config.server.log_level, "info");
1183
1184 assert!(config.database.url.is_some());
1186
1187 assert!(config.providers.contains_key("openai"));
1189 assert!(config.providers.contains_key("anthropic"));
1190 assert!(config.providers.contains_key("google"));
1191 assert!(config.providers.contains_key("my-proxy"));
1192 assert!(config.providers.contains_key("custom-llm"));
1193 assert!(config.providers.contains_key("github-mcp"));
1194 assert!(config.providers.contains_key("header-auth-provider"));
1195 assert!(config.providers.contains_key("paid-provider"));
1196
1197 let my_proxy = &config.providers["my-proxy"];
1199 assert_eq!(my_proxy.api_protocol, Some(ApiProtocol::Openai));
1200 assert!(my_proxy.derives.is_none()); let custom = &config.providers["custom-llm"];
1204 assert_eq!(custom.api_protocol, Some(ApiProtocol::Openai));
1205 let models = custom.models.as_ref().unwrap();
1206 assert!(models.contains_key("my-model-7b"));
1207
1208 assert_eq!(config.models.len(), 3);
1210 assert!(config.models.contains_key("smart"));
1211 assert!(config.models.contains_key("fast"));
1212 assert!(config.models.contains_key("coding"));
1213 assert_eq!(config.models["smart"].strategy, RoutingStrategy::Priority);
1214 assert_eq!(config.models["fast"].strategy, RoutingStrategy::LoadBalance);
1215
1216 assert!(config.tools.contains_key("create_issue"));
1218 assert!(config.tools.contains_key("web_search"));
1219
1220 assert!(config.guardrails.enabled);
1222 assert!(!config.guardrails.disabled_patterns.is_empty());
1223 assert!(!config.guardrails.custom_patterns.is_empty());
1224 assert!(!config.guardrails.upgoing.is_empty());
1225 assert!(!config.guardrails.downgoing.is_empty());
1226
1227 let wallet = config.wallet.as_ref().unwrap();
1229 assert_eq!(wallet.name, "my-wallet");
1230 assert!(wallet.payment.is_some());
1231
1232 let mpp = config.mpp.as_ref().unwrap();
1234 assert!(mpp.enabled);
1235 }
1236
1237 #[test]
1238 fn minimal_template_deserializes() {
1239 let yaml = "";
1240 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1241
1242 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1244 assert_eq!(config.server.log_level, "info");
1245
1246 assert!(config.inherit_defaults);
1248 assert!(config.providers.contains_key("openai"));
1249 assert!(config.providers.contains_key("anthropic"));
1250 assert!(config.providers.contains_key("google"));
1251 assert!(config.providers.contains_key("bitrouter"));
1252
1253 assert!(config.models.is_empty());
1255
1256 assert!(config.wallet.is_none());
1258 assert!(config.mpp.is_none());
1259
1260 assert!(config.guardrails.enabled);
1262 }
1263
1264 #[test]
1265 fn empty_string_deserializes() {
1266 let config = BitrouterConfig::load_from_str("", None).unwrap();
1267
1268 assert_eq!(config.server.listen, "127.0.0.1:8787".parse().unwrap());
1270 assert!(config.inherit_defaults);
1271 assert!(config.providers.contains_key("openai"));
1272 assert!(config.providers.contains_key("anthropic"));
1273 assert!(config.providers.contains_key("google"));
1274 assert!(config.models.is_empty());
1275 assert!(config.guardrails.enabled);
1276 }
1277
1278 #[test]
1279 fn load_with_oauth_auth() {
1280 let yaml = r#"
1281providers:
1282 github-copilot:
1283 api_protocol: openai
1284 api_base: "https://api.githubcopilot.com"
1285 auth:
1286 type: oauth
1287 grant: device_code
1288 client_id: "Iv23limb4eFHH5zfOCr2"
1289 scope: "read:user"
1290 device_auth_url: "https://github.com/login/device/code"
1291 token_url: "https://github.com/login/oauth/access_token"
1292"#;
1293 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1294 let p = &config.providers["github-copilot"];
1295 assert!(matches!(p.auth, Some(AuthConfig::OAuth { .. })));
1296 if let Some(AuthConfig::OAuth {
1297 grant,
1298 client_id,
1299 scope,
1300 device_auth_url,
1301 token_url,
1302 }) = &p.auth
1303 {
1304 assert_eq!(*grant, OAuthGrant::DeviceCode);
1305 assert_eq!(client_id, "Iv23limb4eFHH5zfOCr2");
1306 assert_eq!(scope.as_deref(), Some("read:user"));
1307 assert_eq!(
1308 device_auth_url.as_deref(),
1309 Some("https://github.com/login/device/code")
1310 );
1311 assert_eq!(
1312 token_url.as_deref(),
1313 Some("https://github.com/login/oauth/access_token")
1314 );
1315 }
1316 }
1317
1318 #[test]
1319 fn load_oauth_with_defaults() {
1320 let yaml = r#"
1321providers:
1322 test-oauth:
1323 api_protocol: openai
1324 api_base: "https://api.example.com"
1325 auth:
1326 type: oauth
1327 grant: device_code
1328 client_id: "test-client-id"
1329"#;
1330 let config = BitrouterConfig::load_from_str(yaml, None).unwrap();
1331 let p = &config.providers["test-oauth"];
1332 if let Some(AuthConfig::OAuth {
1333 scope,
1334 device_auth_url,
1335 token_url,
1336 ..
1337 }) = &p.auth
1338 {
1339 assert!(scope.is_none());
1340 assert!(device_auth_url.is_none());
1341 assert!(token_url.is_none());
1342 } else {
1343 panic!("expected OAuth auth config");
1344 }
1345 }
1346}