use anyhow::{Context, Result};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct AccessClaims {
pub sub: String,
pub room: String,
pub session_id: String,
pub exp: usize,
pub iss: Option<String>,
pub aud: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RefreshClaims {
pub sub: String,
pub exp: usize,
}
#[derive(Clone)]
#[derive(Default)]
pub struct JwtAuthOptions {
pub leeway: u64,
pub issuer: Option<String>,
pub audience: Option<String>,
}
#[derive(Clone)]
pub struct JwtAuth {
secret: String,
options: JwtAuthOptions,
}
impl JwtAuth {
pub fn new(secret: &str) -> Self {
Self {
secret: secret.into(),
options: JwtAuthOptions::default(),
}
}
pub fn with_options(secret: &str, options: JwtAuthOptions) -> Self {
Self {
secret: secret.into(),
options,
}
}
pub fn sign_access(
&self,
user_id: String,
room_id: String,
session_id: String,
ttl_secs: usize,
) -> Result<String> {
let now = chrono::Utc::now().timestamp();
let exp = now.saturating_add(ttl_secs as i64) as usize;
let claims = AccessClaims {
sub: user_id,
room: room_id,
session_id,
exp,
iss: self.options.issuer.clone(),
aud: self.options.audience.clone(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(self.secret.as_ref()),
)
.context("Failed to encode access token.")
}
pub fn sign_refresh(&self, user_id: String, ttl_secs: usize) -> Result<String> {
let exp = chrono::Utc::now().timestamp() as usize + ttl_secs;
let claims = RefreshClaims { sub: user_id, exp };
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(self.secret.as_ref()),
)
.context("Failed to encode refresh token.")
}
pub fn verify_access(&self, token: &str) -> Result<AccessClaims> {
let mut validation = Validation::default();
validation.leeway = self.options.leeway;
if let Some(ref iss) = self.options.issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = self.options.audience {
validation.set_audience(&[aud]);
}
let data = decode::<AccessClaims>(
token,
&DecodingKey::from_secret(self.secret.as_ref()),
&validation,
)
.context("Failed to decode access token")?;
Ok(data.claims)
}
pub fn verify_refresh(&self, token: &str) -> Result<RefreshClaims> {
let mut validation = Validation::default();
validation.leeway = self.options.leeway;
let data = decode::<RefreshClaims>(
token,
&DecodingKey::from_secret(self.secret.as_ref()),
&validation,
)
.context("Failed to decode refresh token")?;
Ok(data.claims)
}
pub fn refresh_access(
&self,
refresh_token: &str,
room_id: String,
session_id: String,
access_ttl: usize,
) -> Result<String> {
let claims = self.verify_refresh(refresh_token)?;
self.sign_access(claims.sub, room_id, session_id, access_ttl)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
fn auth() -> JwtAuth {
JwtAuth::new("test-secret")
}
#[test]
fn access_token_roundtrip() {
let auth = auth();
let token = auth
.sign_access("user1".into(), "roomA".into(), "session1".into(), 60)
.unwrap();
let claims = auth.verify_access(&token).unwrap();
assert_eq!(claims.sub, "user1");
assert_eq!(claims.room, "roomA");
assert_eq!(claims.session_id, "session1");
}
#[test]
fn refresh_token_roundtrip() {
let auth = auth();
let token = auth.sign_refresh("user2".into(), 60).unwrap();
let claims = auth.verify_refresh(&token).unwrap();
assert_eq!(claims.sub, "user2");
}
#[test]
fn refresh_access_flow() {
let auth = auth();
let refresh = auth.sign_refresh("user3".into(), 60).unwrap();
let new_access = auth
.refresh_access(&refresh, "roomB".into(), "session3".into(), 60)
.unwrap();
let claims = auth.verify_access(&new_access).unwrap();
assert_eq!(claims.sub, "user3");
assert_eq!(claims.room, "roomB");
assert_eq!(claims.session_id, "session3");
}
#[test]
fn expired_access_token_fails() {
let auth = auth();
let token = auth
.sign_access("user4".into(), "roomC".into(), "session4".into(), 1)
.unwrap();
sleep(Duration::from_secs(2));
assert!(auth.verify_access(&token).is_err());
}
#[test]
fn expired_refresh_token_fails() {
let auth = auth();
let token = auth.sign_refresh("user5".into(), 1).unwrap();
sleep(Duration::from_secs(2));
assert!(auth.verify_refresh(&token).is_err());
}
}