Skip to main content

aivpn_server/
neural.rs

1//! Neural Resonance Module — Baked Mask Encoder
2//!
3//! Implements Signal Reconstruction Resonance (Patent 1) using a lightweight
4//! hand-rolled MLP instead of a full LLM. Each mask's signature_vector is
5//! "baked" into a tiny neural network (64 → 128 → 64) whose weights are
6//! derived deterministically from the mask's 64-float signature.
7//!
8//! Memory per mask: ~66 KB (vs ~400 MB for Qwen-0.5B).
9//! Total for 100 masks: ~6.6 MB — fits any VPS.
10//!
11//! The baked encoder learns the mask's traffic fingerprint:
12//! - Input: 64-dim feature vector extracted from live traffic
13//! - Output: 64-dim reconstruction vector
14//! - MSE(input, output) = reconstruction error = resonance score
15//!
16//! Low MSE → traffic matches the mask → healthy
17//! High MSE → traffic deviates from mask signature → DPI compromise detected
18
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use tracing::{debug, info};
22
23use aivpn_common::mask::MaskProfile;
24
25// ── Configuration ────────────────────────────────────────────────────────────
26
27/// Neural Resonance Module configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct NeuralConfig {
30    /// Hidden layer size for baked MLP
31    pub hidden_size: usize,
32
33    /// Resonance check interval (seconds)
34    pub check_interval_secs: u64,
35
36    /// MSE threshold for compromised mask
37    pub compromised_threshold: f32,
38
39    /// MSE threshold for warning
40    pub warning_threshold: f32,
41
42    /// Enable anomaly detection
43    pub enable_anomaly_detection: bool,
44}
45
46impl Default for NeuralConfig {
47    fn default() -> Self {
48        Self {
49            hidden_size: 128,
50            check_interval_secs: 30,
51            compromised_threshold: 0.35,
52            warning_threshold: 0.15,
53            enable_anomaly_detection: true,
54        }
55    }
56}
57
58// ── Traffic Statistics ───────────────────────────────────────────────────────
59
60/// Traffic statistics for neural analysis
61#[derive(Debug, Clone, Default)]
62pub struct TrafficStats {
63    /// Packet sizes (last N packets)
64    pub packet_sizes: Vec<u16>,
65    /// Inter-arrival times (ms)
66    pub inter_arrivals: Vec<f64>,
67    /// Byte-level entropy samples
68    pub entropy_samples: Vec<f64>,
69    /// Packets per second
70    pub pps: f64,
71    /// Bytes per second
72    pub bps: f64,
73}
74
75impl TrafficStats {
76    pub fn new() -> Self {
77        Self {
78            packet_sizes: Vec::with_capacity(256),
79            inter_arrivals: Vec::with_capacity(256),
80            entropy_samples: Vec::with_capacity(256),
81            pps: 0.0,
82            bps: 0.0,
83        }
84    }
85
86    /// Add packet sample
87    pub fn add_packet(&mut self, size: u16, iat_ms: f64, entropy: f64) {
88        self.packet_sizes.push(size);
89        self.inter_arrivals.push(iat_ms);
90        self.entropy_samples.push(entropy);
91        // Keep last 256 samples
92        if self.packet_sizes.len() > 256 {
93            self.packet_sizes.remove(0);
94            self.inter_arrivals.remove(0);
95            self.entropy_samples.remove(0);
96        }
97    }
98
99    /// Clear stats
100    pub fn clear(&mut self) {
101        self.packet_sizes.clear();
102        self.inter_arrivals.clear();
103        self.entropy_samples.clear();
104        self.pps = 0.0;
105        self.bps = 0.0;
106    }
107}
108
109// ── Baked Mask Encoder (the tiny neural network) ─────────────────────────────
110
111/// Feature dimension (= mask signature_vector length)
112const FEAT_DIM: usize = 64;
113
114/// A tiny MLP whose weights are deterministically "baked" from a mask's
115/// 64-float signature_vector.
116///
117/// Architecture: Linear(64→H) → ReLU → Linear(H→64)
118///
119/// Weight derivation (fully deterministic, no training needed):
120/// - Each weight is seeded by BLAKE3 hash of the signature, ensuring
121///   structurally unique encoders per mask.
122///
123/// Memory: (64*H + H + H*64 + 64) * 4 bytes ≈ 66 KB for H=128
124pub struct BakedMaskEncoder {
125    w1: Vec<f32>, // [hidden × FEAT_DIM] row-major
126    b1: Vec<f32>, // [hidden]
127    w2: Vec<f32>, // [FEAT_DIM × hidden] row-major
128    b2: Vec<f32>, // [FEAT_DIM]
129    hidden: usize,
130}
131
132impl BakedMaskEncoder {
133    /// Bake an encoder from a mask's signature vector.
134    pub fn from_signature(signature: &[f32], hidden: usize) -> Self {
135        assert!(
136            signature.len() >= FEAT_DIM,
137            "signature must have at least {} floats",
138            FEAT_DIM
139        );
140
141        // Deterministic seed from signature for mixing
142        let sig_bytes: Vec<u8> = signature.iter().flat_map(|f| f.to_le_bytes()).collect();
143        let seed = blake3::hash(&sig_bytes);
144        let seed_bytes = seed.as_bytes();
145
146        let mut w1 = vec![0.0f32; hidden * FEAT_DIM];
147        let mut b1 = vec![0.0f32; hidden];
148        let mut w2 = vec![0.0f32; FEAT_DIM * hidden];
149        let mut b2 = vec![0.0f32; FEAT_DIM];
150
151        // Xavier-scale initialization seeded by signature
152        let scale = (2.0 / (FEAT_DIM + hidden) as f32).sqrt();
153
154        for i in 0..hidden {
155            for j in 0..FEAT_DIM {
156                let idx = (i * FEAT_DIM + j) % 32;
157                let mix = (seed_bytes[idx] as f32 - 128.0) / 128.0;
158                w1[i * FEAT_DIM + j] = signature[j % FEAT_DIM] * mix * scale;
159            }
160            b1[i] = signature[i % FEAT_DIM] * 0.01;
161        }
162
163        for j in 0..FEAT_DIM {
164            for i in 0..hidden {
165                let idx = (j * hidden + i) % 32;
166                let mix = (seed_bytes[idx] as f32 - 128.0) / 128.0;
167                w2[j * hidden + i] = signature[j % FEAT_DIM] * mix * scale;
168            }
169            b2[j] = signature[j] * 0.01;
170        }
171
172        Self {
173            w1,
174            b1,
175            w2,
176            b2,
177            hidden,
178        }
179    }
180
181    /// Forward pass: x → Linear → ReLU → Linear → output
182    pub fn forward(&self, input: &[f32; FEAT_DIM]) -> [f32; FEAT_DIM] {
183        // Layer 1: hidden = ReLU(W1 · input + b1)
184        let mut h = vec![0.0f32; self.hidden];
185        for i in 0..self.hidden {
186            let mut sum = self.b1[i];
187            let row = &self.w1[i * FEAT_DIM..(i + 1) * FEAT_DIM];
188            for j in 0..FEAT_DIM {
189                sum += row[j] * input[j];
190            }
191            h[i] = sum.max(0.0); // ReLU
192        }
193
194        // Layer 2: output = W2 · hidden + b2
195        let mut output = [0.0f32; FEAT_DIM];
196        for j in 0..FEAT_DIM {
197            let mut sum = self.b2[j];
198            let row = &self.w2[j * self.hidden..(j + 1) * self.hidden];
199            for i in 0..self.hidden {
200                sum += row[i] * h[i];
201            }
202            output[j] = sum;
203        }
204        output
205    }
206
207    /// Reconstruction error (MSE) between input features and reconstruction
208    pub fn reconstruction_error(&self, features: &[f32; FEAT_DIM]) -> f32 {
209        let recon = self.forward(features);
210        let mut mse = 0.0f32;
211        for i in 0..FEAT_DIM {
212            let diff = features[i] - recon[i];
213            mse += diff * diff;
214        }
215        mse / FEAT_DIM as f32
216    }
217
218    /// Memory usage in bytes
219    pub fn memory_bytes(&self) -> usize {
220        (self.w1.len() + self.b1.len() + self.w2.len() + self.b2.len()) * 4
221    }
222}
223
224// ── Feature Encoding ─────────────────────────────────────────────────────────
225
226/// Encode traffic stats into a 64-dim feature vector
227pub fn encode_features(stats: &TrafficStats) -> [f32; FEAT_DIM] {
228    let mut features = [0.0f32; FEAT_DIM];
229
230    // Block 1 (0–15): Packet size histogram (16 bins)
231    if !stats.packet_sizes.is_empty() {
232        let bins: [usize; 16] = [
233            0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 896, 1024, 1280,
234        ];
235        for &size in &stats.packet_sizes {
236            for j in 0..15 {
237                if (size as usize) >= bins[j] && (size as usize) < bins[j + 1] {
238                    features[j] += 1.0;
239                    break;
240                }
241            }
242        }
243        let n = stats.packet_sizes.len() as f32;
244        for f in features[0..16].iter_mut() {
245            *f /= n;
246        }
247    }
248
249    // Block 2 (16–31): IAT statistics
250    if !stats.inter_arrivals.is_empty() {
251        let n = stats.inter_arrivals.len() as f64;
252        let mean = stats.inter_arrivals.iter().sum::<f64>() / n;
253        let variance = stats
254            .inter_arrivals
255            .iter()
256            .map(|&x| (x - mean).powi(2))
257            .sum::<f64>()
258            / n;
259        let std_dev = variance.sqrt();
260        let max_val = stats
261            .inter_arrivals
262            .iter()
263            .cloned()
264            .fold(f64::NEG_INFINITY, f64::max);
265        let min_val = stats
266            .inter_arrivals
267            .iter()
268            .cloned()
269            .fold(f64::INFINITY, f64::min);
270
271        features[16] = (mean / 100.0) as f32;
272        features[17] = (std_dev / 100.0) as f32;
273        features[18] = (max_val / 1000.0) as f32;
274        features[19] = (min_val / 1000.0) as f32;
275        // Percentiles
276        let mut sorted = stats.inter_arrivals.clone();
277        sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
278        features[20] = (sorted[sorted.len() / 4] / 100.0) as f32;
279        features[21] = (sorted[sorted.len() / 2] / 100.0) as f32;
280        features[22] = (sorted[sorted.len() * 3 / 4] / 100.0) as f32;
281        features[23] = if mean > 0.0 {
282            (std_dev / mean) as f32
283        } else {
284            0.0
285        };
286    }
287
288    // Block 3 (32–47): Entropy features
289    if !stats.entropy_samples.is_empty() {
290        let n = stats.entropy_samples.len() as f64;
291        let mean = stats.entropy_samples.iter().sum::<f64>() / n;
292        let variance = stats
293            .entropy_samples
294            .iter()
295            .map(|&x| (x - mean).powi(2))
296            .sum::<f64>()
297            / n;
298        features[32] = (mean / 8.0) as f32;
299        features[33] = (variance.sqrt() / 8.0) as f32;
300        let max_val = stats
301            .entropy_samples
302            .iter()
303            .cloned()
304            .fold(f64::NEG_INFINITY, f64::max);
305        let min_val = stats
306            .entropy_samples
307            .iter()
308            .cloned()
309            .fold(f64::INFINITY, f64::min);
310        features[34] = (max_val / 8.0) as f32;
311        features[35] = (min_val / 8.0) as f32;
312    }
313
314    // Block 4 (48–63): Temporal features
315    features[48] = stats.pps as f32 / 1000.0;
316    features[49] = stats.bps as f32 / 1_000_000.0;
317    if !stats.packet_sizes.is_empty() {
318        let n = stats.packet_sizes.len() as f32;
319        let mean_size: f32 = stats.packet_sizes.iter().map(|&s| s as f32).sum::<f32>() / n;
320        features[50] = mean_size / 1500.0;
321        let var: f32 = stats
322            .packet_sizes
323            .iter()
324            .map(|&s| (s as f32 - mean_size).powi(2))
325            .sum::<f32>()
326            / n;
327        features[51] = var.sqrt() / 1500.0;
328    }
329
330    features
331}
332
333// ── Anomaly Detector ─────────────────────────────────────────────────────────
334
335/// Anomaly detector for DPI fingerprinting
336pub struct AnomalyDetector {
337    mask_packet_loss: HashMap<String, Vec<f64>>,
338    mask_rtt: HashMap<String, Vec<f64>>,
339    baseline_loss: f64,
340    baseline_rtt: f64,
341}
342
343impl AnomalyDetector {
344    pub fn new() -> Self {
345        Self {
346            mask_packet_loss: HashMap::new(),
347            mask_rtt: HashMap::new(),
348            baseline_loss: 0.01,
349            baseline_rtt: 50.0,
350        }
351    }
352
353    pub fn record_metrics(&mut self, mask_id: &str, packet_loss: f64, rtt_ms: f64) {
354        let losses = self
355            .mask_packet_loss
356            .entry(mask_id.to_string())
357            .or_default();
358        losses.push(packet_loss);
359        if losses.len() > 100 {
360            losses.remove(0);
361        }
362
363        let rtts = self.mask_rtt.entry(mask_id.to_string()).or_default();
364        rtts.push(rtt_ms);
365        if rtts.len() > 100 {
366            rtts.remove(0);
367        }
368    }
369
370    pub fn is_anomalous(&self, mask_id: &str) -> bool {
371        if let Some(losses) = self.mask_packet_loss.get(mask_id) {
372            if losses.len() >= 10 {
373                let avg = losses.iter().sum::<f64>() / losses.len() as f64;
374                if avg > self.baseline_loss * 5.0 {
375                    return true;
376                }
377            }
378        }
379        if let Some(rtts) = self.mask_rtt.get(mask_id) {
380            if rtts.len() >= 10 {
381                let avg = rtts.iter().sum::<f64>() / rtts.len() as f64;
382                if avg > self.baseline_rtt * 3.0 {
383                    return true;
384                }
385            }
386        }
387        false
388    }
389}
390
391// ── Neural Resonance Module ──────────────────────────────────────────────────
392
393/// Neural Resonance Module
394///
395/// Uses Baked Mask Encoders instead of an external LLM.
396/// Each mask's signature_vector is baked into a tiny MLP (~66KB).
397/// Total memory: O(num_masks * 66KB) — fits any VPS.
398pub struct NeuralResonanceModule {
399    config: NeuralConfig,
400
401    /// Baked encoders per mask (mask_id -> encoder)
402    encoders: HashMap<String, BakedMaskEncoder>,
403
404    /// Per-session traffic stats
405    session_stats: dashmap::DashMap<[u8; 16], TrafficStats>,
406
407    /// Anomaly detection state
408    anomaly_detector: AnomalyDetector,
409
410    /// Whether the module is loaded
411    loaded: bool,
412}
413
414/// Resonance check result
415#[derive(Debug, Clone)]
416pub struct ResonanceResult {
417    pub mse: f32,
418    pub status: ResonanceStatus,
419    pub message: Option<String>,
420}
421
422impl ResonanceResult {
423    fn skip(msg: &str) -> Self {
424        Self {
425            mse: 0.0,
426            status: ResonanceStatus::Skip,
427            message: Some(msg.to_string()),
428        }
429    }
430}
431
432/// Resonance status
433#[derive(Debug, Clone, Copy, PartialEq, Eq)]
434pub enum ResonanceStatus {
435    Healthy,
436    Warning,
437    Compromised,
438    Skip,
439}
440
441impl NeuralResonanceModule {
442    /// Create new neural module
443    pub fn new(config: NeuralConfig) -> Result<Self, String> {
444        Ok(Self {
445            config,
446            encoders: HashMap::new(),
447            session_stats: dashmap::DashMap::new(),
448            anomaly_detector: AnomalyDetector::new(),
449            loaded: false,
450        })
451    }
452
453    /// Load model (marks as ready — no external model files needed)
454    pub fn load_model(&mut self) -> Result<(), String> {
455        self.loaded = true;
456        info!(
457            "Baked Mask Encoder ready (hidden={}, ~{}KB per mask)",
458            self.config.hidden_size,
459            (FEAT_DIM * self.config.hidden_size * 2 + self.config.hidden_size + FEAT_DIM) * 4
460                / 1024
461        );
462        Ok(())
463    }
464
465    /// Register mask — bakes its signature into a dedicated MLP encoder
466    pub fn register_mask(&mut self, mask: &MaskProfile) -> Result<(), String> {
467        if mask.signature_vector.len() < FEAT_DIM {
468            return Err(format!(
469                "Mask '{}' signature_vector too short: {} < {}",
470                mask.mask_id,
471                mask.signature_vector.len(),
472                FEAT_DIM
473            ));
474        }
475        let encoder =
476            BakedMaskEncoder::from_signature(&mask.signature_vector, self.config.hidden_size);
477        debug!(
478            "Baked encoder for mask '{}' ({} bytes)",
479            mask.mask_id,
480            encoder.memory_bytes()
481        );
482        self.encoders.insert(mask.mask_id.clone(), encoder);
483        Ok(())
484    }
485
486    /// Record traffic sample for session
487    pub fn record_traffic(
488        &self,
489        session_id: [u8; 16],
490        packet_size: u16,
491        iat_ms: f64,
492        entropy: f64,
493    ) {
494        if let Some(mut stats) = self.session_stats.get_mut(&session_id) {
495            stats.add_packet(packet_size, iat_ms, entropy);
496        } else {
497            let mut stats = TrafficStats::new();
498            stats.add_packet(packet_size, iat_ms, entropy);
499            self.session_stats.insert(session_id, stats);
500        }
501    }
502
503    /// Perform resonance check (Patent 1: Signal Reconstruction Resonance)
504    ///
505    /// Encodes live traffic into a 64-dim feature vector, passes it through
506    /// the mask's baked encoder, and computes reconstruction MSE.
507    pub fn check_resonance(
508        &self,
509        session_id: [u8; 16],
510        mask_id: &str,
511    ) -> Result<ResonanceResult, String> {
512        if !self.loaded {
513            return Ok(ResonanceResult::skip("Model not loaded"));
514        }
515
516        let Some(stats) = self.session_stats.get(&session_id) else {
517            return Ok(ResonanceResult::skip("No traffic stats"));
518        };
519
520        let Some(encoder) = self.encoders.get(mask_id) else {
521            return Ok(ResonanceResult::skip("Mask encoder not found"));
522        };
523
524        let features = encode_features(&stats);
525        let mse = encoder.reconstruction_error(&features);
526
527        let status = if mse > self.config.compromised_threshold {
528            ResonanceStatus::Compromised
529        } else if mse > self.config.warning_threshold {
530            ResonanceStatus::Warning
531        } else {
532            ResonanceStatus::Healthy
533        };
534
535        Ok(ResonanceResult {
536            mse,
537            status,
538            message: None,
539        })
540    }
541
542    /// Record telemetry for anomaly detection
543    pub fn record_telemetry(&mut self, mask_id: &str, packet_loss: f64, rtt_ms: f64) {
544        if self.config.enable_anomaly_detection {
545            self.anomaly_detector
546                .record_metrics(mask_id, packet_loss, rtt_ms);
547        }
548    }
549
550    /// Check if mask is anomalous (possible DPI blocking)
551    pub fn is_mask_anomalous(&self, mask_id: &str) -> bool {
552        self.anomaly_detector.is_anomalous(mask_id)
553    }
554
555    /// Get or create session stats
556    pub fn get_or_create_stats(&self, session_id: [u8; 16]) -> TrafficStats {
557        self.session_stats
558            .entry(session_id)
559            .or_insert_with(TrafficStats::new)
560            .clone()
561    }
562
563    /// Cleanup old session stats
564    pub fn cleanup_stats(&self, session_id: [u8; 16]) {
565        self.session_stats.remove(&session_id);
566    }
567
568    /// Total memory usage for all baked encoders
569    pub fn total_memory_bytes(&self) -> usize {
570        self.encoders.values().map(|e| e.memory_bytes()).sum()
571    }
572
573    /// Number of registered mask encoders
574    pub fn encoder_count(&self) -> usize {
575        self.encoders.len()
576    }
577}