use std::borrow::Cow;
use crate::error::{LiterLlmError, Result};
use crate::provider::Provider;
pub struct AzureProvider {
base_url: String,
api_version: String,
}
impl AzureProvider {
#[must_use]
pub fn new() -> Self {
let base_url = std::env::var("AZURE_OPENAI_ENDPOINT")
.or_else(|_| std::env::var("AZURE_ENDPOINT"))
.unwrap_or_default()
.trim_end_matches('/')
.to_owned();
let api_version = std::env::var("AZURE_API_VERSION").unwrap_or_else(|_| "2025-02-01-preview".to_owned());
Self { base_url, api_version }
}
}
impl Default for AzureProvider {
fn default() -> Self {
Self::new()
}
}
impl Provider for AzureProvider {
fn name(&self) -> &str {
"azure"
}
fn base_url(&self) -> &str {
&self.base_url
}
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
Some((Cow::Borrowed("api-key"), Cow::Borrowed(api_key)))
}
fn matches_model(&self, model: &str) -> bool {
model.starts_with("azure/")
}
fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
model.strip_prefix("azure/").unwrap_or(model)
}
fn validate(&self) -> Result<()> {
if self.base_url.is_empty() {
return Err(LiterLlmError::BadRequest {
message: "Azure OpenAI requires a base URL. \
Set AZURE_OPENAI_ENDPOINT=https://{resource}.openai.azure.com \
(or AZURE_ENDPOINT as a fallback)."
.into(),
});
}
Ok(())
}
fn build_url(&self, endpoint_path: &str, model: &str) -> String {
if self.base_url.is_empty() {
return endpoint_path.to_owned();
}
if self.base_url.contains("/openai/deployments/") {
return format!("{}{}?api-version={}", self.base_url, endpoint_path, self.api_version);
}
format!(
"{}/openai/deployments/{}{}?api-version={}",
self.base_url, model, endpoint_path, self.api_version
)
}
fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
if let Some(obj) = body.as_object_mut() {
let model_name = obj.get("model").and_then(|m| m.as_str()).unwrap_or("").to_owned();
obj.remove("model");
if is_o_series_model(&model_name) {
obj.remove("temperature");
obj.remove("top_p");
if model_name == "o1" || model_name.starts_with("o1-") || model_name.starts_with("o1.") {
obj.remove("stream");
obj.remove("stream_options");
}
}
}
Ok(())
}
fn transform_response(&self, body: &mut serde_json::Value) -> Result<()> {
if let Some(choices) = body.pointer("/choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(filter_results) = choice.get("content_filter_results") {
let is_filtered = choice.get("finish_reason").and_then(|fr| fr.as_str()) == Some("content_filter");
if is_filtered && choice.get("message").is_none() {
if let Some(choices_arr) = body.get_mut("choices").and_then(|c| c.as_array_mut())
&& let Some(choice_obj) = choices_arr.first_mut().and_then(|c| c.as_object_mut())
{
choice_obj.insert(
"message".to_owned(),
serde_json::json!({
"role": "assistant",
"content": null,
"refusal": "Content filtered by Azure content safety."
}),
);
}
break;
}
let _ = filter_results;
}
}
}
Ok(())
}
}
fn is_o_series_model(model: &str) -> bool {
for prefix in &["o1", "o3", "o4"] {
if model == *prefix {
return true;
}
if let Some(rest) = model.strip_prefix(prefix)
&& (rest.starts_with('-') || rest.starts_with('.'))
{
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
fn make_provider(base_url: &str, api_version: &str) -> AzureProvider {
AzureProvider {
base_url: base_url.to_owned(),
api_version: api_version.to_owned(),
}
}
#[test]
fn build_url_embeds_deployment_name() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let url = provider.build_url("/chat/completions", "gpt-4");
assert_eq!(
url,
"https://myresource.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-10-21"
);
}
#[test]
fn build_url_includes_api_version_query_param() {
let provider = make_provider("https://example.openai.azure.com", "2025-01-01");
let url = provider.build_url("/chat/completions", "gpt-4o");
assert!(url.contains("?api-version=2025-01-01"), "url = {url}");
}
#[test]
fn build_url_embeddings_endpoint() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let url = provider.build_url("/embeddings", "text-embedding-3-large");
assert_eq!(
url,
"https://myresource.openai.azure.com/openai/deployments/text-embedding-3-large/embeddings?api-version=2024-10-21"
);
}
#[test]
fn build_url_with_trailing_slash_stripped() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let url = provider.build_url("/chat/completions", "gpt-4");
assert!(!url.contains("//openai"), "double slash in url: {url}");
}
#[test]
fn build_url_already_contains_deployments_path() {
let provider = make_provider(
"https://myresource.openai.azure.com/openai/deployments/gpt-4",
"2025-02-01-preview",
);
let url = provider.build_url("/chat/completions", "gpt-4");
assert!(
!url.contains("deployments/gpt-4/openai/deployments"),
"deployment path must not be doubled: {url}"
);
assert!(
url.contains("/openai/deployments/gpt-4/chat/completions"),
"url should contain the deployment path: {url}"
);
}
#[test]
fn build_url_empty_base_returns_fallback() {
let provider = make_provider("", "2024-10-21");
let url = provider.build_url("/chat/completions", "gpt-4");
assert_eq!(url, "/chat/completions");
}
#[test]
fn transform_request_removes_model_field() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "hello"}],
"temperature": 0.7
});
provider.transform_request(&mut body).expect("transform should succeed");
assert!(body.get("model").is_none(), "model should be removed from body");
assert!(body.get("messages").is_some());
assert!(body.get("temperature").is_some());
}
#[test]
fn transform_request_non_object_body_is_noop() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!("not an object");
assert!(provider.transform_request(&mut body).is_ok());
}
#[test]
fn validate_fails_when_base_url_is_empty() {
let provider = make_provider("", "2024-10-21");
let err = provider.validate().expect_err("should fail with empty base_url");
let msg = err.to_string();
assert!(
msg.contains("Azure OpenAI"),
"error message should mention Azure: {msg}"
);
assert!(
msg.contains("AZURE_OPENAI_ENDPOINT"),
"error message should mention env var: {msg}"
);
}
#[test]
fn validate_succeeds_when_base_url_is_set() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
assert!(provider.validate().is_ok());
}
#[test]
fn explicit_base_url_and_api_version_are_stored() {
let provider = make_provider("https://test.openai.azure.com", "2099-01-01");
assert_eq!(provider.base_url, "https://test.openai.azure.com");
assert_eq!(provider.api_version, "2099-01-01");
}
#[test]
fn default_api_version_is_preview() {
let provider = make_provider("https://test.openai.azure.com", "2025-02-01-preview");
assert_eq!(provider.api_version, "2025-02-01-preview");
}
#[test]
fn strip_model_prefix_removes_azure_prefix() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
assert_eq!(provider.strip_model_prefix("azure/gpt-4"), "gpt-4");
assert_eq!(provider.strip_model_prefix("gpt-4"), "gpt-4");
}
#[test]
fn matches_model_only_for_azure_prefix() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
assert!(provider.matches_model("azure/gpt-4"));
assert!(provider.matches_model("azure/gpt-4o-mini"));
assert!(!provider.matches_model("gpt-4"));
assert!(!provider.matches_model("openai/gpt-4"));
}
#[test]
fn auth_header_uses_api_key_scheme() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let (name, _value) = provider.auth_header("test-key").expect("should return Some");
assert_eq!(name.as_ref(), "api-key");
}
#[test]
fn is_o_series_model_detection() {
assert!(super::is_o_series_model("o1"));
assert!(super::is_o_series_model("o1-preview"));
assert!(super::is_o_series_model("o1-mini"));
assert!(super::is_o_series_model("o3"));
assert!(super::is_o_series_model("o3-mini"));
assert!(super::is_o_series_model("o3.5"));
assert!(super::is_o_series_model("o4"));
assert!(super::is_o_series_model("o4-mini"));
assert!(!super::is_o_series_model("gpt-4"));
assert!(!super::is_o_series_model("gpt-4o"));
assert!(!super::is_o_series_model("o2"));
assert!(!super::is_o_series_model("opt-1"));
}
#[test]
fn transform_request_o_series_removes_temperature_and_top_p() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"model": "o3-mini",
"messages": [{"role": "user", "content": "hello"}],
"temperature": 0.7,
"top_p": 0.9,
"reasoning_effort": "high"
});
provider.transform_request(&mut body).expect("transform should succeed");
assert!(body.get("model").is_none());
assert!(
body.get("temperature").is_none(),
"temperature should be removed for O-series"
);
assert!(body.get("top_p").is_none(), "top_p should be removed for O-series");
assert_eq!(body.get("reasoning_effort").unwrap(), "high");
assert!(body.get("messages").is_some());
}
#[test]
fn transform_request_o1_removes_stream() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"model": "o1-preview",
"messages": [{"role": "user", "content": "hello"}],
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.5
});
provider.transform_request(&mut body).expect("transform should succeed");
assert!(body.get("stream").is_none(), "stream should be removed for o1");
assert!(
body.get("stream_options").is_none(),
"stream_options should be removed for o1"
);
assert!(
body.get("temperature").is_none(),
"temperature should be removed for O-series"
);
}
#[test]
fn transform_request_o3_keeps_stream() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"model": "o3-mini",
"messages": [{"role": "user", "content": "hello"}],
"stream": true
});
provider.transform_request(&mut body).expect("transform should succeed");
assert!(body.get("stream").is_some(), "stream should remain for o3");
}
#[test]
fn transform_request_non_o_series_keeps_all_params() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "hello"}],
"temperature": 0.7,
"top_p": 0.9,
"stream": true
});
provider.transform_request(&mut body).expect("transform should succeed");
assert!(body.get("model").is_none(), "model should be removed");
assert!(
body.get("temperature").is_some(),
"temperature should be kept for non-O-series"
);
assert!(body.get("top_p").is_some(), "top_p should be kept for non-O-series");
assert!(body.get("stream").is_some(), "stream should be kept for non-O-series");
}
#[test]
fn transform_response_passthrough_normal() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}]
});
let original = body.clone();
provider
.transform_response(&mut body)
.expect("transform should succeed");
assert_eq!(body, original, "normal responses should pass through unchanged");
}
#[test]
fn transform_response_content_filter_with_message() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"id": "chatcmpl-123",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": ""},
"finish_reason": "content_filter",
"content_filter_results": {
"hate": {"filtered": true, "severity": "high"}
}
}]
});
provider
.transform_response(&mut body)
.expect("transform should succeed");
assert_eq!(body["choices"][0]["finish_reason"], "content_filter");
assert!(body["choices"][0]["message"].is_object());
}
#[test]
fn transform_response_content_filter_blocked_no_message() {
let provider = make_provider("https://myresource.openai.azure.com", "2024-10-21");
let mut body = json!({
"id": "chatcmpl-123",
"choices": [{
"index": 0,
"finish_reason": "content_filter",
"content_filter_results": {
"hate": {"filtered": true, "severity": "high"}
}
}]
});
provider
.transform_response(&mut body)
.expect("transform should succeed");
let message = &body["choices"][0]["message"];
assert_eq!(message["role"], "assistant");
assert!(message["content"].is_null());
assert!(message["refusal"].as_str().unwrap().contains("Content filtered"));
}
}