#![allow(dead_code)]
use alloc::vec::Vec;
use crate::ct::ConstantTimeEq;
use crate::hash::HmacSha256;
use crate::tls::Error;
pub(crate) const MAX_TOKEN_AGE_SECS: u64 = 300;
pub(crate) const CLIENT_ADDR_BYTES: usize = 18;
const TAG_LEN: usize = 16;
pub(crate) fn mint(
retry_secret: &[u8; 32],
client_addr_bytes: &[u8; CLIENT_ADDR_BYTES],
odcid: &[u8],
now_secs: u64,
) -> Vec<u8> {
debug_assert!(odcid.len() <= 20, "QUIC v1 CID length must be ≤ 20 bytes");
let mut out = Vec::with_capacity(CLIENT_ADDR_BYTES + 1 + odcid.len() + 8 + TAG_LEN);
out.extend_from_slice(client_addr_bytes);
out.push(odcid.len() as u8);
out.extend_from_slice(odcid);
out.extend_from_slice(&now_secs.to_be_bytes());
let body_len = out.len();
let tag = HmacSha256::mac(retry_secret, &out[..body_len]);
out.extend_from_slice(&tag[..TAG_LEN]);
out
}
pub(crate) fn validate(
retry_secret: &[u8; 32],
client_addr_bytes: &[u8; CLIENT_ADDR_BYTES],
token: &[u8],
now_secs: u64,
) -> Result<Vec<u8>, Error> {
if token.len() < CLIENT_ADDR_BYTES + 1 + 8 + TAG_LEN {
return Err(Error::Decode);
}
let addr_in_token = &token[..CLIENT_ADDR_BYTES];
if addr_in_token != client_addr_bytes.as_slice() {
return Err(Error::Decode);
}
let odcid_len = token[CLIENT_ADDR_BYTES] as usize;
if odcid_len > 20 {
return Err(Error::Decode);
}
let odcid_start = CLIENT_ADDR_BYTES + 1;
let odcid_end = odcid_start + odcid_len;
let ts_start = odcid_end;
let ts_end = ts_start + 8;
let tag_start = ts_end;
let tag_end = tag_start + TAG_LEN;
if token.len() != tag_end {
return Err(Error::Decode);
}
let body = &token[..tag_start];
let computed = HmacSha256::mac(retry_secret, body);
let provided = &token[tag_start..tag_end];
let ok = computed[..TAG_LEN].ct_eq(provided);
if !bool::from(ok) {
return Err(Error::Decode);
}
let mut ts_bytes = [0u8; 8];
ts_bytes.copy_from_slice(&token[ts_start..ts_end]);
let ts = u64::from_be_bytes(ts_bytes);
if ts > now_secs {
return Err(Error::Decode);
}
if now_secs - ts > MAX_TOKEN_AGE_SECS {
return Err(Error::Decode);
}
Ok(token[odcid_start..odcid_end].to_vec())
}
#[cfg(feature = "std")]
pub(crate) fn encode_addr(addr: &std::net::SocketAddr) -> [u8; CLIENT_ADDR_BYTES] {
let mut out = [0u8; CLIENT_ADDR_BYTES];
let ip6 = match addr.ip() {
std::net::IpAddr::V4(v4) => v4.to_ipv6_mapped(),
std::net::IpAddr::V6(v6) => v6,
};
out[..16].copy_from_slice(&ip6.octets());
out[16..18].copy_from_slice(&addr.port().to_be_bytes());
out
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
fn fixed_secret() -> [u8; 32] {
let mut s = [0u8; 32];
for (i, b) in s.iter_mut().enumerate() {
*b = i as u8;
}
s
}
#[test]
fn retry_token_roundtrip() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)),
4433,
));
let odcid = [0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
let now = 1000u64;
let tok = mint(&secret, &addr, &odcid, now);
let got = validate(&secret, &addr, &tok, now).expect("validate ok");
assert_eq!(got, odcid);
}
#[test]
fn retry_token_rejects_wrong_addr() {
let secret = fixed_secret();
let addr1 = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)),
4433,
));
let addr2 = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)),
4433,
));
let odcid = [0xaa; 8];
let tok = mint(&secret, &addr1, &odcid, 1000);
let err = validate(&secret, &addr2, &tok, 1000);
assert!(err.is_err());
}
#[test]
fn retry_token_rejects_wrong_secret() {
let secret_a = fixed_secret();
let mut secret_b = fixed_secret();
secret_b[0] ^= 1;
let addr = encode_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0));
let tok = mint(&secret_a, &addr, &[1, 2, 3, 4], 100);
let err = validate(&secret_b, &addr, &tok, 100);
assert!(err.is_err());
}
#[test]
fn retry_token_rejects_expired() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0));
let odcid = [0xab; 8];
let tok = mint(&secret, &addr, &odcid, 100);
assert!(validate(&secret, &addr, &tok, 400).is_ok());
assert!(validate(&secret, &addr, &tok, 401).is_err());
}
#[test]
fn retry_token_rejects_future_timestamp() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0));
let tok = mint(&secret, &addr, &[0xcd; 4], 500);
let err = validate(&secret, &addr, &tok, 100);
assert!(err.is_err());
}
#[test]
fn retry_token_rejects_tampered_hmac() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 9)),
7777,
));
let odcid = [0xde, 0xad, 0xbe, 0xef];
let mut tok = mint(&secret, &addr, &odcid, 1234);
let last = tok.len() - 1;
tok[last] ^= 1;
assert!(validate(&secret, &addr, &tok, 1234).is_err());
}
#[test]
fn retry_token_rejects_tampered_body_bytes() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 9)),
7777,
));
let odcid = [0xde, 0xad, 0xbe, 0xef];
let mut tok = mint(&secret, &addr, &odcid, 1234);
let body_offset = CLIENT_ADDR_BYTES + 1; tok[body_offset] ^= 1;
assert!(validate(&secret, &addr, &tok, 1234).is_err());
}
#[test]
fn retry_token_rejects_short_token() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0));
assert!(validate(&secret, &addr, &[], 0).is_err());
assert!(validate(&secret, &addr, &[0u8; 42], 0).is_err());
}
#[test]
fn retry_token_rejects_extra_trailing_bytes() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0));
let mut tok = mint(&secret, &addr, &[0u8; 8], 100);
tok.push(0); assert!(validate(&secret, &addr, &tok, 100).is_err());
}
#[test]
fn encode_addr_ipv4_mapped_matches_ipv6() {
let a = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
4242,
));
let v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x7f00, 0x0001);
let b = encode_addr(&SocketAddr::new(IpAddr::V6(v6), 4242));
assert_eq!(a, b);
}
#[test]
fn retry_token_constant_time_compare() {
let secret = fixed_secret();
let addr = encode_addr(&SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
1234,
));
let tok = mint(&secret, &addr, &[1, 2, 3, 4], 1000);
let tag_start = tok.len() - TAG_LEN;
for i in tag_start..tok.len() {
let mut bad = tok.clone();
bad[i] ^= 1;
assert!(
validate(&secret, &addr, &bad, 1000).is_err(),
"tag corruption at byte {i} accepted"
);
}
}
}