1use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use tracing::{debug, info};
22
23use aivpn_common::mask::MaskProfile;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct NeuralConfig {
30 pub hidden_size: usize,
32
33 pub check_interval_secs: u64,
35
36 pub compromised_threshold: f32,
38
39 pub warning_threshold: f32,
41
42 pub enable_anomaly_detection: bool,
44}
45
46impl Default for NeuralConfig {
47 fn default() -> Self {
48 Self {
49 hidden_size: 128,
50 check_interval_secs: 30,
51 compromised_threshold: 0.35,
52 warning_threshold: 0.15,
53 enable_anomaly_detection: true,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default)]
62pub struct TrafficStats {
63 pub packet_sizes: Vec<u16>,
65 pub inter_arrivals: Vec<f64>,
67 pub entropy_samples: Vec<f64>,
69 pub pps: f64,
71 pub bps: f64,
73}
74
75impl TrafficStats {
76 pub fn new() -> Self {
77 Self {
78 packet_sizes: Vec::with_capacity(256),
79 inter_arrivals: Vec::with_capacity(256),
80 entropy_samples: Vec::with_capacity(256),
81 pps: 0.0,
82 bps: 0.0,
83 }
84 }
85
86 pub fn add_packet(&mut self, size: u16, iat_ms: f64, entropy: f64) {
88 self.packet_sizes.push(size);
89 self.inter_arrivals.push(iat_ms);
90 self.entropy_samples.push(entropy);
91 if self.packet_sizes.len() > 256 {
93 self.packet_sizes.remove(0);
94 self.inter_arrivals.remove(0);
95 self.entropy_samples.remove(0);
96 }
97 }
98
99 pub fn clear(&mut self) {
101 self.packet_sizes.clear();
102 self.inter_arrivals.clear();
103 self.entropy_samples.clear();
104 self.pps = 0.0;
105 self.bps = 0.0;
106 }
107}
108
109const FEAT_DIM: usize = 64;
113
114pub struct BakedMaskEncoder {
125 w1: Vec<f32>, b1: Vec<f32>, w2: Vec<f32>, b2: Vec<f32>, hidden: usize,
130}
131
132impl BakedMaskEncoder {
133 pub fn from_signature(signature: &[f32], hidden: usize) -> Self {
135 assert!(
136 signature.len() >= FEAT_DIM,
137 "signature must have at least {} floats",
138 FEAT_DIM
139 );
140
141 let sig_bytes: Vec<u8> = signature.iter().flat_map(|f| f.to_le_bytes()).collect();
143 let seed = blake3::hash(&sig_bytes);
144 let seed_bytes = seed.as_bytes();
145
146 let mut w1 = vec![0.0f32; hidden * FEAT_DIM];
147 let mut b1 = vec![0.0f32; hidden];
148 let mut w2 = vec![0.0f32; FEAT_DIM * hidden];
149 let mut b2 = vec![0.0f32; FEAT_DIM];
150
151 let scale = (2.0 / (FEAT_DIM + hidden) as f32).sqrt();
153
154 for i in 0..hidden {
155 for j in 0..FEAT_DIM {
156 let idx = (i * FEAT_DIM + j) % 32;
157 let mix = (seed_bytes[idx] as f32 - 128.0) / 128.0;
158 w1[i * FEAT_DIM + j] = signature[j % FEAT_DIM] * mix * scale;
159 }
160 b1[i] = signature[i % FEAT_DIM] * 0.01;
161 }
162
163 for j in 0..FEAT_DIM {
164 for i in 0..hidden {
165 let idx = (j * hidden + i) % 32;
166 let mix = (seed_bytes[idx] as f32 - 128.0) / 128.0;
167 w2[j * hidden + i] = signature[j % FEAT_DIM] * mix * scale;
168 }
169 b2[j] = signature[j] * 0.01;
170 }
171
172 Self {
173 w1,
174 b1,
175 w2,
176 b2,
177 hidden,
178 }
179 }
180
181 pub fn forward(&self, input: &[f32; FEAT_DIM]) -> [f32; FEAT_DIM] {
183 let mut h = vec![0.0f32; self.hidden];
185 for i in 0..self.hidden {
186 let mut sum = self.b1[i];
187 let row = &self.w1[i * FEAT_DIM..(i + 1) * FEAT_DIM];
188 for j in 0..FEAT_DIM {
189 sum += row[j] * input[j];
190 }
191 h[i] = sum.max(0.0); }
193
194 let mut output = [0.0f32; FEAT_DIM];
196 for j in 0..FEAT_DIM {
197 let mut sum = self.b2[j];
198 let row = &self.w2[j * self.hidden..(j + 1) * self.hidden];
199 for i in 0..self.hidden {
200 sum += row[i] * h[i];
201 }
202 output[j] = sum;
203 }
204 output
205 }
206
207 pub fn reconstruction_error(&self, features: &[f32; FEAT_DIM]) -> f32 {
209 let recon = self.forward(features);
210 let mut mse = 0.0f32;
211 for i in 0..FEAT_DIM {
212 let diff = features[i] - recon[i];
213 mse += diff * diff;
214 }
215 mse / FEAT_DIM as f32
216 }
217
218 pub fn memory_bytes(&self) -> usize {
220 (self.w1.len() + self.b1.len() + self.w2.len() + self.b2.len()) * 4
221 }
222}
223
224pub fn encode_features(stats: &TrafficStats) -> [f32; FEAT_DIM] {
228 let mut features = [0.0f32; FEAT_DIM];
229
230 if !stats.packet_sizes.is_empty() {
232 let bins: [usize; 16] = [
233 0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 896, 1024, 1280,
234 ];
235 for &size in &stats.packet_sizes {
236 for j in 0..15 {
237 if (size as usize) >= bins[j] && (size as usize) < bins[j + 1] {
238 features[j] += 1.0;
239 break;
240 }
241 }
242 }
243 let n = stats.packet_sizes.len() as f32;
244 for f in features[0..16].iter_mut() {
245 *f /= n;
246 }
247 }
248
249 if !stats.inter_arrivals.is_empty() {
251 let n = stats.inter_arrivals.len() as f64;
252 let mean = stats.inter_arrivals.iter().sum::<f64>() / n;
253 let variance = stats
254 .inter_arrivals
255 .iter()
256 .map(|&x| (x - mean).powi(2))
257 .sum::<f64>()
258 / n;
259 let std_dev = variance.sqrt();
260 let max_val = stats
261 .inter_arrivals
262 .iter()
263 .cloned()
264 .fold(f64::NEG_INFINITY, f64::max);
265 let min_val = stats
266 .inter_arrivals
267 .iter()
268 .cloned()
269 .fold(f64::INFINITY, f64::min);
270
271 features[16] = (mean / 100.0) as f32;
272 features[17] = (std_dev / 100.0) as f32;
273 features[18] = (max_val / 1000.0) as f32;
274 features[19] = (min_val / 1000.0) as f32;
275 let mut sorted = stats.inter_arrivals.clone();
277 sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
278 features[20] = (sorted[sorted.len() / 4] / 100.0) as f32;
279 features[21] = (sorted[sorted.len() / 2] / 100.0) as f32;
280 features[22] = (sorted[sorted.len() * 3 / 4] / 100.0) as f32;
281 features[23] = if mean > 0.0 {
282 (std_dev / mean) as f32
283 } else {
284 0.0
285 };
286 }
287
288 if !stats.entropy_samples.is_empty() {
290 let n = stats.entropy_samples.len() as f64;
291 let mean = stats.entropy_samples.iter().sum::<f64>() / n;
292 let variance = stats
293 .entropy_samples
294 .iter()
295 .map(|&x| (x - mean).powi(2))
296 .sum::<f64>()
297 / n;
298 features[32] = (mean / 8.0) as f32;
299 features[33] = (variance.sqrt() / 8.0) as f32;
300 let max_val = stats
301 .entropy_samples
302 .iter()
303 .cloned()
304 .fold(f64::NEG_INFINITY, f64::max);
305 let min_val = stats
306 .entropy_samples
307 .iter()
308 .cloned()
309 .fold(f64::INFINITY, f64::min);
310 features[34] = (max_val / 8.0) as f32;
311 features[35] = (min_val / 8.0) as f32;
312 }
313
314 features[48] = stats.pps as f32 / 1000.0;
316 features[49] = stats.bps as f32 / 1_000_000.0;
317 if !stats.packet_sizes.is_empty() {
318 let n = stats.packet_sizes.len() as f32;
319 let mean_size: f32 = stats.packet_sizes.iter().map(|&s| s as f32).sum::<f32>() / n;
320 features[50] = mean_size / 1500.0;
321 let var: f32 = stats
322 .packet_sizes
323 .iter()
324 .map(|&s| (s as f32 - mean_size).powi(2))
325 .sum::<f32>()
326 / n;
327 features[51] = var.sqrt() / 1500.0;
328 }
329
330 features
331}
332
333pub struct AnomalyDetector {
337 mask_packet_loss: HashMap<String, Vec<f64>>,
338 mask_rtt: HashMap<String, Vec<f64>>,
339 baseline_loss: f64,
340 baseline_rtt: f64,
341}
342
343impl AnomalyDetector {
344 pub fn new() -> Self {
345 Self {
346 mask_packet_loss: HashMap::new(),
347 mask_rtt: HashMap::new(),
348 baseline_loss: 0.01,
349 baseline_rtt: 50.0,
350 }
351 }
352
353 pub fn record_metrics(&mut self, mask_id: &str, packet_loss: f64, rtt_ms: f64) {
354 let losses = self
355 .mask_packet_loss
356 .entry(mask_id.to_string())
357 .or_default();
358 losses.push(packet_loss);
359 if losses.len() > 100 {
360 losses.remove(0);
361 }
362
363 let rtts = self.mask_rtt.entry(mask_id.to_string()).or_default();
364 rtts.push(rtt_ms);
365 if rtts.len() > 100 {
366 rtts.remove(0);
367 }
368 }
369
370 pub fn is_anomalous(&self, mask_id: &str) -> bool {
371 if let Some(losses) = self.mask_packet_loss.get(mask_id) {
372 if losses.len() >= 10 {
373 let avg = losses.iter().sum::<f64>() / losses.len() as f64;
374 if avg > self.baseline_loss * 5.0 {
375 return true;
376 }
377 }
378 }
379 if let Some(rtts) = self.mask_rtt.get(mask_id) {
380 if rtts.len() >= 10 {
381 let avg = rtts.iter().sum::<f64>() / rtts.len() as f64;
382 if avg > self.baseline_rtt * 3.0 {
383 return true;
384 }
385 }
386 }
387 false
388 }
389}
390
391pub struct NeuralResonanceModule {
399 config: NeuralConfig,
400
401 encoders: HashMap<String, BakedMaskEncoder>,
403
404 session_stats: dashmap::DashMap<[u8; 16], TrafficStats>,
406
407 anomaly_detector: AnomalyDetector,
409
410 loaded: bool,
412}
413
414#[derive(Debug, Clone)]
416pub struct ResonanceResult {
417 pub mse: f32,
418 pub status: ResonanceStatus,
419 pub message: Option<String>,
420}
421
422impl ResonanceResult {
423 fn skip(msg: &str) -> Self {
424 Self {
425 mse: 0.0,
426 status: ResonanceStatus::Skip,
427 message: Some(msg.to_string()),
428 }
429 }
430}
431
432#[derive(Debug, Clone, Copy, PartialEq, Eq)]
434pub enum ResonanceStatus {
435 Healthy,
436 Warning,
437 Compromised,
438 Skip,
439}
440
441impl NeuralResonanceModule {
442 pub fn new(config: NeuralConfig) -> Result<Self, String> {
444 Ok(Self {
445 config,
446 encoders: HashMap::new(),
447 session_stats: dashmap::DashMap::new(),
448 anomaly_detector: AnomalyDetector::new(),
449 loaded: false,
450 })
451 }
452
453 pub fn load_model(&mut self) -> Result<(), String> {
455 self.loaded = true;
456 info!(
457 "Baked Mask Encoder ready (hidden={}, ~{}KB per mask)",
458 self.config.hidden_size,
459 (FEAT_DIM * self.config.hidden_size * 2 + self.config.hidden_size + FEAT_DIM) * 4
460 / 1024
461 );
462 Ok(())
463 }
464
465 pub fn register_mask(&mut self, mask: &MaskProfile) -> Result<(), String> {
467 if mask.signature_vector.len() < FEAT_DIM {
468 return Err(format!(
469 "Mask '{}' signature_vector too short: {} < {}",
470 mask.mask_id,
471 mask.signature_vector.len(),
472 FEAT_DIM
473 ));
474 }
475 let encoder =
476 BakedMaskEncoder::from_signature(&mask.signature_vector, self.config.hidden_size);
477 debug!(
478 "Baked encoder for mask '{}' ({} bytes)",
479 mask.mask_id,
480 encoder.memory_bytes()
481 );
482 self.encoders.insert(mask.mask_id.clone(), encoder);
483 Ok(())
484 }
485
486 pub fn record_traffic(
488 &self,
489 session_id: [u8; 16],
490 packet_size: u16,
491 iat_ms: f64,
492 entropy: f64,
493 ) {
494 if let Some(mut stats) = self.session_stats.get_mut(&session_id) {
495 stats.add_packet(packet_size, iat_ms, entropy);
496 } else {
497 let mut stats = TrafficStats::new();
498 stats.add_packet(packet_size, iat_ms, entropy);
499 self.session_stats.insert(session_id, stats);
500 }
501 }
502
503 pub fn check_resonance(
508 &self,
509 session_id: [u8; 16],
510 mask_id: &str,
511 ) -> Result<ResonanceResult, String> {
512 if !self.loaded {
513 return Ok(ResonanceResult::skip("Model not loaded"));
514 }
515
516 let Some(stats) = self.session_stats.get(&session_id) else {
517 return Ok(ResonanceResult::skip("No traffic stats"));
518 };
519
520 let Some(encoder) = self.encoders.get(mask_id) else {
521 return Ok(ResonanceResult::skip("Mask encoder not found"));
522 };
523
524 let features = encode_features(&stats);
525 let mse = encoder.reconstruction_error(&features);
526
527 let status = if mse > self.config.compromised_threshold {
528 ResonanceStatus::Compromised
529 } else if mse > self.config.warning_threshold {
530 ResonanceStatus::Warning
531 } else {
532 ResonanceStatus::Healthy
533 };
534
535 Ok(ResonanceResult {
536 mse,
537 status,
538 message: None,
539 })
540 }
541
542 pub fn record_telemetry(&mut self, mask_id: &str, packet_loss: f64, rtt_ms: f64) {
544 if self.config.enable_anomaly_detection {
545 self.anomaly_detector
546 .record_metrics(mask_id, packet_loss, rtt_ms);
547 }
548 }
549
550 pub fn is_mask_anomalous(&self, mask_id: &str) -> bool {
552 self.anomaly_detector.is_anomalous(mask_id)
553 }
554
555 pub fn get_or_create_stats(&self, session_id: [u8; 16]) -> TrafficStats {
557 self.session_stats
558 .entry(session_id)
559 .or_insert_with(TrafficStats::new)
560 .clone()
561 }
562
563 pub fn cleanup_stats(&self, session_id: [u8; 16]) {
565 self.session_stats.remove(&session_id);
566 }
567
568 pub fn total_memory_bytes(&self) -> usize {
570 self.encoders.values().map(|e| e.memory_bytes()).sum()
571 }
572
573 pub fn encoder_count(&self) -> usize {
575 self.encoders.len()
576 }
577}