use std::env;
use crate::client::OpenAI;
use crate::config::ClientConfig;
use crate::error::OpenAIError;
const DEFAULT_API_VERSION: &str = "2024-10-21";
#[derive(Debug, Clone, Default)]
pub struct AzureConfig {
pub azure_endpoint: Option<String>,
pub azure_deployment: Option<String>,
pub api_version: Option<String>,
pub api_key: Option<String>,
pub azure_ad_token: Option<String>,
}
impl AzureConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn azure_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.azure_endpoint = Some(endpoint.into());
self
}
#[must_use]
pub fn azure_deployment(mut self, deployment: impl Into<String>) -> Self {
self.azure_deployment = Some(deployment.into());
self
}
#[must_use]
pub fn api_version(mut self, version: impl Into<String>) -> Self {
self.api_version = Some(version.into());
self
}
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn azure_ad_token(mut self, token: impl Into<String>) -> Self {
self.azure_ad_token = Some(token.into());
self
}
pub fn build(self) -> Result<OpenAI, OpenAIError> {
let endpoint = self.azure_endpoint.ok_or_else(|| {
OpenAIError::InvalidArgument(
"Azure endpoint is required. Set azure_endpoint() or AZURE_OPENAI_ENDPOINT env var"
.to_string(),
)
})?;
let api_version = self
.api_version
.unwrap_or_else(|| DEFAULT_API_VERSION.to_string());
if self.api_key.is_some() && self.azure_ad_token.is_some() {
return Err(OpenAIError::InvalidArgument(
"api_key and azure_ad_token are mutually exclusive; only one can be set"
.to_string(),
));
}
let (auth_key, use_azure_api_key_header) = match (&self.api_key, &self.azure_ad_token) {
(Some(key), None) => (key.clone(), true),
(None, Some(token)) => (token.clone(), false),
(None, None) => {
return Err(OpenAIError::InvalidArgument(
"Azure credentials required. Set api_key() or azure_ad_token()".to_string(),
));
}
_ => unreachable!(), };
let base_url = {
let trimmed = endpoint.trim_end_matches('/');
match &self.azure_deployment {
Some(deployment) => format!("{trimmed}/openai/deployments/{deployment}"),
None => format!("{trimmed}/openai"),
}
};
let config = ClientConfig::new(auth_key)
.base_url(base_url)
.default_query(vec![("api-version".to_string(), api_version)])
.use_azure_api_key_header(use_azure_api_key_header);
Ok(OpenAI::with_config(config))
}
pub fn from_env() -> Result<OpenAI, OpenAIError> {
let mut config = Self::new();
if let Ok(endpoint) = env::var("AZURE_OPENAI_ENDPOINT") {
config = config.azure_endpoint(endpoint);
}
if let Ok(key) = env::var("AZURE_OPENAI_API_KEY") {
config = config.api_key(key);
}
if let Ok(token) = env::var("AZURE_OPENAI_AD_TOKEN") {
config = config.azure_ad_token(token);
}
if let Ok(version) = env::var("OPENAI_API_VERSION") {
config = config.api_version(version);
}
config.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_azure_url_with_deployment() {
let client = AzureConfig::new()
.azure_endpoint("https://my-resource.openai.azure.com")
.azure_deployment("gpt-4")
.api_key("test-key")
.build()
.unwrap();
assert_eq!(
client.config.base_url(),
"https://my-resource.openai.azure.com/openai/deployments/gpt-4"
);
}
#[test]
fn test_azure_url_without_deployment() {
let client = AzureConfig::new()
.azure_endpoint("https://my-resource.openai.azure.com")
.api_key("test-key")
.build()
.unwrap();
assert_eq!(
client.config.base_url(),
"https://my-resource.openai.azure.com/openai"
);
}
#[test]
fn test_azure_url_trailing_slash_stripped() {
let client = AzureConfig::new()
.azure_endpoint("https://my-resource.openai.azure.com/")
.azure_deployment("gpt-4")
.api_key("test-key")
.build()
.unwrap();
assert_eq!(
client.config.base_url(),
"https://my-resource.openai.azure.com/openai/deployments/gpt-4"
);
}
#[test]
fn test_azure_default_api_version() {
let client = AzureConfig::new()
.azure_endpoint("https://example.openai.azure.com")
.api_key("test-key")
.build()
.unwrap();
let query = client.options.query.as_ref().unwrap();
assert!(
query
.iter()
.any(|(k, v)| k == "api-version" && v == "2024-10-21")
);
}
#[test]
fn test_azure_custom_api_version() {
let client = AzureConfig::new()
.azure_endpoint("https://example.openai.azure.com")
.api_key("test-key")
.api_version("2024-06-01")
.build()
.unwrap();
let query = client.options.query.as_ref().unwrap();
assert!(
query
.iter()
.any(|(k, v)| k == "api-version" && v == "2024-06-01")
);
}
#[tokio::test]
async fn test_azure_sends_api_version_query_param() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/openai/models")
.match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
"api-version".into(),
"2024-10-21".into(),
)]))
.with_status(200)
.with_body(r#"{"data":[],"object":"list"}"#)
.create_async()
.await;
let client = AzureConfig::new()
.azure_endpoint(&server.url())
.api_key("test-key")
.build()
.unwrap();
#[derive(serde::Deserialize)]
struct ListResp {
object: String,
}
let resp: ListResp = client.get("/models").await.unwrap();
assert_eq!(resp.object, "list");
mock.assert_async().await;
}
#[tokio::test]
async fn test_azure_sends_api_key_header() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/openai/test")
.match_header("api-key", "my-azure-key")
.match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
"api-version".into(),
"2024-10-21".into(),
)]))
.with_status(200)
.with_body(r#"{"ok":true}"#)
.create_async()
.await;
let client = AzureConfig::new()
.azure_endpoint(&server.url())
.api_key("my-azure-key")
.build()
.unwrap();
#[derive(serde::Deserialize)]
struct Resp {
ok: bool,
}
let resp: Resp = client.get("/test").await.unwrap();
assert!(resp.ok);
mock.assert_async().await;
}
#[tokio::test]
async fn test_azure_does_not_send_bearer_auth() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/openai/test")
.match_header("api-key", "my-azure-key")
.match_header("authorization", mockito::Matcher::Missing)
.match_query(mockito::Matcher::Any)
.with_status(200)
.with_body(r#"{"ok":true}"#)
.create_async()
.await;
let client = AzureConfig::new()
.azure_endpoint(&server.url())
.api_key("my-azure-key")
.build()
.unwrap();
#[derive(serde::Deserialize)]
struct Resp {
ok: bool,
}
let resp: Resp = client.get("/test").await.unwrap();
assert!(resp.ok);
mock.assert_async().await;
}
#[tokio::test]
async fn test_azure_ad_token_sends_bearer() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/openai/test")
.match_header("authorization", "Bearer my-ad-token")
.match_query(mockito::Matcher::Any)
.with_status(200)
.with_body(r#"{"ok":true}"#)
.create_async()
.await;
let client = AzureConfig::new()
.azure_endpoint(&server.url())
.azure_ad_token("my-ad-token")
.build()
.unwrap();
#[derive(serde::Deserialize)]
struct Resp {
ok: bool,
}
let resp: Resp = client.get("/test").await.unwrap();
assert!(resp.ok);
mock.assert_async().await;
}
#[test]
fn test_mutual_exclusivity_error() {
let result = AzureConfig::new()
.azure_endpoint("https://example.openai.azure.com")
.api_key("key")
.azure_ad_token("token")
.build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("mutually exclusive"),
"unexpected error: {err}"
);
}
#[test]
fn test_no_credentials_error() {
let result = AzureConfig::new()
.azure_endpoint("https://example.openai.azure.com")
.build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("credentials required"),
"unexpected error: {err}"
);
}
#[test]
fn test_no_endpoint_error() {
let result = AzureConfig::new().api_key("key").build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("endpoint is required"),
"unexpected error: {err}"
);
}
#[test]
fn test_from_env_reads_variables() {
unsafe {
env::set_var("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com");
env::set_var("AZURE_OPENAI_API_KEY", "env-key");
env::set_var("OPENAI_API_VERSION", "2024-06-01");
env::remove_var("AZURE_OPENAI_AD_TOKEN");
}
let client = AzureConfig::from_env().unwrap();
assert_eq!(
client.config.base_url(),
"https://test.openai.azure.com/openai"
);
assert_eq!(client.config.api_key(), "env-key");
let query = client.options.query.as_ref().unwrap();
assert!(
query
.iter()
.any(|(k, v)| k == "api-version" && v == "2024-06-01")
);
unsafe {
env::remove_var("AZURE_OPENAI_ENDPOINT");
env::remove_var("AZURE_OPENAI_API_KEY");
env::remove_var("OPENAI_API_VERSION");
}
}
#[tokio::test]
async fn test_azure_chat_completion_e2e() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/openai/deployments/gpt-4/chat/completions")
.match_header("api-key", "azure-key")
.match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
"api-version".into(),
"2024-10-21".into(),
)]))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"id": "chatcmpl-azure-123",
"object": "chat.completion",
"created": 1700000000,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello from Azure!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}"#,
)
.create_async()
.await;
let client = AzureConfig::new()
.azure_endpoint(&server.url())
.azure_deployment("gpt-4")
.api_key("azure-key")
.build()
.unwrap();
use crate::types::chat::{ChatCompletionMessageParam, ChatCompletionRequest, UserContent};
let request = ChatCompletionRequest::new(
"gpt-4",
vec![ChatCompletionMessageParam::User {
content: UserContent::Text("Hello!".into()),
name: None,
}],
);
let response = client.chat().completions().create(request).await.unwrap();
assert_eq!(response.id, "chatcmpl-azure-123");
assert_eq!(
response.choices[0].message.content.as_deref().unwrap_or(""),
"Hello from Azure!"
);
mock.assert_async().await;
}
}