use serde::{Deserialize, Serialize};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use rmp_serde::{from_slice, to_vec};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use chrono::Utc;
use std::error::Error;
use std::fmt;
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug)]
pub struct TokenError(String);
impl fmt::Display for TokenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Error for TokenError {}
pub trait Expirable {
fn exp(&self) -> i64;
}
pub fn create_token<T>(payload: &T, secret: &str) -> Result<String, Box<dyn Error>>
where
T: Serialize,
{
let payload_bytes = to_vec(payload)?;
let signature = sign_payload(secret, &payload_bytes)?;
Ok(format!(
"{}.{}",
URL_SAFE_NO_PAD.encode(&payload_bytes),
URL_SAFE_NO_PAD.encode(&signature)
))
}
fn sign_payload(secret: &str, payload: &[u8]) -> Result<Vec<u8>, Box<dyn Error>> {
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())?;
mac.update(payload);
Ok(mac.finalize().into_bytes().to_vec())
}
pub fn verify_token<T>(secret: &str, token: &str) -> Result<T, Box<dyn Error>>
where
T: for<'de> Deserialize<'de> + Expirable,
{
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(Box::new(TokenError("Invalid token format".to_string())));
}
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[0])?;
let signature = URL_SAFE_NO_PAD.decode(parts[1])?;
let expected_signature = sign_payload(secret, &payload_bytes)?;
if signature != expected_signature {
return Err(Box::new(TokenError("Invalid token signature".to_string())));
}
let payload: T = from_slice(&payload_bytes)?;
let exp_timestamp = payload.exp();
let now_timestamp = Utc::now().timestamp();
if exp_timestamp < now_timestamp {
return Err(Box::new(TokenError("Token has expired".to_string())));
}
Ok(payload)
}
pub fn decode_token<T>(token: &str) -> Result<T, Box<dyn Error>>
where
T: for<'de> Deserialize<'de>,
{
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(Box::new(TokenError("Invalid token format".to_string())));
}
let payload_bytes = URL_SAFE_NO_PAD.decode(parts[0])?;
let payload: T = from_slice(&payload_bytes)?;
Ok(payload)
}