openai_ergonomic/
azure_middleware.rs1use http::Extensions;
7use reqwest::{Request, Response};
8use reqwest_middleware::{Middleware, Next};
9use std::sync::Arc;
10
11#[derive(Clone)]
13pub struct AzureAuthMiddleware {
14 api_key: Arc<String>,
15 api_version: Arc<String>,
16 deployment: Arc<Option<String>>,
17}
18
19impl AzureAuthMiddleware {
20 pub fn new(api_key: String, api_version: Option<String>, deployment: Option<String>) -> Self {
22 Self {
23 api_key: Arc::new(api_key),
24 api_version: Arc::new(api_version.unwrap_or_else(|| "2024-02-01".to_string())),
25 deployment: Arc::new(deployment),
26 }
27 }
28}
29
30#[async_trait::async_trait]
31impl Middleware for AzureAuthMiddleware {
32 async fn handle(
33 &self,
34 mut req: Request,
35 extensions: &mut Extensions,
36 next: Next<'_>,
37 ) -> reqwest_middleware::Result<Response> {
38 req.headers_mut()
40 .insert("api-key", self.api_key.parse().unwrap());
41
42 let url = req.url().clone();
44 let path = url.path();
45
46 if let Some(deployment) = self.deployment.as_ref() {
50 let new_path = if path.starts_with("/v1/") {
51 path.replacen("/v1/", &format!("/openai/deployments/{deployment}/"), 1)
53 } else if !path.starts_with("/openai/") {
54 format!("/openai/deployments/{deployment}{path}")
56 } else {
57 path.to_string()
59 };
60
61 if new_path != path {
62 let mut new_url = url.clone();
63 new_url.set_path(&new_path);
64
65 if new_url.query_pairs().all(|(key, _)| key != "api-version") {
67 new_url
68 .query_pairs_mut()
69 .append_pair("api-version", &self.api_version);
70 }
71
72 *req.url_mut() = new_url;
73 }
74 }
75
76 next.run(req, extensions).await
77 }
78}