use actix_web::{
body::EitherBody,
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, HttpResponse,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};
use std::rc::Rc;
use crate::security::AuthenticatedUser;
use crate::{database::queries::is_admin_query, pool::AppState};
use crate::{
database::query_views::IsAdminQueryView,
jwt_manager::{check_jwt_validity, get_jwt_from_request, get_user_id_from_jwt, JWTCheckError},
};
use lazy_static::lazy_static;
use regex::Regex;
pub struct AdminMiddleware;
impl<S, B> Transform<S, ServiceRequest> for AdminMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type InitError = ();
type Transform = AdminMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(AdminMiddlewareService {
service: Rc::new(service),
}))
}
}
pub struct AdminMiddlewareService<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for AdminMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let svc = self.service.clone();
let app_state = req.app_data::<actix_web::web::Data<AppState>>();
let pool = match app_state {
Some(state) => state.db_pool.clone(),
None => None,
};
let path = req.path();
lazy_static! {
static ref ADMIN_PATH_REGEX: Regex = Regex::new(r"/api/v\d+/admin").unwrap();
}
if !ADMIN_PATH_REGEX.is_match(path) {
return Box::pin(async move {
let res = svc.call(req).await?;
Ok(res.map_into_left_body())
});
}
Box::pin(async move {
let pool = match pool {
Some(p) => p,
None => {
let res = HttpResponse::InternalServerError()
.body("DB Pool missing")
.map_into_right_body();
return Ok(req.into_response(res));
}
};
let jwt_option = get_jwt_from_request(req.request());
let jwt = match jwt_option {
Some(token) => token,
None => {
let response = HttpResponse::Unauthorized()
.body("Unauthorized: No JWT token provided.")
.map_into_right_body();
return Ok(req.into_response(response));
}
};
match check_jwt_validity(&jwt, pool.clone()).await {
Ok(_) => {
let view: IsAdminQueryView = IsAdminQueryView::new(
get_user_id_from_jwt(&jwt).unwrap().parse().unwrap_or(0),
);
if is_admin_query(view, pool).await.unwrap() {
req.extensions_mut().insert(AuthenticatedUser {
id: get_user_id_from_jwt(&jwt).unwrap().parse().unwrap_or(0),
});
let res = svc.call(req).await?;
Ok(res.map_into_left_body())
} else {
let response = HttpResponse::Forbidden()
.body("Forbidden: User is not an admin.")
.map_into_right_body();
Ok(req.into_response(response))
}
}
Err(error) => {
let response =
match error {
JWTCheckError::DatabaseError => HttpResponse::InternalServerError()
.body("Internal server error: Database not initialized."),
JWTCheckError::NoTokenProvided => HttpResponse::Unauthorized()
.body("Unauthorized: No JWT token provided."),
JWTCheckError::ExpiredToken => HttpResponse::Unauthorized()
.body("Unauthorized: JWT token is expired."),
JWTCheckError::InvalidToken => HttpResponse::Unauthorized()
.body("Unauthorized: Invalid JWT token."),
JWTCheckError::UnknownUser => {
HttpResponse::NotFound().body("User not found.")
}
};
Ok(req.into_response(response.map_into_right_body()))
}
}
})
}
}