use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use base64::Engine;
use rand::Rng;
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use super::operator_console::session::{sign, SessionError};
pub use super::operator_console::SessionSecret;
pub const HANDOFF_TTL_SECS: i64 = 60;
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum HandoffError {
#[error("handoff token malformed")]
Malformed,
#[error("handoff token signature invalid")]
BadSignature,
#[error("handoff token expired")]
Expired,
#[error("handoff token bound to a different tenant")]
WrongTenant,
#[error("handoff token already used")]
AlreadyUsed,
}
impl From<SessionError> for HandoffError {
fn from(e: SessionError) -> Self {
match e {
SessionError::Malformed => Self::Malformed,
SessionError::BadSignature => Self::BadSignature,
SessionError::Expired => Self::Expired,
SessionError::WrongTenant => Self::WrongTenant,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct HandoffPayload {
pub op: i64,
pub slug: String,
pub exp: i64,
pub jti: String,
}
impl HandoffPayload {
#[must_use]
pub fn new(op_id: i64, slug: impl Into<String>, ttl_secs: i64) -> Self {
let now = chrono::Utc::now().timestamp();
let mut bytes = [0u8; 16];
rand::thread_rng().fill(&mut bytes[..]);
let jti = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
Self {
op: op_id,
slug: slug.into(),
exp: now + ttl_secs,
jti,
}
}
fn is_expired(&self) -> bool {
chrono::Utc::now().timestamp() >= self.exp
}
}
#[must_use]
pub fn mint(secret: &SessionSecret, payload: &HandoffPayload) -> String {
let json = serde_json::to_vec(payload).expect("payload serializes");
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json);
let sig = sign(secret, payload_b64.as_bytes());
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig);
format!("{payload_b64}.{sig_b64}")
}
pub fn decode(
secret: &SessionSecret,
expected_slug: &str,
value: &str,
) -> Result<HandoffPayload, HandoffError> {
let (payload_b64, sig_b64) = value.split_once('.').ok_or(HandoffError::Malformed)?;
let expected = sign(secret, payload_b64.as_bytes());
let provided = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(sig_b64)
.map_err(|_| HandoffError::Malformed)?;
if expected.ct_eq(&provided[..]).unwrap_u8() == 0 {
return Err(HandoffError::BadSignature);
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|_| HandoffError::Malformed)?;
let payload: HandoffPayload =
serde_json::from_slice(&payload_bytes).map_err(|_| HandoffError::Malformed)?;
if payload.is_expired() {
return Err(HandoffError::Expired);
}
if payload.slug != expected_slug {
return Err(HandoffError::WrongTenant);
}
Ok(payload)
}
pub struct JtiBlacklist {
inner: Mutex<HashMap<String, i64>>,
}
impl JtiBlacklist {
fn new() -> Self {
Self {
inner: Mutex::new(HashMap::new()),
}
}
pub fn shared() -> &'static Self {
static INSTANCE: OnceLock<JtiBlacklist> = OnceLock::new();
INSTANCE.get_or_init(Self::new)
}
pub fn is_used(&self, jti: &str) -> bool {
let map = self.inner.lock().expect("jti blacklist not poisoned");
map.contains_key(jti)
}
pub fn mark_used(&self, jti: &str, exp: i64) -> Result<(), HandoffError> {
let mut map = self.inner.lock().expect("jti blacklist not poisoned");
let now = chrono::Utc::now().timestamp();
map.retain(|_, &mut e| e > now);
if map.contains_key(jti) {
return Err(HandoffError::AlreadyUsed);
}
map.insert(jti.to_owned(), exp);
Ok(())
}
#[cfg(test)]
fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key() -> SessionSecret {
SessionSecret::from_bytes(b"a-test-secret-thirty-two-bytes-x".to_vec())
}
#[test]
fn round_trip_valid_payload() {
let secret = key();
let payload = HandoffPayload::new(7, "acme", 60);
let token = mint(&secret, &payload);
let back = decode(&secret, "acme", &token).unwrap();
assert_eq!(back, payload);
}
#[test]
fn rejects_token_minted_for_a_different_tenant() {
let secret = key();
let payload = HandoffPayload::new(7, "acme", 60);
let token = mint(&secret, &payload);
assert_eq!(
decode(&secret, "globex", &token).unwrap_err(),
HandoffError::WrongTenant,
);
}
#[test]
fn rejects_tampered_signature() {
let secret = key();
let payload = HandoffPayload::new(7, "acme", 60);
let token = mint(&secret, &payload);
let mut bytes = token.into_bytes();
let last = bytes.len() - 1;
bytes[last] = if bytes[last] == b'A' { b'B' } else { b'A' };
let tampered = String::from_utf8(bytes).unwrap();
assert_eq!(
decode(&secret, "acme", &tampered).unwrap_err(),
HandoffError::BadSignature,
);
}
#[test]
fn rejects_token_signed_with_a_different_secret() {
let s1 = key();
let s2 = SessionSecret::from_bytes(b"b-other-secret-thirty-two-bytes-x".to_vec());
let token = mint(&s1, &HandoffPayload::new(7, "acme", 60));
assert_eq!(
decode(&s2, "acme", &token).unwrap_err(),
HandoffError::BadSignature,
);
}
#[test]
fn rejects_expired_token() {
let secret = key();
let token = mint(&secret, &HandoffPayload::new(7, "acme", -10));
assert_eq!(
decode(&secret, "acme", &token).unwrap_err(),
HandoffError::Expired,
);
}
#[test]
fn rejects_malformed_token() {
let secret = key();
assert_eq!(
decode(&secret, "acme", "not-a-token").unwrap_err(),
HandoffError::Malformed,
);
assert_eq!(
decode(&secret, "acme", "abc.!!!").unwrap_err(),
HandoffError::Malformed,
);
}
#[test]
fn jtis_are_unique_across_mints() {
let p1 = HandoffPayload::new(7, "acme", 60);
let p2 = HandoffPayload::new(7, "acme", 60);
assert_ne!(p1.jti, p2.jti, "random jti collision is unacceptable");
}
#[test]
fn jti_blacklist_first_use_succeeds_second_fails() {
let bl = JtiBlacklist::new();
let jti = "abc123";
let exp = chrono::Utc::now().timestamp() + 60;
bl.mark_used(jti, exp).unwrap();
assert!(bl.is_used(jti));
assert_eq!(
bl.mark_used(jti, exp).unwrap_err(),
HandoffError::AlreadyUsed
);
}
#[test]
fn jti_blacklist_prunes_expired_entries_on_insert() {
let bl = JtiBlacklist::new();
let now = chrono::Utc::now().timestamp();
bl.inner.lock().unwrap().insert("stale".into(), now - 60);
assert_eq!(bl.len(), 1);
bl.mark_used("fresh", now + 60).unwrap();
assert!(!bl.is_used("stale"));
assert!(bl.is_used("fresh"));
assert_eq!(bl.len(), 1);
}
#[test]
fn handoff_ttl_default_is_one_minute() {
assert_eq!(HANDOFF_TTL_SECS, 60);
}
}