Skip to main content

aivpn_common/
client_wire.rs

1use rand::RngCore;
2
3use crate::crypto::{
4    self, compute_time_window, current_timestamp_ms, decrypt_payload, derive_session_keys,
5    encrypt_payload, generate_resonance_tag, KeyPair, SessionKeys, DEFAULT_WINDOW_MS, NONCE_SIZE,
6    TAG_SIZE,
7};
8use crate::error::{Error, Result};
9use crate::protocol::{ControlPayload, InnerHeader, InnerType};
10
11/// Default MDH length matching the primary mask (STUN/WebRTC = 20 bytes).
12pub const DEFAULT_MDH_LEN: usize = 20;
13
14/// Legacy constant kept for backward compatibility references.
15pub const DEFAULT_ZERO_MDH: [u8; 4] = [0u8; 4];
16
17pub struct DecodedPacket {
18    pub counter: u64,
19    pub header: InnerHeader,
20    pub payload: Vec<u8>,
21}
22
23const RECV_REORDER_WINDOW: usize = 256;
24const RECV_FUTURE_SEARCH_WINDOW: usize = 4096;
25
26#[derive(Clone, Copy)]
27struct Bitset256 {
28    lo: u128,
29    hi: u128,
30}
31
32impl Bitset256 {
33    fn new() -> Self {
34        Self { lo: 0, hi: 0 }
35    }
36
37    fn clear(&mut self) {
38        self.lo = 0;
39        self.hi = 0;
40    }
41
42    fn shl(self, shift: u64) -> Self {
43        if shift >= 256 {
44            return Self::new();
45        }
46        if shift == 0 {
47            return self;
48        }
49        if shift >= 128 {
50            return Self {
51                lo: 0,
52                hi: self.lo << (shift - 128),
53            };
54        }
55
56        Self {
57            lo: self.lo << shift,
58            hi: (self.hi << shift) | (self.lo >> (128 - shift)),
59        }
60    }
61
62    fn set_bit(&mut self, bit: usize) {
63        if bit < 128 {
64            self.lo |= 1u128 << bit;
65        } else if bit < 256 {
66            self.hi |= 1u128 << (bit - 128);
67        }
68    }
69
70    fn get_bit(&self, bit: usize) -> bool {
71        if bit < 128 {
72            (self.lo >> bit) & 1 == 1
73        } else if bit < 256 {
74            (self.hi >> (bit - 128)) & 1 == 1
75        } else {
76            false
77        }
78    }
79}
80
81pub struct RecvWindow {
82    highest: i64,
83    bitmap: Bitset256,
84}
85
86impl Default for RecvWindow {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl RecvWindow {
93    pub fn new() -> Self {
94        Self {
95            highest: -1,
96            bitmap: Bitset256::new(),
97        }
98    }
99
100    pub fn reset(&mut self) {
101        self.highest = -1;
102        self.bitmap.clear();
103    }
104
105    pub fn find_counter(&self, tag: &[u8; TAG_SIZE], keys: &SessionKeys) -> Option<u64> {
106        let base_tw = compute_time_window(current_timestamp_ms(), DEFAULT_WINDOW_MS);
107        let start = if self.highest < 0 {
108            0
109        } else {
110            (self.highest as u64).saturating_sub((RECV_REORDER_WINDOW - 1) as u64)
111        };
112        let end = if self.highest < 0 {
113            RECV_FUTURE_SEARCH_WINDOW as u64
114        } else {
115            std::cmp::max(
116                RECV_FUTURE_SEARCH_WINDOW as u64,
117                self.highest as u64 + RECV_FUTURE_SEARCH_WINDOW as u64 + 1,
118            )
119        };
120
121        for tw_offset in [0i64, -1, 1] {
122            let tw = (base_tw as i64 + tw_offset) as u64;
123            for counter in start..end {
124                if !self.is_new(counter) {
125                    continue;
126                }
127                let expected = generate_resonance_tag(&keys.tag_secret, counter, tw);
128                if &expected == tag {
129                    return Some(counter);
130                }
131            }
132        }
133
134        None
135    }
136
137    pub fn mark(&mut self, counter: u64) {
138        if self.highest < 0 || counter > self.highest as u64 {
139            let shift = if self.highest < 0 {
140                RECV_REORDER_WINDOW as u64
141            } else {
142                counter - self.highest as u64
143            };
144            self.bitmap = if shift >= RECV_REORDER_WINDOW as u64 {
145                let mut bitmap = Bitset256::new();
146                bitmap.set_bit(0);
147                bitmap
148            } else {
149                let mut bitmap = self.bitmap.shl(shift);
150                bitmap.set_bit(0);
151                bitmap
152            };
153            self.highest = counter as i64;
154        } else {
155            let diff = (self.highest as u64 - counter) as usize;
156            if diff < RECV_REORDER_WINDOW {
157                self.bitmap.set_bit(diff);
158            }
159        }
160    }
161
162    fn is_new(&self, counter: u64) -> bool {
163        if self.highest < 0 {
164            return true;
165        }
166
167        let highest = self.highest as u64;
168        if counter > highest {
169            return true;
170        }
171
172        let diff = highest - counter;
173        if diff >= RECV_REORDER_WINDOW as u64 {
174            return false;
175        }
176
177        !self.bitmap.get_bit(diff as usize)
178    }
179}
180
181pub fn build_inner_packet(inner_type: InnerType, seq_num: u16, payload: &[u8]) -> Vec<u8> {
182    let mut inner = Vec::with_capacity(4 + payload.len());
183    inner.extend_from_slice(&(inner_type as u16).to_le_bytes());
184    inner.extend_from_slice(&seq_num.to_le_bytes());
185    inner.extend_from_slice(payload);
186    inner
187}
188
189/// Build a packet with random MDH of given length (Issue #30 fix).
190/// Each call generates fresh random MDH bytes, eliminating static fingerprints.
191pub fn build_random_mdh_packet(
192    keys: &SessionKeys,
193    counter: &mut u64,
194    inner: &[u8],
195    obfuscated_eph_pub: Option<&[u8; 32]>,
196    mdh_len: usize,
197) -> Result<Vec<u8>> {
198    let pad_len: u16 = 8 + rand::thread_rng().next_u32() as u16 % 16;
199    let mut plaintext = Vec::with_capacity(2 + inner.len() + pad_len as usize);
200    plaintext.extend_from_slice(&pad_len.to_le_bytes());
201    plaintext.extend_from_slice(inner);
202    plaintext.resize(2 + inner.len() + pad_len as usize, 0);
203    rand::thread_rng().fill_bytes(&mut plaintext[2 + inner.len()..]);
204
205    let current_counter = *counter;
206    *counter += 1;
207
208    let nonce = counter_to_nonce(current_counter);
209    let ciphertext = encrypt_payload(&keys.session_key, &nonce, &plaintext)?;
210    let time_window = compute_time_window(current_timestamp_ms(), DEFAULT_WINDOW_MS);
211    let tag = generate_resonance_tag(&keys.tag_secret, current_counter, time_window);
212
213    // Generate random MDH bytes — no static fingerprint
214    let mut mdh = vec![0u8; mdh_len];
215    rand::thread_rng().fill_bytes(&mut mdh);
216
217    let eph_len = if obfuscated_eph_pub.is_some() { 32 } else { 0 };
218    let mut packet = Vec::with_capacity(TAG_SIZE + mdh_len + eph_len + ciphertext.len());
219    packet.extend_from_slice(&tag);
220    packet.extend_from_slice(&mdh);
221    if let Some(eph) = obfuscated_eph_pub {
222        packet.extend_from_slice(eph);
223    }
224    packet.extend_from_slice(&ciphertext);
225
226    Ok(packet)
227}
228
229/// Legacy: build packet with 4-byte zero MDH (kept for backward compatibility).
230pub fn build_zero_mdh_packet(
231    keys: &SessionKeys,
232    counter: &mut u64,
233    inner: &[u8],
234    obfuscated_eph_pub: Option<&[u8; 32]>,
235) -> Result<Vec<u8>> {
236    build_random_mdh_packet(
237        keys,
238        counter,
239        inner,
240        obfuscated_eph_pub,
241        DEFAULT_ZERO_MDH.len(),
242    )
243}
244
245pub fn decode_packet_with_mdh_len(
246    packet: &[u8],
247    keys: &SessionKeys,
248    recv_window: &mut RecvWindow,
249    mdh_len: usize,
250) -> Result<DecodedPacket> {
251    if packet.len() < TAG_SIZE + mdh_len + 16 {
252        return Err(Error::InvalidPacket("Packet too short"));
253    }
254
255    let tag: [u8; TAG_SIZE] = packet[..TAG_SIZE]
256        .try_into()
257        .map_err(|_| Error::InvalidPacket("Packet tag malformed"))?;
258    let counter = recv_window
259        .find_counter(&tag, keys)
260        .ok_or(Error::InvalidPacket("Invalid resonance tag"))?;
261
262    let nonce = counter_to_nonce(counter);
263    let ciphertext = &packet[TAG_SIZE + mdh_len..];
264    let padded = decrypt_payload(&keys.session_key, &nonce, ciphertext)?;
265    recv_window.mark(counter);
266
267    if padded.len() < 2 {
268        return Err(Error::InvalidPacket("Decrypted payload too short"));
269    }
270
271    let pad_len = u16::from_le_bytes([padded[0], padded[1]]) as usize;
272    let end = padded
273        .len()
274        .checked_sub(pad_len)
275        .ok_or(Error::InvalidPacket("Invalid padding length"))?;
276    if end < 2 {
277        return Err(Error::InvalidPacket("Invalid padding length"));
278    }
279
280    let inner = &padded[2..end];
281    if inner.len() < 4 {
282        return Err(Error::InvalidPacket("Inner payload too short"));
283    }
284
285    let header = InnerHeader::decode(inner)?;
286    let payload = inner[4..].to_vec();
287
288    Ok(DecodedPacket {
289        counter,
290        header,
291        payload,
292    })
293}
294
295pub fn process_server_hello_with_mdh_len(
296    packet: &[u8],
297    keys: &mut SessionKeys,
298    keypair: &KeyPair,
299    recv_window: &mut RecvWindow,
300    send_counter: &mut u64,
301    mdh_len: usize,
302) -> Result<()> {
303    let decoded = decode_packet_with_mdh_len(packet, keys, recv_window, mdh_len)?;
304
305    if decoded.header.inner_type != InnerType::Control {
306        return Err(Error::InvalidPacket(
307            "Expected control packet for ServerHello",
308        ));
309    }
310
311    match ControlPayload::decode(&decoded.payload)? {
312        ControlPayload::ServerHello { server_eph_pub, .. } => {
313            let dh2 = keypair.compute_shared(&server_eph_pub)?;
314            let old_session_key = keys.session_key;
315            *keys = derive_session_keys(&dh2, Some(&old_session_key), &keypair.public_key_bytes());
316            *send_counter = 0;
317            recv_window.reset();
318            Ok(())
319        }
320        _ => Err(Error::InvalidPacket("Expected ServerHello control payload")),
321    }
322}
323
324pub fn obfuscate_client_eph_pub(keypair: &KeyPair, server_public_key: &[u8; 32]) -> [u8; 32] {
325    let mut obfuscated = keypair.public_key_bytes();
326    crypto::obfuscate_eph_pub(&mut obfuscated, server_public_key);
327    obfuscated
328}
329
330pub fn counter_to_nonce(counter: u64) -> [u8; NONCE_SIZE] {
331    let mut nonce = [0u8; NONCE_SIZE];
332    nonce[..8].copy_from_slice(&counter.to_le_bytes());
333    nonce
334}