#[cfg(feature = "jwt")]
use std::collections::HashMap;
use std::{borrow::Cow, collections::BTreeMap, str::FromStr};
use axum::{
body::Body,
http::{
header::{AUTHORIZATION, WWW_AUTHENTICATE},
HeaderName, HeaderValue, Request, Response, StatusCode,
},
response::IntoResponse,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
#[cfg(feature = "jwt")]
use deboog::Deboog;
use dyn_clone::{clone_box, DynClone};
#[cfg(feature = "jwt")]
use jsonwebtoken as jwt;
use okapi::{openapi3, Map};
#[cfg(feature = "jwt")]
use serde_json::Value;
use tracing::error;
use crate::{
auth::{errors::AuthError, token::AuthToken, user::UserId},
errors,
};
pub trait AuthExtractor: std::fmt::Debug + DynClone + Send + Sync + 'static {
fn extract_auth(&self, req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError>;
#[must_use]
fn error_response(&self, err: AuthError) -> Response<Body>;
#[must_use]
fn security_schemes(&self) -> BTreeMap<String, openapi3::SecurityScheme> {
BTreeMap::new()
}
fn sensitive_headers(&self) -> Vec<HeaderName> {
vec![]
}
}
#[derive(Clone, Debug, Default)]
pub struct NoOpAuthExtractor;
impl AuthExtractor for NoOpAuthExtractor {
fn extract_auth(&self, _req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError> {
Ok((None, AuthToken::Absent))
}
fn error_response(&self, err: AuthError) -> Response<Body> {
error!("tried to generate auth error response for NoOpAuthExtractor");
problemdetails::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_type(errors::TAG_UXUM_AUTH)
.with_title(err.to_string())
.into_response()
}
}
#[derive(Clone, Debug)]
pub struct BasicAuthExtractor {
www_auth: Cow<'static, str>,
}
impl Default for BasicAuthExtractor {
fn default() -> Self {
Self {
www_auth: Cow::Borrowed(r#"Basic realm="auth", charset="UTF-8""#),
}
}
}
impl AuthExtractor for BasicAuthExtractor {
fn extract_auth(&self, req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError> {
match req.headers().get(AUTHORIZATION) {
Some(header) => {
Self::parse_header(header).map(|(user, pwd)| (Some(user.into()), pwd.into()))
}
None => Err(AuthError::NoAuthProvided),
}
}
fn error_response(&self, err: AuthError) -> Response<Body> {
let status = match err {
AuthError::NoAuthProvided | AuthError::UserNotFound | AuthError::AuthFailed => {
StatusCode::UNAUTHORIZED
}
AuthError::NoPermission(_) => StatusCode::FORBIDDEN,
_ => StatusCode::BAD_REQUEST,
};
let mut resp = problemdetails::new(status)
.with_type(errors::TAG_UXUM_AUTH)
.with_title(err.to_string())
.into_response();
if status == StatusCode::UNAUTHORIZED {
let header_value = match HeaderValue::from_str(&self.www_auth) {
Ok(val) => val,
Err(err) => {
return problemdetails::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_type(errors::TAG_UXUM_AUTH)
.with_title("Invalid HTTP Basic realm value")
.with_detail(err.to_string())
.into_response()
}
};
let _ = resp.headers_mut().insert(WWW_AUTHENTICATE, header_value);
}
resp
}
fn security_schemes(&self) -> BTreeMap<String, openapi3::SecurityScheme> {
maplit::btreemap! {
"basic".into() => openapi3::SecurityScheme {
description: Some("HTTP Basic authentication".into()),
data: openapi3::SecuritySchemeData::Http {
scheme: "basic".into(),
bearer_format: None,
},
extensions: Map::default(),
},
}
}
}
impl BasicAuthExtractor {
const SCHEME: &'static str = "Basic";
pub fn new(realm: Option<impl AsRef<str>>) -> Self {
match realm {
Some(realm) => Self {
www_auth: Cow::Owned(Self::format_www_authenticate(realm)),
},
None => Default::default(),
}
}
fn format_www_authenticate(realm: impl AsRef<str>) -> String {
format!(
r#"{} realm="{}", charset="UTF-8""#,
Self::SCHEME,
realm.as_ref()
)
}
fn parse_header(header: &HeaderValue) -> Result<(String, String), AuthError> {
let Ok(header) = header.to_str() else {
return Err(AuthError::InvalidAuthHeader);
};
match header.split_once(' ') {
Some((scheme, payload)) if scheme.eq_ignore_ascii_case(Self::SCHEME) => {
Self::parse_payload(payload)
}
Some((scheme, _)) => Err(AuthError::UnknownAuthScheme(scheme.to_string())),
None => Err(AuthError::InvalidAuthHeader),
}
}
fn parse_payload(payload: &str) -> Result<(String, String), AuthError> {
let raw = String::from_utf8(B64.decode(payload)?)?;
raw.split_once(':')
.map(|(user, pwd)| (user.to_string(), pwd.to_string()))
.ok_or(AuthError::InvalidAuthPayload)
}
pub fn set_realm(&mut self, realm: impl AsRef<str>) {
self.www_auth = Cow::Owned(Self::format_www_authenticate(realm));
}
}
#[derive(Clone, Debug)]
pub struct HeaderAuthExtractor {
user_header: Cow<'static, str>,
token_header: Cow<'static, str>,
}
impl Default for HeaderAuthExtractor {
fn default() -> Self {
Self {
user_header: Cow::Borrowed("X-API-Name"),
token_header: Cow::Borrowed("X-API-Key"),
}
}
}
impl AuthExtractor for HeaderAuthExtractor {
fn extract_auth(&self, req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError> {
let headers = req.headers();
let user = match headers.get(self.user_header.as_ref()) {
Some(header) => match header.to_str() {
Ok(user) => user.into(),
Err(_) => return Err(AuthError::InvalidAuthPayload),
},
None => return Err(AuthError::NoAuthProvided),
};
let token = match headers.get(self.token_header.as_ref()) {
Some(header) => match header.to_str() {
Ok(user) => user.to_string(),
Err(_) => return Err(AuthError::InvalidAuthPayload),
},
None => return Err(AuthError::NoAuthProvided),
};
Ok((Some(user), token.into()))
}
fn error_response(&self, err: AuthError) -> Response<Body> {
let status = match err {
AuthError::NoAuthProvided
| AuthError::UserNotFound
| AuthError::AuthFailed
| AuthError::NoPermission(_) => StatusCode::FORBIDDEN,
_ => StatusCode::BAD_REQUEST,
};
problemdetails::new(status)
.with_type(errors::TAG_UXUM_AUTH)
.with_title(err.to_string())
.into_response()
}
fn security_schemes(&self) -> BTreeMap<String, openapi3::SecurityScheme> {
maplit::btreemap! {
"api-name".into() => openapi3::SecurityScheme {
description: Some("API user name".into()),
data: openapi3::SecuritySchemeData::ApiKey {
name: self.user_header.to_string(),
location: "header".into(),
},
extensions: Map::default(),
},
"api-key".into() => openapi3::SecurityScheme {
description: Some("API key".into()),
data: openapi3::SecuritySchemeData::ApiKey {
name: self.token_header.to_string(),
location: "header".into(),
},
extensions: Map::default(),
},
}
}
fn sensitive_headers(&self) -> Vec<HeaderName> {
match HeaderName::from_str(self.token_header.as_ref()) {
Ok(hdr) => vec![hdr],
_ => vec![],
}
}
}
impl HeaderAuthExtractor {
pub fn new(
user_header: Option<impl AsRef<str>>,
token_header: Option<impl AsRef<str>>,
) -> Self {
let mut extractor = Self::default();
if let Some(header) = user_header {
extractor.user_header = Cow::Owned(header.as_ref().to_string());
}
if let Some(header) = token_header {
extractor.token_header = Cow::Owned(header.as_ref().to_string());
}
extractor
}
pub fn set_user_header(&mut self, name: impl AsRef<str>) {
self.user_header = Cow::Owned(name.as_ref().into());
}
pub fn set_token_header(&mut self, name: impl AsRef<str>) {
self.token_header = Cow::Owned(name.as_ref().into());
}
}
#[cfg(feature = "jwt")]
#[derive(Clone, Deboog)]
pub struct JwtAuthExtractor {
#[deboog(skip)]
key: jwt::DecodingKey,
validation: jwt::Validation,
www_auth: Cow<'static, str>,
}
#[cfg(feature = "jwt")]
impl AuthExtractor for JwtAuthExtractor {
fn extract_auth(&self, req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError> {
match req.headers().get(AUTHORIZATION) {
Some(header) => {
let claims = self.parse_header(header)?;
let user = claims.get("sub").and_then(|v| match v {
Value::String(s) => Some(s.as_str().into()),
Value::Number(n) => Some(n.to_string().into()),
_ => None,
});
Ok((user, AuthToken::ExternallyVerified))
}
None => Err(AuthError::NoAuthProvided),
}
}
fn error_response(&self, err: AuthError) -> Response<Body> {
let status = match err {
AuthError::NoAuthProvided | AuthError::UserNotFound | AuthError::AuthFailed => {
StatusCode::UNAUTHORIZED
}
AuthError::NoPermission(_) => StatusCode::FORBIDDEN,
_ => StatusCode::BAD_REQUEST,
};
let mut resp = problemdetails::new(status)
.with_type(errors::TAG_UXUM_AUTH)
.with_title(err.to_string())
.into_response();
if status == StatusCode::UNAUTHORIZED {
let header_value = match HeaderValue::from_str(&self.www_auth) {
Ok(val) => val,
Err(err) => {
return problemdetails::new(StatusCode::INTERNAL_SERVER_ERROR)
.with_type(errors::TAG_UXUM_AUTH)
.with_title("Invalid HTTP Basic realm value")
.with_detail(err.to_string())
.into_response()
}
};
let _ = resp.headers_mut().insert(WWW_AUTHENTICATE, header_value);
}
resp
}
fn security_schemes(&self) -> BTreeMap<String, openapi3::SecurityScheme> {
maplit::btreemap! {
"bearer".into() => openapi3::SecurityScheme {
description: Some("HTTP Bearer authentication".into()),
data: openapi3::SecuritySchemeData::Http {
scheme: "bearer".into(),
bearer_format: Some("JWT".into()),
},
extensions: Map::default(),
},
}
}
}
#[cfg(feature = "jwt")]
impl JwtAuthExtractor {
const SCHEME: &'static str = "Bearer";
pub fn new(
realm: Option<impl AsRef<str>>,
key: jwt::DecodingKey,
validation: jwt::Validation,
) -> Self {
let www_auth = match realm {
Some(realm) => Cow::Owned(Self::format_www_authenticate(realm)),
None => Cow::Borrowed(r#"Bearer realm="auth", charset="UTF-8""#),
};
Self {
key,
validation,
www_auth,
}
}
fn format_www_authenticate(realm: impl AsRef<str>) -> String {
format!(
r#"{} realm="{}", charset="UTF-8""#,
Self::SCHEME,
realm.as_ref()
)
}
fn parse_header(&self, header: &HeaderValue) -> Result<HashMap<String, Value>, AuthError> {
let Ok(header) = header.to_str() else {
return Err(AuthError::InvalidAuthHeader);
};
match header.split_once(' ') {
Some((scheme, payload)) if scheme.eq_ignore_ascii_case(Self::SCHEME) => {
self.parse_payload(payload)
}
Some((scheme, _)) => Err(AuthError::UnknownAuthScheme(scheme.to_string())),
None => Err(AuthError::InvalidAuthHeader),
}
}
fn parse_payload(&self, payload: &str) -> Result<HashMap<String, Value>, AuthError> {
use jwt::errors::ErrorKind;
match jwt::decode(payload, &self.key, &self.validation) {
Ok(token) => Ok(token.claims),
Err(err) => match err.kind() {
ErrorKind::InvalidSignature => Err(AuthError::AuthFailed),
ErrorKind::ExpiredSignature => Err(AuthError::AuthFailed),
_ => Err(AuthError::InvalidAuthPayload),
},
}
}
pub fn set_realm(&mut self, realm: impl AsRef<str>) {
self.www_auth = Cow::Owned(Self::format_www_authenticate(realm));
}
pub fn set_key(&mut self, key: jwt::DecodingKey) {
self.key = key;
}
pub fn set_validation(&mut self, valid: jwt::Validation) {
self.validation = valid;
}
}
#[derive(Debug)]
pub struct StackedAuthExtractor {
extractors: Vec<Box<dyn AuthExtractor>>,
}
impl Clone for StackedAuthExtractor {
fn clone(&self) -> Self {
let extractors = self
.extractors
.iter()
.map(|ex| clone_box(ex.as_ref()))
.collect();
Self { extractors }
}
}
impl AuthExtractor for StackedAuthExtractor {
fn extract_auth(&self, req: &Request<Body>) -> Result<(Option<UserId>, AuthToken), AuthError> {
let mut first_error = None;
for ex in &self.extractors {
match ex.extract_auth(req) {
Ok(pair) => return Ok(pair),
Err(err) => {
if first_error.is_none() {
first_error = Some(err);
}
}
}
}
Err(first_error.unwrap())
}
fn error_response(&self, err: AuthError) -> Response<Body> {
self.extractors[0].error_response(err)
}
fn security_schemes(&self) -> BTreeMap<String, openapi3::SecurityScheme> {
let mut map = BTreeMap::new();
for ex in &self.extractors {
map.append(&mut ex.security_schemes());
}
map
}
fn sensitive_headers(&self) -> Vec<HeaderName> {
let mut list = Vec::new();
for ex in &self.extractors {
list.append(&mut ex.sensitive_headers());
}
list
}
}
impl StackedAuthExtractor {
pub fn new(mut extractors: Vec<Box<dyn AuthExtractor>>) -> Self {
if extractors.is_empty() {
extractors.push(Box::new(NoOpAuthExtractor));
}
Self { extractors }
}
}