1use std::collections::{BTreeSet, HashMap};
6use std::net::{Ipv4Addr, SocketAddr};
7use std::sync::Arc;
8
9use std::time::{Duration, Instant};
10
11use chacha20poly1305::aead::OsRng;
12use dashmap::DashMap;
13use hex;
14use parking_lot::Mutex;
15use rand::RngCore;
16use subtle::ConstantTimeEq;
17use tracing::{debug, info, trace};
18
19use aivpn_common::crypto::{
20 self, KeyPair, SessionKeys, DEFAULT_WINDOW_MS, NONCE_SIZE, TAG_SIZE, X25519_PUBLIC_KEY_SIZE,
21};
22use aivpn_common::error::{Error, Result};
23use aivpn_common::mask::MaskProfile;
24use aivpn_common::protocol::{ControlPayload, InnerHeader, InnerType};
25
26pub const MAX_SESSIONS: usize = 500;
28
29pub const IDLE_TIMEOUT: Duration = Duration::from_secs(300);
31
32pub const HARD_TIMEOUT: Duration = Duration::ZERO;
37
38pub const TAG_WINDOW_SIZE: usize = 256;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SessionState {
44 Pending,
45 Active,
46 Idle,
47 Rotating,
48 MaskChange,
49 Expired,
50 Closed,
51}
52
53pub struct Session {
55 pub session_id: [u8; 16],
56 pub client_addr: SocketAddr,
57 pub state: SessionState,
58 pub keys: SessionKeys,
59 pub eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
60
61 pub counter: u64,
63 pub last_seen: Instant,
65 pub created_at: Instant,
67 pub last_server_send: Instant,
69
70 pub mask: Option<MaskProfile>,
72 pub pending_mask: Option<(MaskProfile, Instant)>,
75 pub fsm_state: u16,
77 pub fsm_packets: u32,
79 pub fsm_state_start: Instant,
81
82 pub send_seq: u32,
84 pub recv_seq: u32,
86 pub send_counter: u64,
88
89 pub expected_tags: HashMap<u64, [u8; TAG_SIZE]>,
91 pub tag_window_base: u64,
93 pub received_bitmap: u256,
95 pub pending_bytes_in: u64,
97 pub pending_bytes_out: u64,
99
100 pub server_eph_pub: Option<[u8; 32]>,
103 pub server_hello_signature: Option<[u8; 64]>,
105 pub ratcheted_keys: Option<SessionKeys>,
107 pub ratcheted_expected_tags: HashMap<u64, [u8; TAG_SIZE]>,
109 pub is_ratcheted: bool,
111 pub vpn_ip: Option<Ipv4Addr>,
113 pub client_id: Option<String>,
115
116 pub pre_ratchet_tags: HashMap<u64, [u8; TAG_SIZE]>,
120 pub pre_ratchet_expire: Option<Instant>,
122}
123
124#[derive(Debug, Clone, Copy, Default)]
126#[allow(non_camel_case_types)]
127pub struct u256 {
128 lo: u128,
129 hi: u128,
130}
131
132impl u256 {
133 pub fn set_bit(&mut self, bit: usize) {
134 if bit < 128 {
135 self.lo |= 1u128 << bit;
136 } else {
137 self.hi |= 1u128 << (bit - 128);
138 }
139 }
140
141 pub fn shift_left(&mut self, shift: usize) {
142 if shift == 0 {
143 return;
144 }
145 if shift >= 256 {
146 self.clear();
147 return;
148 }
149 if shift >= 128 {
150 self.hi = self.lo << (shift - 128);
151 self.lo = 0;
152 return;
153 }
154
155 self.hi = (self.hi << shift) | (self.lo >> (128 - shift));
156 self.lo <<= shift;
157 }
158
159 pub fn get_bit(&self, bit: usize) -> bool {
160 if bit < 128 {
161 (self.lo & (1u128 << bit)) != 0
162 } else {
163 (self.hi & (1u128 << (bit - 128))) != 0
164 }
165 }
166
167 pub fn clear(&mut self) {
168 self.lo = 0;
169 self.hi = 0;
170 }
171}
172
173impl Session {
174 pub fn new(
175 session_id: [u8; 16],
176 client_addr: SocketAddr,
177 keys: SessionKeys,
178 eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
179 ) -> Self {
180 let now = Instant::now();
181 Self {
182 session_id,
183 client_addr,
184 state: SessionState::Pending,
185 keys,
186 eph_pub,
187 counter: 0,
188 last_seen: now,
189 created_at: now,
190 last_server_send: now,
191 mask: None,
192 pending_mask: None,
193 fsm_state: 0,
194 fsm_packets: 0,
195 fsm_state_start: now,
196 send_seq: 0,
197 recv_seq: 0,
198 send_counter: 0,
199 expected_tags: HashMap::with_capacity(TAG_WINDOW_SIZE),
200 tag_window_base: 0,
201 received_bitmap: u256::default(),
202 pending_bytes_in: 0,
203 pending_bytes_out: 0,
204 server_eph_pub: None,
205 server_hello_signature: None,
206 ratcheted_keys: None,
207 ratcheted_expected_tags: HashMap::new(),
208 is_ratcheted: false,
209 vpn_ip: None,
210 client_id: None,
211 pre_ratchet_tags: HashMap::new(),
212 pre_ratchet_expire: None,
213 }
214 }
215
216 pub fn next_send_nonce(&mut self) -> ([u8; NONCE_SIZE], u64) {
219 let counter = self.send_counter;
220 let mut nonce = [0u8; NONCE_SIZE];
221 nonce[0..8].copy_from_slice(&counter.to_le_bytes());
222 self.send_counter += 1;
223 (nonce, counter)
224 }
225
226 pub fn update_tag_window(&mut self) {
228 let time_window =
229 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
230
231 self.expected_tags.clear();
235 self.tag_window_base = self.counter;
236 let window_back = TAG_WINDOW_SIZE as u64 - 1;
237 let window_start = self.counter.saturating_sub(window_back);
238 let window_end = self.counter.saturating_add(TAG_WINDOW_SIZE as u64 - 1);
239
240 for counter_val in window_start..=window_end {
241 let tag =
242 crypto::generate_resonance_tag(&self.keys.tag_secret, counter_val, time_window);
243 self.expected_tags.insert(counter_val, tag);
244 }
245 }
246
247 pub fn validate_tag(&self, tag: &[u8; TAG_SIZE]) -> Option<(u64, bool)> {
252 let is_replay = |counter_val: u64| {
253 if counter_val > self.counter {
254 return false;
255 }
256
257 let bit_index = (self.counter - counter_val) as usize;
258 bit_index < TAG_WINDOW_SIZE && self.received_bitmap.get_bit(bit_index)
259 };
260
261 let history_window = TAG_WINDOW_SIZE as u64 - 1;
262 let window_start = self.counter.saturating_sub(history_window);
263 let window_end = self.counter.saturating_add(TAG_WINDOW_SIZE as u64 - 1);
264
265 for (counter, expected) in &self.expected_tags {
267 if bool::from(expected.ct_eq(tag)) {
268 if is_replay(*counter) {
269 return None; }
271 return Some((*counter, false));
272 }
273 }
274 let current_tw =
276 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
277 for tw_offset in [current_tw.wrapping_sub(1), current_tw.wrapping_add(1)] {
278 for counter_val in window_start..=window_end {
279 let expected =
280 crypto::generate_resonance_tag(&self.keys.tag_secret, counter_val, tw_offset);
281 if bool::from(expected.ct_eq(tag)) {
282 if is_replay(counter_val) {
283 return None;
284 }
285 return Some((counter_val, false));
286 }
287 }
288 }
289 if let Some(expire) = self.pre_ratchet_expire {
292 if Instant::now() < expire {
293 for (counter, expected) in &self.pre_ratchet_tags {
294 if bool::from(expected.ct_eq(tag)) {
295 if is_replay(*counter) {
296 return None;
297 }
298 return Some((*counter, false));
299 }
300 }
301 }
302 }
303
304 if !self.is_ratcheted {
306 for (counter, expected) in &self.ratcheted_expected_tags {
307 if bool::from(expected.ct_eq(tag)) {
308 return Some((*counter, true));
309 }
310 }
311 if let Some(ratcheted_keys) = &self.ratcheted_keys {
313 for tw_offset in [current_tw.wrapping_sub(1), current_tw.wrapping_add(1)] {
314 for i in 0..TAG_WINDOW_SIZE {
315 let expected = crypto::generate_resonance_tag(
316 &ratcheted_keys.tag_secret,
317 i as u64,
318 tw_offset,
319 );
320 if bool::from(expected.ct_eq(tag)) {
321 return Some((i as u64, true));
322 }
323 }
324 }
325 }
326 }
327 None
328 }
329
330 pub fn mark_tag_received(&mut self, counter: u64) {
332 if counter > self.counter {
333 let shift = (counter - self.counter) as usize;
334 self.received_bitmap.shift_left(shift);
335 self.counter = counter;
336 self.received_bitmap.set_bit(0);
337 return;
338 }
339
340 let bit_index = (self.counter - counter) as usize;
341 if bit_index < 256 {
342 self.received_bitmap.set_bit(bit_index);
343 }
344 }
345
346 pub fn next_seq(&mut self) -> u32 {
348 let seq = self.send_seq;
349 self.send_seq = self.send_seq.wrapping_add(1);
350 seq
351 }
352
353 pub fn update_fsm(&mut self) {
355 if let Some(mask) = &self.mask {
356 let duration_ms = self.fsm_state_start.elapsed().as_millis() as u64;
357 let (new_state, _size_override, _iat_override, _padding_override) =
358 mask.process_transition(self.fsm_state, self.fsm_packets, duration_ms);
359
360 if new_state != self.fsm_state {
361 self.fsm_state = new_state;
362 self.fsm_packets = 0;
363 self.fsm_state_start = Instant::now();
364 }
365 }
366 self.fsm_packets += 1;
367 }
368
369 pub fn is_idle(&self) -> bool {
371 self.last_seen.elapsed() > IDLE_TIMEOUT
372 }
373
374 pub fn is_expired(&self) -> bool {
376 self.created_at.elapsed() > HARD_TIMEOUT
377 }
378
379 pub fn update_ratcheted_tag_window(&mut self) {
381 if let Some(ratcheted_keys) = &self.ratcheted_keys {
382 let time_window =
383 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
384 self.ratcheted_expected_tags.clear();
385 for i in 0..TAG_WINDOW_SIZE {
387 let tag = crypto::generate_resonance_tag(
388 &ratcheted_keys.tag_secret,
389 i as u64,
390 time_window,
391 );
392 self.ratcheted_expected_tags.insert(i as u64, tag);
393 }
394 }
395 }
396
397 pub fn complete_ratchet(&mut self) {
399 if let Some(ratcheted_keys) = self.ratcheted_keys.take() {
400 self.pre_ratchet_tags = std::mem::take(&mut self.expected_tags);
403 self.pre_ratchet_expire = Some(Instant::now() + Duration::from_secs(2));
404
405 self.keys = ratcheted_keys;
406 self.counter = 0;
407 self.send_counter = 0;
408 self.tag_window_base = self.counter;
409 self.expected_tags = std::mem::take(&mut self.ratcheted_expected_tags);
410 self.received_bitmap.clear();
411 self.pending_bytes_in = 0;
412 self.pending_bytes_out = 0;
413 self.is_ratcheted = true;
414 self.server_eph_pub = None;
415 self.server_hello_signature = None;
416 }
417 }
418
419 pub fn commit_pending_mask(&mut self) -> bool {
423 const MASK_GRACE_PERIOD: Duration = Duration::from_millis(500);
424 if let Some((_, sent_at)) = &self.pending_mask {
425 if sent_at.elapsed() >= MASK_GRACE_PERIOD {
426 let (new_mask, _) = self.pending_mask.take().unwrap();
427 info!("Committing deferred mask switch to '{}'", new_mask.mask_id);
428 self.mask = Some(new_mask);
429 self.fsm_state = 0;
431 self.fsm_packets = 0;
432 self.fsm_state_start = Instant::now();
433 return true;
434 }
435 }
436 false
437 }
438}
439
440pub struct SessionManager {
442 sessions: DashMap<[u8; 16], Arc<Mutex<Session>>>,
444 tag_map: DashMap<[u8; TAG_SIZE], [u8; 16]>,
446 vpn_ip_map: DashMap<Ipv4Addr, [u8; 16]>,
448 ip_pool: Mutex<BTreeSet<u8>>,
451 server_keys: KeyPair,
453 signing_key: ed25519_dalek::SigningKey,
455 default_mask: MaskProfile,
457 hard_timeout: Duration,
459 idle_timeout: Duration,
461}
462
463impl SessionManager {
464 pub fn new(
465 server_keys: KeyPair,
466 signing_key: ed25519_dalek::SigningKey,
467 default_mask: MaskProfile,
468 ) -> Self {
469 Self::with_timeouts(server_keys, signing_key, default_mask, None, None)
470 }
471
472 pub fn with_timeouts(
473 server_keys: KeyPair,
474 signing_key: ed25519_dalek::SigningKey,
475 default_mask: MaskProfile,
476 session_timeout_secs: Option<u64>,
477 idle_timeout_secs: Option<u64>,
478 ) -> Self {
479 let hard_timeout = session_timeout_secs
480 .map(|s| Duration::from_secs(s))
481 .unwrap_or(HARD_TIMEOUT);
482 let idle_timeout = idle_timeout_secs
483 .map(|s| Duration::from_secs(s))
484 .unwrap_or(IDLE_TIMEOUT);
485 Self {
486 sessions: DashMap::new(),
487 tag_map: DashMap::new(),
488 vpn_ip_map: DashMap::new(),
489 ip_pool: Mutex::new((2..=254u8).collect()),
490 server_keys,
491 signing_key,
492 default_mask,
493 hard_timeout,
494 idle_timeout,
495 }
496 }
497
498 pub fn create_session(
503 &self,
504 client_addr: SocketAddr,
505 eph_pub: [u8; X25519_PUBLIC_KEY_SIZE],
506 preshared_key: Option<[u8; 32]>,
507 static_vpn_ip: Option<Ipv4Addr>,
508 ) -> Result<Arc<Mutex<Session>>> {
509 let reused_vpn_ip: Option<Ipv4Addr> = self
513 .sessions
514 .iter()
515 .filter_map(|entry| {
516 let session = entry.value().lock();
517 if session.client_addr.ip() == client_addr.ip() {
518 session.vpn_ip
519 } else {
520 None
521 }
522 })
523 .next();
524
525 if self.sessions.len() >= MAX_SESSIONS {
526 return Err(Error::Session("Max sessions reached".into()));
527 }
528
529 let ip_count = self
531 .sessions
532 .iter()
533 .filter(|e| e.value().lock().client_addr.ip() == client_addr.ip())
534 .count();
535 if ip_count >= 5 {
536 return Err(Error::Session("Per-IP session limit reached".into()));
537 }
538
539 if let std::net::IpAddr::V4(v4) = client_addr.ip() {
544 let subnet24 = u32::from(v4) >> 8;
545 let subnet_count = self
546 .sessions
547 .iter()
548 .filter(|e| {
549 if let std::net::IpAddr::V4(ip) = e.value().lock().client_addr.ip() {
550 (u32::from(ip) >> 8) == subnet24
551 } else {
552 false
553 }
554 })
555 .count();
556 if subnet_count >= 10 {
557 return Err(Error::Session(
558 "Per-subnet (/24) session limit reached".into(),
559 ));
560 }
561 }
562
563 let dh1 = self.server_keys.compute_shared(&eph_pub)?;
565 trace!("Server DH result: {}", hex::encode(&dh1));
566 trace!(
567 "Server eph_pub (after deobfuscation): {}",
568 hex::encode(&eph_pub)
569 );
570 trace!("Server PSK: {:?}", preshared_key.as_ref().map(hex::encode));
571 let initial_keys = crypto::derive_session_keys(&dh1, preshared_key.as_ref(), &eph_pub);
572 trace!(
573 "Server tag_secret: {}",
574 hex::encode(&initial_keys.tag_secret)
575 );
576
577 let server_eph_kp = crypto::KeyPair::generate();
580 let server_eph_pub = server_eph_kp.public_key_bytes();
581
582 let dh2 = server_eph_kp.compute_shared(&eph_pub)?;
584 let ratcheted_keys =
586 crypto::derive_session_keys(&dh2, Some(&initial_keys.session_key), &eph_pub);
587
588 use ed25519_dalek::Signer;
590 let mut sign_message = Vec::with_capacity(64);
591 sign_message.extend_from_slice(&server_eph_pub);
592 sign_message.extend_from_slice(&eph_pub);
593 let signature = self.signing_key.sign(&sign_message).to_bytes();
594
595 let mut session_id = [0u8; 16];
597 OsRng.fill_bytes(&mut session_id);
598
599 let session = Arc::new(Mutex::new(Session::new(
601 session_id,
602 client_addr,
603 initial_keys,
604 eph_pub,
605 )));
606
607 {
609 let mut sess = session.lock();
610 sess.state = SessionState::Active;
611
612 sess.server_eph_pub = Some(server_eph_pub);
614 sess.server_hello_signature = Some(signature);
615 sess.ratcheted_keys = Some(ratcheted_keys);
616
617 sess.update_tag_window();
619 for tag in sess.expected_tags.values() {
620 self.tag_map.insert(*tag, session_id);
621 }
622
623 sess.update_ratcheted_tag_window();
625 for tag in sess.ratcheted_expected_tags.values() {
626 self.tag_map.insert(*tag, session_id);
627 }
628 }
629
630 self.sessions.insert(session_id, session.clone());
632
633 let vpn_ip = if let Some(ip) = static_vpn_ip.or(reused_vpn_ip) {
636 self.ip_pool.lock().remove(&ip.octets()[3]);
638 Some(ip)
639 } else {
640 self.ip_pool
642 .lock()
643 .pop_first()
644 .map(|octet| Ipv4Addr::new(10, 0, 0, octet))
645 };
646
647 if let Some(vpn_ip) = vpn_ip {
648 session.lock().vpn_ip = Some(vpn_ip);
649 self.vpn_ip_map.insert(vpn_ip, session_id);
650 debug!("Assigned VPN IP {} to session", vpn_ip);
651 }
652
653 Ok(session)
654 }
655
656 pub fn cleanup_old_sessions_for_ip(
660 &self,
661 ip: &std::net::IpAddr,
662 keep_session_id: &[u8; 16],
663 ) -> Vec<[u8; 16]> {
664 let to_remove: Vec<[u8; 16]> = self
665 .sessions
666 .iter()
667 .filter_map(|entry| {
668 let session = entry.value().lock();
669 if session.client_addr.ip() == *ip && entry.key() != keep_session_id {
670 Some(*entry.key())
671 } else {
672 None
673 }
674 })
675 .collect();
676
677 let mut removed = Vec::new();
678 for session_id in to_remove {
679 info!(
680 "Removing stale session for IP {} after successful re-handshake",
681 ip
682 );
683 if self.remove_session(&session_id).is_some() {
684 removed.push(session_id);
685 }
686 }
687 removed
688 }
689
690 pub fn cleanup_old_sessions_for_vpn_ip(
695 &self,
696 vpn_ip: &Ipv4Addr,
697 keep_session_id: &[u8; 16],
698 ) -> Vec<[u8; 16]> {
699 let to_remove: Vec<[u8; 16]> = self
700 .sessions
701 .iter()
702 .filter_map(|entry| {
703 let session = entry.value().lock();
704 if session.vpn_ip == Some(*vpn_ip) && entry.key() != keep_session_id {
705 Some(*entry.key())
706 } else {
707 None
708 }
709 })
710 .collect();
711
712 let mut removed = Vec::new();
713 for session_id in to_remove {
714 info!(
715 "Removing stale session for VPN IP {} after successful re-handshake",
716 vpn_ip
717 );
718 if self.remove_session(&session_id).is_some() {
719 removed.push(session_id);
720 }
721 }
722 removed
723 }
724
725 pub fn rollback_failed_session(&self, session_id: &[u8; 16]) {
728 let vpn_ip = self
730 .sessions
731 .get(session_id)
732 .map(|e| e.value().lock().vpn_ip)
733 .flatten();
734
735 self.remove_session(session_id);
736
737 if let Some(vpn_ip) = vpn_ip {
740 for entry in self.sessions.iter() {
741 let sess = entry.value().lock();
742 if sess.vpn_ip == Some(vpn_ip) {
743 self.vpn_ip_map.insert(vpn_ip, *entry.key());
744 self.ip_pool.lock().remove(&vpn_ip.octets()[3]);
745 break;
746 }
747 }
748 }
749 }
750
751 pub fn has_recent_ratcheted_session_on_other_endpoint(
755 &self,
756 client_addr: &SocketAddr,
757 max_age: Duration,
758 ) -> bool {
759 self.sessions.iter().any(|entry| {
760 let sess = entry.value().lock();
761 sess.client_addr.ip() == client_addr.ip()
762 && sess.client_addr != *client_addr
763 && sess.is_ratcheted
764 && sess.last_seen.elapsed() <= max_age
765 })
766 }
767
768 pub fn get_session_by_tag(&self, tag: &[u8; TAG_SIZE]) -> Option<Arc<Mutex<Session>>> {
770 if let Some(entry) = self.tag_map.get(tag) {
771 let session_id = *entry;
772 drop(entry);
773 self.sessions.get(&session_id).map(|e| e.clone())
774 } else {
775 None
776 }
777 }
778
779 pub fn refresh_and_find_by_tag(
782 &self,
783 tag: &[u8; TAG_SIZE],
784 ) -> Option<(Arc<Mutex<Session>>, u64, bool)> {
785 for entry in self.sessions.iter() {
786 let session = entry.value().clone();
787 let session_id = *entry.key();
788 let mut sess = session.lock();
789
790 let old_tags: Vec<[u8; TAG_SIZE]> = sess.expected_tags.values().cloned().collect();
792 for old_tag in &old_tags {
793 self.tag_map.remove(old_tag);
794 }
795 sess.update_tag_window();
796 for t in sess.expected_tags.values() {
797 self.tag_map.insert(*t, session_id);
798 }
799
800 let old_ratcheted: Vec<[u8; TAG_SIZE]> =
802 sess.ratcheted_expected_tags.values().cloned().collect();
803 for old_tag in &old_ratcheted {
804 self.tag_map.remove(old_tag);
805 }
806 sess.update_ratcheted_tag_window();
807 for t in sess.ratcheted_expected_tags.values() {
808 self.tag_map.insert(*t, session_id);
809 }
810
811 if let Some((counter, is_ratcheted)) = sess.validate_tag(tag) {
813 drop(sess);
814 return Some((session, counter, is_ratcheted));
815 }
816 }
817 None
818 }
819
820 pub fn recover_session_by_tag(
825 &self,
826 tag: &[u8; TAG_SIZE],
827 client_ip: &std::net::IpAddr,
828 ) -> Option<(Arc<Mutex<Session>>, u64, bool)> {
829 let current_tw =
830 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
831 const RECOVERY_RANGE: u64 = 65536;
833
834 for entry in self.sessions.iter() {
835 let session = entry.value().clone();
836 let session_id = *entry.key();
837 let sess = session.lock();
838 if sess.client_addr.ip() != *client_ip {
839 continue;
840 }
841
842 let base = sess.counter;
843 let tag_secret = &sess.keys.tag_secret;
844
845 for tw_offset in [0i64, -1, 1] {
846 let tw = (current_tw as i64 + tw_offset) as u64;
847 for i in 0..RECOVERY_RANGE {
848 let c = base + i;
849 let expected = crypto::generate_resonance_tag(tag_secret, c, tw);
850 if bool::from(expected.ct_eq(tag)) {
851 info!(
852 "Counter recovery: found counter {} (drift={}) for session",
853 c, i
854 );
855 drop(sess);
857 {
858 let mut s = session.lock();
859 s.counter = c;
860 s.update_tag_window();
861 }
862 self.tag_map.retain(|_, id| id != &session_id);
864 let s = session.lock();
865 for t in s.expected_tags.values() {
866 self.tag_map.insert(*t, session_id);
867 }
868 drop(s);
869 return Some((session, c, false));
870 }
871 }
872 }
873 }
874 None
875 }
876
877 pub fn get_session(&self, session_id: &[u8; 16]) -> Option<Arc<Mutex<Session>>> {
879 self.sessions.get(session_id).map(|e| e.clone())
880 }
881
882 pub fn get_session_by_vpn_ip(&self, vpn_ip: &Ipv4Addr) -> Option<Arc<Mutex<Session>>> {
884 if let Some(entry) = self.vpn_ip_map.get(vpn_ip) {
885 let session_id = *entry;
886 drop(entry);
887 self.sessions.get(&session_id).map(|e| e.clone())
888 } else {
889 None
890 }
891 }
892
893 pub fn remove_session(&self, session_id: &[u8; 16]) -> Option<[u8; 16]> {
896 if let Some((_, session)) = self.sessions.remove(session_id) {
897 let sess = session.lock();
898 for tag in sess.expected_tags.values() {
900 self.tag_map.remove(tag);
901 }
902 for tag in sess.ratcheted_expected_tags.values() {
903 self.tag_map.remove(tag);
904 }
905 if let Some(vpn_ip) = sess.vpn_ip {
908 if self
909 .vpn_ip_map
910 .remove_if(&vpn_ip, |_, sid| sid == session_id)
911 .is_some()
912 {
913 let octet = vpn_ip.octets()[3];
915 if octet >= 2 {
916 self.ip_pool.lock().insert(octet);
917 }
918 }
919 }
920 Some(*session_id)
921 } else {
922 None
923 }
924 }
925
926 pub fn refresh_session_tags(&self, session_id: &[u8; 16]) {
928 if let Some(session) = self.sessions.get(session_id) {
929 let sess = session.lock();
930 self.tag_map.retain(|_, id| id != session_id);
932 for tag in sess.expected_tags.values() {
934 self.tag_map.insert(*tag, *session_id);
935 }
936 for tag in sess.ratcheted_expected_tags.values() {
937 self.tag_map.insert(*tag, *session_id);
938 }
939 }
940 }
941
942 pub fn complete_session_ratchet(&self, session_id: &[u8; 16]) {
944 if let Some(session) = self.sessions.get(session_id) {
945 let mut sess = session.lock();
946 for tag in sess.expected_tags.values() {
948 self.tag_map.remove(tag);
949 }
950 sess.complete_ratchet();
952 for tag in sess.expected_tags.values() {
954 self.tag_map.insert(*tag, *session_id);
955 }
956 }
957 }
958
959 pub fn cleanup_expired(&self) -> Vec<[u8; 16]> {
962 let expired: Vec<[u8; 16]> = self
963 .sessions
964 .iter()
965 .filter(|e| {
966 let sess = e.value().lock();
967 sess.last_seen.elapsed() > self.idle_timeout
968 || (self.hard_timeout > Duration::ZERO
969 && sess.created_at.elapsed() > self.hard_timeout)
970 })
971 .map(|e| *e.key())
972 .collect();
973
974 let mut removed = Vec::new();
975 for session_id in expired {
976 if self.remove_session(&session_id).is_some() {
977 removed.push(session_id);
978 }
979 }
980 removed
981 }
982
983 pub fn session_count(&self) -> usize {
985 self.sessions.len()
986 }
987
988 pub fn log_session_diagnostics(&self, incoming_tag: &[u8; TAG_SIZE]) {
990 let tag_map_size = self.tag_map.len();
991 let current_tw =
992 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
993 info!(
994 "DIAG: tag_map_size={}, current_tw={}",
995 tag_map_size, current_tw
996 );
997 for entry in self.sessions.iter() {
998 let sess = entry.value().lock();
999 let sid_hex = format!(
1000 "{:02x}{:02x}{:02x}{:02x}",
1001 entry.key()[0],
1002 entry.key()[1],
1003 entry.key()[2],
1004 entry.key()[3]
1005 );
1006 let is_ratcheted = sess.is_ratcheted;
1007 let counter = sess.counter;
1008 let expected_count = sess.expected_tags.len();
1009 let ratcheted_count = sess.ratcheted_expected_tags.len();
1010 let has_ratcheted_keys = sess.ratcheted_keys.is_some();
1011 let mut found = false;
1013 for (c, t) in &sess.expected_tags {
1014 if t == incoming_tag {
1015 found = true;
1016 info!(
1017 "DIAG: Session {} — expected tag MATCHES at counter {}",
1018 sid_hex, c
1019 );
1020 break;
1021 }
1022 }
1023 info!(
1024 "DIAG: Session {} — ratcheted={}, counter={}, expected_tags={}, ratcheted_tags={}, has_ratchet_keys={}, tag_matched={}",
1025 sid_hex, is_ratcheted, counter, expected_count, ratcheted_count, has_ratcheted_keys, found
1026 );
1027 }
1028 }
1029
1030 pub fn server_public_key(&self) -> [u8; X25519_PUBLIC_KEY_SIZE] {
1032 self.server_keys.public_key_bytes()
1033 }
1034
1035 pub fn sign_mask(&self, mask_data: &[u8]) -> [u8; 64] {
1037 use ed25519_dalek::Signer;
1038 let signature = self.signing_key.sign(mask_data);
1039 signature.to_bytes()
1040 }
1041
1042 pub fn iter_sessions(&self) -> dashmap::iter::Iter<'_, [u8; 16], Arc<Mutex<Session>>> {
1044 self.sessions.iter()
1045 }
1046
1047 pub fn update_session_mask(
1052 &self,
1053 session_id: &[u8; 16],
1054 new_mask: MaskProfile,
1055 ) -> Option<(Arc<Mutex<Session>>, SocketAddr)> {
1056 if let Some(session) = self.sessions.get(session_id) {
1057 let client_addr;
1058 {
1059 let mut sess = session.lock();
1060 info!(
1061 "Session mask scheduled: {} → {} (grace period 500ms)",
1062 sess.mask
1063 .as_ref()
1064 .map(|m| m.mask_id.as_str())
1065 .unwrap_or("default"),
1066 new_mask.mask_id
1067 );
1068 sess.pending_mask = Some((new_mask, Instant::now()));
1070 sess.state = SessionState::Active;
1071 client_addr = sess.client_addr;
1072 }
1073 Some((session.clone(), client_addr))
1074 } else {
1075 None
1076 }
1077 }
1078
1079 pub fn build_mask_update_packet(
1082 &self,
1083 session: &Arc<Mutex<Session>>,
1084 new_mask: &MaskProfile,
1085 ) -> Result<Vec<u8>> {
1086 use aivpn_common::crypto::encrypt_payload;
1087
1088 let mask_data = rmp_serde::to_vec(new_mask)
1090 .map_err(|e| Error::Session(format!("Failed to serialize mask: {}", e)))?;
1091
1092 let signature = self.sign_mask(&mask_data);
1094
1095 let control = ControlPayload::MaskUpdate {
1097 mask_data,
1098 signature,
1099 };
1100 let encoded = control.encode()?;
1101
1102 let mut sess = session.lock();
1103 let inner_header = InnerHeader {
1104 inner_type: InnerType::Control,
1105 seq_num: sess.next_seq() as u16,
1106 };
1107 let mut inner_payload = inner_header.encode().to_vec();
1108 inner_payload.extend_from_slice(&encoded);
1109
1110 let (nonce, counter) = sess.next_send_nonce();
1112 let pad_len = 16u16;
1113 let mut padded = Vec::with_capacity(2 + inner_payload.len() + pad_len as usize);
1114 padded.extend_from_slice(&pad_len.to_le_bytes());
1115 padded.extend_from_slice(&inner_payload);
1116 {
1117 use rand::Rng;
1118 let mut rng = rand::thread_rng();
1119 for _ in 0..pad_len {
1120 padded.push(rng.gen::<u8>());
1121 }
1122 }
1123
1124 let ciphertext = encrypt_payload(&sess.keys.session_key, &nonce, &padded)?;
1125
1126 let time_window =
1128 crypto::compute_time_window(crypto::current_timestamp_ms(), DEFAULT_WINDOW_MS);
1129 let tag = crypto::generate_resonance_tag(&sess.keys.tag_secret, counter, time_window);
1130
1131 let transport_mask = sess.mask.as_ref().unwrap_or(&self.default_mask);
1134 let mdh = if let Some(ref spec) = transport_mask.header_spec {
1135 let mut rng = rand::thread_rng();
1136 spec.generate(&mut rng)
1137 } else {
1138 transport_mask.header_template.clone()
1139 };
1140
1141 let mut packet = Vec::with_capacity(TAG_SIZE + mdh.len() + ciphertext.len());
1143 packet.extend_from_slice(&tag);
1144 packet.extend_from_slice(&mdh);
1145 packet.extend_from_slice(&ciphertext);
1146
1147 Ok(packet)
1148 }
1149}