use std::collections::HashMap;
pub trait Auth: Send + Sync + 'static {
fn resolve(&self, token: &str) -> Option<String>;
}
#[derive(Debug, Clone, Default)]
pub struct BearerAuth {
tokens: HashMap<String, String>,
}
impl BearerAuth {
pub fn new() -> Self {
Self::default()
}
pub fn from_pairs<I, T, V>(iter: I) -> Self
where
I: IntoIterator<Item = (T, V)>,
T: Into<String>,
V: Into<String>,
{
let mut auth = Self::new();
for (t, v) in iter {
auth.insert(t, v);
}
auth
}
pub fn with(mut self, token: impl Into<String>, vid: impl Into<String>) -> Self {
self.insert(token, vid);
self
}
pub fn insert(&mut self, token: impl Into<String>, vid: impl Into<String>) {
self.tokens.insert(token.into(), vid.into());
}
}
impl Auth for BearerAuth {
fn resolve(&self, token: &str) -> Option<String> {
self.tokens.get(token).cloned()
}
}
#[cfg(feature = "jwt")]
#[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
pub use jwt::JwtBearerAuth;
#[cfg(feature = "jwt")]
#[cfg_attr(docsrs, doc(cfg(feature = "jwt")))]
mod jwt {
use std::marker::PhantomData;
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::de::DeserializeOwned;
use super::Auth;
pub struct JwtBearerAuth<C: DeserializeOwned + Send + Sync + 'static> {
key: DecodingKey,
validation: Validation,
extract: Box<ExtractFn<C>>,
_phantom: PhantomData<fn() -> C>,
}
type ExtractFn<C> = dyn Fn(&C) -> Option<String> + Send + Sync;
impl<C: DeserializeOwned + Send + Sync + 'static> JwtBearerAuth<C> {
pub fn new<F>(key: DecodingKey, validation: Validation, extract: F) -> Self
where
F: Fn(&C) -> Option<String> + Send + Sync + 'static,
{
Self {
key,
validation,
extract: Box::new(extract),
_phantom: PhantomData,
}
}
}
impl<C: DeserializeOwned + Send + Sync + 'static> std::fmt::Debug for JwtBearerAuth<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtBearerAuth")
.field("claims_type", &std::any::type_name::<C>())
.finish_non_exhaustive()
}
}
impl<C: DeserializeOwned + Send + Sync + 'static> Auth for JwtBearerAuth<C> {
fn resolve(&self, token: &str) -> Option<String> {
let data = decode::<C>(token, &self.key, &self.validation).ok()?;
(self.extract)(&data.claims)
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Claims {
sub: String,
vid: Option<String>,
exp: usize,
}
fn sign(secret: &[u8], claims: &Claims) -> String {
encode(
&Header::new(Algorithm::HS256),
claims,
&EncodingKey::from_secret(secret),
)
.unwrap()
}
fn auth(secret: &[u8]) -> JwtBearerAuth<Claims> {
let mut v = Validation::new(Algorithm::HS256);
v.validate_exp = true;
v.required_spec_claims.clear();
JwtBearerAuth::new(DecodingKey::from_secret(secret), v, |c: &Claims| {
c.vid.clone().or_else(|| Some(c.sub.clone()))
})
}
#[test]
fn valid_jwt_resolves_to_vid_via_claims_closure() {
let secret = b"test-secret";
let token = sign(
secret,
&Claims {
sub: "user-123".into(),
vid: Some("did:web:alice.example".into()),
exp: 10_000_000_000,
},
);
let auth = auth(secret);
assert_eq!(
auth.resolve(&token),
Some("did:web:alice.example".to_string())
);
}
#[test]
fn jwt_without_vid_claim_falls_back_to_subject() {
let secret = b"test-secret";
let token = sign(
secret,
&Claims {
sub: "did:web:fallback.example".into(),
vid: None,
exp: 10_000_000_000,
},
);
let auth = auth(secret);
assert_eq!(
auth.resolve(&token),
Some("did:web:fallback.example".to_string())
);
}
#[test]
fn wrong_signature_returns_none() {
let token = sign(
b"correct-secret",
&Claims {
sub: "x".into(),
vid: None,
exp: 10_000_000_000,
},
);
let auth = auth(b"wrong-secret");
assert!(auth.resolve(&token).is_none());
}
#[test]
fn expired_jwt_returns_none() {
let token = sign(
b"test-secret",
&Claims {
sub: "x".into(),
vid: None,
exp: 1, },
);
let auth = auth(b"test-secret");
assert!(auth.resolve(&token).is_none());
}
#[test]
fn malformed_jwt_returns_none() {
let auth = auth(b"test-secret");
assert!(auth.resolve("not-a-jwt").is_none());
}
}
}