use crate::model::Claims;
use jsonwebtoken::{encode, decode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation};
use serde_json::Value;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH, Duration};
use crate::errors::JwtServiceError;
type Result<T> = std::result::Result<T, JwtServiceError>;
pub enum JwtAlgorithm {
RS256 {
access_private: Vec<u8>,
access_public: Vec<u8>,
refresh_private: Vec<u8>,
refresh_public: Vec<u8>,
},
ES256 {
access_private: Vec<u8>,
access_public: Vec<u8>,
refresh_private: Vec<u8>,
refresh_public: Vec<u8>,
},
}
pub trait JwtKeyPair: Send + Sync {
fn generate_token(&self, kid: &str, sub: &str, expires_in: usize, extra: Option<HashMap<String, Value>>, is_access: bool) -> Result<String>;
fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> ;
}
pub struct JwtKeys {
backend: Box<dyn JwtKeyPair>,
}
impl JwtKeys {
pub fn from_algorithm(algo: JwtAlgorithm) -> Result<Self> {
match algo {
JwtAlgorithm::RS256 {
access_private,
access_public,
refresh_private,
refresh_public,
} => Ok(Self {
backend: Box::new(Rs256KeyPair {
access_enc: EncodingKey::from_rsa_pem(&access_private)?,
access_dec: DecodingKey::from_rsa_pem(&access_public)?,
refresh_enc: EncodingKey::from_rsa_pem(&refresh_private)?,
refresh_dec: DecodingKey::from_rsa_pem(&refresh_public)?,
}),
}),
JwtAlgorithm::ES256 {
access_private,
access_public,
refresh_private,
refresh_public,
} => Ok(Self {
backend: Box::new(Es256KeyPair {
access_enc: EncodingKey::from_ec_pem(&access_private)?,
access_dec: DecodingKey::from_ec_pem(&access_public)?,
refresh_enc: EncodingKey::from_ec_pem(&refresh_private)?,
refresh_dec: DecodingKey::from_ec_pem(&refresh_public)?,
}),
}),
}
}
pub fn generate_access_token(&self, kid: &str, user_id: &str, expires_in: usize, extra: Option<HashMap<String, Value>>) -> Result<String> {
self.backend.generate_token(kid, user_id, expires_in, extra, true)
}
pub fn generate_refresh_token(&self, kid: &str, user_id: &str, expires_in: usize, extra: Option<HashMap<String, Value>>) -> Result<String> {
self.backend.generate_token(kid, user_id, expires_in, extra, false)
}
pub fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> {
self.backend.decode_token(token, token_type, audiences)
}
}
pub struct Rs256KeyPair {
access_enc: EncodingKey,
access_dec: DecodingKey,
refresh_enc: EncodingKey,
refresh_dec: DecodingKey,
}
impl JwtKeyPair for Rs256KeyPair {
fn generate_token(&self, kid: &str, sub: &str, expires_in: usize, extra: Option<HashMap<String, Value>>, is_access: bool) -> Result<String> {
let exp = current_timestamp() + expires_in;
let claims = Claims {
sub: sub.to_string(),
exp,
extra: extra.unwrap_or_default(),
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(kid.to_string());
let enc = if is_access { &self.access_enc } else { &self.refresh_enc };
encode(&header, &claims, enc).map_err(JwtServiceError::from)
}
fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> {
let is_access = match token_type {
"access" => true,
"refresh" => false,
_ => return Err(JwtServiceError::InvalidToken),
};
let dec_key: &DecodingKey = if is_access { &self.access_dec } else { &self.refresh_dec };
let mut validation = Validation::new(Algorithm::RS256);
if let Some(auds) = audiences {
let aud_refs: Vec<&str> = auds.iter().map(String::as_str).collect();
validation.set_audience(&aud_refs);
}
let decoded = decode::<Claims>(token, dec_key, &validation)
.map_err(JwtServiceError::from)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| JwtServiceError::InvalidToken)?
.as_secs() as usize;
if decoded.claims.exp < now {
return Err(JwtServiceError::TokenExpired);
}
Ok(decoded)
}
}
pub struct Es256KeyPair {
access_enc: EncodingKey,
access_dec: DecodingKey,
refresh_enc: EncodingKey,
refresh_dec: DecodingKey,
}
impl JwtKeyPair for Es256KeyPair {
fn generate_token(&self, kid: &str, sub: &str, expires_in: usize, extra: Option<HashMap<String, Value>>, is_access: bool) -> Result<String> {
let exp = current_timestamp() + expires_in;
let claims = Claims {
sub: sub.to_string(),
exp,
extra: extra.unwrap_or_default(),
};
let mut header = Header::new(Algorithm::ES256);
header.kid = Some(kid.to_string());
let enc = if is_access { &self.access_enc } else { &self.refresh_enc };
encode(&header, &claims, enc).map_err(JwtServiceError::from)
}
fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> {
let is_access = match token_type {
"access" => true,
"refresh" => false,
_ => return Err(JwtServiceError::InvalidToken),
};
let dec_key = if is_access { &self.access_dec } else { &self.refresh_dec };
let mut validation = Validation::new(Algorithm::ES256);
if let Some(auds) = audiences {
let aud_refs: Vec<&str> = auds.iter().map(String::as_str).collect();
validation.set_audience(&aud_refs);
}
let decoded = decode::<Claims>(token, dec_key, &validation)
.map_err(JwtServiceError::from)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| JwtServiceError::InvalidToken)?
.as_secs() as usize;
if decoded.claims.exp < now {
return Err(JwtServiceError::TokenExpired);
}
Ok(decoded)
}
}
pub fn current_timestamp() -> usize {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs() as usize
}
pub fn shorten_token(token: &str) -> String {
const LEN: usize = 32;
if token.len() <= LEN * 2 {
token.to_string()
} else {
format!("{}...{}", &token[..LEN], &token[token.len() - LEN..])
}
}