Skip to main content

phi_core/provider/
azure_openai.rs

1//! Azure OpenAI provider.
2//!
3//! Uses the OpenAI Responses API format but with Azure-specific authentication
4//! and URL patterns.
5//!
6//! Base URL format: `https://{resource}.openai.azure.com/openai/deployments/{deployment}`
7//! Auth: `api-key` header or Azure AD Bearer token.
8/*
9ARCHITECTURE: AzureOpenAiProvider — Azure's twist on the OpenAI API
10
11Azure hosts OpenAI models but with several differences from api.openai.com:
12
13URL structure:
14  OpenAI:   POST https://api.openai.com/v1/responses
15  Azure:    POST https://{resource}.openai.azure.com/openai/deployments/{model}/responses?api-version=...
16
17Authentication:
18  OpenAI:   `Authorization: Bearer sk-...`
19  Azure:    `api-key: {azure_api_key}` header (NOT Authorization Bearer)
20            OR `Authorization: Bearer {azure_ad_token}` for Azure AD auth
21
22The `api-version` query parameter is required for Azure and controls which
23version of the API specification to use.
24
25Model name:
26  In Azure, the model name in the URL is the "deployment name" (user-chosen),
27  not the OpenAI model ID. The body still sends `model` for display, but the
28  actual dispatch is done by the URL path.
29
30Since the event format is the same as OpenAI Responses API, we share the same
31deserialization types (ResponsesEvent, etc. defined in this file or shared).
32*/
33
34use super::traits::*;
35use crate::types::*;
36use async_trait::async_trait;
37use futures::StreamExt;
38use reqwest_eventsource::EventSource;
39use serde::Deserialize;
40use tokio::sync::mpsc;
41use tracing::{debug, warn};
42
43/// Unit struct — no state. All logic in the `StreamProvider` impl.
44pub struct AzureOpenAiProvider;
45
46#[async_trait]
47impl StreamProvider for AzureOpenAiProvider {
48    fn provider_id(&self) -> &str {
49        "azure"
50    }
51
52    async fn stream(
53        &self,
54        config: StreamConfig, // REQUEST — api_key sent as `api-key` header (NOT Bearer); base_url includes api-version
55        tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — delegate to OpenAiCompatProvider::stream internally
56        cancel: tokio_util::sync::CancellationToken, // ABORT — forwarded to delegate
57    ) -> Result<Message, ProviderError> {
58        let model_config = &config.model_config;
59        // Resolve via CredentialProvider when set, else use the static `api_key`.
60        let api_key = model_config.resolve_api_key().await?;
61
62        /*
63        ARCHITECTURE: Azure URL construction
64
65        Azure requires an `api-version` query parameter. We append it here.
66        The `model_config.base_url` is the deployment-specific URL:
67          "https://{resource}.openai.azure.com/openai/deployments/{deployment}"
68        We append "/responses?api-version=..." to get the full endpoint.
69
70        RUST QUIRK: `format!("{}/responses?api-version=...", base_url)`
71          String formatting using positional `{}` placeholders.
72          `model_config.base_url` is `String`, which Display-formats as its content.
73          Python analogy: f"{base_url}/responses?api-version=2025-01-01-preview"
74        */
75        let url = format!(
76            "{}/responses?api-version=2025-01-01-preview",
77            model_config.base_url
78        );
79
80        let body = build_azure_request_body(&config);
81        debug!(
82            "Azure OpenAI request: model={} url={}",
83            config.model_config.id, url
84        );
85
86        let client = reqwest::Client::new();
87        let mut request = client
88            .post(&url)
89            .header("content-type", "application/json")
90            .header("api-key", &api_key); // Azure uses `api-key` header, NOT `Authorization: Bearer`
91
92        for (k, v) in &model_config.headers {
93            request = request.header(k, v);
94        }
95
96        let request = request.json(&body);
97        let mut es =
98            EventSource::new(request).map_err(|e| ProviderError::Network(e.to_string()))?;
99
100        let mut content: Vec<Content> = Vec::new();
101        let mut usage = Usage::default();
102        let mut stop_reason = StopReason::Stop;
103        let mut tool_call_buffers: Vec<ToolCallBuffer> = Vec::new();
104
105        let _ = tx.send(StreamEvent::Start);
106
107        loop {
108            tokio::select! {
109                _ = cancel.cancelled() => {
110                    es.close();
111                    return Err(ProviderError::Cancelled);
112                }
113                event = es.next() => {
114                    match event {
115                        None => break,
116                        Some(Ok(reqwest_eventsource::Event::Open)) => {}
117                        Some(Ok(reqwest_eventsource::Event::Message(msg))) => {
118                            match msg.event.as_str() {
119                                "response.output_text.delta" => {
120                                    if let Ok(data) = serde_json::from_str::<DeltaEvent>(&msg.data) {
121                                        let idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
122                                        let idx = match idx {
123                                            Some(i) => i,
124                                            None => {
125                                                content.push(Content::Text { text: String::new() });
126                                                content.len() - 1
127                                            }
128                                        };
129                                        if let Some(Content::Text { text }) = content.get_mut(idx) {
130                                            text.push_str(&data.delta);
131                                        }
132                                        let _ = tx.send(StreamEvent::TextDelta {
133                                            content_index: idx,
134                                            delta: data.delta,
135                                        });
136                                    }
137                                }
138                                "response.function_call_arguments.start" => {
139                                    if let Ok(data) = serde_json::from_str::<FnCallStartEvent>(&msg.data) {
140                                        tool_call_buffers.push(ToolCallBuffer {
141                                            id: data.call_id.unwrap_or_default(),
142                                            name: data.name.unwrap_or_default(),
143                                            arguments: String::new(),
144                                        });
145                                        let buf = tool_call_buffers.last().unwrap();
146                                        let _ = tx.send(StreamEvent::ToolCallStart {
147                                            content_index: content.len() + tool_call_buffers.len() - 1,
148                                            id: buf.id.clone(),
149                                            name: buf.name.clone(),
150                                        });
151                                    }
152                                }
153                                "response.function_call_arguments.delta" => {
154                                    if let Ok(data) = serde_json::from_str::<DeltaEvent>(&msg.data) {
155                                        if let Some(buf) = tool_call_buffers.last_mut() {
156                                            buf.arguments.push_str(&data.delta);
157                                            let _ = tx.send(StreamEvent::ToolCallDelta {
158                                                content_index: content.len() + tool_call_buffers.len() - 1,
159                                                delta: data.delta,
160                                            });
161                                        }
162                                    }
163                                }
164                                "response.completed" => {
165                                    if let Ok(data) = serde_json::from_str::<CompletedEvent>(&msg.data) {
166                                        if let Some(resp) = data.response {
167                                            if let Some(u) = resp.usage {
168                                                usage.input = u.input_tokens;
169                                                usage.output = u.output_tokens;
170                                                usage.total_tokens = u.total_tokens;
171                                            }
172                                        }
173                                    }
174                                    break;
175                                }
176                                "error" => {
177                                    warn!("Azure OpenAI error: {}", msg.data);
178                                    let err_msg = Message::Assistant {
179                                        content: vec![Content::Text { text: String::new() }],
180                                        stop_reason: StopReason::Error,
181                                        model: config.model_config.id.clone(),
182                                        provider: model_config.provider.clone(),
183                                        usage: usage.clone(),
184                                        timestamp: now_ms(),
185                                        error_message: Some(msg.data),
186                                    };
187                                    let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
188                                    return Ok(err_msg);
189                                }
190                                _ => {}
191                            }
192                        }
193                        Some(Err(e)) => {
194                            let err_str = e.to_string();
195                            warn!("Azure SSE error: {}", err_str);
196                            let err_msg = Message::Assistant {
197                                content: vec![Content::Text { text: String::new() }],
198                                stop_reason: StopReason::Error,
199                                model: config.model_config.id.clone(),
200                                provider: model_config.provider.clone(),
201                                usage: usage.clone(),
202                                timestamp: now_ms(),
203                                error_message: Some(err_str),
204                            };
205                            let _ = tx.send(StreamEvent::Error { message: err_msg.clone() });
206                            return Ok(err_msg);
207                        }
208                    }
209                }
210            }
211        }
212
213        for buf in &tool_call_buffers {
214            let args = serde_json::from_str(&buf.arguments)
215                .unwrap_or(serde_json::Value::Object(Default::default()));
216            content.push(Content::ToolCall {
217                id: buf.id.clone(),
218                name: buf.name.clone(),
219                arguments: args,
220            });
221            let _ = tx.send(StreamEvent::ToolCallEnd {
222                content_index: content.len() - 1,
223            });
224        }
225
226        if content
227            .iter()
228            .any(|c| matches!(c, Content::ToolCall { .. }))
229        {
230            stop_reason = StopReason::ToolUse;
231        }
232
233        let message = Message::Assistant {
234            content,
235            stop_reason,
236            model: config.model_config.id.clone(),
237            provider: model_config.provider.clone(),
238            usage,
239            timestamp: now_ms(),
240            error_message: None,
241        };
242
243        let _ = tx.send(StreamEvent::Done {
244            message: message.clone(),
245        });
246        Ok(message)
247    }
248}
249
250struct ToolCallBuffer {
251    id: String,
252    name: String,
253    arguments: String,
254}
255
256fn build_azure_request_body(config: &StreamConfig) -> serde_json::Value {
257    // Same format as OpenAI Responses API
258    let mut input: Vec<serde_json::Value> = Vec::new();
259
260    for msg in &config.messages {
261        match msg {
262            Message::User { content, .. } => {
263                // Build content array for user message (supports text + images)
264                let user_content: Vec<serde_json::Value> = content
265                    .iter()
266                    .filter_map(|c| match c {
267                        Content::Text { text } => Some(serde_json::json!({
268                            "type": "input_text",
269                            "text": text,
270                        })),
271                        Content::Image { data, mime_type } => Some(serde_json::json!({
272                            "type": "input_image",
273                            "image_url": format!("data:{};base64,{}", mime_type, data),
274                        })),
275                        _ => None,
276                    })
277                    .collect();
278
279                if user_content.len() == 1 && user_content[0]["type"] == "input_text" {
280                    // Simple text-only message can use shorthand format
281                    input.push(serde_json::json!({
282                        "role": "user",
283                        "content": user_content[0]["text"].as_str().unwrap_or(""),
284                    }));
285                } else {
286                    // Multi-modal content uses array format
287                    input.push(serde_json::json!({
288                        "role": "user",
289                        "content": user_content,
290                    }));
291                }
292            }
293            Message::Assistant { content, .. } => {
294                for c in content {
295                    match c {
296                        Content::Text { text } => {
297                            input.push(serde_json::json!({
298                                "type": "message",
299                                "role": "assistant",
300                                "content": [{"type": "output_text", "text": text}],
301                            }));
302                        }
303                        Content::ToolCall {
304                            id,
305                            name,
306                            arguments,
307                        } => {
308                            input.push(serde_json::json!({
309                                "type": "function_call",
310                                "call_id": id,
311                                "name": name,
312                                "arguments": arguments.to_string(),
313                            }));
314                        }
315                        _ => {}
316                    }
317                }
318            }
319            Message::ToolResult {
320                tool_call_id,
321                content,
322                ..
323            } => {
324                let output_val = if content.iter().any(|c| matches!(c, Content::Image { .. })) {
325                    let parts: Vec<serde_json::Value> = content
326                        .iter()
327                        .filter_map(|c| match c {
328                            Content::Text { text } => Some(serde_json::json!({
329                                "type": "input_text",
330                                "text": text,
331                            })),
332                            Content::Image { data, mime_type } => Some(serde_json::json!({
333                                "type": "input_image",
334                                "image_url": format!("data:{};base64,{}", mime_type, data),
335                            })),
336                            _ => None,
337                        })
338                        .collect();
339                    serde_json::json!(parts)
340                } else {
341                    let text = content
342                        .iter()
343                        .find_map(|c| match c {
344                            Content::Text { text } => Some(text.clone()),
345                            _ => None,
346                        })
347                        .unwrap_or_default();
348                    serde_json::json!(text)
349                };
350                input.push(serde_json::json!({
351                    "type": "function_call_output",
352                    "call_id": tool_call_id,
353                    "output": output_val,
354                }));
355            }
356        }
357    }
358
359    let mut body = serde_json::json!({
360        "model": config.model_config.id,
361        "stream": true,
362        "input": input,
363    });
364
365    if !config.system_prompt.is_empty() {
366        body["instructions"] = serde_json::json!(config.system_prompt);
367    }
368
369    if let Some(max) = config.max_tokens {
370        body["max_output_tokens"] = serde_json::json!(max);
371    }
372
373    if !config.tools.is_empty() {
374        let tools: Vec<serde_json::Value> = config
375            .tools
376            .iter()
377            .map(|t| {
378                serde_json::json!({
379                    "type": "function",
380                    "name": t.name,
381                    "description": t.description,
382                    "parameters": t.parameters,
383                })
384            })
385            .collect();
386        body["tools"] = serde_json::json!(tools);
387    }
388
389    if let Some(temp) = config.temperature {
390        body["temperature"] = serde_json::json!(temp);
391    }
392
393    // Structured-output wiring (Azure Responses API shares the OpenAI Responses shape:
394    // `text.format` rather than top-level `response_format`).
395    match &config.response_format {
396        ResponseFormat::Text => {} // default; omit the field
397        ResponseFormat::JsonObject => {
398            body["text"] = serde_json::json!({"format": {"type": "json_object"}});
399        }
400        ResponseFormat::JsonSchema {
401            schema,
402            name,
403            strict,
404        } => {
405            body["text"] = serde_json::json!({
406                "format": {
407                    "type": "json_schema",
408                    "name": name,
409                    "schema": schema,
410                    "strict": *strict,
411                },
412            });
413        }
414    }
415
416    body
417}
418
419// Event types
420#[derive(Deserialize)]
421struct DeltaEvent {
422    delta: String,
423}
424
425#[derive(Deserialize)]
426struct FnCallStartEvent {
427    #[serde(default)]
428    call_id: Option<String>,
429    #[serde(default)]
430    name: Option<String>,
431}
432
433#[derive(Deserialize)]
434struct CompletedEvent {
435    #[serde(default)]
436    response: Option<ResponseData>,
437}
438
439#[derive(Deserialize)]
440struct ResponseData {
441    #[serde(default)]
442    usage: Option<AzureUsage>,
443}
444
445#[derive(Deserialize)]
446struct AzureUsage {
447    #[serde(default)]
448    input_tokens: u64,
449    #[serde(default)]
450    output_tokens: u64,
451    #[serde(default)]
452    total_tokens: u64,
453}