1use dashmap::DashMap;
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::UdpSocket;
17use tokio::sync::mpsc;
18use tracing::{debug, error, info, warn};
19
20use aivpn_common::crypto::{self, decrypt_payload, encrypt_payload, NONCE_SIZE, TAG_SIZE};
21use aivpn_common::error::{Error, Result};
22use aivpn_common::mask::{
23 current_unix_secs, derive_bootstrap_candidates, BootstrapDescriptor, MaskProfile,
24};
25use aivpn_common::network_config::VpnNetworkConfig;
26use aivpn_common::protocol::{
27 ControlPayload, ControlSubtype, InnerHeader, InnerType, MAX_PACKET_SIZE,
28};
29
30use crate::client_db::ClientDatabase;
31use crate::mask_gen::generate_and_store_mask;
32use crate::mask_store::MaskStore;
33use crate::metrics::MetricsCollector;
34use crate::nat::NatForwarder;
35use crate::neural::{NeuralConfig, NeuralResonanceModule, ResonanceStatus};
36use crate::recording::RecordingManager;
37use crate::recording::{RecordingStopOutcome, RecordingStopReason};
38use crate::session::{Session, SessionManager, MAX_SESSIONS};
39
40struct QueuedPacket {
41 packet_data: Vec<u8>,
42 client_addr: SocketAddr,
43}
44
45#[derive(Clone)]
47pub struct GatewayConfig {
48 pub listen_addr: String,
49 pub per_ip_pps_limit: u64,
50 pub tun_name: String,
51 pub tun_addr: String,
52 pub tun_netmask: String,
53 pub network_config: VpnNetworkConfig,
54 pub server_private_key: [u8; 32],
55 pub signing_key: [u8; 64],
56 pub enable_nat: bool,
57 pub enable_neural: bool,
59 pub neural_config: NeuralConfig,
61 pub client_db: Option<Arc<ClientDatabase>>,
63 pub mask_dir: std::path::PathBuf,
65 pub session_timeout_secs: Option<u64>,
67 pub idle_timeout_secs: Option<u64>,
69 pub bootstrap_masks: Vec<MaskProfile>,
71}
72
73impl Default for GatewayConfig {
74 fn default() -> Self {
75 Self {
76 listen_addr: "0.0.0.0:443".to_string(),
77 per_ip_pps_limit: 1000,
78 tun_name: "aivpn0".to_string(),
79 tun_addr: "10.0.0.1".to_string(),
80 tun_netmask: "255.255.255.0".to_string(),
81 network_config: VpnNetworkConfig::default(),
82 server_private_key: [0u8; 32],
83 signing_key: [0u8; 64],
84 enable_nat: true,
85 enable_neural: true,
86 neural_config: NeuralConfig::default(),
87 client_db: None,
88 mask_dir: std::path::PathBuf::from("/var/lib/aivpn/masks"),
89 session_timeout_secs: None,
90 idle_timeout_secs: None,
91 bootstrap_masks: Vec::new(),
92 }
93 }
94}
95
96pub struct MaskCatalog {
101 masks: DashMap<String, MaskProfile>,
103 compromised: DashMap<String, Instant>,
105 primary_mask_id: parking_lot::Mutex<String>,
107}
108
109impl MaskCatalog {
110 pub fn new() -> Self {
111 Self {
112 masks: DashMap::new(),
113 compromised: DashMap::new(),
114 primary_mask_id: parking_lot::Mutex::new(String::new()),
115 }
116 }
117
118 pub fn set_primary_mask_id(&self, mask_id: String) {
120 *self.primary_mask_id.lock() = mask_id;
121 }
122
123 pub fn register_mask(&self, mask: MaskProfile) {
125 if !self.compromised.contains_key(&mask.mask_id) {
126 self.masks.insert(mask.mask_id.clone(), mask);
127 }
128 }
129
130 pub fn mark_compromised(&self, mask_id: &str) {
132 self.compromised.insert(mask_id.to_string(), Instant::now());
133 self.masks.remove(mask_id);
134 }
135
136 pub fn remove_mask(&self, mask_id: &str) {
138 self.masks.remove(mask_id);
139 }
140
141 pub fn select_fallback(&self, current_mask_id: &str) -> Option<MaskProfile> {
143 self.masks
144 .iter()
145 .filter(|e| e.key() != current_mask_id)
146 .map(|e| e.value().clone())
147 .next()
148 }
149
150 pub fn available_count(&self) -> usize {
152 self.masks.len()
153 }
154
155 pub fn packet_layout(&self) -> (usize, usize, usize, usize) {
160 let fallback = (20usize, 52usize, 20usize, 32usize);
161 let Some(mask) = self.primary_mask() else {
162 return fallback;
163 };
164
165 packet_layout_for_mask(&mask)
166 }
167
168 pub fn packet_mdh_bytes(&self) -> Vec<u8> {
171 self.primary_mask()
172 .map(|mask| packet_mdh_bytes_for_mask(&mask))
173 .unwrap_or_else(|| vec![0u8; 20])
174 }
175
176 pub fn primary_mask(&self) -> Option<MaskProfile> {
177 let primary_id = self.primary_mask_id.lock().clone();
178 self.masks
179 .get(&primary_id)
180 .map(|entry| entry.value().clone())
181 .or_else(|| self.masks.iter().next().map(|entry| entry.value().clone()))
182 }
183}
184
185fn packet_layout_for_mask(mask: &MaskProfile) -> (usize, usize, usize, usize) {
186 let eph_offset = mask.eph_pub_offset as usize;
187 let eph_len = mask.eph_pub_length as usize;
188 let packet_mdh_len = mask
189 .header_spec
190 .as_ref()
191 .map(|spec| spec.min_length())
192 .unwrap_or_else(|| mask.header_template.len());
193 let handshake_mdh_len = packet_mdh_len.max(eph_offset.saturating_add(eph_len));
194 (packet_mdh_len, handshake_mdh_len, eph_offset, eph_len)
195}
196
197fn packet_mdh_bytes_for_mask(mask: &MaskProfile) -> Vec<u8> {
198 if let Some(ref spec) = mask.header_spec {
199 let mut rng = rand::thread_rng();
200 spec.generate(&mut rng)
201 } else {
202 mask.header_template.clone()
203 }
204}
205
206fn hash_addr(addr: &SocketAddr) -> String {
208 let hash = crypto::blake3_hash(addr.to_string().as_bytes());
209 format!(
210 "{:02x}{:02x}{:02x}{:02x}",
211 hash[0], hash[1], hash[2], hash[3]
212 )
213}
214
215pub struct Gateway {
217 config: GatewayConfig,
218 session_manager: Arc<SessionManager>,
219 udp_socket: Option<Arc<UdpSocket>>,
220 nat_forwarder: Option<Arc<NatForwarder>>,
221 tun_write_tx: Option<mpsc::Sender<Vec<u8>>>,
223 rate_limits: Arc<DashMap<IpAddr, (u64, Instant)>>,
225 handshake_cooldowns: Arc<DashMap<IpAddr, (u32, Instant)>>,
228 handshake_locks: Arc<DashMap<IpAddr, Arc<tokio::sync::Mutex<()>>>>,
232 neural_module: Arc<parking_lot::Mutex<NeuralResonanceModule>>,
234 mask_catalog: Arc<MaskCatalog>,
236 metrics: Arc<MetricsCollector>,
238 client_db: Option<Arc<ClientDatabase>>,
240 recording_manager: Option<Arc<RecordingManager>>,
242 #[allow(dead_code)]
244 mask_store: Option<Arc<MaskStore>>,
245 bootstrap_descriptors: Vec<BootstrapDescriptor>,
247}
248
249const BOOTSTRAP_ROTATION_SECS: u64 = 24 * 3600;
250const BOOTSTRAP_DESCRIPTOR_CANDIDATES: u8 = 4;
251
252fn bootstrap_epoch(unix_secs: u64) -> u64 {
253 unix_secs / BOOTSTRAP_ROTATION_SECS
254}
255
256pub fn derive_server_signing_key(server_private_key: &[u8; 32]) -> ed25519_dalek::SigningKey {
257 let seed = blake3::derive_key("aivpn-ed25519-signing-v1", server_private_key);
258 ed25519_dalek::SigningKey::from_bytes(&seed)
259}
260
261fn sign_bootstrap_descriptor(
262 mut descriptor: BootstrapDescriptor,
263 signing_key: &ed25519_dalek::SigningKey,
264) -> BootstrapDescriptor {
265 use ed25519_dalek::Signer;
266 descriptor.signature = signing_key.sign(&descriptor.signing_bytes()).to_bytes();
267 descriptor
268}
269
270fn build_bootstrap_descriptor(
271 server_seed: &[u8; 32],
272 signing_key: &ed25519_dalek::SigningKey,
273 epoch: u64,
274 bootstrap_masks: &[MaskProfile],
275) -> BootstrapDescriptor {
276 let mut hasher = blake3::Hasher::new_keyed(server_seed);
277 hasher.update(&epoch.to_le_bytes());
278 let hash = hasher.finalize();
279 let mut kdf_salt = [0u8; 32];
280 kdf_salt.copy_from_slice(&hash.as_bytes()[..32]);
281 let created_at = epoch * BOOTSTRAP_ROTATION_SECS;
282 let expires_at = created_at + (2 * BOOTSTRAP_ROTATION_SECS);
283 let (base_mask_ids, embedded_masks) = if bootstrap_masks.is_empty() {
284 (
285 aivpn_common::mask::preset_masks::all()
286 .into_iter()
287 .map(|mask| mask.mask_id)
288 .collect(),
289 Vec::new(),
290 )
291 } else {
292 (Vec::new(), bootstrap_masks.to_vec())
293 };
294
295 sign_bootstrap_descriptor(
296 BootstrapDescriptor {
297 descriptor_id: format!("epoch-{}", epoch),
298 version: 1,
299 created_at,
300 expires_at,
301 base_mask_ids,
302 embedded_masks,
303 candidate_count: BOOTSTRAP_DESCRIPTOR_CANDIDATES,
304 kdf_salt,
305 signature: [0u8; 64],
306 },
307 signing_key,
308 )
309}
310
311pub fn build_bootstrap_descriptors(
312 server_seed: &[u8; 32],
313 signing_key: &ed25519_dalek::SigningKey,
314 bootstrap_masks: &[MaskProfile],
315) -> Vec<BootstrapDescriptor> {
316 let epoch = bootstrap_epoch(current_unix_secs());
317 [epoch.saturating_sub(1), epoch, epoch.saturating_add(1)]
318 .into_iter()
319 .map(|value| build_bootstrap_descriptor(server_seed, signing_key, value, bootstrap_masks))
320 .collect()
321}
322
323impl Gateway {
324 fn can_start_recording(&self, client_id: Option<&str>) -> bool {
325 let Some(client_id) = client_id else {
326 return false;
327 };
328
329 if client_id == "admin" {
330 return true;
331 }
332
333 self.client_db
334 .as_ref()
335 .and_then(|db| db.find_by_id(client_id))
336 .map(|client| client.name.starts_with("recording-admin"))
337 .unwrap_or(false)
338 }
339
340 async fn handle_recording_outcome(
341 socket: &Arc<UdpSocket>,
342 sessions: &Arc<SessionManager>,
343 store: &Arc<MaskStore>,
344 mdh: &[u8],
345 outcome: RecordingStopOutcome,
346 notify_session: Option<Arc<parking_lot::Mutex<Session>>>,
347 ) {
348 match outcome {
349 RecordingStopOutcome::Completed(completed) => {
350 if let Some(ref session) = notify_session {
351 let ack = ControlPayload::RecordingAck {
352 session_id: completed.session_id,
353 status: "analyzing".into(),
354 };
355 if let Err(e) =
356 Self::send_control_message_via(socket.as_ref(), mdh, &ack, session).await
357 {
358 warn!("Failed to send RecordingAck: {}", e);
359 }
360 }
361
362 info!(
363 "Recording stopped for '{}' ({} packets, {}s), analyzing...",
364 completed.service, completed.total_packets, completed.duration_secs
365 );
366
367 let socket = socket.clone();
368 let sessions = sessions.clone();
369 let store = store.clone();
370 let mdh = mdh.to_vec();
371 tokio::spawn(async move {
372 match generate_and_store_mask(&completed.service, &completed.packets, &store)
373 .await
374 {
375 Ok(mask_id) => {
376 info!(
377 "✅ Mask generated: '{}' for service '{}' by {}",
378 mask_id, completed.service, completed.admin_key_id
379 );
380 if let Some(target_session) =
381 sessions.get_session(&completed.session_id)
382 {
383 let confidence = store
384 .get_mask(&mask_id)
385 .map(|entry| entry.stats.confidence)
386 .unwrap_or(0.0);
387 let payload = ControlPayload::RecordingComplete {
388 service: completed.service.clone(),
389 mask_id,
390 confidence,
391 };
392 if let Err(e) = Self::send_control_message_via(
393 socket.as_ref(),
394 &mdh,
395 &payload,
396 &target_session,
397 )
398 .await
399 {
400 warn!("Failed to send RecordingComplete: {}", e);
401 }
402 }
403 }
404 Err(e) => {
405 warn!("Mask generation failed for '{}': {}", completed.service, e);
406 if let Some(target_session) =
407 sessions.get_session(&completed.session_id)
408 {
409 let payload = ControlPayload::RecordingFailed {
410 reason: e.to_string(),
411 };
412 if let Err(send_err) = Self::send_control_message_via(
413 socket.as_ref(),
414 &mdh,
415 &payload,
416 &target_session,
417 )
418 .await
419 {
420 warn!("Failed to send RecordingFailed: {}", send_err);
421 }
422 }
423 }
424 }
425 });
426 }
427 RecordingStopOutcome::Incomplete(incomplete) => {
428 let reason = match incomplete.reason {
429 RecordingStopReason::IdleTimeout => {
430 "Recording stopped after idle timeout before enough traffic was captured"
431 }
432 RecordingStopReason::SessionEnded => {
433 "Recording ended with the session before enough traffic was captured"
434 }
435 _ => "Too few packets or too short duration",
436 };
437 if let Some(ref session) = notify_session {
438 let payload = ControlPayload::RecordingFailed {
439 reason: reason.into(),
440 };
441 if let Err(e) =
442 Self::send_control_message_via(socket.as_ref(), mdh, &payload, session)
443 .await
444 {
445 warn!("Failed to send RecordingFailed: {}", e);
446 }
447 }
448 warn!(
449 "Recording for '{}' ended without mask generation: {} packets, {}s ({:?})",
450 incomplete.service,
451 incomplete.total_packets,
452 incomplete.duration_secs,
453 incomplete.reason
454 );
455 }
456 RecordingStopOutcome::NotFound => {}
457 }
458 }
459
460 pub fn new(config: GatewayConfig) -> Result<Self> {
461 let server_keys = if config.server_private_key != [0u8; 32] {
463 crypto::KeyPair::from_private_key(config.server_private_key)
464 } else {
465 crypto::KeyPair::generate()
466 };
467
468 let signing_key = derive_server_signing_key(&config.server_private_key);
470 let bootstrap_descriptors = build_bootstrap_descriptors(
471 &config.server_private_key,
472 &signing_key,
473 &config.bootstrap_masks,
474 );
475
476 let mask_catalog = Arc::new(MaskCatalog::new());
478
479 let mask_store = Arc::new(MaskStore::new(
481 mask_catalog.clone(),
482 config.mask_dir.clone(),
483 ));
484
485 let primary_id = if let Some(first) = mask_catalog.masks.iter().next() {
488 let id = first.key().clone();
489 id
490 } else {
491 String::new()
492 };
493 if !primary_id.is_empty() {
494 info!("Primary mask set to '{}' (loaded from disk)", primary_id);
495 mask_catalog.set_primary_mask_id(primary_id);
496 } else {
497 warn!("No masks found in {:?} — server will not accept connections until masks are recorded", config.mask_dir);
498 }
499
500 let default_mask = mask_catalog.primary_mask().ok_or_else(|| {
502 Error::Session(format!(
503 "No masks found in {:?} — place mask JSON files there before starting the server",
504 config.mask_dir
505 ))
506 })?;
507
508 let session_manager = Arc::new(SessionManager::with_timeouts(
509 server_keys,
510 signing_key,
511 default_mask,
512 config.session_timeout_secs,
513 config.idle_timeout_secs,
514 ));
515
516 let mut neural = NeuralResonanceModule::new(config.neural_config.clone())
518 .map_err(|e| Error::Session(format!("Neural module init failed: {}", e)))?;
519
520 if config.enable_neural {
521 for entry in mask_catalog.masks.iter() {
523 let _ = neural.register_mask(entry.value());
524 }
525 let _ = neural.load_model();
527 info!("Neural Resonance Module initialized (Patent 1)");
528 }
529
530 let recording_manager = Arc::new(RecordingManager::new(mask_store.clone()));
531 info!(
532 "Auto Mask Recording system initialized ({} masks loaded from disk)",
533 mask_catalog.available_count()
534 );
535
536 Ok(Self {
537 config: config.clone(),
538 session_manager,
539 udp_socket: None,
540 nat_forwarder: None,
541 tun_write_tx: None,
542 rate_limits: Arc::new(DashMap::new()),
543 handshake_cooldowns: Arc::new(DashMap::new()),
544 handshake_locks: Arc::new(DashMap::new()),
545 neural_module: Arc::new(parking_lot::Mutex::new(neural)),
546 mask_catalog,
547 metrics: Arc::new(MetricsCollector::new()),
548 client_db: config.client_db,
549 recording_manager: Some(recording_manager),
550 mask_store: Some(mask_store),
551 bootstrap_descriptors,
552 })
553 }
554
555 async fn send_bootstrap_descriptors(
556 &self,
557 session: &Arc<parking_lot::Mutex<Session>>,
558 ) -> Result<()> {
559 for descriptor in &self.bootstrap_descriptors {
560 let payload = ControlPayload::BootstrapDescriptorUpdate {
561 descriptor_data: rmp_serde::to_vec(descriptor).map_err(|e| {
562 Error::Session(format!("Failed to serialize bootstrap descriptor: {}", e))
563 })?,
564 };
565 self.send_control_message(&payload, session).await?;
566 }
567 Ok(())
568 }
569
570 pub async fn run(mut self) -> Result<()> {
572 info!("Starting AIVPN Gateway on {}", self.config.listen_addr);
573 info!(
574 "Per-IP UDP rate limit: {} pps",
575 self.config.per_ip_pps_limit
576 );
577
578 if self.config.enable_nat {
580 let mut nat = NatForwarder::new(
581 &self.config.tun_name,
582 &self.config.tun_addr,
583 &self.config.tun_netmask,
584 self.config.network_config,
585 )?;
586 nat.create()?;
587 self.nat_forwarder = Some(Arc::new(nat));
588 info!(
589 "TUN device: {} ({}/{})",
590 self.config.tun_name, self.config.tun_addr, self.config.tun_netmask
591 );
592 }
593
594 let bind_addr: SocketAddr =
596 self.config
597 .listen_addr
598 .parse()
599 .map_err(|e: std::net::AddrParseError| {
600 Error::Io(std::io::Error::new(
601 std::io::ErrorKind::InvalidInput,
602 e.to_string(),
603 ))
604 })?;
605
606 let socket2_sock = socket2::Socket::new(
607 if bind_addr.is_ipv4() {
608 socket2::Domain::IPV4
609 } else {
610 socket2::Domain::IPV6
611 },
612 socket2::Type::DGRAM,
613 Some(socket2::Protocol::UDP),
614 )
615 .map_err(Error::Io)?;
616
617 socket2_sock.set_nonblocking(true).map_err(Error::Io)?;
618 let _ = socket2_sock.set_recv_buffer_size(4 * 1024 * 1024);
619 let _ = socket2_sock.set_send_buffer_size(4 * 1024 * 1024);
620 socket2_sock.bind(&bind_addr.into()).map_err(Error::Io)?;
621
622 let std_sock: std::net::UdpSocket = socket2_sock.into();
623 let socket = UdpSocket::from_std(std_sock).map_err(Error::Io)?;
624
625 info!(
626 "UDP listener bound to {} (4MB buffers via socket2)",
627 self.config.listen_addr
628 );
629
630 self.udp_socket = Some(Arc::new(socket));
631
632 if self.config.enable_neural {
634 let neural = self.neural_module.clone();
635 let sessions = self.session_manager.clone();
636 let catalog = self.mask_catalog.clone();
637 let metrics = self.metrics.clone();
638 let check_interval = self.config.neural_config.check_interval_secs;
639 let socket = self.udp_socket.as_ref().unwrap().clone();
640
641 tokio::spawn(async move {
642 Self::resonance_check_loop(
643 neural,
644 sessions,
645 catalog,
646 metrics,
647 check_interval,
648 socket,
649 )
650 .await;
651 });
652 info!(
653 "Neural resonance check loop spawned (interval: {}s)",
654 check_interval
655 );
656 }
657
658 if let Some(ref nat) = self.nat_forwarder {
661 if let Some(tun_reader) = nat.take_reader().await {
662 let sessions = self.session_manager.clone();
663 let socket = self.udp_socket.as_ref().unwrap().clone();
664 let mask = self
665 .mask_catalog
666 .masks
667 .iter()
668 .next()
669 .map(|e| e.value().clone())
670 .expect("at least one mask must be loaded");
671 let server_vpn_ip = self.config.network_config.server_vpn_ip;
672 let recorder = self.recording_manager.clone();
673
674 let (tun_tx, tun_rx) = mpsc::channel::<Vec<u8>>(4096);
676 self.tun_write_tx = Some(tun_tx.clone());
677
678 if let Some(tun_writer) = nat.take_writer().await {
680 tokio::spawn(async move {
681 Self::tun_write_loop(tun_writer, tun_rx).await;
682 });
683 info!("TUN write loop spawned (channel-based, no Mutex)");
684 } else {
685 warn!("Could not take TUN writer — falling back to forward_packet");
686 }
687
688 let client_db = self.client_db.clone();
689 tokio::spawn(async move {
690 Self::tun_read_loop(
691 tun_reader,
692 tun_tx,
693 sessions,
694 socket,
695 mask,
696 server_vpn_ip,
697 recorder,
698 client_db,
699 )
700 .await;
701 });
702 info!("TUN read loop spawned");
703 }
704 }
705
706 {
708 let sessions = self.session_manager.clone();
709 let recorder = self.recording_manager.clone();
710 let socket = self.udp_socket.as_ref().unwrap().clone();
711 let mdh = self.mask_catalog.packet_mdh_bytes();
712 let neural = self.neural_module.clone();
713 tokio::spawn(async move {
714 loop {
715 tokio::time::sleep(Duration::from_secs(5)).await;
716 if let Some(ref rec) = recorder {
717 let store = rec.store();
718 for outcome in rec.take_ready_or_stale(
719 aivpn_common::recording::RECORDING_IDLE_TIMEOUT_SECS,
720 ) {
721 let notify_session = match &outcome {
722 RecordingStopOutcome::Completed(completed) => {
723 sessions.get_session(&completed.session_id)
724 }
725 RecordingStopOutcome::Incomplete(incomplete) => {
726 sessions.get_session(&incomplete.session_id)
727 }
728 RecordingStopOutcome::NotFound => None,
729 };
730 Self::handle_recording_outcome(
731 &socket,
732 &sessions,
733 &store,
734 &mdh,
735 outcome,
736 notify_session,
737 )
738 .await;
739 }
740 }
741
742 let removed = sessions.cleanup_expired();
743 for session_id in &removed {
744 neural.lock().cleanup_stats(*session_id);
747 }
748 if let Some(ref rec) = recorder {
750 let store = rec.store();
751 for session_id in removed {
752 let outcome = rec.stop_for_session_end(session_id);
753 Self::handle_recording_outcome(
754 &socket, &sessions, &store, &mdh, outcome, None,
755 )
756 .await;
757 }
758 }
759 }
760 });
761 info!("Session cleanup / recording auto-finish task spawned (5s interval)");
762 }
763
764 if let Some(ref db) = self.client_db {
766 let db = db.clone();
767 tokio::spawn(async move {
768 loop {
769 tokio::time::sleep(Duration::from_secs(300)).await;
770 db.flush_stats();
771 }
772 });
773 info!("Client stats flush task spawned (300s interval)");
774 }
775
776 if let Some(ref db) = self.client_db {
778 let db = db.clone();
779 tokio::spawn(async move {
780 loop {
781 tokio::time::sleep(Duration::from_secs(10)).await;
782 db.reload_if_changed();
783 }
784 });
785 info!("Client DB hot-reload task spawned (10s interval)");
786 }
787
788 let gateway = Arc::new(self);
791 Self::process_packets_concurrent(gateway).await?;
792
793 Ok(())
794 }
795
796 async fn resonance_check_loop(
803 neural: Arc<parking_lot::Mutex<NeuralResonanceModule>>,
804 sessions: Arc<SessionManager>,
805 catalog: Arc<MaskCatalog>,
806 metrics: Arc<MetricsCollector>,
807 check_interval_secs: u64,
808 socket: Arc<UdpSocket>,
809 ) {
810 let interval = Duration::from_secs(check_interval_secs);
811
812 loop {
813 tokio::time::sleep(interval).await;
814
815 let session_checks: Vec<([u8; 16], String)> = sessions
817 .iter_sessions()
818 .filter_map(|entry| {
819 let sess = entry.value().lock();
820 let mask_id = sess
821 .mask
822 .as_ref()
823 .map(|m| m.mask_id.clone())
824 .unwrap_or_else(|| "unknown".to_string());
825 Some((sess.session_id, mask_id))
826 })
827 .collect();
828
829 if session_checks.is_empty() {
830 continue;
831 }
832
833 let mut pending_sends: Vec<(Vec<u8>, std::net::SocketAddr, [u8; 16], MaskProfile)> =
836 Vec::new();
837
838 {
839 let neural_guard = neural.lock();
840
841 for (session_id, mask_id) in &session_checks {
842 match neural_guard.check_resonance(*session_id, mask_id) {
844 Ok(result) => {
845 metrics
846 .record_neural_check(result.status == ResonanceStatus::Compromised);
847
848 match result.status {
849 ResonanceStatus::Compromised => {
850 warn!(
851 "Mask '{}' compromised (MSE={:.4}) — triggering rotation (Patent 3)",
852 mask_id, result.mse
853 );
854
855 catalog.mark_compromised(mask_id);
857
858 if let Some(new_mask) = catalog.select_fallback(mask_id) {
860 info!(
861 "Auto-rotating to mask '{}' ({} masks remaining)",
862 new_mask.mask_id,
863 catalog.available_count()
864 );
865
866 if let Some(session) = sessions.get_session(session_id) {
867 let client_addr = session.lock().client_addr;
868 match sessions
869 .build_mask_update_packet(&session, &new_mask)
870 {
871 Ok(packet) => {
872 pending_sends.push((
873 packet,
874 client_addr,
875 *session_id,
876 new_mask.clone(),
877 ));
878 }
879 Err(e) => {
880 warn!(
881 "Failed to build MaskUpdate packet: {}",
882 e
883 );
884 }
885 }
886 }
887
888 metrics.record_mask_rotation();
889 } else {
890 error!(
891 "No fallback masks available! All masks compromised."
892 );
893 }
894 }
895 ResonanceStatus::Warning => {
896 debug!(
897 "Mask '{}' warning (MSE={:.4}) — monitoring",
898 mask_id, result.mse
899 );
900 }
901 ResonanceStatus::Healthy => {
902 }
904 ResonanceStatus::Skip => {
905 }
907 }
908 }
909 Err(e) => {
910 debug!("Resonance check error for session: {}", e);
911 }
912 }
913
914 if neural_guard.is_mask_anomalous(mask_id) {
916 warn!(
917 "Anomaly detected for mask '{}' (packet loss / RTT spike)",
918 mask_id
919 );
920 metrics.record_dpi_attack();
921 catalog.mark_compromised(mask_id);
922
923 if let Some(new_mask) = catalog.select_fallback(mask_id) {
924 info!("Anomaly-triggered rotation to mask '{}'", new_mask.mask_id);
925 if let Some(session) = sessions.get_session(session_id) {
926 let client_addr = session.lock().client_addr;
927 if let Ok(packet) =
928 sessions.build_mask_update_packet(&session, &new_mask)
929 {
930 pending_sends.push((
931 packet,
932 client_addr,
933 *session_id,
934 new_mask.clone(),
935 ));
936 }
937 }
938 metrics.record_mask_rotation();
939 }
940 }
941 }
942 } for (packet, client_addr, session_id, new_mask) in pending_sends {
946 if let Err(e) = socket.send_to(&packet, client_addr).await {
947 warn!("Failed to send MaskUpdate to {}: {}", client_addr, e);
948 } else {
949 sessions.update_session_mask(&session_id, new_mask);
950 info!("MaskUpdate control message sent to {}", client_addr);
951 }
952 }
953 }
954 }
955
956 async fn tun_read_loop(
958 mut tun_reader: tun::DeviceReader,
959 tun_writer: tokio::sync::mpsc::Sender<Vec<u8>>,
960 sessions: Arc<SessionManager>,
961 socket: Arc<UdpSocket>,
962 mask: MaskProfile,
963 server_vpn_ip: Ipv4Addr,
964 recorder: Option<Arc<RecordingManager>>,
965 client_db: Option<Arc<ClientDatabase>>,
966 ) {
967 let mut buf = vec![0u8; MAX_PACKET_SIZE];
968 let server_ip = server_vpn_ip;
969
970 loop {
971 match tun_reader.read(&mut buf).await {
972 Ok(0) => continue,
973 Ok(n) => {
974 let packet = &buf[..n];
975
976 if packet.len() < 20 || (packet[0] >> 4) != 4 {
978 continue; }
980 let dst_ip = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
981
982 if dst_ip == server_ip && packet.len() >= 28 && packet[9] == 1 {
984 if let Some(reply) = Self::build_icmp_echo_reply(packet, &server_ip) {
986 let _ = tun_writer.send(reply).await;
987 }
988 continue;
989 }
990
991 let session = match sessions.get_session_by_vpn_ip(&dst_ip) {
993 Some(s) => s,
994 None => {
995 debug!("TUN: no session for VPN IP {}", dst_ip);
996 continue;
997 }
998 };
999
1000 let (session_id, client_addr, downlink_iat_ms, tag, mdh, ciphertext) = {
1003 let mut sess = session.lock();
1004 sess.commit_pending_mask();
1006 let session_id = sess.session_id;
1007 let client_addr = sess.client_addr;
1008 let seq_num = sess.next_seq() as u16;
1009 let (nonce, counter) = sess.next_send_nonce();
1010 let key = sess.keys.session_key.clone();
1011 let tag_secret = sess.keys.tag_secret;
1012 let downlink_iat_ms =
1013 sess.last_server_send.elapsed().as_secs_f64() * 1000.0;
1014 sess.last_server_send = Instant::now();
1015 let session_mdh = sess
1019 .mask
1020 .as_ref()
1021 .map(packet_mdh_bytes_for_mask)
1022 .unwrap_or_else(|| mask.header_template.clone());
1023 let estimated_out = (n + 64) as u64; sess.pending_bytes_out =
1027 sess.pending_bytes_out.saturating_add(estimated_out);
1028 let flush_out = if sess.pending_bytes_out >= 64 * 1024 {
1030 let bytes = sess.pending_bytes_out;
1031 let cid = sess.client_id.clone();
1032 sess.pending_bytes_out = 0;
1033 cid.map(|c| (c, bytes))
1034 } else {
1035 None
1036 };
1037 drop(sess); if let (Some(ref db), Some((cid, bytes))) = (&client_db, flush_out) {
1040 db.record_traffic(&cid, 0, bytes);
1041 }
1042
1043 let inner_header = InnerHeader {
1045 inner_type: InnerType::Data,
1046 seq_num,
1047 };
1048 let mut inner_payload = inner_header.encode().to_vec();
1049 inner_payload.extend_from_slice(packet);
1050
1051 let mdh = session_mdh;
1053
1054 let pad_len: u16 = 0;
1056 let mut padded = Vec::with_capacity(2 + inner_payload.len());
1057 padded.extend_from_slice(&pad_len.to_le_bytes());
1058 padded.extend_from_slice(&inner_payload);
1059
1060 let ciphertext = match encrypt_payload(&key, &nonce, &padded) {
1061 Ok(ct) => ct,
1062 Err(e) => {
1063 debug!("TUN: encrypt error: {}", e);
1064 continue;
1065 }
1066 };
1067
1068 let time_window = crypto::compute_time_window(
1070 crypto::current_timestamp_ms(),
1071 aivpn_common::crypto::DEFAULT_WINDOW_MS,
1072 );
1073 let tag = crypto::generate_resonance_tag(&tag_secret, counter, time_window);
1074
1075 (
1076 session_id,
1077 client_addr,
1078 downlink_iat_ms,
1079 tag,
1080 mdh,
1081 ciphertext,
1082 )
1083 };
1084
1085 let mut aivpn_packet =
1087 Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
1088 aivpn_packet.extend_from_slice(&tag);
1089 aivpn_packet.extend_from_slice(&mdh);
1090 aivpn_packet.extend_from_slice(&ciphertext);
1091
1092 if let Err(e) = socket.send_to(&aivpn_packet, client_addr).await {
1094 debug!("TUN: send failed: {}", e);
1095 } else {
1096 if let Some(ref recorder) = recorder {
1098 if recorder.is_recording(&session_id) {
1099 let meta = aivpn_common::recording::PacketMetadata {
1100 direction: aivpn_common::recording::Direction::Downlink,
1101 size: aivpn_packet.len() as u16,
1102 iat_ms: downlink_iat_ms,
1103 entropy: Self::compute_entropy(&ciphertext) as f32,
1104 header_prefix: aivpn_packet[TAG_SIZE
1105 ..TAG_SIZE + 16.min(aivpn_packet.len() - TAG_SIZE)]
1106 .to_vec(),
1107 timestamp_ns: std::time::SystemTime::now()
1108 .duration_since(std::time::UNIX_EPOCH)
1109 .unwrap_or_default()
1110 .as_nanos()
1111 as u64,
1112 };
1113 recorder.record_packet(session_id, meta);
1114 }
1115 }
1116 }
1117 }
1118 Err(e) => {
1119 error!("TUN read error: {}", e);
1120 tokio::time::sleep(Duration::from_millis(10)).await;
1121 }
1122 }
1123 }
1124 }
1125
1126 fn build_icmp_echo_reply(request: &[u8], server_ip: &Ipv4Addr) -> Option<Vec<u8>> {
1128 if request.len() < 28 {
1129 return None;
1130 }
1131
1132 let src_ip = Ipv4Addr::new(request[12], request[13], request[14], request[15]);
1134
1135 let icmp_type = request[20];
1137 if icmp_type != 8 {
1138 return None; }
1140
1141 let mut reply = Vec::with_capacity(request.len());
1143
1144 reply.push(0x45); reply.push(0x00); let total_len = (request.len() as u16).to_be_bytes();
1148 reply.extend_from_slice(&total_len);
1149 reply.extend_from_slice(&request[4..6]); reply.extend_from_slice(&request[6..8]); reply.push(64); reply.push(1); reply.push(0); reply.push(0);
1155 reply.extend_from_slice(&server_ip.octets()); reply.extend_from_slice(&src_ip.octets()); reply.push(0); reply.push(request[21]); reply.push(0); reply.push(0);
1163 reply.extend_from_slice(&request[24..28]); reply.extend_from_slice(&request[28..]); let checksum = Self::compute_checksum(&reply[20..]);
1168 reply[22] = (checksum >> 8) as u8;
1169 reply[23] = (checksum & 0xFF) as u8;
1170
1171 Some(reply)
1172 }
1173
1174 fn compute_checksum(data: &[u8]) -> u16 {
1176 let mut sum: u32 = 0;
1177 let mut i = 0;
1178
1179 while i + 1 < data.len() {
1181 sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
1182 i += 2;
1183 }
1184
1185 if i < data.len() {
1187 sum += (data[i] as u32) << 8;
1188 }
1189
1190 while (sum >> 16) != 0 {
1192 sum = (sum & 0xFFFF) + (sum >> 16);
1193 }
1194
1195 !sum as u16
1196 }
1197
1198 async fn tun_write_loop(mut writer: tun::DeviceWriter, mut rx: mpsc::Receiver<Vec<u8>>) {
1200 while let Some(packet) = rx.recv().await {
1201 if let Err(e) = writer.write_all(&packet).await {
1202 error!("TUN write error: {}", e);
1203 }
1204 }
1206 warn!("TUN write loop ended — channel closed");
1207 }
1208
1209 fn receive_worker_count() -> usize {
1210 std::thread::available_parallelism()
1211 .map(|count| count.get())
1212 .unwrap_or(4)
1213 .clamp(2, 16)
1214 }
1215
1216 fn worker_index_for_packet(
1217 &self,
1218 packet_data: &[u8],
1219 client_addr: SocketAddr,
1220 worker_count: usize,
1221 ) -> usize {
1222 if worker_count <= 1 {
1223 return 0;
1224 }
1225
1226 let mut shard_addr = client_addr;
1227
1228 if packet_data.len() >= TAG_SIZE {
1229 let mut tag = [0u8; TAG_SIZE];
1230 tag.copy_from_slice(&packet_data[..TAG_SIZE]);
1231
1232 if let Some(session) = self.session_manager.get_session_by_tag(&tag) {
1233 shard_addr = session.lock().client_addr;
1234 }
1235 }
1236
1237 let key = match shard_addr.ip() {
1238 IpAddr::V4(ip) => ((u32::from(ip) as u64) << 16) | shard_addr.port() as u64,
1239 IpAddr::V6(ip) => {
1240 let octets = ip.octets();
1241 u64::from_le_bytes(octets[..8].try_into().unwrap()) ^ shard_addr.port() as u64
1242 }
1243 };
1244
1245 (key as usize) % worker_count
1246 }
1247
1248 async fn process_packets_concurrent(gateway: Arc<Self>) -> Result<()> {
1252 let socket = gateway.udp_socket.as_ref().unwrap().clone();
1253 let mut buf = vec![0u8; MAX_PACKET_SIZE];
1254 let worker_count = Self::receive_worker_count();
1255 let queue_depth = 4096;
1256 let mut worker_txs = Vec::with_capacity(worker_count);
1257
1258 for worker_id in 0..worker_count {
1259 let (tx, mut rx) = mpsc::channel::<QueuedPacket>(queue_depth);
1260 worker_txs.push(tx);
1261
1262 let gw = gateway.clone();
1263 tokio::spawn(async move {
1264 while let Some(packet) = rx.recv().await {
1265 if let Err(e) = gw
1266 .handle_packet(&packet.packet_data, packet.client_addr)
1267 .await
1268 {
1269 debug!(
1270 "Worker {} packet error from {}: {}",
1271 worker_id,
1272 hash_addr(&packet.client_addr),
1273 e
1274 );
1275 }
1276 }
1277 warn!("Receive worker {} ended — channel closed", worker_id);
1278 });
1279 }
1280
1281 loop {
1282 match socket.recv_from(&mut buf).await {
1283 Ok((len, client_addr)) => {
1284 {
1286 let now = Instant::now();
1287 let mut entry = gateway
1288 .rate_limits
1289 .entry(client_addr.ip())
1290 .or_insert((0, now));
1291 if entry.1.elapsed() > Duration::from_secs(1) {
1292 entry.0 = 0;
1293 entry.1 = now;
1294 }
1295 entry.0 += 1;
1296 if entry.0 > gateway.config.per_ip_pps_limit {
1297 continue;
1298 }
1299 }
1300
1301 let packet_data = buf[..len].to_vec();
1302 let worker_idx =
1303 gateway.worker_index_for_packet(&packet_data, client_addr, worker_count);
1304 let packet = QueuedPacket {
1305 packet_data,
1306 client_addr,
1307 };
1308
1309 if worker_txs[worker_idx].send(packet).await.is_err() {
1310 return Err(Error::Channel(format!(
1311 "Receive worker {worker_idx} channel closed"
1312 )));
1313 }
1314 }
1315 Err(e) => {
1316 error!("UDP recv error: {}", e);
1317 tokio::time::sleep(Duration::from_millis(10)).await;
1318 }
1319 }
1320 }
1321 }
1322
1323 #[allow(dead_code)]
1325 async fn process_packets(&self) -> Result<()> {
1326 let socket = self.udp_socket.as_ref().unwrap();
1327 let mut buf = vec![0u8; MAX_PACKET_SIZE];
1328
1329 loop {
1330 match socket.recv_from(&mut buf).await {
1331 Ok((len, client_addr)) => {
1332 {
1334 let now = Instant::now();
1335 let mut entry =
1336 self.rate_limits.entry(client_addr.ip()).or_insert((0, now));
1337 if entry.1.elapsed() > Duration::from_secs(1) {
1338 entry.0 = 0;
1339 entry.1 = now;
1340 }
1341 entry.0 += 1;
1342 if entry.0 > self.config.per_ip_pps_limit {
1343 continue;
1344 }
1345 }
1346
1347 let packet_data = &buf[..len];
1348
1349 if let Err(e) = self.handle_packet(packet_data, client_addr).await {
1351 debug!("Packet error from {}: {}", hash_addr(&client_addr), e);
1352 }
1354 }
1355 Err(e) => {
1356 error!("UDP recv error: {}", e);
1357 tokio::time::sleep(Duration::from_millis(10)).await;
1358 }
1359 }
1360 }
1361 }
1362
1363 async fn handle_packet(&self, packet_data: &[u8], client_addr: SocketAddr) -> Result<()> {
1365 if packet_data.len() < TAG_SIZE + 2 {
1367 return Err(Error::InvalidPacket("Too short"));
1368 }
1369
1370 let mut tag = [0u8; TAG_SIZE];
1372 tag.copy_from_slice(&packet_data[0..TAG_SIZE]);
1373
1374 let (catalog_mdh_len, catalog_hs_mdh_len, _eph_offset, _eph_len) =
1376 self.mask_catalog.packet_layout();
1377 let mut is_new_session = false;
1378 let (session, counter, is_ratcheted_tag) = if let Some(session) =
1379 self.session_manager.get_session_by_tag(&tag)
1380 {
1381 let (counter, is_ratcheted) = {
1383 let sess = session.lock();
1384 sess.validate_tag(&tag)
1385 .ok_or(Error::InvalidPacket("Invalid tag"))?
1386 };
1387 (session, counter, is_ratcheted)
1388 } else if let Some((session, counter, is_ratcheted)) =
1389 self.session_manager.refresh_and_find_by_tag(&tag)
1390 {
1391 debug!(
1393 "Tag matched after refresh (counter={}, ratcheted={})",
1394 counter, is_ratcheted
1395 );
1396 (session, counter, is_ratcheted)
1397 } else if let Some((session, counter, is_ratcheted)) = self
1398 .session_manager
1399 .recover_session_by_tag(&tag, &client_addr.ip())
1400 {
1401 (session, counter, is_ratcheted)
1403 } else {
1404 if self
1411 .session_manager
1412 .has_recent_ratcheted_session_on_other_endpoint(
1413 &client_addr,
1414 Duration::from_secs(30),
1415 )
1416 {
1417 return Err(Error::InvalidPacket("Active session exists on other port"));
1418 }
1419
1420 let _handshake_guard = {
1431 let lock = {
1432 let entry = self
1433 .handshake_locks
1434 .entry(client_addr.ip())
1435 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())));
1436 entry.value().clone()
1437 };
1438 match lock.try_lock_owned() {
1439 Ok(guard) => guard,
1440 Err(_) => return Ok(()),
1441 }
1442 };
1443
1444 if self.session_manager.session_count() + 10 >= MAX_SESSIONS {
1451 debug!("Session pool near capacity ({}/{}), dropping unauthenticated handshake from {}",
1452 self.session_manager.session_count(), MAX_SESSIONS, hash_addr(&client_addr));
1453 return Ok(());
1454 }
1455
1456 {
1461 let ip = client_addr.ip();
1462 if let Some(entry) = self.handshake_cooldowns.get(&ip) {
1463 let (fail_count, last_fail) = *entry;
1464 let cooldown = Duration::from_millis((2000 * (1 << fail_count.min(3))) as u64);
1466 if last_fail.elapsed() < cooldown {
1467 debug!("Handshake cooldown active for {}: fail_count={}, elapsed={:?}, cooldown={:?}",
1468 hash_addr(&client_addr), fail_count, last_fail.elapsed(), cooldown);
1469 return Err(Error::InvalidPacket("Handshake cooldown active"));
1470 }
1471 }
1472 }
1473
1474 let builtin_bootstrap_masks = aivpn_common::mask::preset_masks::all();
1481 let (session, matched_client_id, bootstrap_mask) = if let Some(ref db) = self.client_db
1482 {
1483 let clients = db.list_clients();
1484 let mut found = None;
1485 'bootstrap: for client_cfg in &clients {
1486 if !client_cfg.enabled {
1487 continue;
1488 }
1489
1490 let psk = client_cfg.psk;
1491 let candidate_masks = self
1492 .bootstrap_descriptors
1493 .iter()
1494 .flat_map(|descriptor| derive_bootstrap_candidates(descriptor, Some(&psk)))
1495 .chain(builtin_bootstrap_masks.clone().into_iter())
1496 .collect::<Vec<_>>();
1497
1498 for bootstrap_mask in candidate_masks {
1499 let (
1500 _,
1501 candidate_handshake_mdh_len,
1502 candidate_eph_offset,
1503 candidate_eph_len,
1504 ) = packet_layout_for_mask(&bootstrap_mask);
1505 if packet_data.len() < TAG_SIZE + candidate_handshake_mdh_len {
1506 continue;
1507 }
1508 let eph_start = TAG_SIZE + candidate_eph_offset;
1509 if packet_data.len() < eph_start + candidate_eph_len {
1510 continue;
1511 }
1512
1513 let mut eph_pub = [0u8; 32];
1514 eph_pub.copy_from_slice(
1515 &packet_data[eph_start..eph_start + candidate_eph_len],
1516 );
1517 crypto::obfuscate_eph_pub(
1518 &mut eph_pub,
1519 &self.session_manager.server_public_key(),
1520 );
1521
1522 match self.session_manager.create_session(
1523 client_addr,
1524 eph_pub,
1525 Some(psk),
1526 Some(client_cfg.vpn_ip),
1527 ) {
1528 Ok(sess) => {
1529 let validation = sess.lock().validate_tag(&tag);
1530 if validation.is_some() {
1531 debug!(
1532 "Tag validation SUCCESS for client {} via bootstrap mask {}",
1533 client_cfg.id,
1534 bootstrap_mask.mask_id
1535 );
1536 found =
1537 Some((sess, Some(client_cfg.id.clone()), bootstrap_mask));
1538 break 'bootstrap;
1539 }
1540 let sid = sess.lock().session_id;
1541 self.session_manager.rollback_failed_session(&sid);
1542 }
1543 Err(e) => {
1544 debug!("create_session failed: {}", e);
1545 continue;
1546 }
1547 }
1548 }
1549 }
1550 match found {
1551 Some(f) => f,
1552 None => {
1553 let ip = client_addr.ip();
1555 let fail_count =
1556 self.handshake_cooldowns.get(&ip).map(|e| e.0).unwrap_or(0);
1557 self.handshake_cooldowns
1558 .insert(ip, (fail_count + 1, Instant::now()));
1559 warn!(
1560 "Handshake failed for {} (attempt #{}) — tag mismatch for all {} registered clients",
1561 hash_addr(&client_addr),
1562 fail_count + 1,
1563 clients.len()
1564 );
1565 return Err(Error::InvalidPacket(
1566 "No registered client matches this handshake",
1567 ));
1568 }
1569 }
1570 } else {
1571 let mut found = None;
1573 let candidate_masks = self
1574 .bootstrap_descriptors
1575 .iter()
1576 .flat_map(|descriptor| derive_bootstrap_candidates(descriptor, None))
1577 .chain(builtin_bootstrap_masks.clone().into_iter())
1578 .collect::<Vec<_>>();
1579 for bootstrap_mask in candidate_masks {
1580 let (_, candidate_handshake_mdh_len, candidate_eph_offset, candidate_eph_len) =
1581 packet_layout_for_mask(&bootstrap_mask);
1582 if packet_data.len() < TAG_SIZE + candidate_handshake_mdh_len {
1583 continue;
1584 }
1585 let eph_start = TAG_SIZE + candidate_eph_offset;
1586 if packet_data.len() < eph_start + candidate_eph_len {
1587 continue;
1588 }
1589
1590 let mut eph_pub = [0u8; 32];
1591 eph_pub.copy_from_slice(&packet_data[eph_start..eph_start + candidate_eph_len]);
1592 crypto::obfuscate_eph_pub(
1593 &mut eph_pub,
1594 &self.session_manager.server_public_key(),
1595 );
1596
1597 let sess =
1598 self.session_manager
1599 .create_session(client_addr, eph_pub, None, None)?;
1600 let validation = sess.lock().validate_tag(&tag);
1601 if validation.is_some() {
1602 found = Some((sess, None, bootstrap_mask));
1603 break;
1604 }
1605 let sid = sess.lock().session_id;
1606 self.session_manager.rollback_failed_session(&sid);
1607 }
1608
1609 found.ok_or_else(|| {
1610 Error::InvalidPacket("No bootstrap mask matched this handshake")
1611 })?
1612 };
1613
1614 let validation = {
1616 let sess = session.lock();
1617 sess.validate_tag(&tag)
1618 };
1619 let (counter, is_ratcheted) = match validation {
1620 Some(result) => result,
1621 None => {
1622 let session_id = session.lock().session_id;
1623 self.session_manager.rollback_failed_session(&session_id);
1624 return Err(Error::InvalidPacket("Tag mismatch on new session"));
1625 }
1626 };
1627
1628 {
1633 let (session_id, vpn_ip) = {
1634 let sess_lock = session.lock();
1635 (sess_lock.session_id, sess_lock.vpn_ip)
1636 };
1637 if let Some(vpn_ip) = vpn_ip {
1638 let removed = self
1639 .session_manager
1640 .cleanup_old_sessions_for_vpn_ip(&vpn_ip, &session_id);
1641 if let Some(ref recorder) = self.recording_manager {
1643 let socket = self.udp_socket.as_ref().unwrap().clone();
1644 let store = recorder.store();
1645 let mdh = self.mask_catalog.packet_mdh_bytes();
1646 for sid in removed {
1647 let outcome = recorder.stop_for_session_end(sid);
1648 Self::handle_recording_outcome(
1649 &socket,
1650 &self.session_manager,
1651 &store,
1652 &mdh,
1653 outcome,
1654 None,
1655 )
1656 .await;
1657 }
1658 }
1659 }
1660 }
1661
1662 self.handshake_cooldowns.remove(&client_addr.ip());
1664
1665 {
1666 let mut sess = session.lock();
1667 sess.mask = Some(bootstrap_mask.clone());
1668 }
1669
1670 if let (Some(ref db), Some(ref cid)) = (&self.client_db, &matched_client_id) {
1672 db.record_handshake(cid);
1673 session.lock().client_id = Some(cid.clone());
1675 debug!("Client '{}' authenticated via PSK", cid);
1676 }
1677
1678 self.send_server_hello(&session, client_addr).await?;
1679 self.send_bootstrap_descriptors(&session).await?;
1680
1681 if let Some(runtime_mask) = self.mask_catalog.primary_mask() {
1682 if runtime_mask.mask_id != bootstrap_mask.mask_id {
1683 match self
1684 .session_manager
1685 .build_mask_update_packet(&session, &runtime_mask)
1686 {
1687 Ok(packet) => {
1688 self.udp_socket
1689 .as_ref()
1690 .unwrap()
1691 .send_to(&packet, client_addr)
1692 .await?;
1693 }
1699 Err(e) => {
1700 warn!("Failed to send initial runtime MaskUpdate: {}", e);
1701 }
1702 }
1703 }
1704 }
1705
1706 is_new_session = true;
1710 debug!(
1711 "New session from {} (ServerHello sent)",
1712 hash_addr(&client_addr)
1713 );
1714 (session, counter, is_ratcheted)
1715 };
1716
1717 let (session_mdh_len, session_hs_mdh_len) = {
1724 let sess = session.lock();
1725 if let Some(ref mask) = sess.mask {
1726 let (p, h, _, _) = packet_layout_for_mask(mask);
1727 (p, h)
1728 } else {
1729 (catalog_mdh_len, catalog_hs_mdh_len)
1730 }
1731 };
1732 let packet_mdh_len = session_mdh_len;
1733 let handshake_mdh_len = session_hs_mdh_len;
1734 let is_pre_ratchet_retry = !is_new_session && !is_ratcheted_tag && {
1739 let sess = session.lock();
1740 !sess.is_ratcheted && packet_data.len() >= TAG_SIZE + handshake_mdh_len + 16
1741 };
1742 let mut payload_offsets: Vec<usize> = if is_new_session {
1743 vec![TAG_SIZE + handshake_mdh_len]
1744 } else if is_pre_ratchet_retry && handshake_mdh_len != packet_mdh_len {
1745 vec![TAG_SIZE + packet_mdh_len, TAG_SIZE + handshake_mdh_len]
1746 } else {
1747 vec![TAG_SIZE + packet_mdh_len]
1748 };
1749 if catalog_mdh_len != session_mdh_len {
1752 let catalog_offset = TAG_SIZE + catalog_mdh_len;
1753 if !payload_offsets.contains(&catalog_offset) {
1754 payload_offsets.push(catalog_offset);
1755 }
1756 }
1757
1758 let (payload_offset, padded_plaintext) = {
1759 let sess = session.lock();
1760 let nonce = self.compute_nonce(counter);
1761 let key = if is_new_session {
1766 &sess.keys.session_key
1767 } else if is_ratcheted_tag {
1768 &sess
1769 .ratcheted_keys
1770 .as_ref()
1771 .ok_or(Error::InvalidPacket("Ratcheted keys missing"))?
1772 .session_key
1773 } else {
1774 &sess.keys.session_key
1775 };
1776
1777 let mut decrypted = None;
1778 let mut last_error = None;
1779 for payload_offset in payload_offsets {
1780 if packet_data.len() <= payload_offset {
1781 continue;
1782 }
1783 let encrypted_payload = &packet_data[payload_offset..];
1784 match decrypt_payload(key, &nonce, encrypted_payload) {
1785 Ok(padded_plaintext) => {
1786 decrypted = Some((payload_offset, padded_plaintext));
1787 break;
1788 }
1789 Err(err) => last_error = Some(err),
1790 }
1791 }
1792
1793 match decrypted {
1794 Some(result) => result,
1795 None => {
1796 return Err(last_error.unwrap_or_else(|| Error::InvalidPacket("Invalid length")))
1797 }
1798 }
1799 };
1800 let encrypted_payload = &packet_data[payload_offset..];
1801
1802 if is_ratcheted_tag {
1807 let session_id = session.lock().session_id;
1808 self.session_manager.complete_session_ratchet(&session_id);
1809 self.session_manager.refresh_session_tags(&session_id);
1810 let sess = session.lock();
1811 info!(
1812 "PFS ratchet complete for {} — send_counter={}, counter={}",
1813 hash_addr(&client_addr),
1814 sess.send_counter,
1815 sess.counter
1816 );
1817 }
1818
1819 if padded_plaintext.len() < 2 {
1821 return Err(Error::InvalidPacket("Decrypted payload too short"));
1822 }
1823 let pad_len = u16::from_le_bytes([padded_plaintext[0], padded_plaintext[1]]) as usize;
1824 if 2 + pad_len > padded_plaintext.len() {
1825 return Err(Error::InvalidPacket("Invalid padding length"));
1826 }
1827 let plaintext = &padded_plaintext[2..padded_plaintext.len() - pad_len];
1828
1829 let mut client_db_flush: Option<(String, u64, u64)> = None;
1831 let (session_id, refresh_tags) = {
1832 let mut sess = session.lock();
1833 sess.mark_tag_received(counter);
1834 sess.last_seen = std::time::Instant::now();
1835
1836 if !is_new_session && sess.client_addr != client_addr {
1840 info!(
1841 "Client endpoint migrated: {} → {} (session keepalive active)",
1842 hash_addr(&sess.client_addr),
1843 hash_addr(&client_addr)
1844 );
1845 sess.client_addr = client_addr;
1846 }
1847
1848 let refresh_tags = counter.saturating_sub(sess.tag_window_base) >= 64;
1852 if refresh_tags {
1853 sess.update_tag_window();
1854 }
1855
1856 sess.pending_bytes_in = sess
1858 .pending_bytes_in
1859 .saturating_add(packet_data.len() as u64);
1860 if sess.pending_bytes_in >= 16 * 1024 || sess.pending_bytes_out >= 16 * 1024 {
1861 if let Some(cid) = sess.client_id.clone() {
1862 client_db_flush = Some((cid, sess.pending_bytes_in, sess.pending_bytes_out));
1863 }
1864 sess.pending_bytes_in = 0;
1865 sess.pending_bytes_out = 0;
1866 }
1867
1868 sess.update_fsm();
1869 (sess.session_id, refresh_tags)
1870 };
1871
1872 if refresh_tags {
1874 self.session_manager.refresh_session_tags(&session_id);
1875 }
1876
1877 if self.config.enable_neural {
1879 let packet_size = packet_data.len() as u16;
1880 let entropy = Self::compute_entropy(encrypted_payload);
1882 let iat_ms = {
1884 let sess = session.lock();
1885 let elapsed = sess.last_seen.elapsed();
1886 elapsed.as_secs_f64() * 1000.0
1887 };
1888 if counter & 0x0f == 0 {
1891 self.neural_module
1892 .lock()
1893 .record_traffic(session_id, packet_size, iat_ms, entropy);
1894 }
1895 self.metrics.record_packet_received(packet_data.len());
1896 }
1897
1898 if let Some(ref recorder) = self.recording_manager {
1900 let session_id = session.lock().session_id;
1901 if recorder.is_recording(&session_id) {
1902 let iat_ms = {
1903 let sess = session.lock();
1904 sess.last_seen.elapsed().as_secs_f64() * 1000.0
1905 };
1906 let meta = aivpn_common::recording::PacketMetadata {
1907 direction: aivpn_common::recording::Direction::Uplink,
1908 size: packet_data.len() as u16,
1909 iat_ms,
1910 entropy: Self::compute_entropy(encrypted_payload) as f32,
1911 header_prefix: packet_data
1912 [TAG_SIZE..TAG_SIZE + 16.min(packet_data.len() - TAG_SIZE)]
1913 .to_vec(),
1914 timestamp_ns: std::time::SystemTime::now()
1915 .duration_since(std::time::UNIX_EPOCH)
1916 .unwrap_or_default()
1917 .as_nanos() as u64,
1918 };
1919 recorder.record_packet(session_id, meta);
1920 }
1921 }
1922
1923 if let (Some(ref db), Some((cid, bytes_in, bytes_out))) = (&self.client_db, client_db_flush)
1925 {
1926 db.record_traffic(&cid, bytes_in, bytes_out);
1927 }
1928
1929 if !is_new_session {
1932 self.process_inner_payload(plaintext, &session, client_addr)
1933 .await?;
1934 }
1935
1936 Ok(())
1937 }
1938
1939 fn compute_nonce(&self, counter: u64) -> [u8; NONCE_SIZE] {
1941 let mut nonce = [0u8; NONCE_SIZE];
1942 nonce[0..8].copy_from_slice(&counter.to_le_bytes());
1943 nonce
1944 }
1945
1946 async fn process_inner_payload(
1948 &self,
1949 plaintext: &[u8],
1950 session: &Arc<parking_lot::Mutex<Session>>,
1951 client_addr: SocketAddr,
1952 ) -> Result<()> {
1953 if plaintext.len() < 4 {
1954 return Err(Error::InvalidPacket("Inner payload too short"));
1955 }
1956
1957 let inner_header = InnerHeader::decode(plaintext)?;
1958 let payload = &plaintext[4..];
1959
1960 match inner_header.inner_type {
1961 InnerType::Data => {
1962 debug!(
1964 "DATA packet from {} ({} bytes)",
1965 hash_addr(&client_addr),
1966 payload.len()
1967 );
1968
1969 if let Some(ref tx) = self.tun_write_tx {
1970 if tx.send(payload.to_vec()).await.is_err() {
1971 debug!("TUN write channel closed, dropping packet");
1972 }
1973 } else if let Some(ref nat) = self.nat_forwarder {
1974 nat.forward_packet(payload).await?;
1975 } else {
1976 debug!("NAT disabled, dropping packet");
1977 }
1978 }
1979 InnerType::Control => {
1980 self.handle_control_message(payload, session, client_addr)
1981 .await?;
1982 }
1983 InnerType::Fragment => {
1984 debug!("FRAGMENT packet (not implemented)");
1986 }
1987 InnerType::Ack => {
1988 debug!("ACK packet received");
1990 }
1991 }
1992
1993 Ok(())
1994 }
1995
1996 async fn handle_control_message(
1998 &self,
1999 payload: &[u8],
2000 session: &Arc<parking_lot::Mutex<Session>>,
2001 client_addr: SocketAddr,
2002 ) -> Result<()> {
2003 let control = ControlPayload::decode(payload)?;
2004
2005 match control {
2006 ControlPayload::KeyRotate { new_eph_pub: _ } => {
2007 info!("Key rotation request from {}", hash_addr(&client_addr));
2008 }
2010 ControlPayload::MaskUpdate { .. } => {
2011 warn!("Unexpected MASK_UPDATE from client");
2012 }
2013 ControlPayload::Keepalive => {
2014 debug!("Keepalive from {}", hash_addr(&client_addr));
2015 if !session.lock().is_ratcheted {
2016 self.send_server_hello(session, client_addr).await?;
2020 return Ok(());
2021 }
2022 let ack = ControlPayload::ControlAck {
2024 ack_seq: 0,
2025 ack_for_subtype: ControlSubtype::Keepalive as u8,
2026 };
2027 self.send_control_message(&ack, session).await?;
2028 }
2029 ControlPayload::TelemetryRequest { metric_flags: _ } => {
2030 debug!("Telemetry request from {}", hash_addr(&client_addr));
2031 let response = ControlPayload::TelemetryResponse {
2033 packet_loss: 0,
2034 rtt_ms: 10,
2035 jitter_ms: 2,
2036 buffer_pct: 25,
2037 };
2038 self.send_control_message(&response, session).await?;
2039 }
2040 ControlPayload::TelemetryResponse { .. } => {
2041 debug!("Telemetry response received");
2042 }
2043 ControlPayload::TimeSync { .. } => {
2044 debug!("Time sync request");
2045 }
2046 ControlPayload::Shutdown { reason } => {
2047 info!(
2048 "Shutdown request from {} (reason: {})",
2049 hash_addr(&client_addr),
2050 reason
2051 );
2052 let session_id = session.lock().session_id;
2054 self.session_manager.remove_session(&session_id);
2055 self.neural_module.lock().cleanup_stats(session_id);
2056 if let Some(ref recorder) = self.recording_manager {
2057 let socket = self.udp_socket.as_ref().unwrap().clone();
2058 let store = recorder.store();
2059 let mdh = self.mask_catalog.packet_mdh_bytes();
2060 let outcome = recorder.stop_for_session_end(session_id);
2061 Self::handle_recording_outcome(
2062 &socket,
2063 &self.session_manager,
2064 &store,
2065 &mdh,
2066 outcome,
2067 None,
2068 )
2069 .await;
2070 }
2071 }
2072 ControlPayload::ControlAck { .. } => {
2073 }
2075 ControlPayload::ServerHello { .. } => {
2076 warn!(
2077 "Unexpected ServerHello from client {}",
2078 hash_addr(&client_addr)
2079 );
2080 }
2081 ControlPayload::RecordingStart { service } => {
2082 let admin_key_id = {
2084 let sess = session.lock();
2085 sess.client_id.clone()
2086 };
2087 if !self.can_start_recording(admin_key_id.as_deref()) {
2088 warn!(
2089 "Recording rejected: unauthenticated client {}",
2090 hash_addr(&client_addr)
2091 );
2092 let failed = ControlPayload::RecordingFailed {
2093 reason: "Recording requires a recording-admin key".into(),
2094 };
2095 self.send_control_message(&failed, session).await?;
2096 return Ok(());
2097 }
2098 if let Some(ref recorder) = self.recording_manager {
2099 let session_id = session.lock().session_id;
2100 recorder.start(
2101 session_id,
2102 service.clone(),
2103 admin_key_id.unwrap_or_else(|| "admin".into()),
2104 );
2105 let ack = ControlPayload::RecordingAck {
2106 session_id,
2107 status: "started".into(),
2108 };
2109 self.send_control_message(&ack, session).await?;
2110 info!(
2111 "Recording started for '{}' from {}",
2112 service,
2113 hash_addr(&client_addr)
2114 );
2115 }
2116 }
2117 ControlPayload::RecordingStop {
2118 session_id: rec_session_id,
2119 } => {
2120 if let Some(ref recorder) = self.recording_manager {
2121 let owner_session_id = session.lock().session_id;
2122 if rec_session_id != owner_session_id {
2123 let failed = ControlPayload::RecordingFailed {
2124 reason: "Recording session does not belong to this client".into(),
2125 };
2126 self.send_control_message(&failed, session).await?;
2127 return Ok(());
2128 }
2129
2130 let socket = self.udp_socket.as_ref().unwrap().clone();
2131 let store = recorder.store();
2132 let mdh = self.mask_catalog.packet_mdh_bytes();
2133 let outcome = recorder.stop(owner_session_id);
2134 Self::handle_recording_outcome(
2135 &socket,
2136 &self.session_manager,
2137 &store,
2138 &mdh,
2139 outcome,
2140 Some(session.clone()),
2141 )
2142 .await;
2143 }
2144 }
2145 ControlPayload::RecordingStatusRequest => {
2146 let client_id = {
2147 let sess = session.lock();
2148 sess.client_id.clone()
2149 };
2150 let can_record = self.can_start_recording(client_id.as_deref());
2151 let active_service = self
2152 .recording_manager
2153 .as_ref()
2154 .and_then(|recorder| recorder.status(&session.lock().session_id))
2155 .map(|status| status.service);
2156 let response = ControlPayload::RecordingStatus {
2157 can_record,
2158 active_service,
2159 };
2160 self.send_control_message(&response, session).await?;
2161 }
2162 ControlPayload::RecordingAck { .. } => {
2163 }
2165 ControlPayload::RecordingComplete { .. } => {
2166 }
2168 ControlPayload::RecordingFailed { .. } => {
2169 }
2171 ControlPayload::RecordingStatus { .. } => {
2172 }
2174 ControlPayload::BootstrapDescriptorUpdate { .. } => {
2175 }
2177 }
2178
2179 Ok(())
2180 }
2181
2182 async fn send_control_message(
2184 &self,
2185 payload: &ControlPayload,
2186 session: &Arc<parking_lot::Mutex<Session>>,
2187 ) -> Result<()> {
2188 let socket = self.udp_socket.as_ref().unwrap();
2189 let mdh = {
2190 let mut sess = session.lock();
2191 sess.commit_pending_mask();
2192 sess.mask
2193 .as_ref()
2194 .map(packet_mdh_bytes_for_mask)
2195 .unwrap_or_else(|| self.mask_catalog.packet_mdh_bytes())
2196 };
2197 Self::send_control_message_via(socket, &mdh, payload, session).await
2198 }
2199
2200 async fn send_control_message_via(
2201 socket: &UdpSocket,
2202 mdh: &[u8],
2203 payload: &ControlPayload,
2204 session: &Arc<parking_lot::Mutex<Session>>,
2205 ) -> Result<()> {
2206 let encoded = payload.encode()?;
2207 let (mut inner_payload, nonce, counter, keys, client_addr) = {
2208 let mut sess = session.lock();
2209 let inner_header = InnerHeader {
2210 inner_type: InnerType::Control,
2211 seq_num: sess.next_seq() as u16,
2212 };
2213 let inner_payload = inner_header.encode().to_vec();
2214 let (nonce, counter) = sess.next_send_nonce();
2215 let keys = sess.keys.clone();
2216 let client_addr = sess.client_addr;
2217 (inner_payload, nonce, counter, keys, client_addr)
2218 };
2219 inner_payload.extend_from_slice(&encoded);
2220 let pad_len = 16u16;
2221 let mut padded = Vec::with_capacity(2 + inner_payload.len() + pad_len as usize);
2222 padded.extend_from_slice(&pad_len.to_le_bytes());
2223 padded.extend_from_slice(&inner_payload);
2224 {
2225 use rand::Rng;
2226 let mut rng = rand::thread_rng();
2227 for _ in 0..pad_len {
2228 padded.push(rng.gen::<u8>());
2229 }
2230 }
2231 let ciphertext = encrypt_payload(&keys.session_key, &nonce, &padded)?;
2232 let time_window = crypto::compute_time_window(
2233 crypto::current_timestamp_ms(),
2234 aivpn_common::crypto::DEFAULT_WINDOW_MS,
2235 );
2236 let tag = crypto::generate_resonance_tag(&keys.tag_secret, counter, time_window);
2237 let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
2238 packet.extend_from_slice(&tag);
2239 packet.extend_from_slice(mdh);
2240 packet.extend_from_slice(&ciphertext);
2241 socket.send_to(&packet, client_addr).await?;
2242 Ok(())
2243 }
2244
2245 async fn send_server_hello(
2246 &self,
2247 session: &Arc<parking_lot::Mutex<Session>>,
2248 client_addr: SocketAddr,
2249 ) -> Result<()> {
2250 let (server_eph_pub, signature, network_config) = {
2251 let sess = session.lock();
2252 match (sess.server_eph_pub, sess.server_hello_signature) {
2253 (Some(pub_key), Some(sig)) => {
2254 let network_config = sess
2255 .vpn_ip
2256 .and_then(|vpn_ip| self.config.network_config.client_config(vpn_ip).ok());
2257 (pub_key, sig, network_config)
2258 }
2259 _ => return Err(Error::Session("Missing ratchet data".into())),
2260 }
2261 };
2262
2263 let hello = ControlPayload::ServerHello {
2264 server_eph_pub,
2265 signature,
2266 network_config,
2267 };
2268 let encoded = hello.encode()?;
2269 let inner_header = InnerHeader {
2270 inner_type: InnerType::Control,
2271 seq_num: 0,
2272 };
2273 let mut inner_payload = inner_header.encode().to_vec();
2274 inner_payload.extend_from_slice(&encoded);
2275 let packet = self.build_packet(&inner_payload, session)?;
2276 let socket = self.udp_socket.as_ref().unwrap();
2277 let sent = socket.send_to(&packet, client_addr).await?;
2278 debug!("ServerHello sent: {} bytes to {}", sent, client_addr);
2279 Ok(())
2280 }
2281
2282 fn build_packet(
2285 &self,
2286 plaintext: &[u8],
2287 session: &Arc<parking_lot::Mutex<Session>>,
2288 ) -> Result<Vec<u8>> {
2289 let mut sess = session.lock();
2290
2291 let (nonce, counter) = sess.next_send_nonce();
2293
2294 let pad_len = 16u16;
2297 let mut padded = Vec::with_capacity(2 + plaintext.len() + pad_len as usize);
2298 padded.extend_from_slice(&pad_len.to_le_bytes());
2299 padded.extend_from_slice(plaintext);
2300 use rand::Rng;
2301 let mut rng = rand::thread_rng();
2302 for _ in 0..pad_len {
2303 padded.push(rng.gen::<u8>());
2304 }
2305
2306 let ciphertext = encrypt_payload(&sess.keys.session_key, &nonce, &padded)?;
2307
2308 let time_window = crypto::compute_time_window(
2310 crypto::current_timestamp_ms(),
2311 aivpn_common::crypto::DEFAULT_WINDOW_MS,
2312 );
2313 let tag = crypto::generate_resonance_tag(&sess.keys.tag_secret, counter, time_window);
2314 let current_mask = sess.mask.clone();
2315 drop(sess);
2316
2317 let mdh = current_mask
2320 .as_ref()
2321 .map(packet_mdh_bytes_for_mask)
2322 .unwrap_or_else(|| self.mask_catalog.packet_mdh_bytes());
2323
2324 let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
2326 packet.extend_from_slice(&tag);
2327 packet.extend_from_slice(&mdh);
2328 packet.extend_from_slice(&ciphertext);
2329
2330 Ok(packet)
2331 }
2332
2333 fn compute_entropy(data: &[u8]) -> f64 {
2335 if data.is_empty() {
2336 return 0.0;
2337 }
2338 let mut counts = [0u32; 256];
2339 for &b in data {
2340 counts[b as usize] += 1;
2341 }
2342 let len = data.len() as f64;
2343 let mut entropy = 0.0;
2344 for &c in &counts {
2345 if c > 0 {
2346 let p = c as f64 / len;
2347 entropy -= p * p.log2();
2348 }
2349 }
2350 entropy
2351 }
2352
2353 pub fn mask_catalog(&self) -> &Arc<MaskCatalog> {
2355 &self.mask_catalog
2356 }
2357
2358 pub fn metrics(&self) -> &Arc<MetricsCollector> {
2360 &self.metrics
2361 }
2362}
2363
2364#[cfg(test)]
2365mod tests {
2366 use super::MaskCatalog;
2367 use aivpn_common::crypto::TAG_SIZE;
2368 use aivpn_common::mask::preset_masks::webrtc_zoom_v3;
2369
2370 #[test]
2371 fn packet_layout_extracts_embedded_eph_pub_from_mdh() {
2372 let catalog = MaskCatalog::new();
2373 let mask = webrtc_zoom_v3();
2374 catalog.register_mask(mask.clone());
2375 catalog.set_primary_mask_id(mask.mask_id.clone());
2376 let (packet_mdh_len, handshake_mdh_len, eph_offset, eph_len) = catalog.packet_layout();
2377
2378 let mut mdh = mask.header_template.clone();
2379 if mdh.len() < handshake_mdh_len {
2380 mdh.resize(handshake_mdh_len, 0);
2381 }
2382
2383 let expected_eph = [0x5au8; 32];
2384 mdh[eph_offset..eph_offset + eph_len].copy_from_slice(&expected_eph);
2385
2386 let mut packet = vec![0u8; TAG_SIZE];
2387 packet.extend_from_slice(&mdh);
2388 packet.extend_from_slice(&[0xabu8; 24]);
2389
2390 let eph_start = TAG_SIZE + eph_offset;
2391 let payload_start = TAG_SIZE + handshake_mdh_len;
2392
2393 assert_eq!(
2394 packet_mdh_len, 20,
2395 "regular STUN packet MDH length must stay at 20 bytes"
2396 );
2397 assert_eq!(
2398 handshake_mdh_len, 52,
2399 "handshake MDH length must include embedded eph_pub"
2400 );
2401 assert_eq!(&packet[eph_start..eph_start + eph_len], &expected_eph);
2402 assert_eq!(&packet[payload_start..], &[0xabu8; 24]);
2403 }
2404}