use serde_json::Value;
const DEFAULT_API_VERSION: &str = "2024-02-15-preview";
#[derive(Debug, Clone)]
pub struct AzureOpenAiAdapter {
base_url: String,
deployment: String,
api_version: String,
api_url: String,
}
impl AzureOpenAiAdapter {
pub fn new(base_url: impl Into<String>, deployment: impl Into<String>) -> Self {
let base_url = base_url.into();
let deployment = deployment.into();
let api_version = DEFAULT_API_VERSION.to_string();
let api_url = build_azure_url(&base_url, &deployment, &api_version);
Self {
base_url,
deployment,
api_version,
api_url,
}
}
pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
self.api_version = version.into();
self.api_url = build_azure_url(&self.base_url, &self.deployment, &self.api_version);
self
}
fn strip_model(payload: &mut Value) {
if let Some(obj) = payload.as_object_mut() {
obj.remove("model");
}
}
}
pub fn build_azure_url(base_url: &str, deployment: &str, api_version: &str) -> String {
let base = base_url.trim_end_matches('/');
format!("{base}/openai/deployments/{deployment}/chat/completions?api-version={api_version}")
}
#[async_trait::async_trait]
impl super::base::ProviderAdapter for AzureOpenAiAdapter {
fn provider_name(&self) -> &str {
"azure"
}
fn convert_request(&self, mut payload: Value) -> Value {
Self::strip_model(&mut payload);
payload
.as_object_mut()
.map(|obj| obj.remove("_reasoning_effort"));
payload
}
fn convert_response(&self, response: Value) -> Value {
response
}
fn api_url(&self) -> &str {
&self.api_url
}
fn extra_headers(&self) -> Vec<(String, String)> {
vec![]
}
}
#[cfg(test)]
#[path = "azure_tests.rs"]
mod tests;