use std::sync::Arc;
use std::time::Duration;
use serde::Serialize;
use crate::encoding::base64url;
use crate::{Error, Result};
use super::config::JwtSessionsConfig;
use super::error::JwtError;
use super::signer::{HmacSigner, TokenSigner};
pub struct JwtEncoder {
inner: Arc<JwtEncoderInner>,
}
struct JwtEncoderInner {
signer: Arc<dyn TokenSigner>,
default_expiry: Option<Duration>,
validation: super::validation::ValidationConfig,
}
impl JwtEncoder {
pub fn from_config(config: &JwtSessionsConfig) -> Self {
let signer = HmacSigner::new(config.signing_secret.as_bytes());
Self {
inner: Arc::new(JwtEncoderInner {
signer: Arc::new(signer),
default_expiry: Some(Duration::from_secs(config.access_ttl_secs)),
validation: super::validation::ValidationConfig {
leeway: Duration::ZERO,
require_issuer: config.issuer.clone(),
require_audience: None,
},
}),
}
}
pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
}
pub(super) fn validation(&self) -> super::validation::ValidationConfig {
self.inner.validation.clone()
}
pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String> {
let claims_json = if let Some(default_exp) = self.inner.default_expiry {
let mut value = serde_json::to_value(claims).map_err(|_| {
Error::internal("failed to serialize token")
.chain(JwtError::SerializationFailed)
.with_code(JwtError::SerializationFailed.code())
})?;
if value.get("exp").is_none() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs();
value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
}
serde_json::to_vec(&value)
} else {
serde_json::to_vec(claims)
}
.map_err(|_| {
Error::internal("unauthorized")
.chain(JwtError::SerializationFailed)
.with_code(JwtError::SerializationFailed.code())
})?;
let alg = self.inner.signer.algorithm_name();
let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
let header_b64 = base64url::encode(header.as_bytes());
let payload_b64 = base64url::encode(&claims_json);
let header_payload = format!("{header_b64}.{payload_b64}");
let signature = self.inner.signer.sign(header_payload.as_bytes())?;
let signature_b64 = base64url::encode(&signature);
Ok(format!("{header_payload}.{signature_b64}"))
}
}
impl Clone for JwtEncoder {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use super::super::claims::Claims;
fn test_config() -> JwtSessionsConfig {
JwtSessionsConfig {
signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
..JwtSessionsConfig::default()
}
}
#[test]
fn encode_produces_three_part_token() {
let encoder = JwtEncoder::from_config(&test_config());
let claims = Claims::new().with_exp(9999999999);
let token = encoder.encode(&claims).unwrap();
assert_eq!(token.split('.').count(), 3);
}
#[test]
fn encode_header_contains_hs256() {
let encoder = JwtEncoder::from_config(&test_config());
let claims = Claims::new().with_exp(9999999999);
let token = encoder.encode(&claims).unwrap();
let header_b64 = token.split('.').next().unwrap();
let header_bytes = base64url::decode(header_b64).unwrap();
let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
assert_eq!(header["alg"], "HS256");
assert_eq!(header["typ"], "JWT");
}
#[test]
fn encode_with_default_expiry_auto_sets_exp() {
let config = test_config(); let encoder = JwtEncoder::from_config(&config);
let claims = Claims::new(); let token = encoder.encode(&claims).unwrap();
let payload_b64 = token.split('.').nth(1).unwrap();
let payload_bytes = base64url::decode(payload_b64).unwrap();
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
assert!(payload.get("exp").is_some());
}
#[test]
fn encode_explicit_exp_not_overwritten() {
let config = test_config();
let encoder = JwtEncoder::from_config(&config);
let claims = Claims::new().with_exp(42);
let token = encoder.encode(&claims).unwrap();
let payload_b64 = token.split('.').nth(1).unwrap();
let payload_bytes = base64url::decode(payload_b64).unwrap();
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
assert_eq!(payload["exp"], 42);
}
#[test]
fn encode_custom_struct_directly() {
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CustomPayload {
sub: String,
role: String,
exp: u64,
}
let encoder = JwtEncoder::from_config(&test_config());
let payload = CustomPayload {
sub: "user_1".into(),
role: "admin".into(),
exp: 9999999999,
};
let token = encoder.encode(&payload).unwrap();
assert_eq!(token.split('.').count(), 3);
}
#[test]
fn clone_produces_working_encoder() {
let encoder = JwtEncoder::from_config(&test_config());
let cloned = encoder.clone();
let claims = Claims::new().with_exp(9999999999);
assert!(cloned.encode(&claims).is_ok());
}
}