Skip to main content

aether_cli/
provider_connection_args.rs

1use std::collections::BTreeMap;
2use std::str::FromStr;
3
4use llm::{ProviderAuthMode, ProviderConnectionOverride, ProviderConnectionOverrides};
5
6#[derive(Clone, Debug, Default, clap::Args)]
7pub struct ProviderConnectionArgs {
8    #[arg(
9        long = "provider",
10        value_name = "PROVIDER.url=URL|PROVIDER.auth=default|none|bedrock.inference-profile-arn=ARN"
11    )]
12    pub providers: Vec<ProviderArg>,
13}
14
15impl ProviderConnectionArgs {
16    pub fn into_overrides(self) -> ProviderConnectionOverrides {
17        let mut providers = BTreeMap::new();
18        for arg in self.providers {
19            providers.entry(arg.provider).or_insert_with(ProviderConnectionOverride::default).merge(arg.connection);
20        }
21        ProviderConnectionOverrides::new(providers)
22    }
23}
24
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct ProviderArg {
27    provider: String,
28    connection: ProviderConnectionOverride,
29}
30
31impl FromStr for ProviderArg {
32    type Err = String;
33
34    fn from_str(value: &str) -> Result<Self, Self::Err> {
35        let (key, setting) = split_key_value(value)?;
36        let (provider, field) = key
37            .split_once('.')
38            .ok_or_else(|| "provider override must be PROVIDER.url=URL, PROVIDER.auth=default|none, or bedrock.inference-profile-arn=ARN".to_string())?;
39
40        validate_provider(provider)?;
41        if setting.trim().is_empty() {
42            return Err("provider value cannot be empty".to_string());
43        }
44
45        let connection = match field {
46            "url" => {
47                validate_url(setting)?;
48                ProviderConnectionOverride::url(setting)
49            }
50            "auth" => ProviderConnectionOverride::auth(parse_auth_mode(setting)?),
51            "inference-profile-arn" => {
52                if provider != "bedrock" {
53                    return Err("inference-profile-arn is only supported for the bedrock provider".to_string());
54                }
55                ProviderConnectionOverride::inference_profile_arn(setting)
56            }
57            _ => return Err("provider override field must be url, auth, or inference-profile-arn".to_string()),
58        };
59
60        Ok(Self { provider: provider.to_string(), connection })
61    }
62}
63
64fn split_key_value(value: &str) -> Result<(&str, &str), String> {
65    value.split_once('=').ok_or_else(|| "provider override must be PROVIDER.FIELD=VALUE".to_string())
66}
67
68fn validate_provider(provider: &str) -> Result<(), String> {
69    if provider.trim().is_empty() {
70        return Err("provider name cannot be empty".to_string());
71    }
72    Ok(())
73}
74
75fn validate_url(url: &str) -> Result<(), String> {
76    let parsed = url::Url::parse(url).map_err(|error| format!("invalid provider URL: {error}"))?;
77    match parsed.scheme() {
78        "http" | "https" => Ok(()),
79        scheme => Err(format!("provider URL must use http or https, got {scheme}")),
80    }
81}
82
83fn parse_auth_mode(value: &str) -> Result<ProviderAuthMode, String> {
84    match value {
85        "default" => Ok(ProviderAuthMode::Default),
86        "none" => Ok(ProviderAuthMode::None),
87        _ => Err("provider auth mode must be default or none".to_string()),
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn parses_provider_url() {
97        let arg: ProviderArg = "bedrock.url=http://127.0.0.1:8787".parse().unwrap();
98        assert_eq!(arg.provider, "bedrock");
99        assert_eq!(arg.connection.base_url.as_deref(), Some("http://127.0.0.1:8787"));
100    }
101
102    #[test]
103    fn parses_provider_auth_modes() {
104        assert_eq!(
105            "bedrock.auth=none".parse::<ProviderArg>().unwrap().connection.auth_mode,
106            Some(ProviderAuthMode::None)
107        );
108        assert_eq!(
109            "bedrock.auth=default".parse::<ProviderArg>().unwrap().connection.auth_mode,
110            Some(ProviderAuthMode::Default)
111        );
112    }
113
114    #[test]
115    fn combines_repeated_provider_overrides() {
116        let args = ProviderConnectionArgs {
117            providers: vec![
118                "bedrock.url=http://127.0.0.1:8787".parse().unwrap(),
119                "bedrock.auth=none".parse().unwrap(),
120                "bedrock.inference-profile-arn=arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"
121                    .parse()
122                    .unwrap(),
123            ],
124        };
125
126        let config = args.into_overrides().config_for("bedrock");
127
128        assert_eq!(config.base_url.as_deref(), Some("http://127.0.0.1:8787"));
129        assert_eq!(config.auth_mode, ProviderAuthMode::None);
130        assert_eq!(
131            config.inference_profile_arn.as_deref(),
132            Some("arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000")
133        );
134    }
135
136    #[test]
137    fn parses_bedrock_inference_profile_arn() {
138        let arg: ProviderArg =
139            "bedrock.inference-profile-arn=arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-sonnet-4-5-20250929-v1:0"
140                .parse()
141                .unwrap();
142
143        assert_eq!(arg.provider, "bedrock");
144        assert_eq!(
145            arg.connection.inference_profile_arn.as_deref(),
146            Some(
147                "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-sonnet-4-5-20250929-v1:0"
148            )
149        );
150    }
151
152    #[test]
153    fn rejects_invalid_values() {
154        assert!("bedrock".parse::<ProviderArg>().is_err());
155        assert!("bedrock.url".parse::<ProviderArg>().is_err());
156        assert!(".url=http://127.0.0.1:8787".parse::<ProviderArg>().is_err());
157        assert!("bedrock.url=".parse::<ProviderArg>().is_err());
158        assert!("bedrock.url=file:///tmp/proxy".parse::<ProviderArg>().is_err());
159        assert!("bedrock.auth=disabled".parse::<ProviderArg>().is_err());
160        assert!("bedrock.region=us-west-2".parse::<ProviderArg>().is_err());
161        assert!("openai.inference-profile-arn=arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000".parse::<ProviderArg>().is_err());
162    }
163}