llm/backends/
azure_openai.rs

1//! Azure OpenAI API client implementation for chat and completion functionality.
2//!
3//! This module provides integration with Azure OpenAI's GPT models through their API.
4
5#[cfg(feature = "azure_openai")]
6use crate::{
7    chat::Tool,
8    chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
9    completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10    embedding::EmbeddingProvider,
11    error::LLMError,
12    stt::SpeechToTextProvider,
13    tts::TextToSpeechProvider,
14    LLMProvider,
15};
16use crate::{
17    chat::{ChatResponse, ToolChoice},
18    FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use either::*;
22use reqwest::{Client, Url};
23use serde::{Deserialize, Serialize};
24
25/// Client for interacting with Azure OpenAI's API.
26///
27/// Provides methods for chat and completion requests using Azure OpenAI's models.
28pub struct AzureOpenAI {
29    pub api_key: String,
30    pub api_version: String,
31    pub base_url: Url,
32    pub model: String,
33    pub max_tokens: Option<u32>,
34    pub temperature: Option<f32>,
35    pub system: Option<String>,
36    pub timeout_seconds: Option<u64>,
37    pub stream: Option<bool>,
38    pub top_p: Option<f32>,
39    pub top_k: Option<u32>,
40    pub tools: Option<Vec<Tool>>,
41    pub tool_choice: Option<ToolChoice>,
42    /// Embedding parameters
43    pub embedding_encoding_format: Option<String>,
44    pub embedding_dimensions: Option<u32>,
45    pub reasoning_effort: Option<String>,
46    /// JSON schema for structured output
47    pub json_schema: Option<StructuredOutputFormat>,
48    client: Client,
49}
50
51/// Individual message in an OpenAI chat conversation.
52#[derive(Serialize, Debug)]
53struct AzureOpenAIChatMessage<'a> {
54    #[allow(dead_code)]
55    role: &'a str,
56    #[serde(
57        skip_serializing_if = "Option::is_none",
58        with = "either::serde_untagged_optional"
59    )]
60    content: Option<Either<Vec<AzureMessageContent<'a>>, String>>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    tool_calls: Option<Vec<AzureOpenAIToolCall<'a>>>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    tool_call_id: Option<String>,
65}
66
67impl<'a> From<&'a ChatMessage> for AzureOpenAIChatMessage<'a> {
68    fn from(chat_msg: &'a ChatMessage) -> Self {
69        Self {
70            role: match chat_msg.role {
71                ChatRole::User => "user",
72                ChatRole::Assistant => "assistant",
73            },
74            tool_call_id: None,
75            content: match &chat_msg.message_type {
76                MessageType::Text => Some(Right(chat_msg.content.clone())),
77                // Image case is handled separately above
78                MessageType::Image(_) => unreachable!(),
79                MessageType::Pdf(_) => unimplemented!(),
80                MessageType::ImageURL(url) => {
81                    // Clone the URL to create an owned version
82
83                    Some(Left(vec![AzureMessageContent {
84                        message_type: Some("image_url"),
85                        text: None,
86                        image_url: Some(ImageUrlContent { url }),
87                        tool_output: None,
88                        tool_call_id: None,
89                    }]))
90                }
91                MessageType::ToolUse(_) => None,
92                MessageType::ToolResult(_) => None,
93            },
94            tool_calls: match &chat_msg.message_type {
95                MessageType::ToolUse(calls) => {
96                    let owned_calls: Vec<AzureOpenAIToolCall> =
97                        calls.iter().map(|c| c.into()).collect();
98                    Some(owned_calls)
99                }
100                _ => None,
101            },
102        }
103    }
104}
105
106#[derive(Serialize, Debug)]
107struct AzureOpenAIFunctionCall<'a> {
108    name: &'a str,
109    arguments: &'a str,
110}
111
112impl<'a> From<&'a FunctionCall> for AzureOpenAIFunctionCall<'a> {
113    fn from(value: &'a FunctionCall) -> Self {
114        Self {
115            name: &value.name,
116            arguments: &value.arguments,
117        }
118    }
119}
120
121#[derive(Serialize, Debug)]
122struct AzureOpenAIToolCall<'a> {
123    id: &'a str,
124    #[serde(rename = "type")]
125    content_type: &'a str,
126    function: AzureOpenAIFunctionCall<'a>,
127}
128
129impl<'a> From<&'a ToolCall> for AzureOpenAIToolCall<'a> {
130    fn from(value: &'a ToolCall) -> Self {
131        Self {
132            id: &value.id,
133            content_type: "function",
134            function: AzureOpenAIFunctionCall::from(&value.function),
135        }
136    }
137}
138
139#[derive(Serialize, Debug)]
140struct AzureMessageContent<'a> {
141    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
142    message_type: Option<&'a str>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    text: Option<&'a str>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    image_url: Option<ImageUrlContent<'a>>,
147    #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
148    tool_call_id: Option<&'a str>,
149    #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
150    tool_output: Option<&'a str>,
151}
152
153/// Individual image message in an OpenAI chat conversation.
154#[derive(Serialize, Debug)]
155struct ImageUrlContent<'a> {
156    url: &'a str,
157}
158
159#[derive(Serialize)]
160struct OpenAIEmbeddingRequest {
161    model: String,
162    input: Vec<String>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    encoding_format: Option<String>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    dimensions: Option<u32>,
167}
168
169/// Request payload for Azure OpenAI's chat API endpoint.
170#[derive(Serialize, Debug)]
171struct AzureOpenAIChatRequest<'a> {
172    model: &'a str,
173    messages: Vec<AzureOpenAIChatMessage<'a>>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    max_tokens: Option<u32>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    temperature: Option<f32>,
178    stream: bool,
179    #[serde(skip_serializing_if = "Option::is_none")]
180    top_p: Option<f32>,
181    #[serde(skip_serializing_if = "Option::is_none")]
182    top_k: Option<u32>,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    tools: Option<Vec<Tool>>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    tool_choice: Option<ToolChoice>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    reasoning_effort: Option<String>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    response_format: Option<OpenAIResponseFormat>,
191}
192
193/// Response from OpenAI's chat API endpoint.
194#[derive(Deserialize, Debug)]
195struct AzureOpenAIChatResponse {
196    choices: Vec<AzureOpenAIChatChoice>,
197}
198
199/// Individual choice within an OpenAI chat API response.
200#[derive(Deserialize, Debug)]
201struct AzureOpenAIChatChoice {
202    message: AzureOpenAIChatMsg,
203}
204
205/// Message content within an OpenAI chat API response.
206#[derive(Deserialize, Debug)]
207struct AzureOpenAIChatMsg {
208    #[allow(dead_code)]
209    role: String,
210    content: Option<String>,
211    tool_calls: Option<Vec<ToolCall>>,
212}
213
214#[derive(Deserialize, Debug)]
215struct AzureOpenAIEmbeddingData {
216    embedding: Vec<f32>,
217}
218#[derive(Deserialize, Debug)]
219struct OpenAIEmbeddingResponse {
220    data: Vec<AzureOpenAIEmbeddingData>,
221}
222
223/// An object specifying the format that the model must output.
224///Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
225/// Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it.
226#[derive(Deserialize, Debug, Serialize)]
227enum OpenAIResponseType {
228    #[serde(rename = "text")]
229    Text,
230    #[serde(rename = "json_schema")]
231    JsonSchema,
232    #[serde(rename = "json_object")]
233    JsonObject,
234}
235
236#[derive(Deserialize, Debug, Serialize)]
237struct OpenAIResponseFormat {
238    #[serde(rename = "type")]
239    response_type: OpenAIResponseType,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    json_schema: Option<StructuredOutputFormat>,
242}
243
244impl From<StructuredOutputFormat> for OpenAIResponseFormat {
245    /// Modify the schema to ensure that it meets OpenAI's requirements.
246    fn from(structured_response_format: StructuredOutputFormat) -> Self {
247        // It's possible to pass a StructuredOutputJsonSchema without an actual schema.
248        // In this case, just pass the StructuredOutputJsonSchema object without modifying it.
249        match structured_response_format.schema {
250            None => OpenAIResponseFormat {
251                response_type: OpenAIResponseType::JsonSchema,
252                json_schema: Some(structured_response_format),
253            },
254            Some(mut schema) => {
255                // Although [OpenAI's specifications](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat#additionalproperties-false-must-always-be-set-in-objects) say that the "additionalProperties" field is required, my testing shows that it is not.
256                // Just to be safe, add it to the schema if it is missing.
257                schema = if schema.get("additionalProperties").is_none() {
258                    schema["additionalProperties"] = serde_json::json!(false);
259                    schema
260                } else {
261                    schema
262                };
263
264                OpenAIResponseFormat {
265                    response_type: OpenAIResponseType::JsonSchema,
266                    json_schema: Some(StructuredOutputFormat {
267                        name: structured_response_format.name,
268                        description: structured_response_format.description,
269                        schema: Some(schema),
270                        strict: structured_response_format.strict,
271                    }),
272                }
273            }
274        }
275    }
276}
277
278impl ChatResponse for AzureOpenAIChatResponse {
279    fn text(&self) -> Option<String> {
280        self.choices.first().and_then(|c| c.message.content.clone())
281    }
282
283    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
284        self.choices
285            .first()
286            .and_then(|c| c.message.tool_calls.clone())
287    }
288}
289
290impl std::fmt::Display for AzureOpenAIChatResponse {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        match (
293            &self.choices.first().unwrap().message.content,
294            &self.choices.first().unwrap().message.tool_calls,
295        ) {
296            (Some(content), Some(tool_calls)) => {
297                for tool_call in tool_calls {
298                    write!(f, "{}", tool_call)?;
299                }
300                write!(f, "{}", content)
301            }
302            (Some(content), None) => write!(f, "{}", content),
303            (None, Some(tool_calls)) => {
304                for tool_call in tool_calls {
305                    write!(f, "{}", tool_call)?;
306                }
307                Ok(())
308            }
309            (None, None) => write!(f, ""),
310        }
311    }
312}
313
314impl AzureOpenAI {
315    /// Creates a new OpenAI client with the specified configuration.
316    ///
317    /// # Arguments
318    ///
319    /// * `api_key` - OpenAI API key
320    /// * `model` - Model to use (defaults to "gpt-3.5-turbo")
321    /// * `max_tokens` - Maximum tokens to generate
322    /// * `temperature` - Sampling temperature
323    /// * `timeout_seconds` - Request timeout in seconds
324    /// * `system` - System prompt
325    /// * `stream` - Whether to stream responses
326    /// * `top_p` - Top-p sampling parameter
327    /// * `top_k` - Top-k sampling parameter
328    /// * `embedding_encoding_format` - Format for embedding outputs
329    /// * `embedding_dimensions` - Dimensions for embedding vectors
330    /// * `tools` - Function tools that the model can use
331    /// * `tool_choice` - Determines how the model uses tools
332    /// * `reasoning_effort` - Reasoning effort level
333    /// * `json_schema` - JSON schema for structured output
334    #[allow(clippy::too_many_arguments)]
335    pub fn new(
336        api_key: impl Into<String>,
337        api_version: impl Into<String>,
338        deployment_id: impl Into<String>,
339        endpoint: impl Into<String>,
340        model: Option<String>,
341        max_tokens: Option<u32>,
342        temperature: Option<f32>,
343        timeout_seconds: Option<u64>,
344        system: Option<String>,
345        stream: Option<bool>,
346        top_p: Option<f32>,
347        top_k: Option<u32>,
348        embedding_encoding_format: Option<String>,
349        embedding_dimensions: Option<u32>,
350        tools: Option<Vec<Tool>>,
351        tool_choice: Option<ToolChoice>,
352        reasoning_effort: Option<String>,
353        json_schema: Option<StructuredOutputFormat>,
354    ) -> Self {
355        let mut builder = Client::builder();
356        if let Some(sec) = timeout_seconds {
357            builder = builder.timeout(std::time::Duration::from_secs(sec));
358        }
359
360        let endpoint = endpoint.into();
361        let deployment_id = deployment_id.into();
362
363        Self {
364            api_key: api_key.into(),
365            api_version: api_version.into(),
366            base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
367                .expect("Failed to parse base Url"),
368            model: model.unwrap_or("gpt-3.5-turbo".to_string()),
369            max_tokens,
370            temperature,
371            system,
372            timeout_seconds,
373            stream,
374            top_p,
375            top_k,
376            tools,
377            tool_choice,
378            embedding_encoding_format,
379            embedding_dimensions,
380            client: builder.build().expect("Failed to build reqwest Client"),
381            reasoning_effort,
382            json_schema,
383        }
384    }
385}
386
387#[async_trait]
388impl ChatProvider for AzureOpenAI {
389    /// Sends a chat request to OpenAI's API.
390    ///
391    /// # Arguments
392    ///
393    /// * `messages` - Slice of chat messages representing the conversation
394    /// * `tools` - Optional slice of tools to use in the chat
395    /// # Returns
396    ///
397    /// The model's response text or an error
398    async fn chat_with_tools(
399        &self,
400        messages: &[ChatMessage],
401        tools: Option<&[Tool]>,
402    ) -> Result<Box<dyn ChatResponse>, LLMError> {
403        if self.api_key.is_empty() {
404            return Err(LLMError::AuthError(
405                "Missing Azure OpenAI API key".to_string(),
406            ));
407        }
408
409        let mut openai_msgs: Vec<AzureOpenAIChatMessage> = vec![];
410
411        for msg in messages {
412            if let MessageType::ToolResult(ref results) = msg.message_type {
413                for result in results {
414                    openai_msgs.push(
415                        // Clone strings to own them
416                        AzureOpenAIChatMessage {
417                            role: "tool",
418                            tool_call_id: Some(result.id.clone()),
419                            tool_calls: None,
420                            content: Some(Right(result.function.arguments.clone())),
421                        },
422                    );
423                }
424            } else {
425                openai_msgs.push(msg.into())
426            }
427        }
428
429        if let Some(system) = &self.system {
430            openai_msgs.insert(
431                0,
432                AzureOpenAIChatMessage {
433                    role: "system",
434                    content: Some(Left(vec![AzureMessageContent {
435                        message_type: Some("text"),
436                        text: Some(system),
437                        image_url: None,
438                        tool_call_id: None,
439                        tool_output: None,
440                    }])),
441                    tool_calls: None,
442                    tool_call_id: None,
443                },
444            );
445        }
446
447        // Build the response format object
448        let response_format: Option<OpenAIResponseFormat> =
449            self.json_schema.clone().map(|s| s.into());
450
451        let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
452        let request_tool_choice = if request_tools.is_some() {
453            self.tool_choice.clone()
454        } else {
455            None
456        };
457
458        let body = AzureOpenAIChatRequest {
459            model: &self.model,
460            messages: openai_msgs,
461            max_tokens: self.max_tokens,
462            temperature: self.temperature,
463            stream: self.stream.unwrap_or(false),
464            top_p: self.top_p,
465            top_k: self.top_k,
466            tools: request_tools,
467            tool_choice: request_tool_choice,
468            reasoning_effort: self.reasoning_effort.clone(),
469            response_format,
470        };
471
472        if log::log_enabled!(log::Level::Trace) {
473            if let Ok(json) = serde_json::to_string(&body) {
474                log::trace!("Azure OpenAI request payload: {}", json);
475            }
476        }
477
478        let mut url = self
479            .base_url
480            .join("chat/completions")
481            .map_err(|e| LLMError::HttpError(e.to_string()))?;
482
483        url.query_pairs_mut()
484            .append_pair("api-version", &self.api_version);
485
486        let mut request = self
487            .client
488            .post(url)
489            .header("api-key", &self.api_key)
490            .json(&body);
491
492        if let Some(timeout) = self.timeout_seconds {
493            request = request.timeout(std::time::Duration::from_secs(timeout));
494        }
495
496        // Send the request
497        let response = request.send().await?;
498
499        log::debug!("Azure OpenAI HTTP status: {}", response.status());
500
501        // If we got a non-200 response, let's get the error details
502        if !response.status().is_success() {
503            let status = response.status();
504            let error_text = response.text().await?;
505            return Err(LLMError::ResponseFormatError {
506                message: format!("OpenAI API returned error status: {}", status),
507                raw_response: error_text,
508            });
509        }
510
511        // Parse the successful response
512        let resp_text = response.text().await?;
513        let json_resp: Result<AzureOpenAIChatResponse, serde_json::Error> =
514            serde_json::from_str(&resp_text);
515
516        match json_resp {
517            Ok(response) => Ok(Box::new(response)),
518            Err(e) => Err(LLMError::ResponseFormatError {
519                message: format!("Failed to decode Azure OpenAI API response: {}", e),
520                raw_response: resp_text,
521            }),
522        }
523    }
524
525    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
526        self.chat_with_tools(messages, None).await
527    }
528}
529
530#[async_trait]
531impl CompletionProvider for AzureOpenAI {
532    /// Sends a completion request to OpenAI's API.
533    ///
534    /// Currently not implemented.
535    async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
536        Ok(CompletionResponse {
537            text: "OpenAI completion not implemented.".into(),
538        })
539    }
540}
541
542#[cfg(feature = "azure_openai")]
543#[async_trait]
544impl EmbeddingProvider for AzureOpenAI {
545    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
546        if self.api_key.is_empty() {
547            return Err(LLMError::AuthError("Missing OpenAI API key".into()));
548        }
549
550        let emb_format = self
551            .embedding_encoding_format
552            .clone()
553            .unwrap_or_else(|| "float".to_string());
554
555        let body = OpenAIEmbeddingRequest {
556            model: self.model.clone(),
557            input,
558            encoding_format: Some(emb_format),
559            dimensions: self.embedding_dimensions,
560        };
561
562        let mut url = self
563            .base_url
564            .join("embeddings")
565            .map_err(|e| LLMError::HttpError(e.to_string()))?;
566
567        url.query_pairs_mut()
568            .append_pair("api-version", &self.api_version);
569
570        let resp = self
571            .client
572            .post(url)
573            .header("api-key", &self.api_key)
574            .json(&body)
575            .send()
576            .await?
577            .error_for_status()?;
578
579        let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
580
581        let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
582        Ok(embeddings)
583    }
584}
585
586impl LLMProvider for AzureOpenAI {
587    fn tools(&self) -> Option<&[Tool]> {
588        self.tools.as_deref()
589    }
590}
591
592#[async_trait]
593impl SpeechToTextProvider for AzureOpenAI {
594    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
595        Err(LLMError::ProviderError(
596            "Azure OpenAI does not implement speech to text endpoint yet.".into(),
597        ))
598    }
599}
600
601#[async_trait]
602impl TextToSpeechProvider for AzureOpenAI {
603    async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
604        Err(LLMError::ProviderError(
605            "Text to speech not supported".to_string(),
606        ))
607    }
608}