Skip to main content

agent_code_lib/llm/
azure_openai.rs

1//! Azure OpenAI provider.
2//!
3//! Uses the same OpenAI Chat Completions wire format but with Azure-specific
4//! URL patterns and authentication. The deployment name is part of the URL,
5//! so the model field is omitted from the request body.
6//!
7//! Auth: `api-key` header by default, or `Authorization: Bearer {ad_token}`
8//! when `AZURE_OPENAI_AD_TOKEN` is set.
9
10use async_trait::async_trait;
11use futures::StreamExt;
12use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use tokio::sync::mpsc;
14use tracing::debug;
15
16use super::message::{ContentBlock, Message, StopReason, Usage};
17use super::provider::{Provider, ProviderError, ProviderRequest};
18use super::stream::StreamEvent;
19
20/// Azure OpenAI provider with `api-key` header auth and AD token support.
21pub struct AzureOpenAiProvider {
22    http: reqwest::Client,
23    base_url: String,
24    api_key: String,
25    api_version: String,
26}
27
28impl AzureOpenAiProvider {
29    pub fn new(base_url: &str, api_key: &str) -> Self {
30        let http = reqwest::Client::builder()
31            .timeout(std::time::Duration::from_secs(300))
32            .build()
33            .expect("failed to build HTTP client");
34
35        let api_version =
36            std::env::var("AZURE_OPENAI_API_VERSION").unwrap_or_else(|_| "2024-10-21".to_string());
37
38        Self {
39            http,
40            base_url: base_url.trim_end_matches('/').to_string(),
41            api_key: api_key.to_string(),
42            api_version,
43        }
44    }
45
46    /// Build the request body in OpenAI format, but without the `model` field
47    /// (Azure uses the deployment name from the URL instead).
48    fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
49        let mut messages = Vec::new();
50
51        // System message as first message.
52        if !request.system_prompt.is_empty() {
53            messages.push(serde_json::json!({
54                "role": "system",
55                "content": request.system_prompt,
56            }));
57        }
58
59        // Convert conversation messages.
60        for msg in &request.messages {
61            match msg {
62                Message::User(u) => {
63                    let content = blocks_to_openai_content(&u.content);
64                    messages.push(serde_json::json!({
65                        "role": "user",
66                        "content": content,
67                    }));
68                }
69                Message::Assistant(a) => {
70                    let mut msg_json = serde_json::json!({
71                        "role": "assistant",
72                    });
73
74                    let tool_calls: Vec<serde_json::Value> = a
75                        .content
76                        .iter()
77                        .filter_map(|b| match b {
78                            ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
79                                "id": id,
80                                "type": "function",
81                                "function": {
82                                    "name": name,
83                                    "arguments": serde_json::to_string(input).unwrap_or_default(),
84                                }
85                            })),
86                            _ => None,
87                        })
88                        .collect();
89
90                    let text: String = a
91                        .content
92                        .iter()
93                        .filter_map(|b| match b {
94                            ContentBlock::Text { text } => Some(text.as_str()),
95                            _ => None,
96                        })
97                        .collect::<Vec<_>>()
98                        .join("");
99
100                    msg_json["content"] = serde_json::Value::String(text);
101                    if !tool_calls.is_empty() {
102                        msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
103                    }
104
105                    messages.push(msg_json);
106                }
107                Message::System(_) => {} // Already handled above.
108            }
109        }
110
111        // Handle tool results (OpenAI uses role: "tool").
112        let mut final_messages = Vec::new();
113        for msg in messages {
114            if msg.get("role").and_then(|r| r.as_str()) == Some("user")
115                && let Some(content) = msg.get("content")
116                && let Some(arr) = content.as_array()
117            {
118                let mut tool_results = Vec::new();
119                let mut other_content = Vec::new();
120
121                for block in arr {
122                    if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
123                        tool_results.push(serde_json::json!({
124                                "role": "tool",
125                                "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
126                                "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
127                            }));
128                    } else {
129                        other_content.push(block.clone());
130                    }
131                }
132
133                if !tool_results.is_empty() {
134                    for tr in tool_results {
135                        final_messages.push(tr);
136                    }
137                    if !other_content.is_empty() {
138                        let mut m = msg.clone();
139                        m["content"] = serde_json::Value::Array(other_content);
140                        final_messages.push(m);
141                    }
142                    continue;
143                }
144            }
145            final_messages.push(msg);
146        }
147
148        // Build tools in OpenAI format.
149        let tools: Vec<serde_json::Value> = request
150            .tools
151            .iter()
152            .map(|t| {
153                serde_json::json!({
154                    "type": "function",
155                    "function": {
156                        "name": t.name,
157                        "description": t.description,
158                        "parameters": t.input_schema,
159                    }
160                })
161            })
162            .collect();
163
164        // Azure: no "model" field — deployment name is in the URL.
165        let mut body = serde_json::json!({
166            "messages": final_messages,
167            "stream": true,
168            "stream_options": { "include_usage": true },
169            "max_tokens": request.max_tokens,
170        });
171
172        if !tools.is_empty() {
173            body["tools"] = serde_json::Value::Array(tools);
174
175            use super::provider::ToolChoice;
176            match &request.tool_choice {
177                ToolChoice::Auto => {
178                    body["tool_choice"] = serde_json::json!("auto");
179                }
180                ToolChoice::Any => {
181                    body["tool_choice"] = serde_json::json!("required");
182                }
183                ToolChoice::None => {
184                    body["tool_choice"] = serde_json::json!("none");
185                }
186                ToolChoice::Specific(name) => {
187                    body["tool_choice"] = serde_json::json!({
188                        "type": "function",
189                        "function": { "name": name }
190                    });
191                }
192            }
193        }
194        if let Some(temp) = request.temperature {
195            body["temperature"] = serde_json::json!(temp);
196        }
197
198        body
199    }
200}
201
202#[async_trait]
203impl Provider for AzureOpenAiProvider {
204    fn name(&self) -> &str {
205        "azure-openai"
206    }
207
208    async fn stream(
209        &self,
210        request: &ProviderRequest,
211    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
212        let url = format!(
213            "{}/chat/completions?api-version={}",
214            self.base_url, self.api_version
215        );
216        let body = self.build_body(request);
217
218        let mut headers = HeaderMap::new();
219        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
220
221        // Azure AD token takes precedence over api-key header.
222        if let Ok(ad_token) = std::env::var("AZURE_OPENAI_AD_TOKEN") {
223            headers.insert(
224                AUTHORIZATION,
225                HeaderValue::from_str(&format!("Bearer {ad_token}"))
226                    .map_err(|e| ProviderError::Auth(e.to_string()))?,
227            );
228        } else {
229            headers.insert(
230                HeaderName::from_static("api-key"),
231                HeaderValue::from_str(&self.api_key)
232                    .map_err(|e| ProviderError::Auth(e.to_string()))?,
233            );
234        }
235
236        debug!("Azure OpenAI request to {url}");
237
238        let response = self
239            .http
240            .post(&url)
241            .headers(headers)
242            .json(&body)
243            .send()
244            .await
245            .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247        let status = response.status();
248        if !status.is_success() {
249            let body_text = response.text().await.unwrap_or_default();
250            return match status.as_u16() {
251                401 | 403 => Err(ProviderError::Auth(body_text)),
252                429 => Err(ProviderError::RateLimited {
253                    retry_after_ms: 1000,
254                }),
255                529 => Err(ProviderError::Overloaded),
256                413 => Err(ProviderError::RequestTooLarge(body_text)),
257                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258            };
259        }
260
261        // Parse SSE stream — identical to OpenAI format.
262        let (tx, rx) = mpsc::channel(64);
263        tokio::spawn(async move {
264            let mut byte_stream = response.bytes_stream();
265            let mut buffer = String::new();
266            let mut current_tool_id = String::new();
267            let mut current_tool_name = String::new();
268            let mut current_tool_args = String::new();
269            let mut usage = Usage::default();
270            let mut stop_reason: Option<StopReason> = None;
271
272            while let Some(chunk_result) = byte_stream.next().await {
273                let chunk = match chunk_result {
274                    Ok(c) => c,
275                    Err(e) => {
276                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
277                        break;
278                    }
279                };
280
281                buffer.push_str(&String::from_utf8_lossy(&chunk));
282
283                while let Some(pos) = buffer.find("\n\n") {
284                    let event_text = buffer[..pos].to_string();
285                    buffer = buffer[pos + 2..].to_string();
286
287                    for line in event_text.lines() {
288                        let data = if let Some(d) = line.strip_prefix("data: ") {
289                            d
290                        } else {
291                            continue;
292                        };
293
294                        if data == "[DONE]" {
295                            if !current_tool_id.is_empty() {
296                                let input: serde_json::Value =
297                                    serde_json::from_str(&current_tool_args).unwrap_or_default();
298                                let _ = tx
299                                    .send(StreamEvent::ContentBlockComplete(
300                                        ContentBlock::ToolUse {
301                                            id: current_tool_id.clone(),
302                                            name: current_tool_name.clone(),
303                                            input,
304                                        },
305                                    ))
306                                    .await;
307                                current_tool_id.clear();
308                                current_tool_name.clear();
309                                current_tool_args.clear();
310                            }
311
312                            let _ = tx
313                                .send(StreamEvent::Done {
314                                    usage: usage.clone(),
315                                    stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
316                                })
317                                .await;
318                            return;
319                        }
320
321                        let parsed: serde_json::Value = match serde_json::from_str(data) {
322                            Ok(v) => v,
323                            Err(_) => continue,
324                        };
325
326                        let delta = match parsed
327                            .get("choices")
328                            .and_then(|c| c.get(0))
329                            .and_then(|c| c.get("delta"))
330                        {
331                            Some(d) => d,
332                            None => {
333                                if let Some(u) = parsed.get("usage") {
334                                    usage.input_tokens = u
335                                        .get("prompt_tokens")
336                                        .and_then(|v| v.as_u64())
337                                        .unwrap_or(0);
338                                    usage.output_tokens = u
339                                        .get("completion_tokens")
340                                        .and_then(|v| v.as_u64())
341                                        .unwrap_or(0);
342                                }
343                                continue;
344                            }
345                        };
346
347                        if let Some(content) = delta.get("content").and_then(|c| c.as_str())
348                            && !content.is_empty()
349                        {
350                            debug!(
351                                "Azure OpenAI text delta: {}",
352                                &content[..content.len().min(80)]
353                            );
354                            let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
355                        }
356
357                        if let Some(finish) = parsed
358                            .get("choices")
359                            .and_then(|c| c.get(0))
360                            .and_then(|c| c.get("finish_reason"))
361                            .and_then(|f| f.as_str())
362                        {
363                            debug!("Azure OpenAI finish_reason: {finish}");
364                            match finish {
365                                "stop" => {
366                                    stop_reason = Some(StopReason::EndTurn);
367                                }
368                                "tool_calls" => {
369                                    stop_reason = Some(StopReason::ToolUse);
370                                }
371                                "length" => {
372                                    stop_reason = Some(StopReason::MaxTokens);
373                                }
374                                _ => {}
375                            }
376                        }
377
378                        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
379                        {
380                            for tc in tool_calls {
381                                if let Some(func) = tc.get("function") {
382                                    if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
383                                        if !current_tool_id.is_empty()
384                                            && !current_tool_args.is_empty()
385                                        {
386                                            let input: serde_json::Value =
387                                                serde_json::from_str(&current_tool_args)
388                                                    .unwrap_or_default();
389                                            let _ = tx
390                                                .send(StreamEvent::ContentBlockComplete(
391                                                    ContentBlock::ToolUse {
392                                                        id: current_tool_id.clone(),
393                                                        name: current_tool_name.clone(),
394                                                        input,
395                                                    },
396                                                ))
397                                                .await;
398                                        }
399                                        current_tool_id = tc
400                                            .get("id")
401                                            .and_then(|i| i.as_str())
402                                            .unwrap_or("")
403                                            .to_string();
404                                        current_tool_name = name.to_string();
405                                        current_tool_args.clear();
406                                    }
407                                    if let Some(args) =
408                                        func.get("arguments").and_then(|a| a.as_str())
409                                    {
410                                        current_tool_args.push_str(args);
411                                    }
412                                }
413                            }
414                        }
415                    }
416                }
417            }
418
419            // Emit any remaining tool call.
420            if !current_tool_id.is_empty() {
421                let input: serde_json::Value =
422                    serde_json::from_str(&current_tool_args).unwrap_or_default();
423                let _ = tx
424                    .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
425                        id: current_tool_id,
426                        name: current_tool_name,
427                        input,
428                    }))
429                    .await;
430            }
431
432            let _ = tx
433                .send(StreamEvent::Done {
434                    usage,
435                    stop_reason: Some(StopReason::EndTurn),
436                })
437                .await;
438        });
439
440        Ok(rx)
441    }
442}
443
444/// Convert content blocks to OpenAI format.
445fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
446    if blocks.len() == 1
447        && let ContentBlock::Text { text } = &blocks[0]
448    {
449        return serde_json::Value::String(text.clone());
450    }
451
452    let parts: Vec<serde_json::Value> = blocks
453        .iter()
454        .map(|b| match b {
455            ContentBlock::Text { text } => serde_json::json!({
456                "type": "text",
457                "text": text,
458            }),
459            ContentBlock::Image { media_type, data } => serde_json::json!({
460                "type": "image_url",
461                "image_url": {
462                    "url": format!("data:{media_type};base64,{data}"),
463                }
464            }),
465            ContentBlock::ToolResult {
466                tool_use_id,
467                content,
468                is_error,
469                ..
470            } => serde_json::json!({
471                "type": "tool_result",
472                "tool_use_id": tool_use_id,
473                "content": content,
474                "is_error": is_error,
475            }),
476            ContentBlock::Thinking { thinking, .. } => serde_json::json!({
477                "type": "text",
478                "text": thinking,
479            }),
480            ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
481                "type": "text",
482                "text": format!("[Tool call: {name}({input})]"),
483            }),
484            ContentBlock::Document { title, .. } => serde_json::json!({
485                "type": "text",
486                "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
487            }),
488        })
489        .collect();
490
491    serde_json::Value::Array(parts)
492}