use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct AzureProvider {
base: BaseProvider,
api_version: String,
deployment_name: Option<String>,
pricing_cache: HashMap<String, ModelPricing>,
}
impl AzureProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let provider = Self {
base,
api_version: "2023-12-01-preview".to_string(), deployment_name: config.organization.as_ref().cloned(),
pricing_cache: HashMap::new(), };
info!("Azure provider '{}' initialized successfully", config.name);
Ok(provider)
}
fn create_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("api-key", self.base.api_key.parse().unwrap());
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
}
}
#[async_trait]
impl Provider for AzureProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Azure
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model)
}
async fn supports_images(&self) -> bool {
true
}
async fn supports_embeddings(&self) -> bool {
true
}
async fn supports_streaming(&self) -> bool {
true
}
async fn list_models(&self) -> Result<Vec<Model>> {
let models = self
.base
.get_supported_models()
.iter()
.map(|model| Model {
id: model.clone(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "azure".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Azure health check");
Ok(())
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Azure chat completion for model: {}", request.model);
let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);
let endpoint = format!(
"openai/deployments/{}/chat/completions?api-version={}",
deployment, self.api_version
);
let body = json!({
"messages": request.messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"n": request.n,
"stream": request.stream,
"stop": request.stop,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"user": request.user
});
let response = self
.base
.make_request(reqwest::Method::POST, &endpoint, Some(body))
.await?;
let chat_response: ChatCompletionResponse = self.base.parse_json_response(response).await?;
Ok(chat_response)
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("Azure completion for model: {}", request.model);
let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);
let endpoint = format!(
"openai/deployments/{}/completions?api-version={}",
deployment, self.api_version
);
let body = json!({
"prompt": request.prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"n": request.n,
"stream": request.stream,
"stop": request.stop,
"user": request.user
});
let response = self
.base
.make_request(reqwest::Method::POST, &endpoint, Some(body))
.await?;
let completion_response: CompletionResponse =
self.base.parse_json_response(response).await?;
Ok(completion_response)
}
async fn embedding(
&self,
request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
debug!("Azure embedding for model: {}", request.model);
let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);
let endpoint = format!(
"openai/deployments/{}/embeddings?api-version={}",
deployment, self.api_version
);
let body = json!({
"input": request.input,
"user": request.user
});
let response = self
.base
.make_request(reqwest::Method::POST, &endpoint, Some(body))
.await?;
let embedding_response: EmbeddingResponse = self.base.parse_json_response(response).await?;
Ok(embedding_response)
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(
ProviderError::InvalidRequest("Azure image generation not implemented yet".to_string())
.into(),
)
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.01, output_cost_per_1k: 0.03,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}