1#![allow(missing_docs)]
15
16use rand::RngCore;
20
21use crate::{nat_traversal_api::PeerId, shared::ConnectionId};
22
23use aws_lc_rs::aead::{AES_256_GCM, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
24
25#[derive(Clone)]
28pub struct TokenKey(pub [u8; 32]);
29
30#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct RetryTokenDecoded {
34 pub peer_id: PeerId,
36 pub cid: ConnectionId,
38 pub nonce: u128,
40}
41
42pub fn test_key_from_rng(rng: &mut dyn RngCore) -> TokenKey {
45 let mut k = [0u8; 32];
46 rng.fill_bytes(&mut k);
47 TokenKey(k)
48}
49
50pub fn encode_retry_token_with_rng<R: RngCore>(
54 key: &TokenKey,
55 peer_id: &PeerId,
56 cid: &ConnectionId,
57 rng: &mut R,
58) -> Vec<u8> {
59 let mut nonce_bytes = [0u8; 12]; rng.fill_bytes(&mut nonce_bytes);
61
62 let mut pt = Vec::with_capacity(32 + 1 + crate::MAX_CID_SIZE + 12);
63 pt.extend_from_slice(&peer_id.0);
64 pt.push(cid.len() as u8);
65 pt.extend_from_slice(&cid[..]);
66 pt.extend_from_slice(&nonce_bytes); seal(&key.0, &nonce_bytes, &pt)
68}
69
70pub fn encode_retry_token(key: &TokenKey, peer_id: &PeerId, cid: &ConnectionId) -> Vec<u8> {
71 encode_retry_token_with_rng(key, peer_id, cid, &mut rand::thread_rng())
72}
73
74pub fn decode_retry_token(key: &TokenKey, token: &[u8]) -> Option<RetryTokenDecoded> {
78 let (ct, nonce_suffix) = token.split_at(token.len().checked_sub(12)?);
80 let mut nonce12 = [0u8; 12];
81 nonce12.copy_from_slice(nonce_suffix);
82 let plaintext = open(&key.0, &nonce12, ct).ok()?;
83 if plaintext.len() < 32 + 1 + 12 {
84 return None;
85 } let mut off = 0usize;
87 let mut pid = [0u8; 32];
88 pid.copy_from_slice(&plaintext[off..off + 32]);
89 off += 32;
90 let cid_len = plaintext[off] as usize;
91 off += 1;
92 if plaintext.len() < off + cid_len + 12 {
93 return None;
94 }
95 let mut cid_buf = [0u8; crate::MAX_CID_SIZE];
96 cid_buf[..cid_len].copy_from_slice(&plaintext[off..off + cid_len]);
97 let cid = ConnectionId::new(&cid_buf[..cid_len]);
98 off += cid_len;
99 let mut nonce_arr = [0u8; 12];
100 nonce_arr.copy_from_slice(&plaintext[off..off + 12]);
101 let mut nonce_bytes_16 = [0u8; 16];
102 nonce_bytes_16[..12].copy_from_slice(&nonce_arr);
103 let nonce = u128::from_le_bytes(nonce_bytes_16); Some(RetryTokenDecoded {
105 peer_id: PeerId(pid),
106 cid,
107 nonce,
108 })
109}
110
111pub fn validate_token(
114 key: &TokenKey,
115 token: &[u8],
116 expected_peer: &PeerId,
117 expected_cid: &ConnectionId,
118) -> bool {
119 match decode_retry_token(key, token) {
120 Some(dec) => dec.peer_id == *expected_peer && dec.cid == *expected_cid,
121 None => false,
122 }
123}
124
125#[allow(clippy::expect_used, clippy::let_unit_value)]
128fn seal(key: &[u8; 32], nonce: &[u8; 12], pt: &[u8]) -> Vec<u8> {
129 let unbound_key = UnboundKey::new(&AES_256_GCM, key).expect("invalid key length");
130 let key = LessSafeKey::new(unbound_key);
131
132 let nonce_bytes = *nonce;
134
135 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes).expect("invalid nonce length");
137
138 let mut in_out = pt.to_vec();
139 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut in_out)
140 .expect("encryption failed");
141
142 in_out.extend_from_slice(&nonce_bytes);
144 in_out
145}
146
147fn open(key: &[u8; 32], nonce12: &[u8; 12], ct_without_suffix: &[u8]) -> Result<Vec<u8>, ()> {
154 let unbound_key = UnboundKey::new(&AES_256_GCM, key).map_err(|_| ())?;
155 let key = LessSafeKey::new(unbound_key);
156
157 let nonce = Nonce::try_assume_unique_for_key(nonce12).map_err(|_| ())?;
160
161 let mut in_out = ct_without_suffix.to_vec();
162 key.open_in_place(nonce, Aad::empty(), &mut in_out)
163 .map_err(|_| ())?;
164
165 in_out.truncate(in_out.len() - 16);
167 Ok(in_out)
168}