aliyun-oss 0.2.0

aliyun oss sdk
Documentation
//! HTTP middleware chain for request processing.

use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;

use super::client::{HttpClient, HttpRequest, HttpResponse};
use crate::error::Result;

/// Trait for middleware that intercepts HTTP requests.
#[async_trait]
pub trait Middleware: Send + Sync {
    /// Handles a request, potentially passing it through the chain.
    async fn handle(&self, request: HttpRequest, chain: &MiddlewareChain) -> Result<HttpResponse>;
}

/// A chain of middleware that processes HTTP requests sequentially.
pub struct MiddlewareChain {
    middlewares: Vec<Box<dyn Middleware>>,
    http: Arc<dyn HttpClient>,
}

impl MiddlewareChain {
    /// Creates a new `MiddlewareChain` with the underlying HTTP client.
    pub fn new(http: Arc<dyn HttpClient>) -> Self {
        Self {
            middlewares: Vec::new(),
            http,
        }
    }

    /// Adds a middleware to the chain.
    pub fn with_middleware(mut self, middleware: impl Middleware + 'static) -> Self {
        self.middlewares.push(Box::new(middleware));
        self
    }

    /// Sends the request through the middleware chain.
    pub async fn send(&self, request: HttpRequest) -> Result<HttpResponse> {
        if self.middlewares.is_empty() {
            return self.http.send(request).await;
        }
        self.middlewares[0].handle(request, self).await
    }

    pub(crate) async fn send_through(
        &self,
        request: HttpRequest,
        start_index: usize,
    ) -> Result<HttpResponse> {
        let next_index = start_index + 1;
        if next_index < self.middlewares.len() {
            self.middlewares[next_index].handle(request, self).await
        } else {
            self.http.send(request).await
        }
    }
}

/// Middleware that signs requests using the configured OSS signer and credentials.
pub struct SigningMiddleware {
    signer: Arc<dyn crate::signer::Signer>,
    credentials: Arc<dyn crate::config::credentials::CredentialsProvider>,
    region: String,
}

impl SigningMiddleware {
    /// Creates a new `SigningMiddleware`.
    pub fn new(
        signer: Arc<dyn crate::signer::Signer>,
        credentials: Arc<dyn crate::config::credentials::CredentialsProvider>,
        region: impl Into<String>,
    ) -> Self {
        Self {
            signer,
            credentials,
            region: region.into(),
        }
    }
}

#[async_trait]
impl Middleware for SigningMiddleware {
    async fn handle(&self, request: HttpRequest, chain: &MiddlewareChain) -> Result<HttpResponse> {
        let creds = self.credentials.credentials().await?;

        let headers: Vec<(String, String)> = request
            .headers
            .iter()
            .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
            .collect();

        let mut signing_request = crate::signer::SigningRequest {
            method: request.method.as_str().to_string(),
            uri: extract_path(&request.uri),
            region: self.region.clone(),
            query_params: Vec::new(),
            headers,
            timestamp: chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string(),
        };

        self.signer.sign(&mut signing_request, &creds)?;

        let mut signed_request = HttpRequest::builder()
            .method(request.method.clone())
            .uri(&request.uri);

        for (key, value) in &signing_request.headers {
            if let (Ok(name), Ok(val)) = (
                http::HeaderName::from_bytes(key.as_bytes()),
                http::HeaderValue::from_str(value),
            ) {
                signed_request = signed_request.header(name, val);
            }
        }

        signed_request = signed_request.body(request.body.clone().unwrap_or_default());

        chain.send_through(signed_request.build(), 0).await
    }
}

pub(crate) fn extract_path(uri: &str) -> String {
    if let Some(pos) = uri.find("://") {
        let after_scheme = &uri[pos + 3..];
        if let Some(path_start) = after_scheme.find('/') {
            let path = &after_scheme[path_start..];
            if let Some(q) = path.find('?') {
                return path[..q].to_string();
            }
            return path.to_string();
        }
        return "/".to_string();
    }
    if uri.starts_with('/') {
        let path = if let Some(q) = uri.find('?') {
            &uri[..q]
        } else {
            uri
        };
        return path.to_string();
    }
    "/".to_string()
}

