1#[cfg(feature = "azure_openai")]
6use crate::{
7 chat::Tool,
8 chat::{ChatMessage, ChatProvider, ChatRole, MessageType, StructuredOutputFormat},
9 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
10 embedding::EmbeddingProvider,
11 error::LLMError,
12 stt::SpeechToTextProvider,
13 tts::TextToSpeechProvider,
14 LLMProvider,
15};
16use crate::{
17 chat::{ChatResponse, ToolChoice},
18 FunctionCall, ToolCall,
19};
20use async_trait::async_trait;
21use either::*;
22use reqwest::{Client, Url};
23use serde::{Deserialize, Serialize};
24
25pub struct AzureOpenAI {
29 pub api_key: String,
30 pub api_version: String,
31 pub base_url: Url,
32 pub model: String,
33 pub max_tokens: Option<u32>,
34 pub temperature: Option<f32>,
35 pub system: Option<String>,
36 pub timeout_seconds: Option<u64>,
37 pub stream: Option<bool>,
38 pub top_p: Option<f32>,
39 pub top_k: Option<u32>,
40 pub tools: Option<Vec<Tool>>,
41 pub tool_choice: Option<ToolChoice>,
42 pub embedding_encoding_format: Option<String>,
44 pub embedding_dimensions: Option<u32>,
45 pub reasoning_effort: Option<String>,
46 pub json_schema: Option<StructuredOutputFormat>,
48 client: Client,
49}
50
51#[derive(Serialize, Debug)]
53struct AzureOpenAIChatMessage<'a> {
54 #[allow(dead_code)]
55 role: &'a str,
56 #[serde(
57 skip_serializing_if = "Option::is_none",
58 with = "either::serde_untagged_optional"
59 )]
60 content: Option<Either<Vec<AzureMessageContent<'a>>, String>>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 tool_calls: Option<Vec<AzureOpenAIToolCall<'a>>>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 tool_call_id: Option<String>,
65}
66
67impl<'a> From<&'a ChatMessage> for AzureOpenAIChatMessage<'a> {
68 fn from(chat_msg: &'a ChatMessage) -> Self {
69 Self {
70 role: match chat_msg.role {
71 ChatRole::User => "user",
72 ChatRole::Assistant => "assistant",
73 },
74 tool_call_id: None,
75 content: match &chat_msg.message_type {
76 MessageType::Text => Some(Right(chat_msg.content.clone())),
77 MessageType::Image(_) => unreachable!(),
79 MessageType::Pdf(_) => unimplemented!(),
80 MessageType::ImageURL(url) => {
81 Some(Left(vec![AzureMessageContent {
84 message_type: Some("image_url"),
85 text: None,
86 image_url: Some(ImageUrlContent { url }),
87 tool_output: None,
88 tool_call_id: None,
89 }]))
90 }
91 MessageType::ToolUse(_) => None,
92 MessageType::ToolResult(_) => None,
93 },
94 tool_calls: match &chat_msg.message_type {
95 MessageType::ToolUse(calls) => {
96 let owned_calls: Vec<AzureOpenAIToolCall> =
97 calls.iter().map(|c| c.into()).collect();
98 Some(owned_calls)
99 }
100 _ => None,
101 },
102 }
103 }
104}
105
106#[derive(Serialize, Debug)]
107struct AzureOpenAIFunctionCall<'a> {
108 name: &'a str,
109 arguments: &'a str,
110}
111
112impl<'a> From<&'a FunctionCall> for AzureOpenAIFunctionCall<'a> {
113 fn from(value: &'a FunctionCall) -> Self {
114 Self {
115 name: &value.name,
116 arguments: &value.arguments,
117 }
118 }
119}
120
121#[derive(Serialize, Debug)]
122struct AzureOpenAIToolCall<'a> {
123 id: &'a str,
124 #[serde(rename = "type")]
125 content_type: &'a str,
126 function: AzureOpenAIFunctionCall<'a>,
127}
128
129impl<'a> From<&'a ToolCall> for AzureOpenAIToolCall<'a> {
130 fn from(value: &'a ToolCall) -> Self {
131 Self {
132 id: &value.id,
133 content_type: "function",
134 function: AzureOpenAIFunctionCall::from(&value.function),
135 }
136 }
137}
138
139#[derive(Serialize, Debug)]
140struct AzureMessageContent<'a> {
141 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
142 message_type: Option<&'a str>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 text: Option<&'a str>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 image_url: Option<ImageUrlContent<'a>>,
147 #[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
148 tool_call_id: Option<&'a str>,
149 #[serde(skip_serializing_if = "Option::is_none", rename = "content")]
150 tool_output: Option<&'a str>,
151}
152
153#[derive(Serialize, Debug)]
155struct ImageUrlContent<'a> {
156 url: &'a str,
157}
158
159#[derive(Serialize)]
160struct OpenAIEmbeddingRequest {
161 model: String,
162 input: Vec<String>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 encoding_format: Option<String>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 dimensions: Option<u32>,
167}
168
169#[derive(Serialize, Debug)]
171struct AzureOpenAIChatRequest<'a> {
172 model: &'a str,
173 messages: Vec<AzureOpenAIChatMessage<'a>>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 max_tokens: Option<u32>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 temperature: Option<f32>,
178 stream: bool,
179 #[serde(skip_serializing_if = "Option::is_none")]
180 top_p: Option<f32>,
181 #[serde(skip_serializing_if = "Option::is_none")]
182 top_k: Option<u32>,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 tools: Option<Vec<Tool>>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 tool_choice: Option<ToolChoice>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 reasoning_effort: Option<String>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 response_format: Option<OpenAIResponseFormat>,
191}
192
193#[derive(Deserialize, Debug)]
195struct AzureOpenAIChatResponse {
196 choices: Vec<AzureOpenAIChatChoice>,
197}
198
199#[derive(Deserialize, Debug)]
201struct AzureOpenAIChatChoice {
202 message: AzureOpenAIChatMsg,
203}
204
205#[derive(Deserialize, Debug)]
207struct AzureOpenAIChatMsg {
208 #[allow(dead_code)]
209 role: String,
210 content: Option<String>,
211 tool_calls: Option<Vec<ToolCall>>,
212}
213
214#[derive(Deserialize, Debug)]
215struct AzureOpenAIEmbeddingData {
216 embedding: Vec<f32>,
217}
218#[derive(Deserialize, Debug)]
219struct OpenAIEmbeddingResponse {
220 data: Vec<AzureOpenAIEmbeddingData>,
221}
222
223#[derive(Deserialize, Debug, Serialize)]
227enum OpenAIResponseType {
228 #[serde(rename = "text")]
229 Text,
230 #[serde(rename = "json_schema")]
231 JsonSchema,
232 #[serde(rename = "json_object")]
233 JsonObject,
234}
235
236#[derive(Deserialize, Debug, Serialize)]
237struct OpenAIResponseFormat {
238 #[serde(rename = "type")]
239 response_type: OpenAIResponseType,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 json_schema: Option<StructuredOutputFormat>,
242}
243
244impl From<StructuredOutputFormat> for OpenAIResponseFormat {
245 fn from(structured_response_format: StructuredOutputFormat) -> Self {
247 match structured_response_format.schema {
250 None => OpenAIResponseFormat {
251 response_type: OpenAIResponseType::JsonSchema,
252 json_schema: Some(structured_response_format),
253 },
254 Some(mut schema) => {
255 schema = if schema.get("additionalProperties").is_none() {
258 schema["additionalProperties"] = serde_json::json!(false);
259 schema
260 } else {
261 schema
262 };
263
264 OpenAIResponseFormat {
265 response_type: OpenAIResponseType::JsonSchema,
266 json_schema: Some(StructuredOutputFormat {
267 name: structured_response_format.name,
268 description: structured_response_format.description,
269 schema: Some(schema),
270 strict: structured_response_format.strict,
271 }),
272 }
273 }
274 }
275 }
276}
277
278impl ChatResponse for AzureOpenAIChatResponse {
279 fn text(&self) -> Option<String> {
280 self.choices.first().and_then(|c| c.message.content.clone())
281 }
282
283 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
284 self.choices
285 .first()
286 .and_then(|c| c.message.tool_calls.clone())
287 }
288}
289
290impl std::fmt::Display for AzureOpenAIChatResponse {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match (
293 &self.choices.first().unwrap().message.content,
294 &self.choices.first().unwrap().message.tool_calls,
295 ) {
296 (Some(content), Some(tool_calls)) => {
297 for tool_call in tool_calls {
298 write!(f, "{}", tool_call)?;
299 }
300 write!(f, "{}", content)
301 }
302 (Some(content), None) => write!(f, "{}", content),
303 (None, Some(tool_calls)) => {
304 for tool_call in tool_calls {
305 write!(f, "{}", tool_call)?;
306 }
307 Ok(())
308 }
309 (None, None) => write!(f, ""),
310 }
311 }
312}
313
314impl AzureOpenAI {
315 #[allow(clippy::too_many_arguments)]
335 pub fn new(
336 api_key: impl Into<String>,
337 api_version: impl Into<String>,
338 deployment_id: impl Into<String>,
339 endpoint: impl Into<String>,
340 model: Option<String>,
341 max_tokens: Option<u32>,
342 temperature: Option<f32>,
343 timeout_seconds: Option<u64>,
344 system: Option<String>,
345 stream: Option<bool>,
346 top_p: Option<f32>,
347 top_k: Option<u32>,
348 embedding_encoding_format: Option<String>,
349 embedding_dimensions: Option<u32>,
350 tools: Option<Vec<Tool>>,
351 tool_choice: Option<ToolChoice>,
352 reasoning_effort: Option<String>,
353 json_schema: Option<StructuredOutputFormat>,
354 ) -> Self {
355 let mut builder = Client::builder();
356 if let Some(sec) = timeout_seconds {
357 builder = builder.timeout(std::time::Duration::from_secs(sec));
358 }
359
360 let endpoint = endpoint.into();
361 let deployment_id = deployment_id.into();
362
363 Self {
364 api_key: api_key.into(),
365 api_version: api_version.into(),
366 base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
367 .expect("Failed to parse base Url"),
368 model: model.unwrap_or("gpt-3.5-turbo".to_string()),
369 max_tokens,
370 temperature,
371 system,
372 timeout_seconds,
373 stream,
374 top_p,
375 top_k,
376 tools,
377 tool_choice,
378 embedding_encoding_format,
379 embedding_dimensions,
380 client: builder.build().expect("Failed to build reqwest Client"),
381 reasoning_effort,
382 json_schema,
383 }
384 }
385}
386
387#[async_trait]
388impl ChatProvider for AzureOpenAI {
389 async fn chat_with_tools(
399 &self,
400 messages: &[ChatMessage],
401 tools: Option<&[Tool]>,
402 ) -> Result<Box<dyn ChatResponse>, LLMError> {
403 if self.api_key.is_empty() {
404 return Err(LLMError::AuthError(
405 "Missing Azure OpenAI API key".to_string(),
406 ));
407 }
408
409 let mut openai_msgs: Vec<AzureOpenAIChatMessage> = vec![];
410
411 for msg in messages {
412 if let MessageType::ToolResult(ref results) = msg.message_type {
413 for result in results {
414 openai_msgs.push(
415 AzureOpenAIChatMessage {
417 role: "tool",
418 tool_call_id: Some(result.id.clone()),
419 tool_calls: None,
420 content: Some(Right(result.function.arguments.clone())),
421 },
422 );
423 }
424 } else {
425 openai_msgs.push(msg.into())
426 }
427 }
428
429 if let Some(system) = &self.system {
430 openai_msgs.insert(
431 0,
432 AzureOpenAIChatMessage {
433 role: "system",
434 content: Some(Left(vec![AzureMessageContent {
435 message_type: Some("text"),
436 text: Some(system),
437 image_url: None,
438 tool_call_id: None,
439 tool_output: None,
440 }])),
441 tool_calls: None,
442 tool_call_id: None,
443 },
444 );
445 }
446
447 let response_format: Option<OpenAIResponseFormat> =
449 self.json_schema.clone().map(|s| s.into());
450
451 let request_tools = tools.map(|t| t.to_vec()).or_else(|| self.tools.clone());
452 let request_tool_choice = if request_tools.is_some() {
453 self.tool_choice.clone()
454 } else {
455 None
456 };
457
458 let body = AzureOpenAIChatRequest {
459 model: &self.model,
460 messages: openai_msgs,
461 max_tokens: self.max_tokens,
462 temperature: self.temperature,
463 stream: self.stream.unwrap_or(false),
464 top_p: self.top_p,
465 top_k: self.top_k,
466 tools: request_tools,
467 tool_choice: request_tool_choice,
468 reasoning_effort: self.reasoning_effort.clone(),
469 response_format,
470 };
471
472 if log::log_enabled!(log::Level::Trace) {
473 if let Ok(json) = serde_json::to_string(&body) {
474 log::trace!("Azure OpenAI request payload: {}", json);
475 }
476 }
477
478 let mut url = self
479 .base_url
480 .join("chat/completions")
481 .map_err(|e| LLMError::HttpError(e.to_string()))?;
482
483 url.query_pairs_mut()
484 .append_pair("api-version", &self.api_version);
485
486 let mut request = self
487 .client
488 .post(url)
489 .header("api-key", &self.api_key)
490 .json(&body);
491
492 if let Some(timeout) = self.timeout_seconds {
493 request = request.timeout(std::time::Duration::from_secs(timeout));
494 }
495
496 let response = request.send().await?;
498
499 log::debug!("Azure OpenAI HTTP status: {}", response.status());
500
501 if !response.status().is_success() {
503 let status = response.status();
504 let error_text = response.text().await?;
505 return Err(LLMError::ResponseFormatError {
506 message: format!("OpenAI API returned error status: {}", status),
507 raw_response: error_text,
508 });
509 }
510
511 let resp_text = response.text().await?;
513 let json_resp: Result<AzureOpenAIChatResponse, serde_json::Error> =
514 serde_json::from_str(&resp_text);
515
516 match json_resp {
517 Ok(response) => Ok(Box::new(response)),
518 Err(e) => Err(LLMError::ResponseFormatError {
519 message: format!("Failed to decode Azure OpenAI API response: {}", e),
520 raw_response: resp_text,
521 }),
522 }
523 }
524
525 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
526 self.chat_with_tools(messages, None).await
527 }
528}
529
530#[async_trait]
531impl CompletionProvider for AzureOpenAI {
532 async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
536 Ok(CompletionResponse {
537 text: "OpenAI completion not implemented.".into(),
538 })
539 }
540}
541
542#[cfg(feature = "azure_openai")]
543#[async_trait]
544impl EmbeddingProvider for AzureOpenAI {
545 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
546 if self.api_key.is_empty() {
547 return Err(LLMError::AuthError("Missing OpenAI API key".into()));
548 }
549
550 let emb_format = self
551 .embedding_encoding_format
552 .clone()
553 .unwrap_or_else(|| "float".to_string());
554
555 let body = OpenAIEmbeddingRequest {
556 model: self.model.clone(),
557 input,
558 encoding_format: Some(emb_format),
559 dimensions: self.embedding_dimensions,
560 };
561
562 let mut url = self
563 .base_url
564 .join("embeddings")
565 .map_err(|e| LLMError::HttpError(e.to_string()))?;
566
567 url.query_pairs_mut()
568 .append_pair("api-version", &self.api_version);
569
570 let resp = self
571 .client
572 .post(url)
573 .header("api-key", &self.api_key)
574 .json(&body)
575 .send()
576 .await?
577 .error_for_status()?;
578
579 let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
580
581 let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
582 Ok(embeddings)
583 }
584}
585
586impl LLMProvider for AzureOpenAI {
587 fn tools(&self) -> Option<&[Tool]> {
588 self.tools.as_deref()
589 }
590}
591
592#[async_trait]
593impl SpeechToTextProvider for AzureOpenAI {
594 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
595 Err(LLMError::ProviderError(
596 "Azure OpenAI does not implement speech to text endpoint yet.".into(),
597 ))
598 }
599}
600
601#[async_trait]
602impl TextToSpeechProvider for AzureOpenAI {
603 async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
604 Err(LLMError::ProviderError(
605 "Text to speech not supported".to_string(),
606 ))
607 }
608}