1use async_trait::async_trait;
11use futures::StreamExt;
12use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use tokio::sync::mpsc;
14use tracing::debug;
15
16use super::message::{ContentBlock, Message, StopReason, Usage};
17use super::provider::{Provider, ProviderError, ProviderRequest};
18use super::stream::StreamEvent;
19
20pub struct AzureOpenAiProvider {
22 http: reqwest::Client,
23 base_url: String,
24 api_key: String,
25 api_version: String,
26}
27
28impl AzureOpenAiProvider {
29 pub fn new(base_url: &str, api_key: &str) -> Self {
30 let http = reqwest::Client::builder()
31 .timeout(std::time::Duration::from_secs(300))
32 .build()
33 .expect("failed to build HTTP client");
34
35 let api_version =
36 std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2024-10-21".to_string());
37
38 Self {
39 http,
40 base_url: base_url.trim_end_matches('/').to_string(),
41 api_key: api_key.to_string(),
42 api_version,
43 }
44 }
45
46 fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
49 let mut messages = Vec::new();
50
51 if !request.system_prompt.is_empty() {
53 messages.push(serde_json::json!({
54 "role": "system",
55 "content": request.system_prompt,
56 }));
57 }
58
59 for msg in &request.messages {
61 match msg {
62 Message::User(u) => {
63 let content = blocks_to_openai_content(&u.content);
64 messages.push(serde_json::json!({
65 "role": "user",
66 "content": content,
67 }));
68 }
69 Message::Assistant(a) => {
70 let mut msg_json = serde_json::json!({
71 "role": "assistant",
72 });
73
74 let tool_calls: Vec<serde_json::Value> = a
75 .content
76 .iter()
77 .filter_map(|b| match b {
78 ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
79 "id": id,
80 "type": "function",
81 "function": {
82 "name": name,
83 "arguments": serde_json::to_string(input).unwrap_or_default(),
84 }
85 })),
86 _ => None,
87 })
88 .collect();
89
90 let text: String = a
91 .content
92 .iter()
93 .filter_map(|b| match b {
94 ContentBlock::Text { text } => Some(text.as_str()),
95 _ => None,
96 })
97 .collect::<Vec<_>>()
98 .join("");
99
100 msg_json["content"] = serde_json::Value::String(text);
101 if !tool_calls.is_empty() {
102 msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
103 }
104
105 messages.push(msg_json);
106 }
107 Message::System(_) => {} }
109 }
110
111 let mut final_messages = Vec::new();
113 for msg in messages {
114 if msg.get("role").and_then(|r| r.as_str()) == Some("user")
115 && let Some(content) = msg.get("content")
116 && let Some(arr) = content.as_array()
117 {
118 let mut tool_results = Vec::new();
119 let mut other_content = Vec::new();
120
121 for block in arr {
122 if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
123 tool_results.push(serde_json::json!({
124 "role": "tool",
125 "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
126 "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
127 }));
128 } else {
129 other_content.push(block.clone());
130 }
131 }
132
133 if !tool_results.is_empty() {
134 for tr in tool_results {
135 final_messages.push(tr);
136 }
137 if !other_content.is_empty() {
138 let mut m = msg.clone();
139 m["content"] = serde_json::Value::Array(other_content);
140 final_messages.push(m);
141 }
142 continue;
143 }
144 }
145 final_messages.push(msg);
146 }
147
148 let tools: Vec<serde_json::Value> = request
150 .tools
151 .iter()
152 .map(|t| {
153 serde_json::json!({
154 "type": "function",
155 "function": {
156 "name": t.name,
157 "description": t.description,
158 "parameters": t.input_schema,
159 }
160 })
161 })
162 .collect();
163
164 let mut body = serde_json::json!({
166 "messages": final_messages,
167 "stream": true,
168 "stream_options": { "include_usage": true },
169 "max_tokens": request.max_tokens,
170 });
171
172 if !tools.is_empty() {
173 body["tools"] = serde_json::Value::Array(tools);
174
175 use super::provider::ToolChoice;
176 match &request.tool_choice {
177 ToolChoice::Auto => {
178 body["tool_choice"] = serde_json::json!("auto");
179 }
180 ToolChoice::Any => {
181 body["tool_choice"] = serde_json::json!("required");
182 }
183 ToolChoice::None => {
184 body["tool_choice"] = serde_json::json!("none");
185 }
186 ToolChoice::Specific(name) => {
187 body["tool_choice"] = serde_json::json!({
188 "type": "function",
189 "function": { "name": name }
190 });
191 }
192 }
193 }
194 if let Some(temp) = request.temperature {
195 body["temperature"] = serde_json::json!(temp);
196 }
197
198 body
199 }
200}
201
202#[async_trait]
203impl Provider for AzureOpenAiProvider {
204 fn name(&self) -> &str {
205 "azure-openai"
206 }
207
208 async fn stream(
209 &self,
210 request: &ProviderRequest,
211 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
212 let url = format!(
213 "{}/chat/completions?api-version={}",
214 self.base_url, self.api_version
215 );
216 let body = self.build_body(request);
217
218 let mut headers = HeaderMap::new();
219 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221 if let Ok(ad_token) = std::env::var("AZURE_OPENAI_AD_TOKEN") {
223 headers.insert(
224 AUTHORIZATION,
225 HeaderValue::from_str(&format!("Bearer {ad_token}"))
226 .map_err(|e| ProviderError::Auth(e.to_string()))?,
227 );
228 } else {
229 headers.insert(
230 HeaderName::from_static("api-key"),
231 HeaderValue::from_str(&self.api_key)
232 .map_err(|e| ProviderError::Auth(e.to_string()))?,
233 );
234 }
235
236 debug!("Azure OpenAI request to {url}");
237
238 let response = self
239 .http
240 .post(&url)
241 .headers(headers)
242 .json(&body)
243 .send()
244 .await
245 .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247 let status = response.status();
248 if !status.is_success() {
249 let body_text = response.text().await.unwrap_or_default();
250 return match status.as_u16() {
251 401 | 403 => Err(ProviderError::Auth(body_text)),
252 429 => Err(ProviderError::RateLimited {
253 retry_after_ms: 1000,
254 }),
255 529 => Err(ProviderError::Overloaded),
256 413 => Err(ProviderError::RequestTooLarge(body_text)),
257 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258 };
259 }
260
261 let (tx, rx) = mpsc::channel(64);
263 let cancel = request.cancel.clone();
264 tokio::spawn(async move {
265 let mut byte_stream = response.bytes_stream();
266 let mut buffer = String::new();
267 let mut current_tool_id = String::new();
268 let mut current_tool_name = String::new();
269 let mut current_tool_args = String::new();
270 let mut usage = Usage::default();
271 let mut stop_reason: Option<StopReason> = None;
272
273 loop {
274 let chunk_result = tokio::select! {
278 biased;
279 _ = cancel.cancelled() => return,
280 chunk = byte_stream.next() => match chunk {
281 Some(c) => c,
282 None => break,
283 },
284 };
285 let chunk = match chunk_result {
286 Ok(c) => c,
287 Err(e) => {
288 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
289 break;
290 }
291 };
292
293 buffer.push_str(&String::from_utf8_lossy(&chunk));
294
295 while let Some(pos) = buffer.find("\n\n") {
296 let event_text = buffer[..pos].to_string();
297 buffer = buffer[pos + 2..].to_string();
298
299 for line in event_text.lines() {
300 let data = if let Some(d) = line.strip_prefix("data: ") {
301 d
302 } else {
303 continue;
304 };
305
306 if data == "[DONE]" {
307 if !current_tool_id.is_empty() {
308 let input: serde_json::Value =
309 serde_json::from_str(¤t_tool_args).unwrap_or_default();
310 let _ = tx
311 .send(StreamEvent::ContentBlockComplete(
312 ContentBlock::ToolUse {
313 id: current_tool_id.clone(),
314 name: current_tool_name.clone(),
315 input,
316 },
317 ))
318 .await;
319 current_tool_id.clear();
320 current_tool_name.clear();
321 current_tool_args.clear();
322 }
323
324 let _ = tx
325 .send(StreamEvent::Done {
326 usage: usage.clone(),
327 stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
328 })
329 .await;
330 return;
331 }
332
333 let parsed: serde_json::Value = match serde_json::from_str(data) {
334 Ok(v) => v,
335 Err(_) => continue,
336 };
337
338 let delta = match parsed
339 .get("choices")
340 .and_then(|c| c.get(0))
341 .and_then(|c| c.get("delta"))
342 {
343 Some(d) => d,
344 None => {
345 if let Some(u) = parsed.get("usage") {
346 usage.input_tokens = u
347 .get("prompt_tokens")
348 .and_then(|v| v.as_u64())
349 .unwrap_or(0);
350 usage.output_tokens = u
351 .get("completion_tokens")
352 .and_then(|v| v.as_u64())
353 .unwrap_or(0);
354 }
355 continue;
356 }
357 };
358
359 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
360 && !content.is_empty()
361 {
362 debug!(
363 "Azure OpenAI text delta: {}",
364 &content[..content.len().min(80)]
365 );
366 let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
367 }
368
369 if let Some(finish) = parsed
370 .get("choices")
371 .and_then(|c| c.get(0))
372 .and_then(|c| c.get("finish_reason"))
373 .and_then(|f| f.as_str())
374 {
375 debug!("Azure OpenAI finish_reason: {finish}");
376 match finish {
377 "stop" => {
378 stop_reason = Some(StopReason::EndTurn);
379 }
380 "tool_calls" => {
381 stop_reason = Some(StopReason::ToolUse);
382 }
383 "length" => {
384 stop_reason = Some(StopReason::MaxTokens);
385 }
386 _ => {}
387 }
388 }
389
390 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
391 {
392 for tc in tool_calls {
393 if let Some(func) = tc.get("function") {
394 if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
395 if !current_tool_id.is_empty()
396 && !current_tool_args.is_empty()
397 {
398 let input: serde_json::Value =
399 serde_json::from_str(¤t_tool_args)
400 .unwrap_or_default();
401 let _ = tx
402 .send(StreamEvent::ContentBlockComplete(
403 ContentBlock::ToolUse {
404 id: current_tool_id.clone(),
405 name: current_tool_name.clone(),
406 input,
407 },
408 ))
409 .await;
410 }
411 current_tool_id = tc
412 .get("id")
413 .and_then(|i| i.as_str())
414 .unwrap_or("")
415 .to_string();
416 current_tool_name = name.to_string();
417 current_tool_args.clear();
418 }
419 if let Some(args) =
420 func.get("arguments").and_then(|a| a.as_str())
421 {
422 current_tool_args.push_str(args);
423 }
424 }
425 }
426 }
427 }
428 }
429 }
430
431 if !current_tool_id.is_empty() {
433 let input: serde_json::Value =
434 serde_json::from_str(¤t_tool_args).unwrap_or_default();
435 let _ = tx
436 .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
437 id: current_tool_id,
438 name: current_tool_name,
439 input,
440 }))
441 .await;
442 }
443
444 let _ = tx
445 .send(StreamEvent::Done {
446 usage,
447 stop_reason: Some(StopReason::EndTurn),
448 })
449 .await;
450 });
451
452 Ok(rx)
453 }
454}
455
456fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
458 if blocks.len() == 1
459 && let ContentBlock::Text { text } = &blocks[0]
460 {
461 return serde_json::Value::String(text.clone());
462 }
463
464 let parts: Vec<serde_json::Value> = blocks
465 .iter()
466 .map(|b| match b {
467 ContentBlock::Text { text } => serde_json::json!({
468 "type": "text",
469 "text": text,
470 }),
471 ContentBlock::Image { media_type, data } => serde_json::json!({
472 "type": "image_url",
473 "image_url": {
474 "url": format!("data:{media_type};base64,{data}"),
475 }
476 }),
477 ContentBlock::ToolResult {
478 tool_use_id,
479 content,
480 is_error,
481 ..
482 } => serde_json::json!({
483 "type": "tool_result",
484 "tool_use_id": tool_use_id,
485 "content": content,
486 "is_error": is_error,
487 }),
488 ContentBlock::Thinking { thinking, .. } => serde_json::json!({
489 "type": "text",
490 "text": thinking,
491 }),
492 ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
493 "type": "text",
494 "text": format!("[Tool call: {name}({input})]"),
495 }),
496 ContentBlock::Document { title, .. } => serde_json::json!({
497 "type": "text",
498 "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
499 }),
500 })
501 .collect();
502
503 serde_json::Value::Array(parts)
504}