Skip to main content

aivpn_server/
session.rs

1//! Session Manager
2//!
3//! Manages active VPN sessions with O(1) tag validation
4
5use std::collections::{BTreeSet, HashMap};
6use std::net::{Ipv4Addr, SocketAddr};
7use std::sync::Arc;
8
9use std::time::{Duration, Instant};
10
11use chacha20poly1305::aead::OsRng;
12use dashmap::DashMap;
13use hex;
14use parking_lot::Mutex;
15use rand::RngCore;
16use subtle::ConstantTimeEq;
17use tracing::{debug, info, trace};
18
19use aivpn_common::crypto::{
20    self, KeyPair, SessionKeys, DEFAULT_WINDOW_MS, NONCE_SIZE, TAG_SIZE, X25519_PUBLIC_KEY_SIZE,
21};
22use aivpn_common::error::{Error, Result};
23use aivpn_common::mask::MaskProfile;
24use aivpn_common::protocol::{ControlPayload, InnerHeader, InnerType};
25
26/// Maximum sessions on 1GB VPS
27pub const MAX_SESSIONS: usize = 500;
28
29/// Session idle timeout (default)
30pub const IDLE_TIMEOUT: Duration = Duration::from_secs(300);
31
32/// Session hard timeout — 0 means unlimited (Issue #33).
33/// Configurable via `session_timeout_secs` in server.json.
34/// PFS ratchet already handles key rotation, so forced session
35/// expiration is unnecessary and causes reconnect failures.
36pub const HARD_TIMEOUT: Duration = Duration::ZERO;
37
38/// Tag window size (allow out-of-order packets)
39pub const TAG_WINDOW_SIZE: usize = 256;
40
41/// Session state
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SessionState {
44    Pending,
45    Active,
46    Idle,
47    Rotating,
48    MaskChange,
49    Expired,
50    Closed,
51}
52
53/// Session information
54pub struct Session {
55    pub session_id: [u8; 16],
56    pub client_addr: SocketAddr,
57    pub state: SessionState,
58    pub keys: SessionKeys,
59    pub eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
60
61    /// Packet counter for tag generation
62    pub counter: u64,
63    /// Last seen timestamp
64    pub last_seen: Instant,
65    /// Created timestamp
66    pub created_at: Instant,
67    /// Last server-to-client packet timestamp (for downlink recording IAT)
68    pub last_server_send: Instant,
69
70    /// Current mask profile
71    pub mask: Option<MaskProfile>,
72    /// Pending mask awaiting grace period before activation.
73    /// Stored as (new_mask, timestamp_when_MaskUpdate_was_sent).
74    pub pending_mask: Option<(MaskProfile, Instant)>,
75    /// Current FSM state
76    pub fsm_state: u16,
77    /// Packets in current FSM state
78    pub fsm_packets: u32,
79    /// Duration in current FSM state
80    pub fsm_state_start: Instant,
81
82    /// Sequence number for outgoing packets
83    pub send_seq: u32,
84    /// Last received sequence (for ACK)
85    pub recv_seq: u32,
86    /// Send counter for nonce generation (u64, same space as tags)
87    pub send_counter: u64,
88
89    /// Expected tags (counter -> tag)
90    pub expected_tags: HashMap<u64, [u8; TAG_SIZE]>,
91    /// Counter value used as the base for the currently precomputed tag window.
92    pub tag_window_base: u64,
93    /// Received tag bitmap (for anti-replay)
94    pub received_bitmap: u256,
95    /// Accumulated inbound bytes to flush into client_db in batches.
96    pub pending_bytes_in: u64,
97    /// Accumulated outbound (downlink) bytes to flush into client_db in batches.
98    pub pending_bytes_out: u64,
99
100    // --- PFS Ratchet fields (CRIT-3) ---
101    /// Server's ephemeral public key for this session
102    pub server_eph_pub: Option<[u8; 32]>,
103    /// Ed25519 signature for ServerHello
104    pub server_hello_signature: Option<[u8; 64]>,
105    /// Ratcheted session keys (PFS)
106    pub ratcheted_keys: Option<SessionKeys>,
107    /// Ratcheted tags for validation (counter -> tag)
108    pub ratcheted_expected_tags: HashMap<u64, [u8; TAG_SIZE]>,
109    /// Whether session has completed PFS ratchet
110    pub is_ratcheted: bool,
111    /// Assigned VPN IP (e.g. 10.0.0.2)
112    pub vpn_ip: Option<Ipv4Addr>,
113    /// Registered client ID (from client_db) for traffic accounting
114    pub client_id: Option<String>,
115
116    /// Pre-ratchet expected tags preserved for a 2-second grace window after
117    /// complete_ratchet() so client packets that were already in-flight with
118    /// the old keys are not silently dropped as unrecognised.
119    pub pre_ratchet_tags: HashMap<u64, [u8; TAG_SIZE]>,
120    /// Deadline until which pre_ratchet_tags are still accepted.
121    pub pre_ratchet_expire: Option<Instant>,
122}
123
124/// 256-bit bitmap for tracking received packets
125#[derive(Debug, Clone, Copy, Default)]
126#[allow(non_camel_case_types)]
127pub struct u256 {
128    lo: u128,
129    hi: u128,
130}
131
132impl u256 {
133    pub fn set_bit(&mut self, bit: usize) {
134        if bit < 128 {
135            self.lo |= 1u128 << bit;
136        } else {
137            self.hi |= 1u128 << (bit - 128);
138        }
139    }
140
141    pub fn shift_left(&mut self, shift: usize) {
142        if shift == 0 {
143            return;
144        }
145        if shift >= 256 {
146            self.clear();
147            return;
148        }
149        if shift >= 128 {
150            self.hi = self.lo << (shift - 128);
151            self.lo = 0;
152            return;
153        }
154
155        self.hi = (self.hi << shift) | (self.lo >> (128 - shift));
156        self.lo <<= shift;
157    }
158
159    pub fn get_bit(&self, bit: usize) -> bool {
160        if bit < 128 {
161            (self.lo & (1u128 << bit)) != 0
162        } else {
163            (self.hi & (1u128 << (bit - 128))) != 0
164        }
165    }
166
167    pub fn clear(&mut self) {
168        self.lo = 0;
169        self.hi = 0;
170    }
171}
172
173impl Session {
174    pub fn new(
175        session_id: [u8; 16],
176        client_addr: SocketAddr,
177        keys: SessionKeys,
178        eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
179    ) -> Self {
180        let now = Instant::now();
181        Self {
182            session_id,
183            client_addr,
184            state: SessionState::Pending,
185            keys,
186            eph_pub,
187            counter: 0,
188            last_seen: now,
189            created_at: now,
190            last_server_send: now,
191            mask: None,
192            pending_mask: None,
193            fsm_state: 0,
194            fsm_packets: 0,
195            fsm_state_start: now,
196            send_seq: 0,
197            recv_seq: 0,
198            send_counter: 0,
199            expected_tags: HashMap::with_capacity(TAG_WINDOW_SIZE),
200            tag_window_base: 0,
201            received_bitmap: u256::default(),
202            pending_bytes_in: 0,
203            pending_bytes_out: 0,
204            server_eph_pub: None,
205            server_hello_signature: None,
206            ratcheted_keys: None,
207            ratcheted_expected_tags: HashMap::new(),
208            is_ratcheted: false,
209            vpn_ip: None,
210            client_id: None,
211            pre_ratchet_tags: HashMap::new(),
212            pre_ratchet_expire: None,
213        }
214    }
215
216    /// Compute next nonce for encryption from send_counter (u64)
217    /// Uses the same counter space as tag generation for consistency
218    pub fn next_send_nonce(&mut self) -> ([u8; NONCE_SIZE], u64) {
219        let counter = self.send_counter;
220        let mut nonce = [0u8; NONCE_SIZE];
221        nonce[0..8].copy_from_slice(&counter.to_le_bytes());
222        self.send_counter += 1;
223        (nonce, counter)
224    }
225
226    /// Update expected tags for validation window
227    pub fn update_tag_window(&mut self) {
228        let time_window =
229            crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
230
231        // Pre-compute tags for a bidirectional window around the highest
232        // validated counter so minor UDP reordering does not fall out of the
233        // fast path lookup map.
234        self.expected_tags.clear();
235        self.tag_window_base = self.counter;
236        let window_back = TAG_WINDOW_SIZE as u64 - 1;
237        let window_start = self.counter.saturating_sub(window_back);
238        let window_end = self.counter.saturating_add(TAG_WINDOW_SIZE as u64 - 1);
239
240        for counter_val in window_start..=window_end {
241            let tag =
242                crypto::generate_resonance_tag(&self.keys.tag_secret, counter_val, time_window);
243            self.expected_tags.insert(counter_val, tag);
244        }
245    }
246
247    /// Validate received tag (constant-time)
248    /// Returns (counter, is_ratcheted_tag) if valid.
249    /// Checks the current time window first, then adjacent windows (±1)
250    /// for clock skew tolerance.
251    pub fn validate_tag(&self, tag: &[u8; TAG_SIZE]) -> Option<(u64, bool)> {
252        let is_replay = |counter_val: u64| {
253            if counter_val > self.counter {
254                return false;
255            }
256
257            let bit_index = (self.counter - counter_val) as usize;
258            bit_index < TAG_WINDOW_SIZE && self.received_bitmap.get_bit(bit_index)
259        };
260
261        let history_window = TAG_WINDOW_SIZE as u64 - 1;
262        let window_start = self.counter.saturating_sub(history_window);
263        let window_end = self.counter.saturating_add(TAG_WINDOW_SIZE as u64 - 1);
264
265        // Check initial keys — current time window (pre-computed)
266        for (counter, expected) in &self.expected_tags {
267            if bool::from(expected.ct_eq(tag)) {
268                if is_replay(*counter) {
269                    return None; // Already received
270                }
271                return Some((*counter, false));
272            }
273        }
274        // Check adjacent time windows (±1) on-the-fly for clock skew
275        let current_tw =
276            crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
277        for tw_offset in [current_tw.wrapping_sub(1), current_tw.wrapping_add(1)] {
278            for counter_val in window_start..=window_end {
279                let expected =
280                    crypto::generate_resonance_tag(&self.keys.tag_secret, counter_val, tw_offset);
281                if bool::from(expected.ct_eq(tag)) {
282                    if is_replay(counter_val) {
283                        return None;
284                    }
285                    return Some((counter_val, false));
286                }
287            }
288        }
289        // Check pre-ratchet tags during grace window (in-flight packets from client
290        // that were encrypted with old keys before it switched to ratcheted ones).
291        if let Some(expire) = self.pre_ratchet_expire {
292            if Instant::now() < expire {
293                for (counter, expected) in &self.pre_ratchet_tags {
294                    if bool::from(expected.ct_eq(tag)) {
295                        if is_replay(*counter) {
296                            return None;
297                        }
298                        return Some((*counter, false));
299                    }
300                }
301            }
302        }
303
304        // Check ratcheted keys (only during transition, before ratchet is complete)
305        if !self.is_ratcheted {
306            for (counter, expected) in &self.ratcheted_expected_tags {
307                if bool::from(expected.ct_eq(tag)) {
308                    return Some((*counter, true));
309                }
310            }
311            // Also check adjacent windows for ratcheted keys
312            if let Some(ratcheted_keys) = &self.ratcheted_keys {
313                for tw_offset in [current_tw.wrapping_sub(1), current_tw.wrapping_add(1)] {
314                    for i in 0..TAG_WINDOW_SIZE {
315                        let expected = crypto::generate_resonance_tag(
316                            &ratcheted_keys.tag_secret,
317                            i as u64,
318                            tw_offset,
319                        );
320                        if bool::from(expected.ct_eq(tag)) {
321                            return Some((i as u64, true));
322                        }
323                    }
324                }
325            }
326        }
327        None
328    }
329
330    /// Mark tag as received
331    pub fn mark_tag_received(&mut self, counter: u64) {
332        if counter > self.counter {
333            let shift = (counter - self.counter) as usize;
334            self.received_bitmap.shift_left(shift);
335            self.counter = counter;
336            self.received_bitmap.set_bit(0);
337            return;
338        }
339
340        let bit_index = (self.counter - counter) as usize;
341        if bit_index < 256 {
342            self.received_bitmap.set_bit(bit_index);
343        }
344    }
345
346    /// Get next sequence number for inner header
347    pub fn next_seq(&mut self) -> u32 {
348        let seq = self.send_seq;
349        self.send_seq = self.send_seq.wrapping_add(1);
350        seq
351    }
352
353    /// Update FSM state
354    pub fn update_fsm(&mut self) {
355        if let Some(mask) = &self.mask {
356            let duration_ms = self.fsm_state_start.elapsed().as_millis() as u64;
357            let (new_state, _size_override, _iat_override, _padding_override) =
358                mask.process_transition(self.fsm_state, self.fsm_packets, duration_ms);
359
360            if new_state != self.fsm_state {
361                self.fsm_state = new_state;
362                self.fsm_packets = 0;
363                self.fsm_state_start = Instant::now();
364            }
365        }
366        self.fsm_packets += 1;
367    }
368
369    /// Check if session is idle
370    pub fn is_idle(&self) -> bool {
371        self.last_seen.elapsed() > IDLE_TIMEOUT
372    }
373
374    /// Check if session is expired
375    pub fn is_expired(&self) -> bool {
376        self.created_at.elapsed() > HARD_TIMEOUT
377    }
378
379    /// Pre-compute tags for ratcheted keys
380    pub fn update_ratcheted_tag_window(&mut self) {
381        if let Some(ratcheted_keys) = &self.ratcheted_keys {
382            let time_window =
383                crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
384            self.ratcheted_expected_tags.clear();
385            // Ratcheted counter starts at 0
386            for i in 0..TAG_WINDOW_SIZE {
387                let tag = crypto::generate_resonance_tag(
388                    &ratcheted_keys.tag_secret,
389                    i as u64,
390                    time_window,
391                );
392                self.ratcheted_expected_tags.insert(i as u64, tag);
393            }
394        }
395    }
396
397    /// Complete PFS ratchet: switch to ratcheted keys, zeroize old ones
398    pub fn complete_ratchet(&mut self) {
399        if let Some(ratcheted_keys) = self.ratcheted_keys.take() {
400            // Preserve old expected_tags for 2 s so client packets that were
401            // already in-flight with the pre-ratchet keys are not dropped.
402            self.pre_ratchet_tags = std::mem::take(&mut self.expected_tags);
403            self.pre_ratchet_expire = Some(Instant::now() + Duration::from_secs(2));
404
405            self.keys = ratcheted_keys;
406            self.counter = 0;
407            self.send_counter = 0;
408            self.tag_window_base = self.counter;
409            self.expected_tags = std::mem::take(&mut self.ratcheted_expected_tags);
410            self.received_bitmap.clear();
411            self.pending_bytes_in = 0;
412            self.pending_bytes_out = 0;
413            self.is_ratcheted = true;
414            self.server_eph_pub = None;
415            self.server_hello_signature = None;
416        }
417    }
418
419    /// Check and commit a pending mask if the grace period has elapsed.
420    /// Returns true if a mask was committed.
421    /// Grace period = 500ms — enough for the MaskUpdate packet to reach the client.
422    pub fn commit_pending_mask(&mut self) -> bool {
423        const MASK_GRACE_PERIOD: Duration = Duration::from_millis(500);
424        if let Some((_, sent_at)) = &self.pending_mask {
425            if sent_at.elapsed() >= MASK_GRACE_PERIOD {
426                let (new_mask, _) = self.pending_mask.take().unwrap();
427                info!("Committing deferred mask switch to '{}'", new_mask.mask_id);
428                self.mask = Some(new_mask);
429                // Reset FSM state for the new mask
430                self.fsm_state = 0;
431                self.fsm_packets = 0;
432                self.fsm_state_start = Instant::now();
433                return true;
434            }
435        }
436        false
437    }
438}
439
440/// Session Manager with O(1) tag lookup
441pub struct SessionManager {
442    /// Sessions by ID
443    sessions: DashMap<[u8; 16], Arc<Mutex<Session>>>,
444    /// Tag -> Session ID mapping for O(1) lookup
445    tag_map: DashMap<[u8; TAG_SIZE], [u8; 16]>,
446    /// VPN IP -> Session ID mapping for TUN return routing
447    vpn_ip_map: DashMap<Ipv4Addr, [u8; 16]>,
448    /// Next VPN IP to assign (last octet)
449    /// Pool of free VPN IP octets (2..=254). IPs are returned when sessions end.
450    ip_pool: Mutex<BTreeSet<u8>>,
451    /// Server's long-term keypair
452    server_keys: KeyPair,
453    /// Server's signing key (Ed25519)
454    signing_key: ed25519_dalek::SigningKey,
455    /// Default mask profile
456    default_mask: MaskProfile,
457    /// Configurable session hard timeout
458    hard_timeout: Duration,
459    /// Configurable session idle timeout
460    idle_timeout: Duration,
461}
462
463impl SessionManager {
464    pub fn new(
465        server_keys: KeyPair,
466        signing_key: ed25519_dalek::SigningKey,
467        default_mask: MaskProfile,
468    ) -> Self {
469        Self::with_timeouts(server_keys, signing_key, default_mask, None, None)
470    }
471
472    pub fn with_timeouts(
473        server_keys: KeyPair,
474        signing_key: ed25519_dalek::SigningKey,
475        default_mask: MaskProfile,
476        session_timeout_secs: Option<u64>,
477        idle_timeout_secs: Option<u64>,
478    ) -> Self {
479        let hard_timeout = session_timeout_secs
480            .map(|s| Duration::from_secs(s))
481            .unwrap_or(HARD_TIMEOUT);
482        let idle_timeout = idle_timeout_secs
483            .map(|s| Duration::from_secs(s))
484            .unwrap_or(IDLE_TIMEOUT);
485        Self {
486            sessions: DashMap::new(),
487            tag_map: DashMap::new(),
488            vpn_ip_map: DashMap::new(),
489            ip_pool: Mutex::new((2..=254u8).collect()),
490            server_keys,
491            signing_key,
492            default_mask,
493            hard_timeout,
494            idle_timeout,
495        }
496    }
497
498    /// Create new session from initial packet.
499    /// NOTE: Does NOT remove old sessions for the same client IP.
500    /// The caller must call `cleanup_old_sessions_for_ip()` after
501    /// validating that the new session is legitimate (tag matches).
502    pub fn create_session(
503        &self,
504        client_addr: SocketAddr,
505        eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
506        preshared_key: Option<[u8; 32]>,
507        static_vpn_ip: Option<Ipv4Addr>,
508    ) -> Result<Arc<Mutex<Session>>> {
509        // Look for a reusable VPN IP from an existing session for the same
510        // client IP, but do NOT remove the old session yet — the caller
511        // will do that only after the handshake tag validates.
512        let reused_vpn_ip: Option<Ipv4Addr> = self
513            .sessions
514            .iter()
515            .filter_map(|entry| {
516                let session = entry.value().lock();
517                if session.client_addr.ip() == client_addr.ip() {
518                    session.vpn_ip
519                } else {
520                    None
521                }
522            })
523            .next();
524
525        if self.sessions.len() >= MAX_SESSIONS {
526            return Err(Error::Session("Max sessions reached".into()));
527        }
528
529        // MED-6: Per-IP session limit (max 5 sessions per IP)
530        let ip_count = self
531            .sessions
532            .iter()
533            .filter(|e| e.value().lock().client_addr.ip() == client_addr.ip())
534            .count();
535        if ip_count >= 5 {
536            return Err(Error::Session("Per-IP session limit reached".into()));
537        }
538
539        // Prevent VPN IP pool exhaustion: cap concurrent sessions per /24 subnet.
540        // The per-IP cap of 5 alone is insufficient — a spoofed-source flood from
541        // 51 distinct IPs in one /24 can drain all 253 assignable VPN addresses
542        // while remaining within the per-IP limit.
543        if let std::net::IpAddr::V4(v4) = client_addr.ip() {
544            let subnet24 = u32::from(v4) >> 8;
545            let subnet_count = self
546                .sessions
547                .iter()
548                .filter(|e| {
549                    if let std::net::IpAddr::V4(ip) = e.value().lock().client_addr.ip() {
550                        (u32::from(ip) >> 8) == subnet24
551                    } else {
552                        false
553                    }
554                })
555                .count();
556            if subnet_count >= 10 {
557                return Err(Error::Session(
558                    "Per-subnet (/24) session limit reached".into(),
559                ));
560            }
561        }
562
563        // DH1: server_static * client_eph → initial keys (0-RTT)
564        let dh1 = self.server_keys.compute_shared(&eph_pub)?;
565        trace!("Server DH result: {}", hex::encode(&dh1));
566        trace!(
567            "Server eph_pub (after deobfuscation): {}",
568            hex::encode(&eph_pub)
569        );
570        trace!("Server PSK: {:?}", preshared_key.as_ref().map(hex::encode));
571        let initial_keys = crypto::derive_session_keys(&dh1, preshared_key.as_ref(), &eph_pub);
572        trace!(
573            "Server tag_secret: {}",
574            hex::encode(&initial_keys.tag_secret)
575        );
576
577        // --- CRIT-3 + HIGH-6: PFS ratchet preparation ---
578        // Generate server ephemeral keypair
579        let server_eph_kp = crypto::KeyPair::generate();
580        let server_eph_pub = server_eph_kp.public_key_bytes();
581
582        // DH2: server_eph * client_eph → PFS keys
583        let dh2 = server_eph_kp.compute_shared(&eph_pub)?;
584        // Use initial session_key as PSK for domain separation
585        let ratcheted_keys =
586            crypto::derive_session_keys(&dh2, Some(&initial_keys.session_key), &eph_pub);
587
588        // Sign (server_eph_pub || client_eph_pub) for server authentication (HIGH-6)
589        use ed25519_dalek::Signer;
590        let mut sign_message = Vec::with_capacity(64);
591        sign_message.extend_from_slice(&server_eph_pub);
592        sign_message.extend_from_slice(&eph_pub);
593        let signature = self.signing_key.sign(&sign_message).to_bytes();
594
595        // Generate session ID
596        let mut session_id = [0u8; 16];
597        OsRng.fill_bytes(&mut session_id);
598
599        // Create session with initial (DH1) keys
600        let session = Arc::new(Mutex::new(Session::new(
601            session_id,
602            client_addr,
603            initial_keys,
604            eph_pub,
605        )));
606
607        // Setup ratchet state + populate tag maps
608        {
609            let mut sess = session.lock();
610            sess.state = SessionState::Active;
611
612            // Store ratchet data
613            sess.server_eph_pub = Some(server_eph_pub);
614            sess.server_hello_signature = Some(signature);
615            sess.ratcheted_keys = Some(ratcheted_keys);
616
617            // Compute initial tags
618            sess.update_tag_window();
619            for tag in sess.expected_tags.values() {
620                self.tag_map.insert(*tag, session_id);
621            }
622
623            // Pre-compute ratcheted tags (for when client switches to PFS keys)
624            sess.update_ratcheted_tag_window();
625            for tag in sess.ratcheted_expected_tags.values() {
626                self.tag_map.insert(*tag, session_id);
627            }
628        }
629
630        // Insert into session map
631        self.sessions.insert(session_id, session.clone());
632
633        // Assign VPN IP and register mapping.
634        // Priority: 1) static IP from client config, 2) reused IP, 3) auto-assign
635        let vpn_ip = if let Some(ip) = static_vpn_ip.or(reused_vpn_ip) {
636            // Static or reused IP — ensure it's removed from the free pool
637            self.ip_pool.lock().remove(&ip.octets()[3]);
638            Some(ip)
639        } else {
640            // Allocate the lowest available IP from the pool
641            self.ip_pool
642                .lock()
643                .pop_first()
644                .map(|octet| Ipv4Addr::new(10, 0, 0, octet))
645        };
646
647        if let Some(vpn_ip) = vpn_ip {
648            session.lock().vpn_ip = Some(vpn_ip);
649            self.vpn_ip_map.insert(vpn_ip, session_id);
650            debug!("Assigned VPN IP {} to session", vpn_ip);
651        }
652
653        Ok(session)
654    }
655
656    /// Remove all sessions for a given IP except the specified one.
657    /// Called after a new handshake is validated to clean up stale sessions.
658    /// Returns list of removed session IDs (for stopping recordings).
659    pub fn cleanup_old_sessions_for_ip(
660        &self,
661        ip: &std::net::IpAddr,
662        keep_session_id: &[u8; 16],
663    ) -> Vec<[u8; 16]> {
664        let to_remove: Vec<[u8; 16]> = self
665            .sessions
666            .iter()
667            .filter_map(|entry| {
668                let session = entry.value().lock();
669                if session.client_addr.ip() == *ip && entry.key() != keep_session_id {
670                    Some(*entry.key())
671                } else {
672                    None
673                }
674            })
675            .collect();
676
677        let mut removed = Vec::new();
678        for session_id in to_remove {
679            info!(
680                "Removing stale session for IP {} after successful re-handshake",
681                ip
682            );
683            if self.remove_session(&session_id).is_some() {
684                removed.push(session_id);
685            }
686        }
687        removed
688    }
689
690    /// Remove old sessions for the same VPN IP (same client) except the
691    /// specified one. Unlike `cleanup_old_sessions_for_ip`, this does NOT
692    /// affect sessions belonging to other clients behind the same NAT.
693    /// Returns list of removed session IDs (for stopping recordings).
694    pub fn cleanup_old_sessions_for_vpn_ip(
695        &self,
696        vpn_ip: &Ipv4Addr,
697        keep_session_id: &[u8; 16],
698    ) -> Vec<[u8; 16]> {
699        let to_remove: Vec<[u8; 16]> = self
700            .sessions
701            .iter()
702            .filter_map(|entry| {
703                let session = entry.value().lock();
704                if session.vpn_ip == Some(*vpn_ip) && entry.key() != keep_session_id {
705                    Some(*entry.key())
706                } else {
707                    None
708                }
709            })
710            .collect();
711
712        let mut removed = Vec::new();
713        for session_id in to_remove {
714            info!(
715                "Removing stale session for VPN IP {} after successful re-handshake",
716                vpn_ip
717            );
718            if self.remove_session(&session_id).is_some() {
719                removed.push(session_id);
720            }
721        }
722        removed
723    }
724
725    /// Rollback a session that was created but failed tag validation.
726    /// Restores vpn_ip_map to the old session that still owns that IP.
727    pub fn rollback_failed_session(&self, session_id: &[u8; 16]) {
728        // Grab the VPN IP before removal so we can restore the old mapping.
729        let vpn_ip = self
730            .sessions
731            .get(session_id)
732            .map(|e| e.value().lock().vpn_ip)
733            .flatten();
734
735        self.remove_session(session_id);
736
737        // If there is still another session that owns this VPN IP, restore
738        // the mapping and take the IP back out of the free pool.
739        if let Some(vpn_ip) = vpn_ip {
740            for entry in self.sessions.iter() {
741                let sess = entry.value().lock();
742                if sess.vpn_ip == Some(vpn_ip) {
743                    self.vpn_ip_map.insert(vpn_ip, *entry.key());
744                    self.ip_pool.lock().remove(&vpn_ip.octets()[3]);
745                    break;
746                }
747            }
748        }
749    }
750
751    /// Return true when the same public IP already has a fresh ratcheted session
752    /// on a different socket endpoint. This helps ignore stale duplicate-port
753    /// probes instead of spawning a new handshake loop.
754    pub fn has_recent_ratcheted_session_on_other_endpoint(
755        &self,
756        client_addr: &SocketAddr,
757        max_age: Duration,
758    ) -> bool {
759        self.sessions.iter().any(|entry| {
760            let sess = entry.value().lock();
761            sess.client_addr.ip() == client_addr.ip()
762                && sess.client_addr != *client_addr
763                && sess.is_ratcheted
764                && sess.last_seen.elapsed() <= max_age
765        })
766    }
767
768    /// Get session by tag (O(1) lookup)
769    pub fn get_session_by_tag(&self, tag: &[u8; TAG_SIZE]) -> Option<Arc<Mutex<Session>>> {
770        if let Some(entry) = self.tag_map.get(tag) {
771            let session_id = *entry;
772            drop(entry);
773            self.sessions.get(&session_id).map(|e| e.clone())
774        } else {
775            None
776        }
777    }
778
779    /// Refresh tag windows for all sessions (time window may have advanced)
780    /// and try to find a session matching the given tag.
781    pub fn refresh_and_find_by_tag(
782        &self,
783        tag: &[u8; TAG_SIZE],
784    ) -> Option<(Arc<Mutex<Session>>, u64, bool)> {
785        for entry in self.sessions.iter() {
786            let session = entry.value().clone();
787            let session_id = *entry.key();
788            let mut sess = session.lock();
789
790            // Refresh initial key tags
791            let old_tags: Vec<[u8; TAG_SIZE]> = sess.expected_tags.values().cloned().collect();
792            for old_tag in &old_tags {
793                self.tag_map.remove(old_tag);
794            }
795            sess.update_tag_window();
796            for t in sess.expected_tags.values() {
797                self.tag_map.insert(*t, session_id);
798            }
799
800            // Refresh ratcheted key tags
801            let old_ratcheted: Vec<[u8; TAG_SIZE]> =
802                sess.ratcheted_expected_tags.values().cloned().collect();
803            for old_tag in &old_ratcheted {
804                self.tag_map.remove(old_tag);
805            }
806            sess.update_ratcheted_tag_window();
807            for t in sess.ratcheted_expected_tags.values() {
808                self.tag_map.insert(*t, session_id);
809            }
810
811            // Try to validate the tag now
812            if let Some((counter, is_ratcheted)) = sess.validate_tag(tag) {
813                drop(sess);
814                return Some((session, counter, is_ratcheted));
815            }
816        }
817        None
818    }
819
820    /// Wide-range counter recovery: brute-force search over a large counter
821    /// range to recover from counter drift (e.g., client race condition).
822    /// Only called when normal tag lookup + refresh both fail but a session
823    /// exists for this client IP.
824    pub fn recover_session_by_tag(
825        &self,
826        tag: &[u8; TAG_SIZE],
827        client_ip: &std::net::IpAddr,
828    ) -> Option<(Arc<Mutex<Session>>, u64, bool)> {
829        let current_tw =
830            crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
831        // Search up to 65536 counters ahead from the session's last known counter
832        const RECOVERY_RANGE: u64 = 65536;
833
834        for entry in self.sessions.iter() {
835            let session = entry.value().clone();
836            let session_id = *entry.key();
837            let sess = session.lock();
838            if sess.client_addr.ip() != *client_ip {
839                continue;
840            }
841
842            let base = sess.counter;
843            let tag_secret = &sess.keys.tag_secret;
844
845            for tw_offset in [0i64, -1, 1] {
846                let tw = (current_tw as i64 + tw_offset) as u64;
847                for i in 0..RECOVERY_RANGE {
848                    let c = base + i;
849                    let expected = crypto::generate_resonance_tag(tag_secret, c, tw);
850                    if bool::from(expected.ct_eq(tag)) {
851                        info!(
852                            "Counter recovery: found counter {} (drift={}) for session",
853                            c, i
854                        );
855                        // Update tag window to the recovered counter
856                        drop(sess);
857                        {
858                            let mut s = session.lock();
859                            s.counter = c;
860                            s.update_tag_window();
861                        }
862                        // Refresh tag_map
863                        self.tag_map.retain(|_, id| id != &session_id);
864                        let s = session.lock();
865                        for t in s.expected_tags.values() {
866                            self.tag_map.insert(*t, session_id);
867                        }
868                        drop(s);
869                        return Some((session, c, false));
870                    }
871                }
872            }
873        }
874        None
875    }
876
877    /// Get session by ID
878    pub fn get_session(&self, session_id: &[u8; 16]) -> Option<Arc<Mutex<Session>>> {
879        self.sessions.get(session_id).map(|e| e.clone())
880    }
881
882    /// Get session by VPN IP (for routing TUN responses back to clients)
883    pub fn get_session_by_vpn_ip(&self, vpn_ip: &Ipv4Addr) -> Option<Arc<Mutex<Session>>> {
884        if let Some(entry) = self.vpn_ip_map.get(vpn_ip) {
885            let session_id = *entry;
886            drop(entry);
887            self.sessions.get(&session_id).map(|e| e.clone())
888        } else {
889            None
890        }
891    }
892
893    /// Remove session and return its ID if it existed.
894    /// The returned session_id can be used to stop active recording.
895    pub fn remove_session(&self, session_id: &[u8; 16]) -> Option<[u8; 16]> {
896        if let Some((_, session)) = self.sessions.remove(session_id) {
897            let sess = session.lock();
898            // Remove all tags from tag map (initial + ratcheted)
899            for tag in sess.expected_tags.values() {
900                self.tag_map.remove(tag);
901            }
902            for tag in sess.ratcheted_expected_tags.values() {
903                self.tag_map.remove(tag);
904            }
905            // Remove VPN IP mapping only if it still points to THIS session.
906            // A newer session may have already claimed the same VPN IP.
907            if let Some(vpn_ip) = sess.vpn_ip {
908                if self
909                    .vpn_ip_map
910                    .remove_if(&vpn_ip, |_, sid| sid == session_id)
911                    .is_some()
912                {
913                    // No other session owns this IP — return it to the free pool
914                    let octet = vpn_ip.octets()[3];
915                    if octet >= 2 {
916                        self.ip_pool.lock().insert(octet);
917                    }
918                }
919            }
920            Some(*session_id)
921        } else {
922            None
923        }
924    }
925
926    /// Refresh tag_map after session's tag window has been updated
927    pub fn refresh_session_tags(&self, session_id: &[u8; 16]) {
928        if let Some(session) = self.sessions.get(session_id) {
929            let sess = session.lock();
930            // Remove stale tags for this session
931            self.tag_map.retain(|_, id| id != session_id);
932            // Re-add current tags
933            for tag in sess.expected_tags.values() {
934                self.tag_map.insert(*tag, *session_id);
935            }
936            for tag in sess.ratcheted_expected_tags.values() {
937                self.tag_map.insert(*tag, *session_id);
938            }
939        }
940    }
941
942    /// Complete PFS ratchet for a session: switch to ratcheted keys, remove old tags
943    pub fn complete_session_ratchet(&self, session_id: &[u8; 16]) {
944        if let Some(session) = self.sessions.get(session_id) {
945            let mut sess = session.lock();
946            // Remove old initial key tags from tag_map
947            for tag in sess.expected_tags.values() {
948                self.tag_map.remove(tag);
949            }
950            // Complete the ratchet (swaps keys, moves ratcheted_expected_tags → expected_tags)
951            sess.complete_ratchet();
952            // Re-add the now-active tags (which were the ratcheted tags)
953            for tag in sess.expected_tags.values() {
954                self.tag_map.insert(*tag, *session_id);
955            }
956        }
957    }
958
959    /// Cleanup expired sessions and return list of removed session IDs.
960    /// The returned IDs can be used to stop active recordings.
961    pub fn cleanup_expired(&self) -> Vec<[u8; 16]> {
962        let expired: Vec<[u8; 16]> = self
963            .sessions
964            .iter()
965            .filter(|e| {
966                let sess = e.value().lock();
967                sess.last_seen.elapsed() > self.idle_timeout
968                    || (self.hard_timeout > Duration::ZERO
969                        && sess.created_at.elapsed() > self.hard_timeout)
970            })
971            .map(|e| *e.key())
972            .collect();
973
974        let mut removed = Vec::new();
975        for session_id in expired {
976            if self.remove_session(&session_id).is_some() {
977                removed.push(session_id);
978            }
979        }
980        removed
981    }
982
983    /// Get active session count
984    pub fn session_count(&self) -> usize {
985        self.sessions.len()
986    }
987
988    /// Log diagnostic information about all sessions and tag state
989    pub fn log_session_diagnostics(&self, incoming_tag: &[u8; TAG_SIZE]) {
990        let tag_map_size = self.tag_map.len();
991        let current_tw =
992            crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
993        info!(
994            "DIAG: tag_map_size={}, current_tw={}",
995            tag_map_size, current_tw
996        );
997        for entry in self.sessions.iter() {
998            let sess = entry.value().lock();
999            let sid_hex = format!(
1000                "{:02x}{:02x}{:02x}{:02x}",
1001                entry.key()[0],
1002                entry.key()[1],
1003                entry.key()[2],
1004                entry.key()[3]
1005            );
1006            let is_ratcheted = sess.is_ratcheted;
1007            let counter = sess.counter;
1008            let expected_count = sess.expected_tags.len();
1009            let ratcheted_count = sess.ratcheted_expected_tags.len();
1010            let has_ratcheted_keys = sess.ratcheted_keys.is_some();
1011            // Check if any expected tag matches (manually)
1012            let mut found = false;
1013            for (c, t) in &sess.expected_tags {
1014                if t == incoming_tag {
1015                    found = true;
1016                    info!(
1017                        "DIAG: Session {} — expected tag MATCHES at counter {}",
1018                        sid_hex, c
1019                    );
1020                    break;
1021                }
1022            }
1023            info!(
1024                "DIAG: Session {} — ratcheted={}, counter={}, expected_tags={}, ratcheted_tags={}, has_ratchet_keys={}, tag_matched={}",
1025                sid_hex, is_ratcheted, counter, expected_count, ratcheted_count, has_ratcheted_keys, found
1026            );
1027        }
1028    }
1029
1030    /// Get server public key
1031    pub fn server_public_key(&self) -> [u8; X25519_PUBLIC_KEY_SIZE] {
1032        self.server_keys.public_key_bytes()
1033    }
1034
1035    /// Sign mask data
1036    pub fn sign_mask(&self, mask_data: &[u8]) -> [u8; 64] {
1037        use ed25519_dalek::Signer;
1038        let signature = self.signing_key.sign(mask_data);
1039        signature.to_bytes()
1040    }
1041
1042    /// Iterate over all sessions (for neural resonance checks)
1043    pub fn iter_sessions(&self) -> dashmap::iter::Iter<'_, [u8; 16], Arc<Mutex<Session>>> {
1044        self.sessions.iter()
1045    }
1046
1047    /// Schedule a deferred mask switch for a session.
1048    /// The MaskUpdate control message has already been sent to the client;
1049    /// we store the new mask in `pending_mask` and let it activate after a
1050    /// grace period (see `commit_pending_mask`).
1051    pub fn update_session_mask(
1052        &self,
1053        session_id: &[u8; 16],
1054        new_mask: MaskProfile,
1055    ) -> Option<(Arc<Mutex<Session>>, SocketAddr)> {
1056        if let Some(session) = self.sessions.get(session_id) {
1057            let client_addr;
1058            {
1059                let mut sess = session.lock();
1060                info!(
1061                    "Session mask scheduled: {} → {} (grace period 500ms)",
1062                    sess.mask
1063                        .as_ref()
1064                        .map(|m| m.mask_id.as_str())
1065                        .unwrap_or("default"),
1066                    new_mask.mask_id
1067                );
1068                // Don't switch immediately — store as pending
1069                sess.pending_mask = Some((new_mask, Instant::now()));
1070                sess.state = SessionState::Active;
1071                client_addr = sess.client_addr;
1072            }
1073            Some((session.clone(), client_addr))
1074        } else {
1075            None
1076        }
1077    }
1078
1079    /// Build an encrypted MaskUpdate control packet for the given session.
1080    /// Returns the raw UDP datagram bytes ready to send.
1081    pub fn build_mask_update_packet(
1082        &self,
1083        session: &Arc<Mutex<Session>>,
1084        new_mask: &MaskProfile,
1085    ) -> Result<Vec<u8>> {
1086        use aivpn_common::crypto::encrypt_payload;
1087
1088        // Serialize mask profile → mask_data (MessagePack to match client's rmp_serde::from_slice)
1089        let mask_data = rmp_serde::to_vec(new_mask)
1090            .map_err(|e| Error::Session(format!("Failed to serialize mask: {}", e)))?;
1091
1092        // Sign mask_data with server's Ed25519 key
1093        let signature = self.sign_mask(&mask_data);
1094
1095        // Build control payload
1096        let control = ControlPayload::MaskUpdate {
1097            mask_data,
1098            signature,
1099        };
1100        let encoded = control.encode()?;
1101
1102        let mut sess = session.lock();
1103        let inner_header = InnerHeader {
1104            inner_type: InnerType::Control,
1105            seq_num: sess.next_seq() as u16,
1106        };
1107        let mut inner_payload = inner_header.encode().to_vec();
1108        inner_payload.extend_from_slice(&encoded);
1109
1110        // Encrypt (same logic as Gateway::build_packet)
1111        let (nonce, counter) = sess.next_send_nonce();
1112        let pad_len = 16u16;
1113        let mut padded = Vec::with_capacity(2 + inner_payload.len() + pad_len as usize);
1114        padded.extend_from_slice(&pad_len.to_le_bytes());
1115        padded.extend_from_slice(&inner_payload);
1116        {
1117            use rand::Rng;
1118            let mut rng = rand::thread_rng();
1119            for _ in 0..pad_len {
1120                padded.push(rng.gen::<u8>());
1121            }
1122        }
1123
1124        let ciphertext = encrypt_payload(&sess.keys.session_key, &nonce, &padded)?;
1125
1126        // Generate tag
1127        let time_window =
1128            crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
1129        let tag = crypto::generate_resonance_tag(&sess.keys.tag_secret, counter, time_window);
1130
1131        // Wrap MaskUpdate in the session's current mask. The switch to `new_mask`
1132        // happens only after the packet is successfully delivered.
1133        let transport_mask = sess.mask.as_ref().unwrap_or(&self.default_mask);
1134        let mdh = if let Some(ref spec) = transport_mask.header_spec {
1135            let mut rng = rand::thread_rng();
1136            spec.generate(&mut rng)
1137        } else {
1138            transport_mask.header_template.clone()
1139        };
1140
1141        // Assemble: TAG | MDH | ciphertext
1142        let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
1143        packet.extend_from_slice(&tag);
1144        packet.extend_from_slice(&mdh);
1145        packet.extend_from_slice(&ciphertext);
1146
1147        Ok(packet)
1148    }
1149}