#![allow(dead_code)]
use serde_json::{Value, json};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Algorithm {
RS256,
RS384,
RS512,
ES256,
ES384,
ES512,
}
impl Algorithm {
pub fn name(&self) -> &'static str {
match self {
Self::RS256 => "RS256",
Self::RS384 => "RS384",
Self::RS512 => "RS512",
Self::ES256 => "ES256",
Self::ES384 => "ES384",
Self::ES512 => "ES512",
}
}
pub fn all() -> Vec<Self> {
vec![
Self::RS256,
Self::RS384,
Self::RS512,
Self::ES256,
Self::ES384,
Self::ES512,
]
}
}
#[derive(Debug)]
pub struct TokenBuilder {
base_url: String,
algorithm: Algorithm,
claims: Value,
}
impl TokenBuilder {
pub fn new(base_url: impl Into<String>, algorithm: Algorithm) -> Self {
Self {
base_url: base_url.into(),
algorithm,
claims: json!({}),
}
}
pub fn issuer(mut self, iss: impl Into<String>) -> Self {
self.claims["iss"] = json!(iss.into());
self
}
pub fn subject(mut self, sub: impl Into<String>) -> Self {
self.claims["sub"] = json!(sub.into());
self
}
pub fn audience(mut self, aud: impl Into<String>) -> Self {
self.claims["aud"] = json!(aud.into());
self
}
pub fn expiration(mut self, exp: u64) -> Self {
self.claims["exp"] = json!(exp);
self
}
pub fn issued_at(mut self, iat: u64) -> Self {
self.claims["iat"] = json!(iat);
self
}
pub fn not_before(mut self, nbf: u64) -> Self {
self.claims["nbf"] = json!(nbf);
self
}
pub fn standard_valid_claims(mut self) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.claims = json!({
"iss": self.base_url.clone(),
"sub": "test-user",
"aud": "test-app",
"iat": now - 60,
"nbf": now - 60,
"exp": now + 3600,
});
self
}
pub async fn generate(self) -> Result<String, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let url = format!("{}/sign/{}", self.base_url, self.algorithm.name());
let json_body = serde_json::to_string(&self.claims)?;
let response = client
.post(&url)
.header("Content-Type", "application/json")
.body(json_body)
.send()
.await?;
if !response.status().is_success() {
return Err(format!("HTTP error: {}", response.status()).into());
}
let body = response.text().await?;
let json: Value = serde_json::from_str(&body)?;
let token = json
.get("token")
.and_then(|t| t.as_str())
.ok_or("Missing token in response")?;
Ok(token.to_string())
}
pub fn custom_claim(mut self, key: impl Into<String>, value: Value) -> Self {
self.claims[key.into()] = value;
self
}
}
pub fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
pub fn corrupt_signature(token: &str) -> String {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return token.to_string();
}
let mut sig_bytes = parts[2].as_bytes().to_vec();
if !sig_bytes.is_empty() {
let last_idx = sig_bytes.len() - 1;
sig_bytes[last_idx] ^= 0xFF;
}
let corrupted_sig = String::from_utf8_lossy(&sig_bytes).to_string();
format!("{}.{}.{corrupted_sig}", parts[0], parts[1])
}
pub fn remove_signature(token: &str) -> String {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return token.to_string();
}
format!("{}.{}", parts[0], parts[1])
}
pub fn empty_signature(token: &str) -> String {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return token.to_string();
}
format!("{}.{}.", parts[0], parts[1])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_algorithm_names() {
assert_eq!(Algorithm::RS256.name(), "RS256");
assert_eq!(Algorithm::ES384.name(), "ES384");
assert_eq!(Algorithm::RS512.name(), "RS512");
}
#[test]
fn test_corrupt_signature() {
let token = "header.payload.signature";
let corrupted = corrupt_signature(token);
assert_ne!(token, corrupted);
assert!(corrupted.starts_with("header.payload."));
}
#[test]
fn test_remove_signature() {
let token = "header.payload.signature";
let removed = remove_signature(token);
assert_eq!(removed, "header.payload");
}
#[test]
fn test_empty_signature() {
let token = "header.payload.signature";
let empty = empty_signature(token);
assert_eq!(empty, "header.payload.");
}
}