actix-cloud 0.5.1

Actix Cloud is an all-in-one web framework based on Actix Web.
use std::{future::Future, rc::Rc};

use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    HttpMessage, HttpRequest,
};
use futures::future::{ready, LocalBoxFuture, Ready};
use qstring::QString;

use crate::router::CSRFType;

pub struct Middleware<F> {
    cookie: Rc<String>,
    header: Rc<String>,
    checker: Rc<F>,
}

impl<F> Clone for Middleware<F> {
    fn clone(&self) -> Self {
        Self {
            cookie: self.cookie.clone(),
            header: self.header.clone(),
            checker: self.checker.clone(),
        }
    }
}

impl<F, Fut> Middleware<F>
where
    F: Fn(HttpRequest, String) -> Fut,
    Fut: Future<Output = Result<bool, actix_web::Error>>,
{
    pub fn new(cookie: String, header: String, checker: F) -> Self {
        Self {
            cookie: Rc::new(cookie),
            header: Rc::new(header),
            checker: Rc::new(checker),
        }
    }
}

impl<S, B, F, Fut> Transform<S, ServiceRequest> for Middleware<F>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Future: 'static,
    B: 'static,
    F: Fn(HttpRequest, String) -> Fut + 'static,
    Fut: Future<Output = Result<bool, actix_web::Error>>,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type InitError = ();
    type Transform = MiddlewareService<S, F>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(MiddlewareService {
            service: Rc::new(service),
            cookie: self.cookie.clone(),
            header: self.header.clone(),
            checker: self.checker.clone(),
        }))
    }
}

pub struct MiddlewareService<S, F> {
    service: Rc<S>,
    cookie: Rc<String>,
    header: Rc<String>,
    checker: Rc<F>,
}

impl<S, B, F, Fut> MiddlewareService<S, F>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Future: 'static,
    B: 'static,
    F: Fn(HttpRequest, String) -> Fut + 'static,
    Fut: Future<Output = Result<bool, actix_web::Error>>,
{
    fn get_safe_header(req: &ServiceRequest, name: &str) -> Option<String> {
        let mut ret: Vec<&str> = req
            .headers()
            .get_all(name)
            .map(|x| x.to_str().unwrap_or_default())
            .filter(|x| !x.is_empty())
            .collect();
        if ret.len() != 1 {
            return None;
        }
        ret.pop().map(ToOwned::to_owned)
    }

    async fn check_csrf(
        req: &ServiceRequest,
        cookie: &str,
        header: &str,
        checker: Rc<F>,
        allow_param: bool,
    ) -> Result<bool, actix_web::Error> {
        let Some(cookie) = req.cookie(cookie) else {
            return Ok(false);
        };
        let mut csrf = Self::get_safe_header(req, header);
        if csrf.is_none() && allow_param {
            let qs = QString::from(req.query_string());
            csrf = qs.get(header).map(ToOwned::to_owned);
        }
        let Some(csrf) = csrf else {
            return Ok(false);
        };
        if csrf != cookie.value() {
            return Ok(false);
        }
        checker(req.request().clone(), csrf).await
    }
}

impl<S, B, F, Fut> Service<ServiceRequest> for MiddlewareService<S, F>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
    S::Future: 'static,
    B: 'static,
    F: Fn(HttpRequest, String) -> Fut + 'static,
    Fut: Future<Output = Result<bool, actix_web::Error>>,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let srv = self.service.clone();
        let header = self.header.clone();
        let cookie = self.cookie.clone();
        let checker = self.checker.clone();
        Box::pin(async move {
            let csrf = req.extensions().get::<CSRFType>().unwrap().to_owned();
            if csrf.is_force_header() || csrf.is_force_param() || !req.method().is_safe() {
                let ret = match csrf {
                    CSRFType::Header => {
                        Self::check_csrf(&req, &cookie, &header, checker, false).await
                    }
                    CSRFType::Param => {
                        Self::check_csrf(&req, &cookie, &header, checker, true).await
                    }
                    CSRFType::ForceHeader => {
                        Self::check_csrf(&req, &cookie, &header, checker, false).await
                    }
                    CSRFType::ForceParam => {
                        Self::check_csrf(&req, &cookie, &header, checker, true).await
                    }
                    CSRFType::Disabled => Ok(true),
                }?;
                if !ret {
                    return Err(actix_web::error::ErrorBadRequest("CSRF check failed"));
                }
            }
            srv.call(req).await
        })
    }
}