Skip to main content

corevpn_protocol/
session.rs

1//! Protocol Session Management
2
3use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10    KeyId, OpCode, Packet, DataPacket, DataChannel,
11    ReliableTransport, ReliableConfig, TlsRecordReassembler,
12    ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16/// Protocol session state
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19    /// Initial state, waiting for client hello
20    Initial,
21    /// TLS handshake in progress
22    TlsHandshake,
23    /// Key exchange in progress
24    KeyExchange,
25    /// Authentication in progress
26    Authenticating,
27    /// Session fully established
28    Established,
29    /// Rekeying in progress
30    Rekeying,
31    /// Session terminated
32    Terminated,
33}
34
35/// Session ID type (8 bytes)
36pub type SessionIdBytes = [u8; 8];
37
38/// Replay window for tls-auth packet IDs
39/// Uses a 64-bit bitmap to track the last 64 packet IDs
40struct ReplayWindow {
41    /// Highest seen packet ID
42    highest: u32,
43    /// Bitmap of recently seen packets (relative to highest)
44    /// Bit 0 = highest, bit N = highest - N
45    bitmap: u64,
46}
47
48impl ReplayWindow {
49    /// Window size in packets (64 bits = 64 packet tracking)
50    const WINDOW_SIZE: u32 = 64;
51
52    fn new() -> Self {
53        Self {
54            highest: 0,
55            bitmap: 0,
56        }
57    }
58
59    /// Check if packet ID is valid (not replayed) and update window
60    ///
61    /// Returns true if the packet should be processed, false if it's a replay
62    /// or too old.
63    fn check_and_update(&mut self, packet_id: u32) -> bool {
64        // Packet ID 0 is invalid (counter starts at 1)
65        if packet_id == 0 {
66            return false;
67        }
68
69        if packet_id > self.highest {
70            // New highest packet - advance window
71            let shift = packet_id - self.highest;
72
73            if shift >= Self::WINDOW_SIZE {
74                // Packet is way ahead, clear entire window
75                self.bitmap = 1; // Only mark current packet
76            } else {
77                // Shift window and mark current packet
78                self.bitmap = (self.bitmap << shift) | 1;
79            }
80            self.highest = packet_id;
81            true
82        } else {
83            // Packet is at or before highest
84            let diff = self.highest - packet_id;
85
86            // Check if packet is within window
87            if diff >= Self::WINDOW_SIZE {
88                return false; // Too old
89            }
90
91            // Check if already seen using bit test
92            let mask = 1u64 << diff;
93            if self.bitmap & mask != 0 {
94                return false; // Replay detected
95            }
96
97            // Mark as seen
98            self.bitmap |= mask;
99            true
100        }
101    }
102
103    /// Reset the replay window (e.g., for key renegotiation)
104    fn reset(&mut self) {
105        self.highest = 0;
106        self.bitmap = 0;
107    }
108}
109
110/// Protocol session
111pub struct ProtocolSession {
112    /// Local session ID
113    local_session_id: SessionIdBytes,
114    /// Remote session ID
115    remote_session_id: Option<SessionIdBytes>,
116    /// Current protocol state
117    state: ProtocolState,
118    /// Current key ID
119    current_key_id: KeyId,
120    /// Reliable transport for control channel
121    reliable: ReliableTransport,
122    /// TLS record reassembler
123    tls_reassembler: TlsRecordReassembler,
124    /// Data channels (one per key ID)
125    data_channels: [Option<DataChannel>; 8],
126    /// Peer ID (for P_DATA_V2)
127    peer_id: Option<u32>,
128    /// Use tls-auth
129    use_tls_auth: bool,
130    /// tls-auth key
131    tls_auth_key: Option<corevpn_crypto::HmacAuth>,
132    /// Replay window for tls-auth packet IDs
133    replay_window: ReplayWindow,
134    /// Session creation time
135    created_at: Instant,
136    /// Last activity time
137    last_activity: Instant,
138    /// Cipher suite to use
139    cipher_suite: CipherSuite,
140}
141
142impl ProtocolSession {
143    /// Create a new server-side session
144    pub fn new_server(cipher_suite: CipherSuite) -> Self {
145        Self {
146            local_session_id: corevpn_crypto::generate_session_id(),
147            remote_session_id: None,
148            state: ProtocolState::Initial,
149            current_key_id: KeyId::default(),
150            reliable: ReliableTransport::new(ReliableConfig::default()),
151            tls_reassembler: TlsRecordReassembler::new(65536),
152            data_channels: Default::default(),
153            peer_id: None,
154            use_tls_auth: false,
155            tls_auth_key: None,
156            replay_window: ReplayWindow::new(),
157            created_at: Instant::now(),
158            last_activity: Instant::now(),
159            cipher_suite,
160        }
161    }
162
163    /// Create a new client-side session
164    pub fn new_client(cipher_suite: CipherSuite) -> Self {
165        let mut session = Self::new_server(cipher_suite);
166        session.state = ProtocolState::Initial;
167        session
168    }
169
170    /// Get local session ID
171    pub fn local_session_id(&self) -> &SessionIdBytes {
172        &self.local_session_id
173    }
174
175    /// Get remote session ID
176    pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
177        self.remote_session_id.as_ref()
178    }
179
180    /// Get current state
181    pub fn state(&self) -> ProtocolState {
182        self.state
183    }
184
185    /// Set state
186    pub fn set_state(&mut self, state: ProtocolState) {
187        self.state = state;
188        self.last_activity = Instant::now();
189    }
190
191    /// Set remote session ID
192    pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
193        self.remote_session_id = Some(id);
194    }
195
196    /// Enable tls-auth
197    pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
198        self.use_tls_auth = true;
199        self.tls_auth_key = Some(key);
200    }
201
202    /// Process incoming packet
203    pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
204        self.last_activity = Instant::now();
205
206        // Verify HMAC if tls-auth enabled
207        let data = if self.use_tls_auth {
208            if let Some(key) = &self.tls_auth_key {
209                // First byte is opcode, check if control
210                if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
211                    key.unwrap(data)?
212                } else {
213                    data.to_vec()
214                }
215            } else {
216                data.to_vec()
217            }
218        } else {
219            data.to_vec()
220        };
221
222        let packet = Packet::parse(&data, self.use_tls_auth)?;
223
224        match packet {
225            Packet::Control(ctrl) => self.process_control_packet(ctrl),
226            Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
227        }
228    }
229
230    fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
231        // Check replay protection for tls-auth packets
232        if self.use_tls_auth {
233            if let Some(packet_id) = ctrl.header.packet_id {
234                if !self.replay_window.check_and_update(packet_id) {
235                    return Err(ProtocolError::ReplayDetected);
236                }
237            }
238        }
239
240        // Process ACKs
241        if !ctrl.acks.is_empty() {
242            self.reliable.process_acks(&ctrl.acks);
243        }
244
245        // Handle different opcodes
246        match ctrl.header.opcode {
247            OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
248                // Client initiating connection
249                // Security: Validate session ID - generate new one instead of accepting client's
250                // This prevents session fixation attacks
251                if let Some(remote_sid) = ctrl.header.session_id {
252                    // Validate session ID is not all zeros or obviously malicious
253                    if remote_sid == [0; 8] {
254                        return Err(ProtocolError::InvalidSessionId);
255                    }
256                    // Accept the session ID but we'll use our own for the response
257                    self.remote_session_id = Some(remote_sid);
258                }
259                self.state = ProtocolState::TlsHandshake;
260
261                Ok(ProcessedPacket::HardReset {
262                    session_id: self.local_session_id,
263                })
264            }
265            OpCode::HardResetServerV2 => {
266                // Server response to hard reset
267                if let Some(remote_sid) = ctrl.header.session_id {
268                    self.remote_session_id = Some(remote_sid);
269                }
270                Ok(ProcessedPacket::HardResetAck)
271            }
272            OpCode::ControlV1 => {
273                // TLS data
274                if let Some(packet_id) = ctrl.message_packet_id {
275                    if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone())? {
276                        self.tls_reassembler.add(&data)?;
277                        let records = self.tls_reassembler.extract_records();
278                        if !records.is_empty() {
279                            return Ok(ProcessedPacket::TlsData(records));
280                        }
281                    }
282                }
283                Ok(ProcessedPacket::None)
284            }
285            OpCode::AckV1 => {
286                // Pure ACK, already processed above
287                Ok(ProcessedPacket::None)
288            }
289            OpCode::SoftResetV1 => {
290                // Key renegotiation
291                self.state = ProtocolState::Rekeying;
292                Ok(ProcessedPacket::SoftReset)
293            }
294            _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
295        }
296    }
297
298    fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
299        let packet = DataPacket {
300            key_id: data_pkt.header.key_id,
301            peer_id: data_pkt.peer_id,
302            payload: data_pkt.payload,
303        };
304
305        let key_id = packet.key_id.0 as usize;
306        if let Some(channel) = &mut self.data_channels[key_id] {
307            let decrypted = channel.decrypt(&packet)?;
308            Ok(ProcessedPacket::Data(decrypted))
309        } else {
310            Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
311        }
312    }
313
314    /// Create a hard reset response packet
315    pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
316        let packet = crate::packet::ControlPacketData {
317            header: crate::PacketHeader {
318                opcode: OpCode::HardResetServerV2,
319                key_id: KeyId::default(),
320                session_id: Some(self.local_session_id),
321                hmac: None,
322                packet_id: None,
323                timestamp: None,
324            },
325            remote_session_id: self.remote_session_id,
326            acks: self.reliable.get_acks(),
327            message_packet_id: None,
328            payload: Bytes::new(),
329        };
330
331        let serialized = Packet::Control(packet).serialize();
332        Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
333    }
334
335    /// Create a control packet with TLS data
336    pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
337        let (packet_id, _) = self.reliable.send(tls_data.clone())?;
338
339        let packet = crate::packet::ControlPacketData {
340            header: crate::PacketHeader {
341                opcode: OpCode::ControlV1,
342                key_id: self.current_key_id,
343                session_id: Some(self.local_session_id),
344                hmac: None,
345                packet_id: None,
346                timestamp: None,
347            },
348            remote_session_id: self.remote_session_id,
349            acks: self.reliable.get_acks(),
350            message_packet_id: Some(packet_id),
351            payload: tls_data,
352        };
353
354        let serialized = Packet::Control(packet).serialize();
355        Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
356    }
357
358    /// Create an ACK packet
359    pub fn create_ack_packet(&mut self) -> Option<Bytes> {
360        let acks = self.reliable.get_acks();
361        if acks.is_empty() {
362            return None;
363        }
364
365        let packet = crate::packet::ControlPacketData {
366            header: crate::PacketHeader {
367                opcode: OpCode::AckV1,
368                key_id: self.current_key_id,
369                session_id: Some(self.local_session_id),
370                hmac: None,
371                packet_id: None,
372                timestamp: None,
373            },
374            remote_session_id: self.remote_session_id,
375            acks,
376            message_packet_id: None,
377            payload: Bytes::new(),
378        };
379
380        self.reliable.ack_sent();
381        let serialized = Packet::Control(packet).serialize();
382        Some(self.maybe_wrap_tls_auth(serialized.freeze()))
383    }
384
385    /// Install data channel keys
386    pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
387        let key_id = self.current_key_id;
388        let idx = key_id.0 as usize;
389
390        let (encrypt_key, decrypt_key) = if is_server {
391            (
392                key_material.server_data_key(self.cipher_suite),
393                key_material.client_data_key(self.cipher_suite),
394            )
395        } else {
396            (
397                key_material.client_data_key(self.cipher_suite),
398                key_material.server_data_key(self.cipher_suite),
399            )
400        };
401
402        self.data_channels[idx] = Some(DataChannel::new(
403            key_id,
404            encrypt_key,
405            decrypt_key,
406            true,
407            self.peer_id,
408        ));
409    }
410
411    /// Encrypt data for transmission
412    pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
413        let idx = self.current_key_id.0 as usize;
414        if let Some(channel) = &mut self.data_channels[idx] {
415            let packet = channel.encrypt(data)?;
416            Ok(packet.serialize().freeze())
417        } else {
418            Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
419        }
420    }
421
422    /// Get packets needing retransmission
423    pub fn get_retransmits(&mut self) -> Vec<Bytes> {
424        self.reliable
425            .get_retransmits()
426            .into_iter()
427            .map(|(id, data)| {
428                // Rebuild packet with same ID
429                let packet = crate::packet::ControlPacketData {
430                    header: crate::PacketHeader {
431                        opcode: OpCode::ControlV1,
432                        key_id: self.current_key_id,
433                        session_id: Some(self.local_session_id),
434                        hmac: None,
435                        packet_id: None,
436                        timestamp: None,
437                    },
438                    remote_session_id: self.remote_session_id,
439                    acks: vec![],
440                    message_packet_id: Some(id),
441                    payload: data,
442                };
443                let serialized = Packet::Control(packet).serialize();
444                self.maybe_wrap_tls_auth(serialized.freeze())
445            })
446            .collect()
447    }
448
449    /// Check if we should send an ACK
450    pub fn should_send_ack(&self) -> bool {
451        self.reliable.should_send_ack()
452    }
453
454    /// Get next timeout
455    pub fn next_timeout(&self) -> Option<Duration> {
456        self.reliable.next_timeout()
457    }
458
459    /// Check if session is established
460    pub fn is_established(&self) -> bool {
461        self.state == ProtocolState::Established
462    }
463
464    /// Get session duration
465    pub fn duration(&self) -> Duration {
466        self.created_at.elapsed()
467    }
468
469    /// Get idle time
470    pub fn idle_time(&self) -> Duration {
471        self.last_activity.elapsed()
472    }
473
474    fn maybe_wrap_tls_auth(&self, data: Bytes) -> Bytes {
475        if self.use_tls_auth {
476            if let Some(key) = &self.tls_auth_key {
477                return Bytes::from(key.wrap(&data));
478            }
479        }
480        data
481    }
482
483    /// Rotate to next key ID (for rekeying)
484    pub fn rotate_key(&mut self) {
485        self.current_key_id = self.current_key_id.next();
486        // Reset replay window on key rotation
487        self.replay_window.reset();
488    }
489}
490
491/// Result of processing a packet
492#[derive(Debug)]
493pub enum ProcessedPacket {
494    /// No action needed
495    None,
496    /// Hard reset from client
497    HardReset {
498        /// Session ID for the new connection
499        session_id: SessionIdBytes,
500    },
501    /// Hard reset acknowledged
502    HardResetAck,
503    /// TLS records to process
504    TlsData(Vec<Bytes>),
505    /// Decrypted data packet
506    Data(Bytes),
507    /// Soft reset (rekey)
508    SoftReset,
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_session_creation() {
517        let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
518        assert_eq!(session.state(), ProtocolState::Initial);
519        assert!(session.remote_session_id().is_none());
520    }
521
522    #[test]
523    fn test_hard_reset() {
524        let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
525
526        // Simulate receiving hard reset from client
527        let hard_reset = [
528            0x38, // opcode=7 (HardResetClientV2), key_id=0
529            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // session_id
530            0x00, // ack_count = 0
531        ];
532
533        let result = session.process_packet(&hard_reset).unwrap();
534        matches!(result, ProcessedPacket::HardReset { .. });
535        assert_eq!(session.state(), ProtocolState::TlsHandshake);
536    }
537}