Skip to main content

aivpn_server/
gateway.rs

1//! Gateway Engine - Full Implementation
2//!
3//! Handles:
4//! - UDP packet reception with O(1) tag validation
5//! - Decryption and de-mimicry
6//! - NAT forwarding to internet
7//! - Bidirectional traffic shaping
8//! - Neural Resonance validation (Patent 1)
9//! - Automatic mask rotation on compromise (Patent 3)
10
11use dashmap::DashMap;
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::UdpSocket;
17use tokio::sync::mpsc;
18use tracing::{debug, error, info, warn};
19
20use aivpn_common::crypto::{self, decrypt_payload, encrypt_payload, NONCE_SIZE, TAG_SIZE};
21use aivpn_common::error::{Error, Result};
22use aivpn_common::mask::{
23    current_unix_secs, derive_bootstrap_candidates, BootstrapDescriptor, MaskProfile,
24};
25use aivpn_common::network_config::VpnNetworkConfig;
26use aivpn_common::protocol::{
27    ControlPayload, ControlSubtype, InnerHeader, InnerType, MAX_PACKET_SIZE,
28};
29
30use crate::client_db::ClientDatabase;
31use crate::mask_gen::generate_and_store_mask;
32use crate::mask_store::MaskStore;
33use crate::metrics::MetricsCollector;
34use crate::nat::NatForwarder;
35use crate::neural::{NeuralConfig, NeuralResonanceModule, ResonanceStatus};
36use crate::recording::RecordingManager;
37use crate::recording::{RecordingStopOutcome, RecordingStopReason};
38use crate::session::{Session, SessionManager, MAX_SESSIONS};
39
40struct QueuedPacket {
41    packet_data: Vec<u8>,
42    client_addr: SocketAddr,
43}
44
45/// Gateway configuration
46#[derive(Clone)]
47pub struct GatewayConfig {
48    pub listen_addr: String,
49    pub per_ip_pps_limit: u64,
50    pub tun_name: String,
51    pub tun_addr: String,
52    pub tun_netmask: String,
53    pub network_config: VpnNetworkConfig,
54    pub server_private_key: [u8; 32],
55    pub signing_key: [u8; 64],
56    pub enable_nat: bool,
57    /// Enable neural resonance module (Patent 1)
58    pub enable_neural: bool,
59    /// Neural resonance configuration
60    pub neural_config: NeuralConfig,
61    /// Client database for PSK-based authentication
62    pub client_db: Option<Arc<ClientDatabase>>,
63    /// Directory for mask storage (default: /var/lib/aivpn/masks)
64    pub mask_dir: std::path::PathBuf,
65    /// Session hard timeout in seconds (default: 7 days). `None` uses the default.
66    pub session_timeout_secs: Option<u64>,
67    /// Session idle timeout in seconds (default: 300). `None` uses the default.
68    pub idle_timeout_secs: Option<u64>,
69    /// Optional custom bootstrap masks embedded into signed descriptors.
70    pub bootstrap_masks: Vec<MaskProfile>,
71}
72
73impl Default for GatewayConfig {
74    fn default() -> Self {
75        Self {
76            listen_addr: "0.0.0.0:443".to_string(),
77            per_ip_pps_limit: 1000,
78            tun_name: "aivpn0".to_string(),
79            tun_addr: "10.0.0.1".to_string(),
80            tun_netmask: "255.255.255.0".to_string(),
81            network_config: VpnNetworkConfig::default(),
82            server_private_key: [0u8; 32],
83            signing_key: [0u8; 64],
84            enable_nat: true,
85            enable_neural: true,
86            neural_config: NeuralConfig::default(),
87            client_db: None,
88            mask_dir: std::path::PathBuf::from("/var/lib/aivpn/masks"),
89            session_timeout_secs: None,
90            idle_timeout_secs: None,
91            bootstrap_masks: Vec::new(),
92        }
93    }
94}
95
96/// Mask catalog for automatic rotation (Patent 3 + Patent 9)
97///
98/// Holds a pool of pre-generated masks. When neural resonance detects
99/// that a mask is compromised by DPI, the catalog provides a replacement.
100pub struct MaskCatalog {
101    /// Available masks (mask_id → MaskProfile)
102    masks: DashMap<String, MaskProfile>,
103    /// Compromised mask IDs — never reuse
104    compromised: DashMap<String, Instant>,
105    /// Primary mask used for initial handshake parsing.
106    primary_mask_id: parking_lot::Mutex<String>,
107}
108
109impl MaskCatalog {
110    pub fn new() -> Self {
111        Self {
112            masks: DashMap::new(),
113            compromised: DashMap::new(),
114            primary_mask_id: parking_lot::Mutex::new(String::new()),
115        }
116    }
117
118    /// Set the primary mask ID (first mask loaded from disk)
119    pub fn set_primary_mask_id(&self, mask_id: String) {
120        *self.primary_mask_id.lock() = mask_id;
121    }
122
123    /// Register a new mask (e.g., received via passive distribution or neural unpack)
124    pub fn register_mask(&self, mask: MaskProfile) {
125        if !self.compromised.contains_key(&mask.mask_id) {
126            self.masks.insert(mask.mask_id.clone(), mask);
127        }
128    }
129
130    /// Mark a mask as compromised — remove from rotation
131    pub fn mark_compromised(&self, mask_id: &str) {
132        self.compromised.insert(mask_id.to_string(), Instant::now());
133        self.masks.remove(mask_id);
134    }
135
136    /// Remove a mask from live rotation without marking it as compromised.
137    pub fn remove_mask(&self, mask_id: &str) {
138        self.masks.remove(mask_id);
139    }
140
141    /// Select the best non-compromised mask, excluding `current_mask_id`
142    pub fn select_fallback(&self, current_mask_id: &str) -> Option<MaskProfile> {
143        self.masks
144            .iter()
145            .filter(|e| e.key() != current_mask_id)
146            .map(|e| e.value().clone())
147            .next()
148    }
149
150    /// Get mask count
151    pub fn available_count(&self) -> usize {
152        self.masks.len()
153    }
154
155    /// Get the primary packet layout for client->server traffic.
156    /// Returns `(packet_mdh_len, handshake_mdh_len, eph_offset, eph_len)`.
157    /// Normal packets use only the protocol header, while the initial
158    /// handshake embeds `eph_pub` inside the MDH at `eph_offset`.
159    pub fn packet_layout(&self) -> (usize, usize, usize, usize) {
160        let fallback = (20usize, 52usize, 20usize, 32usize);
161        let Some(mask) = self.primary_mask() else {
162            return fallback;
163        };
164
165        packet_layout_for_mask(&mask)
166    }
167
168    /// Get the regular MDH bytes used for server->client packets.
169    /// Uses HeaderSpec for dynamic per-packet generation when available (Issue #30 fix).
170    pub fn packet_mdh_bytes(&self) -> Vec<u8> {
171        self.primary_mask()
172            .map(|mask| packet_mdh_bytes_for_mask(&mask))
173            .unwrap_or_else(|| vec![0u8; 20])
174    }
175
176    pub fn primary_mask(&self) -> Option<MaskProfile> {
177        let primary_id = self.primary_mask_id.lock().clone();
178        self.masks
179            .get(&primary_id)
180            .map(|entry| entry.value().clone())
181            .or_else(|| self.masks.iter().next().map(|entry| entry.value().clone()))
182    }
183}
184
185fn packet_layout_for_mask(mask: &MaskProfile) -> (usize, usize, usize, usize) {
186    let eph_offset = mask.eph_pub_offset as usize;
187    let eph_len = mask.eph_pub_length as usize;
188    let packet_mdh_len = mask
189        .header_spec
190        .as_ref()
191        .map(|spec| spec.min_length())
192        .unwrap_or_else(|| mask.header_template.len());
193    let handshake_mdh_len = packet_mdh_len.max(eph_offset.saturating_add(eph_len));
194    (packet_mdh_len, handshake_mdh_len, eph_offset, eph_len)
195}
196
197fn packet_mdh_bytes_for_mask(mask: &MaskProfile) -> Vec<u8> {
198    if let Some(ref spec) = mask.header_spec {
199        let mut rng = rand::thread_rng();
200        spec.generate(&mut rng)
201    } else {
202        mask.header_template.clone()
203    }
204}
205
206/// Hash a socket address for privacy-preserving logging (MED-4)
207fn hash_addr(addr: &SocketAddr) -> String {
208    let hash = crypto::blake3_hash(addr.to_string().as_bytes());
209    format!(
210        "{:02x}{:02x}{:02x}{:02x}",
211        hash[0], hash[1], hash[2], hash[3]
212    )
213}
214
215/// Gateway server
216pub struct Gateway {
217    config: GatewayConfig,
218    session_manager: Arc<SessionManager>,
219    udp_socket: Option<Arc<UdpSocket>>,
220    nat_forwarder: Option<Arc<NatForwarder>>,
221    /// Channel-based TUN writer (replaces Mutex for upload throughput)
222    tun_write_tx: Option<mpsc::Sender<Vec<u8>>>,
223    /// Per-IP rate limiter: (packet_count, window_start)
224    rate_limits: Arc<DashMap<IpAddr, (u64, Instant)>>,
225    /// Per-IP handshake failure cooldown: (failure_count, last_failure_time)
226    /// Prevents rapid session-creation loops when client retries with stale keys
227    handshake_cooldowns: Arc<DashMap<IpAddr, (u32, Instant)>>,
228    /// Per-IP handshake mutex: serializes concurrent handshakes arriving on
229    /// different source ports from the same client, preventing duplicate sessions
230    /// that compete for the same VPN IP and cause aead::Error on data packets.
231    handshake_locks: Arc<DashMap<IpAddr, Arc<tokio::sync::Mutex<()>>>>,
232    /// Neural Resonance Module (Patent 1) — periodic traffic validation
233    neural_module: Arc<parking_lot::Mutex<NeuralResonanceModule>>,
234    /// Mask catalog for automatic rotation (Patent 3)
235    mask_catalog: Arc<MaskCatalog>,
236    /// Metrics collector
237    metrics: Arc<MetricsCollector>,
238    /// Client database for PSK-based authentication
239    client_db: Option<Arc<ClientDatabase>>,
240    /// Recording manager for auto mask recording
241    recording_manager: Option<Arc<RecordingManager>>,
242    /// Mask store for auto-generated masks
243    #[allow(dead_code)]
244    mask_store: Option<Arc<MaskStore>>,
245    /// Active bootstrap descriptors for previous/current/next epochs.
246    bootstrap_descriptors: Vec<BootstrapDescriptor>,
247}
248
249const BOOTSTRAP_ROTATION_SECS: u64 = 24 * 3600;
250const BOOTSTRAP_DESCRIPTOR_CANDIDATES: u8 = 4;
251
252fn bootstrap_epoch(unix_secs: u64) -> u64 {
253    unix_secs / BOOTSTRAP_ROTATION_SECS
254}
255
256pub fn derive_server_signing_key(server_private_key: &[u8; 32]) -> ed25519_dalek::SigningKey {
257    let seed = blake3::derive_key("aivpn-ed25519-signing-v1", server_private_key);
258    ed25519_dalek::SigningKey::from_bytes(&seed)
259}
260
261fn sign_bootstrap_descriptor(
262    mut descriptor: BootstrapDescriptor,
263    signing_key: &ed25519_dalek::SigningKey,
264) -> BootstrapDescriptor {
265    use ed25519_dalek::Signer;
266    descriptor.signature = signing_key.sign(&descriptor.signing_bytes()).to_bytes();
267    descriptor
268}
269
270fn build_bootstrap_descriptor(
271    server_seed: &[u8; 32],
272    signing_key: &ed25519_dalek::SigningKey,
273    epoch: u64,
274    bootstrap_masks: &[MaskProfile],
275) -> BootstrapDescriptor {
276    let mut hasher = blake3::Hasher::new_keyed(server_seed);
277    hasher.update(&epoch.to_le_bytes());
278    let hash = hasher.finalize();
279    let mut kdf_salt = [0u8; 32];
280    kdf_salt.copy_from_slice(&hash.as_bytes()[..32]);
281    let created_at = epoch * BOOTSTRAP_ROTATION_SECS;
282    let expires_at = created_at + (2 * BOOTSTRAP_ROTATION_SECS);
283    let (base_mask_ids, embedded_masks) = if bootstrap_masks.is_empty() {
284        (
285            aivpn_common::mask::preset_masks::all()
286                .into_iter()
287                .map(|mask| mask.mask_id)
288                .collect(),
289            Vec::new(),
290        )
291    } else {
292        (Vec::new(), bootstrap_masks.to_vec())
293    };
294
295    sign_bootstrap_descriptor(
296        BootstrapDescriptor {
297            descriptor_id: format!("epoch-{}", epoch),
298            version: 1,
299            created_at,
300            expires_at,
301            base_mask_ids,
302            embedded_masks,
303            candidate_count: BOOTSTRAP_DESCRIPTOR_CANDIDATES,
304            kdf_salt,
305            signature: [0u8; 64],
306        },
307        signing_key,
308    )
309}
310
311pub fn build_bootstrap_descriptors(
312    server_seed: &[u8; 32],
313    signing_key: &ed25519_dalek::SigningKey,
314    bootstrap_masks: &[MaskProfile],
315) -> Vec<BootstrapDescriptor> {
316    let epoch = bootstrap_epoch(current_unix_secs());
317    [epoch.saturating_sub(1), epoch, epoch.saturating_add(1)]
318        .into_iter()
319        .map(|value| build_bootstrap_descriptor(server_seed, signing_key, value, bootstrap_masks))
320        .collect()
321}
322
323impl Gateway {
324    fn can_start_recording(&self, client_id: Option<&str>) -> bool {
325        let Some(client_id) = client_id else {
326            return false;
327        };
328
329        if client_id == "admin" {
330            return true;
331        }
332
333        self.client_db
334            .as_ref()
335            .and_then(|db| db.find_by_id(client_id))
336            .map(|client| client.name.starts_with("recording-admin"))
337            .unwrap_or(false)
338    }
339
340    async fn handle_recording_outcome(
341        socket: &Arc<UdpSocket>,
342        sessions: &Arc<SessionManager>,
343        store: &Arc<MaskStore>,
344        mdh: &[u8],
345        outcome: RecordingStopOutcome,
346        notify_session: Option<Arc<parking_lot::Mutex<Session>>>,
347    ) {
348        match outcome {
349            RecordingStopOutcome::Completed(completed) => {
350                if let Some(ref session) = notify_session {
351                    let ack = ControlPayload::RecordingAck {
352                        session_id: completed.session_id,
353                        status: "analyzing".into(),
354                    };
355                    if let Err(e) =
356                        Self::send_control_message_via(socket.as_ref(), mdh, &ack, session).await
357                    {
358                        warn!("Failed to send RecordingAck: {}", e);
359                    }
360                }
361
362                info!(
363                    "Recording stopped for '{}' ({} packets, {}s), analyzing...",
364                    completed.service, completed.total_packets, completed.duration_secs
365                );
366
367                let socket = socket.clone();
368                let sessions = sessions.clone();
369                let store = store.clone();
370                let mdh = mdh.to_vec();
371                tokio::spawn(async move {
372                    match generate_and_store_mask(&completed.service, &completed.packets, &store)
373                        .await
374                    {
375                        Ok(mask_id) => {
376                            info!(
377                                "✅ Mask generated: '{}' for service '{}' by {}",
378                                mask_id, completed.service, completed.admin_key_id
379                            );
380                            if let Some(target_session) =
381                                sessions.get_session(&completed.session_id)
382                            {
383                                let confidence = store
384                                    .get_mask(&mask_id)
385                                    .map(|entry| entry.stats.confidence)
386                                    .unwrap_or(0.0);
387                                let payload = ControlPayload::RecordingComplete {
388                                    service: completed.service.clone(),
389                                    mask_id,
390                                    confidence,
391                                };
392                                if let Err(e) = Self::send_control_message_via(
393                                    socket.as_ref(),
394                                    &mdh,
395                                    &payload,
396                                    &target_session,
397                                )
398                                .await
399                                {
400                                    warn!("Failed to send RecordingComplete: {}", e);
401                                }
402                            }
403                        }
404                        Err(e) => {
405                            warn!("Mask generation failed for '{}': {}", completed.service, e);
406                            if let Some(target_session) =
407                                sessions.get_session(&completed.session_id)
408                            {
409                                let payload = ControlPayload::RecordingFailed {
410                                    reason: e.to_string(),
411                                };
412                                if let Err(send_err) = Self::send_control_message_via(
413                                    socket.as_ref(),
414                                    &mdh,
415                                    &payload,
416                                    &target_session,
417                                )
418                                .await
419                                {
420                                    warn!("Failed to send RecordingFailed: {}", send_err);
421                                }
422                            }
423                        }
424                    }
425                });
426            }
427            RecordingStopOutcome::Incomplete(incomplete) => {
428                let reason = match incomplete.reason {
429                    RecordingStopReason::IdleTimeout => {
430                        "Recording stopped after idle timeout before enough traffic was captured"
431                    }
432                    RecordingStopReason::SessionEnded => {
433                        "Recording ended with the session before enough traffic was captured"
434                    }
435                    _ => "Too few packets or too short duration",
436                };
437                if let Some(ref session) = notify_session {
438                    let payload = ControlPayload::RecordingFailed {
439                        reason: reason.into(),
440                    };
441                    if let Err(e) =
442                        Self::send_control_message_via(socket.as_ref(), mdh, &payload, session)
443                            .await
444                    {
445                        warn!("Failed to send RecordingFailed: {}", e);
446                    }
447                }
448                warn!(
449                    "Recording for '{}' ended without mask generation: {} packets, {}s ({:?})",
450                    incomplete.service,
451                    incomplete.total_packets,
452                    incomplete.duration_secs,
453                    incomplete.reason
454                );
455            }
456            RecordingStopOutcome::NotFound => {}
457        }
458    }
459
460    pub fn new(config: GatewayConfig) -> Result<Self> {
461        // Create server keypair (use config key if provided, otherwise generate ephemeral)
462        let server_keys = if config.server_private_key != [0u8; 32] {
463            crypto::KeyPair::from_private_key(config.server_private_key)
464        } else {
465            crypto::KeyPair::generate()
466        };
467
468        // Create Ed25519 signing key
469        let signing_key = derive_server_signing_key(&config.server_private_key);
470        let bootstrap_descriptors = build_bootstrap_descriptors(
471            &config.server_private_key,
472            &signing_key,
473            &config.bootstrap_masks,
474        );
475
476        // Initialize mask catalog (empty — populated from disk only)
477        let mask_catalog = Arc::new(MaskCatalog::new());
478
479        // Initialize mask store — loads masks from disk into catalog
480        let mask_store = Arc::new(MaskStore::new(
481            mask_catalog.clone(),
482            config.mask_dir.clone(),
483        ));
484
485        // Runtime primary mask is selected from the masks loaded on disk.
486        // Bootstrap compatibility is handled separately using built-in presets.
487        let primary_id = if let Some(first) = mask_catalog.masks.iter().next() {
488            let id = first.key().clone();
489            id
490        } else {
491            String::new()
492        };
493        if !primary_id.is_empty() {
494            info!("Primary mask set to '{}' (loaded from disk)", primary_id);
495            mask_catalog.set_primary_mask_id(primary_id);
496        } else {
497            warn!("No masks found in {:?} — server will not accept connections until masks are recorded", config.mask_dir);
498        }
499
500        // Get default mask from catalog (required — at least one mask must exist on disk)
501        let default_mask = mask_catalog.primary_mask().ok_or_else(|| {
502            Error::Session(format!(
503                "No masks found in {:?} — place mask JSON files there before starting the server",
504                config.mask_dir
505            ))
506        })?;
507
508        let session_manager = Arc::new(SessionManager::with_timeouts(
509            server_keys,
510            signing_key,
511            default_mask,
512            config.session_timeout_secs,
513            config.idle_timeout_secs,
514        ));
515
516        // Initialize neural resonance module (Patent 1)
517        let mut neural = NeuralResonanceModule::new(config.neural_config.clone())
518            .map_err(|e| Error::Session(format!("Neural module init failed: {}", e)))?;
519
520        if config.enable_neural {
521            // Register all catalog masks for signature-based resonance checking
522            for entry in mask_catalog.masks.iter() {
523                let _ = neural.register_mask(entry.value());
524            }
525            // Load neural model (Baked Mask Encoder — ~66KB per mask)
526            let _ = neural.load_model();
527            info!("Neural Resonance Module initialized (Patent 1)");
528        }
529
530        let recording_manager = Arc::new(RecordingManager::new(mask_store.clone()));
531        info!(
532            "Auto Mask Recording system initialized ({} masks loaded from disk)",
533            mask_catalog.available_count()
534        );
535
536        Ok(Self {
537            config: config.clone(),
538            session_manager,
539            udp_socket: None,
540            nat_forwarder: None,
541            tun_write_tx: None,
542            rate_limits: Arc::new(DashMap::new()),
543            handshake_cooldowns: Arc::new(DashMap::new()),
544            handshake_locks: Arc::new(DashMap::new()),
545            neural_module: Arc::new(parking_lot::Mutex::new(neural)),
546            mask_catalog,
547            metrics: Arc::new(MetricsCollector::new()),
548            client_db: config.client_db,
549            recording_manager: Some(recording_manager),
550            mask_store: Some(mask_store),
551            bootstrap_descriptors,
552        })
553    }
554
555    async fn send_bootstrap_descriptors(
556        &self,
557        session: &Arc<parking_lot::Mutex<Session>>,
558    ) -> Result<()> {
559        for descriptor in &self.bootstrap_descriptors {
560            let payload = ControlPayload::BootstrapDescriptorUpdate {
561                descriptor_data: rmp_serde::to_vec(descriptor).map_err(|e| {
562                    Error::Session(format!("Failed to serialize bootstrap descriptor: {}", e))
563                })?,
564            };
565            self.send_control_message(&payload, session).await?;
566        }
567        Ok(())
568    }
569
570    /// Start the gateway
571    pub async fn run(mut self) -> Result<()> {
572        info!("Starting AIVPN Gateway on {}", self.config.listen_addr);
573        info!(
574            "Per-IP UDP rate limit: {} pps",
575            self.config.per_ip_pps_limit
576        );
577
578        // Create NAT forwarder (requires root — deferred from constructor for testability)
579        if self.config.enable_nat {
580            let mut nat = NatForwarder::new(
581                &self.config.tun_name,
582                &self.config.tun_addr,
583                &self.config.tun_netmask,
584                self.config.network_config,
585            )?;
586            nat.create()?;
587            self.nat_forwarder = Some(Arc::new(nat));
588            info!(
589                "TUN device: {} ({}/{})",
590                self.config.tun_name, self.config.tun_addr, self.config.tun_netmask
591            );
592        }
593
594        // Create UDP socket with 4MB OS buffers (OPTIMIZATION)
595        let bind_addr: SocketAddr =
596            self.config
597                .listen_addr
598                .parse()
599                .map_err(|e: std::net::AddrParseError| {
600                    Error::Io(std::io::Error::new(
601                        std::io::ErrorKind::InvalidInput,
602                        e.to_string(),
603                    ))
604                })?;
605
606        let socket2_sock = socket2::Socket::new(
607            if bind_addr.is_ipv4() {
608                socket2::Domain::IPV4
609            } else {
610                socket2::Domain::IPV6
611            },
612            socket2::Type::DGRAM,
613            Some(socket2::Protocol::UDP),
614        )
615        .map_err(Error::Io)?;
616
617        socket2_sock.set_nonblocking(true).map_err(Error::Io)?;
618        let _ = socket2_sock.set_recv_buffer_size(4 * 1024 * 1024);
619        let _ = socket2_sock.set_send_buffer_size(4 * 1024 * 1024);
620        socket2_sock.bind(&bind_addr.into()).map_err(Error::Io)?;
621
622        let std_sock: std::net::UdpSocket = socket2_sock.into();
623        let socket = UdpSocket::from_std(std_sock).map_err(Error::Io)?;
624
625        info!(
626            "UDP listener bound to {} (4MB buffers via socket2)",
627            self.config.listen_addr
628        );
629
630        self.udp_socket = Some(Arc::new(socket));
631
632        // Spawn neural resonance check loop (Patent 1 — periodic validation)
633        if self.config.enable_neural {
634            let neural = self.neural_module.clone();
635            let sessions = self.session_manager.clone();
636            let catalog = self.mask_catalog.clone();
637            let metrics = self.metrics.clone();
638            let check_interval = self.config.neural_config.check_interval_secs;
639            let socket = self.udp_socket.as_ref().unwrap().clone();
640
641            tokio::spawn(async move {
642                Self::resonance_check_loop(
643                    neural,
644                    sessions,
645                    catalog,
646                    metrics,
647                    check_interval,
648                    socket,
649                )
650                .await;
651            });
652            info!(
653                "Neural resonance check loop spawned (interval: {}s)",
654                check_interval
655            );
656        }
657
658        // Spawn TUN → Client read loop (reads packets from TUN, routes back to clients)
659        // Also set up channel-based TUN writer for upload path (avoids Mutex contention)
660        if let Some(ref nat) = self.nat_forwarder {
661            if let Some(tun_reader) = nat.take_reader().await {
662                let sessions = self.session_manager.clone();
663                let socket = self.udp_socket.as_ref().unwrap().clone();
664                let mask = self
665                    .mask_catalog
666                    .masks
667                    .iter()
668                    .next()
669                    .map(|e| e.value().clone())
670                    .expect("at least one mask must be loaded");
671                let server_vpn_ip = self.config.network_config.server_vpn_ip;
672                let recorder = self.recording_manager.clone();
673
674                // Channel for writing packets to TUN device (upload + ICMP replies)
675                let (tun_tx, tun_rx) = mpsc::channel::<Vec<u8>>(4096);
676                self.tun_write_tx = Some(tun_tx.clone());
677
678                // Spawn dedicated TUN writer task — owns the DeviceWriter, no Mutex needed
679                if let Some(tun_writer) = nat.take_writer().await {
680                    tokio::spawn(async move {
681                        Self::tun_write_loop(tun_writer, tun_rx).await;
682                    });
683                    info!("TUN write loop spawned (channel-based, no Mutex)");
684                } else {
685                    warn!("Could not take TUN writer — falling back to forward_packet");
686                }
687
688                let client_db = self.client_db.clone();
689                tokio::spawn(async move {
690                    Self::tun_read_loop(
691                        tun_reader,
692                        tun_tx,
693                        sessions,
694                        socket,
695                        mask,
696                        server_vpn_ip,
697                        recorder,
698                        client_db,
699                    )
700                    .await;
701                });
702                info!("TUN read loop spawned");
703            }
704        }
705
706        // Spawn periodic session cleanup task (remove expired/idle sessions and stop recordings)
707        {
708            let sessions = self.session_manager.clone();
709            let recorder = self.recording_manager.clone();
710            let socket = self.udp_socket.as_ref().unwrap().clone();
711            let mdh = self.mask_catalog.packet_mdh_bytes();
712            let neural = self.neural_module.clone();
713            tokio::spawn(async move {
714                loop {
715                    tokio::time::sleep(Duration::from_secs(5)).await;
716                    if let Some(ref rec) = recorder {
717                        let store = rec.store();
718                        for outcome in rec.take_ready_or_stale(
719                            aivpn_common::recording::RECORDING_IDLE_TIMEOUT_SECS,
720                        ) {
721                            let notify_session = match &outcome {
722                                RecordingStopOutcome::Completed(completed) => {
723                                    sessions.get_session(&completed.session_id)
724                                }
725                                RecordingStopOutcome::Incomplete(incomplete) => {
726                                    sessions.get_session(&incomplete.session_id)
727                                }
728                                RecordingStopOutcome::NotFound => None,
729                            };
730                            Self::handle_recording_outcome(
731                                &socket,
732                                &sessions,
733                                &store,
734                                &mdh,
735                                outcome,
736                                notify_session,
737                            )
738                            .await;
739                        }
740                    }
741
742                    let removed = sessions.cleanup_expired();
743                    for session_id in &removed {
744                        // Release per-session neural traffic stats; without this the
745                        // neural_module's DashMap grows unbounded as sessions expire.
746                        neural.lock().cleanup_stats(*session_id);
747                    }
748                    // Stop active recordings for removed sessions
749                    if let Some(ref rec) = recorder {
750                        let store = rec.store();
751                        for session_id in removed {
752                            let outcome = rec.stop_for_session_end(session_id);
753                            Self::handle_recording_outcome(
754                                &socket, &sessions, &store, &mdh, outcome, None,
755                            )
756                            .await;
757                        }
758                    }
759                }
760            });
761            info!("Session cleanup / recording auto-finish task spawned (5s interval)");
762        }
763
764        // Spawn client DB stats flush task (persist traffic stats every 5 min)
765        if let Some(ref db) = self.client_db {
766            let db = db.clone();
767            tokio::spawn(async move {
768                loop {
769                    tokio::time::sleep(Duration::from_secs(300)).await;
770                    db.flush_stats();
771                }
772            });
773            info!("Client stats flush task spawned (300s interval)");
774        }
775
776        // Spawn client DB hot-reload task (pick up new clients without restart)
777        if let Some(ref db) = self.client_db {
778            let db = db.clone();
779            tokio::spawn(async move {
780                loop {
781                    tokio::time::sleep(Duration::from_secs(10)).await;
782                    db.reload_if_changed();
783                }
784            });
785            info!("Client DB hot-reload task spawned (10s interval)");
786        }
787
788        // Use session-aware receive sharding: preserve ordering within one
789        // session, but allow different sessions to make progress in parallel.
790        let gateway = Arc::new(self);
791        Self::process_packets_concurrent(gateway).await?;
792
793        Ok(())
794    }
795
796    /// Background task: periodic neural resonance checks (Patent 1)
797    ///
798    /// For each active session, computes reconstruction error between
799    /// observed traffic features and the assigned mask's signature vector.
800    /// If MSE exceeds threshold → mask is detected as compromised by DPI.
801    /// Triggers automatic mask rotation (Patent 3).
802    async fn resonance_check_loop(
803        neural: Arc<parking_lot::Mutex<NeuralResonanceModule>>,
804        sessions: Arc<SessionManager>,
805        catalog: Arc<MaskCatalog>,
806        metrics: Arc<MetricsCollector>,
807        check_interval_secs: u64,
808        socket: Arc<UdpSocket>,
809    ) {
810        let interval = Duration::from_secs(check_interval_secs);
811
812        loop {
813            tokio::time::sleep(interval).await;
814
815            // Collect session IDs and their mask IDs
816            let session_checks: Vec<([u8; 16], String)> = sessions
817                .iter_sessions()
818                .filter_map(|entry| {
819                    let sess = entry.value().lock();
820                    let mask_id = sess
821                        .mask
822                        .as_ref()
823                        .map(|m| m.mask_id.clone())
824                        .unwrap_or_else(|| "unknown".to_string());
825                    Some((sess.session_id, mask_id))
826                })
827                .collect();
828
829            if session_checks.is_empty() {
830                continue;
831            }
832
833            // Collect mask update packets to send AFTER releasing the neural lock
834            // (parking_lot::MutexGuard is !Send, cannot hold across .await)
835            let mut pending_sends: Vec<(Vec<u8>, std::net::SocketAddr, [u8; 16], MaskProfile)> =
836                Vec::new();
837
838            {
839                let neural_guard = neural.lock();
840
841                for (session_id, mask_id) in &session_checks {
842                    // Check neural resonance (Patent 1: Signal Reconstruction Resonance)
843                    match neural_guard.check_resonance(*session_id, mask_id) {
844                        Ok(result) => {
845                            metrics
846                                .record_neural_check(result.status == ResonanceStatus::Compromised);
847
848                            match result.status {
849                                ResonanceStatus::Compromised => {
850                                    warn!(
851                                        "Mask '{}' compromised (MSE={:.4}) — triggering rotation (Patent 3)",
852                                        mask_id, result.mse
853                                    );
854
855                                    // Mark mask as compromised in catalog
856                                    catalog.mark_compromised(mask_id);
857
858                                    // Select fallback mask
859                                    if let Some(new_mask) = catalog.select_fallback(mask_id) {
860                                        info!(
861                                            "Auto-rotating to mask '{}' ({} masks remaining)",
862                                            new_mask.mask_id,
863                                            catalog.available_count()
864                                        );
865
866                                        if let Some(session) = sessions.get_session(session_id) {
867                                            let client_addr = session.lock().client_addr;
868                                            match sessions
869                                                .build_mask_update_packet(&session, &new_mask)
870                                            {
871                                                Ok(packet) => {
872                                                    pending_sends.push((
873                                                        packet,
874                                                        client_addr,
875                                                        *session_id,
876                                                        new_mask.clone(),
877                                                    ));
878                                                }
879                                                Err(e) => {
880                                                    warn!(
881                                                        "Failed to build MaskUpdate packet: {}",
882                                                        e
883                                                    );
884                                                }
885                                            }
886                                        }
887
888                                        metrics.record_mask_rotation();
889                                    } else {
890                                        error!(
891                                            "No fallback masks available! All masks compromised."
892                                        );
893                                    }
894                                }
895                                ResonanceStatus::Warning => {
896                                    debug!(
897                                        "Mask '{}' warning (MSE={:.4}) — monitoring",
898                                        mask_id, result.mse
899                                    );
900                                }
901                                ResonanceStatus::Healthy => {
902                                    // All good
903                                }
904                                ResonanceStatus::Skip => {
905                                    // Not enough data or model not loaded
906                                }
907                            }
908                        }
909                        Err(e) => {
910                            debug!("Resonance check error for session: {}", e);
911                        }
912                    }
913
914                    // Check anomaly detection (DPI blocking indicators)
915                    if neural_guard.is_mask_anomalous(mask_id) {
916                        warn!(
917                            "Anomaly detected for mask '{}' (packet loss / RTT spike)",
918                            mask_id
919                        );
920                        metrics.record_dpi_attack();
921                        catalog.mark_compromised(mask_id);
922
923                        if let Some(new_mask) = catalog.select_fallback(mask_id) {
924                            info!("Anomaly-triggered rotation to mask '{}'", new_mask.mask_id);
925                            if let Some(session) = sessions.get_session(session_id) {
926                                let client_addr = session.lock().client_addr;
927                                if let Ok(packet) =
928                                    sessions.build_mask_update_packet(&session, &new_mask)
929                                {
930                                    pending_sends.push((
931                                        packet,
932                                        client_addr,
933                                        *session_id,
934                                        new_mask.clone(),
935                                    ));
936                                }
937                            }
938                            metrics.record_mask_rotation();
939                        }
940                    }
941                }
942            } // neural_guard dropped here
943
944            // Send collected MaskUpdate packets (async, safe now)
945            for (packet, client_addr, session_id, new_mask) in pending_sends {
946                if let Err(e) = socket.send_to(&packet, client_addr).await {
947                    warn!("Failed to send MaskUpdate to {}: {}", client_addr, e);
948                } else {
949                    sessions.update_session_mask(&session_id, new_mask);
950                    info!("MaskUpdate control message sent to {}", client_addr);
951                }
952            }
953        }
954    }
955
956    /// TUN read loop: reads packets from TUN device and routes them back to clients
957    async fn tun_read_loop(
958        mut tun_reader: tun::DeviceReader,
959        tun_writer: tokio::sync::mpsc::Sender<Vec<u8>>,
960        sessions: Arc<SessionManager>,
961        socket: Arc<UdpSocket>,
962        mask: MaskProfile,
963        server_vpn_ip: Ipv4Addr,
964        recorder: Option<Arc<RecordingManager>>,
965        client_db: Option<Arc<ClientDatabase>>,
966    ) {
967        let mut buf = vec![0u8; MAX_PACKET_SIZE];
968        let server_ip = server_vpn_ip;
969
970        loop {
971            match tun_reader.read(&mut buf).await {
972                Ok(0) => continue,
973                Ok(n) => {
974                    let packet = &buf[..n];
975
976                    // Parse destination IP from IP header
977                    if packet.len() < 20 || (packet[0] >> 4) != 4 {
978                        continue; // Not IPv4
979                    }
980                    let dst_ip = Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
981
982                    // Handle ICMP echo request to server's own IP (ping to gateway)
983                    if dst_ip == server_ip && packet.len() >= 28 && packet[9] == 1 {
984                        // ICMP packet to server — generate echo reply
985                        if let Some(reply) = Self::build_icmp_echo_reply(packet, &server_ip) {
986                            let _ = tun_writer.send(reply).await;
987                        }
988                        continue;
989                    }
990
991                    // Find session by VPN IP
992                    let session = match sessions.get_session_by_vpn_ip(&dst_ip) {
993                        Some(s) => s,
994                        None => {
995                            debug!("TUN: no session for VPN IP {}", dst_ip);
996                            continue;
997                        }
998                    };
999
1000                    // Build encrypted response packet
1001                    // Minimize lock duration: extract only what we need under lock, then encrypt outside
1002                    let (session_id, client_addr, downlink_iat_ms, tag, mdh, ciphertext) = {
1003                        let mut sess = session.lock();
1004                        // Commit deferred mask switch if grace period has elapsed
1005                        sess.commit_pending_mask();
1006                        let session_id = sess.session_id;
1007                        let client_addr = sess.client_addr;
1008                        let seq_num = sess.next_seq() as u16;
1009                        let (nonce, counter) = sess.next_send_nonce();
1010                        let key = sess.keys.session_key.clone();
1011                        let tag_secret = sess.keys.tag_secret;
1012                        let downlink_iat_ms =
1013                            sess.last_server_send.elapsed().as_secs_f64() * 1000.0;
1014                        sess.last_server_send = Instant::now();
1015                        // Use the session's own mask for MDH so the client can
1016                        // decode with the mask it currently expects (bootstrap
1017                        // or runtime after MaskUpdate is processed).
1018                        let session_mdh = sess
1019                            .mask
1020                            .as_ref()
1021                            .map(packet_mdh_bytes_for_mask)
1022                            .unwrap_or_else(|| mask.header_template.clone());
1023                        // Pre-accumulate downlink bytes estimate (IP packet + overhead)
1024                        // This avoids a second lock after send_to
1025                        let estimated_out = (n + 64) as u64; // packet + AIVPN overhead
1026                        sess.pending_bytes_out =
1027                            sess.pending_bytes_out.saturating_add(estimated_out);
1028                        // Flush downlink-only traffic to client_db when threshold reached
1029                        let flush_out = if sess.pending_bytes_out >= 64 * 1024 {
1030                            let bytes = sess.pending_bytes_out;
1031                            let cid = sess.client_id.clone();
1032                            sess.pending_bytes_out = 0;
1033                            cid.map(|c| (c, bytes))
1034                        } else {
1035                            None
1036                        };
1037                        drop(sess); // Release lock BEFORE expensive encryption
1038                                    // Flush outside lock
1039                        if let (Some(ref db), Some((cid, bytes))) = (&client_db, flush_out) {
1040                            db.record_traffic(&cid, 0, bytes);
1041                        }
1042
1043                        // Build inner payload: Data type + IP packet
1044                        let inner_header = InnerHeader {
1045                            inner_type: InnerType::Data,
1046                            seq_num,
1047                        };
1048                        let mut inner_payload = inner_header.encode().to_vec();
1049                        inner_payload.extend_from_slice(packet);
1050
1051                        // Build MDH using session mask (not global runtime mask)
1052                        let mdh = session_mdh;
1053
1054                        // Pad and encrypt (outside lock)
1055                        let pad_len: u16 = 0;
1056                        let mut padded = Vec::with_capacity(2 + inner_payload.len());
1057                        padded.extend_from_slice(&pad_len.to_le_bytes());
1058                        padded.extend_from_slice(&inner_payload);
1059
1060                        let ciphertext = match encrypt_payload(&key, &nonce, &padded) {
1061                            Ok(ct) => ct,
1062                            Err(e) => {
1063                                debug!("TUN: encrypt error: {}", e);
1064                                continue;
1065                            }
1066                        };
1067
1068                        // Generate tag (outside lock)
1069                        let time_window = crypto::compute_time_window(
1070                            crypto::current_timestamp_ms(),
1071                            aivpn_common::crypto::DEFAULT_WINDOW_MS,
1072                        );
1073                        let tag = crypto::generate_resonance_tag(&tag_secret, counter, time_window);
1074
1075                        (
1076                            session_id,
1077                            client_addr,
1078                            downlink_iat_ms,
1079                            tag,
1080                            mdh,
1081                            ciphertext,
1082                        )
1083                    };
1084
1085                    // Assemble: TAG | MDH | ciphertext
1086                    let mut aivpn_packet =
1087                        Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
1088                    aivpn_packet.extend_from_slice(&tag);
1089                    aivpn_packet.extend_from_slice(&mdh);
1090                    aivpn_packet.extend_from_slice(&ciphertext);
1091
1092                    // Send to client
1093                    if let Err(e) = socket.send_to(&aivpn_packet, client_addr).await {
1094                        debug!("TUN: send failed: {}", e);
1095                    } else {
1096                        // bytes_out already tracked inside the earlier lock scope
1097                        if let Some(ref recorder) = recorder {
1098                            if recorder.is_recording(&session_id) {
1099                                let meta = aivpn_common::recording::PacketMetadata {
1100                                    direction: aivpn_common::recording::Direction::Downlink,
1101                                    size: aivpn_packet.len() as u16,
1102                                    iat_ms: downlink_iat_ms,
1103                                    entropy: Self::compute_entropy(&ciphertext) as f32,
1104                                    header_prefix: aivpn_packet[TAG_SIZE
1105                                        ..TAG_SIZE + 16.min(aivpn_packet.len() - TAG_SIZE)]
1106                                        .to_vec(),
1107                                    timestamp_ns: std::time::SystemTime::now()
1108                                        .duration_since(std::time::UNIX_EPOCH)
1109                                        .unwrap_or_default()
1110                                        .as_nanos()
1111                                        as u64,
1112                                };
1113                                recorder.record_packet(session_id, meta);
1114                            }
1115                        }
1116                    }
1117                }
1118                Err(e) => {
1119                    error!("TUN read error: {}", e);
1120                    tokio::time::sleep(Duration::from_millis(10)).await;
1121                }
1122            }
1123        }
1124    }
1125
1126    /// Build ICMP Echo Reply from Echo Request
1127    fn build_icmp_echo_reply(request: &[u8], server_ip: &Ipv4Addr) -> Option<Vec<u8>> {
1128        if request.len() < 28 {
1129            return None;
1130        }
1131
1132        // Parse source IP
1133        let src_ip = Ipv4Addr::new(request[12], request[13], request[14], request[15]);
1134
1135        // Parse ICMP type and code
1136        let icmp_type = request[20];
1137        if icmp_type != 8 {
1138            return None; // Not echo request
1139        }
1140
1141        // Build reply: swap src/dst IP, change ICMP type to 0 (echo reply)
1142        let mut reply = Vec::with_capacity(request.len());
1143
1144        // IP header
1145        reply.push(0x45); // Version 4, IHL 5
1146        reply.push(0x00); // DSCP/ECN
1147        let total_len = (request.len() as u16).to_be_bytes();
1148        reply.extend_from_slice(&total_len);
1149        reply.extend_from_slice(&request[4..6]); // Identification
1150        reply.extend_from_slice(&request[6..8]); // Flags/Fragment
1151        reply.push(64); // TTL
1152        reply.push(1); // Protocol: ICMP
1153        reply.push(0); // Header checksum (will be computed by kernel)
1154        reply.push(0);
1155        reply.extend_from_slice(&server_ip.octets()); // Source IP (server)
1156        reply.extend_from_slice(&src_ip.octets()); // Dest IP (client)
1157
1158        // ICMP header
1159        reply.push(0); // Type: Echo Reply
1160        reply.push(request[21]); // Code
1161        reply.push(0); // Checksum placeholder
1162        reply.push(0);
1163        reply.extend_from_slice(&request[24..28]); // ID + Sequence
1164        reply.extend_from_slice(&request[28..]); // Data
1165
1166        // Compute ICMP checksum
1167        let checksum = Self::compute_checksum(&reply[20..]);
1168        reply[22] = (checksum >> 8) as u8;
1169        reply[23] = (checksum & 0xFF) as u8;
1170
1171        Some(reply)
1172    }
1173
1174    /// Compute Internet checksum (RFC 1071)
1175    fn compute_checksum(data: &[u8]) -> u16 {
1176        let mut sum: u32 = 0;
1177        let mut i = 0;
1178
1179        // Process 16-bit words
1180        while i + 1 < data.len() {
1181            sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
1182            i += 2;
1183        }
1184
1185        // Add remaining byte
1186        if i < data.len() {
1187            sum += (data[i] as u32) << 8;
1188        }
1189
1190        // Fold 32-bit sum to 16 bits
1191        while (sum >> 16) != 0 {
1192            sum = (sum & 0xFFFF) + (sum >> 16);
1193        }
1194
1195        !sum as u16
1196    }
1197
1198    /// Dedicated TUN writer task — owns the DeviceWriter, no Mutex contention
1199    async fn tun_write_loop(mut writer: tun::DeviceWriter, mut rx: mpsc::Receiver<Vec<u8>>) {
1200        while let Some(packet) = rx.recv().await {
1201            if let Err(e) = writer.write_all(&packet).await {
1202                error!("TUN write error: {}", e);
1203            }
1204            // No flush() — let the OS buffer writes for throughput
1205        }
1206        warn!("TUN write loop ended — channel closed");
1207    }
1208
1209    fn receive_worker_count() -> usize {
1210        std::thread::available_parallelism()
1211            .map(|count| count.get())
1212            .unwrap_or(4)
1213            .clamp(2, 16)
1214    }
1215
1216    fn worker_index_for_packet(
1217        &self,
1218        packet_data: &[u8],
1219        client_addr: SocketAddr,
1220        worker_count: usize,
1221    ) -> usize {
1222        if worker_count <= 1 {
1223            return 0;
1224        }
1225
1226        let mut shard_addr = client_addr;
1227
1228        if packet_data.len() >= TAG_SIZE {
1229            let mut tag = [0u8; TAG_SIZE];
1230            tag.copy_from_slice(&packet_data[..TAG_SIZE]);
1231
1232            if let Some(session) = self.session_manager.get_session_by_tag(&tag) {
1233                shard_addr = session.lock().client_addr;
1234            }
1235        }
1236
1237        let key = match shard_addr.ip() {
1238            IpAddr::V4(ip) => ((u32::from(ip) as u64) << 16) | shard_addr.port() as u64,
1239            IpAddr::V6(ip) => {
1240                let octets = ip.octets();
1241                u64::from_le_bytes(octets[..8].try_into().unwrap()) ^ shard_addr.port() as u64
1242            }
1243        };
1244
1245        (key as usize) % worker_count
1246    }
1247
1248    /// Concurrent packet processing loop with shard workers.
1249    /// Packets for the same session stay on the same worker and preserve order,
1250    /// while different sessions can be processed in parallel.
1251    async fn process_packets_concurrent(gateway: Arc<Self>) -> Result<()> {
1252        let socket = gateway.udp_socket.as_ref().unwrap().clone();
1253        let mut buf = vec![0u8; MAX_PACKET_SIZE];
1254        let worker_count = Self::receive_worker_count();
1255        let queue_depth = 4096;
1256        let mut worker_txs = Vec::with_capacity(worker_count);
1257
1258        for worker_id in 0..worker_count {
1259            let (tx, mut rx) = mpsc::channel::<QueuedPacket>(queue_depth);
1260            worker_txs.push(tx);
1261
1262            let gw = gateway.clone();
1263            tokio::spawn(async move {
1264                while let Some(packet) = rx.recv().await {
1265                    if let Err(e) = gw
1266                        .handle_packet(&packet.packet_data, packet.client_addr)
1267                        .await
1268                    {
1269                        debug!(
1270                            "Worker {} packet error from {}: {}",
1271                            worker_id,
1272                            hash_addr(&packet.client_addr),
1273                            e
1274                        );
1275                    }
1276                }
1277                warn!("Receive worker {} ended — channel closed", worker_id);
1278            });
1279        }
1280
1281        loop {
1282            match socket.recv_from(&mut buf).await {
1283                Ok((len, client_addr)) => {
1284                    // Per-IP rate limiting (fast, stays in recv task)
1285                    {
1286                        let now = Instant::now();
1287                        let mut entry = gateway
1288                            .rate_limits
1289                            .entry(client_addr.ip())
1290                            .or_insert((0, now));
1291                        if entry.1.elapsed() > Duration::from_secs(1) {
1292                            entry.0 = 0;
1293                            entry.1 = now;
1294                        }
1295                        entry.0 += 1;
1296                        if entry.0 > gateway.config.per_ip_pps_limit {
1297                            continue;
1298                        }
1299                    }
1300
1301                    let packet_data = buf[..len].to_vec();
1302                    let worker_idx =
1303                        gateway.worker_index_for_packet(&packet_data, client_addr, worker_count);
1304                    let packet = QueuedPacket {
1305                        packet_data,
1306                        client_addr,
1307                    };
1308
1309                    if worker_txs[worker_idx].send(packet).await.is_err() {
1310                        return Err(Error::Channel(format!(
1311                            "Receive worker {worker_idx} channel closed"
1312                        )));
1313                    }
1314                }
1315                Err(e) => {
1316                    error!("UDP recv error: {}", e);
1317                    tokio::time::sleep(Duration::from_millis(10)).await;
1318                }
1319            }
1320        }
1321    }
1322
1323    /// Main packet processing loop (legacy sequential — unused, kept for reference)
1324    #[allow(dead_code)]
1325    async fn process_packets(&self) -> Result<()> {
1326        let socket = self.udp_socket.as_ref().unwrap();
1327        let mut buf = vec![0u8; MAX_PACKET_SIZE];
1328
1329        loop {
1330            match socket.recv_from(&mut buf).await {
1331                Ok((len, client_addr)) => {
1332                    // Per-IP rate limiting.
1333                    {
1334                        let now = Instant::now();
1335                        let mut entry =
1336                            self.rate_limits.entry(client_addr.ip()).or_insert((0, now));
1337                        if entry.1.elapsed() > Duration::from_secs(1) {
1338                            entry.0 = 0;
1339                            entry.1 = now;
1340                        }
1341                        entry.0 += 1;
1342                        if entry.0 > self.config.per_ip_pps_limit {
1343                            continue;
1344                        }
1345                    }
1346
1347                    let packet_data = &buf[..len];
1348
1349                    // Process packet
1350                    if let Err(e) = self.handle_packet(packet_data, client_addr).await {
1351                        debug!("Packet error from {}: {}", hash_addr(&client_addr), e);
1352                        // Silent drop - no response for invalid packets
1353                    }
1354                }
1355                Err(e) => {
1356                    error!("UDP recv error: {}", e);
1357                    tokio::time::sleep(Duration::from_millis(10)).await;
1358                }
1359            }
1360        }
1361    }
1362
1363    /// Handle incoming packet
1364    async fn handle_packet(&self, packet_data: &[u8], client_addr: SocketAddr) -> Result<()> {
1365        // Minimum packet size check
1366        if packet_data.len() < TAG_SIZE + 2 {
1367            return Err(Error::InvalidPacket("Too short"));
1368        }
1369
1370        // Extract resonance tag
1371        let mut tag = [0u8; TAG_SIZE];
1372        tag.copy_from_slice(&packet_data[0..TAG_SIZE]);
1373
1374        // Default layout from runtime primary mask (used for handshake fallback).
1375        let (catalog_mdh_len, catalog_hs_mdh_len, _eph_offset, _eph_len) =
1376            self.mask_catalog.packet_layout();
1377        let mut is_new_session = false;
1378        let (session, counter, is_ratcheted_tag) = if let Some(session) =
1379            self.session_manager.get_session_by_tag(&tag)
1380        {
1381            // Existing session — validate tag
1382            let (counter, is_ratcheted) = {
1383                let sess = session.lock();
1384                sess.validate_tag(&tag)
1385                    .ok_or(Error::InvalidPacket("Invalid tag"))?
1386            };
1387            (session, counter, is_ratcheted)
1388        } else if let Some((session, counter, is_ratcheted)) =
1389            self.session_manager.refresh_and_find_by_tag(&tag)
1390        {
1391            // Tag not in map — time window may have advanced. Refresh all sessions and retry.
1392            debug!(
1393                "Tag matched after refresh (counter={}, ratcheted={})",
1394                counter, is_ratcheted
1395            );
1396            (session, counter, is_ratcheted)
1397        } else if let Some((session, counter, is_ratcheted)) = self
1398            .session_manager
1399            .recover_session_by_tag(&tag, &client_addr.ip())
1400        {
1401            // Counter drift recovery — client counter was out of range but session keys match
1402            (session, counter, is_ratcheted)
1403        } else {
1404            // NOTE: We intentionally do NOT drop packets from the same public IP
1405            // on a different port. Multiple clients behind the same NAT must be
1406            // able to handshake independently (different PSKs → different sessions).
1407
1408            // FIX Issue #42: Skip handshake if this IP already has a fresh
1409            // ratcheted session on a different port (NAT rebind / stale packets).
1410            if self
1411                .session_manager
1412                .has_recent_ratcheted_session_on_other_endpoint(
1413                    &client_addr,
1414                    Duration::from_secs(30),
1415                )
1416            {
1417                return Err(Error::InvalidPacket("Active session exists on other port"));
1418            }
1419
1420            // Serialize concurrent handshakes from the same source IP.
1421            // When a client reconnects rapidly, multiple shard workers may receive
1422            // init packets on different source ports simultaneously and each enter
1423            // this branch before any session is registered in tag_map. Without
1424            // serialization both complete PSK-matching, create sessions for the same
1425            // VPN IP, and the last cleanup_old_sessions_for_vpn_ip call removes the
1426            // session the client already ratcheted to, causing aead::Error on all
1427            // subsequent data packets. try_lock_owned is non-blocking: if another
1428            // handshake is in progress for this IP we drop the packet silently;
1429            // the client retransmits naturally and hits the existing-session path.
1430            let _handshake_guard = {
1431                let lock = {
1432                    let entry = self
1433                        .handshake_locks
1434                        .entry(client_addr.ip())
1435                        .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())));
1436                    entry.value().clone()
1437                };
1438                match lock.try_lock_owned() {
1439                    Ok(guard) => guard,
1440                    Err(_) => return Ok(()),
1441                }
1442            };
1443
1444            // Guard against session pool exhaustion: the handshake path calls
1445            // create_session() speculatively for every (client × bootstrap_mask)
1446            // combination before tag validation confirms which one is correct.
1447            // An attacker spoofing many source IPs can fill the pool with temporary
1448            // sessions and block legitimate clients. Reserve 10 slots so ratchet
1449            // renewals for existing sessions always have capacity.
1450            if self.session_manager.session_count() + 10 >= MAX_SESSIONS {
1451                debug!("Session pool near capacity ({}/{}), dropping unauthenticated handshake from {}",
1452                    self.session_manager.session_count(), MAX_SESSIONS, hash_addr(&client_addr));
1453                return Ok(());
1454            }
1455
1456            // No session found — try handshake
1457            // Rate-limit failed handshake attempts to prevent rapid session-creation loops.
1458            // After mask rotation or session timeout, stale clients may flood the server
1459            // with packets that consistently fail tag validation (issue #21, #42).
1460            {
1461                let ip = client_addr.ip();
1462                if let Some(entry) = self.handshake_cooldowns.get(&ip) {
1463                    let (fail_count, last_fail) = *entry;
1464                    // Exponential cooldown: 2s → 4s → 8s → 16s (max)
1465                    let cooldown = Duration::from_millis((2000 * (1 << fail_count.min(3))) as u64);
1466                    if last_fail.elapsed() < cooldown {
1467                        debug!("Handshake cooldown active for {}: fail_count={}, elapsed={:?}, cooldown={:?}",
1468                            hash_addr(&client_addr), fail_count, last_fail.elapsed(), cooldown);
1469                        return Err(Error::InvalidPacket("Handshake cooldown active"));
1470                    }
1471                }
1472            }
1473
1474            // Try to establish a new session using one of the built-in bootstrap masks.
1475            // Runtime masks can be server-generated, but bootstrap must remain compatible
1476            // with clients that only know the shipped presets.
1477            // If client_db is configured, iterate registered clients and try
1478            // DH + PSK to find one whose derived tags match.
1479            // Falls back to no-PSK for backward compatibility.
1480            let builtin_bootstrap_masks = aivpn_common::mask::preset_masks::all();
1481            let (session, matched_client_id, bootstrap_mask) = if let Some(ref db) = self.client_db
1482            {
1483                let clients = db.list_clients();
1484                let mut found = None;
1485                'bootstrap: for client_cfg in &clients {
1486                    if !client_cfg.enabled {
1487                        continue;
1488                    }
1489
1490                    let psk = client_cfg.psk;
1491                    let candidate_masks = self
1492                        .bootstrap_descriptors
1493                        .iter()
1494                        .flat_map(|descriptor| derive_bootstrap_candidates(descriptor, Some(&psk)))
1495                        .chain(builtin_bootstrap_masks.clone().into_iter())
1496                        .collect::<Vec<_>>();
1497
1498                    for bootstrap_mask in candidate_masks {
1499                        let (
1500                            _,
1501                            candidate_handshake_mdh_len,
1502                            candidate_eph_offset,
1503                            candidate_eph_len,
1504                        ) = packet_layout_for_mask(&bootstrap_mask);
1505                        if packet_data.len() < TAG_SIZE + candidate_handshake_mdh_len {
1506                            continue;
1507                        }
1508                        let eph_start = TAG_SIZE + candidate_eph_offset;
1509                        if packet_data.len() < eph_start + candidate_eph_len {
1510                            continue;
1511                        }
1512
1513                        let mut eph_pub = [0u8; 32];
1514                        eph_pub.copy_from_slice(
1515                            &packet_data[eph_start..eph_start + candidate_eph_len],
1516                        );
1517                        crypto::obfuscate_eph_pub(
1518                            &mut eph_pub,
1519                            &self.session_manager.server_public_key(),
1520                        );
1521
1522                        match self.session_manager.create_session(
1523                            client_addr,
1524                            eph_pub,
1525                            Some(psk),
1526                            Some(client_cfg.vpn_ip),
1527                        ) {
1528                            Ok(sess) => {
1529                                let validation = sess.lock().validate_tag(&tag);
1530                                if validation.is_some() {
1531                                    debug!(
1532                                        "Tag validation SUCCESS for client {} via bootstrap mask {}",
1533                                        client_cfg.id,
1534                                        bootstrap_mask.mask_id
1535                                    );
1536                                    found =
1537                                        Some((sess, Some(client_cfg.id.clone()), bootstrap_mask));
1538                                    break 'bootstrap;
1539                                }
1540                                let sid = sess.lock().session_id;
1541                                self.session_manager.rollback_failed_session(&sid);
1542                            }
1543                            Err(e) => {
1544                                debug!("create_session failed: {}", e);
1545                                continue;
1546                            }
1547                        }
1548                    }
1549                }
1550                match found {
1551                    Some(f) => f,
1552                    None => {
1553                        // Track failed handshake for cooldown
1554                        let ip = client_addr.ip();
1555                        let fail_count =
1556                            self.handshake_cooldowns.get(&ip).map(|e| e.0).unwrap_or(0);
1557                        self.handshake_cooldowns
1558                            .insert(ip, (fail_count + 1, Instant::now()));
1559                        warn!(
1560                            "Handshake failed for {} (attempt #{}) — tag mismatch for all {} registered clients",
1561                            hash_addr(&client_addr),
1562                            fail_count + 1,
1563                            clients.len()
1564                        );
1565                        return Err(Error::InvalidPacket(
1566                            "No registered client matches this handshake",
1567                        ));
1568                    }
1569                }
1570            } else {
1571                // No client DB — legacy mode without PSK
1572                let mut found = None;
1573                let candidate_masks = self
1574                    .bootstrap_descriptors
1575                    .iter()
1576                    .flat_map(|descriptor| derive_bootstrap_candidates(descriptor, None))
1577                    .chain(builtin_bootstrap_masks.clone().into_iter())
1578                    .collect::<Vec<_>>();
1579                for bootstrap_mask in candidate_masks {
1580                    let (_, candidate_handshake_mdh_len, candidate_eph_offset, candidate_eph_len) =
1581                        packet_layout_for_mask(&bootstrap_mask);
1582                    if packet_data.len() < TAG_SIZE + candidate_handshake_mdh_len {
1583                        continue;
1584                    }
1585                    let eph_start = TAG_SIZE + candidate_eph_offset;
1586                    if packet_data.len() < eph_start + candidate_eph_len {
1587                        continue;
1588                    }
1589
1590                    let mut eph_pub = [0u8; 32];
1591                    eph_pub.copy_from_slice(&packet_data[eph_start..eph_start + candidate_eph_len]);
1592                    crypto::obfuscate_eph_pub(
1593                        &mut eph_pub,
1594                        &self.session_manager.server_public_key(),
1595                    );
1596
1597                    let sess =
1598                        self.session_manager
1599                            .create_session(client_addr, eph_pub, None, None)?;
1600                    let validation = sess.lock().validate_tag(&tag);
1601                    if validation.is_some() {
1602                        found = Some((sess, None, bootstrap_mask));
1603                        break;
1604                    }
1605                    let sid = sess.lock().session_id;
1606                    self.session_manager.rollback_failed_session(&sid);
1607                }
1608
1609                found.ok_or_else(|| {
1610                    Error::InvalidPacket("No bootstrap mask matched this handshake")
1611                })?
1612            };
1613
1614            // Validate the tag against the session.
1615            let validation = {
1616                let sess = session.lock();
1617                sess.validate_tag(&tag)
1618            };
1619            let (counter, is_ratcheted) = match validation {
1620                Some(result) => result,
1621                None => {
1622                    let session_id = session.lock().session_id;
1623                    self.session_manager.rollback_failed_session(&session_id);
1624                    return Err(Error::InvalidPacket("Tag mismatch on new session"));
1625                }
1626            };
1627
1628            // Tag is valid — this is a real handshake.
1629            // Clean up old sessions for the SAME CLIENT (by VPN IP), not
1630            // all sessions from this source IP — different clients behind
1631            // the same NAT must coexist.
1632            {
1633                let (session_id, vpn_ip) = {
1634                    let sess_lock = session.lock();
1635                    (sess_lock.session_id, sess_lock.vpn_ip)
1636                };
1637                if let Some(vpn_ip) = vpn_ip {
1638                    let removed = self
1639                        .session_manager
1640                        .cleanup_old_sessions_for_vpn_ip(&vpn_ip, &session_id);
1641                    // Stop active recordings for removed stale sessions
1642                    if let Some(ref recorder) = self.recording_manager {
1643                        let socket = self.udp_socket.as_ref().unwrap().clone();
1644                        let store = recorder.store();
1645                        let mdh = self.mask_catalog.packet_mdh_bytes();
1646                        for sid in removed {
1647                            let outcome = recorder.stop_for_session_end(sid);
1648                            Self::handle_recording_outcome(
1649                                &socket,
1650                                &self.session_manager,
1651                                &store,
1652                                &mdh,
1653                                outcome,
1654                                None,
1655                            )
1656                            .await;
1657                        }
1658                    }
1659                }
1660            }
1661
1662            // Successful handshake — clear cooldown for this IP
1663            self.handshake_cooldowns.remove(&client_addr.ip());
1664
1665            {
1666                let mut sess = session.lock();
1667                sess.mask = Some(bootstrap_mask.clone());
1668            }
1669
1670            // Record handshake in client DB
1671            if let (Some(ref db), Some(ref cid)) = (&self.client_db, &matched_client_id) {
1672                db.record_handshake(cid);
1673                // Store client_id in session for traffic accounting
1674                session.lock().client_id = Some(cid.clone());
1675                debug!("Client '{}' authenticated via PSK", cid);
1676            }
1677
1678            self.send_server_hello(&session, client_addr).await?;
1679            self.send_bootstrap_descriptors(&session).await?;
1680
1681            if let Some(runtime_mask) = self.mask_catalog.primary_mask() {
1682                if runtime_mask.mask_id != bootstrap_mask.mask_id {
1683                    match self
1684                        .session_manager
1685                        .build_mask_update_packet(&session, &runtime_mask)
1686                    {
1687                        Ok(packet) => {
1688                            self.udp_socket
1689                                .as_ref()
1690                                .unwrap()
1691                                .send_to(&packet, client_addr)
1692                                .await?;
1693                            // NOTE: Do NOT call update_session_mask here.
1694                            // The client still sends packets with bootstrap mask layout
1695                            // until it processes MaskUpdate. Keep sess.mask = bootstrap
1696                            // so per-session decryption uses bootstrap layout, with
1697                            // catalog (runtime) layout as fallback for transition.
1698                        }
1699                        Err(e) => {
1700                            warn!("Failed to send initial runtime MaskUpdate: {}", e);
1701                        }
1702                    }
1703                }
1704            }
1705
1706            // NOTE: PFS ratchet is deferred until AFTER decrypting the init packet,
1707            // which was encrypted with pre-ratchet keys.
1708
1709            is_new_session = true;
1710            debug!(
1711                "New session from {} (ServerHello sent)",
1712                hash_addr(&client_addr)
1713            );
1714            (session, counter, is_ratcheted)
1715        };
1716
1717        // Parse packet — pad_len is inside encrypted area (CRIT-5 fix).
1718        // Use the session's own mask layout for decryption. This is critical
1719        // because the client may still be using its bootstrap mask before
1720        // receiving and applying a MaskUpdate from the server.
1721        // We try both the session mask layout AND the catalog (runtime) layout
1722        // to handle the transition window.
1723        let (session_mdh_len, session_hs_mdh_len) = {
1724            let sess = session.lock();
1725            if let Some(ref mask) = sess.mask {
1726                let (p, h, _, _) = packet_layout_for_mask(mask);
1727                (p, h)
1728            } else {
1729                (catalog_mdh_len, catalog_hs_mdh_len)
1730            }
1731        };
1732        let packet_mdh_len = session_mdh_len;
1733        let handshake_mdh_len = session_hs_mdh_len;
1734        // Android retransmits the initial handshake packet with the client
1735        // eph_pub still embedded inside the MDH. Once a session already exists,
1736        // those retries validate against the existing tag window, so the
1737        // ciphertext still starts immediately after the full MDH.
1738        let is_pre_ratchet_retry = !is_new_session && !is_ratcheted_tag && {
1739            let sess = session.lock();
1740            !sess.is_ratcheted && packet_data.len() >= TAG_SIZE + handshake_mdh_len + 16
1741        };
1742        let mut payload_offsets: Vec<usize> = if is_new_session {
1743            vec![TAG_SIZE + handshake_mdh_len]
1744        } else if is_pre_ratchet_retry && handshake_mdh_len != packet_mdh_len {
1745            vec![TAG_SIZE + packet_mdh_len, TAG_SIZE + handshake_mdh_len]
1746        } else {
1747            vec![TAG_SIZE + packet_mdh_len]
1748        };
1749        // During mask transition (bootstrap → runtime), also try the catalog
1750        // (runtime) layout in case the client already applied MaskUpdate.
1751        if catalog_mdh_len != session_mdh_len {
1752            let catalog_offset = TAG_SIZE + catalog_mdh_len;
1753            if !payload_offsets.contains(&catalog_offset) {
1754                payload_offsets.push(catalog_offset);
1755            }
1756        }
1757
1758        let (payload_offset, padded_plaintext) = {
1759            let sess = session.lock();
1760            let nonce = self.compute_nonce(counter);
1761            // For new sessions, always use initial keys for decryption since the
1762            // client hasn't received ServerHello yet and is still sending with
1763            // initial keys. Only use ratcheted keys when the client proves it
1764            // has switched by sending a ratcheted tag on an existing session.
1765            let key = if is_new_session {
1766                &sess.keys.session_key
1767            } else if is_ratcheted_tag {
1768                &sess
1769                    .ratcheted_keys
1770                    .as_ref()
1771                    .ok_or(Error::InvalidPacket("Ratcheted keys missing"))?
1772                    .session_key
1773            } else {
1774                &sess.keys.session_key
1775            };
1776
1777            let mut decrypted = None;
1778            let mut last_error = None;
1779            for payload_offset in payload_offsets {
1780                if packet_data.len() <= payload_offset {
1781                    continue;
1782                }
1783                let encrypted_payload = &packet_data[payload_offset..];
1784                match decrypt_payload(key, &nonce, encrypted_payload) {
1785                    Ok(padded_plaintext) => {
1786                        decrypted = Some((payload_offset, padded_plaintext));
1787                        break;
1788                    }
1789                    Err(err) => last_error = Some(err),
1790                }
1791            }
1792
1793            match decrypted {
1794                Some(result) => result,
1795                None => {
1796                    return Err(last_error.unwrap_or_else(|| Error::InvalidPacket("Invalid length")))
1797                }
1798            }
1799        };
1800        let encrypted_payload = &packet_data[payload_offset..];
1801
1802        // Complete PFS ratchet only when the CLIENT proves it has ratcheted
1803        // by sending a packet with ratcheted-key tags.
1804        // Do NOT ratchet on is_new_session — the client hasn't received
1805        // ServerHello yet and will keep sending packets with initial keys.
1806        if is_ratcheted_tag {
1807            let session_id = session.lock().session_id;
1808            self.session_manager.complete_session_ratchet(&session_id);
1809            self.session_manager.refresh_session_tags(&session_id);
1810            let sess = session.lock();
1811            info!(
1812                "PFS ratchet complete for {} — send_counter={}, counter={}",
1813                hash_addr(&client_addr),
1814                sess.send_counter,
1815                sess.counter
1816            );
1817        }
1818
1819        // Extract pad_len from inside decrypted data and strip padding
1820        if padded_plaintext.len() < 2 {
1821            return Err(Error::InvalidPacket("Decrypted payload too short"));
1822        }
1823        let pad_len = u16::from_le_bytes([padded_plaintext[0], padded_plaintext[1]]) as usize;
1824        if 2 + pad_len > padded_plaintext.len() {
1825            return Err(Error::InvalidPacket("Invalid padding length"));
1826        }
1827        let plaintext = &padded_plaintext[2..padded_plaintext.len() - pad_len];
1828
1829        // Update session state. Avoid expensive O(window) tag-map rebuild on every packet.
1830        let mut client_db_flush: Option<(String, u64, u64)> = None;
1831        let (session_id, refresh_tags) = {
1832            let mut sess = session.lock();
1833            sess.mark_tag_received(counter);
1834            sess.last_seen = std::time::Instant::now();
1835
1836            // IP migration: update stored client address when a validated packet
1837            // arrives from a different endpoint (e.g. WiFi → cellular switchover).
1838            // Safe because the packet passed full cryptographic validation.
1839            if !is_new_session && sess.client_addr != client_addr {
1840                info!(
1841                    "Client endpoint migrated: {} → {} (session keepalive active)",
1842                    hash_addr(&sess.client_addr),
1843                    hash_addr(&client_addr)
1844                );
1845                sess.client_addr = client_addr;
1846            }
1847
1848            // Refresh precomputed tag window only when we've moved far enough.
1849            // Window size is 256; refreshing every 64 packets keeps enough headroom
1850            // while reducing CPU spent in HashMap/tag_map maintenance.
1851            let refresh_tags = counter.saturating_sub(sess.tag_window_base) >= 64;
1852            if refresh_tags {
1853                sess.update_tag_window();
1854            }
1855
1856            // Batch client stats updates to avoid taking a global write lock per packet.
1857            sess.pending_bytes_in = sess
1858                .pending_bytes_in
1859                .saturating_add(packet_data.len() as u64);
1860            if sess.pending_bytes_in >= 16 * 1024 || sess.pending_bytes_out >= 16 * 1024 {
1861                if let Some(cid) = sess.client_id.clone() {
1862                    client_db_flush = Some((cid, sess.pending_bytes_in, sess.pending_bytes_out));
1863                }
1864                sess.pending_bytes_in = 0;
1865                sess.pending_bytes_out = 0;
1866            }
1867
1868            sess.update_fsm();
1869            (sess.session_id, refresh_tags)
1870        };
1871
1872        // Refresh tag_map only when the precomputed window moves.
1873        if refresh_tags {
1874            self.session_manager.refresh_session_tags(&session_id);
1875        }
1876
1877        // Record traffic stats for neural resonance (Patent 1)
1878        if self.config.enable_neural {
1879            let packet_size = packet_data.len() as u16;
1880            // Compute byte-level entropy of the encrypted payload
1881            let entropy = Self::compute_entropy(encrypted_payload);
1882            // Compute real IAT from session's last_seen timestamp
1883            let iat_ms = {
1884                let sess = session.lock();
1885                let elapsed = sess.last_seen.elapsed();
1886                elapsed.as_secs_f64() * 1000.0
1887            };
1888            // Neural model update is expensive under lock. Sampling every 16th packet
1889            // preserves trends while reducing lock contention in the receive hot path.
1890            if counter & 0x0f == 0 {
1891                self.neural_module
1892                    .lock()
1893                    .record_traffic(session_id, packet_size, iat_ms, entropy);
1894            }
1895            self.metrics.record_packet_received(packet_data.len());
1896        }
1897
1898        // Record uplink packet metadata for auto mask recording
1899        if let Some(ref recorder) = self.recording_manager {
1900            let session_id = session.lock().session_id;
1901            if recorder.is_recording(&session_id) {
1902                let iat_ms = {
1903                    let sess = session.lock();
1904                    sess.last_seen.elapsed().as_secs_f64() * 1000.0
1905                };
1906                let meta = aivpn_common::recording::PacketMetadata {
1907                    direction: aivpn_common::recording::Direction::Uplink,
1908                    size: packet_data.len() as u16,
1909                    iat_ms,
1910                    entropy: Self::compute_entropy(encrypted_payload) as f32,
1911                    header_prefix: packet_data
1912                        [TAG_SIZE..TAG_SIZE + 16.min(packet_data.len() - TAG_SIZE)]
1913                        .to_vec(),
1914                    timestamp_ns: std::time::SystemTime::now()
1915                        .duration_since(std::time::UNIX_EPOCH)
1916                        .unwrap_or_default()
1917                        .as_nanos() as u64,
1918                };
1919                recorder.record_packet(session_id, meta);
1920            }
1921        }
1922
1923        // Record traffic in client DB in batches (see pending_bytes_in/out above).
1924        if let (Some(ref db), Some((cid, bytes_in, bytes_out))) = (&self.client_db, client_db_flush)
1925        {
1926            db.record_traffic(&cid, bytes_in, bytes_out);
1927        }
1928
1929        // Process inner payload (skip for new sessions — ServerHello is already the response,
1930        // and any ControlAck sent here would use pre-ratchet keys that the client can't validate)
1931        if !is_new_session {
1932            self.process_inner_payload(plaintext, &session, client_addr)
1933                .await?;
1934        }
1935
1936        Ok(())
1937    }
1938
1939    /// Compute nonce from counter
1940    fn compute_nonce(&self, counter: u64) -> [u8; NONCE_SIZE] {
1941        let mut nonce = [0u8; NONCE_SIZE];
1942        nonce[0..8].copy_from_slice(&counter.to_le_bytes());
1943        nonce
1944    }
1945
1946    /// Process decrypted inner payload
1947    async fn process_inner_payload(
1948        &self,
1949        plaintext: &[u8],
1950        session: &Arc<parking_lot::Mutex<Session>>,
1951        client_addr: SocketAddr,
1952    ) -> Result<()> {
1953        if plaintext.len() < 4 {
1954            return Err(Error::InvalidPacket("Inner payload too short"));
1955        }
1956
1957        let inner_header = InnerHeader::decode(plaintext)?;
1958        let payload = &plaintext[4..];
1959
1960        match inner_header.inner_type {
1961            InnerType::Data => {
1962                // Forward to NAT/internet via TUN write channel (lock-free)
1963                debug!(
1964                    "DATA packet from {} ({} bytes)",
1965                    hash_addr(&client_addr),
1966                    payload.len()
1967                );
1968
1969                if let Some(ref tx) = self.tun_write_tx {
1970                    if tx.send(payload.to_vec()).await.is_err() {
1971                        debug!("TUN write channel closed, dropping packet");
1972                    }
1973                } else if let Some(ref nat) = self.nat_forwarder {
1974                    nat.forward_packet(payload).await?;
1975                } else {
1976                    debug!("NAT disabled, dropping packet");
1977                }
1978            }
1979            InnerType::Control => {
1980                self.handle_control_message(payload, session, client_addr)
1981                    .await?;
1982            }
1983            InnerType::Fragment => {
1984                // TODO: Implement fragmentation
1985                debug!("FRAGMENT packet (not implemented)");
1986            }
1987            InnerType::Ack => {
1988                // Handle ACK
1989                debug!("ACK packet received");
1990            }
1991        }
1992
1993        Ok(())
1994    }
1995
1996    /// Handle control message
1997    async fn handle_control_message(
1998        &self,
1999        payload: &[u8],
2000        session: &Arc<parking_lot::Mutex<Session>>,
2001        client_addr: SocketAddr,
2002    ) -> Result<()> {
2003        let control = ControlPayload::decode(payload)?;
2004
2005        match control {
2006            ControlPayload::KeyRotate { new_eph_pub: _ } => {
2007                info!("Key rotation request from {}", hash_addr(&client_addr));
2008                // TODO: Implement key rotation
2009            }
2010            ControlPayload::MaskUpdate { .. } => {
2011                warn!("Unexpected MASK_UPDATE from client");
2012            }
2013            ControlPayload::Keepalive => {
2014                debug!("Keepalive from {}", hash_addr(&client_addr));
2015                if !session.lock().is_ratcheted {
2016                    // The client is still retrying the initial handshake. If the
2017                    // first ServerHello was lost, replying with a normal pre-ratchet
2018                    // ControlAck leaves the client stuck forever.
2019                    self.send_server_hello(session, client_addr).await?;
2020                    return Ok(());
2021                }
2022                // Send ACK
2023                let ack = ControlPayload::ControlAck {
2024                    ack_seq: 0,
2025                    ack_for_subtype: ControlSubtype::Keepalive as u8,
2026                };
2027                self.send_control_message(&ack, session).await?;
2028            }
2029            ControlPayload::TelemetryRequest { metric_flags: _ } => {
2030                debug!("Telemetry request from {}", hash_addr(&client_addr));
2031                // Send response
2032                let response = ControlPayload::TelemetryResponse {
2033                    packet_loss: 0,
2034                    rtt_ms: 10,
2035                    jitter_ms: 2,
2036                    buffer_pct: 25,
2037                };
2038                self.send_control_message(&response, session).await?;
2039            }
2040            ControlPayload::TelemetryResponse { .. } => {
2041                debug!("Telemetry response received");
2042            }
2043            ControlPayload::TimeSync { .. } => {
2044                debug!("Time sync request");
2045            }
2046            ControlPayload::Shutdown { reason } => {
2047                info!(
2048                    "Shutdown request from {} (reason: {})",
2049                    hash_addr(&client_addr),
2050                    reason
2051                );
2052                // Close session and stop active recording if any
2053                let session_id = session.lock().session_id;
2054                self.session_manager.remove_session(&session_id);
2055                self.neural_module.lock().cleanup_stats(session_id);
2056                if let Some(ref recorder) = self.recording_manager {
2057                    let socket = self.udp_socket.as_ref().unwrap().clone();
2058                    let store = recorder.store();
2059                    let mdh = self.mask_catalog.packet_mdh_bytes();
2060                    let outcome = recorder.stop_for_session_end(session_id);
2061                    Self::handle_recording_outcome(
2062                        &socket,
2063                        &self.session_manager,
2064                        &store,
2065                        &mdh,
2066                        outcome,
2067                        None,
2068                    )
2069                    .await;
2070                }
2071            }
2072            ControlPayload::ControlAck { .. } => {
2073                // ACK received, nothing to do
2074            }
2075            ControlPayload::ServerHello { .. } => {
2076                warn!(
2077                    "Unexpected ServerHello from client {}",
2078                    hash_addr(&client_addr)
2079                );
2080            }
2081            ControlPayload::RecordingStart { service } => {
2082                // Only allow from admin sessions (check client_id)
2083                let admin_key_id = {
2084                    let sess = session.lock();
2085                    sess.client_id.clone()
2086                };
2087                if !self.can_start_recording(admin_key_id.as_deref()) {
2088                    warn!(
2089                        "Recording rejected: unauthenticated client {}",
2090                        hash_addr(&client_addr)
2091                    );
2092                    let failed = ControlPayload::RecordingFailed {
2093                        reason: "Recording requires a recording-admin key".into(),
2094                    };
2095                    self.send_control_message(&failed, session).await?;
2096                    return Ok(());
2097                }
2098                if let Some(ref recorder) = self.recording_manager {
2099                    let session_id = session.lock().session_id;
2100                    recorder.start(
2101                        session_id,
2102                        service.clone(),
2103                        admin_key_id.unwrap_or_else(|| "admin".into()),
2104                    );
2105                    let ack = ControlPayload::RecordingAck {
2106                        session_id,
2107                        status: "started".into(),
2108                    };
2109                    self.send_control_message(&ack, session).await?;
2110                    info!(
2111                        "Recording started for '{}' from {}",
2112                        service,
2113                        hash_addr(&client_addr)
2114                    );
2115                }
2116            }
2117            ControlPayload::RecordingStop {
2118                session_id: rec_session_id,
2119            } => {
2120                if let Some(ref recorder) = self.recording_manager {
2121                    let owner_session_id = session.lock().session_id;
2122                    if rec_session_id != owner_session_id {
2123                        let failed = ControlPayload::RecordingFailed {
2124                            reason: "Recording session does not belong to this client".into(),
2125                        };
2126                        self.send_control_message(&failed, session).await?;
2127                        return Ok(());
2128                    }
2129
2130                    let socket = self.udp_socket.as_ref().unwrap().clone();
2131                    let store = recorder.store();
2132                    let mdh = self.mask_catalog.packet_mdh_bytes();
2133                    let outcome = recorder.stop(owner_session_id);
2134                    Self::handle_recording_outcome(
2135                        &socket,
2136                        &self.session_manager,
2137                        &store,
2138                        &mdh,
2139                        outcome,
2140                        Some(session.clone()),
2141                    )
2142                    .await;
2143                }
2144            }
2145            ControlPayload::RecordingStatusRequest => {
2146                let client_id = {
2147                    let sess = session.lock();
2148                    sess.client_id.clone()
2149                };
2150                let can_record = self.can_start_recording(client_id.as_deref());
2151                let active_service = self
2152                    .recording_manager
2153                    .as_ref()
2154                    .and_then(|recorder| recorder.status(&session.lock().session_id))
2155                    .map(|status| status.service);
2156                let response = ControlPayload::RecordingStatus {
2157                    can_record,
2158                    active_service,
2159                };
2160                self.send_control_message(&response, session).await?;
2161            }
2162            ControlPayload::RecordingAck { .. } => {
2163                // Client-side only, ignore on server
2164            }
2165            ControlPayload::RecordingComplete { .. } => {
2166                // Client-side only, ignore on server
2167            }
2168            ControlPayload::RecordingFailed { .. } => {
2169                // Client-side only, ignore on server
2170            }
2171            ControlPayload::RecordingStatus { .. } => {
2172                // Client-side only, ignore on server
2173            }
2174            ControlPayload::BootstrapDescriptorUpdate { .. } => {
2175                // Client-side only, ignore on server
2176            }
2177        }
2178
2179        Ok(())
2180    }
2181
2182    /// Send control message to client
2183    async fn send_control_message(
2184        &self,
2185        payload: &ControlPayload,
2186        session: &Arc<parking_lot::Mutex<Session>>,
2187    ) -> Result<()> {
2188        let socket = self.udp_socket.as_ref().unwrap();
2189        let mdh = {
2190            let mut sess = session.lock();
2191            sess.commit_pending_mask();
2192            sess.mask
2193                .as_ref()
2194                .map(packet_mdh_bytes_for_mask)
2195                .unwrap_or_else(|| self.mask_catalog.packet_mdh_bytes())
2196        };
2197        Self::send_control_message_via(socket, &mdh, payload, session).await
2198    }
2199
2200    async fn send_control_message_via(
2201        socket: &UdpSocket,
2202        mdh: &[u8],
2203        payload: &ControlPayload,
2204        session: &Arc<parking_lot::Mutex<Session>>,
2205    ) -> Result<()> {
2206        let encoded = payload.encode()?;
2207        let (mut inner_payload, nonce, counter, keys, client_addr) = {
2208            let mut sess = session.lock();
2209            let inner_header = InnerHeader {
2210                inner_type: InnerType::Control,
2211                seq_num: sess.next_seq() as u16,
2212            };
2213            let inner_payload = inner_header.encode().to_vec();
2214            let (nonce, counter) = sess.next_send_nonce();
2215            let keys = sess.keys.clone();
2216            let client_addr = sess.client_addr;
2217            (inner_payload, nonce, counter, keys, client_addr)
2218        };
2219        inner_payload.extend_from_slice(&encoded);
2220        let pad_len = 16u16;
2221        let mut padded = Vec::with_capacity(2 + inner_payload.len() + pad_len as usize);
2222        padded.extend_from_slice(&pad_len.to_le_bytes());
2223        padded.extend_from_slice(&inner_payload);
2224        {
2225            use rand::Rng;
2226            let mut rng = rand::thread_rng();
2227            for _ in 0..pad_len {
2228                padded.push(rng.gen::<u8>());
2229            }
2230        }
2231        let ciphertext = encrypt_payload(&keys.session_key, &nonce, &padded)?;
2232        let time_window = crypto::compute_time_window(
2233            crypto::current_timestamp_ms(),
2234            aivpn_common::crypto::DEFAULT_WINDOW_MS,
2235        );
2236        let tag = crypto::generate_resonance_tag(&keys.tag_secret, counter, time_window);
2237        let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
2238        packet.extend_from_slice(&tag);
2239        packet.extend_from_slice(mdh);
2240        packet.extend_from_slice(&ciphertext);
2241        socket.send_to(&packet, client_addr).await?;
2242        Ok(())
2243    }
2244
2245    async fn send_server_hello(
2246        &self,
2247        session: &Arc<parking_lot::Mutex<Session>>,
2248        client_addr: SocketAddr,
2249    ) -> Result<()> {
2250        let (server_eph_pub, signature, network_config) = {
2251            let sess = session.lock();
2252            match (sess.server_eph_pub, sess.server_hello_signature) {
2253                (Some(pub_key), Some(sig)) => {
2254                    let network_config = sess
2255                        .vpn_ip
2256                        .and_then(|vpn_ip| self.config.network_config.client_config(vpn_ip).ok());
2257                    (pub_key, sig, network_config)
2258                }
2259                _ => return Err(Error::Session("Missing ratchet data".into())),
2260            }
2261        };
2262
2263        let hello = ControlPayload::ServerHello {
2264            server_eph_pub,
2265            signature,
2266            network_config,
2267        };
2268        let encoded = hello.encode()?;
2269        let inner_header = InnerHeader {
2270            inner_type: InnerType::Control,
2271            seq_num: 0,
2272        };
2273        let mut inner_payload = inner_header.encode().to_vec();
2274        inner_payload.extend_from_slice(&encoded);
2275        let packet = self.build_packet(&inner_payload, session)?;
2276        let socket = self.udp_socket.as_ref().unwrap();
2277        let sent = socket.send_to(&packet, client_addr).await?;
2278        debug!("ServerHello sent: {} bytes to {}", sent, client_addr);
2279        Ok(())
2280    }
2281
2282    /// Build AIVPN packet
2283    /// Wire format: TAG | MDH | encrypt(pad_len_u16 || plaintext || random_padding)
2284    fn build_packet(
2285        &self,
2286        plaintext: &[u8],
2287        session: &Arc<parking_lot::Mutex<Session>>,
2288    ) -> Result<Vec<u8>> {
2289        let mut sess = session.lock();
2290
2291        // Use unified counter for both nonce and tag
2292        let (nonce, counter) = sess.next_send_nonce();
2293
2294        // Build padded plaintext: pad_len(u16) || plaintext || random_padding
2295        // pad_len is inside encryption — invisible to DPI (CRIT-5 fix)
2296        let pad_len = 16u16;
2297        let mut padded = Vec::with_capacity(2 + plaintext.len() + pad_len as usize);
2298        padded.extend_from_slice(&pad_len.to_le_bytes());
2299        padded.extend_from_slice(plaintext);
2300        use rand::Rng;
2301        let mut rng = rand::thread_rng();
2302        for _ in 0..pad_len {
2303            padded.push(rng.gen::<u8>());
2304        }
2305
2306        let ciphertext = encrypt_payload(&sess.keys.session_key, &nonce, &padded)?;
2307
2308        // Generate tag
2309        let time_window = crypto::compute_time_window(
2310            crypto::current_timestamp_ms(),
2311            aivpn_common::crypto::DEFAULT_WINDOW_MS,
2312        );
2313        let tag = crypto::generate_resonance_tag(&sess.keys.tag_secret, counter, time_window);
2314        let current_mask = sess.mask.clone();
2315        drop(sess);
2316
2317        // Build MDH using the session's current packet mask so the peer can
2318        // decode bootstrap traffic before any runtime MaskUpdate arrives.
2319        let mdh = current_mask
2320            .as_ref()
2321            .map(packet_mdh_bytes_for_mask)
2322            .unwrap_or_else(|| self.mask_catalog.packet_mdh_bytes());
2323
2324        // Assemble packet: TAG | MDH | ciphertext (no cleartext padding)
2325        let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
2326        packet.extend_from_slice(&tag);
2327        packet.extend_from_slice(&mdh);
2328        packet.extend_from_slice(&ciphertext);
2329
2330        Ok(packet)
2331    }
2332
2333    /// Compute Shannon entropy of a byte slice (0.0 = uniform, 8.0 = max)
2334    fn compute_entropy(data: &[u8]) -> f64 {
2335        if data.is_empty() {
2336            return 0.0;
2337        }
2338        let mut counts = [0u32; 256];
2339        for &b in data {
2340            counts[b as usize] += 1;
2341        }
2342        let len = data.len() as f64;
2343        let mut entropy = 0.0;
2344        for &c in &counts {
2345            if c > 0 {
2346                let p = c as f64 / len;
2347                entropy -= p * p.log2();
2348            }
2349        }
2350        entropy
2351    }
2352
2353    /// Get mask catalog reference
2354    pub fn mask_catalog(&self) -> &Arc<MaskCatalog> {
2355        &self.mask_catalog
2356    }
2357
2358    /// Get metrics reference
2359    pub fn metrics(&self) -> &Arc<MetricsCollector> {
2360        &self.metrics
2361    }
2362}
2363
2364#[cfg(test)]
2365mod tests {
2366    use super::MaskCatalog;
2367    use aivpn_common::crypto::TAG_SIZE;
2368    use aivpn_common::mask::preset_masks::webrtc_zoom_v3;
2369
2370    #[test]
2371    fn packet_layout_extracts_embedded_eph_pub_from_mdh() {
2372        let catalog = MaskCatalog::new();
2373        let mask = webrtc_zoom_v3();
2374        catalog.register_mask(mask.clone());
2375        catalog.set_primary_mask_id(mask.mask_id.clone());
2376        let (packet_mdh_len, handshake_mdh_len, eph_offset, eph_len) = catalog.packet_layout();
2377
2378        let mut mdh = mask.header_template.clone();
2379        if mdh.len() < handshake_mdh_len {
2380            mdh.resize(handshake_mdh_len, 0);
2381        }
2382
2383        let expected_eph = [0x5au8; 32];
2384        mdh[eph_offset..eph_offset + eph_len].copy_from_slice(&expected_eph);
2385
2386        let mut packet = vec![0u8; TAG_SIZE];
2387        packet.extend_from_slice(&mdh);
2388        packet.extend_from_slice(&[0xabu8; 24]);
2389
2390        let eph_start = TAG_SIZE + eph_offset;
2391        let payload_start = TAG_SIZE + handshake_mdh_len;
2392
2393        assert_eq!(
2394            packet_mdh_len, 20,
2395            "regular STUN packet MDH length must stay at 20 bytes"
2396        );
2397        assert_eq!(
2398            handshake_mdh_len, 52,
2399            "handshake MDH length must include embedded eph_pub"
2400        );
2401        assert_eq!(&packet[eph_start..eph_start + eph_len], &expected_eph);
2402        assert_eq!(&packet[payload_start..], &[0xabu8; 24]);
2403    }
2404}