use std::{collections::HashMap, sync::Arc};
#[cfg(feature = "azure_openai")]
use crate::{
builder::LLMBackend,
chat::Tool,
chat::{
ChatMessage, ChatProvider, ChatRole, MessageType, StreamChunk, StreamResponse,
StructuredOutputFormat,
},
completion::{CompletionProvider, CompletionRequest, CompletionResponse},
embedding::EmbeddingProvider,
error::LLMError,
models::{ModelListRequest, ModelListResponse, ModelsProvider, StandardModelListResponse},
stt::SpeechToTextProvider,
tts::TextToSpeechProvider,
LLMProvider,
};
#[cfg(feature = "azure_openai")]
use futures::{Stream, StreamExt};
use crate::{
chat::{ChatResponse, ToolChoice},
FunctionCall, ToolCall,
};
use async_trait::async_trait;
use either::*;
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct AzureOpenAIConfig {
pub api_key: String,
pub api_version: Option<String>,
pub base_url: Url,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
pub timeout_seconds: Option<u64>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
pub embedding_encoding_format: Option<String>,
pub embedding_dimensions: Option<u32>,
pub reasoning_effort: Option<String>,
pub json_schema: Option<StructuredOutputFormat>,
}
#[derive(Debug, Clone)]
pub struct AzureOpenAI {
pub config: Arc<AzureOpenAIConfig>,
pub client: Client,
}
#[derive(Serialize, Debug)]
struct AzureOpenAIChatMessage<'a> {
#[allow(dead_code)]
role: &'a str,
#[serde(
skip_serializing_if = "Option::is_none",
with = "either::serde_untagged_optional"
)]
content: Option<Either<Vec<AzureMessageContent<'a>>, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<AzureOpenAIToolCall<'a>>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
impl<'a> From<&'a ChatMessage> for AzureOpenAIChatMessage<'a> {
fn from(chat_msg: &'a ChatMessage) -> Self {
Self {
role: match chat_msg.role {
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
},
tool_call_id: None,
content: match &chat_msg.message_type {
MessageType::Text => Some(Right(chat_msg.content.clone())),
MessageType::Image(_) => unreachable!(),
MessageType::Pdf(_) => unimplemented!(),
MessageType::Audio(_) => None,
MessageType::ImageURL(url) => {
Some(Left(vec![AzureMessageContent {
message_type: Some("image_url"),
text: None,
image_url: Some(ImageUrlContent { url }),
tool_output: None,
tool_call_id: None,
}]))
}
MessageType::ToolUse(_) => None,
MessageType::ToolResult(_) => None,
},
tool_calls: match &chat_msg.message_type {
MessageType::ToolUse(calls) => {
let owned_calls: Vec<AzureOpenAIToolCall> =
calls.iter().map(|c| c.into()).collect();
Some(owned_calls)
}
_ => None,
},
}
}
}
#[derive(Serialize, Debug)]
struct AzureOpenAIFunctionCall<'a> {
name: &'a str,
arguments: &'a str,
}
impl<'a> From<&'a FunctionCall> for AzureOpenAIFunctionCall<'a> {
fn from(value: &'a FunctionCall) -> Self {
Self {
name: &value.name,
arguments: &value.arguments,
}
}
}
#[derive(Serialize, Debug)]
struct AzureOpenAIToolCall<'a> {
id: &'a str,
#[serde(rename = "type")]
content_type: &'a str,
function: AzureOpenAIFunctionCall<'a>,
}
impl<'a> From<&'a ToolCall> for AzureOpenAIToolCall<'a> {
fn from(value: &'a ToolCall) -> Self {
Self {
id: &value.id,
content_type: "function",
function: AzureOpenAIFunctionCall::from(&value.function),
}
}
}
#[derive(Serialize, Debug)]
struct AzureMessageContent<'a> {
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
message_type: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
image_url: Option<ImageUrlContent<'a>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "tool_call_id")]
tool_call_id: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none", rename = "content")]
tool_output: Option<&'a str>,
}
#[derive(Serialize, Debug)]
struct ImageUrlContent<'a> {
url: &'a str,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest {
model: String,
input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<u32>,
}
#[derive(Serialize, Debug)]
struct AzureOpenAIChatRequest<'a> {
model: &'a str,
messages: Vec<AzureOpenAIChatMessage<'a>>,
#[serde(rename = "max_completion_tokens", skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<OpenAIResponseFormat>,
}
#[derive(Deserialize, Debug)]
struct AzureOpenAIChatResponse {
choices: Vec<AzureOpenAIChatChoice>,
}
#[derive(Deserialize, Debug)]
struct AzureOpenAIChatChoice {
message: AzureOpenAIChatMsg,
}
#[derive(Deserialize, Debug)]
struct AzureOpenAIChatMsg {
#[allow(dead_code)]
role: String,
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Deserialize, Debug)]
struct AzureOpenAIEmbeddingData {
embedding: Vec<f32>,
}
#[derive(Deserialize, Debug)]
struct OpenAIEmbeddingResponse {
data: Vec<AzureOpenAIEmbeddingData>,
}
#[derive(Debug, Default)]
struct AzureToolUseState {
id: String,
name: String,
arguments_buffer: String,
started: bool,
}
#[derive(Debug, Deserialize)]
struct AzureToolStreamChunk {
choices: Vec<AzureToolStreamChoice>,
}
#[derive(Debug, Deserialize)]
struct AzureToolStreamChoice {
delta: AzureToolStreamDelta,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AzureToolStreamDelta {
content: Option<String>,
tool_calls: Option<Vec<AzureToolStreamToolCall>>,
}
#[derive(Debug, Deserialize)]
struct AzureToolStreamToolCall {
index: Option<usize>,
id: Option<String>,
function: AzureToolStreamFunction,
}
#[derive(Debug, Deserialize)]
struct AzureToolStreamFunction {
name: Option<String>,
#[serde(default)]
arguments: String,
}
#[derive(Deserialize, Debug, Serialize)]
enum OpenAIResponseType {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_schema")]
JsonSchema,
#[serde(rename = "json_object")]
JsonObject,
}
#[derive(Deserialize, Debug, Serialize)]
struct OpenAIResponseFormat {
#[serde(rename = "type")]
response_type: OpenAIResponseType,
#[serde(skip_serializing_if = "Option::is_none")]
json_schema: Option<StructuredOutputFormat>,
}
impl From<StructuredOutputFormat> for OpenAIResponseFormat {
fn from(structured_response_format: StructuredOutputFormat) -> Self {
match structured_response_format.schema {
None => OpenAIResponseFormat {
response_type: OpenAIResponseType::JsonSchema,
json_schema: Some(structured_response_format),
},
Some(mut schema) => {
schema = if schema.get("additionalProperties").is_none() {
schema["additionalProperties"] = serde_json::json!(false);
schema
} else {
schema
};
OpenAIResponseFormat {
response_type: OpenAIResponseType::JsonSchema,
json_schema: Some(StructuredOutputFormat {
name: structured_response_format.name,
description: structured_response_format.description,
schema: Some(schema),
strict: structured_response_format.strict,
}),
}
}
}
}
}
impl ChatResponse for AzureOpenAIChatResponse {
fn text(&self) -> Option<String> {
self.choices.first().and_then(|c| c.message.content.clone())
}
fn tool_calls(&self) -> Option<Vec<ToolCall>> {
self.choices
.first()
.and_then(|c| c.message.tool_calls.clone())
}
}
impl std::fmt::Display for AzureOpenAIChatResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Some(choice) = self.choices.first() else {
return Ok(());
};
match (&choice.message.content, &choice.message.tool_calls) {
(Some(content), Some(tool_calls)) => {
for tool_call in tool_calls {
write!(f, "{tool_call}")?;
}
write!(f, "{content}")
}
(Some(content), None) => write!(f, "{content}"),
(None, Some(tool_calls)) => {
for tool_call in tool_calls {
write!(f, "{tool_call}")?;
}
Ok(())
}
(None, None) => Ok(()),
}
}
}
impl AzureOpenAI {
#[allow(clippy::too_many_arguments)]
pub fn new(
api_key: impl Into<String>,
api_version: Option<String>,
deployment_id: impl Into<String>,
endpoint: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
reasoning_effort: Option<String>,
json_schema: Option<StructuredOutputFormat>,
) -> Self {
let mut builder = Client::builder();
if let Some(sec) = timeout_seconds {
builder = builder.timeout(std::time::Duration::from_secs(sec));
}
Self::with_client(
builder.build().expect("Failed to build reqwest Client"),
api_key,
api_version,
deployment_id,
endpoint,
model,
max_tokens,
temperature,
timeout_seconds,
system,
top_p,
top_k,
embedding_encoding_format,
embedding_dimensions,
tools,
tool_choice,
reasoning_effort,
json_schema,
)
}
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
api_version: Option<String>,
deployment_id: impl Into<String>,
endpoint: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
embedding_encoding_format: Option<String>,
embedding_dimensions: Option<u32>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
reasoning_effort: Option<String>,
json_schema: Option<StructuredOutputFormat>,
) -> Self {
let endpoint = endpoint.into();
let deployment_id = deployment_id.into();
Self {
config: Arc::new(AzureOpenAIConfig {
api_key: api_key.into(),
api_version,
base_url: Url::parse(&format!("{endpoint}/openai/v1/"))
.expect("Failed to parse base Url"),
model: model.unwrap_or(deployment_id),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
tools,
tool_choice,
embedding_encoding_format,
embedding_dimensions,
reasoning_effort,
json_schema,
}),
client,
}
}
pub fn api_key(&self) -> &str {
&self.config.api_key
}
pub fn api_version(&self) -> &Option<String> {
&self.config.api_version
}
pub fn base_url(&self) -> &Url {
&self.config.base_url
}
pub fn model(&self) -> &str {
&self.config.model
}
pub fn max_tokens(&self) -> Option<u32> {
self.config.max_tokens
}
pub fn temperature(&self) -> Option<f32> {
self.config.temperature
}
pub fn timeout_seconds(&self) -> Option<u64> {
self.config.timeout_seconds
}
pub fn system(&self) -> Option<&str> {
self.config.system.as_deref()
}
pub fn top_p(&self) -> Option<f32> {
self.config.top_p
}
pub fn top_k(&self) -> Option<u32> {
self.config.top_k
}
pub fn tools(&self) -> Option<&[Tool]> {
self.config.tools.as_deref()
}
pub fn tool_choice(&self) -> Option<&ToolChoice> {
self.config.tool_choice.as_ref()
}
pub fn embedding_encoding_format(&self) -> Option<&str> {
self.config.embedding_encoding_format.as_deref()
}
pub fn embedding_dimensions(&self) -> Option<u32> {
self.config.embedding_dimensions
}
pub fn reasoning_effort(&self) -> Option<&str> {
self.config.reasoning_effort.as_deref()
}
pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
self.config.json_schema.as_ref()
}
pub fn client(&self) -> &Client {
&self.client
}
}
#[cfg(feature = "azure_openai")]
fn parse_azure_sse_chunk_with_tools(
event: &str,
tool_states: &mut HashMap<usize, AzureToolUseState>,
) -> Result<Vec<StreamChunk>, LLMError> {
let mut results = Vec::new();
for line in event.lines() {
let line = line.trim();
let Some(data) = line.strip_prefix("data: ") else {
continue;
};
if data == "[DONE]" {
finish_azure_tool_calls(&mut results, tool_states);
results.push(StreamChunk::Done {
stop_reason: "end_turn".to_string(),
});
return Ok(results);
}
if let Ok(chunk) = serde_json::from_str::<AzureToolStreamChunk>(data) {
for choice in &chunk.choices {
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
results.push(StreamChunk::Text(content.clone()));
}
}
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
let index = tc.index.unwrap_or(0);
let state = tool_states.entry(index).or_default();
if let Some(id) = &tc.id {
state.id = id.clone();
}
if let Some(name) = &tc.function.name {
state.name = name.clone();
if !state.started {
state.started = true;
results.push(StreamChunk::ToolUseStart {
index,
id: state.id.clone(),
name: state.name.clone(),
});
}
}
if !tc.function.arguments.is_empty() {
state.arguments_buffer.push_str(&tc.function.arguments);
results.push(StreamChunk::ToolUseInputDelta {
index,
partial_json: tc.function.arguments.clone(),
});
}
}
}
if let Some(finish_reason) = &choice.finish_reason {
finish_azure_tool_calls(&mut results, tool_states);
let stop_reason = match finish_reason.as_str() {
"tool_calls" => "tool_use",
"stop" => "end_turn",
other => other,
};
results.push(StreamChunk::Done {
stop_reason: stop_reason.to_string(),
});
}
}
}
}
Ok(results)
}
#[cfg(feature = "azure_openai")]
fn finish_azure_tool_calls(
results: &mut Vec<StreamChunk>,
tool_states: &mut HashMap<usize, AzureToolUseState>,
) {
for (index, state) in tool_states.drain() {
if state.started {
results.push(StreamChunk::ToolUseComplete {
index,
tool_call: ToolCall {
id: state.id,
call_type: "function".to_string(),
function: FunctionCall {
name: state.name,
arguments: state.arguments_buffer,
},
},
});
}
}
}
#[cfg(feature = "azure_openai")]
fn create_azure_sse_stream_with_tools(
response: reqwest::Response,
) -> std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>> {
let bytes_stream = response.bytes_stream();
let stream = bytes_stream
.scan(
(String::new(), HashMap::<usize, AzureToolUseState>::new()),
|(event_buffer, tool_states), chunk| {
let results = match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let mut results = Vec::new();
for line in text.lines() {
let line = line.trim_end();
if line.is_empty() {
if !event_buffer.is_empty() {
match parse_azure_sse_chunk_with_tools(event_buffer, tool_states)
{
Ok(chunks) => results.extend(chunks.into_iter().map(Ok)),
Err(e) => results.push(Err(e)),
}
event_buffer.clear();
}
} else {
event_buffer.push_str(line);
event_buffer.push('\n');
}
}
results
}
Err(e) => vec![Err(LLMError::HttpError(e.to_string()))],
};
futures::future::ready(Some(results))
},
)
.flat_map(futures::stream::iter);
Box::pin(stream)
}
const AUDIO_UNSUPPORTED: &str = "Audio messages are not supported by Azure OpenAI chat";
#[async_trait]
impl ChatProvider for AzureOpenAI {
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: Option<&[Tool]>,
) -> Result<Box<dyn ChatResponse>, LLMError> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
}
let mut openai_msgs: Vec<AzureOpenAIChatMessage> = vec![];
for msg in messages {
if let MessageType::ToolResult(ref results) = msg.message_type {
for result in results {
openai_msgs.push(
AzureOpenAIChatMessage {
role: "tool",
tool_call_id: Some(result.id.clone()),
tool_calls: None,
content: Some(Right(result.function.arguments.clone())),
},
);
}
} else {
openai_msgs.push(msg.into())
}
}
if let Some(system) = &self.config.system {
openai_msgs.insert(
0,
AzureOpenAIChatMessage {
role: "system",
content: Some(Left(vec![AzureMessageContent {
message_type: Some("text"),
text: Some(system),
image_url: None,
tool_call_id: None,
tool_output: None,
}])),
tool_calls: None,
tool_call_id: None,
},
);
}
let response_format: Option<OpenAIResponseFormat> =
self.config.json_schema.clone().map(|s| s.into());
let request_tools = tools
.map(|t| t.to_vec())
.or_else(|| self.config.tools.clone());
let request_tool_choice = if request_tools.is_some() {
self.config.tool_choice.clone()
} else {
None
};
let body = AzureOpenAIChatRequest {
model: &self.config.model,
messages: openai_msgs,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: false,
top_p: self.config.top_p,
top_k: self.config.top_k,
tools: request_tools,
tool_choice: request_tool_choice,
reasoning_effort: self.config.reasoning_effort.clone(),
response_format,
};
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("Azure OpenAI request payload: {}", json);
}
}
let mut url = self
.config
.base_url
.join("chat/completions")
.map_err(|e| LLMError::HttpError(e.to_string()))?;
if let Some(api_version) = &self.config.api_version {
url.query_pairs_mut()
.append_pair("api-version", api_version);
}
log::info!("Azure OpenAI HTTP Request {}", url);
let mut request = self
.client
.post(url)
.header("api-key", &self.config.api_key)
.json(&body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let response = request.send().await?;
log::debug!("Azure OpenAI HTTP status: {}", response.status());
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(LLMError::ResponseFormatError {
message: format!("Azure OpenAI API returned error status: {status}"),
raw_response: error_text,
});
}
let resp_text = response.text().await?;
let json_resp: Result<AzureOpenAIChatResponse, serde_json::Error> =
serde_json::from_str(&resp_text);
match json_resp {
Ok(response) => Ok(Box::new(response)),
Err(e) => Err(LLMError::ResponseFormatError {
message: format!("Failed to decode Azure OpenAI API response: {e}"),
raw_response: resp_text,
}),
}
}
async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
self.chat_with_tools(messages, None).await
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<String, LLMError>> + Send>>,
LLMError,
> {
let struct_stream = self.chat_stream_struct(messages).await?;
let content_stream = struct_stream.filter_map(|result| async move {
match result {
Ok(stream_response) => {
if let Some(choice) = stream_response.choices.first() {
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Some(Ok(content.clone()));
}
}
}
None
}
Err(e) => Some(Err(e)),
}
});
Ok(Box::pin(content_stream))
}
async fn chat_stream_struct(
&self,
messages: &[ChatMessage],
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
LLMError,
> {
use crate::providers::openai_compatible::create_sse_stream;
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
}
let mut openai_msgs: Vec<AzureOpenAIChatMessage> = vec![];
for msg in messages {
if let MessageType::ToolResult(ref results) = msg.message_type {
for result in results {
openai_msgs.push(AzureOpenAIChatMessage {
role: "tool",
tool_call_id: Some(result.id.clone()),
tool_calls: None,
content: Some(either::Right(result.function.arguments.clone())),
});
}
} else {
openai_msgs.push(msg.into())
}
}
if let Some(system) = &self.config.system {
openai_msgs.insert(
0,
AzureOpenAIChatMessage {
role: "system",
content: Some(either::Left(vec![AzureMessageContent {
message_type: Some("text"),
text: Some(system),
image_url: None,
tool_call_id: None,
tool_output: None,
}])),
tool_calls: None,
tool_call_id: None,
},
);
}
let response_format: Option<OpenAIResponseFormat> =
self.config.json_schema.clone().map(|s| s.into());
let request_tools = self.config.tools.clone();
let request_tool_choice = if request_tools.is_some() {
self.config.tool_choice.clone()
} else {
None
};
let body = AzureOpenAIChatRequest {
model: &self.config.model,
messages: openai_msgs,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: true,
top_p: self.config.top_p,
top_k: self.config.top_k,
tools: request_tools,
tool_choice: request_tool_choice,
reasoning_effort: self.config.reasoning_effort.clone(),
response_format,
};
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&body) {
log::trace!("Azure OpenAI stream request payload: {}", json);
}
}
let mut url = self
.config
.base_url
.join("chat/completions")
.map_err(|e| LLMError::HttpError(e.to_string()))?;
if let Some(api_version) = &self.config.api_version {
url.query_pairs_mut()
.append_pair("api-version", api_version);
}
log::info!("Azure OpenAI stream HTTP Request {}", url);
let mut request = self
.client
.post(url)
.header("api-key", &self.config.api_key)
.json(&body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let response = request.send().await?;
log::debug!("Azure OpenAI stream HTTP status: {}", response.status());
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(LLMError::ResponseFormatError {
message: format!("Azure OpenAI API returned error status: {status}"),
raw_response: error_text,
});
}
Ok(create_sse_stream(response, false))
}
async fn chat_stream_with_tools(
&self,
messages: &[ChatMessage],
tools: Option<&[Tool]>,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamChunk, LLMError>> + Send>>,
LLMError,
> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
}
let mut openai_msgs: Vec<AzureOpenAIChatMessage> = vec![];
for msg in messages {
if let MessageType::ToolResult(ref results) = msg.message_type {
for result in results {
openai_msgs.push(AzureOpenAIChatMessage {
role: "tool",
tool_call_id: Some(result.id.clone()),
tool_calls: None,
content: Some(either::Right(result.function.arguments.clone())),
});
}
} else {
openai_msgs.push(msg.into())
}
}
if let Some(system) = &self.config.system {
openai_msgs.insert(
0,
AzureOpenAIChatMessage {
role: "system",
content: Some(either::Left(vec![AzureMessageContent {
message_type: Some("text"),
text: Some(system),
image_url: None,
tool_call_id: None,
tool_output: None,
}])),
tool_calls: None,
tool_call_id: None,
},
);
}
let response_format: Option<OpenAIResponseFormat> =
self.config.json_schema.clone().map(|s| s.into());
let effective_tools = tools
.map(|t| t.to_vec())
.or_else(|| self.config.tools.clone());
let request_tool_choice = if effective_tools.is_some() {
self.config.tool_choice.clone()
} else {
None
};
let body = AzureOpenAIChatRequest {
model: &self.config.model,
messages: openai_msgs,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
stream: true,
top_p: self.config.top_p,
top_k: self.config.top_k,
tools: effective_tools,
tool_choice: request_tool_choice,
reasoning_effort: self.config.reasoning_effort.clone(),
response_format,
};
let mut url = self
.config
.base_url
.join("chat/completions")
.map_err(|e| LLMError::HttpError(e.to_string()))?;
if let Some(api_version) = &self.config.api_version {
url.query_pairs_mut()
.append_pair("api-version", api_version);
}
let mut request = self
.client
.post(url)
.header("api-key", &self.config.api_key)
.json(&body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(LLMError::ResponseFormatError {
message: format!("Azure OpenAI API returned error status: {status}"),
raw_response: error_text,
});
}
Ok(create_azure_sse_stream_with_tools(response))
}
}
#[async_trait]
impl CompletionProvider for AzureOpenAI {
async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
Ok(CompletionResponse {
text: "OpenAI completion not implemented.".into(),
})
}
}
#[cfg(feature = "azure_openai")]
#[async_trait]
impl EmbeddingProvider for AzureOpenAI {
async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing OpenAI API key".into()));
}
let emb_format = self
.config
.embedding_encoding_format
.clone()
.unwrap_or_else(|| "float".to_string());
let body = OpenAIEmbeddingRequest {
model: self.config.model.clone(),
input,
encoding_format: Some(emb_format),
dimensions: self.config.embedding_dimensions,
};
let mut url = self
.config
.base_url
.join("embeddings")
.map_err(|e| LLMError::HttpError(e.to_string()))?;
if let Some(api_version) = &self.config.api_version {
url.query_pairs_mut()
.append_pair("api-version", api_version);
}
let resp = self
.client
.post(url)
.header("api-key", &self.config.api_key)
.json(&body)
.send()
.await?
.error_for_status()?;
let json_resp: OpenAIEmbeddingResponse = resp.json().await?;
let embeddings = json_resp.data.into_iter().map(|d| d.embedding).collect();
Ok(embeddings)
}
}
impl LLMProvider for AzureOpenAI {
fn tools(&self) -> Option<&[Tool]> {
self.config.tools.as_deref()
}
}
#[async_trait]
impl SpeechToTextProvider for AzureOpenAI {
async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
Err(LLMError::ProviderError(
"Azure OpenAI does not implement speech to text endpoint yet.".into(),
))
}
}
#[async_trait]
impl TextToSpeechProvider for AzureOpenAI {
async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
Err(LLMError::ProviderError(
"Text to speech not supported".to_string(),
))
}
}
#[async_trait]
impl ModelsProvider for AzureOpenAI {
async fn list_models(
&self,
_request: Option<&ModelListRequest>,
) -> Result<Box<dyn ModelListResponse>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError(
"Missing Azure OpenAI API key".to_string(),
));
}
let mut url = self
.config
.base_url
.join("models")
.map_err(|e| LLMError::HttpError(e.to_string()))?;
if let Some(api_version) = &self.config.api_version {
url.query_pairs_mut()
.append_pair("api-version", api_version);
}
let mut request = self.client.get(url).header("api-key", &self.config.api_key);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let resp = request.send().await?.error_for_status()?;
let result = StandardModelListResponse {
inner: resp.json().await?,
backend: LLMBackend::AzureOpenAI,
};
Ok(Box::new(result))
}
}
#[cfg(all(test, feature = "azure_openai"))]
mod tests {
use super::*;
use reqwest::Client;
#[test]
fn parse_azure_stream_tool_only_deltas() {
let mut tool_states = HashMap::new();
let start = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_abc123","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}"#;
let chunks = parse_azure_sse_chunk_with_tools(start, &mut tool_states).unwrap();
assert!(matches!(
&chunks[0],
StreamChunk::ToolUseStart { index: 0, id, name }
if id == "call_abc123" && name == "get_weather"
));
let delta = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":null}]}"#;
let chunks = parse_azure_sse_chunk_with_tools(delta, &mut tool_states).unwrap();
assert!(matches!(
&chunks[0],
StreamChunk::ToolUseInputDelta { index: 0, partial_json }
if partial_json == "{\"city\":\"Paris\"}"
));
let finish =
r#"data: {"choices":[{"delta":{},"finish_reason":"tool_calls"}]}"#;
let chunks = parse_azure_sse_chunk_with_tools(finish, &mut tool_states).unwrap();
assert_eq!(chunks.len(), 2);
assert!(matches!(
&chunks[0],
StreamChunk::ToolUseComplete { index: 0, tool_call }
if tool_call.id == "call_abc123"
&& tool_call.function.name == "get_weather"
&& tool_call.function.arguments == "{\"city\":\"Paris\"}"
));
assert!(matches!(
&chunks[1],
StreamChunk::Done { stop_reason } if stop_reason == "tool_use"
));
}
#[test]
fn falls_back_to_deployment_id_when_model_is_unset() {
let client = AzureOpenAI::with_client(
Client::new(),
"test-key",
Some("2024-10-21".to_string()),
"my-deployment",
"https://example.openai.azure.com",
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
);
assert_eq!(client.config.model, "my-deployment");
}
}