1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Serialize;
4use serde_json::Value;
5
6use super::api_client::{ApiClient, AuthMethod, AuthProvider};
7use super::azureauth::{AuthError, AzureAuth};
8use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::errors::ProviderError;
10use super::formats::openai::{create_request, get_usage, response_to_message};
11use super::retry::ProviderRetry;
12use super::utils::{get_model, handle_response_openai_compat, ImageFormat};
13use crate::conversation::message::Message;
14use crate::model::ModelConfig;
15use crate::providers::utils::RequestLog;
16use rmcp::model::Tool;
17
18pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
19pub const AZURE_DOC_URL: &str =
20 "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
21pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
22pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
23
24#[derive(Debug)]
25pub struct AzureProvider {
26 api_client: ApiClient,
27 deployment_name: String,
28 api_version: String,
29 model: ModelConfig,
30 name: String,
31}
32
33impl Serialize for AzureProvider {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: serde::Serializer,
37 {
38 use serde::ser::SerializeStruct;
39 let mut state = serializer.serialize_struct("AzureProvider", 2)?;
40 state.serialize_field("deployment_name", &self.deployment_name)?;
41 state.serialize_field("api_version", &self.api_version)?;
42 state.end()
43 }
44}
45
46struct AzureAuthProvider {
48 auth: AzureAuth,
49}
50
51#[async_trait]
52impl AuthProvider for AzureAuthProvider {
53 async fn get_auth_header(&self) -> Result<(String, String)> {
54 let auth_token = self
55 .auth
56 .get_token()
57 .await
58 .map_err(|e| anyhow::anyhow!("Failed to get authentication token: {}", e))?;
59
60 match self.auth.credential_type() {
61 super::azureauth::AzureCredentials::ApiKey(_) => {
62 Ok(("api-key".to_string(), auth_token.token_value))
63 }
64 super::azureauth::AzureCredentials::DefaultCredential => Ok((
65 "Authorization".to_string(),
66 format!("Bearer {}", auth_token.token_value),
67 )),
68 }
69 }
70}
71
72impl AzureProvider {
73 pub async fn from_env(model: ModelConfig) -> Result<Self> {
74 let config = crate::config::Config::global();
75 let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
76 let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
77 let api_version: String = config
78 .get_param("AZURE_OPENAI_API_VERSION")
79 .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
80
81 let api_key = config
82 .get_secret("AZURE_OPENAI_API_KEY")
83 .ok()
84 .filter(|key: &String| !key.is_empty());
85 let auth = AzureAuth::new(api_key).map_err(|e| match e {
86 AuthError::Credentials(msg) => anyhow::anyhow!("Credentials error: {}", msg),
87 AuthError::TokenExchange(msg) => anyhow::anyhow!("Token exchange error: {}", msg),
88 })?;
89
90 let auth_provider = AzureAuthProvider { auth };
91 let api_client = ApiClient::new(endpoint, AuthMethod::Custom(Box::new(auth_provider)))?;
92
93 Ok(Self {
94 api_client,
95 deployment_name,
96 api_version,
97 model,
98 name: Self::metadata().name,
99 })
100 }
101
102 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
103 let path = format!(
105 "openai/deployments/{}/chat/completions?api-version={}",
106 self.deployment_name, self.api_version
107 );
108
109 let response = self.api_client.response_post(&path, payload).await?;
110 handle_response_openai_compat(response).await
111 }
112}
113
114#[async_trait]
115impl Provider for AzureProvider {
116 fn metadata() -> ProviderMetadata {
117 ProviderMetadata::new(
118 "azure_openai",
119 "Azure OpenAI",
120 "Models through Azure OpenAI Service (uses Azure credential chain by default)",
121 "gpt-4o",
122 AZURE_OPENAI_KNOWN_MODELS.to_vec(),
123 AZURE_DOC_URL,
124 vec![
125 ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
126 ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None),
127 ConfigKey::new("AZURE_OPENAI_API_VERSION", true, false, Some("2024-10-21")),
128 ConfigKey::new("AZURE_OPENAI_API_KEY", false, true, Some("")),
129 ],
130 )
131 }
132
133 fn get_name(&self) -> &str {
134 &self.name
135 }
136
137 fn get_model_config(&self) -> ModelConfig {
138 self.model.clone()
139 }
140
141 #[tracing::instrument(
142 skip(self, model_config, system, messages, tools),
143 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
144 )]
145 async fn complete_with_model(
146 &self,
147 model_config: &ModelConfig,
148 system: &str,
149 messages: &[Message],
150 tools: &[Tool],
151 ) -> Result<(Message, ProviderUsage), ProviderError> {
152 let payload = create_request(
153 model_config,
154 system,
155 messages,
156 tools,
157 &ImageFormat::OpenAi,
158 false,
159 )?;
160 let response = self
161 .with_retry(|| async {
162 let payload_clone = payload.clone();
163 self.post(&payload_clone).await
164 })
165 .await?;
166
167 let message = response_to_message(&response)?;
168 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
169 tracing::debug!("Failed to get usage data");
170 Usage::default()
171 });
172 let response_model = get_model(&response);
173 let mut log = RequestLog::start(model_config, &payload)?;
174 log.write(&response, Some(&usage))?;
175 Ok((message, ProviderUsage::new(response_model, usage)))
176 }
177}