pub use jsonwebtoken::{DecodingKey, Validation};
use crate::http::header::{HeaderValue, WWW_AUTHENTICATE};
use crate::http::StatusCode;
use crate::{
async_trait, join, Context, Error, Middleware, Next, Result, State, SyncContext,
};
use headers::{authorization::Bearer, Authorization, HeaderMapExt};
use jsonwebtoken::{dangerous_unsafe_decode, decode};
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::sync::Arc;
const INVALID_TOKEN: &str = r#"Bearer realm="<jwt>", error="invalid_token""#;
struct JwtScope;
pub trait JwtVerifier<S> {
fn claims<C>(&self) -> Result<C>
where
C: 'static + DeserializeOwned;
fn verify<C>(&self, validation: &Validation) -> Result<C>
where
C: 'static + DeserializeOwned;
}
pub fn guard<S: State>(secret: DecodingKey) -> impl Middleware<S> {
guard_by(secret, Validation::default())
}
pub fn guard_by<S: State>(
secret: DecodingKey,
validation: Validation,
) -> impl Middleware<S> {
join(
Arc::new(catch_www_authenticate),
JwtGuard {
secret: secret.into_static(),
validation,
},
)
}
#[inline]
async fn catch_www_authenticate<S: State>(mut ctx: Context<S>, next: Next) -> Result {
let result = next.await;
if let Err(ref err) = result {
if err.status_code == StatusCode::UNAUTHORIZED {
ctx.resp_mut()
.headers
.insert(WWW_AUTHENTICATE, HeaderValue::from_static(INVALID_TOKEN));
}
}
result
}
struct JwtGuard {
secret: DecodingKey<'static>,
validation: Validation,
}
#[inline]
fn unauthorized(_err: impl ToString) -> Error {
Error::new(StatusCode::UNAUTHORIZED, "".to_string(), false)
}
#[inline]
fn guard_not_set() -> Error {
Error::new(
StatusCode::INTERNAL_SERVER_ERROR,
"middleware `JwtGuard` is not set correctly",
false,
)
}
impl<S> JwtVerifier<S> for SyncContext<S>
where
S: State,
{
#[inline]
fn claims<C>(&self) -> Result<C>
where
C: 'static + DeserializeOwned,
{
let token = self.load_scoped::<JwtScope, Bearer>("token");
match token {
Some(token) => dangerous_unsafe_decode(token.token())
.map(|data| data.claims)
.map_err(|err| {
Error::new(
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"{}\ntoken deserialized fails, this maybe a bug of JwtGuard.",
err
),
false,
)
}),
None => Err(guard_not_set()),
}
}
#[inline]
fn verify<C>(&self, validation: &Validation) -> Result<C>
where
C: 'static + DeserializeOwned,
{
let secret = self.load_scoped::<JwtScope, DecodingKey<'static>>("secret");
let token = self.load_scoped::<JwtScope, Bearer>("token");
match (secret, token) {
(Some(secret), Some(token)) => decode(token.token(), &*secret, validation)
.map(|data| data.claims)
.map_err(unauthorized),
_ => Err(guard_not_set()),
}
}
}
#[async_trait(?Send)]
impl<S: State> Middleware<S> for JwtGuard {
#[inline]
async fn handle(self: Arc<Self>, mut ctx: Context<S>, next: Next) -> Result {
let bearer = ctx
.req()
.headers
.typed_get::<Authorization<Bearer>>()
.ok_or_else(|| unauthorized(""))?
.0;
decode::<Value>(bearer.token(), &self.secret, &self.validation)
.map_err(unauthorized)?;
ctx.store_scoped(JwtScope, "secret", self.secret.clone());
ctx.store_scoped(JwtScope, "token", bearer);
next.await
}
}
#[cfg(test)]
mod tests {
use super::{guard, DecodingKey, INVALID_TOKEN};
use crate::http::header::{AUTHORIZATION, WWW_AUTHENTICATE};
use crate::http::StatusCode;
use crate::preload::*;
use crate::{App, Error};
use async_std::task::spawn;
use jsonwebtoken::{encode, EncodingKey, Header};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug, Serialize, Deserialize)]
struct User {
sub: String,
company: String,
exp: u64,
id: u64,
name: String,
}
const SECRET: &[u8] = b"123456";
#[tokio::test]
async fn claims() -> Result<(), Box<dyn std::error::Error>> {
let mut app = App::new(());
let (addr, server) = app
.gate(guard(DecodingKey::from_secret(SECRET)))
.end(move |ctx| async move {
let user: User = ctx.claims()?;
assert_eq!(0, user.id);
assert_eq!("Hexilee", &user.name);
Ok(())
})
.run()?;
spawn(server);
let resp = reqwest::get(&format!("http://{}", addr)).await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(INVALID_TOKEN, resp.headers()[WWW_AUTHENTICATE].to_str()?);
let client = reqwest::Client::new();
let resp = client
.get(&format!("http://{}", addr))
.header(AUTHORIZATION, [255].as_ref())
.send()
.await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(INVALID_TOKEN, resp.headers()[WWW_AUTHENTICATE].to_str()?);
let resp = client
.get(&format!("http://{}", addr))
.header(AUTHORIZATION, "Basic hahaha")
.send()
.await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(INVALID_TOKEN, resp.headers()[WWW_AUTHENTICATE].to_str()?);
let resp = client
.get(&format!("http://{}", addr))
.header(AUTHORIZATION, "Bearer hahaha")
.send()
.await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(INVALID_TOKEN, resp.headers()[WWW_AUTHENTICATE].to_str()?);
let mut user = User {
sub: "user".to_string(),
company: "None".to_string(),
exp: (SystemTime::now() - Duration::from_secs(1))
.duration_since(UNIX_EPOCH)?
.as_secs(),
id: 0,
name: "Hexilee".to_string(),
};
let resp = client
.get(&format!("http://{}", addr))
.header(
AUTHORIZATION,
format!(
"Bearer {}",
encode(
&Header::default(),
&user,
&EncodingKey::from_secret(SECRET)
)?
),
)
.send()
.await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(INVALID_TOKEN, resp.headers()[WWW_AUTHENTICATE].to_str()?);
user.exp = (SystemTime::now() + Duration::from_millis(60))
.duration_since(UNIX_EPOCH)?
.as_secs();
let resp = client
.get(&format!("http://{}", addr))
.header(
AUTHORIZATION,
format!(
"Bearer {}",
encode(
&Header::default(),
&user,
&EncodingKey::from_secret(SECRET)
)?
),
)
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
Ok(())
}
#[tokio::test]
async fn jwt_verify_not_set() -> Result<(), Box<dyn std::error::Error>> {
let mut app = App::new(());
let (addr, server) = app
.gate_fn(move |ctx, _next| async move {
let result: Result<User, Error> = ctx.claims();
assert!(result.is_err());
let status = result.unwrap_err();
assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, status.status_code);
assert_eq!("middleware `JwtGuard` is not set correctly", status.message);
Ok(())
})
.run()?;
spawn(server);
reqwest::get(&format!("http://{}", addr)).await?;
Ok(())
}
}