pub const DEPLOYMENTS_ENDPOINTS: [&str; 8] = [
"/completions",
"/chat/completions",
"/embeddings",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
"/images/edits",
];
#[derive(Clone)]
pub enum AzureAuth {
ApiKey(String),
BearerToken(String),
}
impl std::fmt::Debug for AzureAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ApiKey(_) => f.write_str("ApiKey([REDACTED])"),
Self::BearerToken(_) => f.write_str("BearerToken([REDACTED])"),
}
}
}
impl AzureAuth {
pub fn header(&self) -> (&'static str, String) {
match self {
AzureAuth::ApiKey(key) => ("api-key", key.clone()),
AzureAuth::BearerToken(token) => ("authorization", format!("Bearer {token}")),
}
}
}
pub fn azure_base_url(endpoint: &str, deployment: Option<&str>) -> String {
let endpoint = endpoint.trim_end_matches('/');
match deployment {
Some(dep) => format!("{endpoint}/openai/deployments/{dep}"),
None => format!("{endpoint}/openai"),
}
}
pub fn rewrite_path(path: &str, has_deployment: bool, model_from_body: Option<&str>) -> String {
if !has_deployment && DEPLOYMENTS_ENDPOINTS.contains(&path) {
if let Some(model) = model_from_body {
return format!("/deployments/{model}{path}");
}
}
path.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn base_url_with_deployment() {
assert_eq!(
azure_base_url("https://example.openai.azure.com", Some("gpt-4o")),
"https://example.openai.azure.com/openai/deployments/gpt-4o"
);
}
#[test]
fn base_url_without_deployment() {
assert_eq!(
azure_base_url("https://example.openai.azure.com", None),
"https://example.openai.azure.com/openai"
);
}
#[test]
fn base_url_trims_trailing_slash() {
assert_eq!(
azure_base_url("https://example.openai.azure.com/", Some("dep")),
"https://example.openai.azure.com/openai/deployments/dep"
);
assert_eq!(
azure_base_url("https://example.openai.azure.com/", None),
"https://example.openai.azure.com/openai"
);
}
#[test]
fn rewrite_path_adds_deployment_for_known_endpoint_with_model() {
assert_eq!(
rewrite_path("/chat/completions", false, Some("gpt-4o")),
"/deployments/gpt-4o/chat/completions"
);
assert_eq!(
rewrite_path("/embeddings", false, Some("text-embedding-3-small")),
"/deployments/text-embedding-3-small/embeddings"
);
}
#[test]
fn rewrite_path_unchanged_when_deployment_configured() {
assert_eq!(
rewrite_path("/chat/completions", true, Some("gpt-4o")),
"/chat/completions"
);
}
#[test]
fn rewrite_path_unchanged_without_model() {
assert_eq!(
rewrite_path("/chat/completions", false, None),
"/chat/completions"
);
}
#[test]
fn rewrite_path_unchanged_for_non_deployment_endpoint() {
assert_eq!(
rewrite_path("/models", false, Some("gpt-4o")),
"/models"
);
assert_eq!(
rewrite_path("/files", false, Some("gpt-4o")),
"/files"
);
}
#[test]
fn auth_header_api_key() {
let (name, value) = AzureAuth::ApiKey("secret-key".into()).header();
assert_eq!(name, "api-key");
assert_eq!(value, "secret-key");
}
#[test]
fn auth_header_bearer_token() {
let (name, value) = AzureAuth::BearerToken("ad-token".into()).header();
assert_eq!(name, "authorization");
assert_eq!(value, "Bearer ad-token");
}
#[test]
fn deployments_endpoints_match_python() {
assert_eq!(DEPLOYMENTS_ENDPOINTS.len(), 8);
assert!(DEPLOYMENTS_ENDPOINTS.contains(&"/audio/speech"));
assert!(DEPLOYMENTS_ENDPOINTS.contains(&"/images/edits"));
}
}