1use std::{
9 fmt,
10 mem::size_of,
11 net::{IpAddr, SocketAddr},
12};
13
14use bytes::{Buf, BufMut, Bytes};
15use rand::Rng;
16
17use crate::{
18 Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
19 coding::{BufExt, BufMutExt},
20 crypto::{HandshakeTokenKey, HmacKey},
21 nat_traversal_api::PeerId,
22 packet::InitialHeader,
23 shared::ConnectionId,
24 token_v2::{TokenKey, decode_retry_token},
25};
26
27pub trait TokenLog: Send + Sync {
41 fn check_and_insert(
68 &self,
69 nonce: u128,
70 issued: SystemTime,
71 lifetime: Duration,
72 ) -> Result<(), TokenReuseError>;
73}
74
75pub struct TokenReuseError;
77
78pub(crate) struct NoneTokenLog;
80
81impl TokenLog for NoneTokenLog {
82 fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
83 Err(TokenReuseError)
84 }
85}
86
87pub trait TokenStore: Send + Sync {
90 fn insert(&self, server_name: &str, token: Bytes);
94
95 fn take(&self, server_name: &str) -> Option<Bytes>;
102}
103
104#[allow(dead_code)]
106pub(crate) struct NoneTokenStore;
107
108impl TokenStore for NoneTokenStore {
109 fn insert(&self, _: &str, _: Bytes) {}
110 fn take(&self, _: &str) -> Option<Bytes> {
111 None
112 }
113}
114
115#[derive(Debug)]
117pub(crate) struct IncomingToken {
118 pub(crate) retry_src_cid: Option<ConnectionId>,
119 pub(crate) orig_dst_cid: ConnectionId,
120 pub(crate) validated: bool,
121}
122
123impl IncomingToken {
124 pub(crate) fn from_header(
127 header: &InitialHeader,
128 server_config: &ServerConfig,
129 remote_address: SocketAddr,
130 ) -> Result<Self, InvalidRetryTokenError> {
131 let unvalidated = Self {
132 retry_src_cid: None,
133 orig_dst_cid: header.dst_cid,
134 validated: false,
135 };
136
137 if header.token.is_empty() {
139 return Ok(unvalidated);
140 }
141
142 let Some(retry) = Token::decode(server_config.token_key.as_ref(), &header.token) else {
154 if let Some(_v2_token) =
156 try_decode_v2_token(&header.token, server_config.token_key.as_ref())
157 {
158 return Ok(Self {
161 retry_src_cid: Some(header.dst_cid),
162 orig_dst_cid: header.dst_cid,
163 validated: true,
164 });
165 }
166 return Ok(unvalidated);
167 };
168
169 match retry.payload {
171 TokenPayload::Retry {
172 address,
173 orig_dst_cid,
174 issued,
175 } => {
176 if address != remote_address {
177 return Err(InvalidRetryTokenError);
178 }
179 if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
180 return Err(InvalidRetryTokenError);
181 }
182
183 Ok(Self {
184 retry_src_cid: Some(header.dst_cid),
185 orig_dst_cid,
186 validated: true,
187 })
188 }
189 TokenPayload::Validation { ip, issued } => {
190 if ip != remote_address.ip() {
191 return Ok(unvalidated);
192 }
193 if issued + server_config.validation_token.lifetime
194 < server_config.time_source.now()
195 {
196 return Ok(unvalidated);
197 }
198 if server_config
199 .validation_token
200 .log
201 .check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
202 .is_err()
203 {
204 return Ok(unvalidated);
205 }
206
207 Ok(Self {
208 retry_src_cid: None,
209 orig_dst_cid: header.dst_cid,
210 validated: true,
211 })
212 }
213 }
214 }
215}
216
217pub(crate) struct InvalidRetryTokenError;
221
222pub(crate) struct Token {
224 pub(crate) payload: TokenPayload,
226 nonce: u128,
228}
229
230impl Token {
231 pub(crate) fn new(payload: TokenPayload, rng: &mut impl Rng) -> Self {
233 Self {
234 nonce: rng.r#gen(),
235 payload,
236 }
237 }
238
239 pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
241 let mut buf = Vec::new();
242
243 match self.payload {
245 TokenPayload::Retry {
246 address,
247 orig_dst_cid,
248 issued,
249 } => {
250 buf.put_u8(TokenType::Retry as u8);
251 encode_addr(&mut buf, address);
252 orig_dst_cid.encode_long(&mut buf);
253 encode_unix_secs(&mut buf, issued);
254 }
255 TokenPayload::Validation { ip, issued } => {
256 buf.put_u8(TokenType::Validation as u8);
257 encode_ip(&mut buf, ip);
258 encode_unix_secs(&mut buf, issued);
259 }
260 }
261
262 let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
264 if aead_key.seal(&mut buf, &[]).is_err() {
265 return Vec::new();
267 }
268 buf.extend(&self.nonce.to_le_bytes());
269
270 buf
271 }
272
273 fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
275 let nonce_slice_start = raw_token_bytes.len().checked_sub(size_of::<u128>())?;
279 let (sealed_token, nonce_bytes) = raw_token_bytes.split_at(nonce_slice_start);
280
281 let nonce = u128::from_le_bytes(nonce_bytes.try_into().ok()?);
282
283 let aead_key = key.aead_from_hkdf(nonce_bytes);
284 let mut sealed_token = sealed_token.to_vec();
285 let data = aead_key.open(&mut sealed_token, &[]).ok()?;
286
287 let mut reader = &data[..];
289 let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
290 TokenType::Retry => TokenPayload::Retry {
291 address: decode_addr(&mut reader)?,
292 orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
293 issued: decode_unix_secs(&mut reader)?,
294 },
295 TokenType::Validation => TokenPayload::Validation {
296 ip: decode_ip(&mut reader)?,
297 issued: decode_unix_secs(&mut reader)?,
298 },
299 };
300
301 if !reader.is_empty() {
302 return None;
304 }
305
306 Some(Self { nonce, payload })
307 }
308}
309
310pub(crate) enum TokenPayload {
312 Retry {
314 address: SocketAddr,
316 orig_dst_cid: ConnectionId,
318 issued: SystemTime,
320 },
321 Validation {
323 ip: IpAddr,
325 issued: SystemTime,
327 },
328}
329
330#[derive(Copy, Clone)]
332#[repr(u8)]
333enum TokenType {
334 Retry = 0,
335 Validation = 1,
336}
337
338impl TokenType {
339 fn from_byte(n: u8) -> Option<Self> {
340 use TokenType::*;
341 [Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
342 }
343}
344
345fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
346 encode_ip(buf, address.ip());
347 buf.put_u16(address.port());
348}
349
350fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
351 let ip = decode_ip(buf)?;
352 let port = buf.get().ok()?;
353 Some(SocketAddr::new(ip, port))
354}
355
356fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
357 match ip {
358 IpAddr::V4(x) => {
359 buf.put_u8(0);
360 buf.put_slice(&x.octets());
361 }
362 IpAddr::V6(x) => {
363 buf.put_u8(1);
364 buf.put_slice(&x.octets());
365 }
366 }
367}
368
369fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
370 match buf.get::<u8>().ok()? {
371 0 => buf.get().ok().map(IpAddr::V4),
372 1 => buf.get().ok().map(IpAddr::V6),
373 _ => None,
374 }
375}
376
377fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
378 buf.write::<u64>(
379 time.duration_since(UNIX_EPOCH)
380 .unwrap_or_default()
381 .as_secs(),
382 );
383}
384
385fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
386 Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
387}
388
389fn try_decode_v2_token(
394 token_bytes: &[u8],
395 _token_key: &dyn HandshakeTokenKey,
396) -> Option<crate::token_v2::RetryTokenDecoded> {
397 let fallback_key = TokenKey([0u8; 32]); decode_retry_token(&fallback_key, token_bytes)
403}
404
405#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Copy, Clone, Hash)]
410pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
411
412impl ResetToken {
413 pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
414 let mut signature = vec![0; key.signature_len()];
415 key.sign(&id, &mut signature);
416 let mut result = [0; RESET_TOKEN_SIZE];
418 result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
419 result.into()
420 }
421}
422
423impl PartialEq for ResetToken {
424 fn eq(&self, other: &Self) -> bool {
425 crate::constant_time::eq(&self.0, &other.0)
426 }
427}
428
429impl Eq for ResetToken {}
430
431impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
432 fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
433 Self(x)
434 }
435}
436
437impl std::ops::Deref for ResetToken {
438 type Target = [u8];
439 fn deref(&self) -> &[u8] {
440 &self.0
441 }
442}
443
444impl fmt::Display for ResetToken {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 for byte in self.iter() {
447 write!(f, "{byte:02x}")?;
448 }
449 Ok(())
450 }
451}
452
453#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
454mod test {
455 use super::*;
456 #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
457 use aws_lc_rs::hkdf;
458 use rand::prelude::*;
459 #[cfg(feature = "ring")]
460 use ring::hkdf;
461
462 fn token_round_trip(payload: TokenPayload) -> TokenPayload {
463 let rng = &mut rand::thread_rng();
464 let token = Token::new(payload, rng);
465 let mut master_key = [0; 64];
466 rng.fill_bytes(&mut master_key);
467 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
468 let encoded = token.encode(&prk);
469 let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
470 assert_eq!(token.nonce, decoded.nonce);
471 decoded.payload
472 }
473
474 #[test]
475 fn retry_token_sanity() {
476 use crate::MAX_CID_SIZE;
477 use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
478 use crate::{Duration, UNIX_EPOCH};
479
480 use std::net::Ipv6Addr;
481
482 let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
483 let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
484 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Retry {
486 address: address_1,
487 orig_dst_cid: orig_dst_cid_1,
488 issued: issued_1,
489 };
490 let TokenPayload::Retry {
491 address: address_2,
492 orig_dst_cid: orig_dst_cid_2,
493 issued: issued_2,
494 } = token_round_trip(payload_1)
495 else {
496 panic!("token decoded as wrong variant");
497 };
498
499 assert_eq!(address_1, address_2);
500 assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
501 assert_eq!(issued_1, issued_2);
502 }
503
504 #[test]
505 fn validation_token_sanity() {
506 use crate::{Duration, UNIX_EPOCH};
507
508 use std::net::Ipv6Addr;
509
510 let ip_1 = Ipv6Addr::LOCALHOST.into();
511 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Validation {
514 ip: ip_1,
515 issued: issued_1,
516 };
517 let TokenPayload::Validation {
518 ip: ip_2,
519 issued: issued_2,
520 } = token_round_trip(payload_1)
521 else {
522 panic!("token decoded as wrong variant");
523 };
524
525 assert_eq!(ip_1, ip_2);
526 assert_eq!(issued_1, issued_2);
527 }
528
529 #[test]
530 fn invalid_token_returns_err() {
531 use super::*;
532 use rand::RngCore;
533
534 let rng = &mut rand::thread_rng();
535
536 let mut master_key = [0; 64];
537 rng.fill_bytes(&mut master_key);
538
539 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
540
541 let mut invalid_token = Vec::new();
542
543 let mut random_data = [0; 32];
544 rand::thread_rng().fill_bytes(&mut random_data);
545 invalid_token.put_slice(&random_data);
546
547 assert!(Token::decode(&prk, &invalid_token).is_none());
549 }
550}