llm/
provider_connection.rs1use std::collections::BTreeMap;
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
4#[serde(rename_all = "kebab-case")]
5pub enum ProviderAuthMode {
6 #[default]
7 Default,
8 None,
9}
10
11#[derive(Clone, Debug, Default, PartialEq, Eq)]
12pub struct ProviderConnectionConfig {
13 pub base_url: Option<String>,
14 pub auth_mode: ProviderAuthMode,
15 pub inference_profile_arn: Option<String>,
16}
17
18#[doc = include_str!("docs/provider_connection_override.md")]
19#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20#[serde(rename_all = "camelCase", deny_unknown_fields)]
21pub struct ProviderConnectionOverride {
22 #[serde(default, rename = "url", skip_serializing_if = "Option::is_none")]
24 pub base_url: Option<String>,
25 #[serde(default, rename = "auth", skip_serializing_if = "Option::is_none")]
28 pub auth_mode: Option<ProviderAuthMode>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub inference_profile_arn: Option<String>,
32}
33
34#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
35#[serde(transparent)]
36pub struct ProviderConnectionOverrides {
37 providers: BTreeMap<String, ProviderConnectionOverride>,
38}
39
40impl ProviderConnectionConfig {
41 pub fn from_override(value: ProviderConnectionOverride) -> Self {
42 Self {
43 base_url: value.base_url,
44 auth_mode: value.auth_mode.unwrap_or_default(),
45 inference_profile_arn: value.inference_profile_arn,
46 }
47 }
48}
49
50impl ProviderConnectionOverride {
51 pub fn url(url: impl Into<String>) -> Self {
52 Self { base_url: Some(url.into()), ..Self::default() }
53 }
54
55 pub fn auth(auth_mode: ProviderAuthMode) -> Self {
56 Self { auth_mode: Some(auth_mode), ..Self::default() }
57 }
58
59 pub fn inference_profile_arn(arn: impl Into<String>) -> Self {
60 Self { inference_profile_arn: Some(arn.into()), ..Self::default() }
61 }
62
63 pub fn merge(&mut self, override_value: Self) {
64 if override_value.base_url.is_some() {
65 self.base_url = override_value.base_url;
66 }
67 if override_value.auth_mode.is_some() {
68 self.auth_mode = override_value.auth_mode;
69 }
70 if override_value.inference_profile_arn.is_some() {
71 self.inference_profile_arn = override_value.inference_profile_arn;
72 }
73 }
74}
75
76impl ProviderConnectionOverrides {
77 pub fn new(providers: BTreeMap<String, ProviderConnectionOverride>) -> Self {
78 Self { providers }
79 }
80
81 pub fn is_empty(&self) -> bool {
82 self.providers.is_empty()
83 }
84
85 pub fn merge(&mut self, overrides: ProviderConnectionOverrides) {
86 for (provider, override_value) in overrides.providers {
87 self.providers
88 .entry(provider)
89 .and_modify(|existing| existing.merge(override_value.clone()))
90 .or_insert(override_value);
91 }
92 }
93
94 pub fn config_for(&self, provider: &str) -> ProviderConnectionConfig {
95 self.providers.get(provider).cloned().map(ProviderConnectionConfig::from_override).unwrap_or_default()
96 }
97
98 pub fn into_inner(self) -> BTreeMap<String, ProviderConnectionOverride> {
99 self.providers
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn deserializes_bedrock_inference_profile_arn() {
109 let overrides: ProviderConnectionOverrides = serde_json::from_str(
110 r#"{"bedrock":{"inferenceProfileArn":"arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"}}"#,
111 )
112 .unwrap();
113
114 let config = overrides.config_for("bedrock");
115
116 assert_eq!(
117 config.inference_profile_arn.as_deref(),
118 Some("arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000")
119 );
120 }
121
122 #[test]
123 fn merge_replaces_inference_profile_arn() {
124 let mut first = ProviderConnectionOverride::inference_profile_arn("arn:first");
125
126 first.merge(ProviderConnectionOverride::inference_profile_arn("arn:second"));
127
128 assert_eq!(first.inference_profile_arn.as_deref(), Some("arn:second"));
129 }
130
131 #[test]
132 fn provider_overrides_merge_inference_profile_arn() {
133 let mut first = ProviderConnectionOverrides::new(BTreeMap::from([(
134 "bedrock".to_string(),
135 ProviderConnectionOverride::inference_profile_arn("arn:first"),
136 )]));
137 let second = ProviderConnectionOverrides::new(BTreeMap::from([(
138 "bedrock".to_string(),
139 ProviderConnectionOverride::inference_profile_arn("arn:second"),
140 )]));
141
142 first.merge(second);
143
144 assert_eq!(first.config_for("bedrock").inference_profile_arn.as_deref(), Some("arn:second"));
145 }
146}