use std::collections::HashMap;
use std::fmt;
use std::time::{Duration, Instant};
pub const HANDOFF_TOKEN_BYTES: usize = 16;
pub const DEFAULT_MAX_PENDING_HANDOFF_TOKENS: usize = 1024;
pub const DEFAULT_HANDOFF_TOKEN_TTL: Duration = Duration::from_secs(30);
pub const DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS: usize = 16;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct HandoffToken([u8; HANDOFF_TOKEN_BYTES]);
impl HandoffToken {
pub fn generate() -> Result<Self, HandoffTokenError> {
let mut bytes = [0_u8; HANDOFF_TOKEN_BYTES];
getrandom::fill(&mut bytes)?;
Ok(Self(bytes))
}
pub fn from_bytes(bytes: [u8; HANDOFF_TOKEN_BYTES]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; HANDOFF_TOKEN_BYTES] {
&self.0
}
pub fn into_bytes(self) -> [u8; HANDOFF_TOKEN_BYTES] {
self.0
}
}
impl fmt::Debug for HandoffToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("HandoffToken(<redacted>)")
}
}
impl From<[u8; HANDOFF_TOKEN_BYTES]> for HandoffToken {
fn from(value: [u8; HANDOFF_TOKEN_BYTES]) -> Self {
Self::from_bytes(value)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct HandoffTokenStoreConfig {
pub max_pending_tokens: usize,
pub token_ttl: Duration,
pub collision_attempts: usize,
}
impl HandoffTokenStoreConfig {
pub fn new(max_pending_tokens: usize, token_ttl: Duration) -> Self {
Self {
max_pending_tokens: max_pending_tokens.max(1),
token_ttl: if token_ttl.is_zero() {
Duration::from_millis(1)
} else {
token_ttl
},
collision_attempts: DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS,
}
}
pub fn with_collision_attempts(mut self, collision_attempts: usize) -> Self {
self.collision_attempts = collision_attempts.max(1);
self
}
}
impl Default for HandoffTokenStoreConfig {
fn default() -> Self {
Self {
max_pending_tokens: DEFAULT_MAX_PENDING_HANDOFF_TOKENS,
token_ttl: DEFAULT_HANDOFF_TOKEN_TTL,
collision_attempts: DEFAULT_HANDOFF_TOKEN_COLLISION_ATTEMPTS,
}
}
}
#[derive(Debug)]
pub struct HandoffTokenStore {
config: HandoffTokenStoreConfig,
pending: HashMap<HandoffToken, PendingHandoffToken>,
}
impl HandoffTokenStore {
pub fn new() -> Self {
Self::with_config(HandoffTokenStoreConfig::default())
}
pub fn with_config(config: HandoffTokenStoreConfig) -> Self {
Self {
config,
pending: HashMap::new(),
}
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
pub fn issue(&mut self, now: Instant) -> Result<HandoffToken, HandoffTokenError> {
self.issue_with_random128(now, || {
let mut bytes = [0_u8; HANDOFF_TOKEN_BYTES];
getrandom::fill(&mut bytes)?;
Ok(bytes)
})
}
pub fn issue_with_random128<F>(
&mut self,
now: Instant,
mut next_random128: F,
) -> Result<HandoffToken, HandoffTokenError>
where
F: FnMut() -> Result<[u8; HANDOFF_TOKEN_BYTES], HandoffTokenError>,
{
self.prune_expired(now);
if self.pending.len() >= self.config.max_pending_tokens {
return Err(HandoffTokenError::PendingLimitReached {
max_pending_tokens: self.config.max_pending_tokens,
});
}
for _ in 0..self.config.collision_attempts {
let token = HandoffToken::from_bytes(next_random128()?);
if self.pending.contains_key(&token) {
continue;
}
self.pending.insert(
token,
PendingHandoffToken {
expires_at: expires_at(now, self.config.token_ttl),
},
);
return Ok(token);
}
Err(HandoffTokenError::CollisionExhausted {
attempts: self.config.collision_attempts,
})
}
pub fn consume_matching(
&mut self,
expected: &HandoffToken,
presented: &HandoffToken,
now: Instant,
) -> Result<(), HandoffTokenError> {
self.prune_expired_except(now, Some(expected));
let Some(pending) = self.pending.get(expected) else {
return Err(HandoffTokenError::TokenNotPending);
};
if now >= pending.expires_at {
self.pending.remove(expected);
return Err(HandoffTokenError::TokenExpired);
}
if expected != presented {
return Err(HandoffTokenError::TokenMismatch);
}
self.pending.remove(expected);
Ok(())
}
pub fn revoke(&mut self, token: &HandoffToken) -> bool {
self.pending.remove(token).is_some()
}
pub fn prune_expired(&mut self, now: Instant) -> usize {
self.prune_expired_except(now, None)
}
fn prune_expired_except(&mut self, now: Instant, except: Option<&HandoffToken>) -> usize {
let before = self.pending.len();
self.pending.retain(|token, pending| {
except.is_some_and(|expected| expected == token) || now < pending.expires_at
});
before - self.pending.len()
}
}
impl Default for HandoffTokenStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
pub enum HandoffTokenError {
#[error("handoff token random generation failed: {0}")]
Random(String),
#[error("handoff token pending limit reached ({max_pending_tokens})")]
PendingLimitReached {
max_pending_tokens: usize,
},
#[error("handoff token allocation exhausted after {attempts} collision attempts")]
CollisionExhausted {
attempts: usize,
},
#[error("handoff token mismatch")]
TokenMismatch,
#[error("handoff token expired")]
TokenExpired,
#[error("handoff token is not pending")]
TokenNotPending,
}
impl From<getrandom::Error> for HandoffTokenError {
fn from(value: getrandom::Error) -> Self {
Self::Random(value.to_string())
}
}
#[derive(Clone, Debug)]
struct PendingHandoffToken {
expires_at: Instant,
}
fn expires_at(now: Instant, ttl: Duration) -> Instant {
now.checked_add(ttl).unwrap_or(now)
}