Skip to main content

pi/providers/
azure.rs

1//! Azure OpenAI Chat Completions API provider implementation.
2//!
3//! This module implements the Provider trait for Azure OpenAI, using the same
4//! streaming protocol as OpenAI but with Azure-specific authentication and endpoints.
5//!
6//! Azure OpenAI URL format:
7//! `https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}`
8
9use crate::error::{Error, Result};
10use crate::http::client::Client;
11use crate::model::{
12    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::collections::VecDeque;
22use std::pin::Pin;
23
24// ============================================================================
25// Constants
26// ============================================================================
27
28/// Default Azure OpenAI API version.  Override via `PI_AZURE_API_VERSION`.
29pub(crate) const DEFAULT_API_VERSION: &str = "2024-12-01-preview";
30const DEFAULT_MAX_TOKENS: u32 = 4096;
31
32pub(crate) fn azure_api_version() -> String {
33    std::env::var("PI_AZURE_API_VERSION")
34        .ok()
35        .filter(|v| !v.is_empty())
36        .unwrap_or_else(|| DEFAULT_API_VERSION.to_string())
37}
38
39/// Normalize Azure role names while preserving unknown compat overrides as-is.
40fn normalize_role(role: &str) -> String {
41    let trimmed = role.trim();
42    match trimmed {
43        "system" | "developer" | "user" | "assistant" | "tool" | "function" => trimmed.to_string(),
44        _ => {
45            let lowered = trimmed.to_ascii_lowercase();
46            match lowered.as_str() {
47                "system" | "developer" | "user" | "assistant" | "tool" | "function" => lowered,
48                _ => trimmed.to_string(),
49            }
50        }
51    }
52}
53
54fn authorization_override(
55    options: &StreamOptions,
56    compat: Option<&CompatConfig>,
57) -> Option<String> {
58    super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
59        .or_else(|| {
60            compat
61                .and_then(|compat| compat.custom_headers.as_ref())
62                .and_then(|headers| {
63                    super::first_non_empty_header_value_case_insensitive(
64                        headers,
65                        &["authorization"],
66                    )
67                })
68        })
69}
70
71fn api_key_override(options: &StreamOptions, compat: Option<&CompatConfig>) -> Option<String> {
72    super::first_non_empty_header_value_case_insensitive(&options.headers, &["api-key"]).or_else(
73        || {
74            compat
75                .and_then(|compat| compat.custom_headers.as_ref())
76                .and_then(|headers| {
77                    super::first_non_empty_header_value_case_insensitive(headers, &["api-key"])
78                })
79        },
80    )
81}
82
83// ============================================================================
84// Azure OpenAI Provider
85// ============================================================================
86
87/// Azure OpenAI Chat Completions API provider.
88pub struct AzureOpenAIProvider {
89    client: Client,
90    /// Provider name for event reporting (defaults to "azure").
91    provider: String,
92    /// The deployment name (model deployment in Azure)
93    deployment: String,
94    /// Azure resource name (part of the URL)
95    resource: String,
96    /// API version string
97    api_version: String,
98    /// Optional override for the full endpoint URL (primarily for deterministic tests).
99    endpoint_url_override: Option<String>,
100    compat: Option<CompatConfig>,
101}
102
103impl AzureOpenAIProvider {
104    /// Create a new Azure OpenAI provider.
105    ///
106    /// # Arguments
107    /// * `resource` - Azure OpenAI resource name
108    /// * `deployment` - Model deployment name
109    pub fn new(resource: impl Into<String>, deployment: impl Into<String>) -> Self {
110        Self {
111            client: Client::new(),
112            provider: "azure".to_string(),
113            deployment: deployment.into(),
114            resource: resource.into(),
115            api_version: azure_api_version(),
116            endpoint_url_override: None,
117            compat: None,
118        }
119    }
120
121    /// Set the provider name for event reporting.
122    #[must_use]
123    pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
124        self.provider = provider.into();
125        self
126    }
127
128    /// Set the API version.
129    #[must_use]
130    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
131        self.api_version = version.into();
132        self
133    }
134
135    /// Override the full endpoint URL.
136    ///
137    /// This is intended for deterministic, offline tests (e.g. mock servers). Production
138    /// code should rely on the standard Azure endpoint URL format.
139    #[must_use]
140    pub fn with_endpoint_url(mut self, endpoint_url: impl Into<String>) -> Self {
141        self.endpoint_url_override = Some(endpoint_url.into());
142        self
143    }
144
145    /// Create with a custom HTTP client (VCR, test harness, etc.).
146    #[must_use]
147    pub fn with_client(mut self, client: Client) -> Self {
148        self.client = client;
149        self
150    }
151
152    /// Attach provider-specific compatibility overrides.
153    #[must_use]
154    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
155        self.compat = compat;
156        self
157    }
158
159    /// Get the full endpoint URL.
160    fn endpoint_url(&self) -> String {
161        if let Some(url) = &self.endpoint_url_override {
162            return url.clone();
163        }
164        format!(
165            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
166            self.resource, self.deployment, self.api_version
167        )
168    }
169
170    /// Build the request body for Azure OpenAI (same format as OpenAI).
171    #[allow(clippy::unused_self)]
172    pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> AzureRequest {
173        let messages = self.build_messages(context);
174
175        let tools: Option<Vec<AzureTool>> = if context.tools.is_empty() {
176            None
177        } else {
178            Some(context.tools.iter().map(convert_tool_to_azure).collect())
179        };
180
181        AzureRequest {
182            messages,
183            max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
184            temperature: options.temperature,
185            tools,
186            stream: true,
187            stream_options: Some(AzureStreamOptions {
188                include_usage: true,
189            }),
190        }
191    }
192
193    /// Build the messages array with system prompt prepended.
194    fn build_messages(&self, context: &Context<'_>) -> Vec<AzureMessage> {
195        let mut messages = Vec::new();
196        let system_role = self
197            .compat
198            .as_ref()
199            .and_then(|c| c.system_role_name.as_deref())
200            .unwrap_or("system");
201
202        // Add system prompt as first message
203        if let Some(system) = &context.system_prompt {
204            messages.push(AzureMessage {
205                role: normalize_role(system_role),
206                content: Some(AzureContent::Text(system.to_string())),
207                tool_calls: None,
208                tool_call_id: None,
209            });
210        }
211
212        // Convert conversation messages
213        for message in context.messages.iter() {
214            messages.extend(convert_message_to_azure(message));
215        }
216
217        messages
218    }
219}
220
221#[async_trait]
222#[allow(clippy::too_many_lines)]
223impl Provider for AzureOpenAIProvider {
224    fn name(&self) -> &str {
225        &self.provider
226    }
227
228    fn api(&self) -> &'static str {
229        "azure-openai"
230    }
231
232    fn model_id(&self) -> &str {
233        &self.deployment
234    }
235
236    async fn stream(
237        &self,
238        context: &Context<'_>,
239        options: &StreamOptions,
240    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
241        let has_auth_override = api_key_override(options, self.compat.as_ref()).is_some()
242            || authorization_override(options, self.compat.as_ref()).is_some();
243        let auth_value = if has_auth_override {
244            None
245        } else {
246            Some(
247                options
248                    .api_key
249                    .clone()
250                    .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
251                    .ok_or_else(|| Error::provider("azure-openai", "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var."))?,
252            )
253        };
254
255        let request_body = self.build_request(context, options);
256
257        let endpoint_url = self.endpoint_url();
258
259        // Build request with Azure-specific headers (Content-Type set by .json() below)
260        let mut request = self
261            .client
262            .post(&endpoint_url)
263            .header("Accept", "text/event-stream");
264
265        if let Some(auth_value) = auth_value {
266            request = request.header("api-key", &auth_value); // Azure uses api-key header, not Authorization
267        }
268
269        // Apply provider-specific custom headers from compat config.
270        if let Some(compat) = &self.compat {
271            if let Some(custom_headers) = &compat.custom_headers {
272                request = super::apply_headers_ignoring_blank_auth_overrides(
273                    request,
274                    custom_headers,
275                    &["authorization", "api-key"],
276                );
277            }
278        }
279
280        request = super::apply_headers_ignoring_blank_auth_overrides(
281            request,
282            &options.headers,
283            &["authorization", "api-key"],
284        );
285
286        let request = request.json(&request_body)?;
287
288        let response = Box::pin(request.send()).await?;
289        let status = response.status();
290        if !(200..300).contains(&status) {
291            let body = response
292                .text()
293                .await
294                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
295            return Err(Error::provider(
296                "azure-openai",
297                format!("Azure OpenAI API error (HTTP {status}): {body}"),
298            ));
299        }
300
301        // Create SSE stream for streaming responses.
302        let event_source = SseStream::new(response.bytes_stream());
303
304        // Create stream state
305        let model = self.deployment.clone();
306        let api = self.api().to_string();
307        let provider = self.name().to_string();
308
309        let stream = stream::unfold(
310            StreamState::new(event_source, model, api, provider),
311            |mut state| async move {
312                if state.done {
313                    return None;
314                }
315                loop {
316                    if let Some(event) = state.pending_events.pop_front() {
317                        return Some((Ok(event), state));
318                    }
319
320                    match state.event_source.next().await {
321                        Some(Ok(msg)) => {
322                            state.transient_error_count = 0;
323                            // Azure also sends "[DONE]" as final message
324                            if msg.data == "[DONE]" {
325                                state.done = true;
326                                let reason = state.partial.stop_reason;
327                                let message = std::mem::take(&mut state.partial);
328                                return Some((Ok(StreamEvent::Done { reason, message }), state));
329                            }
330
331                            if let Err(e) = state.process_event(&msg.data) {
332                                state.done = true;
333                                return Some((Err(e), state));
334                            }
335                        }
336                        Some(Err(e)) => {
337                            // WriteZero, WouldBlock, and TimedOut errors are treated as transient.
338                            // Skip them and keep reading the stream, but cap
339                            // consecutive occurrences to avoid infinite loops.
340                            const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
341                            if e.kind() == std::io::ErrorKind::WriteZero
342                                || e.kind() == std::io::ErrorKind::WouldBlock
343                                || e.kind() == std::io::ErrorKind::TimedOut
344                            {
345                                state.transient_error_count += 1;
346                                if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
347                                    tracing::warn!(
348                                        kind = ?e.kind(),
349                                        count = state.transient_error_count,
350                                        "Transient error in SSE stream, continuing"
351                                    );
352                                    continue;
353                                }
354                                tracing::warn!(
355                                    kind = ?e.kind(),
356                                    "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
357                                     consecutive attempts, treating as fatal"
358                                );
359                            }
360                            state.done = true;
361                            let err = Error::api(format!("SSE error: {e}"));
362                            return Some((Err(err), state));
363                        }
364                        // Stream ended without [DONE] sentinel (e.g.
365                        // premature server disconnect).  Emit Done so the
366                        // agent loop receives the accumulated partial
367                        // instead of silently losing it.
368                        None => {
369                            state.done = true;
370                            let reason = state.partial.stop_reason;
371                            let message = std::mem::take(&mut state.partial);
372                            return Some((Ok(StreamEvent::Done { reason, message }), state));
373                        }
374                    }
375                }
376            },
377        );
378
379        Ok(Box::pin(stream))
380    }
381}
382
383// ============================================================================
384// Stream State
385// ============================================================================
386
387struct StreamState<S>
388where
389    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
390{
391    event_source: SseStream<S>,
392    partial: AssistantMessage,
393    tool_calls: Vec<ToolCallState>,
394    pending_events: VecDeque<StreamEvent>,
395    started: bool,
396    done: bool,
397    /// Consecutive WriteZero errors seen without a successful event in between.
398    transient_error_count: usize,
399}
400
401struct ToolCallState {
402    index: usize,
403    content_index: usize,
404    id: String,
405    name: String,
406    arguments: String,
407}
408
409impl<S> StreamState<S>
410where
411    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
412{
413    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
414        Self {
415            event_source,
416            partial: AssistantMessage {
417                content: Vec::new(),
418                api,
419                provider,
420                model,
421                usage: Usage::default(),
422                stop_reason: StopReason::Stop,
423                error_message: None,
424                timestamp: chrono::Utc::now().timestamp_millis(),
425            },
426            tool_calls: Vec::new(),
427            pending_events: VecDeque::new(),
428            started: false,
429            done: false,
430            transient_error_count: 0,
431        }
432    }
433
434    fn finalize_tool_call_arguments(&mut self) {
435        for tc in &self.tool_calls {
436            let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
437                Ok(args) => args,
438                Err(e) => {
439                    tracing::warn!(
440                        error = %e,
441                        raw = %tc.arguments,
442                        "Failed to parse tool arguments as JSON"
443                    );
444                    serde_json::Value::Null
445                }
446            };
447
448            if let Some(ContentBlock::ToolCall(block)) =
449                self.partial.content.get_mut(tc.content_index)
450            {
451                block.arguments = arguments;
452            }
453        }
454    }
455
456    fn push_text_delta(&mut self, text: String) -> StreamEvent {
457        let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
458        if !last_is_text {
459            let content_index = self.partial.content.len();
460            self.partial
461                .content
462                .push(ContentBlock::Text(crate::model::TextContent::new("")));
463            self.pending_events
464                .push_back(StreamEvent::TextStart { content_index });
465        }
466        let content_index = self.partial.content.len() - 1;
467
468        if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
469            t.text.push_str(&text);
470        }
471
472        StreamEvent::TextDelta {
473            content_index,
474            delta: text,
475        }
476    }
477
478    fn ensure_started(&mut self) {
479        if !self.started {
480            self.started = true;
481            self.pending_events.push_back(StreamEvent::Start {
482                partial: self.partial.clone(),
483            });
484        }
485    }
486
487    #[allow(clippy::unnecessary_wraps, clippy::too_many_lines)]
488    fn process_event(&mut self, data: &str) -> Result<()> {
489        let chunk: AzureStreamChunk = serde_json::from_str(data)
490            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
491
492        // Process usage if present
493        if let Some(usage) = chunk.usage {
494            self.partial.usage.input = usage.prompt_tokens;
495            self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
496            self.partial.usage.total_tokens = usage.total_tokens;
497        }
498
499        let choices = chunk.choices;
500        if !self.started {
501            let first = choices.first();
502            let delta_is_empty = first.is_some_and(|choice| {
503                choice.finish_reason.is_none()
504                    && choice.delta.content.is_none()
505                    && choice.delta.tool_calls.is_none()
506            });
507            if delta_is_empty {
508                self.ensure_started();
509                return Ok(());
510            }
511        }
512
513        // Process choices — handle content deltas BEFORE finish_reason so that
514        // TextEnd/ToolCallEnd events always follow the final delta (matching the
515        // OpenAI provider event ordering contract).
516        for choice in choices {
517            // Handle text content
518            if let Some(text) = choice.delta.content {
519                self.ensure_started();
520                let event = self.push_text_delta(text);
521                self.pending_events.push_back(event);
522            }
523
524            // Handle tool calls
525            if let Some(tool_calls) = choice.delta.tool_calls {
526                self.ensure_started();
527
528                for tc in tool_calls {
529                    let idx = tc.index as usize;
530
531                    // Azure may emit sparse tool-call indices. Match by logical index
532                    // instead of assuming contiguous 0..N ordering in arrival order.
533                    let tool_state_idx = if let Some(existing_idx) =
534                        self.tool_calls.iter().position(|tc| tc.index == idx)
535                    {
536                        existing_idx
537                    } else {
538                        let content_index = self.partial.content.len();
539                        self.tool_calls.push(ToolCallState {
540                            index: idx,
541                            content_index,
542                            id: String::new(),
543                            name: String::new(),
544                            arguments: String::new(),
545                        });
546
547                        // Initialize block in partial
548                        self.partial
549                            .content
550                            .push(ContentBlock::ToolCall(crate::model::ToolCall {
551                                id: String::new(),
552                                name: String::new(),
553                                arguments: serde_json::Value::Null,
554                                thought_signature: None,
555                            }));
556
557                        // Emit ToolCallStart
558                        self.pending_events
559                            .push_back(StreamEvent::ToolCallStart { content_index });
560                        self.tool_calls.len() - 1
561                    };
562
563                    let tc_state = &mut self.tool_calls[tool_state_idx];
564                    let content_index = tc_state.content_index;
565
566                    // Update the tool call state
567                    if let Some(id) = tc.id {
568                        tc_state.id.push_str(&id);
569                        if let Some(ContentBlock::ToolCall(block)) =
570                            self.partial.content.get_mut(content_index)
571                        {
572                            block.id.clone_from(&tc_state.id);
573                        }
574                    }
575                    if let Some(func) = tc.function {
576                        if let Some(name) = func.name {
577                            tc_state.name.push_str(&name);
578                            if let Some(ContentBlock::ToolCall(block)) =
579                                self.partial.content.get_mut(content_index)
580                            {
581                                block.name.clone_from(&tc_state.name);
582                            }
583                        }
584                        if let Some(args) = func.arguments {
585                            tc_state.arguments.push_str(&args);
586                            // Note: we don't update partial arguments here as they need to be valid JSON.
587                            // We do that at the end.
588
589                            self.pending_events.push_back(StreamEvent::ToolCallDelta {
590                                content_index,
591                                delta: args,
592                            });
593                        }
594                    }
595                }
596            }
597
598            // Handle finish reason (MUST come after delta processing so TextEnd/ToolCallEnd
599            // events contain the complete accumulated content).
600            // Ensure Start is emitted even when finish arrives in an empty-delta chunk.
601            if choice.finish_reason.is_some() {
602                self.ensure_started();
603            }
604            if let Some(reason) = choice.finish_reason {
605                self.partial.stop_reason = match reason.as_str() {
606                    "length" => StopReason::Length,
607                    "content_filter" => StopReason::Error,
608                    "tool_calls" => StopReason::ToolUse,
609                    // "stop" and any other reason treated as normal stop
610                    _ => StopReason::Stop,
611                };
612
613                // Finalize tool call arguments
614                self.finalize_tool_call_arguments();
615
616                // Emit TextEnd/ThinkingEnd for all open text/thinking blocks.
617                for (content_index, block) in self.partial.content.iter().enumerate() {
618                    if let ContentBlock::Text(t) = block {
619                        self.pending_events.push_back(StreamEvent::TextEnd {
620                            content_index,
621                            content: t.text.clone(),
622                        });
623                    } else if let ContentBlock::Thinking(t) = block {
624                        self.pending_events.push_back(StreamEvent::ThinkingEnd {
625                            content_index,
626                            content: t.thinking.clone(),
627                        });
628                    }
629                }
630
631                // Emit ToolCallEnd for each accumulated tool call
632                for tc in &self.tool_calls {
633                    if let Some(ContentBlock::ToolCall(tool_call)) =
634                        self.partial.content.get(tc.content_index)
635                    {
636                        self.pending_events.push_back(StreamEvent::ToolCallEnd {
637                            content_index: tc.content_index,
638                            tool_call: tool_call.clone(),
639                        });
640                    }
641                }
642            }
643        }
644
645        Ok(())
646    }
647}
648
649// ============================================================================
650// Request Types
651// ============================================================================
652
653#[derive(Debug, Serialize)]
654pub struct AzureRequest {
655    messages: Vec<AzureMessage>,
656    #[serde(skip_serializing_if = "Option::is_none")]
657    max_tokens: Option<u32>,
658    #[serde(skip_serializing_if = "Option::is_none")]
659    temperature: Option<f32>,
660    #[serde(skip_serializing_if = "Option::is_none")]
661    tools: Option<Vec<AzureTool>>,
662    stream: bool,
663    #[serde(skip_serializing_if = "Option::is_none")]
664    stream_options: Option<AzureStreamOptions>,
665}
666
667#[derive(Debug, Serialize)]
668struct AzureStreamOptions {
669    include_usage: bool,
670}
671
672#[derive(Debug, Serialize)]
673struct AzureMessage {
674    role: String,
675    #[serde(skip_serializing_if = "Option::is_none")]
676    content: Option<AzureContent>,
677    #[serde(skip_serializing_if = "Option::is_none")]
678    tool_calls: Option<Vec<AzureToolCallRef>>,
679    #[serde(skip_serializing_if = "Option::is_none")]
680    tool_call_id: Option<String>,
681}
682
683#[derive(Debug, Serialize)]
684#[serde(untagged)]
685enum AzureContent {
686    Text(String),
687    Parts(Vec<AzureContentPart>),
688}
689
690#[derive(Debug, Serialize)]
691#[serde(tag = "type")]
692enum AzureContentPart {
693    #[serde(rename = "text")]
694    Text { text: String },
695    #[serde(rename = "image_url")]
696    ImageUrl { image_url: AzureImageUrl },
697}
698
699#[derive(Debug, Serialize)]
700struct AzureImageUrl {
701    url: String,
702}
703
704#[derive(Debug, Serialize)]
705struct AzureToolCallRef {
706    id: String,
707    r#type: &'static str,
708    function: AzureFunctionRef,
709}
710
711#[derive(Debug, Serialize)]
712struct AzureFunctionRef {
713    name: String,
714    arguments: String,
715}
716
717#[derive(Debug, Serialize)]
718struct AzureTool {
719    r#type: &'static str,
720    function: AzureFunction,
721}
722
723#[derive(Debug, Serialize)]
724struct AzureFunction {
725    name: String,
726    description: String,
727    parameters: serde_json::Value,
728}
729
730// ============================================================================
731// Streaming Response Types
732// ============================================================================
733
734#[derive(Debug, Deserialize)]
735struct AzureStreamChunk {
736    #[serde(default)]
737    choices: Vec<AzureChoice>,
738    #[serde(default)]
739    usage: Option<AzureUsage>,
740}
741
742#[derive(Debug, Deserialize)]
743struct AzureChoice {
744    delta: AzureDelta,
745    #[serde(default)]
746    finish_reason: Option<String>,
747}
748
749#[derive(Debug, Deserialize)]
750struct AzureDelta {
751    #[serde(default)]
752    content: Option<String>,
753    #[serde(default)]
754    tool_calls: Option<Vec<AzureToolCallDelta>>,
755}
756
757#[derive(Debug, Deserialize)]
758struct AzureToolCallDelta {
759    index: u32,
760    #[serde(default)]
761    id: Option<String>,
762    #[serde(default)]
763    function: Option<AzureFunctionDelta>,
764}
765
766#[derive(Debug, Deserialize)]
767struct AzureFunctionDelta {
768    #[serde(default)]
769    name: Option<String>,
770    #[serde(default)]
771    arguments: Option<String>,
772}
773
774#[derive(Debug, Deserialize)]
775#[allow(clippy::struct_field_names)]
776struct AzureUsage {
777    prompt_tokens: u64,
778    #[serde(default)]
779    completion_tokens: Option<u64>,
780    #[allow(dead_code)]
781    total_tokens: u64,
782}
783
784// ============================================================================
785// Conversion Functions
786// ============================================================================
787
788#[allow(clippy::too_many_lines)]
789fn convert_message_to_azure(message: &Message) -> Vec<AzureMessage> {
790    match message {
791        Message::User(user) => vec![AzureMessage {
792            role: "user".to_string(),
793            content: Some(convert_user_content(&user.content)),
794            tool_calls: None,
795            tool_call_id: None,
796        }],
797        Message::Custom(custom) => vec![AzureMessage {
798            role: "user".to_string(),
799            content: Some(AzureContent::Text(custom.content.clone())),
800            tool_calls: None,
801            tool_call_id: None,
802        }],
803        Message::Assistant(assistant) => {
804            let mut messages = Vec::new();
805
806            // Collect text content
807            let text: String = assistant
808                .content
809                .iter()
810                .filter_map(|b| match b {
811                    ContentBlock::Text(t) => Some(t.text.as_str()),
812                    _ => None,
813                })
814                .collect::<Vec<_>>()
815                .join("\n\n");
816
817            // Collect tool calls
818            let tool_calls: Vec<AzureToolCallRef> = assistant
819                .content
820                .iter()
821                .filter_map(|b| match b {
822                    ContentBlock::ToolCall(tc) => Some(AzureToolCallRef {
823                        id: tc.id.clone(),
824                        r#type: "function",
825                        function: AzureFunctionRef {
826                            name: tc.name.clone(),
827                            arguments: tc.arguments.to_string(),
828                        },
829                    }),
830                    _ => None,
831                })
832                .collect();
833
834            let content = if text.is_empty() {
835                None
836            } else {
837                Some(AzureContent::Text(text))
838            };
839
840            let tool_calls = if tool_calls.is_empty() {
841                None
842            } else {
843                Some(tool_calls)
844            };
845
846            messages.push(AzureMessage {
847                role: "assistant".to_string(),
848                content,
849                tool_calls,
850                tool_call_id: None,
851            });
852
853            messages
854        }
855        Message::ToolResult(result) => {
856            let mut text_parts = Vec::new();
857            let mut image_parts = Vec::new();
858
859            for block in &result.content {
860                match block {
861                    ContentBlock::Text(t) => text_parts.push(t.text.clone()),
862                    ContentBlock::Image(img) => {
863                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
864                        image_parts.push(AzureContentPart::ImageUrl {
865                            image_url: AzureImageUrl { url },
866                        });
867                    }
868                    _ => {}
869                }
870            }
871
872            let text_content = if text_parts.is_empty() {
873                if image_parts.is_empty() {
874                    None
875                } else {
876                    Some(AzureContent::Text("(see attached image)".to_string()))
877                }
878            } else {
879                Some(AzureContent::Text(text_parts.join("\n")))
880            };
881
882            let mut messages = vec![AzureMessage {
883                role: "tool".to_string(),
884                content: text_content,
885                tool_calls: None,
886                tool_call_id: Some(result.tool_call_id.clone()),
887            }];
888
889            if !image_parts.is_empty() {
890                let mut parts = vec![AzureContentPart::Text {
891                    text: "Attached image(s) from tool result:".to_string(),
892                }];
893                parts.extend(image_parts);
894                messages.push(AzureMessage {
895                    role: "user".to_string(),
896                    content: Some(AzureContent::Parts(parts)),
897                    tool_calls: None,
898                    tool_call_id: None,
899                });
900            }
901
902            messages
903        }
904    }
905}
906
907fn convert_user_content(content: &UserContent) -> AzureContent {
908    match content {
909        UserContent::Text(text) => AzureContent::Text(text.clone()),
910        UserContent::Blocks(blocks) => {
911            let parts: Vec<AzureContentPart> = blocks
912                .iter()
913                .filter_map(|block| match block {
914                    ContentBlock::Text(t) => Some(AzureContentPart::Text {
915                        text: t.text.clone(),
916                    }),
917                    ContentBlock::Image(img) => {
918                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
919                        Some(AzureContentPart::ImageUrl {
920                            image_url: AzureImageUrl { url },
921                        })
922                    }
923                    _ => None,
924                })
925                .collect();
926            AzureContent::Parts(parts)
927        }
928    }
929}
930
931fn convert_tool_to_azure(tool: &ToolDef) -> AzureTool {
932    AzureTool {
933        r#type: "function",
934        function: AzureFunction {
935            name: tool.name.clone(),
936            description: tool.description.clone(),
937            parameters: tool.parameters.clone(),
938        },
939    }
940}
941
942// ============================================================================
943// Tests
944// ============================================================================
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949    use crate::model::{ImageContent, TextContent, ToolCall, ToolResultMessage, UserMessage};
950    use crate::provider::ToolDef;
951    use asupersync::runtime::RuntimeBuilder;
952    use futures::{StreamExt, stream};
953    use serde::{Deserialize, Serialize};
954    use serde_json::{Value, json};
955    use std::collections::HashMap;
956    use std::io::{Read, Write};
957    use std::net::TcpListener;
958    use std::path::PathBuf;
959    use std::sync::mpsc;
960    use std::time::Duration;
961
962    #[test]
963    fn test_azure_provider_creation() {
964        let provider = AzureOpenAIProvider::new("my-resource", "gpt-4");
965        assert_eq!(provider.name(), "azure");
966        assert_eq!(provider.api(), "azure-openai");
967    }
968
969    #[test]
970    fn test_azure_model_id_uses_deployment() {
971        let provider = AzureOpenAIProvider::new("my-resource", "gpt-4o-mini");
972        assert_eq!(provider.model_id(), "gpt-4o-mini");
973    }
974
975    #[test]
976    fn test_azure_endpoint_url() {
977        let provider = AzureOpenAIProvider::new("contoso", "gpt-4-turbo");
978        let url = provider.endpoint_url();
979        assert!(url.contains("contoso.openai.azure.com"));
980        assert!(url.contains("gpt-4-turbo"));
981        assert!(url.contains("api-version="));
982    }
983
984    #[test]
985    fn test_azure_endpoint_url_custom_version() {
986        let provider = AzureOpenAIProvider::new("contoso", "gpt-4").with_api_version("2024-06-01");
987        let url = provider.endpoint_url();
988        assert!(url.contains("api-version=2024-06-01"));
989    }
990
991    #[test]
992    fn test_azure_endpoint_url_exact_default_shape() {
993        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
994        let url = provider.endpoint_url();
995        assert_eq!(
996            url,
997            "https://contoso.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-12-01-preview"
998        );
999    }
1000
1001    #[test]
1002    fn test_azure_endpoint_url_override_takes_precedence() {
1003        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1004            .with_api_version("2025-01-01")
1005            .with_endpoint_url("http://127.0.0.1:1234/mock-endpoint");
1006        let url = provider.endpoint_url();
1007        assert_eq!(url, "http://127.0.0.1:1234/mock-endpoint");
1008    }
1009
1010    #[test]
1011    fn test_azure_build_request_includes_system_messages_and_tools() {
1012        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
1013        let context = Context {
1014            system_prompt: Some("You are deterministic.".to_string().into()),
1015            messages: vec![
1016                Message::User(UserMessage {
1017                    content: UserContent::Text("Hello".to_string()),
1018                    timestamp: 0,
1019                }),
1020                Message::assistant(AssistantMessage {
1021                    content: vec![
1022                        ContentBlock::Text(TextContent::new("Need tool output")),
1023                        ContentBlock::ToolCall(ToolCall {
1024                            id: "tool_1".to_string(),
1025                            name: "echo".to_string(),
1026                            arguments: json!({"text":"ping"}),
1027                            thought_signature: None,
1028                        }),
1029                    ],
1030                    api: "azure-openai".to_string(),
1031                    provider: "azure".to_string(),
1032                    model: "gpt-4o".to_string(),
1033                    usage: Usage::default(),
1034                    stop_reason: StopReason::ToolUse,
1035                    error_message: None,
1036                    timestamp: 0,
1037                }),
1038            ]
1039            .into(),
1040            tools: vec![ToolDef {
1041                name: "echo".to_string(),
1042                description: "Echo text".to_string(),
1043                parameters: json!({
1044                    "type": "object",
1045                    "properties": {
1046                        "text": {"type":"string"}
1047                    },
1048                    "required": ["text"]
1049                }),
1050            }]
1051            .into(),
1052        };
1053        let options = StreamOptions {
1054            max_tokens: Some(512),
1055            temperature: Some(0.0),
1056            ..Default::default()
1057        };
1058
1059        let request = provider.build_request(&context, &options);
1060        let request_json = serde_json::to_value(&request).expect("serialize request");
1061        assert_eq!(request_json["max_tokens"], json!(512));
1062        assert_eq!(request_json["temperature"], json!(0.0));
1063        assert_eq!(request_json["stream"], json!(true));
1064        assert_eq!(request_json["messages"][0]["role"], json!("system"));
1065        assert_eq!(
1066            request_json["messages"][0]["content"],
1067            json!("You are deterministic.")
1068        );
1069        assert_eq!(request_json["messages"][1]["role"], json!("user"));
1070        assert_eq!(request_json["messages"][2]["role"], json!("assistant"));
1071        assert_eq!(request_json["tools"][0]["type"], json!("function"));
1072        assert_eq!(request_json["tools"][0]["function"]["name"], json!("echo"));
1073    }
1074
1075    #[test]
1076    fn test_azure_build_request_defaults_max_tokens() {
1077        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
1078        let context = Context {
1079            system_prompt: None,
1080            messages: vec![Message::User(UserMessage {
1081                content: UserContent::Text("Hello".to_string()),
1082                timestamp: 0,
1083            })]
1084            .into(),
1085            tools: Vec::new().into(),
1086        };
1087        let options = StreamOptions::default();
1088
1089        let request = provider.build_request(&context, &options);
1090        let request_json = serde_json::to_value(&request).expect("serialize request");
1091        assert_eq!(request_json["max_tokens"], json!(DEFAULT_MAX_TOKENS));
1092        assert_eq!(request_json["stream"], json!(true));
1093        assert!(request_json.get("tools").is_none());
1094    }
1095
1096    #[test]
1097    fn test_azure_build_request_normalizes_known_system_role_name() {
1098        let provider =
1099            AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1100                system_role_name: Some("SYSTEM ".to_string()),
1101                ..CompatConfig::default()
1102            }));
1103        let context = Context {
1104            system_prompt: Some("You are deterministic.".to_string().into()),
1105            messages: Vec::new().into(),
1106            tools: Vec::new().into(),
1107        };
1108
1109        let request = provider.build_request(&context, &StreamOptions::default());
1110        let request_json = serde_json::to_value(&request).expect("serialize request");
1111        assert_eq!(request_json["messages"][0]["role"], json!("system"));
1112    }
1113
1114    #[test]
1115    fn test_azure_build_request_preserves_unknown_system_role_name() {
1116        let provider =
1117            AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1118                system_role_name: Some("custom_role".to_string()),
1119                ..CompatConfig::default()
1120            }));
1121        let context = Context {
1122            system_prompt: Some("You are deterministic.".to_string().into()),
1123            messages: Vec::new().into(),
1124            tools: Vec::new().into(),
1125        };
1126
1127        let request = provider.build_request(&context, &StreamOptions::default());
1128        let request_json = serde_json::to_value(&request).expect("serialize request");
1129        assert_eq!(request_json["messages"][0]["role"], json!("custom_role"));
1130    }
1131
1132    #[test]
1133    fn test_azure_message_conversion() {
1134        let message = Message::User(UserMessage {
1135            content: UserContent::Text("Hello".to_string()),
1136            timestamp: chrono::Utc::now().timestamp_millis(),
1137        });
1138
1139        let azure_messages = convert_message_to_azure(&message);
1140        assert_eq!(azure_messages.len(), 1);
1141        assert_eq!(azure_messages[0].role, "user");
1142    }
1143
1144    #[derive(Debug, Deserialize)]
1145    struct ProviderFixture {
1146        cases: Vec<ProviderCase>,
1147    }
1148
1149    #[derive(Debug, Deserialize)]
1150    struct ProviderCase {
1151        name: String,
1152        events: Vec<Value>,
1153        expected: Vec<EventSummary>,
1154    }
1155
1156    #[derive(Debug, Deserialize, Serialize, PartialEq)]
1157    struct EventSummary {
1158        kind: String,
1159        #[serde(default)]
1160        content_index: Option<usize>,
1161        #[serde(default)]
1162        delta: Option<String>,
1163        #[serde(default)]
1164        content: Option<String>,
1165        #[serde(default)]
1166        reason: Option<String>,
1167    }
1168
1169    #[test]
1170    fn test_stream_fixtures() {
1171        let fixture = load_fixture("azure_stream.json");
1172        for case in fixture.cases {
1173            let events = collect_events(&case.events);
1174            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1175            assert_eq!(summaries, case.expected, "case {}", case.name);
1176        }
1177    }
1178
1179    #[test]
1180    fn test_stream_handles_sparse_tool_call_index_without_panic() {
1181        let events = vec![
1182            json!({ "choices": [{ "delta": {} }] }),
1183            json!({
1184                "choices": [{
1185                    "delta": {
1186                        "tool_calls": [{
1187                            "index": 3,
1188                            "id": "call_sparse",
1189                            "function": {
1190                                "name": "lookup",
1191                                "arguments": "{\"q\":\"azure\"}"
1192                            }
1193                        }]
1194                    }
1195                }]
1196            }),
1197            json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1198            Value::String("[DONE]".to_string()),
1199        ];
1200
1201        let out = collect_events(&events);
1202        let done = out
1203            .iter()
1204            .find_map(|event| match event {
1205                StreamEvent::Done { message, .. } => Some(message),
1206                _ => None,
1207            })
1208            .expect("done event");
1209        let tool_calls: Vec<&ToolCall> = done
1210            .content
1211            .iter()
1212            .filter_map(|block| match block {
1213                ContentBlock::ToolCall(tc) => Some(tc),
1214                _ => None,
1215            })
1216            .collect();
1217        assert_eq!(tool_calls.len(), 1);
1218        assert_eq!(tool_calls[0].id, "call_sparse");
1219        assert_eq!(tool_calls[0].name, "lookup");
1220        assert_eq!(tool_calls[0].arguments, json!({ "q": "azure" }));
1221        assert!(
1222            out.iter()
1223                .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1224            "expected tool call start event"
1225        );
1226    }
1227
1228    #[derive(Debug)]
1229    struct CapturedRequest {
1230        headers: HashMap<String, String>,
1231        body: String,
1232    }
1233
1234    #[test]
1235    fn test_stream_compat_api_key_header_works_without_api_key() {
1236        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1237        let mut custom_headers = HashMap::new();
1238        custom_headers.insert("api-key".to_string(), "compat-azure-key".to_string());
1239        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1240            .with_endpoint_url(base_url)
1241            .with_compat(Some(CompatConfig {
1242                custom_headers: Some(custom_headers),
1243                ..CompatConfig::default()
1244            }));
1245        let context = Context {
1246            system_prompt: None,
1247            messages: vec![Message::User(UserMessage {
1248                content: UserContent::Text("ping".to_string()),
1249                timestamp: 0,
1250            })]
1251            .into(),
1252            tools: Vec::new().into(),
1253        };
1254
1255        let runtime = RuntimeBuilder::current_thread()
1256            .build()
1257            .expect("runtime build");
1258        runtime.block_on(async {
1259            let mut stream = provider
1260                .stream(&context, &StreamOptions::default())
1261                .await
1262                .expect("stream");
1263            while let Some(event) = stream.next().await {
1264                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1265                    break;
1266                }
1267            }
1268        });
1269
1270        let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1271        assert_eq!(
1272            captured.headers.get("api-key").map(String::as_str),
1273            Some("compat-azure-key")
1274        );
1275        let body: Value = serde_json::from_str(&captured.body).expect("body json");
1276        assert_eq!(body["stream"], true);
1277    }
1278
1279    #[test]
1280    fn test_stream_compat_authorization_header_works_without_api_key() {
1281        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1282        let mut custom_headers = HashMap::new();
1283        custom_headers.insert(
1284            "Authorization".to_string(),
1285            "Bearer compat-azure-token".to_string(),
1286        );
1287        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1288            .with_endpoint_url(base_url)
1289            .with_compat(Some(CompatConfig {
1290                custom_headers: Some(custom_headers),
1291                ..CompatConfig::default()
1292            }));
1293        let context = Context {
1294            system_prompt: None,
1295            messages: vec![Message::User(UserMessage {
1296                content: UserContent::Text("ping".to_string()),
1297                timestamp: 0,
1298            })]
1299            .into(),
1300            tools: Vec::new().into(),
1301        };
1302
1303        let runtime = RuntimeBuilder::current_thread()
1304            .build()
1305            .expect("runtime build");
1306        runtime.block_on(async {
1307            let mut stream = provider
1308                .stream(&context, &StreamOptions::default())
1309                .await
1310                .expect("stream");
1311            while let Some(event) = stream.next().await {
1312                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1313                    break;
1314                }
1315            }
1316        });
1317
1318        let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1319        assert_eq!(
1320            captured.headers.get("authorization").map(String::as_str),
1321            Some("Bearer compat-azure-token")
1322        );
1323        let body: Value = serde_json::from_str(&captured.body).expect("body json");
1324        assert_eq!(body["stream"], true);
1325    }
1326
1327    #[test]
1328    fn test_blank_compat_api_key_header_does_not_override_builtin_api_key() {
1329        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1330        let mut custom_headers = HashMap::new();
1331        custom_headers.insert("api-key".to_string(), "   ".to_string());
1332        let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1333            .with_endpoint_url(base_url)
1334            .with_compat(Some(CompatConfig {
1335                custom_headers: Some(custom_headers),
1336                ..CompatConfig::default()
1337            }));
1338        let context = Context {
1339            system_prompt: None,
1340            messages: vec![Message::User(UserMessage {
1341                content: UserContent::Text("ping".to_string()),
1342                timestamp: 0,
1343            })]
1344            .into(),
1345            tools: Vec::new().into(),
1346        };
1347        let options = StreamOptions {
1348            api_key: Some("fallback-azure-key".to_string()),
1349            ..Default::default()
1350        };
1351
1352        let runtime = RuntimeBuilder::current_thread()
1353            .build()
1354            .expect("runtime build");
1355        runtime.block_on(async {
1356            let mut stream = provider.stream(&context, &options).await.expect("stream");
1357            while let Some(event) = stream.next().await {
1358                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1359                    break;
1360                }
1361            }
1362        });
1363
1364        let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1365        assert_eq!(
1366            captured.headers.get("api-key").map(String::as_str),
1367            Some("fallback-azure-key")
1368        );
1369    }
1370
1371    fn load_fixture(file_name: &str) -> ProviderFixture {
1372        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1373            .join("tests/fixtures/provider_responses")
1374            .join(file_name);
1375        let raw = std::fs::read_to_string(path).expect("fixture read");
1376        serde_json::from_str(&raw).expect("fixture parse")
1377    }
1378
1379    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1380        let runtime = RuntimeBuilder::current_thread()
1381            .build()
1382            .expect("runtime build");
1383        runtime.block_on(async move {
1384            let byte_stream = stream::iter(
1385                events
1386                    .iter()
1387                    .map(|event| {
1388                        let data = match event {
1389                            Value::String(text) => text.clone(),
1390                            _ => serde_json::to_string(event).expect("serialize event"),
1391                        };
1392                        format!("data: {data}\n\n").into_bytes()
1393                    })
1394                    .map(Ok),
1395            );
1396            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1397            let mut state = StreamState::new(
1398                event_source,
1399                "gpt-test".to_string(),
1400                "azure-openai".to_string(),
1401                "azure".to_string(),
1402            );
1403            let mut out = Vec::new();
1404
1405            while let Some(item) = state.event_source.next().await {
1406                let msg = item.expect("SSE event");
1407                if msg.data == "[DONE]" {
1408                    out.extend(state.pending_events.drain(..));
1409                    let reason = state.partial.stop_reason;
1410                    out.push(StreamEvent::Done {
1411                        reason,
1412                        message: std::mem::take(&mut state.partial),
1413                    });
1414                    break;
1415                }
1416                state.process_event(&msg.data).expect("process_event");
1417                out.extend(state.pending_events.drain(..));
1418            }
1419
1420            out
1421        })
1422    }
1423
1424    fn success_sse_body() -> String {
1425        [
1426            r#"data: {"choices":[{"delta":{}}]}"#,
1427            "",
1428            r#"data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}"#,
1429            "",
1430            "data: [DONE]",
1431            "",
1432        ]
1433        .join("\n")
1434    }
1435
1436    fn spawn_test_server(
1437        status_code: u16,
1438        content_type: &str,
1439        body: &str,
1440    ) -> (String, mpsc::Receiver<CapturedRequest>) {
1441        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1442        let addr = listener.local_addr().expect("local addr");
1443        let (tx, rx) = mpsc::channel();
1444        let body = body.to_string();
1445        let content_type = content_type.to_string();
1446
1447        std::thread::spawn(move || {
1448            let (mut socket, _) = listener.accept().expect("accept");
1449            socket
1450                .set_read_timeout(Some(Duration::from_secs(2)))
1451                .expect("set read timeout");
1452
1453            let mut bytes = Vec::new();
1454            let mut chunk = [0_u8; 4096];
1455            loop {
1456                match socket.read(&mut chunk) {
1457                    Ok(0) => break,
1458                    Ok(n) => {
1459                        bytes.extend_from_slice(&chunk[..n]);
1460                        if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1461                            break;
1462                        }
1463                    }
1464                    Err(err)
1465                        if err.kind() == std::io::ErrorKind::WouldBlock
1466                            || err.kind() == std::io::ErrorKind::TimedOut =>
1467                    {
1468                        break;
1469                    }
1470                    Err(err) => panic!("{err}"),
1471                }
1472            }
1473
1474            let header_end = bytes
1475                .windows(4)
1476                .position(|window| window == b"\r\n\r\n")
1477                .expect("request header boundary");
1478            let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1479            let headers = parse_headers(&header_text);
1480            let mut request_body = bytes[header_end + 4..].to_vec();
1481
1482            let content_length = headers
1483                .get("content-length")
1484                .and_then(|value| value.parse::<usize>().ok())
1485                .unwrap_or(0);
1486            while request_body.len() < content_length {
1487                match socket.read(&mut chunk) {
1488                    Ok(0) => break,
1489                    Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1490                    Err(err)
1491                        if err.kind() == std::io::ErrorKind::WouldBlock
1492                            || err.kind() == std::io::ErrorKind::TimedOut =>
1493                    {
1494                        break;
1495                    }
1496                    Err(err) => panic!("{err}"),
1497                }
1498            }
1499
1500            tx.send(CapturedRequest {
1501                headers,
1502                body: String::from_utf8_lossy(&request_body).to_string(),
1503            })
1504            .expect("send captured request");
1505
1506            let reason = match status_code {
1507                401 => "Unauthorized",
1508                500 => "Internal Server Error",
1509                _ => "OK",
1510            };
1511            let response = format!(
1512                "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1513                body.len()
1514            );
1515            socket
1516                .write_all(response.as_bytes())
1517                .expect("write response");
1518            socket.flush().expect("flush response");
1519        });
1520
1521        (format!("http://{addr}/azure"), rx)
1522    }
1523
1524    fn parse_headers(header_text: &str) -> HashMap<String, String> {
1525        let mut headers = HashMap::new();
1526        for line in header_text.lines().skip(1) {
1527            if let Some((name, value)) = line.split_once(':') {
1528                headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1529            }
1530        }
1531        headers
1532    }
1533
1534    fn summarize_event(event: &StreamEvent) -> EventSummary {
1535        match event {
1536            StreamEvent::Start { .. } => EventSummary {
1537                kind: "start".to_string(),
1538                content_index: None,
1539                delta: None,
1540                content: None,
1541                reason: None,
1542            },
1543            StreamEvent::TextDelta {
1544                content_index,
1545                delta,
1546                ..
1547            } => EventSummary {
1548                kind: "text_delta".to_string(),
1549                content_index: Some(*content_index),
1550                delta: Some(delta.clone()),
1551                content: None,
1552                reason: None,
1553            },
1554            StreamEvent::Done { reason, .. } => EventSummary {
1555                kind: "done".to_string(),
1556                content_index: None,
1557                delta: None,
1558                content: None,
1559                reason: Some(reason_to_string(*reason)),
1560            },
1561            StreamEvent::Error { reason, .. } => EventSummary {
1562                kind: "error".to_string(),
1563                content_index: None,
1564                delta: None,
1565                content: None,
1566                reason: Some(reason_to_string(*reason)),
1567            },
1568            StreamEvent::TextStart { content_index, .. } => EventSummary {
1569                kind: "text_start".to_string(),
1570                content_index: Some(*content_index),
1571                delta: None,
1572                content: None,
1573                reason: None,
1574            },
1575            StreamEvent::TextEnd {
1576                content_index,
1577                content,
1578                ..
1579            } => EventSummary {
1580                kind: "text_end".to_string(),
1581                content_index: Some(*content_index),
1582                delta: None,
1583                content: Some(content.clone()),
1584                reason: None,
1585            },
1586            _ => EventSummary {
1587                kind: "other".to_string(),
1588                content_index: None,
1589                delta: None,
1590                content: None,
1591                reason: None,
1592            },
1593        }
1594    }
1595
1596    fn reason_to_string(reason: StopReason) -> String {
1597        match reason {
1598            StopReason::Stop => "stop",
1599            StopReason::Length => "length",
1600            StopReason::ToolUse => "tool_use",
1601            StopReason::Error => "error",
1602            StopReason::Aborted => "aborted",
1603        }
1604        .to_string()
1605    }
1606
1607    fn make_tool_result(content: Vec<ContentBlock>) -> Message {
1608        Message::tool_result(ToolResultMessage {
1609            tool_call_id: "call_123".to_string(),
1610            tool_name: "test_tool".to_string(),
1611            content,
1612            details: None,
1613            is_error: false,
1614            timestamp: 0,
1615        })
1616    }
1617
1618    #[test]
1619    fn tool_result_text_only_produces_single_tool_message() {
1620        let msg = make_tool_result(vec![ContentBlock::Text(TextContent {
1621            text: "result text".to_string(),
1622            text_signature: None,
1623        })]);
1624        let azure_msgs = convert_message_to_azure(&msg);
1625        assert_eq!(azure_msgs.len(), 1);
1626        assert_eq!(azure_msgs[0].role, "tool");
1627        assert_eq!(azure_msgs[0].tool_call_id.as_deref(), Some("call_123"));
1628        let json = serde_json::to_value(&azure_msgs[0]).expect("serialize");
1629        assert_eq!(json["content"], "result text");
1630    }
1631
1632    #[test]
1633    fn tool_result_image_only_produces_tool_plus_user_message() {
1634        let msg = make_tool_result(vec![ContentBlock::Image(ImageContent {
1635            data: "aW1hZ2U=".to_string(),
1636            mime_type: "image/png".to_string(),
1637        })]);
1638        let azure_msgs = convert_message_to_azure(&msg);
1639        assert_eq!(
1640            azure_msgs.len(),
1641            2,
1642            "image-only should produce tool + user messages"
1643        );
1644        assert_eq!(azure_msgs[0].role, "tool");
1645        assert_eq!(azure_msgs[1].role, "user");
1646
1647        let tool_json = serde_json::to_value(&azure_msgs[0]).expect("serialize tool");
1648        assert_eq!(tool_json["content"], "(see attached image)");
1649
1650        let user_json = serde_json::to_value(&azure_msgs[1]).expect("serialize user");
1651        let parts = user_json["content"].as_array().expect("parts array");
1652        assert_eq!(parts.len(), 2);
1653        assert_eq!(parts[0]["type"], "text");
1654        assert_eq!(parts[1]["type"], "image_url");
1655        assert!(
1656            parts[1]["image_url"]["url"]
1657                .as_str()
1658                .unwrap()
1659                .starts_with("data:image/png;base64,")
1660        );
1661    }
1662
1663    #[test]
1664    fn tool_result_mixed_text_and_image_splits_correctly() {
1665        let msg = make_tool_result(vec![
1666            ContentBlock::Text(TextContent {
1667                text: "line one".to_string(),
1668                text_signature: None,
1669            }),
1670            ContentBlock::Image(ImageContent {
1671                data: "aW1hZ2U=".to_string(),
1672                mime_type: "image/jpeg".to_string(),
1673            }),
1674            ContentBlock::Text(TextContent {
1675                text: "line two".to_string(),
1676                text_signature: None,
1677            }),
1678        ]);
1679        let azure_msgs = convert_message_to_azure(&msg);
1680        assert_eq!(
1681            azure_msgs.len(),
1682            2,
1683            "mixed content should produce tool + user messages"
1684        );
1685
1686        let tool_json = serde_json::to_value(&azure_msgs[0]).expect("serialize tool");
1687        assert_eq!(tool_json["content"], "line one\nline two");
1688        assert_eq!(tool_json["tool_call_id"], "call_123");
1689
1690        let user_json = serde_json::to_value(&azure_msgs[1]).expect("serialize user");
1691        let parts = user_json["content"].as_array().expect("parts array");
1692        assert_eq!(parts.len(), 2);
1693        assert_eq!(parts[0]["type"], "text");
1694        assert_eq!(parts[1]["type"], "image_url");
1695    }
1696
1697    #[test]
1698    fn tool_result_empty_content_produces_single_tool_message_with_no_content() {
1699        let msg = make_tool_result(vec![]);
1700        let azure_msgs = convert_message_to_azure(&msg);
1701        assert_eq!(azure_msgs.len(), 1);
1702        assert_eq!(azure_msgs[0].role, "tool");
1703        let json = serde_json::to_value(&azure_msgs[0]).expect("serialize");
1704        assert!(
1705            json["content"].is_null(),
1706            "empty tool result should have null content"
1707        );
1708    }
1709}
1710
1711// ============================================================================
1712// Fuzzing support
1713// ============================================================================
1714
1715#[cfg(feature = "fuzzing")]
1716pub mod fuzz {
1717    use super::*;
1718    use futures::stream;
1719    use std::pin::Pin;
1720
1721    type FuzzStream =
1722        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1723
1724    /// Opaque wrapper around the Azure OpenAI stream processor state.
1725    pub struct Processor(StreamState<FuzzStream>);
1726
1727    impl Default for Processor {
1728        fn default() -> Self {
1729            Self::new()
1730        }
1731    }
1732
1733    impl Processor {
1734        /// Create a fresh processor with default state.
1735        pub fn new() -> Self {
1736            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1737            Self(StreamState::new(
1738                crate::sse::SseStream::new(Box::pin(empty)),
1739                "azure-fuzz".into(),
1740                "azure-openai".into(),
1741                "azure".into(),
1742            ))
1743        }
1744
1745        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1746        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1747            self.0.process_event(data)?;
1748            Ok(self.0.pending_events.drain(..).collect())
1749        }
1750    }
1751}