zero4rs 2.0.0

zero4rs is a powerful, pragmatic, and extremely fast web framework for Rust
Documentation
use std::{
    pin::Pin,
    task::{Context, Poll},
};

use futures::{
    future::{ok, Ready},
    Future,
};

use actix_web::{
    body::EitherBody,
    dev::{Service, ServiceRequest, ServiceResponse, Transform},
    http, web, Error as ActixError, FromRequest, HttpMessage, HttpResponse, Result,
};

use crate::core::auth0::jwt_token::Claims;
use crate::core::auth0::jwt_token::JwtToken;
use crate::core::auth0::user_session::TypedSession;

use crate::core::auth0::Agent;
use crate::core::auth0::Language;
use crate::core::auth0::Requestor;
use crate::core::auth0::UserDetails;
use crate::core::auth0::UserId;

use crate::core::result::R;

use crate::server::AppContext;

use crate::services::user_service::get_user_details;

type Permission = Vec<String>;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Protected {
    AnonymousUser,
    LoginUser(Permission),
    InnerService(Permission),
}

impl Protected {
    pub fn build(roles: Vec<&str>) -> Self {
        Protected::LoginUser(roles.iter().map(|&s| s.to_string()).collect())
    }
}

pub struct Authorized(std::rc::Rc<Required>, web::Data<AppContext>);

impl Authorized {
    pub fn builder(app_state: web::Data<AppContext>, required: Protected) -> Authorized {
        Authorized(std::rc::Rc::new(Required(required)), app_state)
    }
}

#[derive(Debug, Clone)]
struct Required(Protected);

impl Required {
    /// 是否可以匿名访问
    pub fn can_anonymous_access(&self) -> bool {
        self.0 == Protected::AnonymousUser
    }

    pub fn has_any_role(&self, user_name: &str, user_roles: &str) -> bool {
        let allowed_roles = match &self.0 {
            Protected::AnonymousUser => vec![],
            Protected::LoginUser(roles) => roles.clone(),
            Protected::InnerService(roles) => roles.clone(),
        };

        if allowed_roles.is_empty() {
            return true;
        }

        let b = user_roles
            .split(',')
            .any(|x| allowed_roles.contains(&x.trim().to_string()));

        log::debug!(
            "has_any_role: {}, username={}, allowed_roles={:?}, user_roles={:?}",
            b,
            user_name,
            allowed_roles,
            user_roles
        );

        b
    }
}

/// Logger middleware service.
pub struct AuthorizedMiddleware<S> {
    app_state: web::Data<AppContext>,
    required: std::rc::Rc<Required>,
    service: std::rc::Rc<S>,
}

impl<S, B> Transform<S, ServiceRequest> for Authorized
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = S::Error;
    type Transform = AuthorizedMiddleware<S>;
    type InitError = ();
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(AuthorizedMiddleware {
            app_state: self.1.clone(),
            service: std::rc::Rc::new(service),
            required: self.0.clone(),
        })
    }
}

