use std::error::Error as StdError;
use std::time::SystemTime;
use base64::{engine::general_purpose, Engine};
use hmac::{Hmac, Mac};
use memchr::{memchr_iter, memrchr};
#[cfg(feature = "rsa")]
use rsa::{
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey},
pkcs1v15::{Signature, SigningKey, VerifyingKey},
signature::SignatureEncoding,
signature::{Signer, Verifier},
};
use serde::de::DeserializeOwned;
use serde_json::{Map, Value};
use sha2::Sha256;
use smallvec::SmallVec;
use thiserror::Error;
#[cfg(feature = "rsa")]
pub use rsa::{RsaPrivateKey, RsaPublicKey};
type Result<T> = std::result::Result<T, JwtError>;
type HmacSha256 = Hmac<Sha256>;
type StackVec = SmallVec<[u8; 512]>;
macro_rules! jwt_error {
($msg:expr) => {
JwtError::new($msg, file!(), line!())
};
($msg:expr, $err:ident) => {
JwtError::with_source($msg, file!(), line!(), $err)
};
}
#[derive(Error, Debug)]
#[error("{message} at {file}:{line}")]
pub struct JwtError {
message: String,
file: &'static str,
line: u32,
#[source]
source: Option<Box<dyn StdError + Send + Sync + 'static>>,
}
impl JwtError {
pub fn new(message: String, file: &'static str, line: u32) -> Self {
JwtError {
message,
file,
line,
source: None,
}
}
pub fn with_source<E>(message: String, file: &'static str, line: u32, source: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
JwtError {
message,
file,
line,
source: Some(Box::new(source)),
}
}
}
pub const AUTHORIZATION: &str = "Authorization";
pub const BEARER: &str = "Bearer ";
pub const WWW_AUTHENTICATE: &str = "WWW-Authenticate";
pub const ISSUER_KEY: &str = "iss";
pub const EXP_KEY: &str = "exp";
const HEADER_B64: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
#[cfg(feature = "rsa")]
const HEADER_RS256_B64: &str = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9";
pub fn encode(claims: Value, key: &str, issuer: &str, ttl: u64) -> Result<String> {
let claims_json = merge_claims(claims, issuer, ttl)?;
encode_raw(&claims_json, key)
}
pub fn encode_raw(claims: &str, key: &str) -> Result<String> {
let mut jwt_str = format!("{}.{}", HEADER_B64, claims);
let mut hs256 = HmacSha256::new_from_slice(key.as_bytes())
.map_err(|e| jwt_error!("hmac sha256 error".to_string(), e))?;
hs256.update(jwt_str.as_bytes());
let sign_bs = hs256.finalize();
let sign_b64 = general_purpose::URL_SAFE_NO_PAD.encode(sign_bs.into_bytes());
jwt_str.push('.');
jwt_str.push_str(&sign_b64);
Ok(jwt_str)
}
pub fn decode(jwt: &str, key: &str, issuer: &str) -> Result<Value> {
decode_custom(jwt, key, issuer, get_iss_exp_by_value)
}
pub fn decode_custom<T: DeserializeOwned>(
jwt: &str,
key: &str,
issuer: &str,
get_iss_exp: fn(&T) -> (Option<&str>, Option<u64>),
) -> Result<T> {
if key.is_empty() {
return Err(jwt_error!("jwt key is empty".to_string()));
}
let jwt = jwt.as_bytes();
let sp_ret = jwt_split_slice(jwt).map_err(|e| jwt_error!("jwt_split error".to_string(), e))?;
let (header_b64, claims_b64, sign_b64) = sp_ret;
if HEADER_B64.as_bytes() != header_b64 {
return Err(jwt_error!("token header error".to_string()));
}
let claims_str = decode_base64_slice(claims_b64)?;
let claims: T = serde_json::from_slice(&claims_str)
.map_err(|e| jwt_error!("deerialze json error".to_string(), e))?;
let (iss, exp) = get_iss_exp(&claims);
check_claims(issuer, &iss, &exp)?;
let mut header_claims_b64 = StackVec::new();
header_claims_b64.extend_from_slice(header_b64);
header_claims_b64.push(b'.');
header_claims_b64.extend_from_slice(claims_b64);
let sign_bs = decode_base64_slice(sign_b64)?;
let mut hs256 = HmacSha256::new_from_slice(key.as_bytes())
.map_err(|e| jwt_error!("hmac sha256 error".to_string(), e))?;
hs256.update(&header_claims_b64);
if let Err(e) = hs256.verify_slice(&sign_bs) {
return Err(jwt_error!("Signature verification failed".to_string(), e));
}
Ok(claims)
}
#[cfg(feature = "rsa")]
pub fn encode_with_rsa_default(claims: Value, issuer: &str, ttl: u64) -> Result<String> {
let pri_key = RsaPrivateKey::from_pkcs1_pem(rsa_key_data::RSA_PRIVATE_KEY)
.map_err(|e| jwt_error!("load private key from pkcs1 pem failed".to_string(), e))?;
encode_with_rsa(claims, pri_key, issuer, ttl)
}
#[cfg(feature = "rsa")]
pub fn encode_with_rsa(claims: Value, pri_key: RsaPrivateKey, issuer: &str, ttl: u64) -> Result<String> {
let claims_json = merge_claims(claims, issuer, ttl)?;
encode_with_rsa_raw(&claims_json, pri_key)
}
#[cfg(feature = "rsa")]
pub fn encode_with_rsa_raw(claims: &str, pri_key: RsaPrivateKey) -> Result<String> {
let mut jwt_str = format!("{}.{}", HEADER_RS256_B64, claims);
let rsa_sign = SigningKey::<Sha256>::new(pri_key);
let sign_bs: Signature = rsa_sign.sign(jwt_str.as_bytes());
let sign_b64 = general_purpose::URL_SAFE_NO_PAD.encode(sign_bs.to_bytes());
jwt_str.push('.');
jwt_str.push_str(&sign_b64);
Ok(jwt_str)
}
#[cfg(feature = "rsa")]
pub fn decode_with_rsa_default(jwt: &str, issuer: &str) -> Result<Value> {
let pub_key = RsaPublicKey::from_pkcs1_pem(rsa_key_data::RSA_PUBLIC_KEY)
.map_err(|e| jwt_error!("load public key from pkcs1 pem failed".to_string(), e))?;
decode_with_rsa(jwt, pub_key, issuer)
}
#[cfg(feature = "rsa")]
pub fn decode_with_rsa(jwt: &str, pub_key: RsaPublicKey, issuer: &str) -> Result<Value> {
decode_custom_with_rsa(jwt, pub_key, issuer, get_iss_exp_by_value)
}
#[cfg(feature = "rsa")]
pub fn decode_custom_with_rsa<T: DeserializeOwned>(
jwt: &str,
pub_key: RsaPublicKey,
issuer: &str,
get_iss_exp: fn(&T) -> (Option<&str>, Option<u64>),
) -> Result<T> {
let jwt = jwt.as_bytes();
let sp_ret = jwt_split_slice(jwt).map_err(|e| jwt_error!("jwt_split error".to_string(), e))?;
let (header_b64, claims_b64, sign_b64) = sp_ret;
if HEADER_RS256_B64.as_bytes() != header_b64 {
return Err(jwt_error!("token header error".to_string()));
}
let claims_str = decode_base64_slice(claims_b64)?;
let claims: T = serde_json::from_slice(&claims_str)
.map_err(|e| jwt_error!("deerialze json error".to_string(), e))?;
let (iss, exp) = get_iss_exp(&claims);
check_claims(issuer, &iss, &exp)?;
let mut header_claims_b64 = StackVec::new();
header_claims_b64.extend_from_slice(header_b64);
header_claims_b64.push(b'.');
header_claims_b64.extend_from_slice(claims_b64);
let sign_bs = decode_base64_slice(sign_b64)?;
let rsa_verify = VerifyingKey::<Sha256>::new(pub_key);
let sign = Signature::try_from(sign_bs.as_ref())
.map_err(|e| jwt_error!("rsa sign data error".to_string(), e))?;
if let Err(e) = rsa_verify.verify(&header_claims_b64, &sign) {
return Err(jwt_error!("Signature verification failed".to_string(), e));
}
Ok(claims)
}
pub fn jwt_split(jwt: &str) -> Result<(&str, &str, &str)> {
use std::str::from_utf8;
const UTF8_ERR: &str = "slice is not utf8";
let (s1, s2, s3) = jwt_split_slice(jwt.as_bytes())
.map_err(|e| jwt_error!("jwt_split error".to_string(), e))?;
let s1 = from_utf8(s1).map_err(|_| jwt_error!(UTF8_ERR.to_string()))?;
let s2 = from_utf8(s2).map_err(|_| jwt_error!(UTF8_ERR.to_string()))?;
let s3 = from_utf8(s3).map_err(|_| jwt_error!(UTF8_ERR.to_string()))?;
Ok((s1, s2, s3))
}
pub fn jwt_split_slice(jwt: &[u8]) -> Result<(&[u8], &[u8], &[u8])> {
let mut find_iter = memchr_iter(b'.', jwt);
let first_pos = match find_iter.next() {
Some(n) => n,
None => return Err(jwt_error!("token find '.' not found".to_string())),
};
let second_pos = match find_iter.next() {
Some(n) => n,
None => return Err(jwt_error!("token find '.' not found".to_string())),
};
let header_b64 = &jwt[..first_pos];
let claims_b64 = &jwt[first_pos + 1..second_pos];
let sign_b64 = &jwt[second_pos + 1..];
Ok((header_b64, claims_b64, sign_b64))
}
pub fn get_sign(jwt: &str) -> Option<&str> {
if let Some(pos) = memrchr(b'.', jwt.as_bytes()) {
if pos < jwt.len() - 1 {
return Some(&jwt[pos + 1..]);
}
}
None
}
pub fn get_issuer(claims: &Value) -> Result<&str> {
match claims.get(ISSUER_KEY) {
Some(Value::String(iss)) => Ok(iss),
_ => Err(jwt_error!("issuer not found".to_string())),
}
}
pub fn get_exp(claims: &Value) -> Result<u64> {
match claims.get(EXP_KEY) {
Some(exp) => match exp.as_u64() {
Some(exp) => Ok(exp),
None => Err(jwt_error!(format!("exp format error: {}", exp))),
},
None => Err(jwt_error!("not found exp".to_string())),
}
}
pub fn check_issuer<'a>(claims: &'a Value, issuer: &str) -> Result<&'a str> {
if let Some(Value::String(iss)) = claims.get(ISSUER_KEY) {
if issuer.is_empty() || issuer == iss {
Ok(iss)
} else {
let msg = format!("incorrect issuer, expected: [{}], actual: [{}]", issuer, iss);
Err(jwt_error!(msg))
}
} else if issuer.is_empty() {
Ok("")
} else {
Err(jwt_error!("not found issuer".to_string()))
}
}
pub fn check_exp(claims: &Value) -> Result<u64> {
match claims.get(EXP_KEY) {
Some(exp) => {
match exp.as_u64() {
Some(exp) => {
let now = unix_timestamp_after(0)?;
if exp >= now {
Ok(exp)
} else {
Err(jwt_error!("incorrect exp".to_string()))
}
}
None => Err(jwt_error!("exp format error".to_string()))
}
}
None => Err(jwt_error!("exp not found".to_string()))
}
}
pub fn check_claims(issuer: &str, iss: &Option<&str>, exp: &Option<u64>) -> Result<()> {
if !issuer.is_empty() {
match iss {
Some(iss) => if *iss != issuer {
let msg = format!("incorrect issuer, expected: [{issuer}], actual: [{iss}]");
return Err(jwt_error!(msg));
}
None => {
return Err(jwt_error!("jwt token iss not found".to_string()));
}
}
}
match exp {
Some(exp) => if *exp < unix_timestamp_after(0)? {
return Err(jwt_error!("Incorrect exp".to_string()));
},
None => return Err(jwt_error!("jwt token exp not found".to_string())),
}
Ok(())
}
pub fn decode_base64(base64: &str) -> Result<Vec<u8>> {
decode_base64_slice(base64.as_bytes())
}
pub fn decode_base64_slice(base64: &[u8]) -> Result<Vec<u8>> {
general_purpose::URL_SAFE_NO_PAD
.decode(base64)
.map_err(|e| jwt_error!("base64 decode error".to_string(), e))
}
pub fn unix_timestamp_after(after_now: u64) -> Result<u64> {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|v| v.as_secs() + after_now)
.map_err(|_| jwt_error!("unix_timestamp error".to_string()))
}
fn merge_claims(claims: Value, issuer: &str, ttl: u64) -> Result<String> {
debug_assert!(claims.is_null() || claims.is_object());
let exp = unix_timestamp_after(ttl)?;
let mut claims = value_to_map(claims)?;
if !issuer.is_empty() {
claims.insert(ISSUER_KEY.to_string(), issuer.into());
}
claims.insert(EXP_KEY.to_owned(), Value::Number(exp.into()));
let claims_bs = claims_to_json(&claims)?;
let claims_b64 = general_purpose::URL_SAFE_NO_PAD.encode(claims_bs);
Ok(claims_b64)
}
fn value_to_map(val: Value) -> Result<Map<String, Value>> {
match val {
Value::Null => Ok(Map::new()),
Value::Object(v) => Ok(v),
_ => Err(jwt_error!("token claims format error".to_string())),
}
}
fn claims_to_json(claims: &Map<String, Value>) -> Result<Vec<u8>> {
serde_json::to_vec(&claims)
.map_err(|e| jwt_error!("serialze claims to json failed".to_string(), e))
}
fn get_iss_exp_by_value(claims: &Value) -> (Option<&str>, Option<u64>) {
let iss = match claims.get(ISSUER_KEY) {
Some(Value::String(iss)) => Some(iss.as_str()),
_ => None,
};
let exp = claims.get(EXP_KEY).and_then(|v| v.as_u64());
(iss, exp)
}
#[cfg(feature = "rsa")]
mod rsa_key_data {
pub const RSA_PRIVATE_KEY: &str = r#"
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA1h9pD+0KYp5Bpda/OTWFVxaXKPO4+36LzNWk53PG4LOrrg2o
rJzdbhFwoqB20ceFQOZm9dK8udY3LFn4Pv1M01pkPRV39+URLds3W+CujnTQiCJ/
vBeWrIf7HYq6TM0oQcmzESvRsVf37xZpvlK21pxgsxg2pyYoqZrx24ttxm81ZtJj
v76QmzbaU3Lz+rYOfwxzeQelXZ0KDxWrptm+FlIspUzSGYxV6RmV66svaxzNi8mi
bQuIx8BWbVHGWU45cISC4+oSnqUirB8i/URgZhwHyYfFO/Tmmf/+NROjXAqbH7lr
bJBtAP+OXYuKfkBNLUqRmXax2uttbaefF0sxQQIDAQABAoIBADEKbKOrJK/Fkz+K
Wa2epnV1xRUqDPn818QIQoaIK8qXHAD3O+Sc4NIuyF9W5R/S1KAypO40X+koOOa9
jG/Qz+GwWDjtS9bI7hBUnu86HICgHIqxbBQGSwok8synU1f3vPqkWZDbOmGlxjFK
LtnaU+n/Ut5x80KBKNr/k9k2q+PAdWrna/LfcjXtVqkZVUl/3xD+Dp7IIM7nQQuX
Lnfvyk86gP+KKUPs23cnPFwqpWibJO7swJ9tlHQXOefVurvCJFzFhjyG8CkpMcn9
R1B9uUZGGTOYrD797vJ3Ll5JM9Wg+9cy2C8ovOGvEY0L6Y9UhGUqESVjYnXM9/bp
tixAG2ECgYEA8elVewkf29b0/5bHNiGimS/6Qo3/Hz4Ijb5xr9ye7LrbjaQbeDAW
4RwnDvz11zJAR9wRKtZApCGH5NEexZykdxjmxQGOGTz/PUpsRhXFpDW+u4BrwHBZ
3l4j/xoOxQbxIvE3hSZLg22FZVhwfWmBOM4P8L9V34IX8awqyYCBx7UCgYEA4pfF
22e7gtbjqkIlTh8ekZJi3obWdOFCR6r+4ySeUG+b3bL91Js0/2RIu5wB0+t9OZ5d
F7N54e0S2p5au2YcD8q1DmJj8b0tgOffHcCgGWnIhDtu6QA8NKxR3lWYHzJTq9Ju
ZhY5yqopXYOhifBQoliAWFecrgh4pqKFTzGk4t0CgYB6t+GzPpe40D0NA5IfdcSk
bWBJLvuC/9cbAMdvbT353XjPS7bbq5mPrNZrlguolUdirNLQpku4d4IWo7c2jBYq
jKlUu0s4pmbc0spGa3kNqm4NdEI1J0mPsrYUDUX80V62WSPPGfQowgBvvwOhu0ng
ZThU6ttHPRmkcbBq9BPiGQKBgGvZY1IHsIcY8qmB7DGfvDP7YdWahg6BfMORztmb
/0I3rQ87d3cvHG2GdNve6DvOtP6sspBqW1O+PCAUCQlzE14s1DpxeDKCIVtegaKu
oUUXRVoy05pRA1bqwdi6ErqegJaihOtQHteoYCHjWgrGeAqdZxElOizXWV2usxa7
gUh9AoGAZiIlLmGG0dDBvWCWjs0oiPCSIpeINtbNEUIZ5CtI9pUAdR2kMOUIRbta
VlGTzzTlP/uGZlLqRhe6QnLii0MeR7B6suM9JHg0bcLN0diwYpiit73+el1KJYwg
1aM3mKtopVX0gWXIAYbgPDGfxCR/0SbK1XVbBEHthO9Ns563CAU=
-----END RSA PRIVATE KEY-----
"#;
pub const RSA_PUBLIC_KEY: &str = r#"
-----BEGIN RSA PUBLIC KEY-----
MIIBCgKCAQEA1h9pD+0KYp5Bpda/OTWFVxaXKPO4+36LzNWk53PG4LOrrg2orJzd
bhFwoqB20ceFQOZm9dK8udY3LFn4Pv1M01pkPRV39+URLds3W+CujnTQiCJ/vBeW
rIf7HYq6TM0oQcmzESvRsVf37xZpvlK21pxgsxg2pyYoqZrx24ttxm81ZtJjv76Q
mzbaU3Lz+rYOfwxzeQelXZ0KDxWrptm+FlIspUzSGYxV6RmV66svaxzNi8mibQuI
x8BWbVHGWU45cISC4+oSnqUirB8i/URgZhwHyYfFO/Tmmf/+NROjXAqbH7lrbJBt
AP+OXYuKfkBNLUqRmXax2uttbaefF0sxQQIDAQAB
-----END RSA PUBLIC KEY-----
"#;
#[allow(dead_code)]
pub const PUBLIC_KEY: &str = r#"
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1h9pD+0KYp5Bpda/OTWF
VxaXKPO4+36LzNWk53PG4LOrrg2orJzdbhFwoqB20ceFQOZm9dK8udY3LFn4Pv1M
01pkPRV39+URLds3W+CujnTQiCJ/vBeWrIf7HYq6TM0oQcmzESvRsVf37xZpvlK2
1pxgsxg2pyYoqZrx24ttxm81ZtJjv76QmzbaU3Lz+rYOfwxzeQelXZ0KDxWrptm+
FlIspUzSGYxV6RmV66svaxzNi8mibQuIx8BWbVHGWU45cISC4+oSnqUirB8i/URg
ZhwHyYfFO/Tmmf/+NROjXAqbH7lrbJBtAP+OXYuKfkBNLUqRmXax2uttbaefF0sx
QQIDAQAB
-----END PUBLIC KEY-----
"#;
}