litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Azure OpenAI provider implementation
//!
//! This module provides Azure OpenAI API integration.

use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};

/// Azure OpenAI provider implementation
#[derive(Debug, Clone)]
pub struct AzureProvider {
    /// Base provider functionality
    base: BaseProvider,
    /// Azure API version
    api_version: String,
    /// Azure deployment name
    deployment_name: Option<String>,
    /// Model pricing cache
    pricing_cache: HashMap<String, ModelPricing>,
}

impl AzureProvider {
    /// Create a new Azure provider
    pub async fn new(config: &ProviderConfig) -> Result<Self> {
        let base = BaseProvider::new(config)?;

        let provider = Self {
            base,
            api_version: "2023-12-01-preview".to_string(), // Default API version
            deployment_name: config.organization.as_ref().cloned(),
            pricing_cache: HashMap::new(), // Azure pricing varies by region
        };

        info!("Azure provider '{}' initialized successfully", config.name);
        Ok(provider)
    }

    /// Create request headers for Azure OpenAI API
    fn create_headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();

        headers.insert("api-key", self.base.api_key.parse().unwrap());

        headers.insert(
            reqwest::header::CONTENT_TYPE,
            "application/json".parse().unwrap(),
        );

        headers
    }
}

#[async_trait]
impl Provider for AzureProvider {
    fn name(&self) -> &str {
        &self.base.name
    }

    fn provider_type(&self) -> ProviderType {
        ProviderType::Azure
    }

    async fn supports_model(&self, model: &str) -> bool {
        self.base.is_model_supported(model)
    }

    async fn supports_images(&self) -> bool {
        true
    }

    async fn supports_embeddings(&self) -> bool {
        true
    }

    async fn supports_streaming(&self) -> bool {
        true
    }

    async fn list_models(&self) -> Result<Vec<Model>> {
        // Azure uses deployments, return configured models
        let models = self
            .base
            .get_supported_models()
            .iter()
            .map(|model| Model {
                id: model.clone(),
                object: "model".to_string(),
                created: chrono::Utc::now().timestamp() as u64,
                owned_by: "azure".to_string(),
            })
            .collect();

        Ok(models)
    }

    async fn health_check(&self) -> Result<()> {
        debug!("Performing Azure health check");
        // TODO: Implement Azure-specific health check
        Ok(())
    }

    async fn chat_completion(
        &self,
        request: ChatCompletionRequest,
        _context: RequestContext,
    ) -> Result<ChatCompletionResponse> {
        debug!("Azure chat completion for model: {}", request.model);

        // Azure endpoint format: {base_url}/openai/deployments/{deployment}/chat/completions?api-version={api_version}
        let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);

        let endpoint = format!(
            "openai/deployments/{}/chat/completions?api-version={}",
            deployment, self.api_version
        );

        let body = json!({
            "messages": request.messages,
            "max_tokens": request.max_tokens,
            "temperature": request.temperature,
            "top_p": request.top_p,
            "n": request.n,
            "stream": request.stream,
            "stop": request.stop,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
            "user": request.user
        });

        let response = self
            .base
            .make_request(reqwest::Method::POST, &endpoint, Some(body))
            .await?;
        let chat_response: ChatCompletionResponse = self.base.parse_json_response(response).await?;

        Ok(chat_response)
    }

    async fn completion(
        &self,
        request: CompletionRequest,
        _context: RequestContext,
    ) -> Result<CompletionResponse> {
        debug!("Azure completion for model: {}", request.model);

        let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);

        let endpoint = format!(
            "openai/deployments/{}/completions?api-version={}",
            deployment, self.api_version
        );

        let body = json!({
            "prompt": request.prompt,
            "max_tokens": request.max_tokens,
            "temperature": request.temperature,
            "top_p": request.top_p,
            "n": request.n,
            "stream": request.stream,
            "stop": request.stop,
            "user": request.user
        });

        let response = self
            .base
            .make_request(reqwest::Method::POST, &endpoint, Some(body))
            .await?;
        let completion_response: CompletionResponse =
            self.base.parse_json_response(response).await?;

        Ok(completion_response)
    }

    async fn embedding(
        &self,
        request: EmbeddingRequest,
        _context: RequestContext,
    ) -> Result<EmbeddingResponse> {
        debug!("Azure embedding for model: {}", request.model);

        let deployment = self.deployment_name.as_ref().unwrap_or(&request.model);

        let endpoint = format!(
            "openai/deployments/{}/embeddings?api-version={}",
            deployment, self.api_version
        );

        let body = json!({
            "input": request.input,
            "user": request.user
        });

        let response = self
            .base
            .make_request(reqwest::Method::POST, &endpoint, Some(body))
            .await?;
        let embedding_response: EmbeddingResponse = self.base.parse_json_response(response).await?;

        Ok(embedding_response)
    }

    async fn image_generation(
        &self,
        _request: ImageGenerationRequest,
        _context: RequestContext,
    ) -> Result<ImageGenerationResponse> {
        Err(
            ProviderError::InvalidRequest("Azure image generation not implemented yet".to_string())
                .into(),
        )
    }

    async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
        // Azure pricing varies by region and is complex
        Ok(ModelPricing {
            model: model.to_string(),
            input_cost_per_1k: 0.01, // Placeholder
            output_cost_per_1k: 0.03,
            currency: "USD".to_string(),
            updated_at: chrono::Utc::now(),
        })
    }

    async fn calculate_cost(
        &self,
        model: &str,
        input_tokens: u32,
        output_tokens: u32,
    ) -> Result<f64> {
        let pricing = self.get_model_pricing(model).await?;

        let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
        let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;

        Ok(input_cost + output_cost)
    }
}