use axum::async_trait;
use axum::extract::{FromRequest, RequestParts};
use http::{header::AUTHORIZATION, StatusCode};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct AuthBasic(pub (String, Option<String>));
#[async_trait]
impl<B> FromRequest<B> for AuthBasic
where
B: Send,
{
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> std::result::Result<Self, Self::Rejection> {
let authorisation = req
.headers()
.and_then(|headers| headers.get(AUTHORIZATION))
.ok_or((StatusCode::BAD_REQUEST, "`Authorization` header is missing"))?
.to_str()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"`Authorization` header contains invalid characters",
)
})?;
let split = authorisation.split_once(' ');
match split {
Some((name, contents)) if name == "Basic" => decode_basic(contents),
_ => Err((
StatusCode::BAD_REQUEST,
"`Authorization` header must be for basic authentication",
)),
}
}
}
fn decode_basic(input: &str) -> Result<AuthBasic, (StatusCode, &'static str)> {
const ERR: (StatusCode, &'static str) = (
StatusCode::BAD_REQUEST,
"`Authorization` header's basic authentication was improperly encoded",
);
let decoded = base64::decode(input).map_err(|_| ERR)?;
let decoded = String::from_utf8(decoded).map_err(|_| ERR)?;
Ok(AuthBasic(
if let Some((id, password)) = decoded.split_once(':') {
(id.to_string(), Some(password.to_string()))
} else {
(decoded, None)
},
))
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct AuthBearer(pub String);
#[async_trait]
impl<B> FromRequest<B> for AuthBearer
where
B: Send,
{
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> std::result::Result<Self, Self::Rejection> {
let authorisation = req
.headers()
.and_then(|headers| headers.get(AUTHORIZATION))
.ok_or((StatusCode::BAD_REQUEST, "`Authorization` header is missing"))?
.to_str()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"`Authorization` header contains invalid characters",
)
})?;
let split = authorisation.split_once(' ');
match split {
Some((name, contents)) if name == "Bearer" => Ok(Self(contents.to_string())),
_ => Err((
StatusCode::BAD_REQUEST,
"`Authorization` header must be a bearer token",
)),
}
}
}