use std::collections::HashMap;
use std::time::Duration;
use crate::crypto::KeyDerivation;
use crate::error::{Result, SessionError, SrxError};
use crate::session::Session;
#[derive(Clone, Debug)]
pub struct SessionTicket {
pub ticket_id: [u8; 16],
pub psk: [u8; 32],
pub seed: [u8; 32],
pub key_index: u64,
pub issued_at: u64,
pub lifetime_secs: u64,
}
impl SessionTicket {
pub fn is_expired(&self, now_unix: u64) -> bool {
now_unix.saturating_sub(self.issued_at) > self.lifetime_secs
}
pub fn to_resume_request(&self, client_nonce: &[u8; 16], timestamp: u64) -> Vec<u8> {
let payload_len: u32 = 16 + 16 + 8; let mut buf = Vec::with_capacity(10 + payload_len as usize);
buf.extend_from_slice(b"SRXH");
buf.push(1); buf.push(4); buf.extend_from_slice(&payload_len.to_be_bytes());
buf.extend_from_slice(&self.ticket_id);
buf.extend_from_slice(client_nonce);
buf.extend_from_slice(×tamp.to_be_bytes());
buf
}
pub fn parse_resume_request(data: &[u8]) -> Result<([u8; 16], [u8; 16], u64)> {
if data.len() < 10 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Resume request too short".into(),
)));
}
if &data[0..4] != b"SRXH" || data[4] != 1 || data[5] != 4 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Invalid resume request header".into(),
)));
}
let payload_len = u32::from_be_bytes(data[6..10].try_into().unwrap()) as usize;
if data.len() != 10 + payload_len || payload_len != 40 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Resume request length mismatch".into(),
)));
}
let payload = &data[10..];
let mut ticket_id = [0u8; 16];
ticket_id.copy_from_slice(&payload[0..16]);
let mut client_nonce = [0u8; 16];
client_nonce.copy_from_slice(&payload[16..32]);
let timestamp = u64::from_be_bytes(payload[32..40].try_into().unwrap());
Ok((ticket_id, client_nonce, timestamp))
}
}
pub struct TicketStore {
tickets: HashMap<[u8; 16], SessionTicket>,
max_lifetime: Duration,
}
impl TicketStore {
pub fn new(max_lifetime: Duration) -> Self {
Self {
tickets: HashMap::new(),
max_lifetime,
}
}
pub fn issue(&mut self, session: &Session) -> SessionTicket {
let mut ticket_id = [0u8; 16];
rand::fill(&mut ticket_id);
let mut psk = [0u8; 32];
rand::fill(&mut psk);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let ticket = SessionTicket {
ticket_id,
psk,
seed: session.rng.seed_bytes(),
key_index: session.key_index,
issued_at: now,
lifetime_secs: self.max_lifetime.as_secs(),
};
self.tickets.insert(ticket_id, ticket.clone());
ticket
}
pub fn validate(&self, ticket_id: &[u8; 16]) -> Option<&SessionTicket> {
let ticket = self.tickets.get(ticket_id)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if ticket.is_expired(now) {
return None;
}
Some(ticket)
}
pub fn consume(&mut self, ticket_id: &[u8; 16]) -> Option<SessionTicket> {
self.tickets.remove(ticket_id)
}
pub fn cleanup_expired(&mut self) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.tickets.retain(|_, t| !t.is_expired(now));
}
pub fn len(&self) -> usize {
self.tickets.len()
}
pub fn is_empty(&self) -> bool {
self.tickets.is_empty()
}
}
pub fn resume_session(
session_id: u64,
ticket: &SessionTicket,
client_nonce: &[u8; 16],
server_nonce: &[u8; 16],
) -> Result<Session> {
let mut ikm = Vec::with_capacity(64);
ikm.extend_from_slice(&ticket.psk);
ikm.extend_from_slice(client_nonce);
let seed = KeyDerivation::derive_initial_seed(
&ticket.psk,
u64::from_be_bytes(server_nonce[0..8].try_into().unwrap()),
client_nonce,
)?;
let data_key = KeyDerivation::derive_data_key(&seed, 0)?;
Ok(Session::new(session_id, seed, data_key))
}
pub fn build_resume_response(server_nonce: &[u8; 16]) -> Vec<u8> {
let payload_len: u32 = 16;
let mut buf = Vec::with_capacity(10 + payload_len as usize);
buf.extend_from_slice(b"SRXH");
buf.push(1); buf.push(5); buf.extend_from_slice(&payload_len.to_be_bytes());
buf.extend_from_slice(server_nonce);
buf
}
pub fn parse_resume_response(data: &[u8]) -> Result<[u8; 16]> {
if data.len() < 10 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Resume response too short".into(),
)));
}
if &data[0..4] != b"SRXH" || data[4] != 1 || data[5] != 5 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Invalid resume response header".into(),
)));
}
let payload_len = u32::from_be_bytes(data[6..10].try_into().unwrap()) as usize;
if data.len() != 10 + payload_len || payload_len != 16 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Resume response length mismatch".into(),
)));
}
let mut server_nonce = [0u8; 16];
server_nonce.copy_from_slice(&data[10..26]);
Ok(server_nonce)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_session() -> Session {
let seed = [0xAAu8; 32];
let key = [0xBBu8; 32];
Session::new(1, seed, key)
}
#[test]
fn ticket_issue_and_validate() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let ticket = store.issue(&session);
assert_eq!(store.len(), 1);
assert!(store.validate(&ticket.ticket_id).is_some());
}
#[test]
fn expired_ticket_rejected() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let mut ticket = store.issue(&session);
ticket.issued_at = 0;
store.tickets.insert(ticket.ticket_id, ticket.clone());
assert!(store.validate(&ticket.ticket_id).is_none());
}
#[test]
fn consume_removes_ticket() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let ticket = store.issue(&session);
assert!(store.consume(&ticket.ticket_id).is_some());
assert!(store.is_empty());
}
#[test]
fn resume_request_roundtrip() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let ticket = store.issue(&session);
let client_nonce = [0x11u8; 16];
let ts = 1_700_000_000u64;
let wire = ticket.to_resume_request(&client_nonce, ts);
let (tid, cn, parsed_ts) = SessionTicket::parse_resume_request(&wire).unwrap();
assert_eq!(tid, ticket.ticket_id);
assert_eq!(cn, client_nonce);
assert_eq!(parsed_ts, ts);
}
#[test]
fn resume_response_roundtrip() {
let server_nonce = [0x22u8; 16];
let wire = build_resume_response(&server_nonce);
let parsed = parse_resume_response(&wire).unwrap();
assert_eq!(parsed, server_nonce);
}
#[test]
fn resume_session_derives_new_keys() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let ticket = store.issue(&session);
let client_nonce = [0x33u8; 16];
let server_nonce = [0x44u8; 16];
let s1 = resume_session(10, &ticket, &client_nonce, &server_nonce).unwrap();
let s2 = resume_session(20, &ticket, &client_nonce, &server_nonce).unwrap();
assert_eq!(s1.data_key, s2.data_key);
assert_eq!(s1.rng.seed_bytes(), s2.rng.seed_bytes());
assert_ne!(s1.data_key, session.data_key);
}
#[test]
fn cleanup_removes_expired() {
let session = dummy_session();
let mut store = TicketStore::new(Duration::from_secs(3600));
let ticket = store.issue(&session);
store.tickets.get_mut(&ticket.ticket_id).unwrap().issued_at = 0;
store.cleanup_expired();
assert!(store.is_empty());
}
}