use std::{
fmt,
net::{IpAddr, SocketAddr},
};
use bytes::{Buf, BufMut, Bytes};
use rand::{CryptoRng, RngExt};
use crate::{
Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
coding::{BufExt, BufMutExt},
crypto::{HandshakeTokenKey, HmacKey},
packet::InitialHeader,
shared::ConnectionId,
};
pub trait TokenLog: Send + Sync {
fn check_and_insert(
&self,
nonce: u128,
issued: SystemTime,
lifetime: Duration,
) -> Result<(), TokenReuseError>;
}
pub struct TokenReuseError;
pub struct NoneTokenLog;
impl TokenLog for NoneTokenLog {
fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
Err(TokenReuseError)
}
}
pub trait TokenStore: Send + Sync {
fn insert(&self, server_name: &str, token: Bytes);
fn take(&self, server_name: &str) -> Option<Bytes>;
}
pub struct NoneTokenStore;
impl TokenStore for NoneTokenStore {
fn insert(&self, _: &str, _: Bytes) {}
fn take(&self, _: &str) -> Option<Bytes> {
None
}
}
#[derive(Debug)]
pub(crate) struct IncomingToken {
pub(crate) retry_src_cid: Option<ConnectionId>,
pub(crate) orig_dst_cid: ConnectionId,
pub(crate) validated: bool,
}
impl IncomingToken {
pub(crate) fn from_header(
header: &InitialHeader,
server_config: &ServerConfig,
remote_address: SocketAddr,
) -> Result<Self, InvalidRetryTokenError> {
let unvalidated = Self {
retry_src_cid: None,
orig_dst_cid: header.dst_cid,
validated: false,
};
if header.token.is_empty() {
return Ok(unvalidated);
}
let Some(retry) = Token::decode(&*server_config.token_key, &header.token) else {
return Ok(unvalidated);
};
match retry.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
if address != remote_address {
return Err(InvalidRetryTokenError);
}
if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
return Err(InvalidRetryTokenError);
}
Ok(Self {
retry_src_cid: Some(header.dst_cid),
orig_dst_cid,
validated: true,
})
}
TokenPayload::Validation { ip, issued } => {
if ip != remote_address.ip() {
return Ok(unvalidated);
}
if issued + server_config.validation_token.lifetime
< server_config.time_source.now()
{
return Ok(unvalidated);
}
if server_config
.validation_token
.log
.check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
.is_err()
{
return Ok(unvalidated);
}
Ok(Self {
retry_src_cid: None,
orig_dst_cid: header.dst_cid,
validated: true,
})
}
}
}
}
pub(crate) struct InvalidRetryTokenError;
pub(crate) struct Token {
pub(crate) payload: TokenPayload,
nonce: u128,
}
impl Token {
pub(crate) fn new(payload: TokenPayload, rng: &mut impl CryptoRng) -> Self {
Self {
nonce: rng.random(),
payload,
}
}
pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
let mut buf = Vec::new();
match self.payload {
TokenPayload::Retry {
address,
orig_dst_cid,
issued,
} => {
buf.put_u8(TokenType::Retry as u8);
encode_addr(&mut buf, address);
orig_dst_cid.encode_long(&mut buf);
encode_unix_secs(&mut buf, issued);
}
TokenPayload::Validation { ip, issued } => {
buf.put_u8(TokenType::Validation as u8);
encode_ip(&mut buf, ip);
encode_unix_secs(&mut buf, issued);
}
}
key.seal(self.nonce, &mut buf).unwrap();
buf.extend(&self.nonce.to_le_bytes());
buf
}
fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
let (sealed_token, nonce_bytes) = raw_token_bytes.split_last_chunk()?;
let nonce = u128::from_le_bytes(*nonce_bytes);
let mut sealed_token = sealed_token.to_vec();
let mut data = key.open(nonce, &mut sealed_token).ok()?;
let payload = match TokenType::from_byte((&mut data).get::<u8>().ok()?)? {
TokenType::Retry => TokenPayload::Retry {
address: decode_addr(&mut data)?,
orig_dst_cid: ConnectionId::decode_long(&mut data)?,
issued: decode_unix_secs(&mut data)?,
},
TokenType::Validation => TokenPayload::Validation {
ip: decode_ip(&mut data)?,
issued: decode_unix_secs(&mut data)?,
},
};
if !data.is_empty() {
return None;
}
Some(Self { nonce, payload })
}
}
pub(crate) enum TokenPayload {
Retry {
address: SocketAddr,
orig_dst_cid: ConnectionId,
issued: SystemTime,
},
Validation {
ip: IpAddr,
issued: SystemTime,
},
}
#[derive(Copy, Clone)]
#[repr(u8)]
enum TokenType {
Retry = 0,
Validation = 1,
}
impl TokenType {
fn from_byte(n: u8) -> Option<Self> {
use TokenType::*;
[Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
}
}
fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
encode_ip(buf, address.ip());
buf.put_u16(address.port());
}
fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
let ip = decode_ip(buf)?;
let port = buf.get().ok()?;
Some(SocketAddr::new(ip, port))
}
fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
match ip {
IpAddr::V4(x) => {
buf.put_u8(0);
buf.put_slice(&x.octets());
}
IpAddr::V6(x) => {
buf.put_u8(1);
buf.put_slice(&x.octets());
}
}
}
fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
match buf.get::<u8>().ok()? {
0 => buf.get().ok().map(IpAddr::V4),
1 => buf.get().ok().map(IpAddr::V6),
_ => None,
}
}
fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
buf.write::<u64>(
time.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
}
fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
}
#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Copy, Clone, Hash)]
pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
impl ResetToken {
pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
let mut signature = vec![0; key.signature_len()];
key.sign(&id, &mut signature);
let mut result = [0; RESET_TOKEN_SIZE];
result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
result.into()
}
}
impl PartialEq for ResetToken {
fn eq(&self, other: &Self) -> bool {
crate::constant_time::eq(&self.0, &other.0)
}
}
impl Eq for ResetToken {}
impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
Self(x)
}
}
impl std::ops::Deref for ResetToken {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.0
}
}
impl fmt::Display for ResetToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.iter() {
write!(f, "{byte:02x}")?;
}
Ok(())
}
}
#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
mod test {
use rand::Rng;
use crate::crypto::ring_like::RetryTokenKey;
use super::*;
fn token_round_trip(payload: TokenPayload) -> TokenPayload {
let rng = &mut rand::rng();
let token = Token::new(payload, rng);
let master_key = RetryTokenKey::new(rng);
let encoded = token.encode(&master_key);
let decoded = Token::decode(&master_key, &encoded).expect("token didn't decrypt / decode");
assert_eq!(token.nonce, decoded.nonce);
decoded.payload
}
#[test]
fn retry_token_sanity() {
use crate::MAX_CID_SIZE;
use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
use crate::{Duration, UNIX_EPOCH};
use std::net::Ipv6Addr;
let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Retry {
address: address_1,
orig_dst_cid: orig_dst_cid_1,
issued: issued_1,
};
let TokenPayload::Retry {
address: address_2,
orig_dst_cid: orig_dst_cid_2,
issued: issued_2,
} = token_round_trip(payload_1)
else {
panic!("token decoded as wrong variant");
};
assert_eq!(address_1, address_2);
assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
assert_eq!(issued_1, issued_2);
}
#[test]
fn validation_token_sanity() {
use crate::{Duration, UNIX_EPOCH};
use std::net::Ipv6Addr;
let ip_1 = Ipv6Addr::LOCALHOST.into();
let issued_1 = UNIX_EPOCH + Duration::from_secs(42);
let payload_1 = TokenPayload::Validation {
ip: ip_1,
issued: issued_1,
};
let TokenPayload::Validation {
ip: ip_2,
issued: issued_2,
} = token_round_trip(payload_1)
else {
panic!("token decoded as wrong variant");
};
assert_eq!(ip_1, ip_2);
assert_eq!(issued_1, issued_2);
}
#[test]
fn invalid_token_returns_err() {
use super::*;
let master_key = RetryTokenKey::new(&mut rand::rng());
let mut invalid_token = Vec::new();
let mut random_data = [0; 32];
rand::rng().fill_bytes(&mut random_data);
invalid_token.put_slice(&random_data);
assert!(Token::decode(&master_key, &invalid_token).is_none());
}
}