use std::collections::HashMap;
use std::time::{Duration, Instant};
pub const MAX_TICKET_SIZE: usize = 256;
pub const DEFAULT_TICKET_LIFETIME: Duration = Duration::from_secs(300);
#[derive(Clone)]
pub struct SessionTicket {
data: [u8; MAX_TICKET_SIZE],
len: usize,
}
impl SessionTicket {
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() > MAX_TICKET_SIZE {
return None;
}
let mut data = [0u8; MAX_TICKET_SIZE];
data[..bytes.len()].copy_from_slice(bytes);
Some(Self {
data,
len: bytes.len(),
})
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
&self.data[..self.len]
}
}
#[derive(Clone)]
pub struct RestoredSession {
pub peer_id: u16,
pub send_key: [u8; 32],
pub recv_key: [u8; 32],
}
pub struct SessionStore {
server_key: [u8; 32],
used_tickets: HashMap<[u8; 16], Instant>,
ticket_lifetime: Duration,
last_cleanup: Instant,
}
impl SessionStore {
pub fn new(server_key: [u8; 32]) -> Self {
Self {
server_key,
used_tickets: HashMap::with_capacity(1024),
ticket_lifetime: DEFAULT_TICKET_LIFETIME,
last_cleanup: Instant::now(),
}
}
pub fn with_lifetime(mut self, lifetime: Duration) -> Self {
self.ticket_lifetime = lifetime;
self
}
pub fn create_ticket(
&self,
peer_id: u16,
send_key: &[u8; 32],
recv_key: &[u8; 32],
) -> SessionTicket {
let mut data = [0u8; MAX_TICKET_SIZE];
let mut ticket_id = [0u8; 16];
rand::RngCore::fill_bytes(&mut rand::rng(), &mut ticket_id);
data[0..16].copy_from_slice(&ticket_id);
data[16..18].copy_from_slice(&peer_id.to_le_bytes());
data[18..50].copy_from_slice(send_key);
data[50..82].copy_from_slice(recv_key);
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ self.ticket_lifetime.as_secs();
data[82..90].copy_from_slice(&expiry.to_le_bytes());
let hmac = self.compute_hmac(&data[0..90]);
data[90..122].copy_from_slice(&hmac);
let enc_key = self.derive_encryption_key(&ticket_id);
for i in 16..122 {
data[i] ^= enc_key[i % 32];
}
SessionTicket { data, len: 122 }
}
pub fn validate_ticket(&mut self, ticket: &SessionTicket) -> Option<RestoredSession> {
self.maybe_cleanup();
let bytes = ticket.as_bytes();
if bytes.len() < 122 {
return None;
}
let mut ticket_id = [0u8; 16];
ticket_id.copy_from_slice(&bytes[0..16]);
if self.used_tickets.contains_key(&ticket_id) {
return None;
}
let enc_key = self.derive_encryption_key(&ticket_id);
let mut decrypted = [0u8; 122];
decrypted[0..16].copy_from_slice(&ticket_id);
for i in 16..122 {
decrypted[i] = bytes[i] ^ enc_key[i % 32];
}
let expected_hmac = self.compute_hmac(&decrypted[0..90]);
let provided_hmac = &decrypted[90..122];
if !constant_time_eq(&expected_hmac, provided_hmac) {
return None;
}
let expiry = u64::from_le_bytes([
decrypted[82], decrypted[83], decrypted[84], decrypted[85],
decrypted[86], decrypted[87], decrypted[88], decrypted[89],
]);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now > expiry {
return None;
}
self.used_tickets.insert(ticket_id, Instant::now());
let peer_id = u16::from_le_bytes([decrypted[16], decrypted[17]]);
let mut send_key = [0u8; 32];
let mut recv_key = [0u8; 32];
send_key.copy_from_slice(&decrypted[18..50]);
recv_key.copy_from_slice(&decrypted[50..82]);
Some(RestoredSession {
peer_id,
send_key,
recv_key,
})
}
fn derive_encryption_key(&self, ticket_id: &[u8; 16]) -> [u8; 32] {
let mut hasher = blake3::Hasher::new_keyed(&self.server_key);
hasher.update(b"ticket-encrypt");
hasher.update(ticket_id);
*hasher.finalize().as_bytes()
}
fn compute_hmac(&self, data: &[u8]) -> [u8; 32] {
let mut hasher = blake3::Hasher::new_keyed(&self.server_key);
hasher.update(b"ticket-hmac");
hasher.update(data);
*hasher.finalize().as_bytes()
}
fn maybe_cleanup(&mut self) {
let now = Instant::now();
if now.duration_since(self.last_cleanup) < Duration::from_secs(60) {
return;
}
self.last_cleanup = now;
let lifetime = self.ticket_lifetime;
self.used_tickets.retain(|_, created| {
now.duration_since(*created) < lifetime * 2
});
}
#[inline]
pub fn used_ticket_count(&self) -> usize {
self.used_tickets.len()
}
}
#[inline]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
pub struct ClientTicketStore {
tickets: HashMap<String, (SessionTicket, Instant)>,
max_tickets: usize,
}
impl ClientTicketStore {
pub fn new() -> Self {
Self {
tickets: HashMap::with_capacity(16),
max_tickets: 100,
}
}
pub fn store(&mut self, server_addr: &str, ticket: SessionTicket) {
if self.tickets.len() >= self.max_tickets {
if let Some(oldest) = self.tickets.iter()
.min_by_key(|(_, (_, t))| *t)
.map(|(k, _)| k.clone())
{
self.tickets.remove(&oldest);
}
}
self.tickets.insert(server_addr.to_string(), (ticket, Instant::now()));
}
pub fn get(&self, server_addr: &str) -> Option<&SessionTicket> {
self.tickets.get(server_addr).map(|(t, _)| t)
}
pub fn remove(&mut self, server_addr: &str) {
self.tickets.remove(server_addr);
}
pub fn clear(&mut self) {
self.tickets.clear();
}
}
impl Default for ClientTicketStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_server_key() -> [u8; 32] {
let mut key = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::rng(), &mut key);
key
}
#[test]
fn test_ticket_roundtrip() {
let server_key = test_server_key();
let mut store = SessionStore::new(server_key);
let send_key = [1u8; 32];
let recv_key = [2u8; 32];
let ticket = store.create_ticket(42, &send_key, &recv_key);
let restored = store.validate_ticket(&ticket).unwrap();
assert_eq!(restored.peer_id, 42);
assert_eq!(restored.send_key, send_key);
assert_eq!(restored.recv_key, recv_key);
}
#[test]
fn test_ticket_replay_protection() {
let server_key = test_server_key();
let mut store = SessionStore::new(server_key);
let ticket = store.create_ticket(1, &[0u8; 32], &[0u8; 32]);
assert!(store.validate_ticket(&ticket).is_some());
assert!(store.validate_ticket(&ticket).is_none());
}
#[test]
fn test_ticket_tampering() {
let server_key = test_server_key();
let mut store = SessionStore::new(server_key);
let ticket = store.create_ticket(1, &[0u8; 32], &[0u8; 32]);
let mut tampered = ticket.clone();
tampered.data[50] ^= 0xFF;
assert!(store.validate_ticket(&tampered).is_none());
}
#[test]
fn test_ticket_expiry() {
let server_key = test_server_key();
let mut store = SessionStore::new(server_key)
.with_lifetime(Duration::from_secs(0));
let ticket = store.create_ticket(1, &[0u8; 32], &[0u8; 32]);
std::thread::sleep(Duration::from_secs(1));
assert!(store.validate_ticket(&ticket).is_none());
}
#[test]
fn test_client_store() {
let mut store = ClientTicketStore::new();
let ticket = SessionTicket::from_bytes(&[0u8; 122]).unwrap();
store.store("server1:7777", ticket.clone());
assert!(store.get("server1:7777").is_some());
assert!(store.get("server2:7777").is_none());
store.remove("server1:7777");
assert!(store.get("server1:7777").is_none());
}
#[test]
fn test_constant_time_eq() {
assert!(constant_time_eq(&[1, 2, 3], &[1, 2, 3]));
assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 4]));
assert!(!constant_time_eq(&[1, 2, 3], &[1, 2]));
}
}