Skip to main content

aether_cli/
provider_connection_args.rs

1use std::collections::BTreeMap;
2use std::str::FromStr;
3
4use llm::{ProviderAuthMode, ProviderConnectionOverride, ProviderConnectionOverrides};
5
6#[derive(Clone, Debug, Default, clap::Args)]
7pub struct ProviderConnectionArgs {
8    #[arg(long = "provider", value_name = "PROVIDER.url=URL|PROVIDER.auth=default|none")]
9    pub providers: Vec<ProviderArg>,
10}
11
12impl ProviderConnectionArgs {
13    pub fn into_overrides(self) -> ProviderConnectionOverrides {
14        let mut providers = BTreeMap::new();
15        for arg in self.providers {
16            providers.entry(arg.provider).or_insert_with(ProviderConnectionOverride::default).merge(arg.connection);
17        }
18        ProviderConnectionOverrides::new(providers)
19    }
20}
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct ProviderArg {
24    provider: String,
25    connection: ProviderConnectionOverride,
26}
27
28impl FromStr for ProviderArg {
29    type Err = String;
30
31    fn from_str(value: &str) -> Result<Self, Self::Err> {
32        let (key, setting) = split_key_value(value)?;
33        let (provider, field) = key
34            .split_once('.')
35            .ok_or_else(|| "provider override must be PROVIDER.url=URL or PROVIDER.auth=default|none".to_string())?;
36        validate_provider(provider)?;
37        if setting.trim().is_empty() {
38            return Err("provider value cannot be empty".to_string());
39        }
40
41        let connection = match field {
42            "url" => {
43                validate_url(setting)?;
44                ProviderConnectionOverride::url(setting)
45            }
46            "auth" => ProviderConnectionOverride::auth(parse_auth_mode(setting)?),
47            _ => return Err("provider override field must be url or auth".to_string()),
48        };
49
50        Ok(Self { provider: provider.to_string(), connection })
51    }
52}
53
54fn split_key_value(value: &str) -> Result<(&str, &str), String> {
55    value.split_once('=').ok_or_else(|| "provider override must be PROVIDER.FIELD=VALUE".to_string())
56}
57
58fn validate_provider(provider: &str) -> Result<(), String> {
59    if provider.trim().is_empty() {
60        return Err("provider name cannot be empty".to_string());
61    }
62    Ok(())
63}
64
65fn validate_url(url: &str) -> Result<(), String> {
66    let parsed = url::Url::parse(url).map_err(|error| format!("invalid provider URL: {error}"))?;
67    match parsed.scheme() {
68        "http" | "https" => Ok(()),
69        scheme => Err(format!("provider URL must use http or https, got {scheme}")),
70    }
71}
72
73fn parse_auth_mode(value: &str) -> Result<ProviderAuthMode, String> {
74    match value {
75        "default" => Ok(ProviderAuthMode::Default),
76        "none" => Ok(ProviderAuthMode::None),
77        _ => Err("provider auth mode must be default or none".to_string()),
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn parses_provider_url() {
87        let arg: ProviderArg = "bedrock.url=http://127.0.0.1:8787".parse().unwrap();
88        assert_eq!(arg.provider, "bedrock");
89        assert_eq!(arg.connection.base_url.as_deref(), Some("http://127.0.0.1:8787"));
90    }
91
92    #[test]
93    fn parses_provider_auth_modes() {
94        assert_eq!(
95            "bedrock.auth=none".parse::<ProviderArg>().unwrap().connection.auth_mode,
96            Some(ProviderAuthMode::None)
97        );
98        assert_eq!(
99            "bedrock.auth=default".parse::<ProviderArg>().unwrap().connection.auth_mode,
100            Some(ProviderAuthMode::Default)
101        );
102    }
103
104    #[test]
105    fn combines_repeated_provider_overrides() {
106        let args = ProviderConnectionArgs {
107            providers: vec!["bedrock.url=http://127.0.0.1:8787".parse().unwrap(), "bedrock.auth=none".parse().unwrap()],
108        };
109
110        let config = args.into_overrides().config_for("bedrock");
111
112        assert_eq!(config.base_url.as_deref(), Some("http://127.0.0.1:8787"));
113        assert_eq!(config.auth_mode, ProviderAuthMode::None);
114    }
115
116    #[test]
117    fn rejects_invalid_values() {
118        assert!("bedrock".parse::<ProviderArg>().is_err());
119        assert!("bedrock.url".parse::<ProviderArg>().is_err());
120        assert!(".url=http://127.0.0.1:8787".parse::<ProviderArg>().is_err());
121        assert!("bedrock.url=".parse::<ProviderArg>().is_err());
122        assert!("bedrock.url=file:///tmp/proxy".parse::<ProviderArg>().is_err());
123        assert!("bedrock.auth=disabled".parse::<ProviderArg>().is_err());
124        assert!("bedrock.region=us-west-2".parse::<ProviderArg>().is_err());
125    }
126}