/// Middleware that sets the User-Agent header on outgoing requests.
pub struct UserAgentMiddleware {
    user_agent: String,
}

impl UserAgentMiddleware {
    /// Creates a new `UserAgentMiddleware` with the given user agent string.
    pub fn new(user_agent: impl Into<String>) -> Self {
        Self {
            user_agent: user_agent.into(),
        }
    }
}

#[async_trait]
impl Middleware for UserAgentMiddleware {
    async fn handle(
        &self,
        mut request: HttpRequest,
        chain: &MiddlewareChain,
    ) -> Result<HttpResponse> {
        request.headers.insert(
            http::HeaderName::from_static("user-agent"),
            http::HeaderValue::from_str(&self.user_agent)
                .unwrap_or(http::HeaderValue::from_static("aliyun-oss")),
        );
        chain.send_through(request, 0).await
    }
}

/// Configuration for retry behavior.
#[derive(Debug, Clone)]
pub struct RetryConfig {
    pub max_retries: u32,
    pub base_delay: Duration,
    pub max_backoff: Duration,
}

impl Default for RetryConfig {
    fn default() -> Self {
        Self {
            max_retries: 3,
            base_delay: Duration::from_millis(100),
            max_backoff: Duration::from_secs(10),
        }
    }
}

impl RetryConfig {
    /// Creates a new `RetryConfig` with default values.
    pub fn new() -> Self {
        Self::default()
    }

    /// Sets the maximum number of retries.
    pub fn with_max_retries(mut self, max: u32) -> Self {
        self.max_retries = max;
        self
    }

    /// Sets the base delay between retries.
    pub fn with_base_delay(mut self, delay: Duration) -> Self {
        self.base_delay = delay;
        self
    }

    /// Sets the maximum backoff duration.
    pub fn with_max_backoff(mut self, backoff: Duration) -> Self {
        self.max_backoff = backoff;
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::http::client::ReqwestHttpClient;

    #[test]
    fn retry_config_default_values() {
        let config = RetryConfig::default();
        assert_eq!(config.max_retries, 3);
        assert_eq!(config.base_delay, Duration::from_millis(100));
        assert_eq!(config.max_backoff, Duration::from_secs(10));
    }

    #[test]
    fn retry_config_builder() {
        let config = RetryConfig::new()
            .with_max_retries(5)
            .with_base_delay(Duration::from_millis(200));
        assert_eq!(config.max_retries, 5);
        assert_eq!(config.base_delay, Duration::from_millis(200));
    }

    #[test]
    fn middleware_chain_no_middleware_passes_to_http() {
        let http = Arc::new(ReqwestHttpClient::default());
        let chain = MiddlewareChain::new(http);
        assert!(chain.middlewares.is_empty());
    }

    #[test]
    fn middleware_chain_with_middleware() {
        let http = Arc::new(ReqwestHttpClient::default());
        let chain =
            MiddlewareChain::new(http).with_middleware(UserAgentMiddleware::new("aliyun-oss/0.1"));
        assert_eq!(chain.middlewares.len(), 1);
    }

    #[test]
    fn extract_path_from_full_url() {
        assert_eq!(
            extract_path("https://oss-cn-hangzhou.aliyuncs.com/bucket/key"),
            "/bucket/key"
        );
        assert_eq!(extract_path("https://oss-cn-hangzhou.aliyuncs.com/"), "/");
    }

    #[test]
    fn extract_path_from_relative() {
        assert_eq!(extract_path("/bucket/key"), "/bucket/key");
        assert_eq!(extract_path("/"), "/");
    }

    #[test]
    fn extract_path_with_query_string() {
        assert_eq!(extract_path("https://example.com/path?a=1&b=2"), "/path");
    }
}