1use async_trait::async_trait;
4use bytes::Bytes;
5use futures::{Stream, StreamExt};
6use reqwest::Client;
7use serde::Deserialize;
8use serde_json::Value as JsonValue;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use crate::{
13 error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
14 ProviderEvent, StopReason, StreamOptions, Usage,
15};
16
17use super::shared_client;
18
19#[derive(Clone)]
34pub struct AzureProvider {
35 client: &'static Client,
36 api_key: Option<String>,
37 resource_name: Option<String>,
38 deployment_name: Option<String>,
39}
40
41impl AzureProvider {
42 pub fn new() -> Self {
46 Self {
47 client: shared_client(),
48 api_key: None,
49 resource_name: None,
50 deployment_name: None,
51 }
52 }
53
54 #[cfg(test)]
56 pub fn with_config(
57 api_key: impl Into<String>,
58 resource_name: impl Into<String>,
59 deployment_name: impl Into<String>,
60 ) -> Self {
61 Self {
62 client: shared_client(),
63 api_key: Some(api_key.into()),
64 resource_name: Some(resource_name.into()),
65 deployment_name: Some(deployment_name.into()),
66 }
67 }
68
69 fn build_url(&self, model: &Model) -> Result<String, ProviderError> {
71 if !model.base_url.is_empty() && model.base_url != "https://api.openai.com" {
73 return Ok(format!(
75 "{}/chat/completions?api-version=2024-02-15-preview",
76 model.base_url.trim_end_matches('/')
77 ));
78 }
79
80 let resource = self.resource_name.as_ref().ok_or_else(|| {
82 ProviderError::InvalidResponse("AZURE_OPENAI_RESOURCE_NAME not set".into())
83 })?;
84
85 let deployment = self.deployment_name.as_ref().ok_or_else(|| {
86 ProviderError::InvalidResponse("AZURE_OPENAI_DEPLOYMENT_NAME not set".into())
87 })?;
88
89 let url = format!(
90 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version=2024-02-15-preview",
91 resource, deployment
92 );
93
94 Ok(url)
95 }
96
97 fn get_api_key(&self, options: &Option<StreamOptions>) -> Result<String, ProviderError> {
99 options
100 .as_ref()
101 .and_then(|o| o.api_key.as_ref())
102 .or(self.api_key.as_ref())
103 .cloned()
104 .ok_or_else(|| ProviderError::MissingApiKey)
105 }
106
107 fn build_headers(
109 &self,
110 api_key: &str,
111 options: &Option<StreamOptions>,
112 ) -> reqwest::header::HeaderMap {
113 let mut headers = reqwest::header::HeaderMap::new();
114
115 headers.insert("api-key", api_key.parse().expect("valid header value"));
117 headers.insert(
118 reqwest::header::CONTENT_TYPE,
119 "application/json".parse().expect("valid header value"),
120 );
121
122 if let Some(opts) = options {
124 for (k, v) in &opts.headers {
125 if let (Ok(name), Ok(value)) = (
126 k.parse::<reqwest::header::HeaderName>(),
127 v.parse::<reqwest::header::HeaderValue>(),
128 ) {
129 headers.insert(name, value);
130 }
131 }
132 }
133
134 headers
135 }
136}
137
138impl Default for AzureProvider {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144#[async_trait]
145impl Provider for AzureProvider {
146 async fn stream(
147 &self,
148 model: &Model,
149 context: &Context,
150 options: Option<StreamOptions>,
151 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
152 let url = self.build_url(model)?;
154
155 let api_key = self.get_api_key(&options)?;
157
158 let messages = build_messages(context)?;
160
161 let mut body = serde_json::json!({
163 "messages": messages,
164 "stream": true,
165 });
166
167 if model.id != "default" && model.id != "azure" {
169 body["model"] = serde_json::json!(model.id);
170 }
171
172 if let Some(ref opts) = options {
174 if let Some(temp) = opts.temperature {
175 body["temperature"] = serde_json::json!(temp);
176 }
177
178 if let Some(max) = opts.max_tokens {
179 body["max_tokens"] = serde_json::json!(max);
180 }
181 }
182
183 if !context.tools.is_empty() {
185 body["tools"] = build_tools(&context.tools)?;
186 }
187
188 let headers = self.build_headers(&api_key, &options);
190
191 let response = self
193 .client
194 .post(&url)
195 .headers(headers)
196 .json(&body)
197 .send()
198 .await
199 .map_err(ProviderError::RequestFailed)?;
200
201 if !response.status().is_success() {
202 let status = response.status();
203 let body: String = response.text().await.unwrap_or_default();
204 return Err(ProviderError::HttpError(status.as_u16(), body));
205 }
206
207 let provider_name = model.provider.clone();
209 let model_id = model.id.clone();
210
211 let stream = response.bytes_stream().flat_map(
212 move |chunk: Result<Bytes, reqwest::Error>| match chunk {
213 Ok(bytes) => {
214 let text = String::from_utf8_lossy(&bytes).to_string();
215 futures::stream::iter(parse_sse_events(&text, &provider_name, &model_id))
216 }
217 Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
218 reason: StopReason::Error,
219 error: create_error_message(&e.to_string(), &provider_name, &model_id),
220 }]),
221 },
222 );
223
224 Ok(Box::pin(stream))
225 }
226
227 fn name(&self) -> &str {
228 "azure"
229 }
230}
231
232fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
234 let mut messages = Vec::new();
235
236 if let Some(ref prompt) = context.system_prompt {
238 messages.push(serde_json::json!({
239 "role": "system",
240 "content": prompt,
241 }));
242 }
243
244 for msg in &context.messages {
246 match msg {
247 crate::Message::User(u) => {
248 let content: String = match &u.content {
249 crate::MessageContent::Text(s) => s.clone(),
250 crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
251 };
252 messages.push(serde_json::json!({
253 "role": "user",
254 "content": content,
255 }));
256 }
257 crate::Message::Assistant(a) => {
258 let content = blocks_to_content(&a.content)?.to_string();
259 messages.push(serde_json::json!({
260 "role": "assistant",
261 "content": content,
262 }));
263 }
264 crate::Message::ToolResult(t) => {
265 let content = blocks_to_content(&t.content)?.to_string();
266 messages.push(serde_json::json!({
267 "role": "tool",
268 "tool_call_id": t.tool_call_id,
269 "tool_name": t.tool_name,
270 "content": content,
271 }));
272 }
273 }
274 }
275
276 Ok(messages)
277}
278
279fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
281 if blocks.len() == 1 {
282 if let Some(text) = blocks[0].as_text() {
283 return Ok(JsonValue::String(text.to_string()));
284 }
285 }
286
287 let items: Result<Vec<_>, _> = blocks
288 .iter()
289 .map(|block| match block {
290 ContentBlock::Text(t) => Ok(serde_json::json!({
291 "type": "text",
292 "text": t.text,
293 })),
294 ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
295 "type": "function",
296 "id": tc.id,
297 "function": {
298 "name": tc.name,
299 "arguments": tc.arguments.to_string(),
300 },
301 })),
302 ContentBlock::Thinking(th) => Ok(serde_json::json!({
303 "type": "thinking",
304 "thinking": th.thinking,
305 })),
306 ContentBlock::Image(img) => Ok(serde_json::json!({
307 "type": "image_url",
308 "image_url": {
309 "url": format!("data:{};base64,{}", img.mime_type, img.data),
310 },
311 })),
312 ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
313 "Unknown content block type".into(),
314 )),
315 })
316 .collect();
317
318 Ok(serde_json::json!(items?))
319}
320
321fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
323 let items: Vec<_> = tools
324 .iter()
325 .map(|tool| {
326 serde_json::json!({
327 "type": "function",
328 "function": {
329 "name": tool.name,
330 "description": tool.description,
331 "parameters": tool.parameters,
332 },
333 })
334 })
335 .collect();
336
337 Ok(serde_json::json!(items))
338}
339
340fn parse_sse_events(text: &str, provider: &str, model_id: &str) -> Vec<ProviderEvent> {
344 let mut events = Vec::new();
345 let mut partial_message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
346
347 let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
349 events.reserve(estimated_events);
350
351 let mut accumulated_usage = Usage::default();
352
353 for line in text.split('\n') {
354 let line = line.trim_end_matches('\r');
355 if line.is_empty() {
356 continue;
357 }
358
359 if !line.starts_with("data: ") {
361 continue;
362 }
363
364 let data = &line[6..]; if data == "[DONE]" {
368 break;
369 }
370
371 if data.is_empty() {
372 continue;
373 }
374
375 let chunk = match serde_json::from_str::<SSEChunk>(data) {
376 Ok(c) => c,
377 Err(_) => continue,
378 };
379
380 let this_chunk_usage = chunk.usage.as_ref();
382
383 for choice in &chunk.choices {
384 if let Some(delta) = &choice.delta {
385 if let Some(content) = &delta.content {
386 let last_text_idx = partial_message
389 .content
390 .iter()
391 .rposition(|b| matches!(b, ContentBlock::Text(_)));
392 if let Some(idx) = last_text_idx {
393 if let ContentBlock::Text(t) = &mut partial_message.content[idx] {
394 t.text.push_str(content);
395 }
396 } else {
397 partial_message
398 .content
399 .push(ContentBlock::Text(crate::TextContent::new(content.clone())));
400 }
401 events.push(ProviderEvent::TextDelta {
402 content_index: choice.index,
403 delta: content.clone(),
404 partial: Arc::new(partial_message.clone()),
405 });
406 }
407
408 if let Some(tool_calls) = &delta.tool_calls {
409 for tc in tool_calls {
410 if let Some(func) = &tc.function {
411 events.push(ProviderEvent::ToolCallDelta {
412 content_index: choice.index,
413 delta: func.arguments.clone().unwrap_or_default(),
414 partial: Arc::new(partial_message.clone()),
415 });
416 }
417 }
418 }
419 }
420
421 if choice.finish_reason.is_some() {
422 let reason = match choice.finish_reason.as_deref() {
425 Some("stop") => StopReason::Stop,
426 Some("length") => StopReason::Length,
427 Some("tool_calls") => StopReason::ToolUse,
428 _ => StopReason::Stop,
429 };
430
431 let mut done_msg = partial_message.clone();
432
433 if let Some(usage) = this_chunk_usage {
435 done_msg.usage.input = usage.prompt_tokens;
436 done_msg.usage.output = usage.completion_tokens;
437 done_msg.usage.cache_read = usage
438 .prompt_tokens_details
439 .as_ref()
440 .map(|d| d.cached_tokens)
441 .unwrap_or(0);
442 done_msg.usage.total_tokens = usage.total_tokens;
443 } else {
444 done_msg.usage = accumulated_usage.clone();
445 }
446
447 events.push(ProviderEvent::Done {
448 reason,
449 message: done_msg,
450 });
451 }
452 }
453
454 if let Some(usage) = this_chunk_usage {
456 accumulated_usage.input = usage.prompt_tokens;
457 accumulated_usage.output = usage.completion_tokens;
458 accumulated_usage.cache_read = usage
459 .prompt_tokens_details
460 .as_ref()
461 .map(|d| d.cached_tokens)
462 .unwrap_or(0);
463 accumulated_usage.total_tokens = usage.total_tokens;
464 }
465 }
466
467 events
468}
469
470fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
472 let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
473 message.stop_reason = StopReason::Error;
474 message.error_message = Some(msg.to_string());
475 message
476}
477
478#[derive(Debug, Deserialize)]
480struct SSEChunk {
481 _id: Option<String>,
482 #[serde(rename = "model")]
483 _model: Option<String>,
484 choices: Vec<Choice>,
485 usage: Option<UsageInfo>,
486}
487
488#[derive(Debug, Deserialize)]
489struct Choice {
490 index: usize,
491 delta: Option<Delta>,
492 finish_reason: Option<String>,
493}
494
495#[derive(Debug, Deserialize)]
496struct Delta {
497 content: Option<String>,
498 tool_calls: Option<Vec<ToolCallDelta>>,
499}
500
501#[derive(Debug, Deserialize)]
502struct ToolCallDelta {
503 _index: Option<usize>,
504 _id: Option<String>,
505 #[serde(rename = "type")]
506 _type_: Option<String>,
507 function: Option<FunctionDelta>,
508}
509
510#[derive(Debug, Deserialize)]
511struct FunctionDelta {
512 _name: Option<String>,
513 arguments: Option<String>,
514}
515
516#[derive(Debug, Deserialize, Clone)]
517struct UsageInfo {
518 prompt_tokens: usize,
519 completion_tokens: usize,
520 total_tokens: usize,
521 #[serde(rename = "prompt_tokens_details")]
522 prompt_tokens_details: Option<PromptTokensDetails>,
523}
524
525#[derive(Debug, Deserialize, Clone)]
526struct PromptTokensDetails {
527 #[serde(rename = "cached_tokens")]
528 cached_tokens: usize,
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 fn make_test_model(id: &str, base_url: &str) -> Model {
536 Model::new(id, id, Api::OpenAiCompletions, "azure", base_url)
537 }
538
539 #[test]
540 fn test_provider_name() {
541 let provider = AzureProvider::new();
542 assert_eq!(provider.name(), "azure");
543 }
544
545 #[test]
546 fn test_build_url_from_base_url() {
547 let provider = AzureProvider::new();
548 let model = make_test_model(
549 "gpt-4o",
550 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o",
551 );
552
553 let url = provider.build_url(&model).unwrap();
554 assert!(url.contains("api-version=2024-02-15-preview"));
555 assert!(url.contains("my-resource"));
556 assert!(url.contains("gpt-4o"));
557 }
558
559 #[test]
560 fn test_build_url_missing_resource() {
561 let provider = AzureProvider {
562 client: shared_client(),
563 api_key: Some("test-key".to_string()),
564 resource_name: None,
565 deployment_name: Some("gpt-4o".to_string()),
566 };
567
568 let model = make_test_model("default", "");
569
570 let result = provider.build_url(&model);
571 assert!(result.is_err());
572 match result.unwrap_err() {
573 ProviderError::InvalidResponse(msg) => {
574 assert!(msg.contains("AZURE_OPENAI_RESOURCE_NAME"));
575 }
576 _ => panic!("Expected InvalidResponse"),
577 }
578 }
579
580 #[test]
581 fn test_build_url_missing_deployment() {
582 let provider = AzureProvider {
583 client: shared_client(),
584 api_key: Some("test-key".to_string()),
585 resource_name: Some("my-resource".to_string()),
586 deployment_name: None,
587 };
588
589 let model = make_test_model("default", "");
590
591 let result = provider.build_url(&model);
592 assert!(result.is_err());
593 match result.unwrap_err() {
594 ProviderError::InvalidResponse(msg) => {
595 assert!(msg.contains("AZURE_OPENAI_DEPLOYMENT_NAME"));
596 }
597 _ => panic!("Expected InvalidResponse"),
598 }
599 }
600
601 #[test]
602 fn test_build_url_from_env_vars() {
603 let provider = AzureProvider {
604 client: shared_client(),
605 api_key: Some("test-key".to_string()),
606 resource_name: Some("my-resource".to_string()),
607 deployment_name: Some("gpt-4o".to_string()),
608 };
609
610 let model = make_test_model("default", "");
611
612 let url = provider.build_url(&model).unwrap();
613 assert_eq!(url, "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-15-preview");
614 }
615
616 #[test]
617 fn test_parse_sse_events_text() {
618 let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
619
620data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":"stop"}]}
621
622data: [DONE]"#;
623
624 let events = parse_sse_events(sse_data, "azure", "gpt-4o");
625
626 assert!(events.len() >= 3);
628
629 match &events[0] {
631 ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "Hello"),
632 _ => panic!("Expected TextDelta event"),
633 }
634
635 match &events[events.len() - 1] {
637 ProviderEvent::Done { reason, .. } => assert_eq!(*reason, StopReason::Stop),
638 _ => panic!("Expected Done event"),
639 }
640 }
641
642 #[test]
643 fn test_parse_sse_events_with_tool_calls() {
644 let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_abc123","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}
645
646data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"location\":"}}]},"finish_reason":null}]}
647
648data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":"tool_calls"}]}
649
650data: [DONE]"#;
651
652 let events = parse_sse_events(sse_data, "azure", "gpt-4o");
653
654 assert!(events.len() >= 4);
656
657 let has_tool_call = events
659 .iter()
660 .any(|e| matches!(e, ProviderEvent::ToolCallDelta { .. }));
661 assert!(
662 has_tool_call,
663 "Should have at least one ToolCallDelta event"
664 );
665
666 match &events[events.len() - 1] {
668 ProviderEvent::Done { reason, .. } => assert_eq!(*reason, StopReason::ToolUse),
669 _ => panic!("Expected Done event with ToolUse reason"),
670 }
671 }
672
673 #[test]
674 fn test_parse_sse_events_usage() {
675 let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15,"prompt_tokens_details":{"cached_tokens":0}}}
676
677data: [DONE]"#;
678
679 let events = parse_sse_events(sse_data, "azure", "gpt-4o");
680
681 let done_event = events
683 .iter()
684 .find(|e| matches!(e, ProviderEvent::Done { .. }));
685 assert!(done_event.is_some());
686
687 if let ProviderEvent::Done { message, .. } = done_event.unwrap() {
688 assert_eq!(message.usage.input, 10);
689 assert_eq!(message.usage.output, 5);
690 assert_eq!(message.usage.total_tokens, 15);
691 }
692 }
693
694 #[test]
695 fn test_build_headers_includes_api_key() {
696 let provider = AzureProvider::new();
697 let api_key = "test-api-key-12345";
698
699 let headers = provider.build_headers(api_key, &None);
700
701 let api_key_header = headers.get("api-key");
703 assert!(api_key_header.is_some());
704 assert_eq!(api_key_header.unwrap().to_str().unwrap(), api_key);
705
706 let content_type = headers.get(reqwest::header::CONTENT_TYPE);
708 assert!(content_type.is_some());
709 }
710
711 #[test]
712 fn test_build_headers_no_bearer_token() {
713 let provider = AzureProvider::new();
714 let api_key = "test-api-key-12345";
715
716 let headers = provider.build_headers(api_key, &None);
717
718 let auth_header = headers.get(reqwest::header::AUTHORIZATION);
720 assert!(
721 auth_header.is_none(),
722 "Azure should not use Bearer token authentication"
723 );
724 }
725
726 #[test]
727 fn test_with_config_constructor() {
728 let provider = AzureProvider::with_config("my-api-key", "my-resource", "gpt-4o");
729
730 let model = make_test_model("default", "");
732
733 let url = provider.build_url(&model).unwrap();
734 assert!(url.contains("my-resource"));
735 assert!(url.contains("gpt-4o"));
736 }
737
738 #[test]
739 fn test_azure_endpoint_format() {
740 let provider = AzureProvider {
741 client: shared_client(),
742 api_key: Some("key".to_string()),
743 resource_name: Some("my-resource".to_string()),
744 deployment_name: Some("gpt-4-turbo".to_string()),
745 };
746
747 let model = make_test_model("default", "");
748 let url = provider.build_url(&model).unwrap();
749
750 assert!(url.starts_with("https://"));
752 assert!(url.contains(".openai.azure.com"));
753 assert!(url.contains("/openai/deployments/"));
754 assert!(url.contains("gpt-4-turbo"));
755 assert!(url.contains("chat/completions"));
756 assert!(url.contains("api-version=2024-02-15-preview"));
757 }
758}