use actix_web::{
body::EitherBody,
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, HttpResponse,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};
use std::rc::Rc;
use crate::jwt_manager::{
check_jwt_validity, get_jwt_from_request, get_user_id_from_jwt, JWTCheckError,
};
use crate::pool::AppState;
use crate::security::AuthenticatedUser;
pub struct JwtMiddleware;
impl<S, B> Transform<S, ServiceRequest> for JwtMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type InitError = ();
type Transform = JwtMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(JwtMiddlewareService {
service: Rc::new(service),
}))
}
}
pub struct JwtMiddlewareService<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for JwtMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let svc = self.service.clone();
let app_state = req.app_data::<actix_web::web::Data<AppState>>();
let pool = match app_state {
Some(state) => state.db_pool.clone(),
None => None,
};
let path = req.path();
if path == "/"
|| path.starts_with("/swagger-ui")
|| path.starts_with("/api-docs")
|| path.contains("/auth")
{
return Box::pin(async move {
let res = svc.call(req).await?;
Ok(res.map_into_left_body())
});
}
Box::pin(async move {
let pool = match pool {
Some(p) => p,
None => {
let res = HttpResponse::InternalServerError()
.body("DB Pool missing")
.map_into_right_body();
return Ok(req.into_response(res));
}
};
let jwt_option = get_jwt_from_request(req.request());
let jwt = match jwt_option {
Some(token) => token,
None => {
eprint!("no jwt");
let response = HttpResponse::Unauthorized()
.body("Unauthorized: No JWT token provided.")
.map_into_right_body();
return Ok(req.into_response(response));
}
};
match check_jwt_validity(&jwt, pool).await {
Ok(_) => {
req.extensions_mut().insert(AuthenticatedUser {
id: get_user_id_from_jwt(&jwt).unwrap().parse().unwrap_or(0),
});
let res = svc.call(req).await?;
Ok(res.map_into_left_body())
}
Err(error) => {
eprint!("error no jwt");
let response =
match error {
JWTCheckError::DatabaseError => HttpResponse::InternalServerError()
.body("Internal server error: Database not initialized."),
JWTCheckError::NoTokenProvided => HttpResponse::Unauthorized()
.body("Unauthorized: No JWT token provided."),
JWTCheckError::ExpiredToken => HttpResponse::Unauthorized()
.body("Unauthorized: JWT token is expired."),
JWTCheckError::InvalidToken => HttpResponse::Unauthorized()
.body("Unauthorized: Invalid JWT token."),
JWTCheckError::UnknownUser => {
HttpResponse::NotFound().body("User not found.")
}
};
Ok(req.into_response(response.map_into_right_body()))
}
}
})
}
}