1use crate::error::{Error, Result};
10use crate::http::client::Client;
11use crate::model::{
12 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::collections::VecDeque;
22use std::pin::Pin;
23
24pub(crate) const DEFAULT_API_VERSION: &str = "2024-12-01-preview";
30const DEFAULT_MAX_TOKENS: u32 = 4096;
31
32pub(crate) fn azure_api_version() -> String {
33 std::env::var("PI_AZURE_API_VERSION")
34 .ok()
35 .filter(|v| !v.is_empty())
36 .unwrap_or_else(|| DEFAULT_API_VERSION.to_string())
37}
38
39fn normalize_role(role: &str) -> String {
41 let trimmed = role.trim();
42 match trimmed {
43 "system" | "developer" | "user" | "assistant" | "tool" | "function" => trimmed.to_string(),
44 _ => {
45 let lowered = trimmed.to_ascii_lowercase();
46 match lowered.as_str() {
47 "system" | "developer" | "user" | "assistant" | "tool" | "function" => lowered,
48 _ => trimmed.to_string(),
49 }
50 }
51 }
52}
53
54fn authorization_override(
55 options: &StreamOptions,
56 compat: Option<&CompatConfig>,
57) -> Option<String> {
58 super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
59 .or_else(|| {
60 compat
61 .and_then(|compat| compat.custom_headers.as_ref())
62 .and_then(|headers| {
63 super::first_non_empty_header_value_case_insensitive(
64 headers,
65 &["authorization"],
66 )
67 })
68 })
69}
70
71fn api_key_override(options: &StreamOptions, compat: Option<&CompatConfig>) -> Option<String> {
72 super::first_non_empty_header_value_case_insensitive(&options.headers, &["api-key"]).or_else(
73 || {
74 compat
75 .and_then(|compat| compat.custom_headers.as_ref())
76 .and_then(|headers| {
77 super::first_non_empty_header_value_case_insensitive(headers, &["api-key"])
78 })
79 },
80 )
81}
82
83pub struct AzureOpenAIProvider {
89 client: Client,
90 provider: String,
92 deployment: String,
94 resource: String,
96 api_version: String,
98 endpoint_url_override: Option<String>,
100 compat: Option<CompatConfig>,
101}
102
103impl AzureOpenAIProvider {
104 pub fn new(resource: impl Into<String>, deployment: impl Into<String>) -> Self {
110 Self {
111 client: Client::new(),
112 provider: "azure".to_string(),
113 deployment: deployment.into(),
114 resource: resource.into(),
115 api_version: azure_api_version(),
116 endpoint_url_override: None,
117 compat: None,
118 }
119 }
120
121 #[must_use]
123 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
124 self.provider = provider.into();
125 self
126 }
127
128 #[must_use]
130 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
131 self.api_version = version.into();
132 self
133 }
134
135 #[must_use]
140 pub fn with_endpoint_url(mut self, endpoint_url: impl Into<String>) -> Self {
141 self.endpoint_url_override = Some(endpoint_url.into());
142 self
143 }
144
145 #[must_use]
147 pub fn with_client(mut self, client: Client) -> Self {
148 self.client = client;
149 self
150 }
151
152 #[must_use]
154 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
155 self.compat = compat;
156 self
157 }
158
159 fn endpoint_url(&self) -> String {
161 if let Some(url) = &self.endpoint_url_override {
162 return url.clone();
163 }
164 format!(
165 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
166 self.resource, self.deployment, self.api_version
167 )
168 }
169
170 #[allow(clippy::unused_self)]
172 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> AzureRequest {
173 let messages = self.build_messages(context);
174
175 let tools: Option<Vec<AzureTool>> = if context.tools.is_empty() {
176 None
177 } else {
178 Some(context.tools.iter().map(convert_tool_to_azure).collect())
179 };
180
181 AzureRequest {
182 messages,
183 max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
184 temperature: options.temperature,
185 tools,
186 stream: true,
187 stream_options: Some(AzureStreamOptions {
188 include_usage: true,
189 }),
190 }
191 }
192
193 fn build_messages(&self, context: &Context<'_>) -> Vec<AzureMessage> {
195 let mut messages = Vec::new();
196 let system_role = self
197 .compat
198 .as_ref()
199 .and_then(|c| c.system_role_name.as_deref())
200 .unwrap_or("system");
201
202 if let Some(system) = &context.system_prompt {
204 messages.push(AzureMessage {
205 role: normalize_role(system_role),
206 content: Some(AzureContent::Text(system.to_string())),
207 tool_calls: None,
208 tool_call_id: None,
209 });
210 }
211
212 for message in context.messages.iter() {
214 messages.extend(convert_message_to_azure(message));
215 }
216
217 messages
218 }
219}
220
221#[async_trait]
222#[allow(clippy::too_many_lines)]
223impl Provider for AzureOpenAIProvider {
224 fn name(&self) -> &str {
225 &self.provider
226 }
227
228 fn api(&self) -> &'static str {
229 "azure-openai"
230 }
231
232 fn model_id(&self) -> &str {
233 &self.deployment
234 }
235
236 async fn stream(
237 &self,
238 context: &Context<'_>,
239 options: &StreamOptions,
240 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
241 let has_auth_override = api_key_override(options, self.compat.as_ref()).is_some()
242 || authorization_override(options, self.compat.as_ref()).is_some();
243 let auth_value = if has_auth_override {
244 None
245 } else {
246 Some(
247 options
248 .api_key
249 .clone()
250 .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
251 .ok_or_else(|| Error::provider("azure-openai", "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var."))?,
252 )
253 };
254
255 let request_body = self.build_request(context, options);
256
257 let endpoint_url = self.endpoint_url();
258
259 let mut request = self
261 .client
262 .post(&endpoint_url)
263 .header("Accept", "text/event-stream");
264
265 if let Some(auth_value) = auth_value {
266 request = request.header("api-key", &auth_value); }
268
269 if let Some(compat) = &self.compat {
271 if let Some(custom_headers) = &compat.custom_headers {
272 request = super::apply_headers_ignoring_blank_auth_overrides(
273 request,
274 custom_headers,
275 &["authorization", "api-key"],
276 );
277 }
278 }
279
280 request = super::apply_headers_ignoring_blank_auth_overrides(
281 request,
282 &options.headers,
283 &["authorization", "api-key"],
284 );
285
286 let request = request.json(&request_body)?;
287
288 let response = Box::pin(request.send()).await?;
289 let status = response.status();
290 if !(200..300).contains(&status) {
291 let body = response
292 .text()
293 .await
294 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
295 return Err(Error::provider(
296 "azure-openai",
297 format!("Azure OpenAI API error (HTTP {status}): {body}"),
298 ));
299 }
300
301 let event_source = SseStream::new(response.bytes_stream());
303
304 let model = self.deployment.clone();
306 let api = self.api().to_string();
307 let provider = self.name().to_string();
308
309 let stream = stream::unfold(
310 StreamState::new(event_source, model, api, provider),
311 |mut state| async move {
312 if state.done {
313 return None;
314 }
315 loop {
316 if let Some(event) = state.pending_events.pop_front() {
317 return Some((Ok(event), state));
318 }
319
320 match state.event_source.next().await {
321 Some(Ok(msg)) => {
322 state.transient_error_count = 0;
323 if msg.data == "[DONE]" {
325 state.done = true;
326 let reason = state.partial.stop_reason;
327 let message = std::mem::take(&mut state.partial);
328 return Some((Ok(StreamEvent::Done { reason, message }), state));
329 }
330
331 if let Err(e) = state.process_event(&msg.data) {
332 state.done = true;
333 return Some((Err(e), state));
334 }
335 }
336 Some(Err(e)) => {
337 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
341 if e.kind() == std::io::ErrorKind::WriteZero
342 || e.kind() == std::io::ErrorKind::WouldBlock
343 || e.kind() == std::io::ErrorKind::TimedOut
344 {
345 state.transient_error_count += 1;
346 if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
347 tracing::warn!(
348 kind = ?e.kind(),
349 count = state.transient_error_count,
350 "Transient error in SSE stream, continuing"
351 );
352 continue;
353 }
354 tracing::warn!(
355 kind = ?e.kind(),
356 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
357 consecutive attempts, treating as fatal"
358 );
359 }
360 state.done = true;
361 let err = Error::api(format!("SSE error: {e}"));
362 return Some((Err(err), state));
363 }
364 None => {
369 state.done = true;
370 let reason = state.partial.stop_reason;
371 let message = std::mem::take(&mut state.partial);
372 return Some((Ok(StreamEvent::Done { reason, message }), state));
373 }
374 }
375 }
376 },
377 );
378
379 Ok(Box::pin(stream))
380 }
381}
382
383struct StreamState<S>
388where
389 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
390{
391 event_source: SseStream<S>,
392 partial: AssistantMessage,
393 tool_calls: Vec<ToolCallState>,
394 pending_events: VecDeque<StreamEvent>,
395 started: bool,
396 done: bool,
397 transient_error_count: usize,
399}
400
401struct ToolCallState {
402 index: usize,
403 content_index: usize,
404 id: String,
405 name: String,
406 arguments: String,
407}
408
409impl<S> StreamState<S>
410where
411 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
412{
413 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
414 Self {
415 event_source,
416 partial: AssistantMessage {
417 content: Vec::new(),
418 api,
419 provider,
420 model,
421 usage: Usage::default(),
422 stop_reason: StopReason::Stop,
423 error_message: None,
424 timestamp: chrono::Utc::now().timestamp_millis(),
425 },
426 tool_calls: Vec::new(),
427 pending_events: VecDeque::new(),
428 started: false,
429 done: false,
430 transient_error_count: 0,
431 }
432 }
433
434 fn finalize_tool_call_arguments(&mut self) {
435 for tc in &self.tool_calls {
436 let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
437 Ok(args) => args,
438 Err(e) => {
439 tracing::warn!(
440 error = %e,
441 raw = %tc.arguments,
442 "Failed to parse tool arguments as JSON"
443 );
444 serde_json::Value::Null
445 }
446 };
447
448 if let Some(ContentBlock::ToolCall(block)) =
449 self.partial.content.get_mut(tc.content_index)
450 {
451 block.arguments = arguments;
452 }
453 }
454 }
455
456 fn push_text_delta(&mut self, text: String) -> StreamEvent {
457 let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
458 if !last_is_text {
459 let content_index = self.partial.content.len();
460 self.partial
461 .content
462 .push(ContentBlock::Text(crate::model::TextContent::new("")));
463 self.pending_events
464 .push_back(StreamEvent::TextStart { content_index });
465 }
466 let content_index = self.partial.content.len() - 1;
467
468 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
469 t.text.push_str(&text);
470 }
471
472 StreamEvent::TextDelta {
473 content_index,
474 delta: text,
475 }
476 }
477
478 fn ensure_started(&mut self) {
479 if !self.started {
480 self.started = true;
481 self.pending_events.push_back(StreamEvent::Start {
482 partial: self.partial.clone(),
483 });
484 }
485 }
486
487 #[allow(clippy::unnecessary_wraps, clippy::too_many_lines)]
488 fn process_event(&mut self, data: &str) -> Result<()> {
489 let chunk: AzureStreamChunk = serde_json::from_str(data)
490 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
491
492 if let Some(usage) = chunk.usage {
494 self.partial.usage.input = usage.prompt_tokens;
495 self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
496 self.partial.usage.total_tokens = usage.total_tokens;
497 }
498
499 let choices = chunk.choices;
500 if !self.started {
501 let first = choices.first();
502 let delta_is_empty = first.is_some_and(|choice| {
503 choice.finish_reason.is_none()
504 && choice.delta.content.is_none()
505 && choice.delta.tool_calls.is_none()
506 });
507 if delta_is_empty {
508 self.ensure_started();
509 return Ok(());
510 }
511 }
512
513 for choice in choices {
517 if let Some(text) = choice.delta.content {
519 self.ensure_started();
520 let event = self.push_text_delta(text);
521 self.pending_events.push_back(event);
522 }
523
524 if let Some(tool_calls) = choice.delta.tool_calls {
526 self.ensure_started();
527
528 for tc in tool_calls {
529 let idx = tc.index as usize;
530
531 let tool_state_idx = if let Some(existing_idx) =
534 self.tool_calls.iter().position(|tc| tc.index == idx)
535 {
536 existing_idx
537 } else {
538 let content_index = self.partial.content.len();
539 self.tool_calls.push(ToolCallState {
540 index: idx,
541 content_index,
542 id: String::new(),
543 name: String::new(),
544 arguments: String::new(),
545 });
546
547 self.partial
549 .content
550 .push(ContentBlock::ToolCall(crate::model::ToolCall {
551 id: String::new(),
552 name: String::new(),
553 arguments: serde_json::Value::Null,
554 thought_signature: None,
555 }));
556
557 self.pending_events
559 .push_back(StreamEvent::ToolCallStart { content_index });
560 self.tool_calls.len() - 1
561 };
562
563 let tc_state = &mut self.tool_calls[tool_state_idx];
564 let content_index = tc_state.content_index;
565
566 if let Some(id) = tc.id {
568 tc_state.id.push_str(&id);
569 if let Some(ContentBlock::ToolCall(block)) =
570 self.partial.content.get_mut(content_index)
571 {
572 block.id.clone_from(&tc_state.id);
573 }
574 }
575 if let Some(func) = tc.function {
576 if let Some(name) = func.name {
577 tc_state.name.push_str(&name);
578 if let Some(ContentBlock::ToolCall(block)) =
579 self.partial.content.get_mut(content_index)
580 {
581 block.name.clone_from(&tc_state.name);
582 }
583 }
584 if let Some(args) = func.arguments {
585 tc_state.arguments.push_str(&args);
586 self.pending_events.push_back(StreamEvent::ToolCallDelta {
590 content_index,
591 delta: args,
592 });
593 }
594 }
595 }
596 }
597
598 if choice.finish_reason.is_some() {
602 self.ensure_started();
603 }
604 if let Some(reason) = choice.finish_reason {
605 self.partial.stop_reason = match reason.as_str() {
606 "length" => StopReason::Length,
607 "content_filter" => StopReason::Error,
608 "tool_calls" => StopReason::ToolUse,
609 _ => StopReason::Stop,
611 };
612
613 self.finalize_tool_call_arguments();
615
616 for (content_index, block) in self.partial.content.iter().enumerate() {
618 if let ContentBlock::Text(t) = block {
619 self.pending_events.push_back(StreamEvent::TextEnd {
620 content_index,
621 content: t.text.clone(),
622 });
623 } else if let ContentBlock::Thinking(t) = block {
624 self.pending_events.push_back(StreamEvent::ThinkingEnd {
625 content_index,
626 content: t.thinking.clone(),
627 });
628 }
629 }
630
631 for tc in &self.tool_calls {
633 if let Some(ContentBlock::ToolCall(tool_call)) =
634 self.partial.content.get(tc.content_index)
635 {
636 self.pending_events.push_back(StreamEvent::ToolCallEnd {
637 content_index: tc.content_index,
638 tool_call: tool_call.clone(),
639 });
640 }
641 }
642 }
643 }
644
645 Ok(())
646 }
647}
648
649#[derive(Debug, Serialize)]
654pub struct AzureRequest {
655 messages: Vec<AzureMessage>,
656 #[serde(skip_serializing_if = "Option::is_none")]
657 max_tokens: Option<u32>,
658 #[serde(skip_serializing_if = "Option::is_none")]
659 temperature: Option<f32>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 tools: Option<Vec<AzureTool>>,
662 stream: bool,
663 #[serde(skip_serializing_if = "Option::is_none")]
664 stream_options: Option<AzureStreamOptions>,
665}
666
667#[derive(Debug, Serialize)]
668struct AzureStreamOptions {
669 include_usage: bool,
670}
671
672#[derive(Debug, Serialize)]
673struct AzureMessage {
674 role: String,
675 #[serde(skip_serializing_if = "Option::is_none")]
676 content: Option<AzureContent>,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 tool_calls: Option<Vec<AzureToolCallRef>>,
679 #[serde(skip_serializing_if = "Option::is_none")]
680 tool_call_id: Option<String>,
681}
682
683#[derive(Debug, Serialize)]
684#[serde(untagged)]
685enum AzureContent {
686 Text(String),
687 Parts(Vec<AzureContentPart>),
688}
689
690#[derive(Debug, Serialize)]
691#[serde(tag = "type")]
692enum AzureContentPart {
693 #[serde(rename = "text")]
694 Text { text: String },
695 #[serde(rename = "image_url")]
696 ImageUrl { image_url: AzureImageUrl },
697}
698
699#[derive(Debug, Serialize)]
700struct AzureImageUrl {
701 url: String,
702}
703
704#[derive(Debug, Serialize)]
705struct AzureToolCallRef {
706 id: String,
707 r#type: &'static str,
708 function: AzureFunctionRef,
709}
710
711#[derive(Debug, Serialize)]
712struct AzureFunctionRef {
713 name: String,
714 arguments: String,
715}
716
717#[derive(Debug, Serialize)]
718struct AzureTool {
719 r#type: &'static str,
720 function: AzureFunction,
721}
722
723#[derive(Debug, Serialize)]
724struct AzureFunction {
725 name: String,
726 description: String,
727 parameters: serde_json::Value,
728}
729
730#[derive(Debug, Deserialize)]
735struct AzureStreamChunk {
736 #[serde(default)]
737 choices: Vec<AzureChoice>,
738 #[serde(default)]
739 usage: Option<AzureUsage>,
740}
741
742#[derive(Debug, Deserialize)]
743struct AzureChoice {
744 delta: AzureDelta,
745 #[serde(default)]
746 finish_reason: Option<String>,
747}
748
749#[derive(Debug, Deserialize)]
750struct AzureDelta {
751 #[serde(default)]
752 content: Option<String>,
753 #[serde(default)]
754 tool_calls: Option<Vec<AzureToolCallDelta>>,
755}
756
757#[derive(Debug, Deserialize)]
758struct AzureToolCallDelta {
759 index: u32,
760 #[serde(default)]
761 id: Option<String>,
762 #[serde(default)]
763 function: Option<AzureFunctionDelta>,
764}
765
766#[derive(Debug, Deserialize)]
767struct AzureFunctionDelta {
768 #[serde(default)]
769 name: Option<String>,
770 #[serde(default)]
771 arguments: Option<String>,
772}
773
774#[derive(Debug, Deserialize)]
775#[allow(clippy::struct_field_names)]
776struct AzureUsage {
777 prompt_tokens: u64,
778 #[serde(default)]
779 completion_tokens: Option<u64>,
780 #[allow(dead_code)]
781 total_tokens: u64,
782}
783
784#[allow(clippy::too_many_lines)]
789fn convert_message_to_azure(message: &Message) -> Vec<AzureMessage> {
790 match message {
791 Message::User(user) => vec![AzureMessage {
792 role: "user".to_string(),
793 content: Some(convert_user_content(&user.content)),
794 tool_calls: None,
795 tool_call_id: None,
796 }],
797 Message::Custom(custom) => vec![AzureMessage {
798 role: "user".to_string(),
799 content: Some(AzureContent::Text(custom.content.clone())),
800 tool_calls: None,
801 tool_call_id: None,
802 }],
803 Message::Assistant(assistant) => {
804 let mut messages = Vec::new();
805
806 let text: String = assistant
808 .content
809 .iter()
810 .filter_map(|b| match b {
811 ContentBlock::Text(t) => Some(t.text.as_str()),
812 _ => None,
813 })
814 .collect::<Vec<_>>()
815 .join("\n\n");
816
817 let tool_calls: Vec<AzureToolCallRef> = assistant
819 .content
820 .iter()
821 .filter_map(|b| match b {
822 ContentBlock::ToolCall(tc) => Some(AzureToolCallRef {
823 id: tc.id.clone(),
824 r#type: "function",
825 function: AzureFunctionRef {
826 name: tc.name.clone(),
827 arguments: tc.arguments.to_string(),
828 },
829 }),
830 _ => None,
831 })
832 .collect();
833
834 let content = if text.is_empty() {
835 None
836 } else {
837 Some(AzureContent::Text(text))
838 };
839
840 let tool_calls = if tool_calls.is_empty() {
841 None
842 } else {
843 Some(tool_calls)
844 };
845
846 messages.push(AzureMessage {
847 role: "assistant".to_string(),
848 content,
849 tool_calls,
850 tool_call_id: None,
851 });
852
853 messages
854 }
855 Message::ToolResult(result) => {
856 let mut text_parts = Vec::new();
857 let mut image_parts = Vec::new();
858
859 for block in &result.content {
860 match block {
861 ContentBlock::Text(t) => text_parts.push(t.text.clone()),
862 ContentBlock::Image(img) => {
863 let url = format!("data:{};base64,{}", img.mime_type, img.data);
864 image_parts.push(AzureContentPart::ImageUrl {
865 image_url: AzureImageUrl { url },
866 });
867 }
868 _ => {}
869 }
870 }
871
872 let text_content = if text_parts.is_empty() {
873 if image_parts.is_empty() {
874 None
875 } else {
876 Some(AzureContent::Text("(see attached image)".to_string()))
877 }
878 } else {
879 Some(AzureContent::Text(text_parts.join("\n")))
880 };
881
882 let mut messages = vec![AzureMessage {
883 role: "tool".to_string(),
884 content: text_content,
885 tool_calls: None,
886 tool_call_id: Some(result.tool_call_id.clone()),
887 }];
888
889 if !image_parts.is_empty() {
890 let mut parts = vec![AzureContentPart::Text {
891 text: "Attached image(s) from tool result:".to_string(),
892 }];
893 parts.extend(image_parts);
894 messages.push(AzureMessage {
895 role: "user".to_string(),
896 content: Some(AzureContent::Parts(parts)),
897 tool_calls: None,
898 tool_call_id: None,
899 });
900 }
901
902 messages
903 }
904 }
905}
906
907fn convert_user_content(content: &UserContent) -> AzureContent {
908 match content {
909 UserContent::Text(text) => AzureContent::Text(text.clone()),
910 UserContent::Blocks(blocks) => {
911 let parts: Vec<AzureContentPart> = blocks
912 .iter()
913 .filter_map(|block| match block {
914 ContentBlock::Text(t) => Some(AzureContentPart::Text {
915 text: t.text.clone(),
916 }),
917 ContentBlock::Image(img) => {
918 let url = format!("data:{};base64,{}", img.mime_type, img.data);
919 Some(AzureContentPart::ImageUrl {
920 image_url: AzureImageUrl { url },
921 })
922 }
923 _ => None,
924 })
925 .collect();
926 AzureContent::Parts(parts)
927 }
928 }
929}
930
931fn convert_tool_to_azure(tool: &ToolDef) -> AzureTool {
932 AzureTool {
933 r#type: "function",
934 function: AzureFunction {
935 name: tool.name.clone(),
936 description: tool.description.clone(),
937 parameters: tool.parameters.clone(),
938 },
939 }
940}
941
942#[cfg(test)]
947mod tests {
948 use super::*;
949 use crate::model::{ImageContent, TextContent, ToolCall, ToolResultMessage, UserMessage};
950 use crate::provider::ToolDef;
951 use asupersync::runtime::RuntimeBuilder;
952 use futures::{StreamExt, stream};
953 use serde::{Deserialize, Serialize};
954 use serde_json::{Value, json};
955 use std::collections::HashMap;
956 use std::io::{Read, Write};
957 use std::net::TcpListener;
958 use std::path::PathBuf;
959 use std::sync::mpsc;
960 use std::time::Duration;
961
962 #[test]
963 fn test_azure_provider_creation() {
964 let provider = AzureOpenAIProvider::new("my-resource", "gpt-4");
965 assert_eq!(provider.name(), "azure");
966 assert_eq!(provider.api(), "azure-openai");
967 }
968
969 #[test]
970 fn test_azure_model_id_uses_deployment() {
971 let provider = AzureOpenAIProvider::new("my-resource", "gpt-4o-mini");
972 assert_eq!(provider.model_id(), "gpt-4o-mini");
973 }
974
975 #[test]
976 fn test_azure_endpoint_url() {
977 let provider = AzureOpenAIProvider::new("contoso", "gpt-4-turbo");
978 let url = provider.endpoint_url();
979 assert!(url.contains("contoso.openai.azure.com"));
980 assert!(url.contains("gpt-4-turbo"));
981 assert!(url.contains("api-version="));
982 }
983
984 #[test]
985 fn test_azure_endpoint_url_custom_version() {
986 let provider = AzureOpenAIProvider::new("contoso", "gpt-4").with_api_version("2024-06-01");
987 let url = provider.endpoint_url();
988 assert!(url.contains("api-version=2024-06-01"));
989 }
990
991 #[test]
992 fn test_azure_endpoint_url_exact_default_shape() {
993 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
994 let url = provider.endpoint_url();
995 assert_eq!(
996 url,
997 "https://contoso.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-12-01-preview"
998 );
999 }
1000
1001 #[test]
1002 fn test_azure_endpoint_url_override_takes_precedence() {
1003 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1004 .with_api_version("2025-01-01")
1005 .with_endpoint_url("http://127.0.0.1:1234/mock-endpoint");
1006 let url = provider.endpoint_url();
1007 assert_eq!(url, "http://127.0.0.1:1234/mock-endpoint");
1008 }
1009
1010 #[test]
1011 fn test_azure_build_request_includes_system_messages_and_tools() {
1012 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
1013 let context = Context {
1014 system_prompt: Some("You are deterministic.".to_string().into()),
1015 messages: vec![
1016 Message::User(UserMessage {
1017 content: UserContent::Text("Hello".to_string()),
1018 timestamp: 0,
1019 }),
1020 Message::assistant(AssistantMessage {
1021 content: vec![
1022 ContentBlock::Text(TextContent::new("Need tool output")),
1023 ContentBlock::ToolCall(ToolCall {
1024 id: "tool_1".to_string(),
1025 name: "echo".to_string(),
1026 arguments: json!({"text":"ping"}),
1027 thought_signature: None,
1028 }),
1029 ],
1030 api: "azure-openai".to_string(),
1031 provider: "azure".to_string(),
1032 model: "gpt-4o".to_string(),
1033 usage: Usage::default(),
1034 stop_reason: StopReason::ToolUse,
1035 error_message: None,
1036 timestamp: 0,
1037 }),
1038 ]
1039 .into(),
1040 tools: vec![ToolDef {
1041 name: "echo".to_string(),
1042 description: "Echo text".to_string(),
1043 parameters: json!({
1044 "type": "object",
1045 "properties": {
1046 "text": {"type":"string"}
1047 },
1048 "required": ["text"]
1049 }),
1050 }]
1051 .into(),
1052 };
1053 let options = StreamOptions {
1054 max_tokens: Some(512),
1055 temperature: Some(0.0),
1056 ..Default::default()
1057 };
1058
1059 let request = provider.build_request(&context, &options);
1060 let request_json = serde_json::to_value(&request).expect("serialize request");
1061 assert_eq!(request_json["max_tokens"], json!(512));
1062 assert_eq!(request_json["temperature"], json!(0.0));
1063 assert_eq!(request_json["stream"], json!(true));
1064 assert_eq!(request_json["messages"][0]["role"], json!("system"));
1065 assert_eq!(
1066 request_json["messages"][0]["content"],
1067 json!("You are deterministic.")
1068 );
1069 assert_eq!(request_json["messages"][1]["role"], json!("user"));
1070 assert_eq!(request_json["messages"][2]["role"], json!("assistant"));
1071 assert_eq!(request_json["tools"][0]["type"], json!("function"));
1072 assert_eq!(request_json["tools"][0]["function"]["name"], json!("echo"));
1073 }
1074
1075 #[test]
1076 fn test_azure_build_request_defaults_max_tokens() {
1077 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
1078 let context = Context {
1079 system_prompt: None,
1080 messages: vec![Message::User(UserMessage {
1081 content: UserContent::Text("Hello".to_string()),
1082 timestamp: 0,
1083 })]
1084 .into(),
1085 tools: Vec::new().into(),
1086 };
1087 let options = StreamOptions::default();
1088
1089 let request = provider.build_request(&context, &options);
1090 let request_json = serde_json::to_value(&request).expect("serialize request");
1091 assert_eq!(request_json["max_tokens"], json!(DEFAULT_MAX_TOKENS));
1092 assert_eq!(request_json["stream"], json!(true));
1093 assert!(request_json.get("tools").is_none());
1094 }
1095
1096 #[test]
1097 fn test_azure_build_request_normalizes_known_system_role_name() {
1098 let provider =
1099 AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1100 system_role_name: Some("SYSTEM ".to_string()),
1101 ..CompatConfig::default()
1102 }));
1103 let context = Context {
1104 system_prompt: Some("You are deterministic.".to_string().into()),
1105 messages: Vec::new().into(),
1106 tools: Vec::new().into(),
1107 };
1108
1109 let request = provider.build_request(&context, &StreamOptions::default());
1110 let request_json = serde_json::to_value(&request).expect("serialize request");
1111 assert_eq!(request_json["messages"][0]["role"], json!("system"));
1112 }
1113
1114 #[test]
1115 fn test_azure_build_request_preserves_unknown_system_role_name() {
1116 let provider =
1117 AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1118 system_role_name: Some("custom_role".to_string()),
1119 ..CompatConfig::default()
1120 }));
1121 let context = Context {
1122 system_prompt: Some("You are deterministic.".to_string().into()),
1123 messages: Vec::new().into(),
1124 tools: Vec::new().into(),
1125 };
1126
1127 let request = provider.build_request(&context, &StreamOptions::default());
1128 let request_json = serde_json::to_value(&request).expect("serialize request");
1129 assert_eq!(request_json["messages"][0]["role"], json!("custom_role"));
1130 }
1131
1132 #[test]
1133 fn test_azure_message_conversion() {
1134 let message = Message::User(UserMessage {
1135 content: UserContent::Text("Hello".to_string()),
1136 timestamp: chrono::Utc::now().timestamp_millis(),
1137 });
1138
1139 let azure_messages = convert_message_to_azure(&message);
1140 assert_eq!(azure_messages.len(), 1);
1141 assert_eq!(azure_messages[0].role, "user");
1142 }
1143
1144 #[derive(Debug, Deserialize)]
1145 struct ProviderFixture {
1146 cases: Vec<ProviderCase>,
1147 }
1148
1149 #[derive(Debug, Deserialize)]
1150 struct ProviderCase {
1151 name: String,
1152 events: Vec<Value>,
1153 expected: Vec<EventSummary>,
1154 }
1155
1156 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1157 struct EventSummary {
1158 kind: String,
1159 #[serde(default)]
1160 content_index: Option<usize>,
1161 #[serde(default)]
1162 delta: Option<String>,
1163 #[serde(default)]
1164 content: Option<String>,
1165 #[serde(default)]
1166 reason: Option<String>,
1167 }
1168
1169 #[test]
1170 fn test_stream_fixtures() {
1171 let fixture = load_fixture("azure_stream.json");
1172 for case in fixture.cases {
1173 let events = collect_events(&case.events);
1174 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1175 assert_eq!(summaries, case.expected, "case {}", case.name);
1176 }
1177 }
1178
1179 #[test]
1180 fn test_stream_handles_sparse_tool_call_index_without_panic() {
1181 let events = vec![
1182 json!({ "choices": [{ "delta": {} }] }),
1183 json!({
1184 "choices": [{
1185 "delta": {
1186 "tool_calls": [{
1187 "index": 3,
1188 "id": "call_sparse",
1189 "function": {
1190 "name": "lookup",
1191 "arguments": "{\"q\":\"azure\"}"
1192 }
1193 }]
1194 }
1195 }]
1196 }),
1197 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1198 Value::String("[DONE]".to_string()),
1199 ];
1200
1201 let out = collect_events(&events);
1202 let done = out
1203 .iter()
1204 .find_map(|event| match event {
1205 StreamEvent::Done { message, .. } => Some(message),
1206 _ => None,
1207 })
1208 .expect("done event");
1209 let tool_calls: Vec<&ToolCall> = done
1210 .content
1211 .iter()
1212 .filter_map(|block| match block {
1213 ContentBlock::ToolCall(tc) => Some(tc),
1214 _ => None,
1215 })
1216 .collect();
1217 assert_eq!(tool_calls.len(), 1);
1218 assert_eq!(tool_calls[0].id, "call_sparse");
1219 assert_eq!(tool_calls[0].name, "lookup");
1220 assert_eq!(tool_calls[0].arguments, json!({ "q": "azure" }));
1221 assert!(
1222 out.iter()
1223 .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1224 "expected tool call start event"
1225 );
1226 }
1227
1228 #[derive(Debug)]
1229 struct CapturedRequest {
1230 headers: HashMap<String, String>,
1231 body: String,
1232 }
1233
1234 #[test]
1235 fn test_stream_compat_api_key_header_works_without_api_key() {
1236 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1237 let mut custom_headers = HashMap::new();
1238 custom_headers.insert("api-key".to_string(), "compat-azure-key".to_string());
1239 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1240 .with_endpoint_url(base_url)
1241 .with_compat(Some(CompatConfig {
1242 custom_headers: Some(custom_headers),
1243 ..CompatConfig::default()
1244 }));
1245 let context = Context {
1246 system_prompt: None,
1247 messages: vec![Message::User(UserMessage {
1248 content: UserContent::Text("ping".to_string()),
1249 timestamp: 0,
1250 })]
1251 .into(),
1252 tools: Vec::new().into(),
1253 };
1254
1255 let runtime = RuntimeBuilder::current_thread()
1256 .build()
1257 .expect("runtime build");
1258 runtime.block_on(async {
1259 let mut stream = provider
1260 .stream(&context, &StreamOptions::default())
1261 .await
1262 .expect("stream");
1263 while let Some(event) = stream.next().await {
1264 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1265 break;
1266 }
1267 }
1268 });
1269
1270 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1271 assert_eq!(
1272 captured.headers.get("api-key").map(String::as_str),
1273 Some("compat-azure-key")
1274 );
1275 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1276 assert_eq!(body["stream"], true);
1277 }
1278
1279 #[test]
1280 fn test_stream_compat_authorization_header_works_without_api_key() {
1281 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1282 let mut custom_headers = HashMap::new();
1283 custom_headers.insert(
1284 "Authorization".to_string(),
1285 "Bearer compat-azure-token".to_string(),
1286 );
1287 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1288 .with_endpoint_url(base_url)
1289 .with_compat(Some(CompatConfig {
1290 custom_headers: Some(custom_headers),
1291 ..CompatConfig::default()
1292 }));
1293 let context = Context {
1294 system_prompt: None,
1295 messages: vec![Message::User(UserMessage {
1296 content: UserContent::Text("ping".to_string()),
1297 timestamp: 0,
1298 })]
1299 .into(),
1300 tools: Vec::new().into(),
1301 };
1302
1303 let runtime = RuntimeBuilder::current_thread()
1304 .build()
1305 .expect("runtime build");
1306 runtime.block_on(async {
1307 let mut stream = provider
1308 .stream(&context, &StreamOptions::default())
1309 .await
1310 .expect("stream");
1311 while let Some(event) = stream.next().await {
1312 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1313 break;
1314 }
1315 }
1316 });
1317
1318 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1319 assert_eq!(
1320 captured.headers.get("authorization").map(String::as_str),
1321 Some("Bearer compat-azure-token")
1322 );
1323 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1324 assert_eq!(body["stream"], true);
1325 }
1326
1327 #[test]
1328 fn test_blank_compat_api_key_header_does_not_override_builtin_api_key() {
1329 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1330 let mut custom_headers = HashMap::new();
1331 custom_headers.insert("api-key".to_string(), " ".to_string());
1332 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
1333 .with_endpoint_url(base_url)
1334 .with_compat(Some(CompatConfig {
1335 custom_headers: Some(custom_headers),
1336 ..CompatConfig::default()
1337 }));
1338 let context = Context {
1339 system_prompt: None,
1340 messages: vec![Message::User(UserMessage {
1341 content: UserContent::Text("ping".to_string()),
1342 timestamp: 0,
1343 })]
1344 .into(),
1345 tools: Vec::new().into(),
1346 };
1347 let options = StreamOptions {
1348 api_key: Some("fallback-azure-key".to_string()),
1349 ..Default::default()
1350 };
1351
1352 let runtime = RuntimeBuilder::current_thread()
1353 .build()
1354 .expect("runtime build");
1355 runtime.block_on(async {
1356 let mut stream = provider.stream(&context, &options).await.expect("stream");
1357 while let Some(event) = stream.next().await {
1358 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1359 break;
1360 }
1361 }
1362 });
1363
1364 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1365 assert_eq!(
1366 captured.headers.get("api-key").map(String::as_str),
1367 Some("fallback-azure-key")
1368 );
1369 }
1370
1371 fn load_fixture(file_name: &str) -> ProviderFixture {
1372 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1373 .join("tests/fixtures/provider_responses")
1374 .join(file_name);
1375 let raw = std::fs::read_to_string(path).expect("fixture read");
1376 serde_json::from_str(&raw).expect("fixture parse")
1377 }
1378
1379 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1380 let runtime = RuntimeBuilder::current_thread()
1381 .build()
1382 .expect("runtime build");
1383 runtime.block_on(async move {
1384 let byte_stream = stream::iter(
1385 events
1386 .iter()
1387 .map(|event| {
1388 let data = match event {
1389 Value::String(text) => text.clone(),
1390 _ => serde_json::to_string(event).expect("serialize event"),
1391 };
1392 format!("data: {data}\n\n").into_bytes()
1393 })
1394 .map(Ok),
1395 );
1396 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1397 let mut state = StreamState::new(
1398 event_source,
1399 "gpt-test".to_string(),
1400 "azure-openai".to_string(),
1401 "azure".to_string(),
1402 );
1403 let mut out = Vec::new();
1404
1405 while let Some(item) = state.event_source.next().await {
1406 let msg = item.expect("SSE event");
1407 if msg.data == "[DONE]" {
1408 out.extend(state.pending_events.drain(..));
1409 let reason = state.partial.stop_reason;
1410 out.push(StreamEvent::Done {
1411 reason,
1412 message: std::mem::take(&mut state.partial),
1413 });
1414 break;
1415 }
1416 state.process_event(&msg.data).expect("process_event");
1417 out.extend(state.pending_events.drain(..));
1418 }
1419
1420 out
1421 })
1422 }
1423
1424 fn success_sse_body() -> String {
1425 [
1426 r#"data: {"choices":[{"delta":{}}]}"#,
1427 "",
1428 r#"data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}"#,
1429 "",
1430 "data: [DONE]",
1431 "",
1432 ]
1433 .join("\n")
1434 }
1435
1436 fn spawn_test_server(
1437 status_code: u16,
1438 content_type: &str,
1439 body: &str,
1440 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1441 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1442 let addr = listener.local_addr().expect("local addr");
1443 let (tx, rx) = mpsc::channel();
1444 let body = body.to_string();
1445 let content_type = content_type.to_string();
1446
1447 std::thread::spawn(move || {
1448 let (mut socket, _) = listener.accept().expect("accept");
1449 socket
1450 .set_read_timeout(Some(Duration::from_secs(2)))
1451 .expect("set read timeout");
1452
1453 let mut bytes = Vec::new();
1454 let mut chunk = [0_u8; 4096];
1455 loop {
1456 match socket.read(&mut chunk) {
1457 Ok(0) => break,
1458 Ok(n) => {
1459 bytes.extend_from_slice(&chunk[..n]);
1460 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1461 break;
1462 }
1463 }
1464 Err(err)
1465 if err.kind() == std::io::ErrorKind::WouldBlock
1466 || err.kind() == std::io::ErrorKind::TimedOut =>
1467 {
1468 break;
1469 }
1470 Err(err) => panic!("{err}"),
1471 }
1472 }
1473
1474 let header_end = bytes
1475 .windows(4)
1476 .position(|window| window == b"\r\n\r\n")
1477 .expect("request header boundary");
1478 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1479 let headers = parse_headers(&header_text);
1480 let mut request_body = bytes[header_end + 4..].to_vec();
1481
1482 let content_length = headers
1483 .get("content-length")
1484 .and_then(|value| value.parse::<usize>().ok())
1485 .unwrap_or(0);
1486 while request_body.len() < content_length {
1487 match socket.read(&mut chunk) {
1488 Ok(0) => break,
1489 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1490 Err(err)
1491 if err.kind() == std::io::ErrorKind::WouldBlock
1492 || err.kind() == std::io::ErrorKind::TimedOut =>
1493 {
1494 break;
1495 }
1496 Err(err) => panic!("{err}"),
1497 }
1498 }
1499
1500 tx.send(CapturedRequest {
1501 headers,
1502 body: String::from_utf8_lossy(&request_body).to_string(),
1503 })
1504 .expect("send captured request");
1505
1506 let reason = match status_code {
1507 401 => "Unauthorized",
1508 500 => "Internal Server Error",
1509 _ => "OK",
1510 };
1511 let response = format!(
1512 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1513 body.len()
1514 );
1515 socket
1516 .write_all(response.as_bytes())
1517 .expect("write response");
1518 socket.flush().expect("flush response");
1519 });
1520
1521 (format!("http://{addr}/azure"), rx)
1522 }
1523
1524 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1525 let mut headers = HashMap::new();
1526 for line in header_text.lines().skip(1) {
1527 if let Some((name, value)) = line.split_once(':') {
1528 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1529 }
1530 }
1531 headers
1532 }
1533
1534 fn summarize_event(event: &StreamEvent) -> EventSummary {
1535 match event {
1536 StreamEvent::Start { .. } => EventSummary {
1537 kind: "start".to_string(),
1538 content_index: None,
1539 delta: None,
1540 content: None,
1541 reason: None,
1542 },
1543 StreamEvent::TextDelta {
1544 content_index,
1545 delta,
1546 ..
1547 } => EventSummary {
1548 kind: "text_delta".to_string(),
1549 content_index: Some(*content_index),
1550 delta: Some(delta.clone()),
1551 content: None,
1552 reason: None,
1553 },
1554 StreamEvent::Done { reason, .. } => EventSummary {
1555 kind: "done".to_string(),
1556 content_index: None,
1557 delta: None,
1558 content: None,
1559 reason: Some(reason_to_string(*reason)),
1560 },
1561 StreamEvent::Error { reason, .. } => EventSummary {
1562 kind: "error".to_string(),
1563 content_index: None,
1564 delta: None,
1565 content: None,
1566 reason: Some(reason_to_string(*reason)),
1567 },
1568 StreamEvent::TextStart { content_index, .. } => EventSummary {
1569 kind: "text_start".to_string(),
1570 content_index: Some(*content_index),
1571 delta: None,
1572 content: None,
1573 reason: None,
1574 },
1575 StreamEvent::TextEnd {
1576 content_index,
1577 content,
1578 ..
1579 } => EventSummary {
1580 kind: "text_end".to_string(),
1581 content_index: Some(*content_index),
1582 delta: None,
1583 content: Some(content.clone()),
1584 reason: None,
1585 },
1586 _ => EventSummary {
1587 kind: "other".to_string(),
1588 content_index: None,
1589 delta: None,
1590 content: None,
1591 reason: None,
1592 },
1593 }
1594 }
1595
1596 fn reason_to_string(reason: StopReason) -> String {
1597 match reason {
1598 StopReason::Stop => "stop",
1599 StopReason::Length => "length",
1600 StopReason::ToolUse => "tool_use",
1601 StopReason::Error => "error",
1602 StopReason::Aborted => "aborted",
1603 }
1604 .to_string()
1605 }
1606
1607 fn make_tool_result(content: Vec<ContentBlock>) -> Message {
1608 Message::tool_result(ToolResultMessage {
1609 tool_call_id: "call_123".to_string(),
1610 tool_name: "test_tool".to_string(),
1611 content,
1612 details: None,
1613 is_error: false,
1614 timestamp: 0,
1615 })
1616 }
1617
1618 #[test]
1619 fn tool_result_text_only_produces_single_tool_message() {
1620 let msg = make_tool_result(vec![ContentBlock::Text(TextContent {
1621 text: "result text".to_string(),
1622 text_signature: None,
1623 })]);
1624 let azure_msgs = convert_message_to_azure(&msg);
1625 assert_eq!(azure_msgs.len(), 1);
1626 assert_eq!(azure_msgs[0].role, "tool");
1627 assert_eq!(azure_msgs[0].tool_call_id.as_deref(), Some("call_123"));
1628 let json = serde_json::to_value(&azure_msgs[0]).expect("serialize");
1629 assert_eq!(json["content"], "result text");
1630 }
1631
1632 #[test]
1633 fn tool_result_image_only_produces_tool_plus_user_message() {
1634 let msg = make_tool_result(vec![ContentBlock::Image(ImageContent {
1635 data: "aW1hZ2U=".to_string(),
1636 mime_type: "image/png".to_string(),
1637 })]);
1638 let azure_msgs = convert_message_to_azure(&msg);
1639 assert_eq!(
1640 azure_msgs.len(),
1641 2,
1642 "image-only should produce tool + user messages"
1643 );
1644 assert_eq!(azure_msgs[0].role, "tool");
1645 assert_eq!(azure_msgs[1].role, "user");
1646
1647 let tool_json = serde_json::to_value(&azure_msgs[0]).expect("serialize tool");
1648 assert_eq!(tool_json["content"], "(see attached image)");
1649
1650 let user_json = serde_json::to_value(&azure_msgs[1]).expect("serialize user");
1651 let parts = user_json["content"].as_array().expect("parts array");
1652 assert_eq!(parts.len(), 2);
1653 assert_eq!(parts[0]["type"], "text");
1654 assert_eq!(parts[1]["type"], "image_url");
1655 assert!(
1656 parts[1]["image_url"]["url"]
1657 .as_str()
1658 .unwrap()
1659 .starts_with("data:image/png;base64,")
1660 );
1661 }
1662
1663 #[test]
1664 fn tool_result_mixed_text_and_image_splits_correctly() {
1665 let msg = make_tool_result(vec![
1666 ContentBlock::Text(TextContent {
1667 text: "line one".to_string(),
1668 text_signature: None,
1669 }),
1670 ContentBlock::Image(ImageContent {
1671 data: "aW1hZ2U=".to_string(),
1672 mime_type: "image/jpeg".to_string(),
1673 }),
1674 ContentBlock::Text(TextContent {
1675 text: "line two".to_string(),
1676 text_signature: None,
1677 }),
1678 ]);
1679 let azure_msgs = convert_message_to_azure(&msg);
1680 assert_eq!(
1681 azure_msgs.len(),
1682 2,
1683 "mixed content should produce tool + user messages"
1684 );
1685
1686 let tool_json = serde_json::to_value(&azure_msgs[0]).expect("serialize tool");
1687 assert_eq!(tool_json["content"], "line one\nline two");
1688 assert_eq!(tool_json["tool_call_id"], "call_123");
1689
1690 let user_json = serde_json::to_value(&azure_msgs[1]).expect("serialize user");
1691 let parts = user_json["content"].as_array().expect("parts array");
1692 assert_eq!(parts.len(), 2);
1693 assert_eq!(parts[0]["type"], "text");
1694 assert_eq!(parts[1]["type"], "image_url");
1695 }
1696
1697 #[test]
1698 fn tool_result_empty_content_produces_single_tool_message_with_no_content() {
1699 let msg = make_tool_result(vec![]);
1700 let azure_msgs = convert_message_to_azure(&msg);
1701 assert_eq!(azure_msgs.len(), 1);
1702 assert_eq!(azure_msgs[0].role, "tool");
1703 let json = serde_json::to_value(&azure_msgs[0]).expect("serialize");
1704 assert!(
1705 json["content"].is_null(),
1706 "empty tool result should have null content"
1707 );
1708 }
1709}
1710
1711#[cfg(feature = "fuzzing")]
1716pub mod fuzz {
1717 use super::*;
1718 use futures::stream;
1719 use std::pin::Pin;
1720
1721 type FuzzStream =
1722 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1723
1724 pub struct Processor(StreamState<FuzzStream>);
1726
1727 impl Default for Processor {
1728 fn default() -> Self {
1729 Self::new()
1730 }
1731 }
1732
1733 impl Processor {
1734 pub fn new() -> Self {
1736 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1737 Self(StreamState::new(
1738 crate::sse::SseStream::new(Box::pin(empty)),
1739 "azure-fuzz".into(),
1740 "azure-openai".into(),
1741 "azure".into(),
1742 ))
1743 }
1744
1745 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1747 self.0.process_event(data)?;
1748 Ok(self.0.pending_events.drain(..).collect())
1749 }
1750 }
1751}