aether_cli/
provider_connection_args.rs1use std::collections::BTreeMap;
2use std::str::FromStr;
3
4use llm::{ProviderAuthMode, ProviderConnectionOverride, ProviderConnectionOverrides};
5
6#[derive(Clone, Debug, Default, clap::Args)]
7pub struct ProviderConnectionArgs {
8 #[arg(long = "provider", value_name = "PROVIDER.url=URL|PROVIDER.auth=default|none")]
9 pub providers: Vec<ProviderArg>,
10}
11
12impl ProviderConnectionArgs {
13 pub fn into_overrides(self) -> ProviderConnectionOverrides {
14 let mut providers = BTreeMap::new();
15 for arg in self.providers {
16 providers.entry(arg.provider).or_insert_with(ProviderConnectionOverride::default).merge(arg.connection);
17 }
18 ProviderConnectionOverrides::new(providers)
19 }
20}
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct ProviderArg {
24 provider: String,
25 connection: ProviderConnectionOverride,
26}
27
28impl FromStr for ProviderArg {
29 type Err = String;
30
31 fn from_str(value: &str) -> Result<Self, Self::Err> {
32 let (key, setting) = split_key_value(value)?;
33 let (provider, field) = key
34 .split_once('.')
35 .ok_or_else(|| "provider override must be PROVIDER.url=URL or PROVIDER.auth=default|none".to_string())?;
36 validate_provider(provider)?;
37 if setting.trim().is_empty() {
38 return Err("provider value cannot be empty".to_string());
39 }
40
41 let connection = match field {
42 "url" => {
43 validate_url(setting)?;
44 ProviderConnectionOverride::url(setting)
45 }
46 "auth" => ProviderConnectionOverride::auth(parse_auth_mode(setting)?),
47 _ => return Err("provider override field must be url or auth".to_string()),
48 };
49
50 Ok(Self { provider: provider.to_string(), connection })
51 }
52}
53
54fn split_key_value(value: &str) -> Result<(&str, &str), String> {
55 value.split_once('=').ok_or_else(|| "provider override must be PROVIDER.FIELD=VALUE".to_string())
56}
57
58fn validate_provider(provider: &str) -> Result<(), String> {
59 if provider.trim().is_empty() {
60 return Err("provider name cannot be empty".to_string());
61 }
62 Ok(())
63}
64
65fn validate_url(url: &str) -> Result<(), String> {
66 let parsed = url::Url::parse(url).map_err(|error| format!("invalid provider URL: {error}"))?;
67 match parsed.scheme() {
68 "http" | "https" => Ok(()),
69 scheme => Err(format!("provider URL must use http or https, got {scheme}")),
70 }
71}
72
73fn parse_auth_mode(value: &str) -> Result<ProviderAuthMode, String> {
74 match value {
75 "default" => Ok(ProviderAuthMode::Default),
76 "none" => Ok(ProviderAuthMode::None),
77 _ => Err("provider auth mode must be default or none".to_string()),
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn parses_provider_url() {
87 let arg: ProviderArg = "bedrock.url=http://127.0.0.1:8787".parse().unwrap();
88 assert_eq!(arg.provider, "bedrock");
89 assert_eq!(arg.connection.base_url.as_deref(), Some("http://127.0.0.1:8787"));
90 }
91
92 #[test]
93 fn parses_provider_auth_modes() {
94 assert_eq!(
95 "bedrock.auth=none".parse::<ProviderArg>().unwrap().connection.auth_mode,
96 Some(ProviderAuthMode::None)
97 );
98 assert_eq!(
99 "bedrock.auth=default".parse::<ProviderArg>().unwrap().connection.auth_mode,
100 Some(ProviderAuthMode::Default)
101 );
102 }
103
104 #[test]
105 fn combines_repeated_provider_overrides() {
106 let args = ProviderConnectionArgs {
107 providers: vec!["bedrock.url=http://127.0.0.1:8787".parse().unwrap(), "bedrock.auth=none".parse().unwrap()],
108 };
109
110 let config = args.into_overrides().config_for("bedrock");
111
112 assert_eq!(config.base_url.as_deref(), Some("http://127.0.0.1:8787"));
113 assert_eq!(config.auth_mode, ProviderAuthMode::None);
114 }
115
116 #[test]
117 fn rejects_invalid_values() {
118 assert!("bedrock".parse::<ProviderArg>().is_err());
119 assert!("bedrock.url".parse::<ProviderArg>().is_err());
120 assert!(".url=http://127.0.0.1:8787".parse::<ProviderArg>().is_err());
121 assert!("bedrock.url=".parse::<ProviderArg>().is_err());
122 assert!("bedrock.url=file:///tmp/proxy".parse::<ProviderArg>().is_err());
123 assert!("bedrock.auth=disabled".parse::<ProviderArg>().is_err());
124 assert!("bedrock.region=us-west-2".parse::<ProviderArg>().is_err());
125 }
126}