use crate::error::{Result, UltimoError};
use jsonwebtoken::{
decode as jwt_decode, encode as jwt_encode, Algorithm, DecodingKey, EncodingKey, Header,
Validation,
};
use serde::{de::DeserializeOwned, Serialize};
#[derive(Debug, Clone)]
enum TokenSource {
Bearer,
Cookie(String),
}
#[derive(Clone)]
pub struct Jwt {
encoding: EncodingKey,
decoding: DecodingKey,
validation: Validation,
source: TokenSource,
optional: bool,
}
impl Jwt {
pub fn hs256(secret: impl AsRef<[u8]>) -> Self {
let secret = secret.as_ref();
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
validation: Validation::new(Algorithm::HS256),
source: TokenSource::Bearer,
optional: false,
}
}
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.validation.set_issuer(&[issuer.into()]);
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.validation.set_audience(&[audience.into()]);
self
}
pub fn leeway(mut self, seconds: u64) -> Self {
self.validation.leeway = seconds;
self
}
pub fn from_bearer(mut self) -> Self {
self.source = TokenSource::Bearer;
self
}
pub fn from_cookie(mut self, name: impl Into<String>) -> Self {
self.source = TokenSource::Cookie(name.into());
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub fn sign<T: Serialize>(&self, claims: &T) -> Result<String> {
jwt_encode(&Header::new(Algorithm::HS256), claims, &self.encoding)
.map_err(|e| UltimoError::Internal(format!("JWT signing failed: {e}")))
}
pub fn decode<T: DeserializeOwned>(&self, token: &str) -> Result<T> {
jwt_decode::<T>(token, &self.decoding, &self.validation)
.map(|data| data.claims)
.map_err(|e| UltimoError::Unauthorized(format!("invalid JWT: {e}")))
}
}
use crate::Context;
fn parse_bearer(header_value: &str) -> Option<String> {
let (scheme, token) = header_value.split_once(' ')?;
if !scheme.eq_ignore_ascii_case("bearer") {
return None;
}
let token = token.trim();
if token.is_empty() {
None
} else {
Some(token.to_string())
}
}
fn extract_token(jwt: &Jwt, ctx: &Context) -> Option<String> {
match &jwt.source {
TokenSource::Bearer => ctx
.req
.header("authorization")
.and_then(|h| parse_bearer(&h)),
TokenSource::Cookie(name) => ctx.cookie(name),
}
}
use crate::middleware::{BoxedMiddleware, Next};
use crate::response::{Response, ResponseBuilder};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
impl Jwt {
pub fn build(self) -> BoxedMiddleware {
let cfg = Arc::new(self);
Arc::new(move |ctx: Context, next: Next| {
let cfg = cfg.clone();
Box::pin(async move {
match extract_token(&cfg, &ctx) {
Some(token) => match cfg.decode::<serde_json::Value>(&token) {
Ok(claims) => {
let principal = crate::auth::Principal {
id: claims.get("sub").and_then(|v| v.as_str()).map(String::from),
scopes: extract_scopes(&claims),
};
ctx.set_jwt_claims(claims).await;
ctx.set_principal(principal).await;
next(ctx).await
}
Err(_) if cfg.optional => next(ctx).await,
Err(_) => Ok(unauthorized()),
},
None if cfg.optional => next(ctx).await,
None => Ok(unauthorized()),
}
}) as Pin<Box<dyn Future<Output = Result<Response>> + Send>>
})
}
}
fn unauthorized() -> Response {
ResponseBuilder::new()
.status(401)
.header("WWW-Authenticate", "Bearer")
.text("Unauthorized")
.build()
.unwrap_or_else(|_| crate::response::helpers::text("Unauthorized").unwrap())
}
fn extract_scopes(claims: &serde_json::Value) -> Vec<String> {
let mut scopes: Vec<String> = Vec::new();
if let Some(s) = claims.get("scope").and_then(|v| v.as_str()) {
scopes.extend(s.split_whitespace().map(String::from));
}
for key in ["scopes", "scp"] {
match claims.get(key) {
Some(serde_json::Value::Array(arr)) => {
scopes.extend(arr.iter().filter_map(|v| v.as_str()).map(String::from));
}
Some(serde_json::Value::String(s)) => {
scopes.extend(s.split_whitespace().map(String::from));
}
_ => {}
}
}
scopes.sort();
scopes.dedup();
scopes
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Claims {
sub: String,
exp: usize,
}
fn far_future() -> usize {
4_102_444_800
}
#[test]
fn sign_then_decode_roundtrip() {
let jwt = Jwt::hs256(b"test-secret");
let token = jwt
.sign(&Claims {
sub: "ada".into(),
exp: far_future(),
})
.unwrap();
assert_eq!(token.split('.').count(), 3);
let claims: Claims = jwt.decode(&token).unwrap();
assert_eq!(
claims,
Claims {
sub: "ada".into(),
exp: far_future()
}
);
}
#[test]
fn decode_rejects_bad_signature() {
let signer = Jwt::hs256(b"secret-a");
let verifier = Jwt::hs256(b"secret-b");
let token = signer
.sign(&Claims {
sub: "ada".into(),
exp: far_future(),
})
.unwrap();
assert!(verifier.decode::<Claims>(&token).is_err());
}
#[test]
fn decode_rejects_expired() {
let jwt = Jwt::hs256(b"secret");
let token = jwt
.sign(&Claims {
sub: "ada".into(),
exp: 1,
})
.unwrap();
assert!(jwt.decode::<Claims>(&token).is_err());
}
#[test]
fn extract_scopes_parses_standard_claims() {
let s = extract_scopes(&serde_json::json!({ "scope": "read write" }));
assert_eq!(s, vec!["read".to_string(), "write".to_string()]);
let s = extract_scopes(&serde_json::json!({ "scopes": ["admin", "read"] }));
assert_eq!(s, vec!["admin".to_string(), "read".to_string()]);
let s = extract_scopes(&serde_json::json!({ "scope": "read", "scp": "read admin" }));
assert_eq!(s, vec!["admin".to_string(), "read".to_string()]);
assert!(extract_scopes(&serde_json::json!({ "sub": "ada" })).is_empty());
}
#[test]
fn bearer_parsing_extracts_token() {
assert_eq!(
parse_bearer("Bearer abc.def.ghi"),
Some("abc.def.ghi".to_string())
);
assert_eq!(parse_bearer("bearer xyz"), Some("xyz".to_string()));
assert_eq!(parse_bearer("Basic abc"), None);
assert_eq!(parse_bearer("Bearer"), None);
assert_eq!(parse_bearer("Bearer "), None);
}
}