#[allow(clippy::type_complexity)]
impl<S, B> Service<ServiceRequest> for AuthorizedMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

    fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, mut req: ServiceRequest) -> Self::Future {
        let svc = self.service.clone();
        let required = self.required.clone();
        let app_state = self.app_state.clone();

        Box::pin(async move {
            let claims = JwtToken::get_claims(&req);

            let user_id = if let Some(claims) = claims.0.as_ref() {
                if !claims.is_expired() {
                    Some(UserId(claims.sub.to_owned()))
                } else {
                    if !required.can_anonymous_access() {
                        // access denied!
                        let response = reject::<B>(&req, 401, "jwt token is expired!");
                        let (request, _) = req.into_parts();

                        return Ok(ServiceResponse::new(request, response));
                    }

                    insert_user(&req, None, Some(claims));

                    return svc.call(req).await.map(ServiceResponse::map_into_left_body);
                }
            } else {
                if claims.1 == 400 {
                    // access denied!
                    let response = reject::<B>(&req, 401, "jwt token bad!");
                    let (request, _) = req.into_parts();

                    return Ok(ServiceResponse::new(request, response));
                }

                if claims.1 == 401 && !required.can_anonymous_access() {
                    // access denied!
                    let response = reject::<B>(&req, 401, "jwt token expired!");
                    let (request, _) = req.into_parts();

                    return Ok(ServiceResponse::new(request, response));
                }

                if claims.2 && required.can_anonymous_access() {
                    insert_user(&req, None, claims.0.as_ref());

                    return svc.call(req).await.map(ServiceResponse::map_into_left_body);
                }

                // 没有 jwt token, 通过 session 读取 user_id
                session_user_id(&mut req).await
            };

            let (can_access, code) = if let Some(user_id) = user_id {
                // 验证权限, 读取用户信息放入上下文
                match get_user_details(app_state.mysql(), &user_id.0).await {
                    Ok(Some(user_details)) => {
                        if required.can_anonymous_access() {
                            insert_user(&req, Some(user_details), claims.0.as_ref());

                            (true, 200)
                        } else {
                            // 验证用户权限
                            if required
                                .has_any_role(&user_details.user_name, &user_details.user_role)
                            {
                                insert_user(&req, Some(user_details), claims.0.as_ref());

                                (true, 200)
                            } else {
                                (false, 403)
                            }
                        }
                    }
                    Ok(None) | Err(_) => {
                        if required.can_anonymous_access() {
                            insert_user(&req, None, claims.0.as_ref());

                            (true, 200)
                        } else {
                            (false, 401)
                        }
                    }
                }
            } else if required.can_anonymous_access() {
                insert_user(&req, None, claims.0.as_ref());

                (true, 200)
            } else {
                (false, 401)
            };

            if can_access {
                svc.call(req).await.map(ServiceResponse::map_into_left_body)
            } else {
                // access denied!
                let response = reject::<B>(&req, code, "Unauthorized");
                let (request, _) = req.into_parts();

                Ok(ServiceResponse::new(request, response))
            }
        })
    }
}

pub async fn session_user_id(req: &mut ServiceRequest) -> Option<UserId> {
    let session = {
        let (http_request, payload) = req.parts_mut();
        TypedSession::from_request(http_request, payload).await
    };

    match session {
        Ok(session) => match session.get_user_id() {
            Ok(Some(user_id)) => Some(user_id),
            Ok(None) => None,
            Err(e) => {
                log::error!("read_session_failed: error={:?}", e);
                None
            }
        },
        Err(e) => {
            log::error!("read_session_error: error={:?}", e);
            None
        }
    }
}

fn reject<B>(req: &ServiceRequest, code: u16, msg: &str) -> HttpResponse<EitherBody<B>> {
    let content_type = get_header_value(req, "Accept");
    let _upgrade = get_header_value(req, "upgrade");

    #[rustfmt::skip]
    let response = if content_type.contains("json") || _upgrade == "websocket" {
        let r = R::failed(code, msg);

        HttpResponse::Unauthorized().json(r).map_into_right_body()
    } else if code == 401 {

        HttpResponse::Found().insert_header((http::header::LOCATION, "/login")).finish().map_into_right_body()
    } else {

        HttpResponse::Found().insert_header((http::header::LOCATION, "/login#access_denied")).finish().map_into_right_body()
    };

    response
}

#[rustfmt::skip]
fn insert_user(req: &ServiceRequest, user_details: Option<UserDetails>, claims: Option<&Claims>) {
    req.extensions_mut().insert(Agent(get_header_value(req, "User-Agent")));
    req.extensions_mut().insert(Language(get_header_value(req, "Accept-Language")));

    if let Some(user_details) = user_details {
        req.extensions_mut().insert(UserId(user_details.user_id.to_owned()));

        if claims.is_some() {
            req.extensions_mut().insert(Requestor::JwtUser(user_details));
        } else {
            req.extensions_mut().insert(Requestor::LoginUser(user_details));
        }
    } else {
        req.extensions_mut().insert(Requestor::Anonymous("".to_string()));
    }
}

fn get_header_value(req: &ServiceRequest, key: &str) -> String {
    match req.headers().get(key) {
        Some(val) => val.to_str().unwrap_or_default().to_string(),
        _ => "".to_string(),
    }
}