use axum::extract::{FromRequestParts, OptionalFromRequestParts};
use http::request::Parts;
use crate::Error;
use super::claims::Claims;
use super::error::JwtError;
#[derive(Debug)]
pub struct Bearer(pub String);
impl<S: Send + Sync> FromRequestParts<S> for Bearer {
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let header = parts
.headers
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
Error::unauthorized("unauthorized")
.chain(JwtError::MissingToken)
.with_code(JwtError::MissingToken.code())
})?;
let token = header
.strip_prefix("Bearer ")
.or_else(|| header.strip_prefix("bearer "))
.ok_or_else(|| {
Error::unauthorized("unauthorized")
.chain(JwtError::MissingToken)
.with_code(JwtError::MissingToken.code())
})?;
if token.is_empty() {
return Err(Error::unauthorized("unauthorized")
.chain(JwtError::MissingToken)
.with_code(JwtError::MissingToken.code()));
}
Ok(Bearer(token.to_string()))
}
}
impl<S: Send + Sync, T> FromRequestParts<S> for Claims<T>
where
T: Clone + Send + Sync + 'static,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Claims<T>>()
.cloned()
.ok_or_else(|| Error::unauthorized("unauthorized"))
}
}
impl<S: Send + Sync, T> OptionalFromRequestParts<S> for Claims<T>
where
T: Clone + Send + Sync + 'static,
{
type Rejection = Error;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<Claims<T>>().cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct TestClaims {
role: String,
}
#[tokio::test]
async fn bearer_extracts_token() {
let (mut parts, _) = http::Request::builder()
.header("Authorization", "Bearer my-token")
.body(())
.unwrap()
.into_parts();
let bearer = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.unwrap();
assert_eq!(bearer.0, "my-token");
}
#[tokio::test]
async fn bearer_missing_header_returns_401() {
let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.unwrap_err();
assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn bearer_wrong_scheme_returns_401() {
let (mut parts, _) = http::Request::builder()
.header("Authorization", "Basic abc")
.body(())
.unwrap()
.into_parts();
let err = <Bearer as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.unwrap_err();
assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn claims_extract_from_extensions() {
let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
let claims = Claims::new(TestClaims {
role: "admin".into(),
})
.with_sub("user_1")
.with_exp(9999999999);
parts.extensions.insert(claims.clone());
let extracted =
<Claims<TestClaims> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.unwrap();
assert_eq!(extracted.custom.role, "admin");
assert_eq!(extracted.sub, Some("user_1".into()));
}
#[tokio::test]
async fn claims_missing_returns_401() {
let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
let err = <Claims<TestClaims> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.unwrap_err();
assert_eq!(err.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn option_claims_none_when_missing() {
let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
let result = <Claims<TestClaims> as OptionalFromRequestParts<()>>::from_request_parts(
&mut parts,
&(),
)
.await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn option_claims_some_when_present() {
let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
parts.extensions.insert(Claims::new(TestClaims {
role: "admin".into(),
}));
let result = <Claims<TestClaims> as OptionalFromRequestParts<()>>::from_request_parts(
&mut parts,
&(),
)
.await;
assert!(result.unwrap().is_some());
}
}