1use crate::providers::traits::{
2 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
3 Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
4};
5use crate::tools::ToolSpec;
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9
10const DEFAULT_API_VERSION: &str = "2024-08-01-preview";
11
12pub struct AzureOpenAiProvider {
13 credential: Option<String>,
14 resource_name: String,
15 deployment_name: String,
16 api_version: String,
17 base_url: String,
18}
19
20#[derive(Debug, Serialize)]
21struct ChatRequest {
22 messages: Vec<Message>,
23 temperature: f64,
24}
25
26#[derive(Debug, Serialize)]
27struct Message {
28 role: String,
29 content: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct ChatResponse {
34 choices: Vec<Choice>,
35}
36
37#[derive(Debug, Deserialize)]
38struct Choice {
39 message: ResponseMessage,
40}
41
42#[derive(Debug, Deserialize)]
43struct ResponseMessage {
44 #[serde(default)]
45 content: Option<String>,
46 #[serde(default)]
47 reasoning_content: Option<String>,
48}
49
50impl ResponseMessage {
51 fn effective_content(&self) -> String {
52 match &self.content {
53 Some(c) if !c.is_empty() => c.clone(),
54 _ => self.reasoning_content.clone().unwrap_or_default(),
55 }
56 }
57}
58
59#[derive(Debug, Serialize)]
60struct NativeChatRequest {
61 messages: Vec<NativeMessage>,
62 temperature: f64,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 tools: Option<Vec<NativeToolSpec>>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 tool_choice: Option<String>,
67}
68
69#[derive(Debug, Serialize)]
70struct NativeMessage {
71 role: String,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 content: Option<String>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 tool_call_id: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 tool_calls: Option<Vec<NativeToolCall>>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 reasoning_content: Option<String>,
80}
81
82#[derive(Debug, Serialize, Deserialize)]
83struct NativeToolSpec {
84 #[serde(rename = "type")]
85 kind: String,
86 function: NativeToolFunctionSpec,
87}
88
89#[derive(Debug, Serialize, Deserialize)]
90struct NativeToolFunctionSpec {
91 name: String,
92 description: String,
93 parameters: serde_json::Value,
94}
95
96fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result<NativeToolSpec> {
97 let spec: NativeToolSpec = serde_json::from_value(value)
98 .map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?;
99
100 if spec.kind != "function" {
101 anyhow::bail!(
102 "Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'",
103 spec.kind
104 );
105 }
106
107 Ok(spec)
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct NativeToolCall {
112 #[serde(skip_serializing_if = "Option::is_none")]
113 id: Option<String>,
114 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
115 kind: Option<String>,
116 function: NativeFunctionCall,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct NativeFunctionCall {
121 name: String,
122 arguments: String,
123}
124
125#[derive(Debug, Deserialize)]
126struct NativeChatResponse {
127 choices: Vec<NativeChoice>,
128 #[serde(default)]
129 usage: Option<UsageInfo>,
130}
131
132#[derive(Debug, Deserialize)]
133struct UsageInfo {
134 #[serde(default)]
135 prompt_tokens: Option<u64>,
136 #[serde(default)]
137 completion_tokens: Option<u64>,
138}
139
140#[derive(Debug, Deserialize)]
141struct NativeChoice {
142 message: NativeResponseMessage,
143}
144
145#[derive(Debug, Deserialize)]
146struct NativeResponseMessage {
147 #[serde(default)]
148 content: Option<String>,
149 #[serde(default)]
150 reasoning_content: Option<String>,
151 #[serde(default)]
152 tool_calls: Option<Vec<NativeToolCall>>,
153}
154
155impl NativeResponseMessage {
156 fn effective_content(&self) -> Option<String> {
157 match &self.content {
158 Some(c) if !c.is_empty() => Some(c.clone()),
159 _ => self.reasoning_content.clone(),
160 }
161 }
162}
163
164impl AzureOpenAiProvider {
165 pub fn new(
166 credential: Option<&str>,
167 resource_name: &str,
168 deployment_name: &str,
169 api_version: Option<&str>,
170 ) -> Self {
171 let version = api_version.unwrap_or(DEFAULT_API_VERSION);
172 let base_url = format!(
173 "https://{}.openai.azure.com/openai/deployments/{}",
174 resource_name, deployment_name
175 );
176 Self {
177 credential: credential.map(ToString::to_string),
178 resource_name: resource_name.to_string(),
179 deployment_name: deployment_name.to_string(),
180 api_version: version.to_string(),
181 base_url,
182 }
183 }
184
185 fn chat_completions_url(&self) -> String {
186 format!(
187 "{}/chat/completions?api-version={}",
188 self.base_url, self.api_version
189 )
190 }
191
192 fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
193 tools.map(|items| {
194 items
195 .iter()
196 .map(|tool| NativeToolSpec {
197 kind: "function".to_string(),
198 function: NativeToolFunctionSpec {
199 name: tool.name.clone(),
200 description: tool.description.clone(),
201 parameters: tool.parameters.clone(),
202 },
203 })
204 .collect()
205 })
206 }
207
208 fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
209 messages
210 .iter()
211 .map(|m| {
212 if m.role == "assistant" {
213 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
214 if let Some(tool_calls_value) = value.get("tool_calls") {
215 if let Ok(parsed_calls) =
216 serde_json::from_value::<Vec<ProviderToolCall>>(
217 tool_calls_value.clone(),
218 )
219 {
220 let tool_calls = parsed_calls
221 .into_iter()
222 .map(|tc| NativeToolCall {
223 id: Some(tc.id),
224 kind: Some("function".to_string()),
225 function: NativeFunctionCall {
226 name: tc.name,
227 arguments: tc.arguments,
228 },
229 })
230 .collect::<Vec<_>>();
231 let content = value
232 .get("content")
233 .and_then(serde_json::Value::as_str)
234 .map(ToString::to_string);
235 let reasoning_content = value
236 .get("reasoning_content")
237 .and_then(serde_json::Value::as_str)
238 .map(ToString::to_string);
239 return NativeMessage {
240 role: "assistant".to_string(),
241 content,
242 tool_call_id: None,
243 tool_calls: Some(tool_calls),
244 reasoning_content,
245 };
246 }
247 }
248 }
249 }
250
251 if m.role == "tool" {
252 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
253 let tool_call_id = value
254 .get("tool_call_id")
255 .and_then(serde_json::Value::as_str)
256 .map(ToString::to_string);
257 let content = value
258 .get("content")
259 .and_then(serde_json::Value::as_str)
260 .map(ToString::to_string);
261 return NativeMessage {
262 role: "tool".to_string(),
263 content,
264 tool_call_id,
265 tool_calls: None,
266 reasoning_content: None,
267 };
268 }
269 }
270
271 NativeMessage {
272 role: m.role.clone(),
273 content: Some(m.content.clone()),
274 tool_call_id: None,
275 tool_calls: None,
276 reasoning_content: None,
277 }
278 })
279 .collect()
280 }
281
282 fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
283 let text = message.effective_content();
284 let reasoning_content = message.reasoning_content.clone();
285 let tool_calls = message
286 .tool_calls
287 .unwrap_or_default()
288 .into_iter()
289 .map(|tc| ProviderToolCall {
290 id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
291 name: tc.function.name,
292 arguments: tc.function.arguments,
293 })
294 .collect::<Vec<_>>();
295
296 ProviderChatResponse {
297 text,
298 tool_calls,
299 usage: None,
300 reasoning_content,
301 }
302 }
303
304 fn http_client(&self) -> Client {
305 crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10)
306 }
307}
308
309#[async_trait]
310impl Provider for AzureOpenAiProvider {
311 fn capabilities(&self) -> ProviderCapabilities {
312 ProviderCapabilities {
313 native_tool_calling: true,
314 vision: true,
315 prompt_caching: false,
316 }
317 }
318
319 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
320 ToolsPayload::OpenAI {
321 tools: tools
322 .iter()
323 .map(|tool| {
324 serde_json::json!({
325 "type": "function",
326 "function": {
327 "name": tool.name,
328 "description": tool.description,
329 "parameters": tool.parameters,
330 }
331 })
332 })
333 .collect(),
334 }
335 }
336
337 fn supports_native_tools(&self) -> bool {
338 true
339 }
340
341 fn supports_vision(&self) -> bool {
342 true
343 }
344
345 async fn chat_with_system(
346 &self,
347 system_prompt: Option<&str>,
348 message: &str,
349 _model: &str,
350 temperature: f64,
351 ) -> anyhow::Result<String> {
352 let credential = self.credential.as_ref().ok_or_else(|| {
353 anyhow::anyhow!(
354 "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
355 )
356 })?;
357
358 let mut messages = Vec::new();
359
360 if let Some(sys) = system_prompt {
361 messages.push(Message {
362 role: "system".to_string(),
363 content: sys.to_string(),
364 });
365 }
366
367 messages.push(Message {
368 role: "user".to_string(),
369 content: message.to_string(),
370 });
371
372 let request = ChatRequest {
373 messages,
374 temperature,
375 };
376
377 let response = self
378 .http_client()
379 .post(self.chat_completions_url())
380 .header("api-key", credential.as_str())
381 .json(&request)
382 .send()
383 .await?;
384
385 if !response.status().is_success() {
386 return Err(super::api_error("Azure OpenAI", response).await);
387 }
388
389 let chat_response: ChatResponse = response.json().await?;
390
391 chat_response
392 .choices
393 .into_iter()
394 .next()
395 .map(|c| c.message.effective_content())
396 .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
397 }
398
399 async fn chat(
400 &self,
401 request: ProviderChatRequest<'_>,
402 _model: &str,
403 temperature: f64,
404 ) -> anyhow::Result<ProviderChatResponse> {
405 let credential = self.credential.as_ref().ok_or_else(|| {
406 anyhow::anyhow!(
407 "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
408 )
409 })?;
410
411 let tools = Self::convert_tools(request.tools);
412 let native_request = NativeChatRequest {
413 messages: Self::convert_messages(request.messages),
414 temperature,
415 tool_choice: tools.as_ref().map(|_| "auto".to_string()),
416 tools,
417 };
418
419 let response = self
420 .http_client()
421 .post(self.chat_completions_url())
422 .header("api-key", credential.as_str())
423 .json(&native_request)
424 .send()
425 .await?;
426
427 if !response.status().is_success() {
428 return Err(super::api_error("Azure OpenAI", response).await);
429 }
430
431 let native_response: NativeChatResponse = response.json().await?;
432 let usage = native_response.usage.map(|u| TokenUsage {
433 input_tokens: u.prompt_tokens,
434 output_tokens: u.completion_tokens,
435 cached_input_tokens: None,
436 });
437 let message = native_response
438 .choices
439 .into_iter()
440 .next()
441 .map(|c| c.message)
442 .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
443 let mut result = Self::parse_native_response(message);
444 result.usage = usage;
445 Ok(result)
446 }
447
448 async fn chat_with_tools(
449 &self,
450 messages: &[ChatMessage],
451 tools: &[serde_json::Value],
452 _model: &str,
453 temperature: f64,
454 ) -> anyhow::Result<ProviderChatResponse> {
455 let credential = self.credential.as_ref().ok_or_else(|| {
456 anyhow::anyhow!(
457 "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
458 )
459 })?;
460
461 let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
462 None
463 } else {
464 Some(
465 tools
466 .iter()
467 .cloned()
468 .map(parse_native_tool_spec)
469 .collect::<Result<Vec<_>, _>>()?,
470 )
471 };
472
473 let native_request = NativeChatRequest {
474 messages: Self::convert_messages(messages),
475 temperature,
476 tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
477 tools: native_tools,
478 };
479
480 let response = self
481 .http_client()
482 .post(self.chat_completions_url())
483 .header("api-key", credential.as_str())
484 .json(&native_request)
485 .send()
486 .await?;
487
488 if !response.status().is_success() {
489 return Err(super::api_error("Azure OpenAI", response).await);
490 }
491
492 let native_response: NativeChatResponse = response.json().await?;
493 let usage = native_response.usage.map(|u| TokenUsage {
494 input_tokens: u.prompt_tokens,
495 output_tokens: u.completion_tokens,
496 cached_input_tokens: None,
497 });
498 let message = native_response
499 .choices
500 .into_iter()
501 .next()
502 .map(|c| c.message)
503 .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
504 let mut result = Self::parse_native_response(message);
505 result.usage = usage;
506 Ok(result)
507 }
508
509 async fn warmup(&self) -> anyhow::Result<()> {
510 Ok(())
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn url_construction_default_version() {
522 let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None);
523 assert_eq!(
524 p.chat_completions_url(),
525 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
526 );
527 }
528
529 #[test]
530 fn url_construction_custom_version() {
531 let p = AzureOpenAiProvider::new(
532 Some("test-key"),
533 "my-resource",
534 "gpt-4o",
535 Some("2024-06-01"),
536 );
537 assert_eq!(
538 p.chat_completions_url(),
539 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01"
540 );
541 }
542
543 #[test]
544 fn url_construction_preserves_resource_and_deployment() {
545 let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None);
546 let url = p.chat_completions_url();
547 assert!(url.contains("contoso-ai.openai.azure.com"));
548 assert!(url.contains("/deployments/my-gpt35-deployment/"));
549 assert!(url.contains("api-version=2024-08-01-preview"));
550 }
551
552 #[test]
553 fn auth_header_uses_api_key_not_bearer() {
554 let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None);
558 assert_eq!(p.credential.as_deref(), Some("my-azure-key"));
559 }
560
561 #[test]
562 fn creates_with_credential() {
563 let p = AzureOpenAiProvider::new(
564 Some("azure-test-credential"),
565 "resource",
566 "deployment",
567 None,
568 );
569 assert_eq!(p.credential.as_deref(), Some("azure-test-credential"));
570 assert_eq!(p.resource_name, "resource");
571 assert_eq!(p.deployment_name, "deployment");
572 assert_eq!(p.api_version, DEFAULT_API_VERSION);
573 }
574
575 #[test]
576 fn creates_without_credential() {
577 let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
578 assert!(p.credential.is_none());
579 }
580
581 #[tokio::test]
582 async fn chat_fails_without_key() {
583 let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
584 let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
585 assert!(result.is_err());
586 assert!(result.unwrap_err().to_string().contains("API key not set"));
587 }
588
589 #[tokio::test]
590 async fn chat_with_system_fails_without_key() {
591 let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
592 let result = p
593 .chat_with_system(Some("You are Construct"), "test", "gpt-4o", 0.5)
594 .await;
595 assert!(result.is_err());
596 }
597
598 #[test]
599 fn request_serializes_with_system_message() {
600 let req = ChatRequest {
601 messages: vec![
602 Message {
603 role: "system".to_string(),
604 content: "You are Construct".to_string(),
605 },
606 Message {
607 role: "user".to_string(),
608 content: "hello".to_string(),
609 },
610 ],
611 temperature: 0.7,
612 };
613 let json = serde_json::to_string(&req).unwrap();
614 assert!(json.contains("\"role\":\"system\""));
615 assert!(json.contains("\"role\":\"user\""));
616 assert!(!json.contains("\"model\""));
618 }
619
620 #[test]
621 fn request_serializes_without_system() {
622 let req = ChatRequest {
623 messages: vec![Message {
624 role: "user".to_string(),
625 content: "hello".to_string(),
626 }],
627 temperature: 0.0,
628 };
629 let json = serde_json::to_string(&req).unwrap();
630 assert!(!json.contains("system"));
631 assert!(json.contains("\"temperature\":0.0"));
632 }
633
634 #[test]
635 fn response_deserializes_single_choice() {
636 let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
637 let resp: ChatResponse = serde_json::from_str(json).unwrap();
638 assert_eq!(resp.choices.len(), 1);
639 assert_eq!(resp.choices[0].message.effective_content(), "Hi!");
640 }
641
642 #[test]
643 fn response_deserializes_empty_choices() {
644 let json = r#"{"choices":[]}"#;
645 let resp: ChatResponse = serde_json::from_str(json).unwrap();
646 assert!(resp.choices.is_empty());
647 }
648
649 #[test]
650 fn response_deserializes_multiple_choices() {
651 let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
652 let resp: ChatResponse = serde_json::from_str(json).unwrap();
653 assert_eq!(resp.choices.len(), 2);
654 assert_eq!(resp.choices[0].message.effective_content(), "A");
655 }
656
657 #[test]
658 fn tool_call_response_parsing() {
659 let json = r#"{"choices":[{"message":{
660 "content":"Let me check",
661 "tool_calls":[{
662 "id":"call_abc123",
663 "type":"function",
664 "function":{"name":"shell","arguments":"{\"command\":\"ls\"}"}
665 }]
666 }}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#;
667 let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
668 let message = resp.choices.into_iter().next().unwrap().message;
669 let parsed = AzureOpenAiProvider::parse_native_response(message);
670 assert_eq!(parsed.text.as_deref(), Some("Let me check"));
671 assert_eq!(parsed.tool_calls.len(), 1);
672 assert_eq!(parsed.tool_calls[0].id, "call_abc123");
673 assert_eq!(parsed.tool_calls[0].name, "shell");
674 assert!(parsed.tool_calls[0].arguments.contains("ls"));
675 }
676
677 #[test]
678 fn tool_call_response_without_id_generates_uuid() {
679 let json = r#"{"choices":[{"message":{
680 "content":null,
681 "tool_calls":[{
682 "function":{"name":"test","arguments":"{}"}
683 }]
684 }}]}"#;
685 let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
686 let message = resp.choices.into_iter().next().unwrap().message;
687 let parsed = AzureOpenAiProvider::parse_native_response(message);
688 assert_eq!(parsed.tool_calls.len(), 1);
689 assert!(!parsed.tool_calls[0].id.is_empty());
690 }
691
692 #[tokio::test]
693 async fn chat_with_tools_fails_without_key() {
694 let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
695 let messages = vec![ChatMessage::user("hello".to_string())];
696 let tools = vec![serde_json::json!({
697 "type": "function",
698 "function": {
699 "name": "shell",
700 "description": "Run a shell command",
701 "parameters": {
702 "type": "object",
703 "properties": {
704 "command": { "type": "string" }
705 },
706 "required": ["command"]
707 }
708 }
709 })];
710 let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
711 assert!(result.is_err());
712 assert!(result.unwrap_err().to_string().contains("API key not set"));
713 }
714
715 #[test]
716 fn native_response_parses_usage() {
717 let json = r#"{
718 "choices": [{"message": {"content": "Hello"}}],
719 "usage": {"prompt_tokens": 100, "completion_tokens": 50}
720 }"#;
721 let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
722 let usage = resp.usage.unwrap();
723 assert_eq!(usage.prompt_tokens, Some(100));
724 assert_eq!(usage.completion_tokens, Some(50));
725 }
726
727 #[test]
728 fn capabilities_reports_native_tools_and_vision() {
729 let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
730 let caps = <AzureOpenAiProvider as Provider>::capabilities(&p);
731 assert!(caps.native_tool_calling);
732 assert!(caps.vision);
733 }
734
735 #[test]
736 fn supports_native_tools_returns_true() {
737 let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
738 assert!(p.supports_native_tools());
739 }
740
741 #[test]
742 fn supports_vision_returns_true() {
743 let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
744 assert!(p.supports_vision());
745 }
746
747 #[tokio::test]
748 async fn warmup_is_noop() {
749 let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
750 let result = p.warmup().await;
751 assert!(result.is_ok());
752 }
753
754 #[test]
755 fn custom_api_version_stored() {
756 let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01"));
757 assert_eq!(p.api_version, "2025-01-01");
758 }
759}