use axum::{
body::Body,
extract::{FromRequestParts, OptionalFromRequestParts, Request},
response::IntoResponse,
};
use futures_core::future::BoxFuture;
use http::{Response, StatusCode, request::Parts};
use std::convert::Infallible;
use std::error::Error;
use std::task::{Context, Poll};
use tower::{Layer, Service};
use tracing::error;
use crate::validate_incoming::{MAuthValidationError, ValidatedRequestDetails};
use crate::{
MAuthInfo,
config::{ConfigFileSection, ConfigReadError},
};
pub struct RequiredMAuthValidationService<S> {
mauth_info: MAuthInfo,
config_info: ConfigFileSection,
service: S,
}
impl<S> Service<Request> for RequiredMAuthValidationService<S>
where
S: Service<Request> + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn Error + Sync + Send>>,
S::Response: Into<Response<Body>>,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let mut cloned = self.clone();
Box::pin(async move {
match cloned.mauth_info.validate_request(request).await {
Ok(valid_request) => match cloned.service.call(valid_request).await {
Ok(response) => Ok(response.into()),
Err(err) => Err(err),
},
Err(err) => {
error!(
error = ?err,
"Failed to validate MAuth signature, rejecting request"
);
Ok(StatusCode::UNAUTHORIZED.into_response())
}
}
})
}
}
impl<S: Clone> Clone for RequiredMAuthValidationService<S> {
fn clone(&self) -> Self {
RequiredMAuthValidationService {
mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(),
config_info: self.config_info.clone(),
service: self.service.clone(),
}
}
}
#[derive(Clone)]
pub struct RequiredMAuthValidationLayer {
config_info: ConfigFileSection,
}
impl<S> Layer<S> for RequiredMAuthValidationLayer {
type Service = RequiredMAuthValidationService<S>;
fn layer(&self, service: S) -> Self::Service {
RequiredMAuthValidationService {
mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(),
config_info: self.config_info.clone(),
service,
}
}
}
impl RequiredMAuthValidationLayer {
pub fn from_default_file() -> Result<Self, ConfigReadError> {
let config_info = MAuthInfo::config_section_from_default_file()?;
MAuthInfo::from_config_section(&config_info)?;
Ok(RequiredMAuthValidationLayer { config_info })
}
pub fn from_config_section(config_info: ConfigFileSection) -> Result<Self, ConfigReadError> {
MAuthInfo::from_config_section(&config_info)?;
Ok(RequiredMAuthValidationLayer { config_info })
}
}
pub struct OptionalMAuthValidationService<S> {
mauth_info: MAuthInfo,
config_info: ConfigFileSection,
service: S,
}
impl<S> Service<Request> for OptionalMAuthValidationService<S>
where
S: Service<Request> + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn Error + Sync + Send>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let mut cloned = self.clone();
Box::pin(async move {
let processed_request = cloned.mauth_info.validate_request_optionally(request).await;
cloned.service.call(processed_request).await
})
}
}
impl<S: Clone> Clone for OptionalMAuthValidationService<S> {
fn clone(&self) -> Self {
OptionalMAuthValidationService {
mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(),
config_info: self.config_info.clone(),
service: self.service.clone(),
}
}
}
#[derive(Clone)]
pub struct OptionalMAuthValidationLayer {
config_info: ConfigFileSection,
}
impl<S> Layer<S> for OptionalMAuthValidationLayer {
type Service = OptionalMAuthValidationService<S>;
fn layer(&self, service: S) -> Self::Service {
OptionalMAuthValidationService {
mauth_info: MAuthInfo::from_config_section(&self.config_info).unwrap(),
config_info: self.config_info.clone(),
service,
}
}
}
impl OptionalMAuthValidationLayer {
pub fn from_default_file() -> Result<Self, ConfigReadError> {
let config_info = MAuthInfo::config_section_from_default_file()?;
MAuthInfo::from_config_section(&config_info)?;
Ok(OptionalMAuthValidationLayer { config_info })
}
pub fn from_config_section(config_info: ConfigFileSection) -> Result<Self, ConfigReadError> {
MAuthInfo::from_config_section(&config_info)?;
Ok(OptionalMAuthValidationLayer { config_info })
}
}
impl<S> FromRequestParts<S> for ValidatedRequestDetails
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<ValidatedRequestDetails>()
.cloned()
.ok_or(StatusCode::UNAUTHORIZED)
}
}
impl<S> OptionalFromRequestParts<S> for ValidatedRequestDetails
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<ValidatedRequestDetails>().cloned())
}
}
impl<S> OptionalFromRequestParts<S> for MAuthValidationError
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<MAuthValidationError>().cloned())
}
}