Skip to main content

aliyun_oss/http/
middleware.rs

1//! HTTP middleware chain for request processing.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7
8use super::client::{HttpClient, HttpRequest, HttpResponse};
9use crate::error::Result;
10
11/// Trait for middleware that intercepts HTTP requests.
12#[async_trait]
13pub trait Middleware: Send + Sync {
14    /// Handles a request, potentially passing it through the chain.
15    async fn handle(&self, request: HttpRequest, chain: &MiddlewareChain) -> Result<HttpResponse>;
16}
17
18/// A chain of middleware that processes HTTP requests sequentially.
19pub struct MiddlewareChain {
20    middlewares: Vec<Box<dyn Middleware>>,
21    http: Arc<dyn HttpClient>,
22}
23
24impl MiddlewareChain {
25    /// Creates a new `MiddlewareChain` with the underlying HTTP client.
26    pub fn new(http: Arc<dyn HttpClient>) -> Self {
27        Self {
28            middlewares: Vec::new(),
29            http,
30        }
31    }
32
33    /// Adds a middleware to the chain.
34    pub fn with_middleware(mut self, middleware: impl Middleware + 'static) -> Self {
35        self.middlewares.push(Box::new(middleware));
36        self
37    }
38
39    /// Sends the request through the middleware chain.
40    pub async fn send(&self, request: HttpRequest) -> Result<HttpResponse> {
41        if self.middlewares.is_empty() {
42            return self.http.send(request).await;
43        }
44        self.middlewares[0].handle(request, self).await
45    }
46
47    pub(crate) async fn send_through(
48        &self,
49        request: HttpRequest,
50        start_index: usize,
51    ) -> Result<HttpResponse> {
52        let next_index = start_index + 1;
53        if next_index < self.middlewares.len() {
54            self.middlewares[next_index].handle(request, self).await
55        } else {
56            self.http.send(request).await
57        }
58    }
59}
60
61/// Middleware that signs requests using the configured OSS signer and credentials.
62pub struct SigningMiddleware {
63    signer: Arc<dyn crate::signer::Signer>,
64    credentials: Arc<dyn crate::config::credentials::CredentialsProvider>,
65    region: String,
66}
67
68impl SigningMiddleware {
69    /// Creates a new `SigningMiddleware`.
70    pub fn new(
71        signer: Arc<dyn crate::signer::Signer>,
72        credentials: Arc<dyn crate::config::credentials::CredentialsProvider>,
73        region: impl Into<String>,
74    ) -> Self {
75        Self {
76            signer,
77            credentials,
78            region: region.into(),
79        }
80    }
81}
82
83#[async_trait]
84impl Middleware for SigningMiddleware {
85    async fn handle(&self, request: HttpRequest, chain: &MiddlewareChain) -> Result<HttpResponse> {
86        let creds = self.credentials.credentials().await?;
87
88        let headers: Vec<(String, String)> = request
89            .headers
90            .iter()
91            .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
92            .collect();
93
94        let mut signing_request = crate::signer::SigningRequest {
95            method: request.method.as_str().to_string(),
96            uri: extract_path(&request.uri),
97            region: self.region.clone(),
98            query_params: Vec::new(),
99            headers,
100            timestamp: chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string(),
101        };
102
103        self.signer.sign(&mut signing_request, &creds)?;
104
105        let mut signed_request = HttpRequest::builder()
106            .method(request.method.clone())
107            .uri(&request.uri);
108
109        for (key, value) in &signing_request.headers {
110            if let (Ok(name), Ok(val)) = (
111                http::HeaderName::from_bytes(key.as_bytes()),
112                http::HeaderValue::from_str(value),
113            ) {
114                signed_request = signed_request.header(name, val);
115            }
116        }
117
118        signed_request = signed_request.body(request.body.clone().unwrap_or_default());
119
120        chain.send_through(signed_request.build(), 0).await
121    }
122}
123
124pub(crate) fn extract_path(uri: &str) -> String {
125    if let Some(pos) = uri.find("://") {
126        let after_scheme = &uri[pos + 3..];
127        if let Some(path_start) = after_scheme.find('/') {
128            let path = &after_scheme[path_start..];
129            if let Some(q) = path.find('?') {
130                return path[..q].to_string();
131            }
132            return path.to_string();
133        }
134        return "/".to_string();
135    }
136    if uri.starts_with('/') {
137        let path = if let Some(q) = uri.find('?') {
138            &uri[..q]
139        } else {
140            uri
141        };
142        return path.to_string();
143    }
144    "/".to_string()
145}
146
147/// Middleware that sets the User-Agent header on outgoing requests.
148pub struct UserAgentMiddleware {
149    user_agent: String,
150}
151
152impl UserAgentMiddleware {
153    /// Creates a new `UserAgentMiddleware` with the given user agent string.
154    pub fn new(user_agent: impl Into<String>) -> Self {
155        Self {
156            user_agent: user_agent.into(),
157        }
158    }
159}
160
161#[async_trait]
162impl Middleware for UserAgentMiddleware {
163    async fn handle(
164        &self,
165        mut request: HttpRequest,
166        chain: &MiddlewareChain,
167    ) -> Result<HttpResponse> {
168        request.headers.insert(
169            http::HeaderName::from_static("user-agent"),
170            http::HeaderValue::from_str(&self.user_agent)
171                .unwrap_or(http::HeaderValue::from_static("aliyun-oss")),
172        );
173        chain.send_through(request, 0).await
174    }
175}
176
177/// Configuration for retry behavior.
178#[derive(Debug, Clone)]
179pub struct RetryConfig {
180    pub max_retries: u32,
181    pub base_delay: Duration,
182    pub max_backoff: Duration,
183}
184
185impl Default for RetryConfig {
186    fn default() -> Self {
187        Self {
188            max_retries: 3,
189            base_delay: Duration::from_millis(100),
190            max_backoff: Duration::from_secs(10),
191        }
192    }
193}
194
195impl RetryConfig {
196    /// Creates a new `RetryConfig` with default values.
197    pub fn new() -> Self {
198        Self::default()
199    }
200
201    /// Sets the maximum number of retries.
202    pub fn with_max_retries(mut self, max: u32) -> Self {
203        self.max_retries = max;
204        self
205    }
206
207    /// Sets the base delay between retries.
208    pub fn with_base_delay(mut self, delay: Duration) -> Self {
209        self.base_delay = delay;
210        self
211    }
212
213    /// Sets the maximum backoff duration.
214    pub fn with_max_backoff(mut self, backoff: Duration) -> Self {
215        self.max_backoff = backoff;
216        self
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::http::client::ReqwestHttpClient;
224
225    #[test]
226    fn retry_config_default_values() {
227        let config = RetryConfig::default();
228        assert_eq!(config.max_retries, 3);
229        assert_eq!(config.base_delay, Duration::from_millis(100));
230        assert_eq!(config.max_backoff, Duration::from_secs(10));
231    }
232
233    #[test]
234    fn retry_config_builder() {
235        let config = RetryConfig::new()
236            .with_max_retries(5)
237            .with_base_delay(Duration::from_millis(200));
238        assert_eq!(config.max_retries, 5);
239        assert_eq!(config.base_delay, Duration::from_millis(200));
240    }
241
242    #[test]
243    fn middleware_chain_no_middleware_passes_to_http() {
244        let http = Arc::new(ReqwestHttpClient::default());
245        let chain = MiddlewareChain::new(http);
246        assert!(chain.middlewares.is_empty());
247    }
248
249    #[test]
250    fn middleware_chain_with_middleware() {
251        let http = Arc::new(ReqwestHttpClient::default());
252        let chain =
253            MiddlewareChain::new(http).with_middleware(UserAgentMiddleware::new("aliyun-oss/0.1"));
254        assert_eq!(chain.middlewares.len(), 1);
255    }
256
257    #[test]
258    fn extract_path_from_full_url() {
259        assert_eq!(
260            extract_path("https://oss-cn-hangzhou.aliyuncs.com/bucket/key"),
261            "/bucket/key"
262        );
263        assert_eq!(extract_path("https://oss-cn-hangzhou.aliyuncs.com/"), "/");
264    }
265
266    #[test]
267    fn extract_path_from_relative() {
268        assert_eq!(extract_path("/bucket/key"), "/bucket/key");
269        assert_eq!(extract_path("/"), "/");
270    }
271
272    #[test]
273    fn extract_path_with_query_string() {
274        assert_eq!(extract_path("https://example.com/path?a=1&b=2"), "/path");
275    }
276}