use axon_csys::{ContinuityWire, ContinuityWireError};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
#[derive(Debug)]
pub enum ContinuityTokenError {
Malformed(String),
ForgedOrRotated,
Expired { expired_at: DateTime<Utc> },
}
impl std::fmt::Display for ContinuityTokenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Malformed(msg) => {
write!(f, "continuity token malformed: {msg}")
}
Self::ForgedOrRotated => write!(
f,
"continuity token failed HMAC verification (forged or \
signer key rotated)"
),
Self::Expired { expired_at } => {
write!(f, "continuity token expired at {expired_at}")
}
}
}
}
impl std::error::Error for ContinuityTokenError {}
impl From<ContinuityWireError> for ContinuityTokenError {
fn from(value: ContinuityWireError) -> Self {
match value {
ContinuityWireError::ForgedOrRotated => Self::ForgedOrRotated,
other => Self::Malformed(other.to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContinuityToken {
pub session_id: String,
pub expires_at: DateTime<Utc>,
}
impl ContinuityToken {
pub fn new(session_id: impl Into<String>, ttl: ChronoDuration) -> Self {
ContinuityToken {
session_id: session_id.into(),
expires_at: Utc::now() + ttl,
}
}
}
#[derive(Debug, Clone)]
pub struct ContinuityTokenSigner {
key: Vec<u8>,
}
impl ContinuityTokenSigner {
pub fn new(key: impl Into<Vec<u8>>) -> Self {
ContinuityTokenSigner { key: key.into() }
}
pub fn sign(&self, token: &ContinuityToken) -> String {
let expiry_ms = token.expires_at.timestamp_millis();
ContinuityWire::sign(&self.key, &token.session_id, expiry_ms)
.expect("ContinuityToken.session_id must not contain 0x1e and must be ≤ 1024 bytes")
}
pub fn verify(
&self,
raw: &str,
) -> Result<ContinuityToken, ContinuityTokenError> {
let (session_id, expiry_ms) = ContinuityWire::verify(&self.key, raw)?;
let expires_at =
DateTime::<Utc>::from_timestamp_millis(expiry_ms).ok_or_else(|| {
ContinuityTokenError::Malformed(
"expiry timestamp out of range".into(),
)
})?;
if expires_at <= Utc::now() {
return Err(ContinuityTokenError::Expired { expired_at: expires_at });
}
Ok(ContinuityToken {
session_id,
expires_at,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sign_verify_roundtrip() {
let signer = ContinuityTokenSigner::new([7u8; 32]);
let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(15));
let wire = signer.sign(&token);
let decoded = signer.verify(&wire).expect("verify");
assert_eq!(decoded.session_id, "sess-1");
assert_eq!(
decoded.expires_at.timestamp_millis(),
token.expires_at.timestamp_millis()
);
}
#[test]
fn verify_rejects_tampered_session_id() {
use axon_csys::{b64url_decode, b64url_encode};
let signer = ContinuityTokenSigner::new([7u8; 32]);
let token = ContinuityToken::new("sess-a", ChronoDuration::minutes(15));
let wire = signer.sign(&token);
let decoded_bytes = b64url_decode(&wire).unwrap();
let text = std::str::from_utf8(&decoded_bytes).unwrap();
let tampered = text.replacen("sess-a", "sess-b", 1);
let tampered_wire = b64url_encode(tampered.as_bytes());
let err = signer.verify(&tampered_wire).unwrap_err();
assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
}
#[test]
fn verify_rejects_different_signer_key() {
let s1 = ContinuityTokenSigner::new([1u8; 32]);
let s2 = ContinuityTokenSigner::new([2u8; 32]);
let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(15));
let wire = s1.sign(&token);
let err = s2.verify(&wire).unwrap_err();
assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
}
#[test]
fn verify_rejects_expired_token() {
let signer = ContinuityTokenSigner::new([7u8; 32]);
let token = ContinuityToken::new("sess-1", ChronoDuration::seconds(-1));
let wire = signer.sign(&token);
let err = signer.verify(&wire).unwrap_err();
assert!(matches!(err, ContinuityTokenError::Expired { .. }));
}
#[test]
fn verify_rejects_malformed_base64() {
let signer = ContinuityTokenSigner::new([7u8; 32]);
let err = signer.verify("not-valid-base64!@#").unwrap_err();
assert!(matches!(err, ContinuityTokenError::Malformed(_)));
}
#[test]
fn verify_rejects_wrong_field_count() {
use axon_csys::b64url_encode;
let signer = ContinuityTokenSigner::new([7u8; 32]);
let bad = b64url_encode(b"sess-1\x1e9999");
let err = signer.verify(&bad).unwrap_err();
assert!(matches!(err, ContinuityTokenError::Malformed(_)));
}
#[test]
fn hmac_uses_constant_time_compare() {
use axon_csys::{b64url_decode, b64url_encode};
let signer = ContinuityTokenSigner::new([7u8; 32]);
let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(5));
let wire_good = signer.sign(&token);
let decoded = b64url_decode(&wire_good).unwrap();
let mut text = std::str::from_utf8(&decoded).unwrap().to_string();
let len = text.len();
let last = text.chars().last().unwrap();
let flipped = if last == 'a' { 'b' } else { 'a' };
text.replace_range(len - 1.., &flipped.to_string());
let wire_bad = b64url_encode(text.as_bytes());
let err = signer.verify(&wire_bad).unwrap_err();
assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
}
}