1use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6
7use corevpn_crypto::{CipherSuite, KeyMaterial};
8
9use crate::{
10 KeyId, OpCode, Packet, DataPacket, DataChannel,
11 ReliableTransport, ReliableConfig, TlsRecordReassembler,
12 ProtocolError, Result,
13};
14use crate::packet::ControlPacketData;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ProtocolState {
19 Initial,
21 TlsHandshake,
23 KeyExchange,
25 Authenticating,
27 Established,
29 Rekeying,
31 Terminated,
33}
34
35pub type SessionIdBytes = [u8; 8];
37
38pub struct ProtocolSession {
40 local_session_id: SessionIdBytes,
42 remote_session_id: Option<SessionIdBytes>,
44 state: ProtocolState,
46 current_key_id: KeyId,
48 reliable: ReliableTransport,
50 tls_reassembler: TlsRecordReassembler,
52 data_channels: [Option<DataChannel>; 8],
54 peer_id: Option<u32>,
56 use_tls_auth: bool,
58 tls_auth_key: Option<corevpn_crypto::HmacAuth>,
60 created_at: Instant,
62 last_activity: Instant,
64 cipher_suite: CipherSuite,
66}
67
68impl ProtocolSession {
69 pub fn new_server(cipher_suite: CipherSuite) -> Self {
71 Self {
72 local_session_id: corevpn_crypto::generate_session_id(),
73 remote_session_id: None,
74 state: ProtocolState::Initial,
75 current_key_id: KeyId::default(),
76 reliable: ReliableTransport::new(ReliableConfig::default()),
77 tls_reassembler: TlsRecordReassembler::new(65536),
78 data_channels: Default::default(),
79 peer_id: None,
80 use_tls_auth: false,
81 tls_auth_key: None,
82 created_at: Instant::now(),
83 last_activity: Instant::now(),
84 cipher_suite,
85 }
86 }
87
88 pub fn new_client(cipher_suite: CipherSuite) -> Self {
90 let mut session = Self::new_server(cipher_suite);
91 session.state = ProtocolState::Initial;
92 session
93 }
94
95 pub fn local_session_id(&self) -> &SessionIdBytes {
97 &self.local_session_id
98 }
99
100 pub fn remote_session_id(&self) -> Option<&SessionIdBytes> {
102 self.remote_session_id.as_ref()
103 }
104
105 pub fn state(&self) -> ProtocolState {
107 self.state
108 }
109
110 pub fn set_state(&mut self, state: ProtocolState) {
112 self.state = state;
113 self.last_activity = Instant::now();
114 }
115
116 pub fn set_remote_session_id(&mut self, id: SessionIdBytes) {
118 self.remote_session_id = Some(id);
119 }
120
121 pub fn set_tls_auth(&mut self, key: corevpn_crypto::HmacAuth) {
123 self.use_tls_auth = true;
124 self.tls_auth_key = Some(key);
125 }
126
127 pub fn process_packet(&mut self, data: &[u8]) -> Result<ProcessedPacket> {
129 self.last_activity = Instant::now();
130
131 let data = if self.use_tls_auth {
133 if let Some(key) = &self.tls_auth_key {
134 if !data.is_empty() && OpCode::from_byte(data[0])?.is_control() {
136 key.unwrap(data)?
137 } else {
138 data.to_vec()
139 }
140 } else {
141 data.to_vec()
142 }
143 } else {
144 data.to_vec()
145 };
146
147 let packet = Packet::parse(&data, false)?;
148
149 match packet {
150 Packet::Control(ctrl) => self.process_control_packet(ctrl),
151 Packet::Data(data_pkt) => self.process_data_packet(data_pkt),
152 }
153 }
154
155 fn process_control_packet(&mut self, ctrl: ControlPacketData) -> Result<ProcessedPacket> {
156 if !ctrl.acks.is_empty() {
158 self.reliable.process_acks(&ctrl.acks);
159 }
160
161 match ctrl.header.opcode {
163 OpCode::HardResetClientV2 | OpCode::HardResetClientV3 => {
164 if let Some(remote_sid) = ctrl.header.session_id {
166 self.remote_session_id = Some(remote_sid);
167 }
168 self.state = ProtocolState::TlsHandshake;
169
170 Ok(ProcessedPacket::HardReset {
171 session_id: ctrl.header.session_id.unwrap_or([0; 8]),
172 })
173 }
174 OpCode::HardResetServerV2 => {
175 if let Some(remote_sid) = ctrl.header.session_id {
177 self.remote_session_id = Some(remote_sid);
178 }
179 Ok(ProcessedPacket::HardResetAck)
180 }
181 OpCode::ControlV1 => {
182 if let Some(packet_id) = ctrl.message_packet_id {
184 if let Some(data) = self.reliable.receive(packet_id, ctrl.payload.clone()) {
185 self.tls_reassembler.add(&data)?;
186 let records = self.tls_reassembler.extract_records();
187 if !records.is_empty() {
188 return Ok(ProcessedPacket::TlsData(records));
189 }
190 }
191 }
192 Ok(ProcessedPacket::None)
193 }
194 OpCode::AckV1 => {
195 Ok(ProcessedPacket::None)
197 }
198 OpCode::SoftResetV1 => {
199 self.state = ProtocolState::Rekeying;
201 Ok(ProcessedPacket::SoftReset)
202 }
203 _ => Err(ProtocolError::UnknownOpcode(ctrl.header.opcode as u8)),
204 }
205 }
206
207 fn process_data_packet(&mut self, data_pkt: crate::packet::DataPacketData) -> Result<ProcessedPacket> {
208 let packet = DataPacket {
209 key_id: data_pkt.header.key_id,
210 peer_id: data_pkt.peer_id,
211 payload: data_pkt.payload,
212 };
213
214 let key_id = packet.key_id.0 as usize;
215 if let Some(channel) = &mut self.data_channels[key_id] {
216 let decrypted = channel.decrypt(&packet)?;
217 Ok(ProcessedPacket::Data(decrypted))
218 } else {
219 Err(ProtocolError::KeyNotAvailable(packet.key_id.0))
220 }
221 }
222
223 pub fn create_hard_reset_response(&mut self) -> Result<Bytes> {
225 let packet = crate::packet::ControlPacketData {
226 header: crate::PacketHeader {
227 opcode: OpCode::HardResetServerV2,
228 key_id: KeyId::default(),
229 session_id: Some(self.local_session_id),
230 hmac: None,
231 packet_id: None,
232 timestamp: None,
233 },
234 remote_session_id: self.remote_session_id,
235 acks: self.reliable.get_acks(),
236 message_packet_id: None,
237 payload: Bytes::new(),
238 };
239
240 let serialized = Packet::Control(packet).serialize();
241 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
242 }
243
244 pub fn create_control_packet(&mut self, tls_data: Bytes) -> Result<Bytes> {
246 let (packet_id, _) = self.reliable.send(tls_data.clone())?;
247
248 let packet = crate::packet::ControlPacketData {
249 header: crate::PacketHeader {
250 opcode: OpCode::ControlV1,
251 key_id: self.current_key_id,
252 session_id: Some(self.local_session_id),
253 hmac: None,
254 packet_id: None,
255 timestamp: None,
256 },
257 remote_session_id: self.remote_session_id,
258 acks: self.reliable.get_acks(),
259 message_packet_id: Some(packet_id),
260 payload: tls_data,
261 };
262
263 let serialized = Packet::Control(packet).serialize();
264 Ok(self.maybe_wrap_tls_auth(serialized.freeze()))
265 }
266
267 pub fn create_ack_packet(&mut self) -> Option<Bytes> {
269 let acks = self.reliable.get_acks();
270 if acks.is_empty() {
271 return None;
272 }
273
274 let packet = crate::packet::ControlPacketData {
275 header: crate::PacketHeader {
276 opcode: OpCode::AckV1,
277 key_id: self.current_key_id,
278 session_id: Some(self.local_session_id),
279 hmac: None,
280 packet_id: None,
281 timestamp: None,
282 },
283 remote_session_id: self.remote_session_id,
284 acks,
285 message_packet_id: None,
286 payload: Bytes::new(),
287 };
288
289 self.reliable.ack_sent();
290 let serialized = Packet::Control(packet).serialize();
291 Some(self.maybe_wrap_tls_auth(serialized.freeze()))
292 }
293
294 pub fn install_keys(&mut self, key_material: &KeyMaterial, is_server: bool) {
296 let key_id = self.current_key_id;
297 let idx = key_id.0 as usize;
298
299 let (encrypt_key, decrypt_key) = if is_server {
300 (
301 key_material.server_data_key(self.cipher_suite),
302 key_material.client_data_key(self.cipher_suite),
303 )
304 } else {
305 (
306 key_material.client_data_key(self.cipher_suite),
307 key_material.server_data_key(self.cipher_suite),
308 )
309 };
310
311 self.data_channels[idx] = Some(DataChannel::new(
312 key_id,
313 encrypt_key,
314 decrypt_key,
315 true,
316 self.peer_id,
317 ));
318 }
319
320 pub fn encrypt_data(&mut self, data: &[u8]) -> Result<Bytes> {
322 let idx = self.current_key_id.0 as usize;
323 if let Some(channel) = &mut self.data_channels[idx] {
324 let packet = channel.encrypt(data)?;
325 Ok(packet.serialize().freeze())
326 } else {
327 Err(ProtocolError::KeyNotAvailable(self.current_key_id.0))
328 }
329 }
330
331 pub fn get_retransmits(&mut self) -> Vec<Bytes> {
333 self.reliable
334 .get_retransmits()
335 .into_iter()
336 .map(|(id, data)| {
337 let packet = crate::packet::ControlPacketData {
339 header: crate::PacketHeader {
340 opcode: OpCode::ControlV1,
341 key_id: self.current_key_id,
342 session_id: Some(self.local_session_id),
343 hmac: None,
344 packet_id: None,
345 timestamp: None,
346 },
347 remote_session_id: self.remote_session_id,
348 acks: vec![],
349 message_packet_id: Some(id),
350 payload: data,
351 };
352 let serialized = Packet::Control(packet).serialize();
353 self.maybe_wrap_tls_auth(serialized.freeze())
354 })
355 .collect()
356 }
357
358 pub fn should_send_ack(&self) -> bool {
360 self.reliable.should_send_ack()
361 }
362
363 pub fn next_timeout(&self) -> Option<Duration> {
365 self.reliable.next_timeout()
366 }
367
368 pub fn is_established(&self) -> bool {
370 self.state == ProtocolState::Established
371 }
372
373 pub fn duration(&self) -> Duration {
375 self.created_at.elapsed()
376 }
377
378 pub fn idle_time(&self) -> Duration {
380 self.last_activity.elapsed()
381 }
382
383 fn maybe_wrap_tls_auth(&self, data: Bytes) -> Bytes {
384 if self.use_tls_auth {
385 if let Some(key) = &self.tls_auth_key {
386 return Bytes::from(key.wrap(&data));
387 }
388 }
389 data
390 }
391
392 pub fn rotate_key(&mut self) {
394 self.current_key_id = self.current_key_id.next();
395 }
396}
397
398#[derive(Debug)]
400pub enum ProcessedPacket {
401 None,
403 HardReset {
405 session_id: SessionIdBytes,
407 },
408 HardResetAck,
410 TlsData(Vec<Bytes>),
412 Data(Bytes),
414 SoftReset,
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_session_creation() {
424 let session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
425 assert_eq!(session.state(), ProtocolState::Initial);
426 assert!(session.remote_session_id().is_none());
427 }
428
429 #[test]
430 fn test_hard_reset() {
431 let mut session = ProtocolSession::new_server(CipherSuite::ChaCha20Poly1305);
432
433 let hard_reset = [
435 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, ];
439
440 let result = session.process_packet(&hard_reset).unwrap();
441 matches!(result, ProcessedPacket::HardReset { .. });
442 assert_eq!(session.state(), ProtocolState::TlsHandshake);
443 }
444}