use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureConfig {
pub api_key: Option<String>,
pub azure_endpoint: Option<String>,
pub api_version: String,
pub azure_ad_token_provider: Option<String>,
pub deployment_name: Option<String>,
pub resource_group: Option<String>,
pub subscription_id: Option<String>,
pub custom_headers: HashMap<String, String>,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
api_key: None,
azure_endpoint: None,
api_version: "2024-02-01".to_string(),
azure_ad_token_provider: None,
deployment_name: None,
resource_group: None,
subscription_id: None,
custom_headers: HashMap::new(),
}
}
}
impl AzureConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
self
}
pub fn with_azure_endpoint(mut self, endpoint: String) -> Self {
self.azure_endpoint = Some(endpoint);
self
}
pub fn with_api_version(mut self, version: String) -> Self {
self.api_version = version;
self
}
pub fn with_deployment_name(mut self, deployment: String) -> Self {
self.deployment_name = Some(deployment);
self
}
pub async fn get_effective_api_key(&self) -> Option<String> {
if let Some(key) = &self.api_key {
return Some(key.clone());
}
if let Ok(key) = std::env::var("AZURE_OPENAI_KEY") {
return Some(key);
}
if let Ok(key) = std::env::var("AZURE_API_KEY") {
return Some(key);
}
if self.azure_ad_token_provider.is_some() {
return None;
}
None
}
pub fn get_effective_azure_endpoint(&self) -> Option<String> {
self.azure_endpoint
.clone()
.or_else(|| std::env::var("AZURE_OPENAI_ENDPOINT").ok())
.or_else(|| std::env::var("AZURE_ENDPOINT").ok())
}
pub fn get_effective_deployment_name(&self, model: &str) -> String {
self.deployment_name
.clone()
.or_else(|| std::env::var("AZURE_DEPLOYMENT_NAME").ok())
.unwrap_or_else(|| model.to_string())
}
}
impl crate::core::traits::provider::ProviderConfig for AzureConfig {
fn validate(&self) -> Result<(), String> {
if self.get_effective_azure_endpoint().is_none() {
return Err("Azure endpoint is required".to_string());
}
if self.api_version.is_empty() {
return Err("API version is required".to_string());
}
Ok(())
}
fn api_key(&self) -> Option<&str> {
self.api_key.as_deref()
}
fn api_base(&self) -> Option<&str> {
self.azure_endpoint.as_deref()
}
fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_secs(60) }
fn max_retries(&self) -> u32 {
3 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureModelInfo {
pub deployment_name: String,
pub model_name: String,
pub max_tokens: Option<u32>,
pub supports_functions: bool,
pub supports_streaming: bool,
pub api_version: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_azure_config_default() {
let config = AzureConfig::default();
assert!(config.api_key.is_none());
assert!(config.azure_endpoint.is_none());
assert_eq!(config.api_version, "2024-02-01");
assert!(config.deployment_name.is_none());
}
#[test]
fn test_azure_config_builder() {
let config = AzureConfig::new()
.with_api_key("test-key".to_string())
.with_azure_endpoint("https://test.openai.azure.com".to_string())
.with_deployment_name("gpt-4".to_string())
.with_api_version("2024-03-01".to_string());
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(
config.azure_endpoint,
Some("https://test.openai.azure.com".to_string())
);
assert_eq!(config.deployment_name, Some("gpt-4".to_string()));
assert_eq!(config.api_version, "2024-03-01");
}
#[test]
fn test_azure_config_effective_deployment_name() {
let config = AzureConfig::new().with_deployment_name("my-deployment".to_string());
assert_eq!(
config.get_effective_deployment_name("gpt-4"),
"my-deployment"
);
let config_no_deployment = AzureConfig::new();
assert_eq!(
config_no_deployment.get_effective_deployment_name("gpt-4"),
"gpt-4"
);
}
#[test]
fn test_azure_config_effective_endpoint() {
let config =
AzureConfig::new().with_azure_endpoint("https://test.openai.azure.com".to_string());
assert_eq!(
config.get_effective_azure_endpoint(),
Some("https://test.openai.azure.com".to_string())
);
}
#[test]
fn test_azure_config_validation() {
use crate::core::traits::provider::ProviderConfig;
let config = AzureConfig::new();
assert!(config.validate().is_err());
let config_with_endpoint =
AzureConfig::new().with_azure_endpoint("https://test.openai.azure.com".to_string());
assert!(config_with_endpoint.validate().is_ok());
}
#[test]
fn test_azure_config_provider_config_trait() {
use crate::core::traits::provider::ProviderConfig;
let config = AzureConfig::new()
.with_api_key("test-key".to_string())
.with_azure_endpoint("https://test.openai.azure.com".to_string());
assert_eq!(config.api_key(), Some("test-key"));
assert_eq!(config.api_base(), Some("https://test.openai.azure.com"));
assert_eq!(config.timeout(), std::time::Duration::from_secs(60));
assert_eq!(config.max_retries(), 3);
}
#[test]
fn test_azure_model_info() {
let model_info = AzureModelInfo {
deployment_name: "gpt-4-deployment".to_string(),
model_name: "gpt-4".to_string(),
max_tokens: Some(8192),
supports_functions: true,
supports_streaming: true,
api_version: "2024-02-01".to_string(),
};
assert_eq!(model_info.deployment_name, "gpt-4-deployment");
assert_eq!(model_info.model_name, "gpt-4");
assert!(model_info.supports_functions);
assert!(model_info.supports_streaming);
}
}