Skip to main content

llm/
provider_connection.rs

1use std::collections::BTreeMap;
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
4#[serde(rename_all = "kebab-case")]
5pub enum ProviderAuthMode {
6    #[default]
7    Default,
8    None,
9}
10
11#[derive(Clone, Debug, Default, PartialEq, Eq)]
12pub struct ProviderConnectionConfig {
13    pub base_url: Option<String>,
14    pub auth_mode: ProviderAuthMode,
15    pub inference_profile_arn: Option<String>,
16}
17
18#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
19#[serde(rename_all = "camelCase", deny_unknown_fields)]
20pub struct ProviderConnectionOverride {
21    #[serde(default, rename = "url", skip_serializing_if = "Option::is_none")]
22    pub base_url: Option<String>,
23    #[serde(default, rename = "auth", skip_serializing_if = "Option::is_none")]
24    pub auth_mode: Option<ProviderAuthMode>,
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub inference_profile_arn: Option<String>,
27}
28
29#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
30#[serde(transparent)]
31pub struct ProviderConnectionOverrides {
32    providers: BTreeMap<String, ProviderConnectionOverride>,
33}
34
35impl ProviderConnectionConfig {
36    pub fn from_override(value: ProviderConnectionOverride) -> Self {
37        Self {
38            base_url: value.base_url,
39            auth_mode: value.auth_mode.unwrap_or_default(),
40            inference_profile_arn: value.inference_profile_arn,
41        }
42    }
43}
44
45impl ProviderConnectionOverride {
46    pub fn url(url: impl Into<String>) -> Self {
47        Self { base_url: Some(url.into()), ..Self::default() }
48    }
49
50    pub fn auth(auth_mode: ProviderAuthMode) -> Self {
51        Self { auth_mode: Some(auth_mode), ..Self::default() }
52    }
53
54    pub fn inference_profile_arn(arn: impl Into<String>) -> Self {
55        Self { inference_profile_arn: Some(arn.into()), ..Self::default() }
56    }
57
58    pub fn merge(&mut self, override_value: Self) {
59        if override_value.base_url.is_some() {
60            self.base_url = override_value.base_url;
61        }
62        if override_value.auth_mode.is_some() {
63            self.auth_mode = override_value.auth_mode;
64        }
65        if override_value.inference_profile_arn.is_some() {
66            self.inference_profile_arn = override_value.inference_profile_arn;
67        }
68    }
69}
70
71impl ProviderConnectionOverrides {
72    pub fn new(providers: BTreeMap<String, ProviderConnectionOverride>) -> Self {
73        Self { providers }
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.providers.is_empty()
78    }
79
80    pub fn merge(&mut self, overrides: ProviderConnectionOverrides) {
81        for (provider, override_value) in overrides.providers {
82            self.providers
83                .entry(provider)
84                .and_modify(|existing| existing.merge(override_value.clone()))
85                .or_insert(override_value);
86        }
87    }
88
89    pub fn config_for(&self, provider: &str) -> ProviderConnectionConfig {
90        self.providers.get(provider).cloned().map(ProviderConnectionConfig::from_override).unwrap_or_default()
91    }
92
93    pub fn into_inner(self) -> BTreeMap<String, ProviderConnectionOverride> {
94        self.providers
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn deserializes_bedrock_inference_profile_arn() {
104        let overrides: ProviderConnectionOverrides = serde_json::from_str(
105            r#"{"bedrock":{"inferenceProfileArn":"arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"}}"#,
106        )
107        .unwrap();
108
109        let config = overrides.config_for("bedrock");
110
111        assert_eq!(
112            config.inference_profile_arn.as_deref(),
113            Some("arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000")
114        );
115    }
116
117    #[test]
118    fn merge_replaces_inference_profile_arn() {
119        let mut first = ProviderConnectionOverride::inference_profile_arn("arn:first");
120
121        first.merge(ProviderConnectionOverride::inference_profile_arn("arn:second"));
122
123        assert_eq!(first.inference_profile_arn.as_deref(), Some("arn:second"));
124    }
125
126    #[test]
127    fn provider_overrides_merge_inference_profile_arn() {
128        let mut first = ProviderConnectionOverrides::new(BTreeMap::from([(
129            "bedrock".to_string(),
130            ProviderConnectionOverride::inference_profile_arn("arn:first"),
131        )]));
132        let second = ProviderConnectionOverrides::new(BTreeMap::from([(
133            "bedrock".to_string(),
134            ProviderConnectionOverride::inference_profile_arn("arn:second"),
135        )]));
136
137        first.merge(second);
138
139        assert_eq!(first.config_for("bedrock").inference_profile_arn.as_deref(), Some("arn:second"));
140    }
141}