use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::{
future::{ok, Ready},
Future,
};
use actix_web::{
body::EitherBody,
dev::{Service, ServiceRequest, ServiceResponse, Transform},
http, web, Error as ActixError, FromRequest, HttpMessage, HttpResponse, Result,
};
use crate::core::auth0::jwt_token::Claims;
use crate::core::auth0::jwt_token::JwtToken;
use crate::core::auth0::user_session::TypedSession;
use crate::core::auth0::Agent;
use crate::core::auth0::Language;
use crate::core::auth0::Requestor;
use crate::core::auth0::UserDetails;
use crate::core::auth0::UserId;
use crate::core::result::R;
use crate::server::AppContext;
use crate::services::user_service::get_user_details;
type Permission = Vec<String>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Protected {
AnonymousUser,
LoginUser(Permission),
InnerService(Permission),
}
impl Protected {
pub fn build(roles: Vec<&str>) -> Self {
Protected::LoginUser(roles.iter().map(|&s| s.to_string()).collect())
}
}
pub struct Authorized(std::rc::Rc<Required>, web::Data<AppContext>);
impl Authorized {
pub fn builder(app_state: web::Data<AppContext>, required: Protected) -> Authorized {
Authorized(std::rc::Rc::new(Required(required)), app_state)
}
}
#[derive(Debug, Clone)]
struct Required(Protected);
impl Required {
pub fn can_anonymous_access(&self) -> bool {
self.0 == Protected::AnonymousUser
}
pub fn has_any_role(&self, user_name: &str, user_roles: &str) -> bool {
let allowed_roles = match &self.0 {
Protected::AnonymousUser => vec![],
Protected::LoginUser(roles) => roles.clone(),
Protected::InnerService(roles) => roles.clone(),
};
if allowed_roles.is_empty() {
return true;
}
let b = user_roles
.split(',')
.any(|x| allowed_roles.contains(&x.trim().to_string()));
log::debug!(
"has_any_role: {}, username={}, allowed_roles={:?}, user_roles={:?}",
b,
user_name,
allowed_roles,
user_roles
);
b
}
}
pub struct AuthorizedMiddleware<S> {
app_state: web::Data<AppContext>,
required: std::rc::Rc<Required>,
service: std::rc::Rc<S>,
}
impl<S, B> Transform<S, ServiceRequest> for Authorized
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = S::Error;
type Transform = AuthorizedMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(AuthorizedMiddleware {
app_state: self.1.clone(),
service: std::rc::Rc::new(service),
required: self.0.clone(),
})
}
}
#[allow(clippy::type_complexity)]
impl<S, B> Service<ServiceRequest> for AuthorizedMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(ctx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let svc = self.service.clone();
let required = self.required.clone();
let app_state = self.app_state.clone();
Box::pin(async move {
let claims = JwtToken::get_claims(&req);
let user_id = if let Some(claims) = claims.0.as_ref() {
if !claims.is_expired() {
Some(UserId(claims.sub.to_owned()))
} else {
if !required.can_anonymous_access() {
let response = reject::<B>(&req, 401, "jwt token is expired!");
let (request, _) = req.into_parts();
return Ok(ServiceResponse::new(request, response));
}
insert_user(&req, None, Some(claims));
return svc.call(req).await.map(ServiceResponse::map_into_left_body);
}
} else {
if claims.1 == 400 {
let response = reject::<B>(&req, 401, "jwt token bad!");
let (request, _) = req.into_parts();
return Ok(ServiceResponse::new(request, response));
}
if claims.1 == 401 && !required.can_anonymous_access() {
let response = reject::<B>(&req, 401, "jwt token expired!");
let (request, _) = req.into_parts();
return Ok(ServiceResponse::new(request, response));
}
if claims.2 && required.can_anonymous_access() {
insert_user(&req, None, claims.0.as_ref());
return svc.call(req).await.map(ServiceResponse::map_into_left_body);
}
session_user_id(&mut req).await
};
let (can_access, code) = if let Some(user_id) = user_id {
match get_user_details(app_state.mysql(), &user_id.0).await {
Ok(Some(user_details)) => {
if required.can_anonymous_access() {
insert_user(&req, Some(user_details), claims.0.as_ref());
(true, 200)
} else {
if required
.has_any_role(&user_details.user_name, &user_details.user_role)
{
insert_user(&req, Some(user_details), claims.0.as_ref());
(true, 200)
} else {
(false, 403)
}
}
}
Ok(None) | Err(_) => {
if required.can_anonymous_access() {
insert_user(&req, None, claims.0.as_ref());
(true, 200)
} else {
(false, 401)
}
}
}
} else if required.can_anonymous_access() {
insert_user(&req, None, claims.0.as_ref());
(true, 200)
} else {
(false, 401)
};
if can_access {
svc.call(req).await.map(ServiceResponse::map_into_left_body)
} else {
let response = reject::<B>(&req, code, "Unauthorized");
let (request, _) = req.into_parts();
Ok(ServiceResponse::new(request, response))
}
})
}
}
pub async fn session_user_id(req: &mut ServiceRequest) -> Option<UserId> {
let session = {
let (http_request, payload) = req.parts_mut();
TypedSession::from_request(http_request, payload).await
};
match session {
Ok(session) => match session.get_user_id() {
Ok(Some(user_id)) => Some(user_id),
Ok(None) => None,
Err(e) => {
log::error!("read_session_failed: error={:?}", e);
None
}
},
Err(e) => {
log::error!("read_session_error: error={:?}", e);
None
}
}
}
fn reject<B>(req: &ServiceRequest, code: u16, msg: &str) -> HttpResponse<EitherBody<B>> {
let content_type = get_header_value(req, "Accept");
let _upgrade = get_header_value(req, "upgrade");
#[rustfmt::skip]
let response = if content_type.contains("json") || _upgrade == "websocket" {
let r = R::failed(code, msg);
HttpResponse::Unauthorized().json(r).map_into_right_body()
} else if code == 401 {
HttpResponse::Found().insert_header((http::header::LOCATION, "/login")).finish().map_into_right_body()
} else {
HttpResponse::Found().insert_header((http::header::LOCATION, "/login#access_denied")).finish().map_into_right_body()
};
response
}
#[rustfmt::skip]
fn insert_user(req: &ServiceRequest, user_details: Option<UserDetails>, claims: Option<&Claims>) {
req.extensions_mut().insert(Agent(get_header_value(req, "User-Agent")));
req.extensions_mut().insert(Language(get_header_value(req, "Accept-Language")));
if let Some(user_details) = user_details {
req.extensions_mut().insert(UserId(user_details.user_id.to_owned()));
if claims.is_some() {
req.extensions_mut().insert(Requestor::JwtUser(user_details));
} else {
req.extensions_mut().insert(Requestor::LoginUser(user_details));
}
} else {
req.extensions_mut().insert(Requestor::Anonymous("".to_string()));
}
}
fn get_header_value(req: &ServiceRequest, key: &str) -> String {
match req.headers().get(key) {
Some(val) => val.to_str().unwrap_or_default().to_string(),
_ => "".to_string(),
}
}