1use std::pin::Pin;
8use std::time::{Duration, Instant};
9
10use futures::stream::{self, Stream, StreamExt as _};
11use serde::Deserialize;
12use tokio_util::sync::CancellationToken;
13use tracing::debug;
14
15use swink_agent::{AgentContext, AssistantMessageEvent, ModelSpec, StreamFn, StreamOptions};
16use swink_agent_auth::{ExpiringValue, SingleFlightTokenSource};
17
18use crate::classify::{HttpErrorKind, classify_with_overrides};
19use crate::oai_transport::{OaiAdapterShell, oai_send_and_parse, prepare_oai_request};
20
21#[derive(Clone)]
23pub enum AzureAuth {
24 ApiKey(String),
26 EntraId {
28 tenant_id: String,
29 client_id: String,
30 client_secret: String,
31 },
32}
33
34impl std::fmt::Debug for AzureAuth {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"[REDACTED]").finish(),
38 Self::EntraId { .. } => f
39 .debug_struct("EntraId")
40 .field("tenant_id", &"[REDACTED]")
41 .field("client_id", &"[REDACTED]")
42 .field("client_secret", &"[REDACTED]")
43 .finish(),
44 }
45 }
46}
47
48const REFRESH_MARGIN: Duration = Duration::from_secs(300);
50
51#[derive(Deserialize)]
53struct TokenResponse {
54 access_token: String,
55 expires_in: u64,
56}
57
58#[derive(Clone)]
59enum TokenAcquireError {
60 Auth(String),
61 Throttled(String),
62 Network(String),
63 Other(String),
64}
65
66pub struct AzureStreamFn {
67 shell: OaiAdapterShell,
68 auth: AzureAuth,
69 token_source: SingleFlightTokenSource<String, TokenAcquireError>,
70 token_endpoint_override: Option<String>,
72}
73
74impl AzureStreamFn {
75 #[must_use]
76 pub fn new(base_url: impl Into<String>, auth: AzureAuth) -> Self {
77 let shell_api_key = match &auth {
78 AzureAuth::ApiKey(key) => key.clone(),
79 AzureAuth::EntraId { .. } => String::new(),
80 };
81
82 Self {
83 shell: OaiAdapterShell::new_with_path(
84 "Azure",
85 base_url,
86 shell_api_key,
87 "/chat/completions",
88 ),
89 auth,
90 token_source: SingleFlightTokenSource::new(REFRESH_MARGIN),
91 token_endpoint_override: None,
92 }
93 }
94
95 #[must_use]
97 pub fn with_token_endpoint(mut self, url: impl Into<String>) -> Self {
98 self.token_endpoint_override = Some(url.into());
99 self
100 }
101}
102
103impl AzureStreamFn {
104 async fn acquire_token(
106 client: reqwest::Client,
107 token_url: String,
108 client_id: String,
109 client_secret: String,
110 ) -> Result<ExpiringValue<String>, TokenAcquireError> {
111 let params = [
112 ("grant_type", "client_credentials".to_string()),
113 ("client_id", client_id),
114 ("client_secret", client_secret),
115 (
116 "scope",
117 "https://cognitiveservices.azure.com/.default".to_string(),
118 ),
119 ];
120
121 let resp = client
122 .post(&token_url)
123 .form(¶ms)
124 .send()
125 .await
126 .map_err(|e| TokenAcquireError::Network(format!("token request failed: {e}")))?;
127
128 if !resp.status().is_success() {
129 let status = resp.status().as_u16();
130 let body = resp.text().await.unwrap_or_default();
131 return Err(match classify_token_endpoint_status(status) {
132 Some(HttpErrorKind::Auth) => TokenAcquireError::Auth(format!(
133 "token endpoint auth error (HTTP {status}): {body}"
134 )),
135 Some(HttpErrorKind::Throttled) => TokenAcquireError::Throttled(format!(
136 "token endpoint rate limit (HTTP {status}): {body}"
137 )),
138 Some(HttpErrorKind::Network) => TokenAcquireError::Network(format!(
139 "token endpoint server error (HTTP {status}): {body}"
140 )),
141 None => TokenAcquireError::Other(format!(
142 "token endpoint returned error (HTTP {status}): {body}"
143 )),
144 });
145 }
146
147 let token_resp: TokenResponse = resp.json().await.map_err(|e| {
148 TokenAcquireError::Other(format!("failed to parse token response: {e}"))
149 })?;
150
151 Ok(ExpiringValue::new(
152 token_resp.access_token,
153 Instant::now() + Duration::from_secs(token_resp.expires_in),
154 ))
155 }
156
157 async fn get_or_refresh_token(
159 &self,
160 tenant_id: &str,
161 client_id: &str,
162 client_secret: &str,
163 ) -> Result<String, TokenAcquireError> {
164 let client = self.shell.client().clone();
165 let token_url = self.token_url(tenant_id);
166 let client_id = client_id.to_string();
167 let client_secret = client_secret.to_string();
168
169 self.token_source
170 .get_or_refresh(move || {
171 Self::acquire_token(client, token_url, client_id, client_secret)
172 })
173 .await
174 }
175
176 fn token_url(&self, tenant_id: &str) -> String {
178 self.token_endpoint_override.as_ref().map_or_else(
179 || format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"),
180 Clone::clone,
181 )
182 }
183
184 async fn apply_auth(
186 &self,
187 request: reqwest::RequestBuilder,
188 options: &StreamOptions,
189 ) -> Result<reqwest::RequestBuilder, AssistantMessageEvent> {
190 match &self.auth {
191 AzureAuth::ApiKey(key) => {
192 let api_key = options.api_key.as_deref().unwrap_or(key);
193 Ok(request.header("api-key", api_key))
194 }
195 AzureAuth::EntraId {
196 tenant_id,
197 client_id,
198 client_secret,
199 } => {
200 let token = self
201 .get_or_refresh_token(tenant_id, client_id, client_secret)
202 .await
203 .map_err(|e| match e {
204 TokenAcquireError::Auth(message) => AssistantMessageEvent::error_auth(
205 format!("Azure token error: {message}"),
206 ),
207 TokenAcquireError::Throttled(message) => {
208 AssistantMessageEvent::error_throttled(format!(
209 "Azure token error: {message}"
210 ))
211 }
212 TokenAcquireError::Network(message) => {
213 AssistantMessageEvent::error_network(format!(
214 "Azure token error: {message}"
215 ))
216 }
217 TokenAcquireError::Other(message) => {
218 AssistantMessageEvent::error(format!("Azure token error: {message}"))
219 }
220 })?;
221 Ok(request.header("Authorization", format!("Bearer {token}")))
222 }
223 }
224 }
225}
226
227fn classify_token_endpoint_status(status: u16) -> Option<HttpErrorKind> {
228 match status {
229 400..=499 if status != 408 && status != 429 => Some(HttpErrorKind::Auth),
230 _ => classify_with_overrides(status, &[]),
231 }
232}
233
234impl std::fmt::Debug for AzureStreamFn {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("AzureStreamFn")
237 .field("base_url", &self.shell.base_url())
238 .field("auth", &self.auth)
239 .finish_non_exhaustive()
240 }
241}
242
243impl StreamFn for AzureStreamFn {
244 fn stream<'a>(
245 &'a self,
246 model: &'a ModelSpec,
247 context: &'a AgentContext,
248 options: &'a StreamOptions,
249 cancellation_token: CancellationToken,
250 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
251 Box::pin(azure_stream(
252 self,
253 model,
254 context,
255 options,
256 cancellation_token,
257 ))
258 }
259}
260
261fn azure_stream<'a>(
262 azure: &'a AzureStreamFn,
263 model: &'a ModelSpec,
264 context: &'a AgentContext,
265 options: &'a StreamOptions,
266 cancellation_token: CancellationToken,
267) -> impl Stream<Item = AssistantMessageEvent> + Send + 'a {
268 stream::once(async move {
269 let url = azure.shell.chat_completions_url();
270 debug!(
271 %url,
272 model = %model.model_id,
273 messages = context.messages.len(),
274 "sending Azure request"
275 );
276
277 let request = prepare_oai_request(azure.shell.client(), &url, model, context, options);
278 let request = match crate::base::race_pre_stream_cancellation(
279 &cancellation_token,
280 "Azure request cancelled",
281 azure.apply_auth(request, options),
282 )
283 .await
284 {
285 Ok(r) => r,
286 Err(event) => return stream::iter(crate::base::pre_stream_error(event)).left_stream(),
287 };
288
289 oai_send_and_parse(
290 request,
291 azure.shell.provider(),
292 cancellation_token,
293 options.on_raw_payload.clone(),
294 |status, body| {
295 if is_content_filter_error(body) {
296 Some(AssistantMessageEvent::error_content_filtered(format!(
297 "Azure content filter blocked request (HTTP {status})"
298 )))
299 } else {
300 None
301 }
302 },
303 )
304 .right_stream()
305 })
306 .flatten()
307}
308
309fn is_content_filter_error(body: &str) -> bool {
314 serde_json::from_str::<serde_json::Value>(body)
315 .ok()
316 .and_then(|v| v.get("error")?.get("code")?.as_str().map(String::from))
317 .is_some_and(|code| code == "ContentFilterBlocked")
318}
319
320const _: () = {
321 const fn assert_send_sync<T: Send + Sync>() {}
322 assert_send_sync::<AzureStreamFn>();
323};