Skip to main content

pi/providers/
azure.rs

1//! Azure OpenAI Chat Completions API provider implementation.
2//!
3//! This module implements the Provider trait for Azure OpenAI, using the same
4//! streaming protocol as OpenAI but with Azure-specific authentication and endpoints.
5//!
6//! Azure OpenAI URL format:
7//! `https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}`
8
9use crate::error::{Error, Result};
10use crate::http::client::Client;
11use crate::model::{
12    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::collections::VecDeque;
22use std::pin::Pin;
23
24// ============================================================================
25// Constants
26// ============================================================================
27
28pub(crate) const DEFAULT_API_VERSION: &str = "2024-02-15-preview";
29const DEFAULT_MAX_TOKENS: u32 = 4096;
30
31/// Normalize Azure role names while preserving unknown compat overrides as-is.
32fn normalize_role(role: &str) -> String {
33    let trimmed = role.trim();
34    match trimmed {
35        "system" | "developer" | "user" | "assistant" | "tool" | "function" => trimmed.to_string(),
36        _ => {
37            let lowered = trimmed.to_ascii_lowercase();
38            match lowered.as_str() {
39                "system" | "developer" | "user" | "assistant" | "tool" | "function" => lowered,
40                _ => trimmed.to_string(),
41            }
42        }
43    }
44}
45
46// ============================================================================
47// Azure OpenAI Provider
48// ============================================================================
49
50/// Azure OpenAI Chat Completions API provider.
51pub struct AzureOpenAIProvider {
52    client: Client,
53    /// The deployment name (model deployment in Azure)
54    deployment: String,
55    /// Azure resource name (part of the URL)
56    resource: String,
57    /// API version string
58    api_version: String,
59    /// Optional override for the full endpoint URL (primarily for deterministic tests).
60    endpoint_url_override: Option<String>,
61    compat: Option<CompatConfig>,
62}
63
64impl AzureOpenAIProvider {
65    /// Create a new Azure OpenAI provider.
66    ///
67    /// # Arguments
68    /// * `resource` - Azure OpenAI resource name
69    /// * `deployment` - Model deployment name
70    pub fn new(resource: impl Into<String>, deployment: impl Into<String>) -> Self {
71        Self {
72            client: Client::new(),
73            deployment: deployment.into(),
74            resource: resource.into(),
75            api_version: DEFAULT_API_VERSION.to_string(),
76            endpoint_url_override: None,
77            compat: None,
78        }
79    }
80
81    /// Set the API version.
82    #[must_use]
83    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
84        self.api_version = version.into();
85        self
86    }
87
88    /// Override the full endpoint URL.
89    ///
90    /// This is intended for deterministic, offline tests (e.g. mock servers). Production
91    /// code should rely on the standard Azure endpoint URL format.
92    #[must_use]
93    pub fn with_endpoint_url(mut self, endpoint_url: impl Into<String>) -> Self {
94        self.endpoint_url_override = Some(endpoint_url.into());
95        self
96    }
97
98    /// Create with a custom HTTP client (VCR, test harness, etc.).
99    #[must_use]
100    pub fn with_client(mut self, client: Client) -> Self {
101        self.client = client;
102        self
103    }
104
105    /// Attach provider-specific compatibility overrides.
106    #[must_use]
107    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
108        self.compat = compat;
109        self
110    }
111
112    /// Get the full endpoint URL.
113    fn endpoint_url(&self) -> String {
114        if let Some(url) = &self.endpoint_url_override {
115            return url.clone();
116        }
117        format!(
118            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
119            self.resource, self.deployment, self.api_version
120        )
121    }
122
123    /// Build the request body for Azure OpenAI (same format as OpenAI).
124    #[allow(clippy::unused_self)]
125    pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> AzureRequest {
126        let messages = self.build_messages(context);
127
128        let tools: Option<Vec<AzureTool>> = if context.tools.is_empty() {
129            None
130        } else {
131            Some(context.tools.iter().map(convert_tool_to_azure).collect())
132        };
133
134        AzureRequest {
135            messages,
136            max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
137            temperature: options.temperature,
138            tools,
139            stream: true,
140            stream_options: Some(AzureStreamOptions {
141                include_usage: true,
142            }),
143        }
144    }
145
146    /// Build the messages array with system prompt prepended.
147    fn build_messages(&self, context: &Context<'_>) -> Vec<AzureMessage> {
148        let mut messages = Vec::new();
149        let system_role = self
150            .compat
151            .as_ref()
152            .and_then(|c| c.system_role_name.as_deref())
153            .unwrap_or("system");
154
155        // Add system prompt as first message
156        if let Some(system) = &context.system_prompt {
157            messages.push(AzureMessage {
158                role: normalize_role(system_role),
159                content: Some(AzureContent::Text(system.to_string())),
160                tool_calls: None,
161                tool_call_id: None,
162            });
163        }
164
165        // Convert conversation messages
166        for message in context.messages.iter() {
167            messages.extend(convert_message_to_azure(message));
168        }
169
170        messages
171    }
172}
173
174#[async_trait]
175impl Provider for AzureOpenAIProvider {
176    fn name(&self) -> &'static str {
177        "azure"
178    }
179
180    fn api(&self) -> &'static str {
181        "azure-openai"
182    }
183
184    fn model_id(&self) -> &str {
185        &self.deployment
186    }
187
188    async fn stream(
189        &self,
190        context: &Context<'_>,
191        options: &StreamOptions,
192    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
193        let auth_value = options
194            .api_key
195            .clone()
196            .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
197            .ok_or_else(|| Error::provider("azure-openai", "Missing API key for Azure OpenAI. Set AZURE_OPENAI_API_KEY or configure in settings."))?;
198
199        let request_body = self.build_request(context, options);
200
201        let endpoint_url = self.endpoint_url();
202
203        // Build request with Azure-specific headers (Content-Type set by .json() below)
204        let mut request = self
205            .client
206            .post(&endpoint_url)
207            .header("Accept", "text/event-stream")
208            .header("api-key", &auth_value); // Azure uses api-key header, not Authorization
209
210        // Apply provider-specific custom headers from compat config.
211        if let Some(compat) = &self.compat {
212            if let Some(custom_headers) = &compat.custom_headers {
213                for (key, value) in custom_headers {
214                    request = request.header(key, value);
215                }
216            }
217        }
218
219        for (key, value) in &options.headers {
220            request = request.header(key, value);
221        }
222
223        let request = request.json(&request_body)?;
224
225        let response = Box::pin(request.send()).await?;
226        let status = response.status();
227        if !(200..300).contains(&status) {
228            let body = response
229                .text()
230                .await
231                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
232            return Err(Error::provider(
233                "azure-openai",
234                format!("Azure OpenAI API error (HTTP {status}): {body}"),
235            ));
236        }
237
238        // Create SSE stream for streaming responses.
239        let event_source = SseStream::new(response.bytes_stream());
240
241        // Create stream state
242        let model = self.deployment.clone();
243        let api = self.api().to_string();
244        let provider = self.name().to_string();
245
246        let stream = stream::unfold(
247            StreamState::new(event_source, model, api, provider),
248            |mut state| async move {
249                if state.done {
250                    return None;
251                }
252                loop {
253                    if let Some(event) = state.pending_events.pop_front() {
254                        return Some((Ok(event), state));
255                    }
256
257                    match state.event_source.next().await {
258                        Some(Ok(msg)) => {
259                            // Azure also sends "[DONE]" as final message
260                            if msg.data == "[DONE]" {
261                                state.done = true;
262                                let reason = state.partial.stop_reason;
263                                let message = std::mem::take(&mut state.partial);
264                                return Some((Ok(StreamEvent::Done { reason, message }), state));
265                            }
266
267                            if let Err(e) = state.process_event(&msg.data) {
268                                state.done = true;
269                                return Some((Err(e), state));
270                            }
271                        }
272                        Some(Err(e)) => {
273                            state.done = true;
274                            let err = Error::api(format!("SSE error: {e}"));
275                            return Some((Err(err), state));
276                        }
277                        // Stream ended without [DONE] sentinel (e.g.
278                        // premature server disconnect).  Emit Done so the
279                        // agent loop receives the accumulated partial
280                        // instead of silently losing it.
281                        None => {
282                            state.done = true;
283                            let reason = state.partial.stop_reason;
284                            let message = std::mem::take(&mut state.partial);
285                            return Some((Ok(StreamEvent::Done { reason, message }), state));
286                        }
287                    }
288                }
289            },
290        );
291
292        Ok(Box::pin(stream))
293    }
294}
295
296// ============================================================================
297// Stream State
298// ============================================================================
299
300struct StreamState<S>
301where
302    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
303{
304    event_source: SseStream<S>,
305    partial: AssistantMessage,
306    tool_calls: Vec<ToolCallState>,
307    pending_events: VecDeque<StreamEvent>,
308    started: bool,
309    done: bool,
310}
311
312struct ToolCallState {
313    index: usize,
314    content_index: usize,
315    id: String,
316    name: String,
317    arguments: String,
318}
319
320impl<S> StreamState<S>
321where
322    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
323{
324    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
325        Self {
326            event_source,
327            partial: AssistantMessage {
328                content: Vec::new(),
329                api,
330                provider,
331                model,
332                usage: Usage::default(),
333                stop_reason: StopReason::Stop,
334                error_message: None,
335                timestamp: chrono::Utc::now().timestamp_millis(),
336            },
337            tool_calls: Vec::new(),
338            pending_events: VecDeque::new(),
339            started: false,
340            done: false,
341        }
342    }
343
344    fn finalize_tool_call_arguments(&mut self) {
345        for tc in &self.tool_calls {
346            let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
347                Ok(args) => args,
348                Err(e) => {
349                    tracing::warn!(
350                        error = %e,
351                        raw = %tc.arguments,
352                        "Failed to parse tool arguments as JSON"
353                    );
354                    serde_json::Value::Null
355                }
356            };
357
358            if let Some(ContentBlock::ToolCall(block)) =
359                self.partial.content.get_mut(tc.content_index)
360            {
361                block.arguments = arguments;
362            }
363        }
364    }
365
366    fn push_text_delta(&mut self, text: String) -> StreamEvent {
367        let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
368        if !last_is_text {
369            let content_index = self.partial.content.len();
370            self.partial
371                .content
372                .push(ContentBlock::Text(crate::model::TextContent::new("")));
373            self.pending_events
374                .push_back(StreamEvent::TextStart { content_index });
375        }
376        let content_index = self.partial.content.len() - 1;
377
378        if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
379            t.text.push_str(&text);
380        }
381
382        StreamEvent::TextDelta {
383            content_index,
384            delta: text,
385        }
386    }
387
388    fn ensure_started(&mut self) {
389        if !self.started {
390            self.started = true;
391            self.pending_events.push_back(StreamEvent::Start {
392                partial: self.partial.clone(),
393            });
394        }
395    }
396
397    #[allow(clippy::unnecessary_wraps, clippy::too_many_lines)]
398    fn process_event(&mut self, data: &str) -> Result<()> {
399        let chunk: AzureStreamChunk =
400            serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
401
402        // Process usage if present
403        if let Some(usage) = chunk.usage {
404            self.partial.usage.input = usage.prompt_tokens;
405            self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
406            self.partial.usage.total_tokens = usage.total_tokens;
407        }
408
409        let choices = chunk.choices;
410        if !self.started {
411            let first = choices.first();
412            let delta_is_empty = first.is_some_and(|choice| {
413                choice.finish_reason.is_none()
414                    && choice.delta.content.is_none()
415                    && choice.delta.tool_calls.is_none()
416            });
417            if delta_is_empty {
418                self.ensure_started();
419                return Ok(());
420            }
421        }
422
423        // Process choices — handle content deltas BEFORE finish_reason so that
424        // TextEnd/ToolCallEnd events always follow the final delta (matching the
425        // OpenAI provider event ordering contract).
426        for choice in choices {
427            // Handle text content
428            if let Some(text) = choice.delta.content {
429                self.ensure_started();
430                let event = self.push_text_delta(text);
431                self.pending_events.push_back(event);
432            }
433
434            // Handle tool calls
435            if let Some(tool_calls) = choice.delta.tool_calls {
436                self.ensure_started();
437
438                for tc in tool_calls {
439                    let idx = tc.index as usize;
440
441                    // Azure may emit sparse tool-call indices. Match by logical index
442                    // instead of assuming contiguous 0..N ordering in arrival order.
443                    let tool_state_idx = if let Some(existing_idx) =
444                        self.tool_calls.iter().position(|tc| tc.index == idx)
445                    {
446                        existing_idx
447                    } else {
448                        let content_index = self.partial.content.len();
449                        self.tool_calls.push(ToolCallState {
450                            index: idx,
451                            content_index,
452                            id: String::new(),
453                            name: String::new(),
454                            arguments: String::new(),
455                        });
456
457                        // Initialize block in partial
458                        self.partial
459                            .content
460                            .push(ContentBlock::ToolCall(crate::model::ToolCall {
461                                id: String::new(),
462                                name: String::new(),
463                                arguments: serde_json::Value::Null,
464                                thought_signature: None,
465                            }));
466
467                        // Emit ToolCallStart
468                        self.pending_events
469                            .push_back(StreamEvent::ToolCallStart { content_index });
470                        self.tool_calls.len() - 1
471                    };
472
473                    let tc_state = &mut self.tool_calls[tool_state_idx];
474                    let content_index = tc_state.content_index;
475
476                    // Update the tool call state
477                    if let Some(id) = tc.id {
478                        tc_state.id.clone_from(&id);
479                        if let Some(ContentBlock::ToolCall(block)) =
480                            self.partial.content.get_mut(content_index)
481                        {
482                            block.id = id;
483                        }
484                    }
485                    if let Some(func) = tc.function {
486                        if let Some(name) = func.name {
487                            tc_state.name.clone_from(&name);
488                            if let Some(ContentBlock::ToolCall(block)) =
489                                self.partial.content.get_mut(content_index)
490                            {
491                                block.name = name;
492                            }
493                        }
494                        if let Some(args) = func.arguments {
495                            tc_state.arguments.push_str(&args);
496                            // Note: we don't update partial arguments here as they need to be valid JSON.
497                            // We do that at the end.
498
499                            self.pending_events.push_back(StreamEvent::ToolCallDelta {
500                                content_index,
501                                delta: args,
502                            });
503                        }
504                    }
505                }
506            }
507
508            // Handle finish reason (MUST come after delta processing so TextEnd/ToolCallEnd
509            // events contain the complete accumulated content).
510            // Ensure Start is emitted even when finish arrives in an empty-delta chunk.
511            if choice.finish_reason.is_some() {
512                self.ensure_started();
513            }
514            if let Some(reason) = choice.finish_reason {
515                self.partial.stop_reason = match reason.as_str() {
516                    "length" => StopReason::Length,
517                    "content_filter" => StopReason::Error,
518                    "tool_calls" => StopReason::ToolUse,
519                    // "stop" and any other reason treated as normal stop
520                    _ => StopReason::Stop,
521                };
522
523                // Finalize tool call arguments
524                self.finalize_tool_call_arguments();
525
526                // Emit TextEnd/ThinkingEnd for all open text/thinking blocks.
527                for (content_index, block) in self.partial.content.iter().enumerate() {
528                    if let ContentBlock::Text(t) = block {
529                        self.pending_events.push_back(StreamEvent::TextEnd {
530                            content_index,
531                            content: t.text.clone(),
532                        });
533                    } else if let ContentBlock::Thinking(t) = block {
534                        self.pending_events.push_back(StreamEvent::ThinkingEnd {
535                            content_index,
536                            content: t.thinking.clone(),
537                        });
538                    }
539                }
540
541                // Emit ToolCallEnd for each accumulated tool call
542                for tc in &self.tool_calls {
543                    if let Some(ContentBlock::ToolCall(tool_call)) =
544                        self.partial.content.get(tc.content_index)
545                    {
546                        self.pending_events.push_back(StreamEvent::ToolCallEnd {
547                            content_index: tc.content_index,
548                            tool_call: tool_call.clone(),
549                        });
550                    }
551                }
552            }
553        }
554
555        Ok(())
556    }
557}
558
559// ============================================================================
560// Request Types
561// ============================================================================
562
563#[derive(Debug, Serialize)]
564pub struct AzureRequest {
565    messages: Vec<AzureMessage>,
566    #[serde(skip_serializing_if = "Option::is_none")]
567    max_tokens: Option<u32>,
568    #[serde(skip_serializing_if = "Option::is_none")]
569    temperature: Option<f32>,
570    #[serde(skip_serializing_if = "Option::is_none")]
571    tools: Option<Vec<AzureTool>>,
572    stream: bool,
573    #[serde(skip_serializing_if = "Option::is_none")]
574    stream_options: Option<AzureStreamOptions>,
575}
576
577#[derive(Debug, Serialize)]
578struct AzureStreamOptions {
579    include_usage: bool,
580}
581
582#[derive(Debug, Serialize)]
583struct AzureMessage {
584    role: String,
585    #[serde(skip_serializing_if = "Option::is_none")]
586    content: Option<AzureContent>,
587    #[serde(skip_serializing_if = "Option::is_none")]
588    tool_calls: Option<Vec<AzureToolCallRef>>,
589    #[serde(skip_serializing_if = "Option::is_none")]
590    tool_call_id: Option<String>,
591}
592
593#[derive(Debug, Serialize)]
594#[serde(untagged)]
595enum AzureContent {
596    Text(String),
597    Parts(Vec<AzureContentPart>),
598}
599
600#[derive(Debug, Serialize)]
601#[serde(tag = "type")]
602enum AzureContentPart {
603    #[serde(rename = "text")]
604    Text { text: String },
605    #[serde(rename = "image_url")]
606    ImageUrl { image_url: AzureImageUrl },
607}
608
609#[derive(Debug, Serialize)]
610struct AzureImageUrl {
611    url: String,
612}
613
614#[derive(Debug, Serialize)]
615struct AzureToolCallRef {
616    id: String,
617    r#type: &'static str,
618    function: AzureFunctionRef,
619}
620
621#[derive(Debug, Serialize)]
622struct AzureFunctionRef {
623    name: String,
624    arguments: String,
625}
626
627#[derive(Debug, Serialize)]
628struct AzureTool {
629    r#type: &'static str,
630    function: AzureFunction,
631}
632
633#[derive(Debug, Serialize)]
634struct AzureFunction {
635    name: String,
636    description: String,
637    parameters: serde_json::Value,
638}
639
640// ============================================================================
641// Streaming Response Types
642// ============================================================================
643
644#[derive(Debug, Deserialize)]
645struct AzureStreamChunk {
646    #[serde(default)]
647    choices: Vec<AzureChoice>,
648    #[serde(default)]
649    usage: Option<AzureUsage>,
650}
651
652#[derive(Debug, Deserialize)]
653struct AzureChoice {
654    delta: AzureDelta,
655    #[serde(default)]
656    finish_reason: Option<String>,
657}
658
659#[derive(Debug, Deserialize)]
660struct AzureDelta {
661    #[serde(default)]
662    content: Option<String>,
663    #[serde(default)]
664    tool_calls: Option<Vec<AzureToolCallDelta>>,
665}
666
667#[derive(Debug, Deserialize)]
668struct AzureToolCallDelta {
669    index: u32,
670    #[serde(default)]
671    id: Option<String>,
672    #[serde(default)]
673    function: Option<AzureFunctionDelta>,
674}
675
676#[derive(Debug, Deserialize)]
677struct AzureFunctionDelta {
678    #[serde(default)]
679    name: Option<String>,
680    #[serde(default)]
681    arguments: Option<String>,
682}
683
684#[derive(Debug, Deserialize)]
685#[allow(clippy::struct_field_names)]
686struct AzureUsage {
687    prompt_tokens: u64,
688    #[serde(default)]
689    completion_tokens: Option<u64>,
690    #[allow(dead_code)]
691    total_tokens: u64,
692}
693
694// ============================================================================
695// Conversion Functions
696// ============================================================================
697
698fn convert_message_to_azure(message: &Message) -> Vec<AzureMessage> {
699    match message {
700        Message::User(user) => vec![AzureMessage {
701            role: "user".to_string(),
702            content: Some(convert_user_content(&user.content)),
703            tool_calls: None,
704            tool_call_id: None,
705        }],
706        Message::Custom(custom) => vec![AzureMessage {
707            role: "user".to_string(),
708            content: Some(AzureContent::Text(custom.content.clone())),
709            tool_calls: None,
710            tool_call_id: None,
711        }],
712        Message::Assistant(assistant) => {
713            let mut messages = Vec::new();
714
715            // Collect text content
716            let text: String = assistant
717                .content
718                .iter()
719                .filter_map(|b| match b {
720                    ContentBlock::Text(t) => Some(t.text.as_str()),
721                    _ => None,
722                })
723                .collect::<String>();
724
725            // Collect tool calls
726            let tool_calls: Vec<AzureToolCallRef> = assistant
727                .content
728                .iter()
729                .filter_map(|b| match b {
730                    ContentBlock::ToolCall(tc) => Some(AzureToolCallRef {
731                        id: tc.id.clone(),
732                        r#type: "function",
733                        function: AzureFunctionRef {
734                            name: tc.name.clone(),
735                            arguments: tc.arguments.to_string(),
736                        },
737                    }),
738                    _ => None,
739                })
740                .collect();
741
742            let content = if text.is_empty() {
743                None
744            } else {
745                Some(AzureContent::Text(text))
746            };
747
748            let tool_calls = if tool_calls.is_empty() {
749                None
750            } else {
751                Some(tool_calls)
752            };
753
754            messages.push(AzureMessage {
755                role: "assistant".to_string(),
756                content,
757                tool_calls,
758                tool_call_id: None,
759            });
760
761            messages
762        }
763        Message::ToolResult(result) => {
764            let parts: Vec<AzureContentPart> = result
765                .content
766                .iter()
767                .filter_map(|block| match block {
768                    ContentBlock::Text(t) => Some(AzureContentPart::Text {
769                        text: t.text.clone(),
770                    }),
771                    ContentBlock::Image(img) => {
772                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
773                        Some(AzureContentPart::ImageUrl {
774                            image_url: AzureImageUrl { url },
775                        })
776                    }
777                    _ => None,
778                })
779                .collect();
780
781            let content = if parts.is_empty() {
782                None
783            } else if parts.len() == 1 && matches!(parts[0], AzureContentPart::Text { .. }) {
784                if let AzureContentPart::Text { text } = &parts[0] {
785                    Some(AzureContent::Text(text.clone()))
786                } else {
787                    Some(AzureContent::Parts(parts))
788                }
789            } else {
790                Some(AzureContent::Parts(parts))
791            };
792
793            vec![AzureMessage {
794                role: "tool".to_string(),
795                content,
796                tool_calls: None,
797                tool_call_id: Some(result.tool_call_id.clone()),
798            }]
799        }
800    }
801}
802
803fn convert_user_content(content: &UserContent) -> AzureContent {
804    match content {
805        UserContent::Text(text) => AzureContent::Text(text.clone()),
806        UserContent::Blocks(blocks) => {
807            let parts: Vec<AzureContentPart> = blocks
808                .iter()
809                .filter_map(|block| match block {
810                    ContentBlock::Text(t) => Some(AzureContentPart::Text {
811                        text: t.text.clone(),
812                    }),
813                    ContentBlock::Image(img) => {
814                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
815                        Some(AzureContentPart::ImageUrl {
816                            image_url: AzureImageUrl { url },
817                        })
818                    }
819                    _ => None,
820                })
821                .collect();
822            AzureContent::Parts(parts)
823        }
824    }
825}
826
827fn convert_tool_to_azure(tool: &ToolDef) -> AzureTool {
828    AzureTool {
829        r#type: "function",
830        function: AzureFunction {
831            name: tool.name.clone(),
832            description: tool.description.clone(),
833            parameters: tool.parameters.clone(),
834        },
835    }
836}
837
838// ============================================================================
839// Tests
840// ============================================================================
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845    use crate::model::{TextContent, ToolCall, UserMessage};
846    use crate::provider::ToolDef;
847    use asupersync::runtime::RuntimeBuilder;
848    use futures::{StreamExt, stream};
849    use serde::{Deserialize, Serialize};
850    use serde_json::{Value, json};
851    use std::path::PathBuf;
852
853    #[test]
854    fn test_azure_provider_creation() {
855        let provider = AzureOpenAIProvider::new("my-resource", "gpt-4");
856        assert_eq!(provider.name(), "azure");
857        assert_eq!(provider.api(), "azure-openai");
858    }
859
860    #[test]
861    fn test_azure_model_id_uses_deployment() {
862        let provider = AzureOpenAIProvider::new("my-resource", "gpt-4o-mini");
863        assert_eq!(provider.model_id(), "gpt-4o-mini");
864    }
865
866    #[test]
867    fn test_azure_endpoint_url() {
868        let provider = AzureOpenAIProvider::new("contoso", "gpt-4-turbo");
869        let url = provider.endpoint_url();
870        assert!(url.contains("contoso.openai.azure.com"));
871        assert!(url.contains("gpt-4-turbo"));
872        assert!(url.contains("api-version="));
873    }
874
875    #[test]
876    fn test_azure_endpoint_url_custom_version() {
877        let provider = AzureOpenAIProvider::new("contoso", "gpt-4").with_api_version("2024-06-01");
878        let url = provider.endpoint_url();
879        assert!(url.contains("api-version=2024-06-01"));
880    }
881
882    #[test]
883    fn test_azure_endpoint_url_exact_default_shape() {
884        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
885        let url = provider.endpoint_url();
886        assert_eq!(
887            url,
888            "https://contoso.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-15-preview"
889        );
890    }
891
892    #[test]
893    fn test_azure_endpoint_url_override_takes_precedence() {
894        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
895            .with_api_version("2025-01-01")
896            .with_endpoint_url("http://127.0.0.1:1234/mock-endpoint");
897        let url = provider.endpoint_url();
898        assert_eq!(url, "http://127.0.0.1:1234/mock-endpoint");
899    }
900
901    #[test]
902    fn test_azure_build_request_includes_system_messages_and_tools() {
903        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
904        let context = Context {
905            system_prompt: Some("You are deterministic.".to_string().into()),
906            messages: vec![
907                Message::User(UserMessage {
908                    content: UserContent::Text("Hello".to_string()),
909                    timestamp: 0,
910                }),
911                Message::assistant(AssistantMessage {
912                    content: vec![
913                        ContentBlock::Text(TextContent::new("Need tool output")),
914                        ContentBlock::ToolCall(ToolCall {
915                            id: "tool_1".to_string(),
916                            name: "echo".to_string(),
917                            arguments: json!({"text":"ping"}),
918                            thought_signature: None,
919                        }),
920                    ],
921                    api: "azure-openai".to_string(),
922                    provider: "azure".to_string(),
923                    model: "gpt-4o".to_string(),
924                    usage: Usage::default(),
925                    stop_reason: StopReason::ToolUse,
926                    error_message: None,
927                    timestamp: 0,
928                }),
929            ]
930            .into(),
931            tools: vec![ToolDef {
932                name: "echo".to_string(),
933                description: "Echo text".to_string(),
934                parameters: json!({
935                    "type": "object",
936                    "properties": {
937                        "text": {"type":"string"}
938                    },
939                    "required": ["text"]
940                }),
941            }]
942            .into(),
943        };
944        let options = StreamOptions {
945            max_tokens: Some(512),
946            temperature: Some(0.0),
947            ..Default::default()
948        };
949
950        let request = provider.build_request(&context, &options);
951        let request_json = serde_json::to_value(&request).expect("serialize request");
952        assert_eq!(request_json["max_tokens"], json!(512));
953        assert_eq!(request_json["temperature"], json!(0.0));
954        assert_eq!(request_json["stream"], json!(true));
955        assert_eq!(request_json["messages"][0]["role"], json!("system"));
956        assert_eq!(
957            request_json["messages"][0]["content"],
958            json!("You are deterministic.")
959        );
960        assert_eq!(request_json["messages"][1]["role"], json!("user"));
961        assert_eq!(request_json["messages"][2]["role"], json!("assistant"));
962        assert_eq!(request_json["tools"][0]["type"], json!("function"));
963        assert_eq!(request_json["tools"][0]["function"]["name"], json!("echo"));
964    }
965
966    #[test]
967    fn test_azure_build_request_defaults_max_tokens() {
968        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
969        let context = Context {
970            system_prompt: None,
971            messages: vec![Message::User(UserMessage {
972                content: UserContent::Text("Hello".to_string()),
973                timestamp: 0,
974            })]
975            .into(),
976            tools: Vec::new().into(),
977        };
978        let options = StreamOptions::default();
979
980        let request = provider.build_request(&context, &options);
981        let request_json = serde_json::to_value(&request).expect("serialize request");
982        assert_eq!(request_json["max_tokens"], json!(DEFAULT_MAX_TOKENS));
983        assert_eq!(request_json["stream"], json!(true));
984        assert!(request_json.get("tools").is_none());
985    }
986
987    #[test]
988    fn test_azure_build_request_normalizes_known_system_role_name() {
989        let provider =
990            AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
991                system_role_name: Some("SYSTEM ".to_string()),
992                ..CompatConfig::default()
993            }));
994        let context = Context {
995            system_prompt: Some("You are deterministic.".to_string().into()),
996            messages: Vec::new().into(),
997            tools: Vec::new().into(),
998        };
999
1000        let request = provider.build_request(&context, &StreamOptions::default());
1001        let request_json = serde_json::to_value(&request).expect("serialize request");
1002        assert_eq!(request_json["messages"][0]["role"], json!("system"));
1003    }
1004
1005    #[test]
1006    fn test_azure_build_request_preserves_unknown_system_role_name() {
1007        let provider =
1008            AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1009                system_role_name: Some("custom_role".to_string()),
1010                ..CompatConfig::default()
1011            }));
1012        let context = Context {
1013            system_prompt: Some("You are deterministic.".to_string().into()),
1014            messages: Vec::new().into(),
1015            tools: Vec::new().into(),
1016        };
1017
1018        let request = provider.build_request(&context, &StreamOptions::default());
1019        let request_json = serde_json::to_value(&request).expect("serialize request");
1020        assert_eq!(request_json["messages"][0]["role"], json!("custom_role"));
1021    }
1022
1023    #[test]
1024    fn test_azure_message_conversion() {
1025        let message = Message::User(UserMessage {
1026            content: UserContent::Text("Hello".to_string()),
1027            timestamp: chrono::Utc::now().timestamp_millis(),
1028        });
1029
1030        let azure_messages = convert_message_to_azure(&message);
1031        assert_eq!(azure_messages.len(), 1);
1032        assert_eq!(azure_messages[0].role, "user");
1033    }
1034
1035    #[derive(Debug, Deserialize)]
1036    struct ProviderFixture {
1037        cases: Vec<ProviderCase>,
1038    }
1039
1040    #[derive(Debug, Deserialize)]
1041    struct ProviderCase {
1042        name: String,
1043        events: Vec<Value>,
1044        expected: Vec<EventSummary>,
1045    }
1046
1047    #[derive(Debug, Deserialize, Serialize, PartialEq)]
1048    struct EventSummary {
1049        kind: String,
1050        #[serde(default)]
1051        content_index: Option<usize>,
1052        #[serde(default)]
1053        delta: Option<String>,
1054        #[serde(default)]
1055        content: Option<String>,
1056        #[serde(default)]
1057        reason: Option<String>,
1058    }
1059
1060    #[test]
1061    fn test_stream_fixtures() {
1062        let fixture = load_fixture("azure_stream.json");
1063        for case in fixture.cases {
1064            let events = collect_events(&case.events);
1065            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1066            assert_eq!(summaries, case.expected, "case {}", case.name);
1067        }
1068    }
1069
1070    #[test]
1071    fn test_stream_handles_sparse_tool_call_index_without_panic() {
1072        let events = vec![
1073            json!({ "choices": [{ "delta": {} }] }),
1074            json!({
1075                "choices": [{
1076                    "delta": {
1077                        "tool_calls": [{
1078                            "index": 3,
1079                            "id": "call_sparse",
1080                            "function": {
1081                                "name": "lookup",
1082                                "arguments": "{\"q\":\"azure\"}"
1083                            }
1084                        }]
1085                    }
1086                }]
1087            }),
1088            json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1089            Value::String("[DONE]".to_string()),
1090        ];
1091
1092        let out = collect_events(&events);
1093        let done = out
1094            .iter()
1095            .find_map(|event| match event {
1096                StreamEvent::Done { message, .. } => Some(message),
1097                _ => None,
1098            })
1099            .expect("done event");
1100        let tool_calls: Vec<&ToolCall> = done
1101            .content
1102            .iter()
1103            .filter_map(|block| match block {
1104                ContentBlock::ToolCall(tc) => Some(tc),
1105                _ => None,
1106            })
1107            .collect();
1108        assert_eq!(tool_calls.len(), 1);
1109        assert_eq!(tool_calls[0].id, "call_sparse");
1110        assert_eq!(tool_calls[0].name, "lookup");
1111        assert_eq!(tool_calls[0].arguments, json!({ "q": "azure" }));
1112        assert!(
1113            out.iter()
1114                .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1115            "expected tool call start event"
1116        );
1117    }
1118
1119    fn load_fixture(file_name: &str) -> ProviderFixture {
1120        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1121            .join("tests/fixtures/provider_responses")
1122            .join(file_name);
1123        let raw = std::fs::read_to_string(path).expect("fixture read");
1124        serde_json::from_str(&raw).expect("fixture parse")
1125    }
1126
1127    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1128        let runtime = RuntimeBuilder::current_thread()
1129            .build()
1130            .expect("runtime build");
1131        runtime.block_on(async move {
1132            let byte_stream = stream::iter(
1133                events
1134                    .iter()
1135                    .map(|event| {
1136                        let data = match event {
1137                            Value::String(text) => text.clone(),
1138                            _ => serde_json::to_string(event).expect("serialize event"),
1139                        };
1140                        format!("data: {data}\n\n").into_bytes()
1141                    })
1142                    .map(Ok),
1143            );
1144            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1145            let mut state = StreamState::new(
1146                event_source,
1147                "gpt-test".to_string(),
1148                "azure-openai".to_string(),
1149                "azure".to_string(),
1150            );
1151            let mut out = Vec::new();
1152
1153            while let Some(item) = state.event_source.next().await {
1154                let msg = item.expect("SSE event");
1155                if msg.data == "[DONE]" {
1156                    out.extend(state.pending_events.drain(..));
1157                    let reason = state.partial.stop_reason;
1158                    out.push(StreamEvent::Done {
1159                        reason,
1160                        message: std::mem::take(&mut state.partial),
1161                    });
1162                    break;
1163                }
1164                state.process_event(&msg.data).expect("process_event");
1165                out.extend(state.pending_events.drain(..));
1166            }
1167
1168            out
1169        })
1170    }
1171
1172    fn summarize_event(event: &StreamEvent) -> EventSummary {
1173        match event {
1174            StreamEvent::Start { .. } => EventSummary {
1175                kind: "start".to_string(),
1176                content_index: None,
1177                delta: None,
1178                content: None,
1179                reason: None,
1180            },
1181            StreamEvent::TextDelta {
1182                content_index,
1183                delta,
1184                ..
1185            } => EventSummary {
1186                kind: "text_delta".to_string(),
1187                content_index: Some(*content_index),
1188                delta: Some(delta.clone()),
1189                content: None,
1190                reason: None,
1191            },
1192            StreamEvent::Done { reason, .. } => EventSummary {
1193                kind: "done".to_string(),
1194                content_index: None,
1195                delta: None,
1196                content: None,
1197                reason: Some(reason_to_string(*reason)),
1198            },
1199            StreamEvent::Error { reason, .. } => EventSummary {
1200                kind: "error".to_string(),
1201                content_index: None,
1202                delta: None,
1203                content: None,
1204                reason: Some(reason_to_string(*reason)),
1205            },
1206            StreamEvent::TextStart { content_index, .. } => EventSummary {
1207                kind: "text_start".to_string(),
1208                content_index: Some(*content_index),
1209                delta: None,
1210                content: None,
1211                reason: None,
1212            },
1213            StreamEvent::TextEnd {
1214                content_index,
1215                content,
1216                ..
1217            } => EventSummary {
1218                kind: "text_end".to_string(),
1219                content_index: Some(*content_index),
1220                delta: None,
1221                content: Some(content.clone()),
1222                reason: None,
1223            },
1224            _ => EventSummary {
1225                kind: "other".to_string(),
1226                content_index: None,
1227                delta: None,
1228                content: None,
1229                reason: None,
1230            },
1231        }
1232    }
1233
1234    fn reason_to_string(reason: StopReason) -> String {
1235        match reason {
1236            StopReason::Stop => "stop",
1237            StopReason::Length => "length",
1238            StopReason::ToolUse => "tool_use",
1239            StopReason::Error => "error",
1240            StopReason::Aborted => "aborted",
1241        }
1242        .to_string()
1243    }
1244}
1245
1246// ============================================================================
1247// Fuzzing support
1248// ============================================================================
1249
1250#[cfg(feature = "fuzzing")]
1251pub mod fuzz {
1252    use super::*;
1253    use futures::stream;
1254    use std::pin::Pin;
1255
1256    type FuzzStream =
1257        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1258
1259    /// Opaque wrapper around the Azure OpenAI stream processor state.
1260    pub struct Processor(StreamState<FuzzStream>);
1261
1262    impl Default for Processor {
1263        fn default() -> Self {
1264            Self::new()
1265        }
1266    }
1267
1268    impl Processor {
1269        /// Create a fresh processor with default state.
1270        pub fn new() -> Self {
1271            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1272            Self(StreamState::new(
1273                crate::sse::SseStream::new(Box::pin(empty)),
1274                "azure-fuzz".into(),
1275                "azure-openai".into(),
1276                "azure".into(),
1277            ))
1278        }
1279
1280        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1281        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1282            self.0.process_event(data)?;
1283            Ok(self.0.pending_events.drain(..).collect())
1284        }
1285    }
1286}