use std::collections::HashMap;
use crate::core::traits::provider::ProviderConfig;
use crate::define_provider_config;
define_provider_config!(AzureAIConfig {});
impl AzureAIConfig {
pub fn from_env() -> Self {
let mut config = Self::new("azure_ai");
if config.base.api_base.is_none() {
if let Ok(api_base) = std::env::var("AZURE_AI_API_BASE") {
config.base.api_base = Some(api_base);
} else if let Ok(endpoint) = std::env::var("AZURE_AI_ENDPOINT") {
config.base.api_base = Some(endpoint);
}
}
if config.base.api_key.is_none() {
if let Ok(api_key) = std::env::var("AZURE_AI_API_KEY") {
config.base.api_key = Some(api_key);
} else if let Ok(api_key) = std::env::var("AZURE_API_KEY") {
config.base.api_key = Some(api_key);
}
}
config
}
pub fn build_endpoint_url(&self, path: &str) -> Result<String, String> {
let base_url = self
.base
.api_base
.as_ref()
.ok_or("Azure AI API base URL not set")?;
let base = base_url.trim_end_matches('/');
let endpoint_path = path.trim_start_matches('/');
Ok(format!("{}/{}", base, endpoint_path))
}
pub fn create_default_headers(&self) -> Result<HashMap<String, String>, String> {
let mut headers = HashMap::new();
if let Some(api_key) = &self.base.api_key {
headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
} else {
return Err("Azure AI API key not set".to_string());
}
headers.insert("Content-Type".to_string(), "application/json".to_string());
headers.insert("User-Agent".to_string(), "litellm-rust/0.1.0".to_string());
headers.insert("api-version".to_string(), "2024-05-01-preview".to_string());
Ok(headers)
}
pub fn timeout(&self) -> std::time::Duration {
self.base.timeout_duration()
}
pub fn validate(&self) -> Result<(), String> {
self.base.validate("azure_ai")
}
}
impl ProviderConfig for AzureAIConfig {
fn validate(&self) -> Result<(), String> {
self.base.validate("azure_ai")
}
fn api_key(&self) -> Option<&str> {
self.base.api_key.as_deref()
}
fn api_base(&self) -> Option<&str> {
self.base.api_base.as_deref()
}
fn timeout(&self) -> std::time::Duration {
self.base.timeout_duration()
}
fn max_retries(&self) -> u32 {
self.base.max_retries
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AzureAIEndpointType {
ChatCompletions,
Embeddings,
ImageEmbeddings,
ImageGeneration,
Rerank,
}
impl AzureAIEndpointType {
pub fn as_path(&self) -> &'static str {
match self {
AzureAIEndpointType::ChatCompletions => "chat/completions",
AzureAIEndpointType::Embeddings => "embeddings",
AzureAIEndpointType::ImageEmbeddings => "images/embeddings",
AzureAIEndpointType::ImageGeneration => "images/generations",
AzureAIEndpointType::Rerank => "rerank",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_azure_ai_config() {
let config = AzureAIConfig::new("azure_ai");
assert_eq!(config.base.max_retries, 3);
assert_eq!(config.base.timeout, 60);
}
#[test]
fn test_endpoint_types() {
assert_eq!(
AzureAIEndpointType::ChatCompletions.as_path(),
"chat/completions"
);
assert_eq!(AzureAIEndpointType::Embeddings.as_path(), "embeddings");
assert_eq!(
AzureAIEndpointType::ImageGeneration.as_path(),
"images/generations"
);
assert_eq!(AzureAIEndpointType::Rerank.as_path(), "rerank");
}
#[test]
fn test_build_endpoint_url() {
let mut config = AzureAIConfig::new("azure_ai");
config.base.api_base = Some("https://test.ai.azure.com".to_string());
let url = config.build_endpoint_url("chat/completions").unwrap();
assert_eq!(url, "https://test.ai.azure.com/chat/completions");
config.base.api_base = Some("https://test.ai.azure.com/".to_string());
let url = config.build_endpoint_url("/chat/completions").unwrap();
assert_eq!(url, "https://test.ai.azure.com/chat/completions");
}
#[test]
fn test_build_endpoint_url_no_base() {
let config = AzureAIConfig::default();
let result = config.build_endpoint_url("chat/completions");
assert!(result.is_err());
assert!(result.unwrap_err().contains("API base URL not set"));
}
#[test]
fn test_create_default_headers_with_api_key() {
let mut config = AzureAIConfig::new("azure_ai");
config.base.api_key = Some("test-api-key".to_string());
let headers = config.create_default_headers().unwrap();
assert_eq!(headers.get("Authorization").unwrap(), "Bearer test-api-key");
assert_eq!(headers.get("Content-Type").unwrap(), "application/json");
assert_eq!(headers.get("User-Agent").unwrap(), "litellm-rust/0.1.0");
assert_eq!(headers.get("api-version").unwrap(), "2024-05-01-preview");
}
#[test]
fn test_create_default_headers_no_api_key() {
let config = AzureAIConfig::new("azure_ai");
let result = config.create_default_headers();
assert!(result.is_err());
assert!(result.unwrap_err().contains("API key not set"));
}
#[test]
fn test_timeout() {
let config = AzureAIConfig::new("azure_ai");
let timeout = config.timeout();
assert_eq!(timeout, std::time::Duration::from_secs(60));
}
#[test]
fn test_validate_success() {
let mut config = AzureAIConfig::new("azure_ai");
config.base.api_key = Some("test-key".to_string());
config.base.api_base = Some("https://test.com".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_missing_api_key() {
let mut config = AzureAIConfig::new("azure_ai");
config.base.api_base = Some("https://test.com".to_string());
assert!(config.validate().is_err());
}
#[test]
fn test_image_embeddings_endpoint() {
assert_eq!(
AzureAIEndpointType::ImageEmbeddings.as_path(),
"images/embeddings"
);
}
#[test]
fn test_endpoint_type_equality() {
assert_eq!(
AzureAIEndpointType::ChatCompletions,
AzureAIEndpointType::ChatCompletions
);
assert_ne!(
AzureAIEndpointType::ChatCompletions,
AzureAIEndpointType::Embeddings
);
}
#[test]
fn test_provider_config_trait() {
let mut config = AzureAIConfig::new("azure_ai");
config.base.api_key = Some("test-key".to_string());
config.base.api_base = Some("https://test.com".to_string());
assert_eq!(config.api_key(), Some("test-key"));
assert_eq!(config.api_base(), Some("https://test.com"));
assert_eq!(
ProviderConfig::timeout(&config),
std::time::Duration::from_secs(60)
);
assert_eq!(config.max_retries(), 3);
}
}