use std::future::{ready, Ready};
use actix_service::{Service, Transform};
use actix_web::{
body::MessageBody,
dev::{forward_ready, ServiceRequest, ServiceResponse},
Error, HttpMessage, HttpRequest, HttpResponse,
};
pub trait AuthExtractor: Clone + Send + Sync + 'static {
type Principal: Clone + Send + Sync + 'static;
fn extract(&self, req: &HttpRequest) -> Option<Self::Principal>;
}
pub struct Authentication<E: AuthExtractor> {
extractor: E,
}
impl<E: AuthExtractor> Authentication<E> {
pub fn new(extractor: E) -> Self {
Self { extractor }
}
}
impl<E: AuthExtractor> Clone for Authentication<E> {
fn clone(&self) -> Self {
Self {
extractor: self.extractor.clone(),
}
}
}
impl<S, E, B> Transform<S, ServiceRequest> for Authentication<E>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
E: AuthExtractor,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Transform = AuthMiddleware<S, E>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(AuthMiddleware {
service,
extractor: self.extractor.clone(),
}))
}
}
pub struct AuthMiddleware<S, E: AuthExtractor> {
service: S,
extractor: E,
}
impl<S, E, B> Service<ServiceRequest> for AuthMiddleware<S, E>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
E: AuthExtractor,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Future =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
if let Some(principal) = self.extractor.extract(req.request()) {
req.extensions_mut().insert(principal);
let fut = self.service.call(req);
Box::pin(async move {
let res = fut.await?;
Ok(res.map_into_boxed_body())
})
} else {
let (req, _) = req.into_parts();
let response = HttpResponse::Unauthorized()
.insert_header(("www-authenticate", "Bearer"))
.body("Unauthorized");
Box::pin(ready(Ok(ServiceResponse::new(req, response))))
}
}
}
pub struct Authorization<F>
where
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
{
check: F,
}
impl<F> Authorization<F>
where
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
{
pub fn new(check: F) -> Self {
Self { check }
}
}
impl<F> Clone for Authorization<F>
where
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
check: self.check.clone(),
}
}
}
impl<S, F, B> Transform<S, ServiceRequest> for Authorization<F>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Transform = AuthzMiddleware<S, F>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(AuthzMiddleware {
service,
check: self.check.clone(),
}))
}
}
pub struct AuthzMiddleware<S, F>
where
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
{
service: S,
check: F,
}
impl<S, F, B> Service<ServiceRequest> for AuthzMiddleware<S, F>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
F: Fn(&HttpRequest) -> bool + Clone + Send + Sync + 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Future =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
if (self.check)(req.request()) {
let fut = self.service.call(req);
Box::pin(async move {
let res = fut.await?;
Ok(res.map_into_boxed_body())
})
} else {
let (req, _) = req.into_parts();
let response = HttpResponse::Forbidden().body("Forbidden");
Box::pin(ready(Ok(ServiceResponse::new(req, response))))
}
}
}