use std::path::PathBuf;
use std::sync::Arc;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tracing::{error, info, warn};
use aivpn_common::error::Result;
use aivpn_common::mask::MaskProfile;
use crate::gateway::MaskCatalog;
const DEACTIVATION_THRESHOLD: f32 = 0.80;
const MIN_USAGES_FOR_DEACTIVATION: u64 = 100;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MaskStats {
pub mask_id: String,
pub times_used: u64,
pub times_failed: u64,
pub success_rate: f32,
pub confidence: f32,
pub is_active: bool,
pub created_by: String,
pub created_at: u64,
pub last_used: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MaskEntry {
pub profile: MaskProfile,
pub stats: MaskStats,
}
pub struct MaskStore {
masks: DashMap<String, MaskEntry>,
catalog: Arc<MaskCatalog>,
storage_dir: PathBuf,
}
impl MaskStore {
pub fn new(catalog: Arc<MaskCatalog>, storage_dir: PathBuf) -> Self {
let store = Self {
masks: DashMap::new(),
catalog,
storage_dir,
};
store.load_from_disk();
store
}
pub fn add_mask(&self, entry: MaskEntry) -> Result<()> {
let mask_id = entry.stats.mask_id.clone();
info!(
"Storing mask '{}' (confidence: {:.2})",
mask_id, entry.stats.confidence
);
self.save_to_disk(&mask_id, &entry);
self.catalog.register_mask(entry.profile.clone());
self.masks.insert(mask_id, entry);
Ok(())
}
pub fn register_in_catalog(&self, mask_id: &str) -> Result<()> {
if let Some(entry) = self.masks.get(mask_id) {
self.catalog.register_mask(entry.value().profile.clone());
}
Ok(())
}
pub fn record_usage(&self, mask_id: &str) {
if let Some(mut entry) = self.masks.get_mut(mask_id) {
entry.stats.times_used += 1;
entry.stats.success_rate = if entry.stats.times_used > 0 {
1.0 - entry.stats.times_failed as f32 / entry.stats.times_used as f32
} else {
1.0
};
entry.stats.last_used = Some(current_unix_secs());
self.save_stats_to_disk(mask_id, &entry.stats);
}
}
pub fn record_failure(&self, mask_id: &str) {
if let Some(mut entry) = self.masks.get_mut(mask_id) {
entry.stats.times_used += 1;
entry.stats.times_failed += 1;
entry.stats.success_rate = if entry.stats.times_used > 0 {
1.0 - entry.stats.times_failed as f32 / entry.stats.times_used as f32
} else {
1.0
};
if entry.stats.success_rate < DEACTIVATION_THRESHOLD
&& entry.stats.times_used > MIN_USAGES_FOR_DEACTIVATION
{
entry.stats.is_active = false;
self.catalog.remove_mask(mask_id);
warn!(
"Mask '{}' deactivated: success={:.1}% ({}/{} failures)",
mask_id,
entry.stats.success_rate * 100.0,
entry.stats.times_failed,
entry.stats.times_used
);
}
self.save_stats_to_disk(mask_id, &entry.stats);
}
}
pub fn list_masks(&self) -> Vec<MaskEntry> {
self.masks.iter().map(|e| e.value().clone()).collect()
}
pub fn get_mask(&self, mask_id: &str) -> Option<MaskEntry> {
self.masks.get(mask_id).map(|e| e.value().clone())
}
pub fn delete_mask(&self, mask_id: &str) {
self.masks.remove(mask_id);
self.catalog.remove_mask(mask_id);
let json_path = self.storage_dir.join(format!("{}.json", mask_id));
let stats_path = self.storage_dir.join(format!("{}.stats", mask_id));
let _ = std::fs::remove_file(&json_path);
let _ = std::fs::remove_file(&stats_path);
info!("Deleted mask '{}'", mask_id);
}
pub async fn broadcast_mask_update(&self, mask_id: &str) -> Result<()> {
if let Some(entry) = self.masks.get(mask_id) {
let _profile_data = rmp_serde::to_vec(&entry.value().profile)
.map_err(|e| aivpn_common::error::Error::Serialization(e.to_string()))?;
info!("Broadcast mask '{}' to all clients", mask_id);
}
Ok(())
}
fn save_stats_to_disk(&self, mask_id: &str, stats: &MaskStats) {
let _ = std::fs::create_dir_all(&self.storage_dir);
let stats_path = self.storage_dir.join(format!("{}.stats", mask_id));
match serde_json::to_string_pretty(stats) {
Ok(json) => {
if let Err(e) = std::fs::write(&stats_path, json) {
error!("Failed to save mask stats {}: {}", mask_id, e);
}
}
Err(e) => error!("Failed to serialize mask stats {}: {}", mask_id, e),
}
}
fn save_to_disk(&self, mask_id: &str, entry: &MaskEntry) {
let _ = std::fs::create_dir_all(&self.storage_dir);
let json_path = self.storage_dir.join(format!("{}.json", mask_id));
match serde_json::to_string_pretty(&entry.profile) {
Ok(json) => {
if let Err(e) = std::fs::write(&json_path, json) {
error!("Failed to save mask profile {}: {}", mask_id, e);
}
}
Err(e) => error!("Failed to serialize mask profile {}: {}", mask_id, e),
}
self.save_stats_to_disk(mask_id, &entry.stats);
}
fn load_from_disk(&self) {
let dir = &self.storage_dir;
if !dir.exists() {
return;
}
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("json") {
let mask_id = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_string();
if mask_id.is_empty() {
continue;
}
let profile: MaskProfile = match std::fs::read_to_string(&path)
.ok()
.and_then(|json| serde_json::from_str(&json).ok())
{
Some(p) => p,
None => continue,
};
let stats_path = dir.join(format!("{}.stats", mask_id));
let stats: MaskStats = std::fs::read_to_string(&stats_path)
.ok()
.and_then(|json| serde_json::from_str(&json).ok())
.unwrap_or(MaskStats {
mask_id: mask_id.clone(),
times_used: 0,
times_failed: 0,
success_rate: 1.0,
confidence: 0.0,
is_active: true,
created_by: "loaded".into(),
created_at: 0,
last_used: None,
});
info!(
"Loaded mask '{}' from disk (success: {:.1}%)",
mask_id,
stats.success_rate * 100.0
);
if stats.is_active {
self.catalog.register_mask(profile.clone());
}
self.masks.insert(mask_id, MaskEntry { profile, stats });
}
}
}
}
fn current_unix_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}