openai_ergonomic/
azure_middleware.rs

1//! Middleware for Azure `OpenAI` authentication.
2//!
3//! This module provides middleware that adds the appropriate authentication
4//! headers for Azure `OpenAI` API requests and transforms paths to Azure format.
5
6use http::Extensions;
7use reqwest::{Request, Response};
8use reqwest_middleware::{Middleware, Next};
9use std::sync::Arc;
10
11/// Middleware that adds Azure `OpenAI` authentication headers and transforms paths.
12#[derive(Clone)]
13pub struct AzureAuthMiddleware {
14    api_key: Arc<String>,
15    api_version: Arc<String>,
16    deployment: Arc<Option<String>>,
17}
18
19impl AzureAuthMiddleware {
20    /// Create a new Azure authentication middleware.
21    pub fn new(api_key: String, api_version: Option<String>, deployment: Option<String>) -> Self {
22        Self {
23            api_key: Arc::new(api_key),
24            api_version: Arc::new(api_version.unwrap_or_else(|| "2024-02-01".to_string())),
25            deployment: Arc::new(deployment),
26        }
27    }
28}
29
30#[async_trait::async_trait]
31impl Middleware for AzureAuthMiddleware {
32    async fn handle(
33        &self,
34        mut req: Request,
35        extensions: &mut Extensions,
36        next: Next<'_>,
37    ) -> reqwest_middleware::Result<Response> {
38        // Add api-key header for Azure OpenAI
39        req.headers_mut()
40            .insert("api-key", self.api_key.parse().unwrap());
41
42        // Transform the URL path for Azure `OpenAI`
43        let url = req.url().clone();
44        let path = url.path();
45
46        // Azure `OpenAI` uses paths like: /openai/deployments/{deployment-id}/chat/completions
47        // Standard `OpenAI` uses: /v1/chat/completions
48        // We need to transform both /v1/* and /* to /openai/deployments/{deployment}/*
49        if let Some(deployment) = self.deployment.as_ref() {
50            let new_path = if path.starts_with("/v1/") {
51                // Handle /v1/chat/completions -> /openai/deployments/{deployment}/chat/completions
52                path.replacen("/v1/", &format!("/openai/deployments/{deployment}/"), 1)
53            } else if !path.starts_with("/openai/") {
54                // Handle /chat/completions -> /openai/deployments/{deployment}/chat/completions
55                format!("/openai/deployments/{deployment}{path}")
56            } else {
57                // Path already in correct format
58                path.to_string()
59            };
60
61            if new_path != path {
62                let mut new_url = url.clone();
63                new_url.set_path(&new_path);
64
65                // Add api-version as query parameter if not already present
66                if new_url.query_pairs().all(|(key, _)| key != "api-version") {
67                    new_url
68                        .query_pairs_mut()
69                        .append_pair("api-version", &self.api_version);
70                }
71
72                *req.url_mut() = new_url;
73            }
74        }
75
76        next.run(req, extensions).await
77    }
78}