corevpn_protocol/
session.rs

1//! Protocol Session Management
2
3use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10    KeyId, OpCode, Packet, DataPacket, DataChannel,
11    ReliableTransport, ReliableConfig, TlsRecordReassembler,
12    ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16/// Protocol session state
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19    /// Initial state, waiting for client hello
20    Initial,
21    /// TLS handshake in progress
22    TlsHandshake,
23    /// Key exchange in progress
24    KeyExchange,
25    /// Authentication in progress
26    Authenticating,
27    /// Session fully established
28    Established,
29    /// Rekeying in progress
30    Rekeying,
31    /// Session terminated
32    Terminated,
33}
34
35/// Session ID type (8 bytes)
36pub type SessionIdBytes = [u8; 8];
37
38/// Protocol session
39pub struct ProtocolSession {
40    /// Local session ID
41    local_session_id: SessionIdBytes,
42    /// Remote session ID
43    remote_session_id: Option<SessionIdBytes>,
44    /// Current protocol state
45    state: ProtocolState,
46    /// Current key ID
47    current_key_id: KeyId,
48    /// Reliable transport for control channel
49    reliable: ReliableTransport,
50    /// TLS record reassembler
51    tls_reassembler: TlsRecordReassembler,
52    /// Data channels (one per key ID)
53    data_channels: [Option<DataChannel>; 8],
54    /// Peer ID (for P_DATA_V2)
55    peer_id: Option<u32>,
56    /// Use tls-auth
57    use_tls_auth: bool,
58    /// tls-auth key
59    tls_auth_key: Option<corevpn_crypto::HmacAuth>,
60    /// Session creation time
61    created_at: Instant,
62    /// Last activity time
63    last_activity: Instant,
64    /// Cipher suite to use
65    cipher_suite: CipherSuite,
66}
67
68impl ProtocolSession {
69    /// Create a new server-side session
70    pub fn new_server(cipher_suite: CipherSuite) -> Self {
71        Self {
72            local_session_id: corevpn_crypto::generate_session_id(),
73            remote_session_id: None,
74            state: ProtocolState::Initial,
75            current_key_id: KeyId::default(),
76            reliable: ReliableTransport::new(ReliableConfig::default()),
77            tls_reassembler: TlsRecordReassembler::new(65536),
78            data_channels: Default::default(),
79            peer_id: None,
80            use_tls_auth: false,
81            tls_auth_key: None,
82            created_at: Instant::now(),
83            last_activity: Instant::now(),
84            cipher_suite,
85        }
86    }
87
88    /// Create a new client-side session
89    pub fn new_client(cipher_suite: CipherSuite) -> Self {
90        let mut session = Self::new_server(cipher_suite);
91        session.state = ProtocolState::Initial;
92        session
93    }
94
95    /// Get local session ID
96    pub fn local_session_id(&self) -> &SessionIdBytes {
97        &self.local_session_id
98    }
99
100    /// Get remote session ID
101    pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
102        self.remote_session_id.as_ref()
103    }
104
105    /// Get current state
106    pub fn state(&self) -> ProtocolState {
107        self.state
108    }
109
110    /// Set state
111    pub fn set_state(&mut self, state: ProtocolState) {
112        self.state = state;
113        self.last_activity = Instant::now();
114    }
115
116    /// Set remote session ID
117    pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
118        self.remote_session_id = Some(id);
119    }
120
121    /// Enable tls-auth
122    pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
123        self.use_tls_auth = true;
124        self.tls_auth_key = Some(key);
125    }
126
127    /// Process incoming packet
128    pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
129        self.last_activity = Instant::now();
130
131        // Verify HMAC if tls-auth enabled
132        let data = if self.use_tls_auth {
133            if let Some(key) = &self.tls_auth_key {
134                // First byte is opcode, check if control
135                if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
136                    key.unwrap(data)?
137                } else {
138                    data.to_vec()
139                }
140            } else {
141                data.to_vec()
142            }
143        } else {
144            data.to_vec()
145        };
146
147        let packet = Packet::parse(&data, false)?;
148
149        match packet {
150            Packet::Control(ctrl) => self.process_control_packet(ctrl),
151            Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
152        }
153    }
154
155    fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
156        // Process ACKs
157        if !ctrl.acks.is_empty() {
158            self.reliable.process_acks(&ctrl.acks);
159        }
160
161        // Handle different opcodes
162        match ctrl.header.opcode {
163            OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
164                // Client initiating connection
165                if let Some(remote_sid) = ctrl.header.session_id {
166                    self.remote_session_id = Some(remote_sid);
167                }
168                self.state = ProtocolState::TlsHandshake;
169
170                Ok(ProcessedPacket::HardReset {
171                    session_id: ctrl.header.session_id.unwrap_or([0; 8]),
172                })
173            }
174            OpCode::HardResetServerV2 => {
175                // Server response to hard reset
176                if let Some(remote_sid) = ctrl.header.session_id {
177                    self.remote_session_id = Some(remote_sid);
178                }
179                Ok(ProcessedPacket::HardResetAck)
180            }
181            OpCode::ControlV1 => {
182                // TLS data
183                if let Some(packet_id) = ctrl.message_packet_id {
184                    if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone()) {
185                        self.tls_reassembler.add(&data)?;
186                        let records = self.tls_reassembler.extract_records();
187                        if !records.is_empty() {
188                            return Ok(ProcessedPacket::TlsData(records));
189                        }
190                    }
191                }
192                Ok(ProcessedPacket::None)
193            }
194            OpCode::AckV1 => {
195                // Pure ACK, already processed above
196                Ok(ProcessedPacket::None)
197            }
198            OpCode::SoftResetV1 => {
199                // Key renegotiation
200                self.state = ProtocolState::Rekeying;
201                Ok(ProcessedPacket::SoftReset)
202            }
203            _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
204        }
205    }
206
207    fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
208        let packet = DataPacket {
209            key_id: data_pkt.header.key_id,
210            peer_id: data_pkt.peer_id,
211            payload: data_pkt.payload,
212        };
213
214        let key_id = packet.key_id.0 as usize;
215        if let Some(channel) = &mut self.data_channels[key_id] {
216            let decrypted = channel.decrypt(&packet)?;
217            Ok(ProcessedPacket::Data(decrypted))
218        } else {
219            Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
220        }
221    }
222
223    /// Create a hard reset response packet
224    pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
225        let packet = crate::packet::ControlPacketData {
226            header: crate::PacketHeader {
227                opcode: OpCode::HardResetServerV2,
228                key_id: KeyId::default(),
229                session_id: Some(self.local_session_id),
230                hmac: None,
231                packet_id: None,
232                timestamp: None,
233            },
234            remote_session_id: self.remote_session_id,
235            acks: self.reliable.get_acks(),
236            message_packet_id: None,
237            payload: Bytes::new(),
238        };
239
240        let serialized = Packet::Control(packet).serialize();
241        Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
242    }
243
244    /// Create a control packet with TLS data
245    pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
246        let (packet_id, _) = self.reliable.send(tls_data.clone())?;
247
248        let packet = crate::packet::ControlPacketData {
249            header: crate::PacketHeader {
250                opcode: OpCode::ControlV1,
251                key_id: self.current_key_id,
252                session_id: Some(self.local_session_id),
253                hmac: None,
254                packet_id: None,
255                timestamp: None,
256            },
257            remote_session_id: self.remote_session_id,
258            acks: self.reliable.get_acks(),
259            message_packet_id: Some(packet_id),
260            payload: tls_data,
261        };
262
263        let serialized = Packet::Control(packet).serialize();
264        Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
265    }
266
267    /// Create an ACK packet
268    pub fn create_ack_packet(&mut self) -> Option<Bytes> {
269        let acks = self.reliable.get_acks();
270        if acks.is_empty() {
271            return None;
272        }
273
274        let packet = crate::packet::ControlPacketData {
275            header: crate::PacketHeader {
276                opcode: OpCode::AckV1,
277                key_id: self.current_key_id,
278                session_id: Some(self.local_session_id),
279                hmac: None,
280                packet_id: None,
281                timestamp: None,
282            },
283            remote_session_id: self.remote_session_id,
284            acks,
285            message_packet_id: None,
286            payload: Bytes::new(),
287        };
288
289        self.reliable.ack_sent();
290        let serialized = Packet::Control(packet).serialize();
291        Some(self.maybe_wrap_tls_auth(serialized.freeze()))
292    }
293
294    /// Install data channel keys
295    pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
296        let key_id = self.current_key_id;
297        let idx = key_id.0 as usize;
298
299        let (encrypt_key, decrypt_key) = if is_server {
300            (
301                key_material.server_data_key(self.cipher_suite),
302                key_material.client_data_key(self.cipher_suite),
303            )
304        } else {
305            (
306                key_material.client_data_key(self.cipher_suite),
307                key_material.server_data_key(self.cipher_suite),
308            )
309        };
310
311        self.data_channels[idx] = Some(DataChannel::new(
312            key_id,
313            encrypt_key,
314            decrypt_key,
315            true,
316            self.peer_id,
317        ));
318    }
319
320    /// Encrypt data for transmission
321    pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
322        let idx = self.current_key_id.0 as usize;
323        if let Some(channel) = &mut self.data_channels[idx] {
324            let packet = channel.encrypt(data)?;
325            Ok(packet.serialize().freeze())
326        } else {
327            Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
328        }
329    }
330
331    /// Get packets needing retransmission
332    pub fn get_retransmits(&mut self) -> Vec<Bytes> {
333        self.reliable
334            .get_retransmits()
335            .into_iter()
336            .map(|(id, data)| {
337                // Rebuild packet with same ID
338                let packet = crate::packet::ControlPacketData {
339                    header: crate::PacketHeader {
340                        opcode: OpCode::ControlV1,
341                        key_id: self.current_key_id,
342                        session_id: Some(self.local_session_id),
343                        hmac: None,
344                        packet_id: None,
345                        timestamp: None,
346                    },
347                    remote_session_id: self.remote_session_id,
348                    acks: vec![],
349                    message_packet_id: Some(id),
350                    payload: data,
351                };
352                let serialized = Packet::Control(packet).serialize();
353                self.maybe_wrap_tls_auth(serialized.freeze())
354            })
355            .collect()
356    }
357
358    /// Check if we should send an ACK
359    pub fn should_send_ack(&self) -> bool {
360        self.reliable.should_send_ack()
361    }
362
363    /// Get next timeout
364    pub fn next_timeout(&self) -> Option<Duration> {
365        self.reliable.next_timeout()
366    }
367
368    /// Check if session is established
369    pub fn is_established(&self) -> bool {
370        self.state == ProtocolState::Established
371    }
372
373    /// Get session duration
374    pub fn duration(&self) -> Duration {
375        self.created_at.elapsed()
376    }
377
378    /// Get idle time
379    pub fn idle_time(&self) -> Duration {
380        self.last_activity.elapsed()
381    }
382
383    fn maybe_wrap_tls_auth(&self, data: Bytes) -> Bytes {
384        if self.use_tls_auth {
385            if let Some(key) = &self.tls_auth_key {
386                return Bytes::from(key.wrap(&data));
387            }
388        }
389        data
390    }
391
392    /// Rotate to next key ID (for rekeying)
393    pub fn rotate_key(&mut self) {
394        self.current_key_id = self.current_key_id.next();
395    }
396}
397
398/// Result of processing a packet
399#[derive(Debug)]
400pub enum ProcessedPacket {
401    /// No action needed
402    None,
403    /// Hard reset from client
404    HardReset {
405        /// Session ID for the new connection
406        session_id: SessionIdBytes,
407    },
408    /// Hard reset acknowledged
409    HardResetAck,
410    /// TLS records to process
411    TlsData(Vec<Bytes>),
412    /// Decrypted data packet
413    Data(Bytes),
414    /// Soft reset (rekey)
415    SoftReset,
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_session_creation() {
424        let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
425        assert_eq!(session.state(), ProtocolState::Initial);
426        assert!(session.remote_session_id().is_none());
427    }
428
429    #[test]
430    fn test_hard_reset() {
431        let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
432
433        // Simulate receiving hard reset from client
434        let hard_reset = [
435            0x38, // opcode=7 (HardResetClientV2), key_id=0
436            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, // session_id
437            0x00, // ack_count = 0
438        ];
439
440        let result = session.process_packet(&hard_reset).unwrap();
441        matches!(result, ProcessedPacket::HardReset { .. });
442        assert_eq!(session.state(), ProtocolState::TlsHandshake);
443    }
444}