use crate::error::{Result, TidewayError};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Clone)]
pub struct JwtIssuerConfig {
secret: SecretKey,
pub issuer: String,
pub audience: Option<String>,
pub key_id: Option<String>,
pub access_token_ttl: Duration,
pub refresh_token_ttl: Duration,
pub remember_me_ttl: Duration,
algorithm: Algorithm,
}
#[derive(Clone)]
enum SecretKey {
Symmetric(Vec<u8>),
Rsa { private_pem: Vec<u8> },
}
impl JwtIssuerConfig {
pub fn with_secret(secret: impl Into<String>, issuer: impl Into<String>) -> Self {
Self {
secret: SecretKey::Symmetric(secret.into().into_bytes()),
issuer: issuer.into(),
audience: None,
key_id: None,
access_token_ttl: Duration::from_secs(15 * 60), refresh_token_ttl: Duration::from_secs(7 * 24 * 60 * 60), remember_me_ttl: Duration::from_secs(30 * 24 * 60 * 60), algorithm: Algorithm::HS256,
}
}
pub fn with_rsa_private_key(
private_pem: impl Into<Vec<u8>>,
issuer: impl Into<String>,
) -> Self {
Self {
secret: SecretKey::Rsa {
private_pem: private_pem.into(),
},
issuer: issuer.into(),
audience: None,
key_id: None,
access_token_ttl: Duration::from_secs(15 * 60),
refresh_token_ttl: Duration::from_secs(7 * 24 * 60 * 60),
remember_me_ttl: Duration::from_secs(30 * 24 * 60 * 60),
algorithm: Algorithm::RS256,
}
}
pub fn audience(mut self, aud: impl Into<String>) -> Self {
self.audience = Some(aud.into());
self
}
pub fn key_id(mut self, kid: impl Into<String>) -> Self {
self.key_id = Some(kid.into());
self
}
pub fn access_token_ttl(mut self, ttl: Duration) -> Self {
self.access_token_ttl = ttl;
self
}
pub fn refresh_token_ttl(mut self, ttl: Duration) -> Self {
self.refresh_token_ttl = ttl;
self
}
pub fn remember_me_ttl(mut self, ttl: Duration) -> Self {
self.remember_me_ttl = ttl;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardClaims {
pub sub: String,
pub iss: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
pub exp: u64,
pub iat: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<u64>,
pub jti: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Access,
Refresh,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenClaims<T = ()>
where
T: Serialize,
{
#[serde(flatten)]
pub standard: StandardClaims,
pub token_type: TokenType,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub custom: Option<T>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshTokenClaims {
#[serde(flatten)]
pub standard: StandardClaims,
pub token_type: TokenType,
pub family: String,
pub generation: u32,
}
#[derive(Debug, Clone, Serialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub expires_in: u64,
pub token_type: &'static str,
#[serde(skip)]
pub family: String,
}
pub struct TokenSubject<'a, T = ()>
where
T: Serialize,
{
pub user_id: &'a str,
pub email: Option<&'a str>,
pub name: Option<&'a str>,
pub custom: Option<T>,
}
impl<'a> TokenSubject<'a, ()> {
pub fn new(user_id: &'a str) -> Self {
Self {
user_id,
email: None,
name: None,
custom: None,
}
}
}
impl<'a, T: Serialize> TokenSubject<'a, T> {
pub fn with_email(mut self, email: &'a str) -> Self {
self.email = Some(email);
self
}
pub fn with_name(mut self, name: &'a str) -> Self {
self.name = Some(name);
self
}
pub fn with_custom<U: Serialize>(self, custom: U) -> TokenSubject<'a, U> {
TokenSubject {
user_id: self.user_id,
email: self.email,
name: self.name,
custom: Some(custom),
}
}
}
#[derive(Clone)]
pub struct JwtIssuer {
config: JwtIssuerConfig,
encoding_key: EncodingKey,
}
impl JwtIssuer {
pub fn new(config: JwtIssuerConfig) -> Result<Self> {
let encoding_key = match &config.secret {
SecretKey::Symmetric(secret) => EncodingKey::from_secret(secret),
SecretKey::Rsa { private_pem } => EncodingKey::from_rsa_pem(private_pem)
.map_err(|e| TidewayError::Internal(format!("Invalid RSA private key: {}", e)))?,
};
Ok(Self {
config,
encoding_key,
})
}
fn build_header(&self) -> Header {
let mut header = Header::new(self.config.algorithm);
if let Some(ref kid) = self.config.key_id {
header.kid = Some(kid.clone());
}
header
}
pub fn issue<T: Serialize>(
&self,
subject: TokenSubject<'_, T>,
remember_me: bool,
) -> Result<TokenPair> {
let now = current_timestamp();
let jti_access = generate_jti();
let jti_refresh = generate_jti();
let family = generate_token_family();
let access_claims = AccessTokenClaims {
standard: StandardClaims {
sub: subject.user_id.to_string(),
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
exp: now + self.config.access_token_ttl.as_secs(),
iat: now,
nbf: Some(now),
jti: jti_access,
},
token_type: TokenType::Access,
email: subject.email.map(String::from),
name: subject.name.map(String::from),
custom: subject.custom,
};
let refresh_ttl = if remember_me {
self.config.remember_me_ttl
} else {
self.config.refresh_token_ttl
};
let refresh_claims = RefreshTokenClaims {
standard: StandardClaims {
sub: subject.user_id.to_string(),
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
exp: now + refresh_ttl.as_secs(),
iat: now,
nbf: Some(now),
jti: jti_refresh,
},
token_type: TokenType::Refresh,
family: family.clone(),
generation: 0,
};
let header = self.build_header();
let access_token = encode(&header, &access_claims, &self.encoding_key)
.map_err(|e| TidewayError::Internal(format!("Failed to encode access token: {}", e)))?;
let refresh_token = encode(&header, &refresh_claims, &self.encoding_key).map_err(|e| {
TidewayError::Internal(format!("Failed to encode refresh token: {}", e))
})?;
Ok(TokenPair {
access_token,
refresh_token,
expires_in: self.config.access_token_ttl.as_secs(),
token_type: "Bearer",
family,
})
}
pub fn issue_access_token<T: Serialize>(
&self,
subject: TokenSubject<'_, T>,
) -> Result<(String, u64)> {
let now = current_timestamp();
let jti = generate_jti();
let claims = AccessTokenClaims {
standard: StandardClaims {
sub: subject.user_id.to_string(),
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
exp: now + self.config.access_token_ttl.as_secs(),
iat: now,
nbf: Some(now),
jti,
},
token_type: TokenType::Access,
email: subject.email.map(String::from),
name: subject.name.map(String::from),
custom: subject.custom,
};
let header = self.build_header();
let token = encode(&header, &claims, &self.encoding_key)
.map_err(|e| TidewayError::Internal(format!("Failed to encode access token: {}", e)))?;
Ok((token, self.config.access_token_ttl.as_secs()))
}
pub fn rotate_refresh_token(&self, old_claims: &RefreshTokenClaims) -> Result<String> {
let now = current_timestamp();
let jti = generate_jti();
let remaining = old_claims.standard.exp.saturating_sub(now);
let new_exp = now + remaining;
let claims = RefreshTokenClaims {
standard: StandardClaims {
sub: old_claims.standard.sub.clone(),
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
exp: new_exp,
iat: now,
nbf: Some(now),
jti,
},
token_type: TokenType::Refresh,
family: old_claims.family.clone(),
generation: old_claims.generation + 1,
};
let header = self.build_header();
encode(&header, &claims, &self.encoding_key)
.map_err(|e| TidewayError::Internal(format!("Failed to encode refresh token: {}", e)))
}
pub fn issuer(&self) -> &str {
&self.config.issuer
}
pub fn audience(&self) -> Option<&str> {
self.config.audience.as_deref()
}
pub fn algorithm(&self) -> Algorithm {
self.config.algorithm
}
pub fn key_id(&self) -> Option<&str> {
self.config.key_id.as_deref()
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn generate_jti() -> String {
use rand::RngCore;
let mut bytes = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
fn generate_token_family() -> String {
use rand::RngCore;
let mut bytes = [0u8; 12];
rand::rngs::OsRng.fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{DecodingKey, Validation, decode};
fn test_issuer() -> JwtIssuer {
JwtIssuer::new(JwtIssuerConfig::with_secret(
"test-secret-key-32-bytes-long!!",
"test-app",
))
.unwrap()
}
#[test]
fn test_issue_token_pair() {
let issuer = test_issuer();
let subject = TokenSubject::new("user-123")
.with_email("test@example.com")
.with_name("Test User");
let pair = issuer.issue(subject, false).unwrap();
assert!(!pair.access_token.is_empty());
assert!(!pair.refresh_token.is_empty());
assert_eq!(pair.token_type, "Bearer");
assert!(pair.expires_in > 0);
}
#[test]
fn test_access_token_claims() {
let issuer = test_issuer();
let subject = TokenSubject::new("user-123").with_email("test@example.com");
let pair = issuer.issue(subject, false).unwrap();
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&["test-app"]);
validation.set_required_spec_claims(&["exp", "iat", "sub"]);
let decoded = decode::<AccessTokenClaims>(
&pair.access_token,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
assert_eq!(decoded.claims.standard.sub, "user-123");
assert_eq!(decoded.claims.email, Some("test@example.com".to_string()));
assert_eq!(decoded.claims.token_type, TokenType::Access);
}
#[test]
fn test_refresh_token_rotation() {
let issuer = test_issuer();
let subject = TokenSubject::new("user-123");
let pair = issuer.issue(subject, false).unwrap();
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&["test-app"]);
let decoded = decode::<RefreshTokenClaims>(
&pair.refresh_token,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
assert_eq!(decoded.claims.generation, 0);
let rotated = issuer.rotate_refresh_token(&decoded.claims).unwrap();
let decoded_rotated = decode::<RefreshTokenClaims>(
&rotated,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
assert_eq!(decoded_rotated.claims.generation, 1);
assert_eq!(decoded_rotated.claims.family, decoded.claims.family);
}
#[test]
fn test_custom_claims() {
#[derive(Serialize, Deserialize)]
struct CustomClaims {
org_id: String,
role: String,
}
let issuer = test_issuer();
let subject = TokenSubject::new("user-123").with_custom(CustomClaims {
org_id: "org-456".to_string(),
role: "admin".to_string(),
});
let pair = issuer.issue(subject, false).unwrap();
#[derive(Deserialize)]
struct FullClaims {
sub: String,
org_id: String,
role: String,
}
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&["test-app"]);
let decoded = decode::<FullClaims>(
&pair.access_token,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
assert_eq!(decoded.claims.sub, "user-123");
assert_eq!(decoded.claims.org_id, "org-456");
assert_eq!(decoded.claims.role, "admin");
}
#[test]
fn test_remember_me_extends_refresh() {
let issuer = test_issuer();
let subject = TokenSubject::new("user-123");
let normal = issuer.issue(subject, false).unwrap();
let subject = TokenSubject::new("user-123");
let remembered = issuer.issue(subject, true).unwrap();
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&["test-app"]);
let normal_claims = decode::<RefreshTokenClaims>(
&normal.refresh_token,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
let remembered_claims = decode::<RefreshTokenClaims>(
&remembered.refresh_token,
&DecodingKey::from_secret(b"test-secret-key-32-bytes-long!!"),
&validation,
)
.unwrap();
assert!(remembered_claims.claims.standard.exp > normal_claims.claims.standard.exp);
}
}