1use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10 KeyId, OpCode, Packet, DataPacket, DataChannel,
11 ReliableTransport, ReliableConfig, TlsRecordReassembler,
12 ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19 Initial,
21 TlsHandshake,
23 KeyExchange,
25 Authenticating,
27 Established,
29 Rekeying,
31 Terminated,
33}
34
35pub type SessionIdBytes = [u8; 8];
37
38struct ReplayWindow {
41 highest: u32,
43 bitmap: u64,
46}
47
48impl ReplayWindow {
49 const WINDOW_SIZE: u32 = 64;
51
52 fn new() -> Self {
53 Self {
54 highest: 0,
55 bitmap: 0,
56 }
57 }
58
59 fn check_and_update(&mut self, packet_id: u32) -> bool {
64 if packet_id == 0 {
66 return false;
67 }
68
69 if packet_id > self.highest {
70 let shift = packet_id - self.highest;
72
73 if shift >= Self::WINDOW_SIZE {
74 self.bitmap = 1; } else {
77 self.bitmap = (self.bitmap << shift) | 1;
79 }
80 self.highest = packet_id;
81 true
82 } else {
83 let diff = self.highest - packet_id;
85
86 if diff >= Self::WINDOW_SIZE {
88 return false; }
90
91 let mask = 1u64 << diff;
93 if self.bitmap & mask != 0 {
94 return false; }
96
97 self.bitmap |= mask;
99 true
100 }
101 }
102
103 fn reset(&mut self) {
105 self.highest = 0;
106 self.bitmap = 0;
107 }
108}
109
110pub struct ProtocolSession {
112 local_session_id: SessionIdBytes,
114 remote_session_id: Option<SessionIdBytes>,
116 state: ProtocolState,
118 current_key_id: KeyId,
120 reliable: ReliableTransport,
122 tls_reassembler: TlsRecordReassembler,
124 data_channels: [Option<DataChannel>; 8],
126 peer_id: Option<u32>,
128 use_tls_auth: bool,
130 tls_auth_key: Option<corevpn_crypto::HmacAuth>,
132 replay_window: ReplayWindow,
134 tls_auth_packet_id: u32,
136 created_at: Instant,
138 last_activity: Instant,
140 cipher_suite: CipherSuite,
142}
143
144impl ProtocolSession {
145 pub fn new_server(cipher_suite: CipherSuite) -> Self {
147 Self {
148 local_session_id: corevpn_crypto::generate_session_id(),
149 remote_session_id: None,
150 state: ProtocolState::Initial,
151 current_key_id: KeyId::default(),
152 reliable: ReliableTransport::new(ReliableConfig::default()),
153 tls_reassembler: TlsRecordReassembler::new(65536),
154 data_channels: Default::default(),
155 peer_id: None,
156 use_tls_auth: false,
157 tls_auth_key: None,
158 replay_window: ReplayWindow::new(),
159 tls_auth_packet_id: 1, created_at: Instant::now(),
161 last_activity: Instant::now(),
162 cipher_suite,
163 }
164 }
165
166 pub fn new_client(cipher_suite: CipherSuite) -> Self {
168 let mut session = Self::new_server(cipher_suite);
169 session.state = ProtocolState::Initial;
170 session
171 }
172
173 pub fn local_session_id(&self) -> &SessionIdBytes {
175 &self.local_session_id
176 }
177
178 pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
180 self.remote_session_id.as_ref()
181 }
182
183 pub fn state(&self) -> ProtocolState {
185 self.state
186 }
187
188 pub fn set_state(&mut self, state: ProtocolState) {
190 self.state = state;
191 self.last_activity = Instant::now();
192 }
193
194 pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
196 self.remote_session_id = Some(id);
197 }
198
199 pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
201 self.use_tls_auth = true;
202 self.tls_auth_key = Some(key);
203 }
204
205 pub fn set_cipher_suite(&mut self, cipher_suite: CipherSuite) {
207 self.cipher_suite = cipher_suite;
208 }
209
210 pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
212 self.last_activity = Instant::now();
213
214 if self.use_tls_auth {
216 if let Some(key) = &self.tls_auth_key {
217 if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
218 if data.len() < 49 {
225 return Err(ProtocolError::PacketTooShort {
226 expected: 49,
227 got: data.len(),
228 });
229 }
230
231 let mut hmac = [0u8; 32];
232 hmac.copy_from_slice(&data[9..41]); let mut hmac_input = Vec::with_capacity(8 + 9 + data.len() - 49);
236 hmac_input.extend_from_slice(&data[41..49]); hmac_input.extend_from_slice(&data[0..9]); hmac_input.extend_from_slice(&data[49..]); key.verify(&hmac_input, &hmac)?;
241 }
242 }
243 }
244
245 let packet = Packet::parse(data, self.use_tls_auth)?;
246
247 match packet {
248 Packet::Control(ctrl) => self.process_control_packet(ctrl),
249 Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
250 }
251 }
252
253 fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
254 if self.use_tls_auth {
256 if let Some(packet_id) = ctrl.header.packet_id {
257 if !self.replay_window.check_and_update(packet_id) {
258 return Err(ProtocolError::ReplayDetected);
259 }
260 }
261 }
262
263 if !ctrl.acks.is_empty() {
265 self.reliable.process_acks(&ctrl.acks);
266 }
267
268 match ctrl.header.opcode {
270 OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
271 if let Some(remote_sid) = ctrl.header.session_id {
275 if remote_sid == [0; 8] {
277 return Err(ProtocolError::InvalidSessionId);
278 }
279 self.remote_session_id = Some(remote_sid);
281 }
282
283 if let Some(packet_id) = ctrl.message_packet_id {
286 let _ = self.reliable.receive(packet_id, Bytes::new())?;
287 }
288
289 self.state = ProtocolState::TlsHandshake;
290
291 Ok(ProcessedPacket::HardReset {
292 session_id: self.local_session_id,
293 })
294 }
295 OpCode::HardResetServerV2 => {
296 if let Some(remote_sid) = ctrl.header.session_id {
298 self.remote_session_id = Some(remote_sid);
299 }
300 Ok(ProcessedPacket::HardResetAck)
301 }
302 OpCode::ControlV1 => {
303 if let Some(packet_id) = ctrl.message_packet_id {
305 if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone())? {
306 self.tls_reassembler.add(&data)?;
307 let records = self.tls_reassembler.extract_records();
308 if !records.is_empty() {
309 return Ok(ProcessedPacket::TlsData(records));
310 }
311 }
312 }
313 Ok(ProcessedPacket::None)
314 }
315 OpCode::AckV1 => {
316 Ok(ProcessedPacket::None)
318 }
319 OpCode::SoftResetV1 => {
320 self.state = ProtocolState::Rekeying;
322 Ok(ProcessedPacket::SoftReset)
323 }
324 _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
325 }
326 }
327
328 fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
329 let packet = DataPacket {
330 key_id: data_pkt.header.key_id,
331 peer_id: data_pkt.peer_id,
332 payload: data_pkt.payload,
333 };
334
335 let key_id = packet.key_id.0 as usize;
336 if let Some(channel) = &mut self.data_channels[key_id] {
337 let decrypted = channel.decrypt(&packet)?;
338 Ok(ProcessedPacket::Data(decrypted))
339 } else {
340 Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
341 }
342 }
343
344 pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
346 let (packet_id, _) = self.reliable.send(Bytes::new())?;
350
351 let packet = crate::packet::ControlPacketData {
352 header: crate::PacketHeader {
353 opcode: OpCode::HardResetServerV2,
354 key_id: KeyId::default(),
355 session_id: Some(self.local_session_id),
356 hmac: None,
357 packet_id: None,
358 timestamp: None,
359 },
360 remote_session_id: self.remote_session_id,
361 acks: self.reliable.get_acks(),
362 message_packet_id: Some(packet_id),
363 payload: Bytes::new(),
364 };
365
366 let serialized = Packet::Control(packet).serialize();
367 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
368 }
369
370 const MAX_CONTROL_PAYLOAD: usize = 1100;
381
382 pub fn create_control_packets(&mut self, tls_data: Bytes) -> Result<Vec<Bytes>> {
386 let mut packets = Vec::new();
387
388 if tls_data.len() <= Self::MAX_CONTROL_PAYLOAD {
389 packets.push(self.create_single_control_packet(tls_data)?);
391 } else {
392 let mut offset = 0;
394 while offset < tls_data.len() {
395 let end = std::cmp::min(offset + Self::MAX_CONTROL_PAYLOAD, tls_data.len());
396 let chunk = tls_data.slice(offset..end);
397 packets.push(self.create_single_control_packet(chunk)?);
398 offset = end;
399 }
400 }
401
402 Ok(packets)
403 }
404
405 fn create_single_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
407 let (packet_id, _) = self.reliable.send(tls_data.clone())?;
408
409 let packet = crate::packet::ControlPacketData {
410 header: crate::PacketHeader {
411 opcode: OpCode::ControlV1,
412 key_id: self.current_key_id,
413 session_id: Some(self.local_session_id),
414 hmac: None,
415 packet_id: None,
416 timestamp: None,
417 },
418 remote_session_id: self.remote_session_id,
419 acks: self.reliable.get_acks(),
420 message_packet_id: Some(packet_id),
421 payload: tls_data,
422 };
423
424 let serialized = Packet::Control(packet).serialize();
425 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
426 }
427
428 pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
430 self.create_single_control_packet(tls_data)
431 }
432
433 pub fn create_ack_packet(&mut self) -> Option<Bytes> {
435 let acks = self.reliable.get_acks();
436 if acks.is_empty() {
437 return None;
438 }
439
440 let packet = crate::packet::ControlPacketData {
441 header: crate::PacketHeader {
442 opcode: OpCode::AckV1,
443 key_id: self.current_key_id,
444 session_id: Some(self.local_session_id),
445 hmac: None,
446 packet_id: None,
447 timestamp: None,
448 },
449 remote_session_id: self.remote_session_id,
450 acks,
451 message_packet_id: None,
452 payload: Bytes::new(),
453 };
454
455 self.reliable.ack_sent();
456 let serialized = Packet::Control(packet).serialize();
457 Some(self.maybe_wrap_tls_auth(serialized.freeze()))
458 }
459
460 pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
462 let key_id = self.current_key_id;
463 let idx = key_id.0 as usize;
464
465 let (encrypt_key, decrypt_key) = if is_server {
466 (
467 key_material.server_data_key(self.cipher_suite),
468 key_material.client_data_key(self.cipher_suite),
469 )
470 } else {
471 (
472 key_material.client_data_key(self.cipher_suite),
473 key_material.server_data_key(self.cipher_suite),
474 )
475 };
476
477 let use_v2 = self.peer_id.is_some();
481 self.data_channels[idx] = Some(DataChannel::new(
482 key_id,
483 encrypt_key,
484 decrypt_key,
485 use_v2,
486 self.peer_id,
487 ));
488 }
489
490 pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
492 let idx = self.current_key_id.0 as usize;
493 if let Some(channel) = &mut self.data_channels[idx] {
494 let packet = channel.encrypt(data)?;
495 Ok(packet.serialize().freeze())
496 } else {
497 Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
498 }
499 }
500
501 pub fn get_retransmits(&mut self) -> Vec<Bytes> {
503 self.reliable
504 .get_retransmits()
505 .into_iter()
506 .map(|(id, data)| {
507 let packet = crate::packet::ControlPacketData {
509 header: crate::PacketHeader {
510 opcode: OpCode::ControlV1,
511 key_id: self.current_key_id,
512 session_id: Some(self.local_session_id),
513 hmac: None,
514 packet_id: None,
515 timestamp: None,
516 },
517 remote_session_id: self.remote_session_id,
518 acks: vec![],
519 message_packet_id: Some(id),
520 payload: data,
521 };
522 let serialized = Packet::Control(packet).serialize();
523 self.maybe_wrap_tls_auth(serialized.freeze())
524 })
525 .collect()
526 }
527
528 pub fn should_send_ack(&self) -> bool {
530 self.reliable.should_send_ack()
531 }
532
533 pub fn next_timeout(&self) -> Option<Duration> {
535 self.reliable.next_timeout()
536 }
537
538 pub fn is_established(&self) -> bool {
540 self.state == ProtocolState::Established
541 }
542
543 pub fn duration(&self) -> Duration {
545 self.created_at.elapsed()
546 }
547
548 pub fn idle_time(&self) -> Duration {
550 self.last_activity.elapsed()
551 }
552
553 fn maybe_wrap_tls_auth(&mut self, data: Bytes) -> Bytes {
554 if self.use_tls_auth {
555 if let Some(key) = &self.tls_auth_key {
556 if data.len() < 9 {
568 return data;
569 }
570
571 let packet_id = self.tls_auth_packet_id;
573 self.tls_auth_packet_id += 1;
574
575 let timestamp = std::time::SystemTime::now()
577 .duration_since(std::time::UNIX_EPOCH)
578 .unwrap_or_default()
579 .as_secs() as u32;
580
581 let pid_bytes = packet_id.to_be_bytes();
582 let time_bytes = timestamp.to_be_bytes();
583
584 let mut hmac_input = Vec::with_capacity(8 + data.len());
587 hmac_input.extend_from_slice(&pid_bytes);
588 hmac_input.extend_from_slice(&time_bytes);
589 hmac_input.extend_from_slice(&data); let hmac = key.authenticate(&hmac_input);
592
593 let mut output = Vec::with_capacity(data.len() + 32 + 8);
595 output.extend_from_slice(&data[0..9]); output.extend_from_slice(&hmac); output.extend_from_slice(&pid_bytes); output.extend_from_slice(&time_bytes); output.extend_from_slice(&data[9..]); return Bytes::from(output);
602 }
603 }
604 data
605 }
606
607 pub fn rotate_key(&mut self) {
609 self.current_key_id = self.current_key_id.next();
610 self.replay_window.reset();
612 }
613}
614
615#[derive(Debug)]
617pub enum ProcessedPacket {
618 None,
620 HardReset {
622 session_id: SessionIdBytes,
624 },
625 HardResetAck,
627 TlsData(Vec<Bytes>),
629 Data(Bytes),
631 SoftReset,
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_session_creation() {
641 let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
642 assert_eq!(session.state(), ProtocolState::Initial);
643 assert!(session.remote_session_id().is_none());
644 }
645
646 #[test]
647 fn test_hard_reset() {
648 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
649
650 let hard_reset = [
652 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, ];
657
658 let result = session.process_packet(&hard_reset).unwrap();
659 matches!(result, ProcessedPacket::HardReset { .. });
660 assert_eq!(session.state(), ProtocolState::TlsHandshake);
661 }
662
663 #[test]
664 fn test_hard_reset_response_has_packet_id() {
665 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
666
667 let hard_reset = [
669 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, ];
674 session.process_packet(&hard_reset).unwrap();
675
676 let response = session.create_hard_reset_response().unwrap();
678
679 assert!(response.len() >= 26, "Response too short: {} bytes", response.len());
687
688 assert_eq!(response[0], 0x40);
690
691 assert_eq!(response[9], 1);
693
694 let ack_id = u32::from_be_bytes(response[10..14].try_into().unwrap());
696 assert_eq!(ack_id, 0);
697
698 let msg_pid = u32::from_be_bytes(response[22..26].try_into().unwrap());
700 assert_eq!(msg_pid, 0);
701 }
702}