axon/pem/
continuity_token.rs1use axon_csys::{ContinuityWire, ContinuityWireError};
45use chrono::{DateTime, Duration as ChronoDuration, Utc};
46
47#[derive(Debug)]
50pub enum ContinuityTokenError {
51 Malformed(String),
54 ForgedOrRotated,
57 Expired { expired_at: DateTime<Utc> },
59}
60
61impl std::fmt::Display for ContinuityTokenError {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 Self::Malformed(msg) => {
65 write!(f, "continuity token malformed: {msg}")
66 }
67 Self::ForgedOrRotated => write!(
68 f,
69 "continuity token failed HMAC verification (forged or \
70 signer key rotated)"
71 ),
72 Self::Expired { expired_at } => {
73 write!(f, "continuity token expired at {expired_at}")
74 }
75 }
76 }
77}
78
79impl std::error::Error for ContinuityTokenError {}
80
81impl From<ContinuityWireError> for ContinuityTokenError {
82 fn from(value: ContinuityWireError) -> Self {
83 match value {
84 ContinuityWireError::ForgedOrRotated => Self::ForgedOrRotated,
85 other => Self::Malformed(other.to_string()),
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
97pub struct ContinuityToken {
98 pub session_id: String,
99 pub expires_at: DateTime<Utc>,
100}
101
102impl ContinuityToken {
103 pub fn new(session_id: impl Into<String>, ttl: ChronoDuration) -> Self {
105 ContinuityToken {
106 session_id: session_id.into(),
107 expires_at: Utc::now() + ttl,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
118pub struct ContinuityTokenSigner {
119 key: Vec<u8>,
120}
121
122impl ContinuityTokenSigner {
123 pub fn new(key: impl Into<Vec<u8>>) -> Self {
127 ContinuityTokenSigner { key: key.into() }
128 }
129
130 pub fn sign(&self, token: &ContinuityToken) -> String {
134 let expiry_ms = token.expires_at.timestamp_millis();
135 ContinuityWire::sign(&self.key, &token.session_id, expiry_ms)
144 .expect("ContinuityToken.session_id must not contain 0x1e and must be ≤ 1024 bytes")
145 }
146
147 pub fn verify(
150 &self,
151 raw: &str,
152 ) -> Result<ContinuityToken, ContinuityTokenError> {
153 let (session_id, expiry_ms) = ContinuityWire::verify(&self.key, raw)?;
154 let expires_at =
155 DateTime::<Utc>::from_timestamp_millis(expiry_ms).ok_or_else(|| {
156 ContinuityTokenError::Malformed(
157 "expiry timestamp out of range".into(),
158 )
159 })?;
160 if expires_at <= Utc::now() {
161 return Err(ContinuityTokenError::Expired { expired_at: expires_at });
162 }
163 Ok(ContinuityToken {
164 session_id,
165 expires_at,
166 })
167 }
168}
169
170#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn sign_verify_roundtrip() {
178 let signer = ContinuityTokenSigner::new([7u8; 32]);
179 let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(15));
180 let wire = signer.sign(&token);
181 let decoded = signer.verify(&wire).expect("verify");
182 assert_eq!(decoded.session_id, "sess-1");
183 assert_eq!(
185 decoded.expires_at.timestamp_millis(),
186 token.expires_at.timestamp_millis()
187 );
188 }
189
190 #[test]
191 fn verify_rejects_tampered_session_id() {
192 use axon_csys::{b64url_decode, b64url_encode};
193 let signer = ContinuityTokenSigner::new([7u8; 32]);
194 let token = ContinuityToken::new("sess-a", ChronoDuration::minutes(15));
195 let wire = signer.sign(&token);
196 let decoded_bytes = b64url_decode(&wire).unwrap();
197 let text = std::str::from_utf8(&decoded_bytes).unwrap();
198 let tampered = text.replacen("sess-a", "sess-b", 1);
199 let tampered_wire = b64url_encode(tampered.as_bytes());
200
201 let err = signer.verify(&tampered_wire).unwrap_err();
202 assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
203 }
204
205 #[test]
206 fn verify_rejects_different_signer_key() {
207 let s1 = ContinuityTokenSigner::new([1u8; 32]);
208 let s2 = ContinuityTokenSigner::new([2u8; 32]);
209 let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(15));
210 let wire = s1.sign(&token);
211 let err = s2.verify(&wire).unwrap_err();
212 assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
213 }
214
215 #[test]
216 fn verify_rejects_expired_token() {
217 let signer = ContinuityTokenSigner::new([7u8; 32]);
218 let token = ContinuityToken::new("sess-1", ChronoDuration::seconds(-1));
219 let wire = signer.sign(&token);
220 let err = signer.verify(&wire).unwrap_err();
221 assert!(matches!(err, ContinuityTokenError::Expired { .. }));
222 }
223
224 #[test]
225 fn verify_rejects_malformed_base64() {
226 let signer = ContinuityTokenSigner::new([7u8; 32]);
227 let err = signer.verify("not-valid-base64!@#").unwrap_err();
228 assert!(matches!(err, ContinuityTokenError::Malformed(_)));
229 }
230
231 #[test]
232 fn verify_rejects_wrong_field_count() {
233 use axon_csys::b64url_encode;
234 let signer = ContinuityTokenSigner::new([7u8; 32]);
235 let bad = b64url_encode(b"sess-1\x1e9999");
236 let err = signer.verify(&bad).unwrap_err();
237 assert!(matches!(err, ContinuityTokenError::Malformed(_)));
238 }
239
240 #[test]
241 fn hmac_uses_constant_time_compare() {
242 use axon_csys::{b64url_decode, b64url_encode};
246 let signer = ContinuityTokenSigner::new([7u8; 32]);
247 let token = ContinuityToken::new("sess-1", ChronoDuration::minutes(5));
248 let wire_good = signer.sign(&token);
249
250 let decoded = b64url_decode(&wire_good).unwrap();
252 let mut text = std::str::from_utf8(&decoded).unwrap().to_string();
253 let len = text.len();
254 let last = text.chars().last().unwrap();
255 let flipped = if last == 'a' { 'b' } else { 'a' };
256 text.replace_range(len - 1.., &flipped.to_string());
257 let wire_bad = b64url_encode(text.as_bytes());
258
259 let err = signer.verify(&wire_bad).unwrap_err();
260 assert!(matches!(err, ContinuityTokenError::ForgedOrRotated));
261 }
262}