use http::Extensions;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next};
use std::sync::Arc;
#[derive(Clone)]
pub struct AzureAuthMiddleware {
api_key: Arc<String>,
api_version: Arc<String>,
deployment: Arc<Option<String>>,
}
impl AzureAuthMiddleware {
pub fn new(api_key: String, api_version: Option<String>, deployment: Option<String>) -> Self {
Self {
api_key: Arc::new(api_key),
api_version: Arc::new(api_version.unwrap_or_else(|| "2024-02-01".to_string())),
deployment: Arc::new(deployment),
}
}
}
#[async_trait::async_trait]
impl Middleware for AzureAuthMiddleware {
async fn handle(
&self,
mut req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
req.headers_mut()
.insert("api-key", self.api_key.parse().unwrap());
let url = req.url().clone();
let path = url.path();
if let Some(deployment) = self.deployment.as_ref() {
let new_path = if path.starts_with("/v1/") {
path.replacen("/v1/", &format!("/openai/deployments/{deployment}/"), 1)
} else if !path.starts_with("/openai/") {
format!("/openai/deployments/{deployment}{path}")
} else {
path.to_string()
};
if new_path != path {
let mut new_url = url.clone();
new_url.set_path(&new_path);
if new_url.query_pairs().all(|(key, _)| key != "api-version") {
new_url
.query_pairs_mut()
.append_pair("api-version", &self.api_version);
}
*req.url_mut() = new_url;
}
}
next.run(req, extensions).await
}
}