1use crate::error::{Error, Result};
10use crate::http::client::Client;
11use crate::model::{
12 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, Usage, UserContent,
13};
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamOptions, ToolDef};
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::collections::VecDeque;
22use std::pin::Pin;
23
24pub(crate) const DEFAULT_API_VERSION: &str = "2024-02-15-preview";
29const DEFAULT_MAX_TOKENS: u32 = 4096;
30
31fn normalize_role(role: &str) -> String {
33 let trimmed = role.trim();
34 match trimmed {
35 "system" | "developer" | "user" | "assistant" | "tool" | "function" => trimmed.to_string(),
36 _ => {
37 let lowered = trimmed.to_ascii_lowercase();
38 match lowered.as_str() {
39 "system" | "developer" | "user" | "assistant" | "tool" | "function" => lowered,
40 _ => trimmed.to_string(),
41 }
42 }
43 }
44}
45
46pub struct AzureOpenAIProvider {
52 client: Client,
53 deployment: String,
55 resource: String,
57 api_version: String,
59 endpoint_url_override: Option<String>,
61 compat: Option<CompatConfig>,
62}
63
64impl AzureOpenAIProvider {
65 pub fn new(resource: impl Into<String>, deployment: impl Into<String>) -> Self {
71 Self {
72 client: Client::new(),
73 deployment: deployment.into(),
74 resource: resource.into(),
75 api_version: DEFAULT_API_VERSION.to_string(),
76 endpoint_url_override: None,
77 compat: None,
78 }
79 }
80
81 #[must_use]
83 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
84 self.api_version = version.into();
85 self
86 }
87
88 #[must_use]
93 pub fn with_endpoint_url(mut self, endpoint_url: impl Into<String>) -> Self {
94 self.endpoint_url_override = Some(endpoint_url.into());
95 self
96 }
97
98 #[must_use]
100 pub fn with_client(mut self, client: Client) -> Self {
101 self.client = client;
102 self
103 }
104
105 #[must_use]
107 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
108 self.compat = compat;
109 self
110 }
111
112 fn endpoint_url(&self) -> String {
114 if let Some(url) = &self.endpoint_url_override {
115 return url.clone();
116 }
117 format!(
118 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
119 self.resource, self.deployment, self.api_version
120 )
121 }
122
123 #[allow(clippy::unused_self)]
125 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> AzureRequest {
126 let messages = self.build_messages(context);
127
128 let tools: Option<Vec<AzureTool>> = if context.tools.is_empty() {
129 None
130 } else {
131 Some(context.tools.iter().map(convert_tool_to_azure).collect())
132 };
133
134 AzureRequest {
135 messages,
136 max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
137 temperature: options.temperature,
138 tools,
139 stream: true,
140 stream_options: Some(AzureStreamOptions {
141 include_usage: true,
142 }),
143 }
144 }
145
146 fn build_messages(&self, context: &Context<'_>) -> Vec<AzureMessage> {
148 let mut messages = Vec::new();
149 let system_role = self
150 .compat
151 .as_ref()
152 .and_then(|c| c.system_role_name.as_deref())
153 .unwrap_or("system");
154
155 if let Some(system) = &context.system_prompt {
157 messages.push(AzureMessage {
158 role: normalize_role(system_role),
159 content: Some(AzureContent::Text(system.to_string())),
160 tool_calls: None,
161 tool_call_id: None,
162 });
163 }
164
165 for message in context.messages.iter() {
167 messages.extend(convert_message_to_azure(message));
168 }
169
170 messages
171 }
172}
173
174#[async_trait]
175impl Provider for AzureOpenAIProvider {
176 fn name(&self) -> &'static str {
177 "azure"
178 }
179
180 fn api(&self) -> &'static str {
181 "azure-openai"
182 }
183
184 fn model_id(&self) -> &str {
185 &self.deployment
186 }
187
188 async fn stream(
189 &self,
190 context: &Context<'_>,
191 options: &StreamOptions,
192 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
193 let auth_value = options
194 .api_key
195 .clone()
196 .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
197 .ok_or_else(|| Error::provider("azure-openai", "Missing API key for Azure OpenAI. Set AZURE_OPENAI_API_KEY or configure in settings."))?;
198
199 let request_body = self.build_request(context, options);
200
201 let endpoint_url = self.endpoint_url();
202
203 let mut request = self
205 .client
206 .post(&endpoint_url)
207 .header("Accept", "text/event-stream")
208 .header("api-key", &auth_value); if let Some(compat) = &self.compat {
212 if let Some(custom_headers) = &compat.custom_headers {
213 for (key, value) in custom_headers {
214 request = request.header(key, value);
215 }
216 }
217 }
218
219 for (key, value) in &options.headers {
220 request = request.header(key, value);
221 }
222
223 let request = request.json(&request_body)?;
224
225 let response = Box::pin(request.send()).await?;
226 let status = response.status();
227 if !(200..300).contains(&status) {
228 let body = response
229 .text()
230 .await
231 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
232 return Err(Error::provider(
233 "azure-openai",
234 format!("Azure OpenAI API error (HTTP {status}): {body}"),
235 ));
236 }
237
238 let event_source = SseStream::new(response.bytes_stream());
240
241 let model = self.deployment.clone();
243 let api = self.api().to_string();
244 let provider = self.name().to_string();
245
246 let stream = stream::unfold(
247 StreamState::new(event_source, model, api, provider),
248 |mut state| async move {
249 if state.done {
250 return None;
251 }
252 loop {
253 if let Some(event) = state.pending_events.pop_front() {
254 return Some((Ok(event), state));
255 }
256
257 match state.event_source.next().await {
258 Some(Ok(msg)) => {
259 if msg.data == "[DONE]" {
261 state.done = true;
262 let reason = state.partial.stop_reason;
263 let message = std::mem::take(&mut state.partial);
264 return Some((Ok(StreamEvent::Done { reason, message }), state));
265 }
266
267 if let Err(e) = state.process_event(&msg.data) {
268 state.done = true;
269 return Some((Err(e), state));
270 }
271 }
272 Some(Err(e)) => {
273 state.done = true;
274 let err = Error::api(format!("SSE error: {e}"));
275 return Some((Err(err), state));
276 }
277 None => {
282 state.done = true;
283 let reason = state.partial.stop_reason;
284 let message = std::mem::take(&mut state.partial);
285 return Some((Ok(StreamEvent::Done { reason, message }), state));
286 }
287 }
288 }
289 },
290 );
291
292 Ok(Box::pin(stream))
293 }
294}
295
296struct StreamState<S>
301where
302 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
303{
304 event_source: SseStream<S>,
305 partial: AssistantMessage,
306 tool_calls: Vec<ToolCallState>,
307 pending_events: VecDeque<StreamEvent>,
308 started: bool,
309 done: bool,
310}
311
312struct ToolCallState {
313 index: usize,
314 content_index: usize,
315 id: String,
316 name: String,
317 arguments: String,
318}
319
320impl<S> StreamState<S>
321where
322 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
323{
324 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
325 Self {
326 event_source,
327 partial: AssistantMessage {
328 content: Vec::new(),
329 api,
330 provider,
331 model,
332 usage: Usage::default(),
333 stop_reason: StopReason::Stop,
334 error_message: None,
335 timestamp: chrono::Utc::now().timestamp_millis(),
336 },
337 tool_calls: Vec::new(),
338 pending_events: VecDeque::new(),
339 started: false,
340 done: false,
341 }
342 }
343
344 fn finalize_tool_call_arguments(&mut self) {
345 for tc in &self.tool_calls {
346 let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
347 Ok(args) => args,
348 Err(e) => {
349 tracing::warn!(
350 error = %e,
351 raw = %tc.arguments,
352 "Failed to parse tool arguments as JSON"
353 );
354 serde_json::Value::Null
355 }
356 };
357
358 if let Some(ContentBlock::ToolCall(block)) =
359 self.partial.content.get_mut(tc.content_index)
360 {
361 block.arguments = arguments;
362 }
363 }
364 }
365
366 fn push_text_delta(&mut self, text: String) -> StreamEvent {
367 let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
368 if !last_is_text {
369 let content_index = self.partial.content.len();
370 self.partial
371 .content
372 .push(ContentBlock::Text(crate::model::TextContent::new("")));
373 self.pending_events
374 .push_back(StreamEvent::TextStart { content_index });
375 }
376 let content_index = self.partial.content.len() - 1;
377
378 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
379 t.text.push_str(&text);
380 }
381
382 StreamEvent::TextDelta {
383 content_index,
384 delta: text,
385 }
386 }
387
388 fn ensure_started(&mut self) {
389 if !self.started {
390 self.started = true;
391 self.pending_events.push_back(StreamEvent::Start {
392 partial: self.partial.clone(),
393 });
394 }
395 }
396
397 #[allow(clippy::unnecessary_wraps, clippy::too_many_lines)]
398 fn process_event(&mut self, data: &str) -> Result<()> {
399 let chunk: AzureStreamChunk =
400 serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
401
402 if let Some(usage) = chunk.usage {
404 self.partial.usage.input = usage.prompt_tokens;
405 self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
406 self.partial.usage.total_tokens = usage.total_tokens;
407 }
408
409 let choices = chunk.choices;
410 if !self.started {
411 let first = choices.first();
412 let delta_is_empty = first.is_some_and(|choice| {
413 choice.finish_reason.is_none()
414 && choice.delta.content.is_none()
415 && choice.delta.tool_calls.is_none()
416 });
417 if delta_is_empty {
418 self.ensure_started();
419 return Ok(());
420 }
421 }
422
423 for choice in choices {
427 if let Some(text) = choice.delta.content {
429 self.ensure_started();
430 let event = self.push_text_delta(text);
431 self.pending_events.push_back(event);
432 }
433
434 if let Some(tool_calls) = choice.delta.tool_calls {
436 self.ensure_started();
437
438 for tc in tool_calls {
439 let idx = tc.index as usize;
440
441 let tool_state_idx = if let Some(existing_idx) =
444 self.tool_calls.iter().position(|tc| tc.index == idx)
445 {
446 existing_idx
447 } else {
448 let content_index = self.partial.content.len();
449 self.tool_calls.push(ToolCallState {
450 index: idx,
451 content_index,
452 id: String::new(),
453 name: String::new(),
454 arguments: String::new(),
455 });
456
457 self.partial
459 .content
460 .push(ContentBlock::ToolCall(crate::model::ToolCall {
461 id: String::new(),
462 name: String::new(),
463 arguments: serde_json::Value::Null,
464 thought_signature: None,
465 }));
466
467 self.pending_events
469 .push_back(StreamEvent::ToolCallStart { content_index });
470 self.tool_calls.len() - 1
471 };
472
473 let tc_state = &mut self.tool_calls[tool_state_idx];
474 let content_index = tc_state.content_index;
475
476 if let Some(id) = tc.id {
478 tc_state.id.clone_from(&id);
479 if let Some(ContentBlock::ToolCall(block)) =
480 self.partial.content.get_mut(content_index)
481 {
482 block.id = id;
483 }
484 }
485 if let Some(func) = tc.function {
486 if let Some(name) = func.name {
487 tc_state.name.clone_from(&name);
488 if let Some(ContentBlock::ToolCall(block)) =
489 self.partial.content.get_mut(content_index)
490 {
491 block.name = name;
492 }
493 }
494 if let Some(args) = func.arguments {
495 tc_state.arguments.push_str(&args);
496 self.pending_events.push_back(StreamEvent::ToolCallDelta {
500 content_index,
501 delta: args,
502 });
503 }
504 }
505 }
506 }
507
508 if choice.finish_reason.is_some() {
512 self.ensure_started();
513 }
514 if let Some(reason) = choice.finish_reason {
515 self.partial.stop_reason = match reason.as_str() {
516 "length" => StopReason::Length,
517 "content_filter" => StopReason::Error,
518 "tool_calls" => StopReason::ToolUse,
519 _ => StopReason::Stop,
521 };
522
523 self.finalize_tool_call_arguments();
525
526 for (content_index, block) in self.partial.content.iter().enumerate() {
528 if let ContentBlock::Text(t) = block {
529 self.pending_events.push_back(StreamEvent::TextEnd {
530 content_index,
531 content: t.text.clone(),
532 });
533 } else if let ContentBlock::Thinking(t) = block {
534 self.pending_events.push_back(StreamEvent::ThinkingEnd {
535 content_index,
536 content: t.thinking.clone(),
537 });
538 }
539 }
540
541 for tc in &self.tool_calls {
543 if let Some(ContentBlock::ToolCall(tool_call)) =
544 self.partial.content.get(tc.content_index)
545 {
546 self.pending_events.push_back(StreamEvent::ToolCallEnd {
547 content_index: tc.content_index,
548 tool_call: tool_call.clone(),
549 });
550 }
551 }
552 }
553 }
554
555 Ok(())
556 }
557}
558
559#[derive(Debug, Serialize)]
564pub struct AzureRequest {
565 messages: Vec<AzureMessage>,
566 #[serde(skip_serializing_if = "Option::is_none")]
567 max_tokens: Option<u32>,
568 #[serde(skip_serializing_if = "Option::is_none")]
569 temperature: Option<f32>,
570 #[serde(skip_serializing_if = "Option::is_none")]
571 tools: Option<Vec<AzureTool>>,
572 stream: bool,
573 #[serde(skip_serializing_if = "Option::is_none")]
574 stream_options: Option<AzureStreamOptions>,
575}
576
577#[derive(Debug, Serialize)]
578struct AzureStreamOptions {
579 include_usage: bool,
580}
581
582#[derive(Debug, Serialize)]
583struct AzureMessage {
584 role: String,
585 #[serde(skip_serializing_if = "Option::is_none")]
586 content: Option<AzureContent>,
587 #[serde(skip_serializing_if = "Option::is_none")]
588 tool_calls: Option<Vec<AzureToolCallRef>>,
589 #[serde(skip_serializing_if = "Option::is_none")]
590 tool_call_id: Option<String>,
591}
592
593#[derive(Debug, Serialize)]
594#[serde(untagged)]
595enum AzureContent {
596 Text(String),
597 Parts(Vec<AzureContentPart>),
598}
599
600#[derive(Debug, Serialize)]
601#[serde(tag = "type")]
602enum AzureContentPart {
603 #[serde(rename = "text")]
604 Text { text: String },
605 #[serde(rename = "image_url")]
606 ImageUrl { image_url: AzureImageUrl },
607}
608
609#[derive(Debug, Serialize)]
610struct AzureImageUrl {
611 url: String,
612}
613
614#[derive(Debug, Serialize)]
615struct AzureToolCallRef {
616 id: String,
617 r#type: &'static str,
618 function: AzureFunctionRef,
619}
620
621#[derive(Debug, Serialize)]
622struct AzureFunctionRef {
623 name: String,
624 arguments: String,
625}
626
627#[derive(Debug, Serialize)]
628struct AzureTool {
629 r#type: &'static str,
630 function: AzureFunction,
631}
632
633#[derive(Debug, Serialize)]
634struct AzureFunction {
635 name: String,
636 description: String,
637 parameters: serde_json::Value,
638}
639
640#[derive(Debug, Deserialize)]
645struct AzureStreamChunk {
646 #[serde(default)]
647 choices: Vec<AzureChoice>,
648 #[serde(default)]
649 usage: Option<AzureUsage>,
650}
651
652#[derive(Debug, Deserialize)]
653struct AzureChoice {
654 delta: AzureDelta,
655 #[serde(default)]
656 finish_reason: Option<String>,
657}
658
659#[derive(Debug, Deserialize)]
660struct AzureDelta {
661 #[serde(default)]
662 content: Option<String>,
663 #[serde(default)]
664 tool_calls: Option<Vec<AzureToolCallDelta>>,
665}
666
667#[derive(Debug, Deserialize)]
668struct AzureToolCallDelta {
669 index: u32,
670 #[serde(default)]
671 id: Option<String>,
672 #[serde(default)]
673 function: Option<AzureFunctionDelta>,
674}
675
676#[derive(Debug, Deserialize)]
677struct AzureFunctionDelta {
678 #[serde(default)]
679 name: Option<String>,
680 #[serde(default)]
681 arguments: Option<String>,
682}
683
684#[derive(Debug, Deserialize)]
685#[allow(clippy::struct_field_names)]
686struct AzureUsage {
687 prompt_tokens: u64,
688 #[serde(default)]
689 completion_tokens: Option<u64>,
690 #[allow(dead_code)]
691 total_tokens: u64,
692}
693
694fn convert_message_to_azure(message: &Message) -> Vec<AzureMessage> {
699 match message {
700 Message::User(user) => vec![AzureMessage {
701 role: "user".to_string(),
702 content: Some(convert_user_content(&user.content)),
703 tool_calls: None,
704 tool_call_id: None,
705 }],
706 Message::Custom(custom) => vec![AzureMessage {
707 role: "user".to_string(),
708 content: Some(AzureContent::Text(custom.content.clone())),
709 tool_calls: None,
710 tool_call_id: None,
711 }],
712 Message::Assistant(assistant) => {
713 let mut messages = Vec::new();
714
715 let text: String = assistant
717 .content
718 .iter()
719 .filter_map(|b| match b {
720 ContentBlock::Text(t) => Some(t.text.as_str()),
721 _ => None,
722 })
723 .collect::<String>();
724
725 let tool_calls: Vec<AzureToolCallRef> = assistant
727 .content
728 .iter()
729 .filter_map(|b| match b {
730 ContentBlock::ToolCall(tc) => Some(AzureToolCallRef {
731 id: tc.id.clone(),
732 r#type: "function",
733 function: AzureFunctionRef {
734 name: tc.name.clone(),
735 arguments: tc.arguments.to_string(),
736 },
737 }),
738 _ => None,
739 })
740 .collect();
741
742 let content = if text.is_empty() {
743 None
744 } else {
745 Some(AzureContent::Text(text))
746 };
747
748 let tool_calls = if tool_calls.is_empty() {
749 None
750 } else {
751 Some(tool_calls)
752 };
753
754 messages.push(AzureMessage {
755 role: "assistant".to_string(),
756 content,
757 tool_calls,
758 tool_call_id: None,
759 });
760
761 messages
762 }
763 Message::ToolResult(result) => {
764 let parts: Vec<AzureContentPart> = result
765 .content
766 .iter()
767 .filter_map(|block| match block {
768 ContentBlock::Text(t) => Some(AzureContentPart::Text {
769 text: t.text.clone(),
770 }),
771 ContentBlock::Image(img) => {
772 let url = format!("data:{};base64,{}", img.mime_type, img.data);
773 Some(AzureContentPart::ImageUrl {
774 image_url: AzureImageUrl { url },
775 })
776 }
777 _ => None,
778 })
779 .collect();
780
781 let content = if parts.is_empty() {
782 None
783 } else if parts.len() == 1 && matches!(parts[0], AzureContentPart::Text { .. }) {
784 if let AzureContentPart::Text { text } = &parts[0] {
785 Some(AzureContent::Text(text.clone()))
786 } else {
787 Some(AzureContent::Parts(parts))
788 }
789 } else {
790 Some(AzureContent::Parts(parts))
791 };
792
793 vec![AzureMessage {
794 role: "tool".to_string(),
795 content,
796 tool_calls: None,
797 tool_call_id: Some(result.tool_call_id.clone()),
798 }]
799 }
800 }
801}
802
803fn convert_user_content(content: &UserContent) -> AzureContent {
804 match content {
805 UserContent::Text(text) => AzureContent::Text(text.clone()),
806 UserContent::Blocks(blocks) => {
807 let parts: Vec<AzureContentPart> = blocks
808 .iter()
809 .filter_map(|block| match block {
810 ContentBlock::Text(t) => Some(AzureContentPart::Text {
811 text: t.text.clone(),
812 }),
813 ContentBlock::Image(img) => {
814 let url = format!("data:{};base64,{}", img.mime_type, img.data);
815 Some(AzureContentPart::ImageUrl {
816 image_url: AzureImageUrl { url },
817 })
818 }
819 _ => None,
820 })
821 .collect();
822 AzureContent::Parts(parts)
823 }
824 }
825}
826
827fn convert_tool_to_azure(tool: &ToolDef) -> AzureTool {
828 AzureTool {
829 r#type: "function",
830 function: AzureFunction {
831 name: tool.name.clone(),
832 description: tool.description.clone(),
833 parameters: tool.parameters.clone(),
834 },
835 }
836}
837
838#[cfg(test)]
843mod tests {
844 use super::*;
845 use crate::model::{TextContent, ToolCall, UserMessage};
846 use crate::provider::ToolDef;
847 use asupersync::runtime::RuntimeBuilder;
848 use futures::{StreamExt, stream};
849 use serde::{Deserialize, Serialize};
850 use serde_json::{Value, json};
851 use std::path::PathBuf;
852
853 #[test]
854 fn test_azure_provider_creation() {
855 let provider = AzureOpenAIProvider::new("my-resource", "gpt-4");
856 assert_eq!(provider.name(), "azure");
857 assert_eq!(provider.api(), "azure-openai");
858 }
859
860 #[test]
861 fn test_azure_model_id_uses_deployment() {
862 let provider = AzureOpenAIProvider::new("my-resource", "gpt-4o-mini");
863 assert_eq!(provider.model_id(), "gpt-4o-mini");
864 }
865
866 #[test]
867 fn test_azure_endpoint_url() {
868 let provider = AzureOpenAIProvider::new("contoso", "gpt-4-turbo");
869 let url = provider.endpoint_url();
870 assert!(url.contains("contoso.openai.azure.com"));
871 assert!(url.contains("gpt-4-turbo"));
872 assert!(url.contains("api-version="));
873 }
874
875 #[test]
876 fn test_azure_endpoint_url_custom_version() {
877 let provider = AzureOpenAIProvider::new("contoso", "gpt-4").with_api_version("2024-06-01");
878 let url = provider.endpoint_url();
879 assert!(url.contains("api-version=2024-06-01"));
880 }
881
882 #[test]
883 fn test_azure_endpoint_url_exact_default_shape() {
884 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
885 let url = provider.endpoint_url();
886 assert_eq!(
887 url,
888 "https://contoso.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-15-preview"
889 );
890 }
891
892 #[test]
893 fn test_azure_endpoint_url_override_takes_precedence() {
894 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o")
895 .with_api_version("2025-01-01")
896 .with_endpoint_url("http://127.0.0.1:1234/mock-endpoint");
897 let url = provider.endpoint_url();
898 assert_eq!(url, "http://127.0.0.1:1234/mock-endpoint");
899 }
900
901 #[test]
902 fn test_azure_build_request_includes_system_messages_and_tools() {
903 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
904 let context = Context {
905 system_prompt: Some("You are deterministic.".to_string().into()),
906 messages: vec![
907 Message::User(UserMessage {
908 content: UserContent::Text("Hello".to_string()),
909 timestamp: 0,
910 }),
911 Message::assistant(AssistantMessage {
912 content: vec![
913 ContentBlock::Text(TextContent::new("Need tool output")),
914 ContentBlock::ToolCall(ToolCall {
915 id: "tool_1".to_string(),
916 name: "echo".to_string(),
917 arguments: json!({"text":"ping"}),
918 thought_signature: None,
919 }),
920 ],
921 api: "azure-openai".to_string(),
922 provider: "azure".to_string(),
923 model: "gpt-4o".to_string(),
924 usage: Usage::default(),
925 stop_reason: StopReason::ToolUse,
926 error_message: None,
927 timestamp: 0,
928 }),
929 ]
930 .into(),
931 tools: vec![ToolDef {
932 name: "echo".to_string(),
933 description: "Echo text".to_string(),
934 parameters: json!({
935 "type": "object",
936 "properties": {
937 "text": {"type":"string"}
938 },
939 "required": ["text"]
940 }),
941 }]
942 .into(),
943 };
944 let options = StreamOptions {
945 max_tokens: Some(512),
946 temperature: Some(0.0),
947 ..Default::default()
948 };
949
950 let request = provider.build_request(&context, &options);
951 let request_json = serde_json::to_value(&request).expect("serialize request");
952 assert_eq!(request_json["max_tokens"], json!(512));
953 assert_eq!(request_json["temperature"], json!(0.0));
954 assert_eq!(request_json["stream"], json!(true));
955 assert_eq!(request_json["messages"][0]["role"], json!("system"));
956 assert_eq!(
957 request_json["messages"][0]["content"],
958 json!("You are deterministic.")
959 );
960 assert_eq!(request_json["messages"][1]["role"], json!("user"));
961 assert_eq!(request_json["messages"][2]["role"], json!("assistant"));
962 assert_eq!(request_json["tools"][0]["type"], json!("function"));
963 assert_eq!(request_json["tools"][0]["function"]["name"], json!("echo"));
964 }
965
966 #[test]
967 fn test_azure_build_request_defaults_max_tokens() {
968 let provider = AzureOpenAIProvider::new("contoso", "gpt-4o");
969 let context = Context {
970 system_prompt: None,
971 messages: vec![Message::User(UserMessage {
972 content: UserContent::Text("Hello".to_string()),
973 timestamp: 0,
974 })]
975 .into(),
976 tools: Vec::new().into(),
977 };
978 let options = StreamOptions::default();
979
980 let request = provider.build_request(&context, &options);
981 let request_json = serde_json::to_value(&request).expect("serialize request");
982 assert_eq!(request_json["max_tokens"], json!(DEFAULT_MAX_TOKENS));
983 assert_eq!(request_json["stream"], json!(true));
984 assert!(request_json.get("tools").is_none());
985 }
986
987 #[test]
988 fn test_azure_build_request_normalizes_known_system_role_name() {
989 let provider =
990 AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
991 system_role_name: Some("SYSTEM ".to_string()),
992 ..CompatConfig::default()
993 }));
994 let context = Context {
995 system_prompt: Some("You are deterministic.".to_string().into()),
996 messages: Vec::new().into(),
997 tools: Vec::new().into(),
998 };
999
1000 let request = provider.build_request(&context, &StreamOptions::default());
1001 let request_json = serde_json::to_value(&request).expect("serialize request");
1002 assert_eq!(request_json["messages"][0]["role"], json!("system"));
1003 }
1004
1005 #[test]
1006 fn test_azure_build_request_preserves_unknown_system_role_name() {
1007 let provider =
1008 AzureOpenAIProvider::new("contoso", "gpt-4o").with_compat(Some(CompatConfig {
1009 system_role_name: Some("custom_role".to_string()),
1010 ..CompatConfig::default()
1011 }));
1012 let context = Context {
1013 system_prompt: Some("You are deterministic.".to_string().into()),
1014 messages: Vec::new().into(),
1015 tools: Vec::new().into(),
1016 };
1017
1018 let request = provider.build_request(&context, &StreamOptions::default());
1019 let request_json = serde_json::to_value(&request).expect("serialize request");
1020 assert_eq!(request_json["messages"][0]["role"], json!("custom_role"));
1021 }
1022
1023 #[test]
1024 fn test_azure_message_conversion() {
1025 let message = Message::User(UserMessage {
1026 content: UserContent::Text("Hello".to_string()),
1027 timestamp: chrono::Utc::now().timestamp_millis(),
1028 });
1029
1030 let azure_messages = convert_message_to_azure(&message);
1031 assert_eq!(azure_messages.len(), 1);
1032 assert_eq!(azure_messages[0].role, "user");
1033 }
1034
1035 #[derive(Debug, Deserialize)]
1036 struct ProviderFixture {
1037 cases: Vec<ProviderCase>,
1038 }
1039
1040 #[derive(Debug, Deserialize)]
1041 struct ProviderCase {
1042 name: String,
1043 events: Vec<Value>,
1044 expected: Vec<EventSummary>,
1045 }
1046
1047 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1048 struct EventSummary {
1049 kind: String,
1050 #[serde(default)]
1051 content_index: Option<usize>,
1052 #[serde(default)]
1053 delta: Option<String>,
1054 #[serde(default)]
1055 content: Option<String>,
1056 #[serde(default)]
1057 reason: Option<String>,
1058 }
1059
1060 #[test]
1061 fn test_stream_fixtures() {
1062 let fixture = load_fixture("azure_stream.json");
1063 for case in fixture.cases {
1064 let events = collect_events(&case.events);
1065 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1066 assert_eq!(summaries, case.expected, "case {}", case.name);
1067 }
1068 }
1069
1070 #[test]
1071 fn test_stream_handles_sparse_tool_call_index_without_panic() {
1072 let events = vec![
1073 json!({ "choices": [{ "delta": {} }] }),
1074 json!({
1075 "choices": [{
1076 "delta": {
1077 "tool_calls": [{
1078 "index": 3,
1079 "id": "call_sparse",
1080 "function": {
1081 "name": "lookup",
1082 "arguments": "{\"q\":\"azure\"}"
1083 }
1084 }]
1085 }
1086 }]
1087 }),
1088 json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1089 Value::String("[DONE]".to_string()),
1090 ];
1091
1092 let out = collect_events(&events);
1093 let done = out
1094 .iter()
1095 .find_map(|event| match event {
1096 StreamEvent::Done { message, .. } => Some(message),
1097 _ => None,
1098 })
1099 .expect("done event");
1100 let tool_calls: Vec<&ToolCall> = done
1101 .content
1102 .iter()
1103 .filter_map(|block| match block {
1104 ContentBlock::ToolCall(tc) => Some(tc),
1105 _ => None,
1106 })
1107 .collect();
1108 assert_eq!(tool_calls.len(), 1);
1109 assert_eq!(tool_calls[0].id, "call_sparse");
1110 assert_eq!(tool_calls[0].name, "lookup");
1111 assert_eq!(tool_calls[0].arguments, json!({ "q": "azure" }));
1112 assert!(
1113 out.iter()
1114 .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1115 "expected tool call start event"
1116 );
1117 }
1118
1119 fn load_fixture(file_name: &str) -> ProviderFixture {
1120 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1121 .join("tests/fixtures/provider_responses")
1122 .join(file_name);
1123 let raw = std::fs::read_to_string(path).expect("fixture read");
1124 serde_json::from_str(&raw).expect("fixture parse")
1125 }
1126
1127 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1128 let runtime = RuntimeBuilder::current_thread()
1129 .build()
1130 .expect("runtime build");
1131 runtime.block_on(async move {
1132 let byte_stream = stream::iter(
1133 events
1134 .iter()
1135 .map(|event| {
1136 let data = match event {
1137 Value::String(text) => text.clone(),
1138 _ => serde_json::to_string(event).expect("serialize event"),
1139 };
1140 format!("data: {data}\n\n").into_bytes()
1141 })
1142 .map(Ok),
1143 );
1144 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1145 let mut state = StreamState::new(
1146 event_source,
1147 "gpt-test".to_string(),
1148 "azure-openai".to_string(),
1149 "azure".to_string(),
1150 );
1151 let mut out = Vec::new();
1152
1153 while let Some(item) = state.event_source.next().await {
1154 let msg = item.expect("SSE event");
1155 if msg.data == "[DONE]" {
1156 out.extend(state.pending_events.drain(..));
1157 let reason = state.partial.stop_reason;
1158 out.push(StreamEvent::Done {
1159 reason,
1160 message: std::mem::take(&mut state.partial),
1161 });
1162 break;
1163 }
1164 state.process_event(&msg.data).expect("process_event");
1165 out.extend(state.pending_events.drain(..));
1166 }
1167
1168 out
1169 })
1170 }
1171
1172 fn summarize_event(event: &StreamEvent) -> EventSummary {
1173 match event {
1174 StreamEvent::Start { .. } => EventSummary {
1175 kind: "start".to_string(),
1176 content_index: None,
1177 delta: None,
1178 content: None,
1179 reason: None,
1180 },
1181 StreamEvent::TextDelta {
1182 content_index,
1183 delta,
1184 ..
1185 } => EventSummary {
1186 kind: "text_delta".to_string(),
1187 content_index: Some(*content_index),
1188 delta: Some(delta.clone()),
1189 content: None,
1190 reason: None,
1191 },
1192 StreamEvent::Done { reason, .. } => EventSummary {
1193 kind: "done".to_string(),
1194 content_index: None,
1195 delta: None,
1196 content: None,
1197 reason: Some(reason_to_string(*reason)),
1198 },
1199 StreamEvent::Error { reason, .. } => EventSummary {
1200 kind: "error".to_string(),
1201 content_index: None,
1202 delta: None,
1203 content: None,
1204 reason: Some(reason_to_string(*reason)),
1205 },
1206 StreamEvent::TextStart { content_index, .. } => EventSummary {
1207 kind: "text_start".to_string(),
1208 content_index: Some(*content_index),
1209 delta: None,
1210 content: None,
1211 reason: None,
1212 },
1213 StreamEvent::TextEnd {
1214 content_index,
1215 content,
1216 ..
1217 } => EventSummary {
1218 kind: "text_end".to_string(),
1219 content_index: Some(*content_index),
1220 delta: None,
1221 content: Some(content.clone()),
1222 reason: None,
1223 },
1224 _ => EventSummary {
1225 kind: "other".to_string(),
1226 content_index: None,
1227 delta: None,
1228 content: None,
1229 reason: None,
1230 },
1231 }
1232 }
1233
1234 fn reason_to_string(reason: StopReason) -> String {
1235 match reason {
1236 StopReason::Stop => "stop",
1237 StopReason::Length => "length",
1238 StopReason::ToolUse => "tool_use",
1239 StopReason::Error => "error",
1240 StopReason::Aborted => "aborted",
1241 }
1242 .to_string()
1243 }
1244}
1245
1246#[cfg(feature = "fuzzing")]
1251pub mod fuzz {
1252 use super::*;
1253 use futures::stream;
1254 use std::pin::Pin;
1255
1256 type FuzzStream =
1257 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1258
1259 pub struct Processor(StreamState<FuzzStream>);
1261
1262 impl Default for Processor {
1263 fn default() -> Self {
1264 Self::new()
1265 }
1266 }
1267
1268 impl Processor {
1269 pub fn new() -> Self {
1271 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1272 Self(StreamState::new(
1273 crate::sse::SseStream::new(Box::pin(empty)),
1274 "azure-fuzz".into(),
1275 "azure-openai".into(),
1276 "azure".into(),
1277 ))
1278 }
1279
1280 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1282 self.0.process_event(data)?;
1283 Ok(self.0.pending_events.drain(..).collect())
1284 }
1285 }
1286}