llm/
provider_connection.rs1use std::collections::BTreeMap;
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
4#[serde(rename_all = "kebab-case")]
5pub enum ProviderAuthMode {
6 #[default]
7 Default,
8 None,
9}
10
11#[derive(Clone, Debug, Default, PartialEq, Eq)]
12pub struct ProviderConnectionConfig {
13 pub base_url: Option<String>,
14 pub auth_mode: ProviderAuthMode,
15 pub inference_profile_arn: Option<String>,
16}
17
18#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
19#[serde(rename_all = "camelCase", deny_unknown_fields)]
20pub struct ProviderConnectionOverride {
21 #[serde(default, rename = "url", skip_serializing_if = "Option::is_none")]
22 pub base_url: Option<String>,
23 #[serde(default, rename = "auth", skip_serializing_if = "Option::is_none")]
24 pub auth_mode: Option<ProviderAuthMode>,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub inference_profile_arn: Option<String>,
27}
28
29#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
30#[serde(transparent)]
31pub struct ProviderConnectionOverrides {
32 providers: BTreeMap<String, ProviderConnectionOverride>,
33}
34
35impl ProviderConnectionConfig {
36 pub fn from_override(value: ProviderConnectionOverride) -> Self {
37 Self {
38 base_url: value.base_url,
39 auth_mode: value.auth_mode.unwrap_or_default(),
40 inference_profile_arn: value.inference_profile_arn,
41 }
42 }
43}
44
45impl ProviderConnectionOverride {
46 pub fn url(url: impl Into<String>) -> Self {
47 Self { base_url: Some(url.into()), ..Self::default() }
48 }
49
50 pub fn auth(auth_mode: ProviderAuthMode) -> Self {
51 Self { auth_mode: Some(auth_mode), ..Self::default() }
52 }
53
54 pub fn inference_profile_arn(arn: impl Into<String>) -> Self {
55 Self { inference_profile_arn: Some(arn.into()), ..Self::default() }
56 }
57
58 pub fn merge(&mut self, override_value: Self) {
59 if override_value.base_url.is_some() {
60 self.base_url = override_value.base_url;
61 }
62 if override_value.auth_mode.is_some() {
63 self.auth_mode = override_value.auth_mode;
64 }
65 if override_value.inference_profile_arn.is_some() {
66 self.inference_profile_arn = override_value.inference_profile_arn;
67 }
68 }
69}
70
71impl ProviderConnectionOverrides {
72 pub fn new(providers: BTreeMap<String, ProviderConnectionOverride>) -> Self {
73 Self { providers }
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.providers.is_empty()
78 }
79
80 pub fn merge(&mut self, overrides: ProviderConnectionOverrides) {
81 for (provider, override_value) in overrides.providers {
82 self.providers
83 .entry(provider)
84 .and_modify(|existing| existing.merge(override_value.clone()))
85 .or_insert(override_value);
86 }
87 }
88
89 pub fn config_for(&self, provider: &str) -> ProviderConnectionConfig {
90 self.providers.get(provider).cloned().map(ProviderConnectionConfig::from_override).unwrap_or_default()
91 }
92
93 pub fn into_inner(self) -> BTreeMap<String, ProviderConnectionOverride> {
94 self.providers
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn deserializes_bedrock_inference_profile_arn() {
104 let overrides: ProviderConnectionOverrides = serde_json::from_str(
105 r#"{"bedrock":{"inferenceProfileArn":"arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"}}"#,
106 )
107 .unwrap();
108
109 let config = overrides.config_for("bedrock");
110
111 assert_eq!(
112 config.inference_profile_arn.as_deref(),
113 Some("arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000")
114 );
115 }
116
117 #[test]
118 fn merge_replaces_inference_profile_arn() {
119 let mut first = ProviderConnectionOverride::inference_profile_arn("arn:first");
120
121 first.merge(ProviderConnectionOverride::inference_profile_arn("arn:second"));
122
123 assert_eq!(first.inference_profile_arn.as_deref(), Some("arn:second"));
124 }
125
126 #[test]
127 fn provider_overrides_merge_inference_profile_arn() {
128 let mut first = ProviderConnectionOverrides::new(BTreeMap::from([(
129 "bedrock".to_string(),
130 ProviderConnectionOverride::inference_profile_arn("arn:first"),
131 )]));
132 let second = ProviderConnectionOverrides::new(BTreeMap::from([(
133 "bedrock".to_string(),
134 ProviderConnectionOverride::inference_profile_arn("arn:second"),
135 )]));
136
137 first.merge(second);
138
139 assert_eq!(first.config_for("bedrock").inference_profile_arn.as_deref(), Some("arn:second"));
140 }
141}