use crate::model::Claims;
use jsonwebtoken::{encode, decode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, errors::{Error, ErrorKind, Result}};
use serde_json::Value;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH, Duration};
pub enum JwtAlgorithm {
RS256 {
access_private: Vec<u8>,
access_public: Vec<u8>,
refresh_private: Vec<u8>,
refresh_public: Vec<u8>,
},
}
pub trait JwtKeyPair: Send + Sync {
fn generate_token(&self, kid: &str, sub: &str, expires_in: usize, extra: Option<HashMap<String, Value>>, is_access: bool) -> Result<String>;
fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> ;
}
pub struct JwtKeys {
backend: Box<dyn JwtKeyPair>,
}
impl JwtKeys {
pub fn from_algorithm(algo: JwtAlgorithm) -> Result<Self> {
match algo {
JwtAlgorithm::RS256 {
access_private,
access_public,
refresh_private,
refresh_public,
} => Ok(Self {
backend: Box::new(Rs256KeyPair {
access_enc: EncodingKey::from_rsa_pem(&access_private)?,
access_dec: DecodingKey::from_rsa_pem(&access_public)?,
refresh_enc: EncodingKey::from_rsa_pem(&refresh_private)?,
refresh_dec: DecodingKey::from_rsa_pem(&refresh_public)?,
}),
}),
}
}
pub fn generate_access_token(&self, kid: &str, user_id: &str, expires_in: usize, extra: Option<HashMap<String, Value>>) -> Result<String> {
self.backend.generate_token(kid, user_id, expires_in, extra, true)
}
pub fn generate_refresh_token(&self, kid: &str, user_id: &str, expires_in: usize, extra: Option<HashMap<String, Value>>) -> Result<String> {
self.backend.generate_token(kid, user_id, expires_in, extra, false)
}
pub fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> {
self.backend.decode_token(token, token_type, audiences)
}
}
struct Rs256KeyPair {
access_enc: EncodingKey,
access_dec: DecodingKey,
refresh_enc: EncodingKey,
refresh_dec: DecodingKey,
}
impl JwtKeyPair for Rs256KeyPair {
fn generate_token(&self, kid: &str, sub: &str, expires_in: usize, extra: Option<HashMap<String, Value>>, is_access: bool) -> Result<String> {
let exp = current_timestamp() + expires_in;
let claims = Claims {
sub: sub.to_string(),
exp,
extra: extra.unwrap_or_default(),
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(kid.to_string());
let enc = if is_access { &self.access_enc } else { &self.refresh_enc };
encode(&header, &claims, enc)
}
fn decode_token(&self, token: &str, token_type: &str, audiences: Option<Vec<String>>) -> Result<TokenData<Claims>> {
let is_access = match token_type {
"access" => true,
"refresh" => false,
_ => {
return Err(Error::from(ErrorKind::InvalidToken));
}
};
let dec = if is_access { &self.access_dec } else { &self.refresh_dec };
let mut validation = Validation::new(Algorithm::RS256);
let audience_slice: Vec<&str> = audiences
.as_ref()
.map(|aud_list| aud_list.iter().map(AsRef::as_ref).collect())
.unwrap_or_else(Vec::new);
validation.set_audience(&audience_slice);
let decoded = match decode::<Claims>(token, dec, &validation) {
Ok(decoded_token) => decoded_token,
Err(_e) => {
return Err(_e);
}
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| Error::from(ErrorKind::InvalidIssuer))?
.as_secs() as usize;
if decoded.claims.exp < now {
return Err(Error::from(ErrorKind::ExpiredSignature));
}
Ok(decoded)
}
}
pub fn current_timestamp() -> usize {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0))
.as_secs() as usize
}
pub fn shorten_token(token: &str) -> String {
const LEN: usize = 32;
if token.len() <= LEN * 2 {
token.to_string()
} else {
format!("{}...{}", &token[..LEN], &token[token.len() - LEN..])
}
}