1use rand::RngCore;
2
3use crate::crypto::{
4 self, compute_time_window, current_timestamp_ms, decrypt_payload, derive_session_keys,
5 encrypt_payload, generate_resonance_tag, KeyPair, SessionKeys, DEFAULT_WINDOW_MS, NONCE_SIZE,
6 TAG_SIZE,
7};
8use crate::error::{Error, Result};
9use crate::protocol::{ControlPayload, InnerHeader, InnerType};
10
11pub const DEFAULT_MDH_LEN: usize = 20;
13
14pub const DEFAULT_ZERO_MDH: [u8; 4] = [0u8; 4];
16
17pub struct DecodedPacket {
18 pub counter: u64,
19 pub header: InnerHeader,
20 pub payload: Vec<u8>,
21}
22
23const RECV_REORDER_WINDOW: usize = 256;
24const RECV_FUTURE_SEARCH_WINDOW: usize = 4096;
25
26#[derive(Clone, Copy)]
27struct Bitset256 {
28 lo: u128,
29 hi: u128,
30}
31
32impl Bitset256 {
33 fn new() -> Self {
34 Self { lo: 0, hi: 0 }
35 }
36
37 fn clear(&mut self) {
38 self.lo = 0;
39 self.hi = 0;
40 }
41
42 fn shl(self, shift: u64) -> Self {
43 if shift >= 256 {
44 return Self::new();
45 }
46 if shift == 0 {
47 return self;
48 }
49 if shift >= 128 {
50 return Self {
51 lo: 0,
52 hi: self.lo << (shift - 128),
53 };
54 }
55
56 Self {
57 lo: self.lo << shift,
58 hi: (self.hi << shift) | (self.lo >> (128 - shift)),
59 }
60 }
61
62 fn set_bit(&mut self, bit: usize) {
63 if bit < 128 {
64 self.lo |= 1u128 << bit;
65 } else if bit < 256 {
66 self.hi |= 1u128 << (bit - 128);
67 }
68 }
69
70 fn get_bit(&self, bit: usize) -> bool {
71 if bit < 128 {
72 (self.lo >> bit) & 1 == 1
73 } else if bit < 256 {
74 (self.hi >> (bit - 128)) & 1 == 1
75 } else {
76 false
77 }
78 }
79}
80
81pub struct RecvWindow {
82 highest: i64,
83 bitmap: Bitset256,
84}
85
86impl Default for RecvWindow {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl RecvWindow {
93 pub fn new() -> Self {
94 Self {
95 highest: -1,
96 bitmap: Bitset256::new(),
97 }
98 }
99
100 pub fn reset(&mut self) {
101 self.highest = -1;
102 self.bitmap.clear();
103 }
104
105 pub fn find_counter(&self, tag: &[u8; TAG_SIZE], keys: &SessionKeys) -> Option<u64> {
106 let base_tw = compute_time_window(current_timestamp_ms(), DEFAULT_WINDOW_MS);
107 let start = if self.highest < 0 {
108 0
109 } else {
110 (self.highest as u64).saturating_sub((RECV_REORDER_WINDOW - 1) as u64)
111 };
112 let end = if self.highest < 0 {
113 RECV_FUTURE_SEARCH_WINDOW as u64
114 } else {
115 std::cmp::max(
116 RECV_FUTURE_SEARCH_WINDOW as u64,
117 self.highest as u64 + RECV_FUTURE_SEARCH_WINDOW as u64 + 1,
118 )
119 };
120
121 for tw_offset in [0i64, -1, 1] {
122 let tw = (base_tw as i64 + tw_offset) as u64;
123 for counter in start..end {
124 if !self.is_new(counter) {
125 continue;
126 }
127 let expected = generate_resonance_tag(&keys.tag_secret, counter, tw);
128 if &expected == tag {
129 return Some(counter);
130 }
131 }
132 }
133
134 None
135 }
136
137 pub fn mark(&mut self, counter: u64) {
138 if self.highest < 0 || counter > self.highest as u64 {
139 let shift = if self.highest < 0 {
140 RECV_REORDER_WINDOW as u64
141 } else {
142 counter - self.highest as u64
143 };
144 self.bitmap = if shift >= RECV_REORDER_WINDOW as u64 {
145 let mut bitmap = Bitset256::new();
146 bitmap.set_bit(0);
147 bitmap
148 } else {
149 let mut bitmap = self.bitmap.shl(shift);
150 bitmap.set_bit(0);
151 bitmap
152 };
153 self.highest = counter as i64;
154 } else {
155 let diff = (self.highest as u64 - counter) as usize;
156 if diff < RECV_REORDER_WINDOW {
157 self.bitmap.set_bit(diff);
158 }
159 }
160 }
161
162 fn is_new(&self, counter: u64) -> bool {
163 if self.highest < 0 {
164 return true;
165 }
166
167 let highest = self.highest as u64;
168 if counter > highest {
169 return true;
170 }
171
172 let diff = highest - counter;
173 if diff >= RECV_REORDER_WINDOW as u64 {
174 return false;
175 }
176
177 !self.bitmap.get_bit(diff as usize)
178 }
179}
180
181pub fn build_inner_packet(inner_type: InnerType, seq_num: u16, payload: &[u8]) -> Vec<u8> {
182 let mut inner = Vec::with_capacity(4 + payload.len());
183 inner.extend_from_slice(&(inner_type as u16).to_le_bytes());
184 inner.extend_from_slice(&seq_num.to_le_bytes());
185 inner.extend_from_slice(payload);
186 inner
187}
188
189pub fn build_random_mdh_packet(
192 keys: &SessionKeys,
193 counter: &mut u64,
194 inner: &[u8],
195 obfuscated_eph_pub: Option<&[u8; 32]>,
196 mdh_len: usize,
197) -> Result<Vec<u8>> {
198 let pad_len: u16 = 8 + rand::thread_rng().next_u32() as u16 % 16;
199 let mut plaintext = Vec::with_capacity(2 + inner.len() + pad_len as usize);
200 plaintext.extend_from_slice(&pad_len.to_le_bytes());
201 plaintext.extend_from_slice(inner);
202 plaintext.resize(2 + inner.len() + pad_len as usize, 0);
203 rand::thread_rng().fill_bytes(&mut plaintext[2 + inner.len()..]);
204
205 let current_counter = *counter;
206 *counter += 1;
207
208 let nonce = counter_to_nonce(current_counter);
209 let ciphertext = encrypt_payload(&keys.session_key, &nonce, &plaintext)?;
210 let time_window = compute_time_window(current_timestamp_ms(), DEFAULT_WINDOW_MS);
211 let tag = generate_resonance_tag(&keys.tag_secret, current_counter, time_window);
212
213 let mut mdh = vec![0u8; mdh_len];
215 rand::thread_rng().fill_bytes(&mut mdh);
216
217 let eph_len = if obfuscated_eph_pub.is_some() { 32 } else { 0 };
218 let mut packet = Vec::with_capacity(TAG_SIZE + mdh_len + eph_len + ciphertext.len());
219 packet.extend_from_slice(&tag);
220 packet.extend_from_slice(&mdh);
221 if let Some(eph) = obfuscated_eph_pub {
222 packet.extend_from_slice(eph);
223 }
224 packet.extend_from_slice(&ciphertext);
225
226 Ok(packet)
227}
228
229pub fn build_zero_mdh_packet(
231 keys: &SessionKeys,
232 counter: &mut u64,
233 inner: &[u8],
234 obfuscated_eph_pub: Option<&[u8; 32]>,
235) -> Result<Vec<u8>> {
236 build_random_mdh_packet(
237 keys,
238 counter,
239 inner,
240 obfuscated_eph_pub,
241 DEFAULT_ZERO_MDH.len(),
242 )
243}
244
245pub fn decode_packet_with_mdh_len(
246 packet: &[u8],
247 keys: &SessionKeys,
248 recv_window: &mut RecvWindow,
249 mdh_len: usize,
250) -> Result<DecodedPacket> {
251 if packet.len() < TAG_SIZE + mdh_len + 16 {
252 return Err(Error::InvalidPacket("Packet too short"));
253 }
254
255 let tag: [u8; TAG_SIZE] = packet[..TAG_SIZE]
256 .try_into()
257 .map_err(|_| Error::InvalidPacket("Packet tag malformed"))?;
258 let counter = recv_window
259 .find_counter(&tag, keys)
260 .ok_or(Error::InvalidPacket("Invalid resonance tag"))?;
261
262 let nonce = counter_to_nonce(counter);
263 let ciphertext = &packet[TAG_SIZE + mdh_len..];
264 let padded = decrypt_payload(&keys.session_key, &nonce, ciphertext)?;
265 recv_window.mark(counter);
266
267 if padded.len() < 2 {
268 return Err(Error::InvalidPacket("Decrypted payload too short"));
269 }
270
271 let pad_len = u16::from_le_bytes([padded[0], padded[1]]) as usize;
272 let end = padded
273 .len()
274 .checked_sub(pad_len)
275 .ok_or(Error::InvalidPacket("Invalid padding length"))?;
276 if end < 2 {
277 return Err(Error::InvalidPacket("Invalid padding length"));
278 }
279
280 let inner = &padded[2..end];
281 if inner.len() < 4 {
282 return Err(Error::InvalidPacket("Inner payload too short"));
283 }
284
285 let header = InnerHeader::decode(inner)?;
286 let payload = inner[4..].to_vec();
287
288 Ok(DecodedPacket {
289 counter,
290 header,
291 payload,
292 })
293}
294
295pub fn process_server_hello_with_mdh_len(
296 packet: &[u8],
297 keys: &mut SessionKeys,
298 keypair: &KeyPair,
299 recv_window: &mut RecvWindow,
300 send_counter: &mut u64,
301 mdh_len: usize,
302) -> Result<()> {
303 let decoded = decode_packet_with_mdh_len(packet, keys, recv_window, mdh_len)?;
304
305 if decoded.header.inner_type != InnerType::Control {
306 return Err(Error::InvalidPacket(
307 "Expected control packet for ServerHello",
308 ));
309 }
310
311 match ControlPayload::decode(&decoded.payload)? {
312 ControlPayload::ServerHello { server_eph_pub, .. } => {
313 let dh2 = keypair.compute_shared(&server_eph_pub)?;
314 let old_session_key = keys.session_key;
315 *keys = derive_session_keys(&dh2, Some(&old_session_key), &keypair.public_key_bytes());
316 *send_counter = 0;
317 recv_window.reset();
318 Ok(())
319 }
320 _ => Err(Error::InvalidPacket("Expected ServerHello control payload")),
321 }
322}
323
324pub fn obfuscate_client_eph_pub(keypair: &KeyPair, server_public_key: &[u8; 32]) -> [u8; 32] {
325 let mut obfuscated = keypair.public_key_bytes();
326 crypto::obfuscate_eph_pub(&mut obfuscated, server_public_key);
327 obfuscated
328}
329
330pub fn counter_to_nonce(counter: u64) -> [u8; NONCE_SIZE] {
331 let mut nonce = [0u8; NONCE_SIZE];
332 nonce[..8].copy_from_slice(&counter.to_le_bytes());
333 nonce
334}