mod algorithm;
mod error;
mod keygen;
pub use algorithm::*;
pub use error::*;
pub use keygen::*;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
#[derive(Debug, Clone)]
pub struct Encoder;
#[derive(Debug, Clone)]
pub struct Decoded;
pub type Headers = Map<String, Value>;
#[derive(Debug, Clone)]
pub struct Jwt<State, C> {
headers: Headers,
claims: C,
_state: std::marker::PhantomData<State>,
}
impl<C> Jwt<Encoder, C>
where
C: Serialize,
{
pub fn new(claims: C) -> Self {
let mut headers = Map::new();
headers.insert("typ".to_string(), Value::String("JWT".to_string()));
Self {
headers,
claims,
_state: std::marker::PhantomData,
}
}
pub fn header<V: Serialize>(mut self, key: &str, value: V) -> Self {
if let Ok(value) = serde_json::to_value(value) {
self.headers.insert(key.to_string(), value);
}
self
}
pub fn claims_mut(&mut self) -> &mut C {
&mut self.claims
}
pub fn encode<S: Signer>(mut self, signer: &S) -> Result<String, JwtError> {
self.headers
.insert("alg".to_string(), Value::String(signer.name().to_string()));
let header_json =
serde_json::to_string(&self.headers).map_err(|_| JwtError::SerializationError)?;
let claims_json =
serde_json::to_string(&self.claims).map_err(|_| JwtError::SerializationError)?;
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
let signing_input = format!("{}.{}", header_b64, claims_b64);
let signature = signer.sign(signing_input.as_bytes())?;
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
Ok(format!("{}.{}.{}", header_b64, claims_b64, signature_b64))
}
}
impl<C> Jwt<Decoded, C>
where
C: for<'de> Deserialize<'de>,
{
pub fn decode<V: Verifier>(token: &str, verifier: &V) -> Result<Jwt<Decoded, C>, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::InvalidFormat);
}
let header_bytes = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|_| JwtError::InvalidFormat)?;
let claims_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| JwtError::InvalidFormat)?;
let signature = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|_| JwtError::InvalidFormat)?;
let headers: Headers =
serde_json::from_slice(&header_bytes).map_err(|_| JwtError::SerializationError)?;
let claims: C =
serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::SerializationError)?;
if let Some(Value::String(alg)) = headers.get("alg") {
if alg != verifier.name() {
return Err(JwtError::InvalidAlgorithm);
}
} else {
return Err(JwtError::InvalidAlgorithm);
}
let signing_input = format!("{}.{}", parts[0], parts[1]);
if !verifier.verify(signing_input.as_bytes(), &signature)? {
return Err(JwtError::InvalidSignature);
}
Ok(Jwt {
headers,
claims,
_state: std::marker::PhantomData,
})
}
pub fn headers(&self) -> &Headers {
&self.headers
}
pub fn claims(&self) -> &C {
&self.claims
}
pub fn header(&self, key: &str) -> Option<&Value> {
self.headers.get(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct TestClaims {
sub: String,
name: String,
iat: u64,
admin: Option<bool>,
}
#[test]
fn test_new() {
let claims = TestClaims {
sub: "user123".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: Some(true),
};
let jwt = Jwt::new(claims.clone());
assert_eq!(jwt.claims.sub, "user123");
assert_eq!(jwt.claims.name, "Test User");
assert_eq!(jwt.claims.iat, 1234567890);
assert_eq!(jwt.claims.admin, Some(true));
}
#[test]
fn test_claims_mut() {
let claims = TestClaims {
sub: "user123".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: Some(false),
};
let mut jwt = Jwt::new(claims);
jwt.claims_mut().admin = Some(true);
assert_eq!(jwt.claims.admin, Some(true));
}
#[test]
fn test_header_access() {
let claims = TestClaims {
sub: "user123".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: None,
};
let jwt = Jwt::<Encoder, TestClaims>::new(claims).header("custom", "header_value");
let decoded_jwt = Jwt::<Decoded, TestClaims> {
headers: jwt.headers,
claims: jwt.claims,
_state: std::marker::PhantomData,
};
assert_eq!(
decoded_jwt.header("typ"),
Some(&Value::String("JWT".to_string()))
);
assert_eq!(
decoded_jwt.header("custom"),
Some(&Value::String("header_value".to_string()))
);
assert_eq!(decoded_jwt.header("nonexistent"), None);
}
#[test]
fn test_headers_and_claims_getters() {
let claims = TestClaims {
sub: "test".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: None,
};
let jwt = Jwt::<Decoded, TestClaims> {
headers: {
let mut h = Map::new();
h.insert("typ".to_string(), Value::String("JWT".to_string()));
h
},
claims: claims.clone(),
_state: std::marker::PhantomData,
};
assert_eq!(
jwt.headers().get("typ"),
Some(&Value::String("JWT".to_string()))
);
assert_eq!(jwt.claims().sub, "test");
}
#[test]
fn test_hs256_encode_decode() {
let secret = random_secret();
let algorithm = HS256::new(&secret);
let claims = TestClaims {
sub: "1234567890".to_string(),
name: "John Doe".to_string(),
iat: 1516239022,
admin: Some(true),
};
let jwt = Jwt::<Encoder, TestClaims>::new(claims.clone());
let token = jwt.encode(&algorithm).unwrap();
println!("Token: {}", token);
let decoded = Jwt::<Decoded, TestClaims>::decode(&token, &algorithm).unwrap();
assert_eq!(decoded.claims(), &claims);
}
#[test]
fn test_invalid_signature() {
let secret = b"256-bit-secret";
let wrong_secret = b"wrong-secret";
let algorithm = HS256::new(secret);
let wrong_algorithm = HS256::new(wrong_secret);
let claims = TestClaims {
sub: "1234567890".to_string(),
name: "John Doe".to_string(),
iat: 1516239022,
admin: None,
};
let jwt = Jwt::<Encoder, TestClaims>::new(claims);
let token = jwt.encode(&algorithm).unwrap();
let result = Jwt::<Decoded, TestClaims>::decode(&token, &wrong_algorithm);
assert!(result.is_err())
}
#[test]
fn test_rs256_encode_decode() {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct RS256Claims {
sub: String,
admin: bool,
roles: Vec<String>,
}
let (private_key, public_key) = rsa_keypair().unwrap();
let signer = RS256Signer::new(private_key);
let verifier = RS256Verifier::new(public_key);
let claims = RS256Claims {
sub: "test-user".to_string(),
admin: true,
roles: vec!["editor".to_string(), "viewer".to_string()],
};
let jwt = Jwt::<Encoder, RS256Claims>::new(claims.clone());
let token = jwt.encode(&signer).unwrap();
println!("RS256 Token: {}", token);
let decoded = Jwt::<Decoded, RS256Claims>::decode(&token, &verifier).unwrap();
assert_eq!(decoded.claims(), &claims);
let decoded_with_signer = Jwt::<Decoded, RS256Claims>::decode(&token, &signer).unwrap();
assert_eq!(decoded.claims(), decoded_with_signer.claims());
}
#[test]
fn test_rs256_invalid_signature() {
let claims = TestClaims {
sub: "some-user".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: None,
};
let (private_key_signer, _) = rsa_keypair().unwrap();
let signer = RS256Signer::new(private_key_signer);
let (_, public_key_verifier) = rsa_keypair().unwrap();
let wrong_verifier = RS256Verifier::new(public_key_verifier);
let jwt = Jwt::<Encoder, TestClaims>::new(claims);
let token = jwt.encode(&signer).unwrap();
let result = Jwt::<Decoded, TestClaims>::decode(&token, &wrong_verifier);
assert!(matches!(result, Err(JwtError::InvalidSignature)));
}
#[test]
fn test_invalid_format() {
let algorithm = HS256::new(b"secret");
let result = Jwt::<Decoded, TestClaims>::decode("invalid", &algorithm);
assert!(matches!(result, Err(JwtError::InvalidFormat)));
let result = Jwt::<Decoded, TestClaims>::decode("too.many.parts.here", &algorithm);
assert!(matches!(result, Err(JwtError::InvalidFormat)));
let result =
Jwt::<Decoded, TestClaims>::decode("invalid_base64.claims.signature", &algorithm);
assert!(matches!(result, Err(JwtError::InvalidFormat)));
}
#[test]
fn test_invalid_algorithm() {
let hs_algorithm = HS256::new(b"secret");
let (_, public_key) = rsa_keypair().unwrap();
let rs_algorithm = RS256Verifier::new(public_key);
let claims = TestClaims {
sub: "test".to_string(),
name: "Test User".to_string(),
iat: 1234567890,
admin: None,
};
let jwt = Jwt::<Encoder, TestClaims>::new(claims);
let token = jwt.encode(&hs_algorithm).unwrap();
let result = Jwt::<Decoded, TestClaims>::decode(&token, &rs_algorithm);
assert!(matches!(result, Err(JwtError::InvalidAlgorithm)));
}
}