use std::collections::HashMap;
pub trait Auth: Send + Sync + 'static {
fn resolve(&self, token: &str) -> Option<String>;
}
#[derive(Debug, Clone, Default)]
pub struct BearerAuth {
tokens: HashMap<String, String>,
}
impl BearerAuth {
pub fn new() -> Self {
Self::default()
}
pub fn from_pairs<I, T, V>(iter: I) -> Self
where
I: IntoIterator<Item = (T, V)>,
T: Into<String>,
V: Into<String>,
{
let mut auth = Self::new();
for (t, v) in iter {
auth.insert(t, v);
}
auth
}
pub fn with(mut self, token: impl Into<String>, vid: impl Into<String>) -> Self {
self.insert(token, vid);
self
}
pub fn insert(&mut self, token: impl Into<String>, vid: impl Into<String>) {
self.tokens.insert(token.into(), vid.into());
}
}
impl Auth for BearerAuth {
fn resolve(&self, token: &str) -> Option<String> {
self.tokens.get(token).cloned()
}
}
#[cfg(feature = "jwt")]
#[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
pub use jwt::JwtBearerAuth;
#[cfg(feature = "jwt")]
#[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
mod jwt {
use std::marker::PhantomData;
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::de::DeserializeOwned;
use super::Auth;
pub struct JwtBearerAuth<C: DeserializeOwned + Send + Sync + 'static> {
key: DecodingKey,
validation: Validation,
extract: Box<ExtractFn<C>>,
_phantom: PhantomData<fn() -> C>,
}
type ExtractFn<C> = dyn Fn(&C) -> Option<String> + Send + Sync;
impl<C: DeserializeOwned + Send + Sync + 'static> JwtBearerAuth<C> {
pub fn new<F>(key: DecodingKey, validation: Validation, extract: F) -> Self
where
F: Fn(&C) -> Option<String> + Send + Sync + 'static,
{
assert!(
!validation.algorithms.is_empty(),
"JwtBearerAuth requires Validation to pin at least one algorithm; \
an empty `algorithms` list disables algorithm checking and \
exposes algorithm-confusion attacks"
);
Self {
key,
validation,
extract: Box::new(extract),
_phantom: PhantomData,
}
}
}
impl<C: DeserializeOwned + Send + Sync + 'static> std::fmt::Debug for JwtBearerAuth<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtBearerAuth")
.field("claims_type", &std::any::type_name::<C>())
.finish_non_exhaustive()
}
}
impl<C: DeserializeOwned + Send + Sync + 'static> Auth for JwtBearerAuth<C> {
fn resolve(&self, token: &str) -> Option<String> {
let data = decode::<C>(token, &self.key, &self.validation).ok()?;
(self.extract)(&data.claims)
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Claims {
sub: String,
vid: Option<String>,
exp: usize,
}
fn sign(secret: &[u8], claims: &Claims) -> String {
encode(
&Header::new(Algorithm::HS256),
claims,
&EncodingKey::from_secret(secret),
)
.unwrap()
}
fn auth(secret: &[u8]) -> JwtBearerAuth<Claims> {
let mut v = Validation::new(Algorithm::HS256);
v.validate_exp = true;
v.required_spec_claims.clear();
JwtBearerAuth::new(DecodingKey::from_secret(secret), v, |c: &Claims| {
c.vid.clone().or_else(|| Some(c.sub.clone()))
})
}
#[test]
fn valid_jwt_resolves_to_vid_via_claims_closure() {
let secret = b"test-secret";
let token = sign(
secret,
&Claims {
sub: "user-123".into(),
vid: Some("did:web:alice.example".into()),
exp: 10_000_000_000,
},
);
let auth = auth(secret);
assert_eq!(
auth.resolve(&token),
Some("did:web:alice.example".to_string())
);
}
#[test]
#[should_panic(expected = "at least one algorithm")]
fn empty_validation_algorithms_is_rejected_at_construction() {
let mut v = Validation::new(Algorithm::HS256);
v.algorithms.clear(); let _ = JwtBearerAuth::new(DecodingKey::from_secret(b"x"), v, |c: &Claims| {
Some(c.sub.clone())
});
}
#[test]
fn token_with_unexpected_alg_is_rejected() {
let secret = b"test-secret";
let token = encode(
&Header::new(Algorithm::HS512),
&Claims {
sub: "did:web:eve.example".into(),
vid: None,
exp: 10_000_000_000,
},
&EncodingKey::from_secret(secret),
)
.unwrap();
assert_eq!(auth(secret).resolve(&token), None);
}
#[test]
fn jwt_without_vid_claim_falls_back_to_subject() {
let secret = b"test-secret";
let token = sign(
secret,
&Claims {
sub: "did:web:fallback.example".into(),
vid: None,
exp: 10_000_000_000,
},
);
let auth = auth(secret);
assert_eq!(
auth.resolve(&token),
Some("did:web:fallback.example".to_string())
);
}
#[test]
fn wrong_signature_returns_none() {
let token = sign(
b"correct-secret",
&Claims {
sub: "x".into(),
vid: None,
exp: 10_000_000_000,
},
);
let auth = auth(b"wrong-secret");
assert!(auth.resolve(&token).is_none());
}
#[test]
fn expired_jwt_returns_none() {
let token = sign(
b"test-secret",
&Claims {
sub: "x".into(),
vid: None,
exp: 1, },
);
let auth = auth(b"test-secret");
assert!(auth.resolve(&token).is_none());
}
#[test]
fn malformed_jwt_returns_none() {
let auth = auth(b"test-secret");
assert!(auth.resolve("not-a-jwt").is_none());
}
}
}