pub mod assistants;
pub mod batches;
pub mod chat;
pub mod client;
pub mod config;
pub mod embed;
pub mod error;
pub mod image;
pub mod responses;
pub mod utils;
pub use crate::core::providers::unified_provider::ProviderError;
pub use client::{AzureClient, AzureConfigFactory, AzureRateLimitInfo};
pub use config::{AzureConfig, AzureModelInfo};
pub use error::{
AzureErrorMapper, azure_ad_error, azure_api_error, azure_config_error, azure_deployment_error,
azure_header_error,
};
pub use utils::{AzureEndpointType, AzureUtils};
pub use crate::core::cost::providers::azure::{
AzureCostCalculator, cost_per_token, get_azure_model_pricing,
};
pub use assistants::{AzureAssistantHandler, AzureAssistantUtils};
pub use batches::{AzureBatchHandler, AzureBatchUtils};
pub use chat::{AzureChatHandler, AzureChatUtils};
pub use embed::{AzureEmbeddingHandler, AzureEmbeddingUtils};
pub use image::{AzureImageHandler, AzureImageUtils};
pub use responses::{AzureResponseHandler, AzureResponseProcessor, AzureResponseUtils};
use futures::Stream;
use serde_json::Value;
use std::pin::Pin;
use crate::core::types::{
chat::ChatRequest,
context::RequestContext,
embedding::EmbeddingRequest,
health::HealthStatus,
image::ImageGenerationRequest,
model::ModelInfo,
model::ProviderCapability,
responses::{ChatChunk, ChatResponse, EmbeddingResponse, ImageGenerationResponse},
};
use crate::core::traits::error_mapper::trait_def::ErrorMapper;
use crate::core::traits::provider::llm_provider::trait_definition::LLMProvider;
#[derive(Debug, Clone)]
pub struct AzureOpenAIProvider {
config: AzureConfig,
chat_handler: AzureChatHandler,
embedding_handler: AzureEmbeddingHandler,
image_handler: AzureImageHandler,
cost_calculator: AzureCostCalculator,
}
impl AzureOpenAIProvider {
pub fn new(config: AzureConfig) -> Result<Self, ProviderError> {
let chat_handler = AzureChatHandler::new(config.clone())?;
let embedding_handler = AzureEmbeddingHandler::new(config.clone())?;
let image_handler = AzureImageHandler::new(config.clone())?;
let cost_calculator = AzureCostCalculator::new();
Ok(Self {
config,
chat_handler,
embedding_handler,
image_handler,
cost_calculator,
})
}
pub fn from_config(config: AzureConfig) -> Result<Self, ProviderError> {
Self::new(config)
}
pub fn get_azure_config(&self) -> &AzureConfig {
&self.config
}
pub fn get_cost_calculator(&self) -> &AzureCostCalculator {
&self.cost_calculator
}
pub fn from_env() -> Result<Self, ProviderError> {
let config = AzureConfig::new();
Self::new(config)
}
pub fn with_api_key(
api_key: impl Into<String>,
endpoint: impl Into<String>,
) -> Result<Self, ProviderError> {
let config = AzureConfig::new()
.with_api_key(api_key.into())
.with_azure_endpoint(endpoint.into());
Self::new(config)
}
}
impl LLMProvider for AzureOpenAIProvider {
fn name(&self) -> &'static str {
"azure_openai"
}
fn capabilities(&self) -> &'static [ProviderCapability] {
static CAPABILITIES: &[ProviderCapability] = &[
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::Embeddings,
ProviderCapability::ImageGeneration,
ProviderCapability::FunctionCalling,
ProviderCapability::ToolCalling,
];
CAPABILITIES
}
fn models(&self) -> &[ModelInfo] {
&[]
}
fn get_supported_openai_params(&self, _model: &str) -> &'static [&'static str] {
&[
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stream",
"functions",
"function_call",
"tools",
"tool_choice",
]
}
async fn map_openai_params(
&self,
params: std::collections::HashMap<String, serde_json::Value>,
_model: &str,
) -> Result<std::collections::HashMap<String, serde_json::Value>, ProviderError> {
Ok(params)
}
async fn transform_request(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<Value, ProviderError> {
self.chat_handler.transform_request(&request)
}
async fn transform_response(
&self,
raw_response: &[u8],
model: &str,
_request_id: &str,
) -> Result<ChatResponse, ProviderError> {
let response_json: Value = serde_json::from_slice(raw_response)?;
self.chat_handler.transform_response(response_json, model)
}
fn get_error_mapper(&self) -> Box<dyn ErrorMapper<ProviderError>> {
Box::new(AzureErrorMapper)
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64, ProviderError> {
let cost = match model {
"gpt-35-turbo" => {
(input_tokens as f64 * 0.0015 + output_tokens as f64 * 0.002) / 1000.0
}
"gpt-4" => (input_tokens as f64 * 0.03 + output_tokens as f64 * 0.06) / 1000.0,
"gpt-4-turbo" => (input_tokens as f64 * 0.01 + output_tokens as f64 * 0.03) / 1000.0,
_ => 0.0,
};
Ok(cost)
}
async fn chat_completion(
&self,
request: ChatRequest,
context: RequestContext,
) -> Result<ChatResponse, ProviderError> {
self.chat_handler
.create_chat_completion(request, context)
.await
}
async fn chat_completion_stream(
&self,
request: ChatRequest,
context: RequestContext,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatChunk, ProviderError>> + Send>>, ProviderError>
{
self.chat_handler
.create_chat_completion_stream(request, context)
.await
}
async fn embeddings(
&self,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse, ProviderError> {
self.embedding_handler
.create_embeddings(request, context)
.await
}
async fn image_generation(
&self,
request: ImageGenerationRequest,
context: RequestContext,
) -> Result<ImageGenerationResponse, ProviderError> {
self.image_handler.generate_image(request, context).await
}
async fn health_check(&self) -> HealthStatus {
if self.config.api_key.is_some() {
HealthStatus::Healthy
} else {
HealthStatus::Unhealthy
}
}
}
pub struct AzureProviderFactory;
impl AzureProviderFactory {
pub fn create_default() -> Result<AzureOpenAIProvider, ProviderError> {
let config = AzureConfig::new();
AzureOpenAIProvider::new(config)
}
pub fn create_with_config(config: AzureConfig) -> Result<AzureOpenAIProvider, ProviderError> {
AzureOpenAIProvider::new(config)
}
pub fn create_from_env() -> Result<AzureOpenAIProvider, ProviderError> {
AzureOpenAIProvider::from_env()
}
}