use std::sync::Arc;
use aivpn_common::protocol::ControlPayload;
use aivpn_common::recording::*;
use aivpn_server::gateway::MaskCatalog;
use aivpn_server::mask_store::{MaskEntry, MaskStats, MaskStore};
use aivpn_server::recording::{RecordingManager, RecordingStopOutcome};
fn generate_video_call_packets(count: usize) -> Vec<PacketMetadata> {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut packets = Vec::with_capacity(count);
let start_ns: u64 = 1_000_000_000_000;
let mut current_ns = start_ns;
for i in 0..count {
let direction = if i % 3 == 0 {
Direction::Uplink
} else {
Direction::Downlink
};
let size: u16 = if rng.gen_bool(0.3) {
rng.gen_range(80..150) } else {
rng.gen_range(800..1300) };
let iat_ms: f64 = if size > 500 {
20.0 + rng.gen_range(-5.0..5.0)
} else {
100.0 + rng.gen_range(-30.0..30.0)
};
current_ns += (iat_ms * 1_000_000.0) as u64;
let entropy: f32 = 7.2 + rng.gen_range(-0.3..0.3);
let mut header_prefix = vec![0xC0u8; 16]; header_prefix[1] = 0x00;
header_prefix[2] = 0x00;
header_prefix[3] = 0x01;
packets.push(PacketMetadata {
direction,
size,
iat_ms,
entropy,
header_prefix,
timestamp_ns: current_ns,
});
}
packets
}
#[test]
fn test_recording_session_basic() {
let session_id = [1u8; 16];
let mut session = RecordingSession::new(session_id, "test_service".into(), "admin".into());
assert_eq!(session.total_packets, 0);
assert_eq!(session.service, "test_service");
let packets = generate_video_call_packets(100);
for p in &packets {
session.record(p.clone());
}
assert_eq!(session.total_packets, 100);
assert_eq!(session.packets.len(), 100);
assert!(session.running_stats.uplink_count > 0);
assert!(session.running_stats.downlink_count > 0);
assert!(session.running_stats.mean_entropy() > 6.0);
println!(
"✅ RecordingSession: {} packets, entropy={:.2}, uplink={}, downlink={}",
session.total_packets,
session.running_stats.mean_entropy(),
session.running_stats.uplink_count,
session.running_stats.downlink_count,
);
}
#[test]
fn test_recording_session_cap() {
let session_id = [2u8; 16];
let mut session = RecordingSession::new(session_id, "cap_test".into(), "admin".into());
let packets = generate_video_call_packets(MAX_RECORDING_PACKETS + 1000);
for p in &packets {
session.record(p.clone());
}
assert_eq!(session.packets.len(), MAX_RECORDING_PACKETS);
assert_eq!(session.total_packets, (MAX_RECORDING_PACKETS + 1000) as u64);
println!(
"✅ RecordingSession cap: stored={}, total={}",
session.packets.len(),
session.total_packets
);
}
#[test]
fn test_running_stats_incremental() {
let mut stats = RunningStats::default();
let packets = generate_video_call_packets(1000);
for p in &packets {
stats.update(p);
}
assert!(stats.uplink_count + stats.downlink_count == 1000);
assert!(stats.mean_entropy() > 6.5);
println!(
"✅ RunningStats: up={}, down={}, entropy={:.2}",
stats.uplink_count,
stats.downlink_count,
stats.mean_entropy()
);
}
#[test]
fn test_recording_control_roundtrip() {
let start = ControlPayload::RecordingStart {
service: "yandex_telemost".into(),
};
let encoded = start.encode().unwrap();
let decoded = ControlPayload::decode(&encoded).unwrap();
match decoded {
ControlPayload::RecordingStart { service } => {
assert_eq!(service, "yandex_telemost");
}
_ => panic!("Expected RecordingStart"),
}
let session_id = [0xABu8; 16];
let ack = ControlPayload::RecordingAck {
session_id,
status: "started".into(),
};
let encoded = ack.encode().unwrap();
let decoded = ControlPayload::decode(&encoded).unwrap();
match decoded {
ControlPayload::RecordingAck {
session_id: sid,
status,
} => {
assert_eq!(sid, session_id);
assert_eq!(status, "started");
}
_ => panic!("Expected RecordingAck"),
}
let stop = ControlPayload::RecordingStop { session_id };
let encoded = stop.encode().unwrap();
let decoded = ControlPayload::decode(&encoded).unwrap();
match decoded {
ControlPayload::RecordingStop { session_id: sid } => {
assert_eq!(sid, session_id);
}
_ => panic!("Expected RecordingStop"),
}
let complete = ControlPayload::RecordingComplete {
service: "zoom".into(),
mask_id: "auto_zoom_v1".into(),
confidence: 0.87,
};
let encoded = complete.encode().unwrap();
let decoded = ControlPayload::decode(&encoded).unwrap();
match decoded {
ControlPayload::RecordingComplete {
service,
mask_id,
confidence,
} => {
assert_eq!(service, "zoom");
assert_eq!(mask_id, "auto_zoom_v1");
assert!((confidence - 0.87).abs() < 0.001);
}
_ => panic!("Expected RecordingComplete"),
}
let failed = ControlPayload::RecordingFailed {
reason: "Too few packets".into(),
};
let encoded = failed.encode().unwrap();
let decoded = ControlPayload::decode(&encoded).unwrap();
match decoded {
ControlPayload::RecordingFailed { reason } => {
assert_eq!(reason, "Too few packets");
}
_ => panic!("Expected RecordingFailed"),
}
println!("✅ All 5 Recording ControlPayload variants encode/decode correctly");
}
#[test]
fn test_recording_manager_lifecycle() {
let catalog = Arc::new(MaskCatalog::new());
let store = Arc::new(MaskStore::new(
catalog,
std::path::PathBuf::from("/tmp/aivpn-test-masks"),
));
let manager = RecordingManager::new(store);
let session_id = [3u8; 16];
assert!(!manager.is_recording(&session_id));
manager.start(session_id, "test_service".into(), "admin".into());
assert!(manager.is_recording(&session_id));
let packets = generate_video_call_packets(100);
for p in &packets {
manager.record_packet(session_id, p.clone());
}
let status = manager.status(&session_id);
assert!(status.is_some());
let status = status.unwrap();
assert_eq!(status.total_packets, 100);
assert_eq!(status.service, "test_service");
println!(
"✅ RecordingManager: service='{}', packets={}, up={}, down={}",
status.service, status.total_packets, status.uplink_count, status.downlink_count,
);
let result = manager.stop(session_id);
assert!(matches!(result, RecordingStopOutcome::Incomplete(_)));
assert!(!manager.is_recording(&session_id));
println!("✅ RecordingManager lifecycle works correctly");
}
#[test]
fn test_mask_store_crud() {
let catalog = Arc::new(MaskCatalog::new());
let initial_count = catalog.available_count();
let store = MaskStore::new(
catalog.clone(),
std::path::PathBuf::from("/tmp/aivpn-test-masks-crud"),
);
let profile = aivpn_common::mask::preset_masks::quic_https_v2();
let mask_id = "test_mask_001".to_string();
let mut modified_profile = profile.clone();
modified_profile.mask_id = mask_id.clone();
let entry = MaskEntry {
profile: modified_profile,
stats: MaskStats {
mask_id: mask_id.clone(),
times_used: 0,
times_failed: 0,
success_rate: 1.0,
confidence: 0.85,
is_active: true,
created_by: "test".into(),
created_at: 1000,
last_used: None,
},
};
store.add_mask(entry).unwrap();
assert!(catalog.available_count() > initial_count);
let masks = store.list_masks();
assert!(!masks.is_empty());
assert!(masks.iter().any(|m| m.stats.mask_id == mask_id));
let got = store.get_mask(&mask_id);
assert!(got.is_some());
assert_eq!(got.unwrap().stats.confidence, 0.85);
store.record_usage(&mask_id);
let got = store.get_mask(&mask_id).unwrap();
assert_eq!(got.stats.times_used, 1);
assert_eq!(got.stats.success_rate, 1.0);
store.record_failure(&mask_id);
let got = store.get_mask(&mask_id).unwrap();
assert_eq!(got.stats.times_used, 2);
assert_eq!(got.stats.times_failed, 1);
assert!((got.stats.success_rate - 0.5).abs() < 0.01);
assert!(got.stats.is_active);
store.delete_mask(&mask_id);
assert!(store.get_mask(&mask_id).is_none());
let _ = std::fs::remove_dir_all("/tmp/aivpn-test-masks-crud");
println!("✅ MaskStore CRUD operations work correctly");
}
#[tokio::test]
async fn test_full_mask_generation_pipeline() {
let catalog = Arc::new(MaskCatalog::new());
let storage_dir = std::path::PathBuf::from("/tmp/aivpn-test-mask-gen");
let _ = std::fs::remove_dir_all(&storage_dir);
let store = Arc::new(MaskStore::new(catalog.clone(), storage_dir.clone()));
let packets = generate_video_call_packets(2000);
println!("📊 Test data: {} packets", packets.len());
println!(
" Uplink: {}, Downlink: {}",
packets
.iter()
.filter(|p| p.direction == Direction::Uplink)
.count(),
packets
.iter()
.filter(|p| p.direction == Direction::Downlink)
.count(),
);
let result =
aivpn_server::mask_gen::generate_and_store_mask("video_call_test", &packets, &store).await;
match &result {
Ok(mask_id) => {
println!("✅ Mask generated: '{}'", mask_id);
let entry = store.get_mask(mask_id);
assert!(entry.is_some(), "Mask should be in store after generation");
let entry = entry.unwrap();
println!(" mask_id: {}", entry.profile.mask_id);
println!(" spoof_protocol: {:?}", entry.profile.spoof_protocol);
println!(
" header_template_len: {}",
entry.profile.header_template.len()
);
println!(" fsm_states: {}", entry.profile.fsm_states.len());
println!(
" size_dist_type: {:?}",
entry.profile.size_distribution.dist_type
);
println!(
" iat_dist_type: {:?}",
entry.profile.iat_distribution.dist_type
);
println!(" confidence: {:.2}", entry.stats.confidence);
println!(" is_active: {}", entry.stats.is_active);
assert!(!entry.profile.mask_id.is_empty());
assert!(!entry.profile.header_template.is_empty());
assert!(!entry.profile.fsm_states.is_empty());
assert!(entry.stats.confidence > 0.0);
assert!(entry.stats.is_active);
assert!(catalog.available_count() >= 1);
let mut rng = rand::thread_rng();
let size = entry.profile.size_distribution.sample(&mut rng);
let iat = entry.profile.iat_distribution.sample(&mut rng);
println!(" sample size: {}", size);
println!(" sample iat: {:.2}ms", iat);
assert!(size > 0, "Sampled size should be positive");
assert!(iat >= 0.0, "Sampled IAT should be non-negative");
}
Err(e) => {
println!("❌ Mask generation failed: {}", e);
let uplink_count = packets
.iter()
.filter(|p| p.direction == Direction::Uplink)
.count();
println!(" Uplink packets: {} (need >= 100)", uplink_count);
println!(" Total packets: {}", packets.len());
}
}
let _ = std::fs::remove_dir_all(&storage_dir);
assert!(
result.is_ok(),
"Full mask generation pipeline should succeed: {:?}",
result.err()
);
}
#[tokio::test]
async fn test_end_to_end_recording() {
let catalog = Arc::new(MaskCatalog::new());
let storage_dir = std::path::PathBuf::from("/tmp/aivpn-test-e2e");
let _ = std::fs::remove_dir_all(&storage_dir);
let store = Arc::new(MaskStore::new(catalog.clone(), storage_dir.clone()));
let manager = RecordingManager::new(store.clone());
let session_id = [0xE2u8; 16];
manager.start(session_id, "e2e_test_service".into(), "admin".into());
assert!(manager.is_recording(&session_id));
let packets = generate_video_call_packets(3000);
for p in &packets {
manager.record_packet(session_id, p.clone());
}
let status = manager.status(&session_id).unwrap();
println!(
"📊 E2E recording status: {} packets, {}s",
status.total_packets, status.duration_secs
);
let stopped = manager.stop(session_id);
assert!(!manager.is_recording(&session_id));
println!(
" Stop result: {} (expected Incomplete due to short duration)",
if matches!(stopped, RecordingStopOutcome::Incomplete(_)) {
"Incomplete"
} else {
"Other"
}
);
let result =
aivpn_server::mask_gen::generate_and_store_mask("e2e_test_service", &packets, &store).await;
assert!(
result.is_ok(),
"E2E pipeline should succeed: {:?}",
result.err()
);
let mask_id = result.unwrap();
println!("✅ E2E mask generated: '{}'", mask_id);
let entry = store.get_mask(&mask_id);
assert!(entry.is_some());
let _ = std::fs::remove_dir_all(&storage_dir);
println!("✅ End-to-End recording test passed");
}