llm_kit_azure/
settings.rs1use std::collections::HashMap;
2
3#[derive(Debug, Clone)]
11pub struct AzureOpenAIProviderSettings {
12 pub resource_name: Option<String>,
17
18 pub base_url: Option<String>,
24
25 pub api_key: Option<String>,
28
29 pub headers: Option<HashMap<String, String>>,
31
32 pub api_version: String,
37
38 pub use_deployment_based_urls: bool,
46}
47
48impl Default for AzureOpenAIProviderSettings {
49 fn default() -> Self {
50 Self {
51 resource_name: None,
52 base_url: None,
53 api_key: None,
54 headers: None,
55 api_version: "v1".to_string(),
56 use_deployment_based_urls: false,
57 }
58 }
59}
60
61impl AzureOpenAIProviderSettings {
62 pub fn new() -> Self {
66 Self::default()
67 }
68
69 pub fn with_resource_name(mut self, resource_name: impl Into<String>) -> Self {
74 self.resource_name = Some(resource_name.into());
75 self
76 }
77
78 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
82 self.base_url = Some(base_url.into());
83 self
84 }
85
86 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
88 self.api_key = Some(api_key.into());
89 self
90 }
91
92 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
94 self.headers = Some(headers);
95 self
96 }
97
98 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
100 let mut headers = self.headers.unwrap_or_default();
101 headers.insert(key.into(), value.into());
102 self.headers = Some(headers);
103 self
104 }
105
106 pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
110 self.api_version = api_version.into();
111 self
112 }
113
114 pub fn with_use_deployment_based_urls(mut self, use_deployment_based_urls: bool) -> Self {
119 self.use_deployment_based_urls = use_deployment_based_urls;
120 self
121 }
122
123 pub(crate) fn get_base_url(&self) -> Option<String> {
127 if let Some(ref base_url) = self.base_url {
128 Some(base_url.clone())
129 } else {
130 self.resource_name
131 .as_ref()
132 .map(|resource_name| format!("https://{}.openai.azure.com/openai", resource_name))
133 }
134 }
135
136 pub(crate) fn validate(&self) -> Result<(), String> {
138 if self.base_url.is_none() && self.resource_name.is_none() {
139 return Err(
140 "Either base_url or resource_name must be provided for Azure OpenAI".to_string(),
141 );
142 }
143 Ok(())
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_default_settings() {
153 let settings = AzureOpenAIProviderSettings::new();
154 assert_eq!(settings.api_version, "v1");
155 assert!(!settings.use_deployment_based_urls);
156 assert!(settings.resource_name.is_none());
157 assert!(settings.base_url.is_none());
158 }
159
160 #[test]
161 fn test_with_resource_name() {
162 let settings = AzureOpenAIProviderSettings::new().with_resource_name("my-resource");
163 assert_eq!(settings.resource_name, Some("my-resource".to_string()));
164 assert_eq!(
165 settings.get_base_url(),
166 Some("https://my-resource.openai.azure.com/openai".to_string())
167 );
168 }
169
170 #[test]
171 fn test_with_base_url() {
172 let settings =
173 AzureOpenAIProviderSettings::new().with_base_url("https://custom.endpoint.com/openai");
174 assert_eq!(
175 settings.base_url,
176 Some("https://custom.endpoint.com/openai".to_string())
177 );
178 assert_eq!(
179 settings.get_base_url(),
180 Some("https://custom.endpoint.com/openai".to_string())
181 );
182 }
183
184 #[test]
185 fn test_base_url_takes_precedence() {
186 let settings = AzureOpenAIProviderSettings::new()
187 .with_resource_name("my-resource")
188 .with_base_url("https://custom.endpoint.com/openai");
189
190 assert_eq!(
192 settings.get_base_url(),
193 Some("https://custom.endpoint.com/openai".to_string())
194 );
195 }
196
197 #[test]
198 fn test_with_api_key() {
199 let settings = AzureOpenAIProviderSettings::new().with_api_key("test-key");
200 assert_eq!(settings.api_key, Some("test-key".to_string()));
201 }
202
203 #[test]
204 fn test_with_api_version() {
205 let settings = AzureOpenAIProviderSettings::new().with_api_version("2024-02-15-preview");
206 assert_eq!(settings.api_version, "2024-02-15-preview");
207 }
208
209 #[test]
210 fn test_with_headers() {
211 let mut headers = HashMap::new();
212 headers.insert("X-Custom".to_string(), "value".to_string());
213
214 let settings = AzureOpenAIProviderSettings::new().with_headers(headers.clone());
215 assert_eq!(settings.headers, Some(headers));
216 }
217
218 #[test]
219 fn test_with_header() {
220 let settings = AzureOpenAIProviderSettings::new()
221 .with_header("X-Custom-1", "value1")
222 .with_header("X-Custom-2", "value2");
223
224 let headers = settings.headers.unwrap();
225 assert_eq!(headers.get("X-Custom-1"), Some(&"value1".to_string()));
226 assert_eq!(headers.get("X-Custom-2"), Some(&"value2".to_string()));
227 }
228
229 #[test]
230 fn test_with_deployment_based_urls() {
231 let settings = AzureOpenAIProviderSettings::new().with_use_deployment_based_urls(true);
232 assert!(settings.use_deployment_based_urls);
233 }
234
235 #[test]
236 fn test_validate_success() {
237 let settings = AzureOpenAIProviderSettings::new().with_resource_name("my-resource");
238 assert!(settings.validate().is_ok());
239
240 let settings2 =
241 AzureOpenAIProviderSettings::new().with_base_url("https://custom.endpoint.com");
242 assert!(settings2.validate().is_ok());
243 }
244
245 #[test]
246 fn test_validate_failure() {
247 let settings = AzureOpenAIProviderSettings::new();
248 assert!(settings.validate().is_err());
249 }
250
251 #[test]
252 fn test_builder_pattern() {
253 let settings = AzureOpenAIProviderSettings::new()
254 .with_resource_name("test-resource")
255 .with_api_key("test-key")
256 .with_api_version("2024-02-15-preview")
257 .with_use_deployment_based_urls(true)
258 .with_header("X-Custom", "value");
259
260 assert_eq!(settings.resource_name, Some("test-resource".to_string()));
261 assert_eq!(settings.api_key, Some("test-key".to_string()));
262 assert_eq!(settings.api_version, "2024-02-15-preview");
263 assert!(settings.use_deployment_based_urls);
264 assert!(settings.headers.is_some());
265 assert_eq!(
266 settings.headers.unwrap().get("X-Custom"),
267 Some(&"value".to_string())
268 );
269 }
270}