1use std::{
9 fmt,
10 mem::size_of,
11 net::{IpAddr, SocketAddr},
12};
13
14use bytes::{Buf, BufMut, Bytes};
15use rand::Rng;
16
17use crate::{
18 Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
19 coding::{BufExt, BufMutExt},
20 crypto::{HandshakeTokenKey, HmacKey},
21 nat_traversal_api::PeerId,
22 packet::InitialHeader,
23 shared::ConnectionId,
24 token_v2::{TokenKey, decode_retry_token},
25};
26
27pub trait TokenLog: Send + Sync {
41 fn check_and_insert(
68 &self,
69 nonce: u128,
70 issued: SystemTime,
71 lifetime: Duration,
72 ) -> Result<(), TokenReuseError>;
73}
74
75pub struct TokenReuseError;
77
78pub(crate) struct NoneTokenLog;
80
81impl TokenLog for NoneTokenLog {
82 fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
83 Err(TokenReuseError)
84 }
85}
86
87pub trait TokenStore: Send + Sync {
90 fn insert(&self, server_name: &str, token: Bytes);
94
95 fn take(&self, server_name: &str) -> Option<Bytes>;
102}
103
104#[allow(dead_code)]
106pub(crate) struct NoneTokenStore;
107
108impl TokenStore for NoneTokenStore {
109 fn insert(&self, _: &str, _: Bytes) {}
110 fn take(&self, _: &str) -> Option<Bytes> {
111 None
112 }
113}
114
115#[derive(Debug)]
117pub(crate) struct IncomingToken {
118 pub(crate) retry_src_cid: Option<ConnectionId>,
119 pub(crate) orig_dst_cid: ConnectionId,
120 pub(crate) validated: bool,
121 pub(crate) peer_id: Option<PeerId>,
122}
123
124impl IncomingToken {
125 pub(crate) fn from_header(
128 header: &InitialHeader,
129 server_config: &ServerConfig,
130 remote_address: SocketAddr,
131 ) -> Result<Self, InvalidRetryTokenError> {
132 let unvalidated = Self {
133 retry_src_cid: None,
134 orig_dst_cid: header.dst_cid,
135 validated: false,
136 peer_id: None,
137 };
138
139 if header.token.is_empty() {
141 return Ok(unvalidated);
142 }
143
144 let Some(retry) = Token::decode(server_config.token_key.as_ref(), &header.token) else {
156 if let Some(v2_token) = try_decode_v2_token(
158 &header.token,
159 server_config.token_v2_key.as_ref(),
160 &header.dst_cid,
161 ) {
162 return Ok(Self {
163 retry_src_cid: Some(header.dst_cid),
164 orig_dst_cid: header.dst_cid,
165 validated: true,
166 peer_id: Some(v2_token.peer_id),
167 });
168 }
169 return Ok(unvalidated);
170 };
171
172 match retry.payload {
174 TokenPayload::Retry {
175 address,
176 orig_dst_cid,
177 issued,
178 } => {
179 if address != remote_address {
180 return Err(InvalidRetryTokenError);
181 }
182 if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
183 return Err(InvalidRetryTokenError);
184 }
185
186 Ok(Self {
187 retry_src_cid: Some(header.dst_cid),
188 orig_dst_cid,
189 validated: true,
190 peer_id: None,
191 })
192 }
193 TokenPayload::Validation { ip, issued } => {
194 if ip != remote_address.ip() {
195 return Ok(unvalidated);
196 }
197 if issued + server_config.validation_token.lifetime
198 < server_config.time_source.now()
199 {
200 return Ok(unvalidated);
201 }
202 if server_config
203 .validation_token
204 .log
205 .check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
206 .is_err()
207 {
208 return Ok(unvalidated);
209 }
210
211 Ok(Self {
212 retry_src_cid: None,
213 orig_dst_cid: header.dst_cid,
214 validated: true,
215 peer_id: None,
216 })
217 }
218 }
219 }
220}
221
222pub(crate) struct InvalidRetryTokenError;
226
227pub(crate) struct Token {
229 pub(crate) payload: TokenPayload,
231 nonce: u128,
233}
234
235impl Token {
236 pub(crate) fn new(payload: TokenPayload, rng: &mut impl Rng) -> Self {
238 Self {
239 nonce: rng.r#gen(),
240 payload,
241 }
242 }
243
244 pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
246 let mut buf = Vec::new();
247
248 match self.payload {
250 TokenPayload::Retry {
251 address,
252 orig_dst_cid,
253 issued,
254 } => {
255 buf.put_u8(TokenType::Retry as u8);
256 encode_addr(&mut buf, address);
257 orig_dst_cid.encode_long(&mut buf);
258 encode_unix_secs(&mut buf, issued);
259 }
260 TokenPayload::Validation { ip, issued } => {
261 buf.put_u8(TokenType::Validation as u8);
262 encode_ip(&mut buf, ip);
263 encode_unix_secs(&mut buf, issued);
264 }
265 }
266
267 let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
269 if aead_key.seal(&mut buf, &[]).is_err() {
270 return Vec::new();
272 }
273 buf.extend(&self.nonce.to_le_bytes());
274
275 buf
276 }
277
278 fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
280 let nonce_slice_start = raw_token_bytes.len().checked_sub(size_of::<u128>())?;
284 let (sealed_token, nonce_bytes) = raw_token_bytes.split_at(nonce_slice_start);
285
286 let nonce = u128::from_le_bytes(nonce_bytes.try_into().ok()?);
287
288 let aead_key = key.aead_from_hkdf(nonce_bytes);
289 let mut sealed_token = sealed_token.to_vec();
290 let data = aead_key.open(&mut sealed_token, &[]).ok()?;
291
292 let mut reader = &data[..];
294 let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
295 TokenType::Retry => TokenPayload::Retry {
296 address: decode_addr(&mut reader)?,
297 orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
298 issued: decode_unix_secs(&mut reader)?,
299 },
300 TokenType::Validation => TokenPayload::Validation {
301 ip: decode_ip(&mut reader)?,
302 issued: decode_unix_secs(&mut reader)?,
303 },
304 };
305
306 if !reader.is_empty() {
307 return None;
309 }
310
311 Some(Self { nonce, payload })
312 }
313}
314
315pub(crate) enum TokenPayload {
317 Retry {
319 address: SocketAddr,
321 orig_dst_cid: ConnectionId,
323 issued: SystemTime,
325 },
326 Validation {
328 ip: IpAddr,
330 issued: SystemTime,
332 },
333}
334
335#[derive(Copy, Clone)]
337#[repr(u8)]
338enum TokenType {
339 Retry = 0,
340 Validation = 1,
341}
342
343impl TokenType {
344 fn from_byte(n: u8) -> Option<Self> {
345 use TokenType::*;
346 [Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
347 }
348}
349
350fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
351 encode_ip(buf, address.ip());
352 buf.put_u16(address.port());
353}
354
355fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
356 let ip = decode_ip(buf)?;
357 let port = buf.get().ok()?;
358 Some(SocketAddr::new(ip, port))
359}
360
361fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
362 match ip {
363 IpAddr::V4(x) => {
364 buf.put_u8(0);
365 buf.put_slice(&x.octets());
366 }
367 IpAddr::V6(x) => {
368 buf.put_u8(1);
369 buf.put_slice(&x.octets());
370 }
371 }
372}
373
374fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
375 match buf.get::<u8>().ok()? {
376 0 => buf.get().ok().map(IpAddr::V4),
377 1 => buf.get().ok().map(IpAddr::V6),
378 _ => None,
379 }
380}
381
382fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
383 buf.write::<u64>(
384 time.duration_since(UNIX_EPOCH)
385 .unwrap_or_default()
386 .as_secs(),
387 );
388}
389
390fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
391 Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
392}
393
394fn try_decode_v2_token(
399 token_bytes: &[u8],
400 token_key: Option<&TokenKey>,
401 expected_cid: &ConnectionId,
402) -> Option<crate::token_v2::RetryTokenDecoded> {
403 let key = token_key?;
404 let decoded = decode_retry_token(key, token_bytes)?;
405 if &decoded.cid != expected_cid {
406 return None;
407 }
408 Some(decoded)
409}
410
411#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Copy, Clone, Hash)]
416pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
417
418impl ResetToken {
419 pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
420 let mut signature = vec![0; key.signature_len()];
421 key.sign(&id, &mut signature);
422 let mut result = [0; RESET_TOKEN_SIZE];
424 result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
425 result.into()
426 }
427}
428
429impl PartialEq for ResetToken {
430 fn eq(&self, other: &Self) -> bool {
431 crate::constant_time::eq(&self.0, &other.0)
432 }
433}
434
435impl Eq for ResetToken {}
436
437impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
438 fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
439 Self(x)
440 }
441}
442
443impl std::ops::Deref for ResetToken {
444 type Target = [u8];
445 fn deref(&self) -> &[u8] {
446 &self.0
447 }
448}
449
450impl fmt::Display for ResetToken {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 for byte in self.iter() {
453 write!(f, "{byte:02x}")?;
454 }
455 Ok(())
456 }
457}
458
459#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
460mod test {
461 use super::*;
462 #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
463 use aws_lc_rs::hkdf;
464 use rand::prelude::*;
465 #[cfg(feature = "ring")]
466 use ring::hkdf;
467
468 fn token_round_trip(payload: TokenPayload) -> TokenPayload {
469 let rng = &mut rand::thread_rng();
470 let token = Token::new(payload, rng);
471 let mut master_key = [0; 64];
472 rng.fill_bytes(&mut master_key);
473 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
474 let encoded = token.encode(&prk);
475 let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
476 assert_eq!(token.nonce, decoded.nonce);
477 decoded.payload
478 }
479
480 #[test]
481 fn retry_token_sanity() {
482 use crate::MAX_CID_SIZE;
483 use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
484 use crate::{Duration, UNIX_EPOCH};
485
486 use std::net::Ipv6Addr;
487
488 let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
489 let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
490 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Retry {
492 address: address_1,
493 orig_dst_cid: orig_dst_cid_1,
494 issued: issued_1,
495 };
496 let TokenPayload::Retry {
497 address: address_2,
498 orig_dst_cid: orig_dst_cid_2,
499 issued: issued_2,
500 } = token_round_trip(payload_1)
501 else {
502 panic!("token decoded as wrong variant");
503 };
504
505 assert_eq!(address_1, address_2);
506 assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
507 assert_eq!(issued_1, issued_2);
508 }
509
510 #[test]
511 fn validation_token_sanity() {
512 use crate::{Duration, UNIX_EPOCH};
513
514 use std::net::Ipv6Addr;
515
516 let ip_1 = Ipv6Addr::LOCALHOST.into();
517 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Validation {
520 ip: ip_1,
521 issued: issued_1,
522 };
523 let TokenPayload::Validation {
524 ip: ip_2,
525 issued: issued_2,
526 } = token_round_trip(payload_1)
527 else {
528 panic!("token decoded as wrong variant");
529 };
530
531 assert_eq!(ip_1, ip_2);
532 assert_eq!(issued_1, issued_2);
533 }
534
535 #[test]
536 fn invalid_token_returns_err() {
537 use super::*;
538 use rand::RngCore;
539
540 let rng = &mut rand::thread_rng();
541
542 let mut master_key = [0; 64];
543 rng.fill_bytes(&mut master_key);
544
545 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
546
547 let mut invalid_token = Vec::new();
548
549 let mut random_data = [0; 32];
550 rand::thread_rng().fill_bytes(&mut random_data);
551 invalid_token.put_slice(&random_data);
552
553 assert!(Token::decode(&prk, &invalid_token).is_none());
555 }
556}