1use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10 KeyId, OpCode, Packet, DataPacket, DataChannel,
11 ReliableTransport, ReliableConfig, TlsRecordReassembler,
12 ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19 Initial,
21 TlsHandshake,
23 KeyExchange,
25 Authenticating,
27 Established,
29 Rekeying,
31 Terminated,
33}
34
35pub type SessionIdBytes = [u8; 8];
37
38struct ReplayWindow {
41 highest: u32,
43 bitmap: u64,
46}
47
48impl ReplayWindow {
49 const WINDOW_SIZE: u32 = 64;
51
52 fn new() -> Self {
53 Self {
54 highest: 0,
55 bitmap: 0,
56 }
57 }
58
59 fn check_and_update(&mut self, packet_id: u32) -> bool {
64 if packet_id == 0 {
66 return false;
67 }
68
69 if packet_id > self.highest {
70 let shift = packet_id - self.highest;
72
73 if shift >= Self::WINDOW_SIZE {
74 self.bitmap = 1; } else {
77 self.bitmap = (self.bitmap << shift) | 1;
79 }
80 self.highest = packet_id;
81 true
82 } else {
83 let diff = self.highest - packet_id;
85
86 if diff >= Self::WINDOW_SIZE {
88 return false; }
90
91 let mask = 1u64 << diff;
93 if self.bitmap & mask != 0 {
94 return false; }
96
97 self.bitmap |= mask;
99 true
100 }
101 }
102
103 fn reset(&mut self) {
105 self.highest = 0;
106 self.bitmap = 0;
107 }
108}
109
110pub struct ProtocolSession {
112 local_session_id: SessionIdBytes,
114 remote_session_id: Option<SessionIdBytes>,
116 state: ProtocolState,
118 current_key_id: KeyId,
120 reliable: ReliableTransport,
122 tls_reassembler: TlsRecordReassembler,
124 data_channels: [Option<DataChannel>; 8],
126 peer_id: Option<u32>,
128 use_tls_auth: bool,
130 tls_auth_key: Option<corevpn_crypto::HmacAuth>,
132 replay_window: ReplayWindow,
134 created_at: Instant,
136 last_activity: Instant,
138 cipher_suite: CipherSuite,
140}
141
142impl ProtocolSession {
143 pub fn new_server(cipher_suite: CipherSuite) -> Self {
145 Self {
146 local_session_id: corevpn_crypto::generate_session_id(),
147 remote_session_id: None,
148 state: ProtocolState::Initial,
149 current_key_id: KeyId::default(),
150 reliable: ReliableTransport::new(ReliableConfig::default()),
151 tls_reassembler: TlsRecordReassembler::new(65536),
152 data_channels: Default::default(),
153 peer_id: None,
154 use_tls_auth: false,
155 tls_auth_key: None,
156 replay_window: ReplayWindow::new(),
157 created_at: Instant::now(),
158 last_activity: Instant::now(),
159 cipher_suite,
160 }
161 }
162
163 pub fn new_client(cipher_suite: CipherSuite) -> Self {
165 let mut session = Self::new_server(cipher_suite);
166 session.state = ProtocolState::Initial;
167 session
168 }
169
170 pub fn local_session_id(&self) -> &SessionIdBytes {
172 &self.local_session_id
173 }
174
175 pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
177 self.remote_session_id.as_ref()
178 }
179
180 pub fn state(&self) -> ProtocolState {
182 self.state
183 }
184
185 pub fn set_state(&mut self, state: ProtocolState) {
187 self.state = state;
188 self.last_activity = Instant::now();
189 }
190
191 pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
193 self.remote_session_id = Some(id);
194 }
195
196 pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
198 self.use_tls_auth = true;
199 self.tls_auth_key = Some(key);
200 }
201
202 pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
204 self.last_activity = Instant::now();
205
206 let data = if self.use_tls_auth {
208 if let Some(key) = &self.tls_auth_key {
209 if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
211 key.unwrap(data)?
212 } else {
213 data.to_vec()
214 }
215 } else {
216 data.to_vec()
217 }
218 } else {
219 data.to_vec()
220 };
221
222 let packet = Packet::parse(&data, self.use_tls_auth)?;
223
224 match packet {
225 Packet::Control(ctrl) => self.process_control_packet(ctrl),
226 Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
227 }
228 }
229
230 fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
231 if self.use_tls_auth {
233 if let Some(packet_id) = ctrl.header.packet_id {
234 if !self.replay_window.check_and_update(packet_id) {
235 return Err(ProtocolError::ReplayDetected);
236 }
237 }
238 }
239
240 if !ctrl.acks.is_empty() {
242 self.reliable.process_acks(&ctrl.acks);
243 }
244
245 match ctrl.header.opcode {
247 OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
248 if let Some(remote_sid) = ctrl.header.session_id {
252 if remote_sid == [0; 8] {
254 return Err(ProtocolError::InvalidSessionId);
255 }
256 self.remote_session_id = Some(remote_sid);
258 }
259 self.state = ProtocolState::TlsHandshake;
260
261 Ok(ProcessedPacket::HardReset {
262 session_id: self.local_session_id,
263 })
264 }
265 OpCode::HardResetServerV2 => {
266 if let Some(remote_sid) = ctrl.header.session_id {
268 self.remote_session_id = Some(remote_sid);
269 }
270 Ok(ProcessedPacket::HardResetAck)
271 }
272 OpCode::ControlV1 => {
273 if let Some(packet_id) = ctrl.message_packet_id {
275 if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone())? {
276 self.tls_reassembler.add(&data)?;
277 let records = self.tls_reassembler.extract_records();
278 if !records.is_empty() {
279 return Ok(ProcessedPacket::TlsData(records));
280 }
281 }
282 }
283 Ok(ProcessedPacket::None)
284 }
285 OpCode::AckV1 => {
286 Ok(ProcessedPacket::None)
288 }
289 OpCode::SoftResetV1 => {
290 self.state = ProtocolState::Rekeying;
292 Ok(ProcessedPacket::SoftReset)
293 }
294 _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
295 }
296 }
297
298 fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
299 let packet = DataPacket {
300 key_id: data_pkt.header.key_id,
301 peer_id: data_pkt.peer_id,
302 payload: data_pkt.payload,
303 };
304
305 let key_id = packet.key_id.0 as usize;
306 if let Some(channel) = &mut self.data_channels[key_id] {
307 let decrypted = channel.decrypt(&packet)?;
308 Ok(ProcessedPacket::Data(decrypted))
309 } else {
310 Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
311 }
312 }
313
314 pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
316 let packet = crate::packet::ControlPacketData {
317 header: crate::PacketHeader {
318 opcode: OpCode::HardResetServerV2,
319 key_id: KeyId::default(),
320 session_id: Some(self.local_session_id),
321 hmac: None,
322 packet_id: None,
323 timestamp: None,
324 },
325 remote_session_id: self.remote_session_id,
326 acks: self.reliable.get_acks(),
327 message_packet_id: None,
328 payload: Bytes::new(),
329 };
330
331 let serialized = Packet::Control(packet).serialize();
332 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
333 }
334
335 pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
337 let (packet_id, _) = self.reliable.send(tls_data.clone())?;
338
339 let packet = crate::packet::ControlPacketData {
340 header: crate::PacketHeader {
341 opcode: OpCode::ControlV1,
342 key_id: self.current_key_id,
343 session_id: Some(self.local_session_id),
344 hmac: None,
345 packet_id: None,
346 timestamp: None,
347 },
348 remote_session_id: self.remote_session_id,
349 acks: self.reliable.get_acks(),
350 message_packet_id: Some(packet_id),
351 payload: tls_data,
352 };
353
354 let serialized = Packet::Control(packet).serialize();
355 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
356 }
357
358 pub fn create_ack_packet(&mut self) -> Option<Bytes> {
360 let acks = self.reliable.get_acks();
361 if acks.is_empty() {
362 return None;
363 }
364
365 let packet = crate::packet::ControlPacketData {
366 header: crate::PacketHeader {
367 opcode: OpCode::AckV1,
368 key_id: self.current_key_id,
369 session_id: Some(self.local_session_id),
370 hmac: None,
371 packet_id: None,
372 timestamp: None,
373 },
374 remote_session_id: self.remote_session_id,
375 acks,
376 message_packet_id: None,
377 payload: Bytes::new(),
378 };
379
380 self.reliable.ack_sent();
381 let serialized = Packet::Control(packet).serialize();
382 Some(self.maybe_wrap_tls_auth(serialized.freeze()))
383 }
384
385 pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
387 let key_id = self.current_key_id;
388 let idx = key_id.0 as usize;
389
390 let (encrypt_key, decrypt_key) = if is_server {
391 (
392 key_material.server_data_key(self.cipher_suite),
393 key_material.client_data_key(self.cipher_suite),
394 )
395 } else {
396 (
397 key_material.client_data_key(self.cipher_suite),
398 key_material.server_data_key(self.cipher_suite),
399 )
400 };
401
402 self.data_channels[idx] = Some(DataChannel::new(
403 key_id,
404 encrypt_key,
405 decrypt_key,
406 true,
407 self.peer_id,
408 ));
409 }
410
411 pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
413 let idx = self.current_key_id.0 as usize;
414 if let Some(channel) = &mut self.data_channels[idx] {
415 let packet = channel.encrypt(data)?;
416 Ok(packet.serialize().freeze())
417 } else {
418 Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
419 }
420 }
421
422 pub fn get_retransmits(&mut self) -> Vec<Bytes> {
424 self.reliable
425 .get_retransmits()
426 .into_iter()
427 .map(|(id, data)| {
428 let packet = crate::packet::ControlPacketData {
430 header: crate::PacketHeader {
431 opcode: OpCode::ControlV1,
432 key_id: self.current_key_id,
433 session_id: Some(self.local_session_id),
434 hmac: None,
435 packet_id: None,
436 timestamp: None,
437 },
438 remote_session_id: self.remote_session_id,
439 acks: vec![],
440 message_packet_id: Some(id),
441 payload: data,
442 };
443 let serialized = Packet::Control(packet).serialize();
444 self.maybe_wrap_tls_auth(serialized.freeze())
445 })
446 .collect()
447 }
448
449 pub fn should_send_ack(&self) -> bool {
451 self.reliable.should_send_ack()
452 }
453
454 pub fn next_timeout(&self) -> Option<Duration> {
456 self.reliable.next_timeout()
457 }
458
459 pub fn is_established(&self) -> bool {
461 self.state == ProtocolState::Established
462 }
463
464 pub fn duration(&self) -> Duration {
466 self.created_at.elapsed()
467 }
468
469 pub fn idle_time(&self) -> Duration {
471 self.last_activity.elapsed()
472 }
473
474 fn maybe_wrap_tls_auth(&self, data: Bytes) -> Bytes {
475 if self.use_tls_auth {
476 if let Some(key) = &self.tls_auth_key {
477 return Bytes::from(key.wrap(&data));
478 }
479 }
480 data
481 }
482
483 pub fn rotate_key(&mut self) {
485 self.current_key_id = self.current_key_id.next();
486 self.replay_window.reset();
488 }
489}
490
491#[derive(Debug)]
493pub enum ProcessedPacket {
494 None,
496 HardReset {
498 session_id: SessionIdBytes,
500 },
501 HardResetAck,
503 TlsData(Vec<Bytes>),
505 Data(Bytes),
507 SoftReset,
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_session_creation() {
517 let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
518 assert_eq!(session.state(), ProtocolState::Initial);
519 assert!(session.remote_session_id().is_none());
520 }
521
522 #[test]
523 fn test_hard_reset() {
524 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
525
526 let hard_reset = [
528 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, ];
532
533 let result = session.process_packet(&hard_reset).unwrap();
534 matches!(result, ProcessedPacket::HardReset { .. });
535 assert_eq!(session.state(), ProtocolState::TlsHandshake);
536 }
537}