1use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10 KeyId, OpCode, Packet, DataPacket, DataChannel,
11 ReliableTransport, ReliableConfig, TlsRecordReassembler,
12 ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19 Initial,
21 TlsHandshake,
23 KeyExchange,
25 Authenticating,
27 Established,
29 Rekeying,
31 Terminated,
33}
34
35pub type SessionIdBytes = [u8; 8];
37
38struct ReplayWindow {
41 highest: u32,
43 bitmap: u64,
46}
47
48impl ReplayWindow {
49 const WINDOW_SIZE: u32 = 64;
51
52 fn new() -> Self {
53 Self {
54 highest: 0,
55 bitmap: 0,
56 }
57 }
58
59 fn check_and_update(&mut self, packet_id: u32) -> bool {
64 if packet_id == 0 {
66 return false;
67 }
68
69 if packet_id > self.highest {
70 let shift = packet_id - self.highest;
72
73 if shift >= Self::WINDOW_SIZE {
74 self.bitmap = 1; } else {
77 self.bitmap = (self.bitmap << shift) | 1;
79 }
80 self.highest = packet_id;
81 true
82 } else {
83 let diff = self.highest - packet_id;
85
86 if diff >= Self::WINDOW_SIZE {
88 return false; }
90
91 let mask = 1u64 << diff;
93 if self.bitmap & mask != 0 {
94 return false; }
96
97 self.bitmap |= mask;
99 true
100 }
101 }
102
103 fn reset(&mut self) {
105 self.highest = 0;
106 self.bitmap = 0;
107 }
108}
109
110pub struct ProtocolSession {
112 local_session_id: SessionIdBytes,
114 remote_session_id: Option<SessionIdBytes>,
116 state: ProtocolState,
118 current_key_id: KeyId,
120 reliable: ReliableTransport,
122 tls_reassembler: TlsRecordReassembler,
124 data_channels: [Option<DataChannel>; 8],
126 peer_id: Option<u32>,
128 use_tls_auth: bool,
130 tls_auth_key: Option<corevpn_crypto::HmacAuth>,
132 replay_window: ReplayWindow,
134 tls_auth_packet_id: u32,
136 created_at: Instant,
138 last_activity: Instant,
140 cipher_suite: CipherSuite,
142}
143
144impl ProtocolSession {
145 pub fn new_server(cipher_suite: CipherSuite) -> Self {
147 Self {
148 local_session_id: corevpn_crypto::generate_session_id(),
149 remote_session_id: None,
150 state: ProtocolState::Initial,
151 current_key_id: KeyId::default(),
152 reliable: ReliableTransport::new(ReliableConfig::default()),
153 tls_reassembler: TlsRecordReassembler::new(65536),
154 data_channels: Default::default(),
155 peer_id: None,
156 use_tls_auth: false,
157 tls_auth_key: None,
158 replay_window: ReplayWindow::new(),
159 tls_auth_packet_id: 1, created_at: Instant::now(),
161 last_activity: Instant::now(),
162 cipher_suite,
163 }
164 }
165
166 pub fn new_client(cipher_suite: CipherSuite) -> Self {
168 let mut session = Self::new_server(cipher_suite);
169 session.state = ProtocolState::Initial;
170 session
171 }
172
173 pub fn local_session_id(&self) -> &SessionIdBytes {
175 &self.local_session_id
176 }
177
178 pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
180 self.remote_session_id.as_ref()
181 }
182
183 pub fn state(&self) -> ProtocolState {
185 self.state
186 }
187
188 pub fn set_state(&mut self, state: ProtocolState) {
190 self.state = state;
191 self.last_activity = Instant::now();
192 }
193
194 pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
196 self.remote_session_id = Some(id);
197 }
198
199 pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
201 self.use_tls_auth = true;
202 self.tls_auth_key = Some(key);
203 }
204
205 pub fn set_cipher_suite(&mut self, cipher_suite: CipherSuite) {
207 self.cipher_suite = cipher_suite;
208 }
209
210 pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
212 self.last_activity = Instant::now();
213
214 if self.use_tls_auth {
216 if let Some(key) = &self.tls_auth_key {
217 if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
218 if data.len() < 49 {
225 return Err(ProtocolError::PacketTooShort {
226 expected: 49,
227 got: data.len(),
228 });
229 }
230
231 let mut hmac = [0u8; 32];
232 hmac.copy_from_slice(&data[9..41]); let mut hmac_input = Vec::with_capacity(8 + 9 + data.len() - 49);
236 hmac_input.extend_from_slice(&data[41..49]); hmac_input.extend_from_slice(&data[0..9]); hmac_input.extend_from_slice(&data[49..]); key.verify(&hmac_input, &hmac)?;
241 }
242 }
243 }
244
245 let packet = Packet::parse(data, self.use_tls_auth)?;
246
247 match packet {
248 Packet::Control(ctrl) => self.process_control_packet(ctrl),
249 Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
250 }
251 }
252
253 fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
254 if self.use_tls_auth {
256 if let Some(packet_id) = ctrl.header.packet_id {
257 if !self.replay_window.check_and_update(packet_id) {
258 return Err(ProtocolError::ReplayDetected);
259 }
260 }
261 }
262
263 if !ctrl.acks.is_empty() {
265 self.reliable.process_acks(&ctrl.acks);
266 }
267
268 match ctrl.header.opcode {
270 OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
271 if let Some(remote_sid) = ctrl.header.session_id {
275 if remote_sid == [0; 8] {
277 return Err(ProtocolError::InvalidSessionId);
278 }
279 self.remote_session_id = Some(remote_sid);
281 }
282
283 if let Some(packet_id) = ctrl.message_packet_id {
286 let _ = self.reliable.receive(packet_id, Bytes::new())?;
287 }
288
289 self.state = ProtocolState::TlsHandshake;
290
291 Ok(ProcessedPacket::HardReset {
292 session_id: self.local_session_id,
293 })
294 }
295 OpCode::HardResetServerV2 => {
296 if let Some(remote_sid) = ctrl.header.session_id {
298 self.remote_session_id = Some(remote_sid);
299 }
300
301 if let Some(packet_id) = ctrl.message_packet_id {
303 let _ = self.reliable.receive(packet_id, Bytes::new())?;
304 }
305
306 self.state = ProtocolState::TlsHandshake;
307
308 Ok(ProcessedPacket::HardResetAck)
309 }
310 OpCode::ControlV1 => {
311 if let Some(packet_id) = ctrl.message_packet_id {
313 if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone())? {
314 self.tls_reassembler.add(&data)?;
315 let records = self.tls_reassembler.extract_records();
316 if !records.is_empty() {
317 return Ok(ProcessedPacket::TlsData(records));
318 }
319 }
320 }
321 Ok(ProcessedPacket::None)
322 }
323 OpCode::AckV1 => {
324 Ok(ProcessedPacket::None)
326 }
327 OpCode::SoftResetV1 => {
328 self.state = ProtocolState::Rekeying;
330 Ok(ProcessedPacket::SoftReset)
331 }
332 _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
333 }
334 }
335
336 fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
337 let packet = DataPacket {
338 key_id: data_pkt.header.key_id,
339 peer_id: data_pkt.peer_id,
340 payload: data_pkt.payload,
341 };
342
343 let key_id = packet.key_id.0 as usize;
344 if let Some(channel) = &mut self.data_channels[key_id] {
345 let decrypted = channel.decrypt(&packet)?;
346 Ok(ProcessedPacket::Data(decrypted))
347 } else {
348 Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
349 }
350 }
351
352 pub fn create_hard_reset_client(&mut self) -> Result<Bytes> {
354 let (packet_id, _) = self.reliable.send(Bytes::new())?;
355
356 let packet = crate::packet::ControlPacketData {
357 header: crate::PacketHeader {
358 opcode: OpCode::HardResetClientV2,
359 key_id: KeyId::default(),
360 session_id: Some(self.local_session_id),
361 hmac: None,
362 packet_id: None,
363 timestamp: None,
364 },
365 remote_session_id: None, acks: vec![],
367 message_packet_id: Some(packet_id),
368 payload: Bytes::new(),
369 };
370
371 let serialized = Packet::Control(packet).serialize();
372 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
373 }
374
375 pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
377 let (packet_id, _) = self.reliable.send(Bytes::new())?;
381
382 let packet = crate::packet::ControlPacketData {
383 header: crate::PacketHeader {
384 opcode: OpCode::HardResetServerV2,
385 key_id: KeyId::default(),
386 session_id: Some(self.local_session_id),
387 hmac: None,
388 packet_id: None,
389 timestamp: None,
390 },
391 remote_session_id: self.remote_session_id,
392 acks: self.reliable.get_acks(),
393 message_packet_id: Some(packet_id),
394 payload: Bytes::new(),
395 };
396
397 let serialized = Packet::Control(packet).serialize();
398 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
399 }
400
401 const MAX_CONTROL_PAYLOAD: usize = 1100;
412
413 pub fn create_control_packets(&mut self, tls_data: Bytes) -> Result<Vec<Bytes>> {
417 let mut packets = Vec::new();
418
419 if tls_data.len() <= Self::MAX_CONTROL_PAYLOAD {
420 packets.push(self.create_single_control_packet(tls_data)?);
422 } else {
423 let mut offset = 0;
425 while offset < tls_data.len() {
426 let end = std::cmp::min(offset + Self::MAX_CONTROL_PAYLOAD, tls_data.len());
427 let chunk = tls_data.slice(offset..end);
428 packets.push(self.create_single_control_packet(chunk)?);
429 offset = end;
430 }
431 }
432
433 Ok(packets)
434 }
435
436 fn create_single_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
438 let (packet_id, _) = self.reliable.send(tls_data.clone())?;
439
440 let packet = crate::packet::ControlPacketData {
441 header: crate::PacketHeader {
442 opcode: OpCode::ControlV1,
443 key_id: self.current_key_id,
444 session_id: Some(self.local_session_id),
445 hmac: None,
446 packet_id: None,
447 timestamp: None,
448 },
449 remote_session_id: self.remote_session_id,
450 acks: self.reliable.get_acks(),
451 message_packet_id: Some(packet_id),
452 payload: tls_data,
453 };
454
455 let serialized = Packet::Control(packet).serialize();
456 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
457 }
458
459 pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
461 self.create_single_control_packet(tls_data)
462 }
463
464 pub fn create_ack_packet(&mut self) -> Option<Bytes> {
466 let acks = self.reliable.get_acks();
467 if acks.is_empty() {
468 return None;
469 }
470
471 let packet = crate::packet::ControlPacketData {
472 header: crate::PacketHeader {
473 opcode: OpCode::AckV1,
474 key_id: self.current_key_id,
475 session_id: Some(self.local_session_id),
476 hmac: None,
477 packet_id: None,
478 timestamp: None,
479 },
480 remote_session_id: self.remote_session_id,
481 acks,
482 message_packet_id: None,
483 payload: Bytes::new(),
484 };
485
486 self.reliable.ack_sent();
487 let serialized = Packet::Control(packet).serialize();
488 Some(self.maybe_wrap_tls_auth(serialized.freeze()))
489 }
490
491 pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
493 let key_id = self.current_key_id;
494 let idx = key_id.0 as usize;
495
496 let (encrypt_key, decrypt_key) = if is_server {
497 (
498 key_material.server_data_key(self.cipher_suite),
499 key_material.client_data_key(self.cipher_suite),
500 )
501 } else {
502 (
503 key_material.client_data_key(self.cipher_suite),
504 key_material.server_data_key(self.cipher_suite),
505 )
506 };
507
508 let use_v2 = self.peer_id.is_some();
512 self.data_channels[idx] = Some(DataChannel::new(
513 key_id,
514 encrypt_key,
515 decrypt_key,
516 use_v2,
517 self.peer_id,
518 ));
519 }
520
521 pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
523 let idx = self.current_key_id.0 as usize;
524 if let Some(channel) = &mut self.data_channels[idx] {
525 let packet = channel.encrypt(data)?;
526 Ok(packet.serialize().freeze())
527 } else {
528 Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
529 }
530 }
531
532 pub fn get_retransmits(&mut self) -> Vec<Bytes> {
534 self.reliable
535 .get_retransmits()
536 .into_iter()
537 .map(|(id, data)| {
538 let packet = crate::packet::ControlPacketData {
540 header: crate::PacketHeader {
541 opcode: OpCode::ControlV1,
542 key_id: self.current_key_id,
543 session_id: Some(self.local_session_id),
544 hmac: None,
545 packet_id: None,
546 timestamp: None,
547 },
548 remote_session_id: self.remote_session_id,
549 acks: vec![],
550 message_packet_id: Some(id),
551 payload: data,
552 };
553 let serialized = Packet::Control(packet).serialize();
554 self.maybe_wrap_tls_auth(serialized.freeze())
555 })
556 .collect()
557 }
558
559 pub fn should_send_ack(&self) -> bool {
561 self.reliable.should_send_ack()
562 }
563
564 pub fn next_timeout(&self) -> Option<Duration> {
566 self.reliable.next_timeout()
567 }
568
569 pub fn is_established(&self) -> bool {
571 self.state == ProtocolState::Established
572 }
573
574 pub fn duration(&self) -> Duration {
576 self.created_at.elapsed()
577 }
578
579 pub fn idle_time(&self) -> Duration {
581 self.last_activity.elapsed()
582 }
583
584 fn maybe_wrap_tls_auth(&mut self, data: Bytes) -> Bytes {
585 if self.use_tls_auth {
586 if let Some(key) = &self.tls_auth_key {
587 if data.len() < 9 {
599 return data;
600 }
601
602 let packet_id = self.tls_auth_packet_id;
604 self.tls_auth_packet_id += 1;
605
606 let timestamp = std::time::SystemTime::now()
608 .duration_since(std::time::UNIX_EPOCH)
609 .unwrap_or_default()
610 .as_secs() as u32;
611
612 let pid_bytes = packet_id.to_be_bytes();
613 let time_bytes = timestamp.to_be_bytes();
614
615 let mut hmac_input = Vec::with_capacity(8 + data.len());
618 hmac_input.extend_from_slice(&pid_bytes);
619 hmac_input.extend_from_slice(&time_bytes);
620 hmac_input.extend_from_slice(&data); let hmac = key.authenticate(&hmac_input);
623
624 let mut output = Vec::with_capacity(data.len() + 32 + 8);
626 output.extend_from_slice(&data[0..9]); output.extend_from_slice(&hmac); output.extend_from_slice(&pid_bytes); output.extend_from_slice(&time_bytes); output.extend_from_slice(&data[9..]); return Bytes::from(output);
633 }
634 }
635 data
636 }
637
638 pub fn rotate_key(&mut self) {
640 self.current_key_id = self.current_key_id.next();
641 self.replay_window.reset();
643 }
644}
645
646#[derive(Debug)]
648pub enum ProcessedPacket {
649 None,
651 HardReset {
653 session_id: SessionIdBytes,
655 },
656 HardResetAck,
658 TlsData(Vec<Bytes>),
660 Data(Bytes),
662 SoftReset,
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_session_creation() {
672 let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
673 assert_eq!(session.state(), ProtocolState::Initial);
674 assert!(session.remote_session_id().is_none());
675 }
676
677 #[test]
678 fn test_hard_reset() {
679 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
680
681 let hard_reset = [
683 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, ];
688
689 let result = session.process_packet(&hard_reset).unwrap();
690 matches!(result, ProcessedPacket::HardReset { .. });
691 assert_eq!(session.state(), ProtocolState::TlsHandshake);
692 }
693
694 #[test]
695 fn test_hard_reset_response_has_packet_id() {
696 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
697
698 let hard_reset = [
700 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, ];
705 session.process_packet(&hard_reset).unwrap();
706
707 let response = session.create_hard_reset_response().unwrap();
709
710 assert!(response.len() >= 26, "Response too short: {} bytes", response.len());
718
719 assert_eq!(response[0], 0x40);
721
722 assert_eq!(response[9], 1);
724
725 let ack_id = u32::from_be_bytes(response[10..14].try_into().unwrap());
727 assert_eq!(ack_id, 0);
728
729 let msg_pid = u32::from_be_bytes(response[22..26].try_into().unwrap());
731 assert_eq!(msg_pid, 0);
732 }
733}