Skip to main content

aster/providers/
azure.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Serialize;
4use serde_json::Value;
5
6use super::api_client::{ApiClient, AuthMethod, AuthProvider};
7use super::azureauth::{AuthError, AzureAuth};
8use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::errors::ProviderError;
10use super::formats::openai::{create_request, get_usage, response_to_message};
11use super::retry::ProviderRetry;
12use super::utils::{get_model, handle_response_openai_compat, ImageFormat};
13use crate::conversation::message::Message;
14use crate::model::ModelConfig;
15use crate::providers::utils::RequestLog;
16use rmcp::model::Tool;
17
18pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
19pub const AZURE_DOC_URL: &str =
20    "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
21pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
22pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
23
24#[derive(Debug)]
25pub struct AzureProvider {
26    api_client: ApiClient,
27    deployment_name: String,
28    api_version: String,
29    model: ModelConfig,
30    name: String,
31}
32
33impl Serialize for AzureProvider {
34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35    where
36        S: serde::Serializer,
37    {
38        use serde::ser::SerializeStruct;
39        let mut state = serializer.serialize_struct("AzureProvider", 2)?;
40        state.serialize_field("deployment_name", &self.deployment_name)?;
41        state.serialize_field("api_version", &self.api_version)?;
42        state.end()
43    }
44}
45
46// Custom auth provider that wraps AzureAuth
47struct AzureAuthProvider {
48    auth: AzureAuth,
49}
50
51#[async_trait]
52impl AuthProvider for AzureAuthProvider {
53    async fn get_auth_header(&self) -> Result<(String, String)> {
54        let auth_token = self
55            .auth
56            .get_token()
57            .await
58            .map_err(|e| anyhow::anyhow!("Failed to get authentication token: {}", e))?;
59
60        match self.auth.credential_type() {
61            super::azureauth::AzureCredentials::ApiKey(_) => {
62                Ok(("api-key".to_string(), auth_token.token_value))
63            }
64            super::azureauth::AzureCredentials::DefaultCredential => Ok((
65                "Authorization".to_string(),
66                format!("Bearer {}", auth_token.token_value),
67            )),
68        }
69    }
70}
71
72impl AzureProvider {
73    pub async fn from_env(model: ModelConfig) -> Result<Self> {
74        let config = crate::config::Config::global();
75        let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
76        let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
77        let api_version: String = config
78            .get_param("AZURE_OPENAI_API_VERSION")
79            .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
80
81        let api_key = config
82            .get_secret("AZURE_OPENAI_API_KEY")
83            .ok()
84            .filter(|key: &String| !key.is_empty());
85        let auth = AzureAuth::new(api_key).map_err(|e| match e {
86            AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg),
87            AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg),
88        })?;
89
90        let auth_provider = AzureAuthProvider { auth };
91        let api_client = ApiClient::new(endpoint, AuthMethod::Custom(Box::new(auth_provider)))?;
92
93        Ok(Self {
94            api_client,
95            deployment_name,
96            api_version,
97            model,
98            name: Self::metadata().name,
99        })
100    }
101
102    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
103        // Build the path for Azure OpenAI
104        let path = format!(
105            "openai/deployments/{}/chat/completions?api-version={}",
106            self.deployment_name, self.api_version
107        );
108
109        let response = self.api_client.response_post(&path, payload).await?;
110        handle_response_openai_compat(response).await
111    }
112}
113
114#[async_trait]
115impl Provider for AzureProvider {
116    fn metadata() -> ProviderMetadata {
117        ProviderMetadata::new(
118            "azure_openai",
119            "Azure OpenAI",
120            "Models through Azure OpenAI Service (uses Azure credential chain by default)",
121            "gpt-4o",
122            AZURE_OPENAI_KNOWN_MODELS.to_vec(),
123            AZURE_DOC_URL,
124            vec![
125                ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
126                ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None),
127                ConfigKey::new("AZURE_OPENAI_API_VERSION", true, false, Some("2024-10-21")),
128                ConfigKey::new("AZURE_OPENAI_API_KEY", false, true, Some("")),
129            ],
130        )
131    }
132
133    fn get_name(&self) -> &str {
134        &self.name
135    }
136
137    fn get_model_config(&self) -> ModelConfig {
138        self.model.clone()
139    }
140
141    #[tracing::instrument(
142        skip(self, model_config, system, messages, tools),
143        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
144    )]
145    async fn complete_with_model(
146        &self,
147        model_config: &ModelConfig,
148        system: &str,
149        messages: &[Message],
150        tools: &[Tool],
151    ) -> Result<(Message, ProviderUsage), ProviderError> {
152        let payload = create_request(
153            model_config,
154            system,
155            messages,
156            tools,
157            &ImageFormat::OpenAi,
158            false,
159        )?;
160        let response = self
161            .with_retry(|| async {
162                let payload_clone = payload.clone();
163                self.post(&payload_clone).await
164            })
165            .await?;
166
167        let message = response_to_message(&response)?;
168        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
169            tracing::debug!("Failed to get usage data");
170            Usage::default()
171        });
172        let response_model = get_model(&response);
173        let mut log = RequestLog::start(model_config, &payload)?;
174        log.write(&response, Some(&usage))?;
175        Ok((message, ProviderUsage::new(response_model, usage)))
176    }
177}