use crate::{CLIENT, MAuthInfo, PUBKEY_CACHE};
use axum::extract::Request;
use bytes::Bytes;
use chrono::prelude::*;
use mauth_core::verifier::Verifier;
use thiserror::Error;
use tracing::error;
use uuid::Uuid;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ValidatedRequestDetails {
pub app_uuid: Uuid,
}
const MAUTH_V1_SIGNATURE_HEADER: &str = "X-MWS-Authentication";
const MAUTH_V2_SIGNATURE_HEADER: &str = "MCC-Authentication";
const MAUTH_V1_TIMESTAMP_HEADER: &str = "X-MWS-Time";
const MAUTH_V2_TIMESTAMP_HEADER: &str = "MCC-Time";
impl MAuthInfo {
pub(crate) async fn validate_request(
&self,
req: Request,
) -> Result<Request, MAuthValidationError> {
let (mut parts, body) = req.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|_| MAuthValidationError::InvalidBody)?;
match self.validate_request_v2(&parts, &body_bytes).await {
Ok(host_app_uuid) => {
parts.extensions.insert(ValidatedRequestDetails {
app_uuid: host_app_uuid,
});
let new_body = axum::body::Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
Ok(new_request)
}
Err(err) => {
if self.allow_v1_auth {
match self.validate_request_v1(&parts, &body_bytes).await {
Ok(host_app_uuid) => {
parts.extensions.insert(ValidatedRequestDetails {
app_uuid: host_app_uuid,
});
let new_body = axum::body::Body::from(body_bytes);
let new_request = Request::from_parts(parts, new_body);
Ok(new_request)
}
Err(err) => Err(err),
}
} else {
Err(err)
}
}
}
}
pub(crate) async fn validate_request_optionally(&self, req: Request) -> Request {
let (mut parts, body) = req.into_parts();
if parts.headers.contains_key(MAUTH_V2_SIGNATURE_HEADER)
|| parts.headers.contains_key(MAUTH_V1_SIGNATURE_HEADER)
{
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(error) => {
error!(
?error,
"Failed to retrieve request body, continuing with empty body"
);
Bytes::new()
}
};
match self.validate_request_v2(&parts, &body_bytes).await {
Ok(host_app_uuid) => {
parts.extensions.insert(ValidatedRequestDetails {
app_uuid: host_app_uuid,
});
}
Err(error_v2) => {
if self.allow_v1_auth {
match self.validate_request_v1(&parts, &body_bytes).await {
Ok(host_app_uuid) => {
parts.extensions.insert(ValidatedRequestDetails {
app_uuid: host_app_uuid,
});
}
Err(error_v1) => {
error!(
?error_v2,
?error_v1,
"Error attempting to validate MAuth signatures"
);
parts.extensions.insert(error_v1);
}
}
} else {
error!(?error_v2, "Error attempting to validate MAuth V2 signature");
parts.extensions.insert(error_v2);
}
}
}
let new_body = axum::body::Body::from(body_bytes);
Request::from_parts(parts, new_body)
} else {
Request::from_parts(parts, body)
}
}
async fn validate_request_v2(
&self,
req: &http::request::Parts,
body_bytes: &bytes::Bytes,
) -> Result<Uuid, MAuthValidationError> {
let sig_header = req
.headers
.get(MAUTH_V2_SIGNATURE_HEADER)
.ok_or(MAuthValidationError::NoSig)?
.to_str()
.map_err(|_| MAuthValidationError::InvalidSignature)?;
let (host_app_uuid, raw_signature) = Self::split_auth_string(sig_header, "MWSV2")?;
let ts_str = req
.headers
.get(MAUTH_V2_TIMESTAMP_HEADER)
.ok_or(MAuthValidationError::NoTime)?
.to_str()
.map_err(|_| MAuthValidationError::InvalidTime)?;
Self::validate_timestamp(ts_str)?;
match self.get_app_pub_key(&host_app_uuid).await {
None => Err(MAuthValidationError::KeyUnavailable),
Some(verifier) => {
if let Ok(signature) = String::from_utf8(raw_signature) {
match verifier.verify_signature(
2,
req.method.as_str(),
req.uri.path(),
req.uri.query().unwrap_or(""),
body_bytes,
ts_str,
signature,
) {
Ok(()) => Ok(host_app_uuid),
Err(_) => Err(MAuthValidationError::SignatureVerifyFailure),
}
} else {
Err(MAuthValidationError::SignatureVerifyFailure)
}
}
}
}
async fn validate_request_v1(
&self,
req: &http::request::Parts,
body_bytes: &bytes::Bytes,
) -> Result<Uuid, MAuthValidationError> {
let sig_header = req
.headers
.get(MAUTH_V1_SIGNATURE_HEADER)
.ok_or(MAuthValidationError::NoSig)?
.to_str()
.map_err(|_| MAuthValidationError::InvalidSignature)?;
let (host_app_uuid, raw_signature) = Self::split_auth_string(sig_header, "MWS")?;
let ts_str = req
.headers
.get(MAUTH_V1_TIMESTAMP_HEADER)
.ok_or(MAuthValidationError::NoTime)?
.to_str()
.map_err(|_| MAuthValidationError::InvalidTime)?;
Self::validate_timestamp(ts_str)?;
match self.get_app_pub_key(&host_app_uuid).await {
None => Err(MAuthValidationError::KeyUnavailable),
Some(verifier) => {
if let Ok(signature) = String::from_utf8(raw_signature) {
match verifier.verify_signature(
1,
req.method.as_str(),
req.uri.path(),
req.uri.query().unwrap_or(""),
body_bytes,
ts_str,
signature,
) {
Ok(()) => Ok(host_app_uuid),
Err(_) => Err(MAuthValidationError::SignatureVerifyFailure),
}
} else {
Err(MAuthValidationError::SignatureVerifyFailure)
}
}
}
}
fn validate_timestamp(timestamp_str: &str) -> Result<(), MAuthValidationError> {
let ts_num: i64 = timestamp_str
.parse()
.map_err(|_| MAuthValidationError::InvalidTime)?;
let ts_diff = ts_num - Utc::now().timestamp();
if !(-300..=300).contains(&ts_diff) {
Err(MAuthValidationError::InvalidTime)
} else {
Ok(())
}
}
fn split_auth_string(
auth_str: &str,
expected_prefix: &str,
) -> Result<(Uuid, Vec<u8>), MAuthValidationError> {
let header_pattern = vec![' ', ':', ';'];
let mut header_split = auth_str.split(header_pattern.as_slice());
let start_str = header_split
.next()
.ok_or(MAuthValidationError::InvalidSignature)?;
if start_str != expected_prefix {
return Err(MAuthValidationError::InvalidSignature);
}
let host_uuid_str = header_split
.next()
.ok_or(MAuthValidationError::InvalidSignature)?;
let host_app_uuid =
Uuid::parse_str(host_uuid_str).map_err(|_| MAuthValidationError::InvalidSignature)?;
let signature_encoded_string = header_split
.next()
.ok_or(MAuthValidationError::InvalidSignature)?;
Ok((host_app_uuid, signature_encoded_string.into()))
}
async fn get_app_pub_key(&self, app_uuid: &Uuid) -> Option<Verifier> {
{
let key_store = PUBKEY_CACHE.read().unwrap();
if let Some(pub_key) = key_store.get(app_uuid) {
return Some(pub_key.clone());
}
}
let uri = self.mauth_uri_base.join(&format!("{}", &app_uuid)).unwrap();
let mauth_response = CLIENT.get().unwrap().get(uri).send().await;
match mauth_response {
Err(_) => None,
Ok(response) => {
if let Ok(response_obj) = response.json::<serde_json::Value>().await
&& let Some(pub_key_str) = response_obj
.pointer("/security_token/public_key_str")
.and_then(|s| s.as_str())
.map(|st| st.to_owned())
&& let Ok(verifier) = Verifier::new(*app_uuid, pub_key_str)
{
let mut key_store = PUBKEY_CACHE.write().unwrap();
key_store.insert(*app_uuid, verifier.clone());
Some(verifier)
} else {
None
}
}
}
}
}
#[derive(Debug, Error, Clone)]
pub enum MAuthValidationError {
#[error("The timestamp of the response was either invalid or outside of the permitted range")]
InvalidTime,
#[error("The MAuth signature of the response was either missing or incorrectly formatted")]
InvalidSignature,
#[error("The timestamp header of the response was missing")]
NoTime,
#[error("The signature header of the response was missing")]
NoSig,
#[error("An error occurred while attempting to retrieve part of the response body")]
ResponseProblem,
#[error("The response body failed to parse")]
InvalidBody,
#[error("Attempt to retrieve a key to verify the response failed")]
KeyUnavailable,
#[error("The body of the response did not match the signature")]
SignatureVerifyFailure,
}