Skip to main content

agent_code_lib/llm/
azure_openai.rs

1//! Azure OpenAI provider.
2//!
3//! Uses the same OpenAI Chat Completions wire format but with Azure-specific
4//! URL patterns and authentication. The deployment name is part of the URL,
5//! so the model field is omitted from the request body.
6//!
7//! Auth: `api-key` header by default, or `Authorization: Bearer {ad_token}`
8//! when `AZURE_OPENAI_AD_TOKEN` is set.
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use tokio::sync::mpsc;
14use tracing::debug;
15
16use super::message::{ContentBlock, Message, StopReason, Usage};
17use super::provider::{Provider, ProviderError, ProviderRequest};
18use super::stream::StreamEvent;
19
20/// Azure OpenAI provider with `api-key` header auth and AD token support.
21pub struct AzureOpenAiProvider {
22    http: reqwest::Client,
23    base_url: String,
24    api_key: String,
25    api_version: String,
26}
27
28impl AzureOpenAiProvider {
29    pub fn new(base_url: &str, api_key: &str) -> Self {
30        let http = reqwest::Client::builder()
31            .timeout(std::time::Duration::from_secs(300))
32            .build()
33            .expect("failed to build HTTP client");
34
35        let api_version =
36            std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2024-10-21".to_string());
37
38        Self {
39            http,
40            base_url: base_url.trim_end_matches('/').to_string(),
41            api_key: api_key.to_string(),
42            api_version,
43        }
44    }
45
46    /// Build the request body in OpenAI format, but without the `model` field
47    /// (Azure uses the deployment name from the URL instead).
48    fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
49        let mut messages = Vec::new();
50
51        // System message as first message.
52        if !request.system_prompt.is_empty() {
53            messages.push(serde_json::json!({
54                "role": "system",
55                "content": request.system_prompt,
56            }));
57        }
58
59        // Convert conversation messages.
60        for msg in &request.messages {
61            match msg {
62                Message::User(u) => {
63                    let content = blocks_to_openai_content(&u.content);
64                    messages.push(serde_json::json!({
65                        "role": "user",
66                        "content": content,
67                    }));
68                }
69                Message::Assistant(a) => {
70                    let mut msg_json = serde_json::json!({
71                        "role": "assistant",
72                    });
73
74                    let tool_calls: Vec<serde_json::Value> = a
75                        .content
76                        .iter()
77                        .filter_map(|b| match b {
78                            ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
79                                "id": id,
80                                "type": "function",
81                                "function": {
82                                    "name": name,
83                                    "arguments": serde_json::to_string(input).unwrap_or_default(),
84                                }
85                            })),
86                            _ => None,
87                        })
88                        .collect();
89
90                    let text: String = a
91                        .content
92                        .iter()
93                        .filter_map(|b| match b {
94                            ContentBlock::Text { text } => Some(text.as_str()),
95                            _ => None,
96                        })
97                        .collect::<Vec<_>>()
98                        .join("");
99
100                    msg_json["content"] = serde_json::Value::String(text);
101                    if !tool_calls.is_empty() {
102                        msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
103                    }
104
105                    messages.push(msg_json);
106                }
107                Message::System(_) => {} // Already handled above.
108            }
109        }
110
111        // Handle tool results (OpenAI uses role: "tool").
112        let mut final_messages = Vec::new();
113        for msg in messages {
114            if msg.get("role").and_then(|r| r.as_str()) == Some("user")
115                && let Some(content) = msg.get("content")
116                && let Some(arr) = content.as_array()
117            {
118                let mut tool_results = Vec::new();
119                let mut other_content = Vec::new();
120
121                for block in arr {
122                    if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
123                        tool_results.push(serde_json::json!({
124                                "role": "tool",
125                                "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
126                                "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
127                            }));
128                    } else {
129                        other_content.push(block.clone());
130                    }
131                }
132
133                if !tool_results.is_empty() {
134                    for tr in tool_results {
135                        final_messages.push(tr);
136                    }
137                    if !other_content.is_empty() {
138                        let mut m = msg.clone();
139                        m["content"] = serde_json::Value::Array(other_content);
140                        final_messages.push(m);
141                    }
142                    continue;
143                }
144            }
145            final_messages.push(msg);
146        }
147
148        // Build tools in OpenAI format.
149        let tools: Vec<serde_json::Value> = request
150            .tools
151            .iter()
152            .map(|t| {
153                serde_json::json!({
154                    "type": "function",
155                    "function": {
156                        "name": t.name,
157                        "description": t.description,
158                        "parameters": t.input_schema,
159                    }
160                })
161            })
162            .collect();
163
164        // Azure: no "model" field — deployment name is in the URL.
165        let mut body = serde_json::json!({
166            "messages": final_messages,
167            "stream": true,
168            "stream_options": { "include_usage": true },
169            "max_tokens": request.max_tokens,
170        });
171
172        if !tools.is_empty() {
173            body["tools"] = serde_json::Value::Array(tools);
174
175            use super::provider::ToolChoice;
176            match &request.tool_choice {
177                ToolChoice::Auto => {
178                    body["tool_choice"] = serde_json::json!("auto");
179                }
180                ToolChoice::Any => {
181                    body["tool_choice"] = serde_json::json!("required");
182                }
183                ToolChoice::None => {
184                    body["tool_choice"] = serde_json::json!("none");
185                }
186                ToolChoice::Specific(name) => {
187                    body["tool_choice"] = serde_json::json!({
188                        "type": "function",
189                        "function": { "name": name }
190                    });
191                }
192            }
193        }
194        if let Some(temp) = request.temperature {
195            body["temperature"] = serde_json::json!(temp);
196        }
197
198        body
199    }
200}
201
202#[async_trait]
203impl Provider for AzureOpenAiProvider {
204    fn name(&self) -> &str {
205        "azure-openai"
206    }
207
208    async fn stream(
209        &self,
210        request: &ProviderRequest,
211    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
212        let url = format!(
213            "{}/chat/completions?api-version={}",
214            self.base_url, self.api_version
215        );
216        let body = self.build_body(request);
217
218        let mut headers = HeaderMap::new();
219        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221        // Azure AD token takes precedence over api-key header.
222        if let Ok(ad_token) = std::env::var("AZURE_OPENAI_AD_TOKEN") {
223            headers.insert(
224                AUTHORIZATION,
225                HeaderValue::from_str(&format!("Bearer {ad_token}"))
226                    .map_err(|e| ProviderError::Auth(e.to_string()))?,
227            );
228        } else {
229            headers.insert(
230                HeaderName::from_static("api-key"),
231                HeaderValue::from_str(&self.api_key)
232                    .map_err(|e| ProviderError::Auth(e.to_string()))?,
233            );
234        }
235
236        debug!("Azure OpenAI request to {url}");
237
238        let response = self
239            .http
240            .post(&url)
241            .headers(headers)
242            .json(&body)
243            .send()
244            .await
245            .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247        let status = response.status();
248        if !status.is_success() {
249            let body_text = response.text().await.unwrap_or_default();
250            return match status.as_u16() {
251                401 | 403 => Err(ProviderError::Auth(body_text)),
252                429 => Err(ProviderError::RateLimited {
253                    retry_after_ms: 1000,
254                }),
255                529 => Err(ProviderError::Overloaded),
256                413 => Err(ProviderError::RequestTooLarge(body_text)),
257                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258            };
259        }
260
261        // Parse SSE stream — identical to OpenAI format.
262        let (tx, rx) = mpsc::channel(64);
263        let cancel = request.cancel.clone();
264        tokio::spawn(async move {
265            let mut byte_stream = response.bytes_stream();
266            let mut buffer = String::new();
267            let mut current_tool_id = String::new();
268            let mut current_tool_name = String::new();
269            let mut current_tool_args = String::new();
270            let mut usage = Usage::default();
271            let mut stop_reason: Option<StopReason> = None;
272
273            loop {
274                // Race the next SSE chunk against cancellation. On cancel,
275                // drop the byte_stream (and therefore the reqwest::Response),
276                // which aborts the underlying HTTP connection immediately.
277                let chunk_result = tokio::select! {
278                    biased;
279                    _ = cancel.cancelled() => return,
280                    chunk = byte_stream.next() => match chunk {
281                        Some(c) => c,
282                        None => break,
283                    },
284                };
285                let chunk = match chunk_result {
286                    Ok(c) => c,
287                    Err(e) => {
288                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
289                        break;
290                    }
291                };
292
293                buffer.push_str(&String::from_utf8_lossy(&chunk));
294
295                while let Some(pos) = buffer.find("\n\n") {
296                    let event_text = buffer[..pos].to_string();
297                    buffer = buffer[pos + 2..].to_string();
298
299                    for line in event_text.lines() {
300                        let data = if let Some(d) = line.strip_prefix("data: ") {
301                            d
302                        } else {
303                            continue;
304                        };
305
306                        if data == "[DONE]" {
307                            if !current_tool_id.is_empty() {
308                                let input: serde_json::Value =
309                                    serde_json::from_str(&current_tool_args).unwrap_or_default();
310                                let _ = tx
311                                    .send(StreamEvent::ContentBlockComplete(
312                                        ContentBlock::ToolUse {
313                                            id: current_tool_id.clone(),
314                                            name: current_tool_name.clone(),
315                                            input,
316                                        },
317                                    ))
318                                    .await;
319                                current_tool_id.clear();
320                                current_tool_name.clear();
321                                current_tool_args.clear();
322                            }
323
324                            let _ = tx
325                                .send(StreamEvent::Done {
326                                    usage: usage.clone(),
327                                    stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
328                                })
329                                .await;
330                            return;
331                        }
332
333                        let parsed: serde_json::Value = match serde_json::from_str(data) {
334                            Ok(v) => v,
335                            Err(_) => continue,
336                        };
337
338                        let delta = match parsed
339                            .get("choices")
340                            .and_then(|c| c.get(0))
341                            .and_then(|c| c.get("delta"))
342                        {
343                            Some(d) => d,
344                            None => {
345                                if let Some(u) = parsed.get("usage") {
346                                    usage.input_tokens = u
347                                        .get("prompt_tokens")
348                                        .and_then(|v| v.as_u64())
349                                        .unwrap_or(0);
350                                    usage.output_tokens = u
351                                        .get("completion_tokens")
352                                        .and_then(|v| v.as_u64())
353                                        .unwrap_or(0);
354                                }
355                                continue;
356                            }
357                        };
358
359                        if let Some(content) = delta.get("content").and_then(|c| c.as_str())
360                            && !content.is_empty()
361                        {
362                            debug!(
363                                "Azure OpenAI text delta: {}",
364                                &content[..content.len().min(80)]
365                            );
366                            let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
367                        }
368
369                        if let Some(finish) = parsed
370                            .get("choices")
371                            .and_then(|c| c.get(0))
372                            .and_then(|c| c.get("finish_reason"))
373                            .and_then(|f| f.as_str())
374                        {
375                            debug!("Azure OpenAI finish_reason: {finish}");
376                            match finish {
377                                "stop" => {
378                                    stop_reason = Some(StopReason::EndTurn);
379                                }
380                                "tool_calls" => {
381                                    stop_reason = Some(StopReason::ToolUse);
382                                }
383                                "length" => {
384                                    stop_reason = Some(StopReason::MaxTokens);
385                                }
386                                _ => {}
387                            }
388                        }
389
390                        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
391                        {
392                            for tc in tool_calls {
393                                if let Some(func) = tc.get("function") {
394                                    if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
395                                        if !current_tool_id.is_empty()
396                                            && !current_tool_args.is_empty()
397                                        {
398                                            let input: serde_json::Value =
399                                                serde_json::from_str(&current_tool_args)
400                                                    .unwrap_or_default();
401                                            let _ = tx
402                                                .send(StreamEvent::ContentBlockComplete(
403                                                    ContentBlock::ToolUse {
404                                                        id: current_tool_id.clone(),
405                                                        name: current_tool_name.clone(),
406                                                        input,
407                                                    },
408                                                ))
409                                                .await;
410                                        }
411                                        current_tool_id = tc
412                                            .get("id")
413                                            .and_then(|i| i.as_str())
414                                            .unwrap_or("")
415                                            .to_string();
416                                        current_tool_name = name.to_string();
417                                        current_tool_args.clear();
418                                    }
419                                    if let Some(args) =
420                                        func.get("arguments").and_then(|a| a.as_str())
421                                    {
422                                        current_tool_args.push_str(args);
423                                    }
424                                }
425                            }
426                        }
427                    }
428                }
429            }
430
431            // Emit any remaining tool call.
432            if !current_tool_id.is_empty() {
433                let input: serde_json::Value =
434                    serde_json::from_str(&current_tool_args).unwrap_or_default();
435                let _ = tx
436                    .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
437                        id: current_tool_id,
438                        name: current_tool_name,
439                        input,
440                    }))
441                    .await;
442            }
443
444            let _ = tx
445                .send(StreamEvent::Done {
446                    usage,
447                    stop_reason: Some(StopReason::EndTurn),
448                })
449                .await;
450        });
451
452        Ok(rx)
453    }
454}
455
456/// Convert content blocks to OpenAI format.
457fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
458    if blocks.len() == 1
459        && let ContentBlock::Text { text } = &blocks[0]
460    {
461        return serde_json::Value::String(text.clone());
462    }
463
464    let parts: Vec<serde_json::Value> = blocks
465        .iter()
466        .map(|b| match b {
467            ContentBlock::Text { text } => serde_json::json!({
468                "type": "text",
469                "text": text,
470            }),
471            ContentBlock::Image { media_type, data } => serde_json::json!({
472                "type": "image_url",
473                "image_url": {
474                    "url": format!("data:{media_type};base64,{data}"),
475                }
476            }),
477            ContentBlock::ToolResult {
478                tool_use_id,
479                content,
480                is_error,
481                ..
482            } => serde_json::json!({
483                "type": "tool_result",
484                "tool_use_id": tool_use_id,
485                "content": content,
486                "is_error": is_error,
487            }),
488            ContentBlock::Thinking { thinking, .. } => serde_json::json!({
489                "type": "text",
490                "text": thinking,
491            }),
492            ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
493                "type": "text",
494                "text": format!("[Tool call: {name}({input})]"),
495            }),
496            ContentBlock::Document { title, .. } => serde_json::json!({
497                "type": "text",
498                "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
499            }),
500        })
501        .collect();
502
503    serde_json::Value::Array(parts)
504}