#![allow(clippy::exhaustive_structs)]
use as_variant::as_variant;
use http::{HeaderMap, header};
use serde::Deserialize;
pub trait AuthScheme: Sized {
type Input<'a>;
type AddAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
type Output;
type ExtractAuthenticationError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
input: Self::Input<'_>,
) -> Result<(), Self::AddAuthenticationError>;
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<Self::Output, Self::ExtractAuthenticationError>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoAuthentication;
impl AuthScheme for NoAuthentication {
type Input<'a> = SendAccessToken<'a>;
type AddAuthenticationError = header::InvalidHeaderValue;
type Output = ();
type ExtractAuthenticationError = std::convert::Infallible;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
access_token: SendAccessToken<'_>,
) -> Result<(), Self::AddAuthenticationError> {
if let Some(access_token) = access_token.get_not_required_for_endpoint() {
add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
}
Ok(())
}
fn extract_authentication<T: AsRef<[u8]>>(
_request: &http::Request<T>,
) -> Result<(), Self::ExtractAuthenticationError> {
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AccessToken;
impl AuthScheme for AccessToken {
type Input<'a> = SendAccessToken<'a>;
type AddAuthenticationError = AddRequiredTokenError;
type Output = String;
type ExtractAuthenticationError = ExtractTokenError;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
access_token: SendAccessToken<'_>,
) -> Result<(), Self::AddAuthenticationError> {
let token = access_token
.get_required_for_endpoint()
.ok_or(AddRequiredTokenError::MissingAccessToken)?;
Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
}
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<String, Self::ExtractAuthenticationError> {
extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AccessTokenOptional;
impl AuthScheme for AccessTokenOptional {
type Input<'a> = SendAccessToken<'a>;
type AddAuthenticationError = header::InvalidHeaderValue;
type Output = Option<String>;
type ExtractAuthenticationError = ExtractTokenError;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
access_token: SendAccessToken<'_>,
) -> Result<(), Self::AddAuthenticationError> {
if let Some(access_token) = access_token.get_required_for_endpoint() {
add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
}
Ok(())
}
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<Option<String>, Self::ExtractAuthenticationError> {
extract_bearer_or_query_token(request)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AppserviceToken;
impl AuthScheme for AppserviceToken {
type Input<'a> = SendAccessToken<'a>;
type AddAuthenticationError = AddRequiredTokenError;
type Output = String;
type ExtractAuthenticationError = ExtractTokenError;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
access_token: SendAccessToken<'_>,
) -> Result<(), Self::AddAuthenticationError> {
let token = access_token
.get_required_for_appservice()
.ok_or(AddRequiredTokenError::MissingAccessToken)?;
Ok(add_access_token_as_authorization_header(request.headers_mut(), token)?)
}
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<String, Self::ExtractAuthenticationError> {
extract_bearer_or_query_token(request)?.ok_or(ExtractTokenError::MissingAccessToken)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AppserviceTokenOptional;
impl AuthScheme for AppserviceTokenOptional {
type Input<'a> = SendAccessToken<'a>;
type AddAuthenticationError = header::InvalidHeaderValue;
type Output = Option<String>;
type ExtractAuthenticationError = ExtractTokenError;
fn add_authentication<T: AsRef<[u8]>>(
request: &mut http::Request<T>,
access_token: SendAccessToken<'_>,
) -> Result<(), Self::AddAuthenticationError> {
if let Some(access_token) = access_token.get_required_for_appservice() {
add_access_token_as_authorization_header(request.headers_mut(), access_token)?;
}
Ok(())
}
fn extract_authentication<T: AsRef<[u8]>>(
request: &http::Request<T>,
) -> Result<Option<String>, Self::ExtractAuthenticationError> {
extract_bearer_or_query_token(request)
}
}
fn add_access_token_as_authorization_header(
headers: &mut HeaderMap,
token: &str,
) -> Result<(), header::InvalidHeaderValue> {
headers.insert(header::AUTHORIZATION, format!("Bearer {token}").try_into()?);
Ok(())
}
fn extract_bearer_or_query_token<T>(
request: &http::Request<T>,
) -> Result<Option<String>, ExtractTokenError> {
if let Some(token) = extract_bearer_token_from_authorization_header(request.headers())? {
return Ok(Some(token));
}
if let Some(query) = request.uri().query() {
Ok(extract_access_token_from_query(query)?)
} else {
Ok(None)
}
}
fn extract_bearer_token_from_authorization_header(
headers: &HeaderMap,
) -> Result<Option<String>, ExtractTokenError> {
const EXPECTED_START: &str = "bearer ";
let Some(value) = headers.get(header::AUTHORIZATION) else {
return Ok(None);
};
let value = value.to_str()?;
if value.len() < EXPECTED_START.len()
|| !value[..EXPECTED_START.len()].eq_ignore_ascii_case(EXPECTED_START)
{
return Err(ExtractTokenError::InvalidAuthorizationScheme);
}
Ok(Some(value[EXPECTED_START.len()..].to_owned()))
}
fn extract_access_token_from_query(
query: &str,
) -> Result<Option<String>, serde_html_form::de::Error> {
#[derive(Deserialize)]
struct AccessTokenDeHelper {
access_token: Option<String>,
}
serde_html_form::from_str::<AccessTokenDeHelper>(query).map(|helper| helper.access_token)
}
#[derive(Clone, Copy, Debug)]
#[allow(clippy::exhaustive_enums)]
pub enum SendAccessToken<'a> {
IfRequired(&'a str),
Always(&'a str),
Appservice(&'a str),
None,
}
impl<'a> SendAccessToken<'a> {
pub fn get_required_for_endpoint(self) -> Option<&'a str> {
as_variant!(self, Self::IfRequired | Self::Appservice | Self::Always)
}
pub fn get_not_required_for_endpoint(self) -> Option<&'a str> {
as_variant!(self, Self::Always)
}
pub fn get_required_for_appservice(self) -> Option<&'a str> {
as_variant!(self, Self::Appservice | Self::Always)
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum AddRequiredTokenError {
#[error("no access token provided, but this endpoint requires one")]
MissingAccessToken,
#[error(transparent)]
IntoHeader(#[from] header::InvalidHeaderValue),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ExtractTokenError {
#[error("no access token found, but this endpoint requires one")]
MissingAccessToken,
#[error(transparent)]
FromHeader(#[from] header::ToStrError),
#[error("invalid authorization header scheme")]
InvalidAuthorizationScheme,
#[error("failed to deserialize query string: {0}")]
FromQuery(#[from] serde_html_form::de::Error),
}