Skip to main content

llm/
provider_connection.rs

1use std::collections::BTreeMap;
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
4#[serde(rename_all = "kebab-case")]
5pub enum ProviderAuthMode {
6    #[default]
7    Default,
8    None,
9}
10
11#[derive(Clone, Debug, Default, PartialEq, Eq)]
12pub struct ProviderConnectionConfig {
13    pub base_url: Option<String>,
14    pub auth_mode: ProviderAuthMode,
15    pub inference_profile_arn: Option<String>,
16}
17
18#[doc = include_str!("docs/provider_connection_override.md")]
19#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20#[serde(rename_all = "camelCase", deny_unknown_fields)]
21pub struct ProviderConnectionOverride {
22    /// Base URL override for the provider's API endpoint.
23    #[serde(default, rename = "url", skip_serializing_if = "Option::is_none")]
24    pub base_url: Option<String>,
25    /// Authentication mode. `default` uses the provider's normal credential
26    /// chain; `none` disables auth, for local or unauthenticated servers.
27    #[serde(default, rename = "auth", skip_serializing_if = "Option::is_none")]
28    pub auth_mode: Option<ProviderAuthMode>,
29    /// AWS Bedrock application inference profile ARN to route requests through.
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub inference_profile_arn: Option<String>,
32}
33
34#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
35#[serde(transparent)]
36pub struct ProviderConnectionOverrides {
37    providers: BTreeMap<String, ProviderConnectionOverride>,
38}
39
40impl ProviderConnectionConfig {
41    pub fn from_override(value: ProviderConnectionOverride) -> Self {
42        Self {
43            base_url: value.base_url,
44            auth_mode: value.auth_mode.unwrap_or_default(),
45            inference_profile_arn: value.inference_profile_arn,
46        }
47    }
48}
49
50impl ProviderConnectionOverride {
51    pub fn url(url: impl Into<String>) -> Self {
52        Self { base_url: Some(url.into()), ..Self::default() }
53    }
54
55    pub fn auth(auth_mode: ProviderAuthMode) -> Self {
56        Self { auth_mode: Some(auth_mode), ..Self::default() }
57    }
58
59    pub fn inference_profile_arn(arn: impl Into<String>) -> Self {
60        Self { inference_profile_arn: Some(arn.into()), ..Self::default() }
61    }
62
63    pub fn merge(&mut self, override_value: Self) {
64        if override_value.base_url.is_some() {
65            self.base_url = override_value.base_url;
66        }
67        if override_value.auth_mode.is_some() {
68            self.auth_mode = override_value.auth_mode;
69        }
70        if override_value.inference_profile_arn.is_some() {
71            self.inference_profile_arn = override_value.inference_profile_arn;
72        }
73    }
74}
75
76impl ProviderConnectionOverrides {
77    pub fn new(providers: BTreeMap<String, ProviderConnectionOverride>) -> Self {
78        Self { providers }
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.providers.is_empty()
83    }
84
85    pub fn merge(&mut self, overrides: ProviderConnectionOverrides) {
86        for (provider, override_value) in overrides.providers {
87            self.providers
88                .entry(provider)
89                .and_modify(|existing| existing.merge(override_value.clone()))
90                .or_insert(override_value);
91        }
92    }
93
94    pub fn config_for(&self, provider: &str) -> ProviderConnectionConfig {
95        self.providers.get(provider).cloned().map(ProviderConnectionConfig::from_override).unwrap_or_default()
96    }
97
98    pub fn into_inner(self) -> BTreeMap<String, ProviderConnectionOverride> {
99        self.providers
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn deserializes_bedrock_inference_profile_arn() {
109        let overrides: ProviderConnectionOverrides = serde_json::from_str(
110            r#"{"bedrock":{"inferenceProfileArn":"arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"}}"#,
111        )
112        .unwrap();
113
114        let config = overrides.config_for("bedrock");
115
116        assert_eq!(
117            config.inference_profile_arn.as_deref(),
118            Some("arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000")
119        );
120    }
121
122    #[test]
123    fn merge_replaces_inference_profile_arn() {
124        let mut first = ProviderConnectionOverride::inference_profile_arn("arn:first");
125
126        first.merge(ProviderConnectionOverride::inference_profile_arn("arn:second"));
127
128        assert_eq!(first.inference_profile_arn.as_deref(), Some("arn:second"));
129    }
130
131    #[test]
132    fn provider_overrides_merge_inference_profile_arn() {
133        let mut first = ProviderConnectionOverrides::new(BTreeMap::from([(
134            "bedrock".to_string(),
135            ProviderConnectionOverride::inference_profile_arn("arn:first"),
136        )]));
137        let second = ProviderConnectionOverrides::new(BTreeMap::from([(
138            "bedrock".to_string(),
139            ProviderConnectionOverride::inference_profile_arn("arn:second"),
140        )]));
141
142        first.merge(second);
143
144        assert_eq!(first.config_for("bedrock").inference_profile_arn.as_deref(), Some("arn:second"));
145    }
146}