use std::fmt;
use std::future::Future;
use std::pin::Pin;
use http::StatusCode;
use http::header::AUTHORIZATION;
use crate::middleware::IntoMiddleware;
use crate::middleware::Next;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub trait JwtVerifier: Send + Sync + Clone + 'static {
type Claims: Send + Sync + Clone + 'static;
type Error: fmt::Display;
fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error>;
}
pub struct JwtAuth<V: JwtVerifier> {
verifier: V,
}
impl<V: JwtVerifier> JwtAuth<V> {
pub fn new(verifier: V) -> Self {
Self { verifier }
}
}
impl<V: JwtVerifier> IntoMiddleware for JwtAuth<V> {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let verifier = self.verifier;
move |mut req: Request, next: Next| {
let verifier = verifier.clone();
Box::pin(async move {
let token = match req
.headers()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(str::trim)
{
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
"Missing or invalid Authorization header",
)
.into_response();
}
};
let claims = match verifier.verify(token) {
Ok(c) => c,
Err(e) => {
return (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")).into_response();
}
};
req.extensions_mut().insert(claims);
next.run(req).await.into_response()
})
}
}
}
#[cfg(feature = "jwt-simple")]
mod jwt_simple_impl {
use std::collections::HashMap;
use std::sync::Arc;
use ::jwt_simple::prelude::*;
use serde::Serialize;
use serde::de::DeserializeOwned;
use super::*;
use crate::types::BuildHasher;
pub enum AnyVerifyKey {
HS256(Arc<HS256Key>),
HS384(Arc<HS384Key>),
HS512(Arc<HS512Key>),
Blake2b(Arc<Blake2bKey>),
RS256(Arc<RS256PublicKey>),
RS384(Arc<RS384PublicKey>),
RS512(Arc<RS512PublicKey>),
PS256(Arc<PS256PublicKey>),
PS384(Arc<PS384PublicKey>),
PS512(Arc<PS512PublicKey>),
ES256(Arc<ES256PublicKey>),
ES256K(Arc<ES256kPublicKey>),
ES384(Arc<ES384PublicKey>),
EdDSA(Arc<Ed25519PublicKey>),
}
impl AnyVerifyKey {
pub fn alg_id(&self) -> &'static str {
match self {
Self::HS256(_) => "HS256",
Self::HS384(_) => "HS384",
Self::HS512(_) => "HS512",
Self::Blake2b(_) => "BLAKE2B",
Self::RS256(_) => "RS256",
Self::RS384(_) => "RS384",
Self::RS512(_) => "RS512",
Self::PS256(_) => "PS256",
Self::PS384(_) => "PS384",
Self::PS512(_) => "PS512",
Self::ES256(_) => "ES256",
Self::ES256K(_) => "ES256K",
Self::ES384(_) => "ES384",
Self::EdDSA(_) => "EdDSA",
}
}
fn verify_token<C>(&self, token: &str) -> Result<JWTClaims<C>, ::jwt_simple::Error>
where
C: Serialize + DeserializeOwned,
{
let opts = VerificationOptions::default();
match self {
Self::HS256(k) => k.verify_token::<C>(token, Some(opts)),
Self::HS384(k) => k.verify_token::<C>(token, Some(opts)),
Self::HS512(k) => k.verify_token::<C>(token, Some(opts)),
Self::Blake2b(k) => k.verify_token::<C>(token, Some(opts)),
Self::RS256(k) => k.verify_token::<C>(token, Some(opts)),
Self::RS384(k) => k.verify_token::<C>(token, Some(opts)),
Self::RS512(k) => k.verify_token::<C>(token, Some(opts)),
Self::PS256(k) => k.verify_token::<C>(token, Some(opts)),
Self::PS384(k) => k.verify_token::<C>(token, Some(opts)),
Self::PS512(k) => k.verify_token::<C>(token, Some(opts)),
Self::ES256(k) => k.verify_token::<C>(token, Some(opts)),
Self::ES256K(k) => k.verify_token::<C>(token, Some(opts)),
Self::ES384(k) => k.verify_token::<C>(token, Some(opts)),
Self::EdDSA(k) => k.verify_token::<C>(token, Some(opts)),
}
}
}
impl Clone for AnyVerifyKey {
fn clone(&self) -> Self {
match self {
Self::HS256(k) => Self::HS256(Arc::clone(k)),
Self::HS384(k) => Self::HS384(Arc::clone(k)),
Self::HS512(k) => Self::HS512(Arc::clone(k)),
Self::Blake2b(k) => Self::Blake2b(Arc::clone(k)),
Self::RS256(k) => Self::RS256(Arc::clone(k)),
Self::RS384(k) => Self::RS384(Arc::clone(k)),
Self::RS512(k) => Self::RS512(Arc::clone(k)),
Self::PS256(k) => Self::PS256(Arc::clone(k)),
Self::PS384(k) => Self::PS384(Arc::clone(k)),
Self::PS512(k) => Self::PS512(Arc::clone(k)),
Self::ES256(k) => Self::ES256(Arc::clone(k)),
Self::ES256K(k) => Self::ES256K(Arc::clone(k)),
Self::ES384(k) => Self::ES384(Arc::clone(k)),
Self::EdDSA(k) => Self::EdDSA(Arc::clone(k)),
}
}
}
pub struct MultiKeyVerifier<C> {
keys: HashMap<&'static str, AnyVerifyKey, BuildHasher>,
_phantom: std::marker::PhantomData<C>,
}
impl<C> Clone for MultiKeyVerifier<C> {
fn clone(&self) -> Self {
Self {
keys: self.keys.clone(),
_phantom: std::marker::PhantomData,
}
}
}
impl<C> MultiKeyVerifier<C> {
pub fn new(keys: HashMap<&'static str, AnyVerifyKey, BuildHasher>) -> Self {
Self {
keys,
_phantom: std::marker::PhantomData,
}
}
}
impl<C> JwtVerifier for MultiKeyVerifier<C>
where
C: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
{
type Claims = JWTClaims<C>;
type Error = String;
fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error> {
let meta = ::jwt_simple::token::Token::decode_metadata(token)
.map_err(|e| format!("Cannot decode JWT header: {e}"))?;
let alg = meta.algorithm();
let key = self
.keys
.get(alg)
.ok_or_else(|| format!("Algorithm {alg} not allowed"))?;
key.verify_token::<C>(token).map_err(|e| e.to_string())
}
}
}
#[cfg(feature = "jwt-simple")]
pub use jwt_simple_impl::{AnyVerifyKey, MultiKeyVerifier};