Skip to main content

construct/providers/
azure_openai.rs

1use crate::providers::traits::{
2    ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
3    Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
4};
5use crate::tools::ToolSpec;
6use async_trait::async_trait;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9
10const DEFAULT_API_VERSION: &str = "2024-08-01-preview";
11
12pub struct AzureOpenAiProvider {
13    credential: Option<String>,
14    resource_name: String,
15    deployment_name: String,
16    api_version: String,
17    base_url: String,
18}
19
20#[derive(Debug, Serialize)]
21struct ChatRequest {
22    messages: Vec<Message>,
23    temperature: f64,
24}
25
26#[derive(Debug, Serialize)]
27struct Message {
28    role: String,
29    content: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct ChatResponse {
34    choices: Vec<Choice>,
35}
36
37#[derive(Debug, Deserialize)]
38struct Choice {
39    message: ResponseMessage,
40}
41
42#[derive(Debug, Deserialize)]
43struct ResponseMessage {
44    #[serde(default)]
45    content: Option<String>,
46    #[serde(default)]
47    reasoning_content: Option<String>,
48}
49
50impl ResponseMessage {
51    fn effective_content(&self) -> String {
52        match &self.content {
53            Some(c) if !c.is_empty() => c.clone(),
54            _ => self.reasoning_content.clone().unwrap_or_default(),
55        }
56    }
57}
58
59#[derive(Debug, Serialize)]
60struct NativeChatRequest {
61    messages: Vec<NativeMessage>,
62    temperature: f64,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    tools: Option<Vec<NativeToolSpec>>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    tool_choice: Option<String>,
67}
68
69#[derive(Debug, Serialize)]
70struct NativeMessage {
71    role: String,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    content: Option<String>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    tool_call_id: Option<String>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    tool_calls: Option<Vec<NativeToolCall>>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    reasoning_content: Option<String>,
80}
81
82#[derive(Debug, Serialize, Deserialize)]
83struct NativeToolSpec {
84    #[serde(rename = "type")]
85    kind: String,
86    function: NativeToolFunctionSpec,
87}
88
89#[derive(Debug, Serialize, Deserialize)]
90struct NativeToolFunctionSpec {
91    name: String,
92    description: String,
93    parameters: serde_json::Value,
94}
95
96fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result<NativeToolSpec> {
97    let spec: NativeToolSpec = serde_json::from_value(value)
98        .map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?;
99
100    if spec.kind != "function" {
101        anyhow::bail!(
102            "Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'",
103            spec.kind
104        );
105    }
106
107    Ok(spec)
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct NativeToolCall {
112    #[serde(skip_serializing_if = "Option::is_none")]
113    id: Option<String>,
114    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
115    kind: Option<String>,
116    function: NativeFunctionCall,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct NativeFunctionCall {
121    name: String,
122    arguments: String,
123}
124
125#[derive(Debug, Deserialize)]
126struct NativeChatResponse {
127    choices: Vec<NativeChoice>,
128    #[serde(default)]
129    usage: Option<UsageInfo>,
130}
131
132#[derive(Debug, Deserialize)]
133struct UsageInfo {
134    #[serde(default)]
135    prompt_tokens: Option<u64>,
136    #[serde(default)]
137    completion_tokens: Option<u64>,
138}
139
140#[derive(Debug, Deserialize)]
141struct NativeChoice {
142    message: NativeResponseMessage,
143}
144
145#[derive(Debug, Deserialize)]
146struct NativeResponseMessage {
147    #[serde(default)]
148    content: Option<String>,
149    #[serde(default)]
150    reasoning_content: Option<String>,
151    #[serde(default)]
152    tool_calls: Option<Vec<NativeToolCall>>,
153}
154
155impl NativeResponseMessage {
156    fn effective_content(&self) -> Option<String> {
157        match &self.content {
158            Some(c) if !c.is_empty() => Some(c.clone()),
159            _ => self.reasoning_content.clone(),
160        }
161    }
162}
163
164impl AzureOpenAiProvider {
165    pub fn new(
166        credential: Option<&str>,
167        resource_name: &str,
168        deployment_name: &str,
169        api_version: Option<&str>,
170    ) -> Self {
171        let version = api_version.unwrap_or(DEFAULT_API_VERSION);
172        let base_url = format!(
173            "https://{}.openai.azure.com/openai/deployments/{}",
174            resource_name, deployment_name
175        );
176        Self {
177            credential: credential.map(ToString::to_string),
178            resource_name: resource_name.to_string(),
179            deployment_name: deployment_name.to_string(),
180            api_version: version.to_string(),
181            base_url,
182        }
183    }
184
185    fn chat_completions_url(&self) -> String {
186        format!(
187            "{}/chat/completions?api-version={}",
188            self.base_url, self.api_version
189        )
190    }
191
192    fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
193        tools.map(|items| {
194            items
195                .iter()
196                .map(|tool| NativeToolSpec {
197                    kind: "function".to_string(),
198                    function: NativeToolFunctionSpec {
199                        name: tool.name.clone(),
200                        description: tool.description.clone(),
201                        parameters: tool.parameters.clone(),
202                    },
203                })
204                .collect()
205        })
206    }
207
208    fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
209        messages
210            .iter()
211            .map(|m| {
212                if m.role == "assistant" {
213                    if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
214                        if let Some(tool_calls_value) = value.get("tool_calls") {
215                            if let Ok(parsed_calls) =
216                                serde_json::from_value::<Vec<ProviderToolCall>>(
217                                    tool_calls_value.clone(),
218                                )
219                            {
220                                let tool_calls = parsed_calls
221                                    .into_iter()
222                                    .map(|tc| NativeToolCall {
223                                        id: Some(tc.id),
224                                        kind: Some("function".to_string()),
225                                        function: NativeFunctionCall {
226                                            name: tc.name,
227                                            arguments: tc.arguments,
228                                        },
229                                    })
230                                    .collect::<Vec<_>>();
231                                let content = value
232                                    .get("content")
233                                    .and_then(serde_json::Value::as_str)
234                                    .map(ToString::to_string);
235                                let reasoning_content = value
236                                    .get("reasoning_content")
237                                    .and_then(serde_json::Value::as_str)
238                                    .map(ToString::to_string);
239                                return NativeMessage {
240                                    role: "assistant".to_string(),
241                                    content,
242                                    tool_call_id: None,
243                                    tool_calls: Some(tool_calls),
244                                    reasoning_content,
245                                };
246                            }
247                        }
248                    }
249                }
250
251                if m.role == "tool" {
252                    if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
253                        let tool_call_id = value
254                            .get("tool_call_id")
255                            .and_then(serde_json::Value::as_str)
256                            .map(ToString::to_string);
257                        let content = value
258                            .get("content")
259                            .and_then(serde_json::Value::as_str)
260                            .map(ToString::to_string);
261                        return NativeMessage {
262                            role: "tool".to_string(),
263                            content,
264                            tool_call_id,
265                            tool_calls: None,
266                            reasoning_content: None,
267                        };
268                    }
269                }
270
271                NativeMessage {
272                    role: m.role.clone(),
273                    content: Some(m.content.clone()),
274                    tool_call_id: None,
275                    tool_calls: None,
276                    reasoning_content: None,
277                }
278            })
279            .collect()
280    }
281
282    fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
283        let text = message.effective_content();
284        let reasoning_content = message.reasoning_content.clone();
285        let tool_calls = message
286            .tool_calls
287            .unwrap_or_default()
288            .into_iter()
289            .map(|tc| ProviderToolCall {
290                id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
291                name: tc.function.name,
292                arguments: tc.function.arguments,
293            })
294            .collect::<Vec<_>>();
295
296        ProviderChatResponse {
297            text,
298            tool_calls,
299            usage: None,
300            reasoning_content,
301        }
302    }
303
304    fn http_client(&self) -> Client {
305        crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10)
306    }
307}
308
309#[async_trait]
310impl Provider for AzureOpenAiProvider {
311    fn capabilities(&self) -> ProviderCapabilities {
312        ProviderCapabilities {
313            native_tool_calling: true,
314            vision: true,
315            prompt_caching: false,
316        }
317    }
318
319    fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
320        ToolsPayload::OpenAI {
321            tools: tools
322                .iter()
323                .map(|tool| {
324                    serde_json::json!({
325                        "type": "function",
326                        "function": {
327                            "name": tool.name,
328                            "description": tool.description,
329                            "parameters": tool.parameters,
330                        }
331                    })
332                })
333                .collect(),
334        }
335    }
336
337    fn supports_native_tools(&self) -> bool {
338        true
339    }
340
341    fn supports_vision(&self) -> bool {
342        true
343    }
344
345    async fn chat_with_system(
346        &self,
347        system_prompt: Option<&str>,
348        message: &str,
349        _model: &str,
350        temperature: f64,
351    ) -> anyhow::Result<String> {
352        let credential = self.credential.as_ref().ok_or_else(|| {
353            anyhow::anyhow!(
354                "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
355            )
356        })?;
357
358        let mut messages = Vec::new();
359
360        if let Some(sys) = system_prompt {
361            messages.push(Message {
362                role: "system".to_string(),
363                content: sys.to_string(),
364            });
365        }
366
367        messages.push(Message {
368            role: "user".to_string(),
369            content: message.to_string(),
370        });
371
372        let request = ChatRequest {
373            messages,
374            temperature,
375        };
376
377        let response = self
378            .http_client()
379            .post(self.chat_completions_url())
380            .header("api-key", credential.as_str())
381            .json(&request)
382            .send()
383            .await?;
384
385        if !response.status().is_success() {
386            return Err(super::api_error("Azure OpenAI", response).await);
387        }
388
389        let chat_response: ChatResponse = response.json().await?;
390
391        chat_response
392            .choices
393            .into_iter()
394            .next()
395            .map(|c| c.message.effective_content())
396            .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
397    }
398
399    async fn chat(
400        &self,
401        request: ProviderChatRequest<'_>,
402        _model: &str,
403        temperature: f64,
404    ) -> anyhow::Result<ProviderChatResponse> {
405        let credential = self.credential.as_ref().ok_or_else(|| {
406            anyhow::anyhow!(
407                "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
408            )
409        })?;
410
411        let tools = Self::convert_tools(request.tools);
412        let native_request = NativeChatRequest {
413            messages: Self::convert_messages(request.messages),
414            temperature,
415            tool_choice: tools.as_ref().map(|_| "auto".to_string()),
416            tools,
417        };
418
419        let response = self
420            .http_client()
421            .post(self.chat_completions_url())
422            .header("api-key", credential.as_str())
423            .json(&native_request)
424            .send()
425            .await?;
426
427        if !response.status().is_success() {
428            return Err(super::api_error("Azure OpenAI", response).await);
429        }
430
431        let native_response: NativeChatResponse = response.json().await?;
432        let usage = native_response.usage.map(|u| TokenUsage {
433            input_tokens: u.prompt_tokens,
434            output_tokens: u.completion_tokens,
435            cached_input_tokens: None,
436        });
437        let message = native_response
438            .choices
439            .into_iter()
440            .next()
441            .map(|c| c.message)
442            .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
443        let mut result = Self::parse_native_response(message);
444        result.usage = usage;
445        Ok(result)
446    }
447
448    async fn chat_with_tools(
449        &self,
450        messages: &[ChatMessage],
451        tools: &[serde_json::Value],
452        _model: &str,
453        temperature: f64,
454    ) -> anyhow::Result<ProviderChatResponse> {
455        let credential = self.credential.as_ref().ok_or_else(|| {
456            anyhow::anyhow!(
457                "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
458            )
459        })?;
460
461        let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
462            None
463        } else {
464            Some(
465                tools
466                    .iter()
467                    .cloned()
468                    .map(parse_native_tool_spec)
469                    .collect::<Result<Vec<_>, _>>()?,
470            )
471        };
472
473        let native_request = NativeChatRequest {
474            messages: Self::convert_messages(messages),
475            temperature,
476            tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
477            tools: native_tools,
478        };
479
480        let response = self
481            .http_client()
482            .post(self.chat_completions_url())
483            .header("api-key", credential.as_str())
484            .json(&native_request)
485            .send()
486            .await?;
487
488        if !response.status().is_success() {
489            return Err(super::api_error("Azure OpenAI", response).await);
490        }
491
492        let native_response: NativeChatResponse = response.json().await?;
493        let usage = native_response.usage.map(|u| TokenUsage {
494            input_tokens: u.prompt_tokens,
495            output_tokens: u.completion_tokens,
496            cached_input_tokens: None,
497        });
498        let message = native_response
499            .choices
500            .into_iter()
501            .next()
502            .map(|c| c.message)
503            .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
504        let mut result = Self::parse_native_response(message);
505        result.usage = usage;
506        Ok(result)
507    }
508
509    async fn warmup(&self) -> anyhow::Result<()> {
510        // Azure OpenAI does not have a lightweight models endpoint,
511        // so warmup is a no-op to avoid unnecessary API calls.
512        Ok(())
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn url_construction_default_version() {
522        let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None);
523        assert_eq!(
524            p.chat_completions_url(),
525            "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
526        );
527    }
528
529    #[test]
530    fn url_construction_custom_version() {
531        let p = AzureOpenAiProvider::new(
532            Some("test-key"),
533            "my-resource",
534            "gpt-4o",
535            Some("2024-06-01"),
536        );
537        assert_eq!(
538            p.chat_completions_url(),
539            "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01"
540        );
541    }
542
543    #[test]
544    fn url_construction_preserves_resource_and_deployment() {
545        let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None);
546        let url = p.chat_completions_url();
547        assert!(url.contains("contoso-ai.openai.azure.com"));
548        assert!(url.contains("/deployments/my-gpt35-deployment/"));
549        assert!(url.contains("api-version=2024-08-01-preview"));
550    }
551
552    #[test]
553    fn auth_header_uses_api_key_not_bearer() {
554        // This test verifies the provider stores the credential correctly
555        // and that the auth header name is "api-key" (verified via the
556        // implementation in chat_with_system which uses .header("api-key", ...)).
557        let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None);
558        assert_eq!(p.credential.as_deref(), Some("my-azure-key"));
559    }
560
561    #[test]
562    fn creates_with_credential() {
563        let p = AzureOpenAiProvider::new(
564            Some("azure-test-credential"),
565            "resource",
566            "deployment",
567            None,
568        );
569        assert_eq!(p.credential.as_deref(), Some("azure-test-credential"));
570        assert_eq!(p.resource_name, "resource");
571        assert_eq!(p.deployment_name, "deployment");
572        assert_eq!(p.api_version, DEFAULT_API_VERSION);
573    }
574
575    #[test]
576    fn creates_without_credential() {
577        let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
578        assert!(p.credential.is_none());
579    }
580
581    #[tokio::test]
582    async fn chat_fails_without_key() {
583        let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
584        let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
585        assert!(result.is_err());
586        assert!(result.unwrap_err().to_string().contains("API key not set"));
587    }
588
589    #[tokio::test]
590    async fn chat_with_system_fails_without_key() {
591        let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
592        let result = p
593            .chat_with_system(Some("You are Construct"), "test", "gpt-4o", 0.5)
594            .await;
595        assert!(result.is_err());
596    }
597
598    #[test]
599    fn request_serializes_with_system_message() {
600        let req = ChatRequest {
601            messages: vec![
602                Message {
603                    role: "system".to_string(),
604                    content: "You are Construct".to_string(),
605                },
606                Message {
607                    role: "user".to_string(),
608                    content: "hello".to_string(),
609                },
610            ],
611            temperature: 0.7,
612        };
613        let json = serde_json::to_string(&req).unwrap();
614        assert!(json.contains("\"role\":\"system\""));
615        assert!(json.contains("\"role\":\"user\""));
616        // Azure requests should NOT contain a model field (deployment is in the URL)
617        assert!(!json.contains("\"model\""));
618    }
619
620    #[test]
621    fn request_serializes_without_system() {
622        let req = ChatRequest {
623            messages: vec![Message {
624                role: "user".to_string(),
625                content: "hello".to_string(),
626            }],
627            temperature: 0.0,
628        };
629        let json = serde_json::to_string(&req).unwrap();
630        assert!(!json.contains("system"));
631        assert!(json.contains("\"temperature\":0.0"));
632    }
633
634    #[test]
635    fn response_deserializes_single_choice() {
636        let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
637        let resp: ChatResponse = serde_json::from_str(json).unwrap();
638        assert_eq!(resp.choices.len(), 1);
639        assert_eq!(resp.choices[0].message.effective_content(), "Hi!");
640    }
641
642    #[test]
643    fn response_deserializes_empty_choices() {
644        let json = r#"{"choices":[]}"#;
645        let resp: ChatResponse = serde_json::from_str(json).unwrap();
646        assert!(resp.choices.is_empty());
647    }
648
649    #[test]
650    fn response_deserializes_multiple_choices() {
651        let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
652        let resp: ChatResponse = serde_json::from_str(json).unwrap();
653        assert_eq!(resp.choices.len(), 2);
654        assert_eq!(resp.choices[0].message.effective_content(), "A");
655    }
656
657    #[test]
658    fn tool_call_response_parsing() {
659        let json = r#"{"choices":[{"message":{
660            "content":"Let me check",
661            "tool_calls":[{
662                "id":"call_abc123",
663                "type":"function",
664                "function":{"name":"shell","arguments":"{\"command\":\"ls\"}"}
665            }]
666        }}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#;
667        let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
668        let message = resp.choices.into_iter().next().unwrap().message;
669        let parsed = AzureOpenAiProvider::parse_native_response(message);
670        assert_eq!(parsed.text.as_deref(), Some("Let me check"));
671        assert_eq!(parsed.tool_calls.len(), 1);
672        assert_eq!(parsed.tool_calls[0].id, "call_abc123");
673        assert_eq!(parsed.tool_calls[0].name, "shell");
674        assert!(parsed.tool_calls[0].arguments.contains("ls"));
675    }
676
677    #[test]
678    fn tool_call_response_without_id_generates_uuid() {
679        let json = r#"{"choices":[{"message":{
680            "content":null,
681            "tool_calls":[{
682                "function":{"name":"test","arguments":"{}"}
683            }]
684        }}]}"#;
685        let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
686        let message = resp.choices.into_iter().next().unwrap().message;
687        let parsed = AzureOpenAiProvider::parse_native_response(message);
688        assert_eq!(parsed.tool_calls.len(), 1);
689        assert!(!parsed.tool_calls[0].id.is_empty());
690    }
691
692    #[tokio::test]
693    async fn chat_with_tools_fails_without_key() {
694        let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
695        let messages = vec![ChatMessage::user("hello".to_string())];
696        let tools = vec![serde_json::json!({
697            "type": "function",
698            "function": {
699                "name": "shell",
700                "description": "Run a shell command",
701                "parameters": {
702                    "type": "object",
703                    "properties": {
704                        "command": { "type": "string" }
705                    },
706                    "required": ["command"]
707                }
708            }
709        })];
710        let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
711        assert!(result.is_err());
712        assert!(result.unwrap_err().to_string().contains("API key not set"));
713    }
714
715    #[test]
716    fn native_response_parses_usage() {
717        let json = r#"{
718            "choices": [{"message": {"content": "Hello"}}],
719            "usage": {"prompt_tokens": 100, "completion_tokens": 50}
720        }"#;
721        let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
722        let usage = resp.usage.unwrap();
723        assert_eq!(usage.prompt_tokens, Some(100));
724        assert_eq!(usage.completion_tokens, Some(50));
725    }
726
727    #[test]
728    fn capabilities_reports_native_tools_and_vision() {
729        let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
730        let caps = <AzureOpenAiProvider as Provider>::capabilities(&p);
731        assert!(caps.native_tool_calling);
732        assert!(caps.vision);
733    }
734
735    #[test]
736    fn supports_native_tools_returns_true() {
737        let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
738        assert!(p.supports_native_tools());
739    }
740
741    #[test]
742    fn supports_vision_returns_true() {
743        let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
744        assert!(p.supports_vision());
745    }
746
747    #[tokio::test]
748    async fn warmup_is_noop() {
749        let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
750        let result = p.warmup().await;
751        assert!(result.is_ok());
752    }
753
754    #[test]
755    fn custom_api_version_stored() {
756        let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01"));
757        assert_eq!(p.api_version, "2025-01-01");
758    }
759}