1use async_trait::async_trait;
11use futures::StreamExt;
12use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use tokio::sync::mpsc;
14use tracing::debug;
15
16use super::message::{ContentBlock, Message, StopReason, Usage};
17use super::provider::{Provider, ProviderError, ProviderRequest};
18use super::stream::StreamEvent;
19
20pub struct AzureOpenAiProvider {
22 http: reqwest::Client,
23 base_url: String,
24 api_key: String,
25 api_version: String,
26}
27
28impl AzureOpenAiProvider {
29 pub fn new(base_url: &str, api_key: &str) -> Self {
30 let http = reqwest::Client::builder()
31 .timeout(std::time::Duration::from_secs(300))
32 .build()
33 .expect("failed to build HTTP client");
34
35 let api_version =
36 std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2024-10-21".to_string());
37
38 Self {
39 http,
40 base_url: base_url.trim_end_matches('/').to_string(),
41 api_key: api_key.to_string(),
42 api_version,
43 }
44 }
45
46 fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
49 let mut messages = Vec::new();
50
51 if !request.system_prompt.is_empty() {
53 messages.push(serde_json::json!({
54 "role": "system",
55 "content": request.system_prompt,
56 }));
57 }
58
59 for msg in &request.messages {
61 match msg {
62 Message::User(u) => {
63 let content = blocks_to_openai_content(&u.content);
64 messages.push(serde_json::json!({
65 "role": "user",
66 "content": content,
67 }));
68 }
69 Message::Assistant(a) => {
70 let mut msg_json = serde_json::json!({
71 "role": "assistant",
72 });
73
74 let tool_calls: Vec<serde_json::Value> = a
75 .content
76 .iter()
77 .filter_map(|b| match b {
78 ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
79 "id": id,
80 "type": "function",
81 "function": {
82 "name": name,
83 "arguments": serde_json::to_string(input).unwrap_or_default(),
84 }
85 })),
86 _ => None,
87 })
88 .collect();
89
90 let text: String = a
91 .content
92 .iter()
93 .filter_map(|b| match b {
94 ContentBlock::Text { text } => Some(text.as_str()),
95 _ => None,
96 })
97 .collect::<Vec<_>>()
98 .join("");
99
100 msg_json["content"] = serde_json::Value::String(text);
101 if !tool_calls.is_empty() {
102 msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
103 }
104
105 messages.push(msg_json);
106 }
107 Message::System(_) => {} }
109 }
110
111 let mut final_messages = Vec::new();
113 for msg in messages {
114 if msg.get("role").and_then(|r| r.as_str()) == Some("user")
115 && let Some(content) = msg.get("content")
116 && let Some(arr) = content.as_array()
117 {
118 let mut tool_results = Vec::new();
119 let mut other_content = Vec::new();
120
121 for block in arr {
122 if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
123 tool_results.push(serde_json::json!({
124 "role": "tool",
125 "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
126 "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
127 }));
128 } else {
129 other_content.push(block.clone());
130 }
131 }
132
133 if !tool_results.is_empty() {
134 for tr in tool_results {
135 final_messages.push(tr);
136 }
137 if !other_content.is_empty() {
138 let mut m = msg.clone();
139 m["content"] = serde_json::Value::Array(other_content);
140 final_messages.push(m);
141 }
142 continue;
143 }
144 }
145 final_messages.push(msg);
146 }
147
148 let tools: Vec<serde_json::Value> = request
150 .tools
151 .iter()
152 .map(|t| {
153 serde_json::json!({
154 "type": "function",
155 "function": {
156 "name": t.name,
157 "description": t.description,
158 "parameters": t.input_schema,
159 }
160 })
161 })
162 .collect();
163
164 let mut body = serde_json::json!({
166 "messages": final_messages,
167 "stream": true,
168 "stream_options": { "include_usage": true },
169 "max_tokens": request.max_tokens,
170 });
171
172 if !tools.is_empty() {
173 body["tools"] = serde_json::Value::Array(tools);
174
175 use super::provider::ToolChoice;
176 match &request.tool_choice {
177 ToolChoice::Auto => {
178 body["tool_choice"] = serde_json::json!("auto");
179 }
180 ToolChoice::Any => {
181 body["tool_choice"] = serde_json::json!("required");
182 }
183 ToolChoice::None => {
184 body["tool_choice"] = serde_json::json!("none");
185 }
186 ToolChoice::Specific(name) => {
187 body["tool_choice"] = serde_json::json!({
188 "type": "function",
189 "function": { "name": name }
190 });
191 }
192 }
193 }
194 if let Some(temp) = request.temperature {
195 body["temperature"] = serde_json::json!(temp);
196 }
197
198 body
199 }
200}
201
202#[async_trait]
203impl Provider for AzureOpenAiProvider {
204 fn name(&self) -> &str {
205 "azure-openai"
206 }
207
208 async fn stream(
209 &self,
210 request: &ProviderRequest,
211 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
212 let url = format!(
213 "{}/chat/completions?api-version={}",
214 self.base_url, self.api_version
215 );
216 let body = self.build_body(request);
217
218 let mut headers = HeaderMap::new();
219 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221 if let Ok(ad_token) = std::env::var("AZURE_OPENAI_AD_TOKEN") {
223 headers.insert(
224 AUTHORIZATION,
225 HeaderValue::from_str(&format!("Bearer {ad_token}"))
226 .map_err(|e| ProviderError::Auth(e.to_string()))?,
227 );
228 } else {
229 headers.insert(
230 HeaderName::from_static("api-key"),
231 HeaderValue::from_str(&self.api_key)
232 .map_err(|e| ProviderError::Auth(e.to_string()))?,
233 );
234 }
235
236 debug!("Azure OpenAI request to {url}");
237
238 let response = self
239 .http
240 .post(&url)
241 .headers(headers)
242 .json(&body)
243 .send()
244 .await
245 .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247 let status = response.status();
248 if !status.is_success() {
249 let body_text = response.text().await.unwrap_or_default();
250 return match status.as_u16() {
251 401 | 403 => Err(ProviderError::Auth(body_text)),
252 429 => Err(ProviderError::RateLimited {
253 retry_after_ms: 1000,
254 }),
255 529 => Err(ProviderError::Overloaded),
256 413 => Err(ProviderError::RequestTooLarge(body_text)),
257 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258 };
259 }
260
261 let (tx, rx) = mpsc::channel(64);
263 tokio::spawn(async move {
264 let mut byte_stream = response.bytes_stream();
265 let mut buffer = String::new();
266 let mut current_tool_id = String::new();
267 let mut current_tool_name = String::new();
268 let mut current_tool_args = String::new();
269 let mut usage = Usage::default();
270 let mut stop_reason: Option<StopReason> = None;
271
272 while let Some(chunk_result) = byte_stream.next().await {
273 let chunk = match chunk_result {
274 Ok(c) => c,
275 Err(e) => {
276 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
277 break;
278 }
279 };
280
281 buffer.push_str(&String::from_utf8_lossy(&chunk));
282
283 while let Some(pos) = buffer.find("\n\n") {
284 let event_text = buffer[..pos].to_string();
285 buffer = buffer[pos + 2..].to_string();
286
287 for line in event_text.lines() {
288 let data = if let Some(d) = line.strip_prefix("data: ") {
289 d
290 } else {
291 continue;
292 };
293
294 if data == "[DONE]" {
295 if !current_tool_id.is_empty() {
296 let input: serde_json::Value =
297 serde_json::from_str(¤t_tool_args).unwrap_or_default();
298 let _ = tx
299 .send(StreamEvent::ContentBlockComplete(
300 ContentBlock::ToolUse {
301 id: current_tool_id.clone(),
302 name: current_tool_name.clone(),
303 input,
304 },
305 ))
306 .await;
307 current_tool_id.clear();
308 current_tool_name.clear();
309 current_tool_args.clear();
310 }
311
312 let _ = tx
313 .send(StreamEvent::Done {
314 usage: usage.clone(),
315 stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
316 })
317 .await;
318 return;
319 }
320
321 let parsed: serde_json::Value = match serde_json::from_str(data) {
322 Ok(v) => v,
323 Err(_) => continue,
324 };
325
326 let delta = match parsed
327 .get("choices")
328 .and_then(|c| c.get(0))
329 .and_then(|c| c.get("delta"))
330 {
331 Some(d) => d,
332 None => {
333 if let Some(u) = parsed.get("usage") {
334 usage.input_tokens = u
335 .get("prompt_tokens")
336 .and_then(|v| v.as_u64())
337 .unwrap_or(0);
338 usage.output_tokens = u
339 .get("completion_tokens")
340 .and_then(|v| v.as_u64())
341 .unwrap_or(0);
342 }
343 continue;
344 }
345 };
346
347 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
348 && !content.is_empty()
349 {
350 debug!(
351 "Azure OpenAI text delta: {}",
352 &content[..content.len().min(80)]
353 );
354 let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
355 }
356
357 if let Some(finish) = parsed
358 .get("choices")
359 .and_then(|c| c.get(0))
360 .and_then(|c| c.get("finish_reason"))
361 .and_then(|f| f.as_str())
362 {
363 debug!("Azure OpenAI finish_reason: {finish}");
364 match finish {
365 "stop" => {
366 stop_reason = Some(StopReason::EndTurn);
367 }
368 "tool_calls" => {
369 stop_reason = Some(StopReason::ToolUse);
370 }
371 "length" => {
372 stop_reason = Some(StopReason::MaxTokens);
373 }
374 _ => {}
375 }
376 }
377
378 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
379 {
380 for tc in tool_calls {
381 if let Some(func) = tc.get("function") {
382 if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
383 if !current_tool_id.is_empty()
384 && !current_tool_args.is_empty()
385 {
386 let input: serde_json::Value =
387 serde_json::from_str(¤t_tool_args)
388 .unwrap_or_default();
389 let _ = tx
390 .send(StreamEvent::ContentBlockComplete(
391 ContentBlock::ToolUse {
392 id: current_tool_id.clone(),
393 name: current_tool_name.clone(),
394 input,
395 },
396 ))
397 .await;
398 }
399 current_tool_id = tc
400 .get("id")
401 .and_then(|i| i.as_str())
402 .unwrap_or("")
403 .to_string();
404 current_tool_name = name.to_string();
405 current_tool_args.clear();
406 }
407 if let Some(args) =
408 func.get("arguments").and_then(|a| a.as_str())
409 {
410 current_tool_args.push_str(args);
411 }
412 }
413 }
414 }
415 }
416 }
417 }
418
419 if !current_tool_id.is_empty() {
421 let input: serde_json::Value =
422 serde_json::from_str(¤t_tool_args).unwrap_or_default();
423 let _ = tx
424 .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
425 id: current_tool_id,
426 name: current_tool_name,
427 input,
428 }))
429 .await;
430 }
431
432 let _ = tx
433 .send(StreamEvent::Done {
434 usage,
435 stop_reason: Some(StopReason::EndTurn),
436 })
437 .await;
438 });
439
440 Ok(rx)
441 }
442}
443
444fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
446 if blocks.len() == 1
447 && let ContentBlock::Text { text } = &blocks[0]
448 {
449 return serde_json::Value::String(text.clone());
450 }
451
452 let parts: Vec<serde_json::Value> = blocks
453 .iter()
454 .map(|b| match b {
455 ContentBlock::Text { text } => serde_json::json!({
456 "type": "text",
457 "text": text,
458 }),
459 ContentBlock::Image { media_type, data } => serde_json::json!({
460 "type": "image_url",
461 "image_url": {
462 "url": format!("data:{media_type};base64,{data}"),
463 }
464 }),
465 ContentBlock::ToolResult {
466 tool_use_id,
467 content,
468 is_error,
469 ..
470 } => serde_json::json!({
471 "type": "tool_result",
472 "tool_use_id": tool_use_id,
473 "content": content,
474 "is_error": is_error,
475 }),
476 ContentBlock::Thinking { thinking, .. } => serde_json::json!({
477 "type": "text",
478 "text": thinking,
479 }),
480 ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
481 "type": "text",
482 "text": format!("[Tool call: {name}({input})]"),
483 }),
484 ContentBlock::Document { title, .. } => serde_json::json!({
485 "type": "text",
486 "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
487 }),
488 })
489 .collect();
490
491 serde_json::Value::Array(parts)
492}