Skip to main content

aivpn_server/
mask_store.rs

1//! Mask Store — Storage and Rating System for Auto-Generated Masks
2//!
3//! Stores MaskProfile + MaskStats pairs with automatic deactivation
4//! when success rate drops below threshold. Persists to disk.
5
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use tracing::{error, info, warn};
12
13use aivpn_common::error::Result;
14use aivpn_common::mask::MaskProfile;
15
16use crate::gateway::MaskCatalog;
17
18/// Success rate threshold — masks below this are deactivated
19const DEACTIVATION_THRESHOLD: f32 = 0.80;
20
21/// Minimum usages before deactivation can trigger
22const MIN_USAGES_FOR_DEACTIVATION: u64 = 100;
23
24/// Mask statistics for rating system
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct MaskStats {
27    pub mask_id: String,
28    pub times_used: u64,
29    pub times_failed: u64,
30    pub success_rate: f32,
31    pub confidence: f32,
32    pub is_active: bool,
33    pub created_by: String,
34    pub created_at: u64,
35    pub last_used: Option<u64>,
36}
37
38/// Combined mask profile + statistics
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct MaskEntry {
41    pub profile: MaskProfile,
42    pub stats: MaskStats,
43}
44
45/// Mask store with rating system and disk persistence
46pub struct MaskStore {
47    /// All masks (mask_id → MaskEntry)
48    masks: DashMap<String, MaskEntry>,
49    /// Reference to the gateway's mask catalog for registration
50    catalog: Arc<MaskCatalog>,
51    /// Storage directory for mask files
52    storage_dir: PathBuf,
53}
54
55impl MaskStore {
56    /// Create a new mask store
57    pub fn new(catalog: Arc<MaskCatalog>, storage_dir: PathBuf) -> Self {
58        let store = Self {
59            masks: DashMap::new(),
60            catalog,
61            storage_dir,
62        };
63        // Load masks only from disk — no hardcoded presets
64        store.load_from_disk();
65        store
66    }
67
68    /// Add a new mask entry
69    pub fn add_mask(&self, entry: MaskEntry) -> Result<()> {
70        let mask_id = entry.stats.mask_id.clone();
71        info!(
72            "Storing mask '{}' (confidence: {:.2})",
73            mask_id, entry.stats.confidence
74        );
75
76        // Save to disk
77        self.save_to_disk(&mask_id, &entry);
78
79        // Register in catalog for neural resonance
80        self.catalog.register_mask(entry.profile.clone());
81
82        // Insert into in-memory store
83        self.masks.insert(mask_id, entry);
84        Ok(())
85    }
86
87    /// Register mask in the gateway catalog
88    pub fn register_in_catalog(&self, mask_id: &str) -> Result<()> {
89        if let Some(entry) = self.masks.get(mask_id) {
90            self.catalog.register_mask(entry.value().profile.clone());
91        }
92        Ok(())
93    }
94
95    /// Record successful usage of a mask
96    pub fn record_usage(&self, mask_id: &str) {
97        if let Some(mut entry) = self.masks.get_mut(mask_id) {
98            entry.stats.times_used += 1;
99            entry.stats.success_rate = if entry.stats.times_used > 0 {
100                1.0 - entry.stats.times_failed as f32 / entry.stats.times_used as f32
101            } else {
102                1.0
103            };
104            entry.stats.last_used = Some(current_unix_secs());
105            self.save_stats_to_disk(mask_id, &entry.stats);
106        }
107    }
108
109    /// Record a failure (DPI block detected)
110    pub fn record_failure(&self, mask_id: &str) {
111        if let Some(mut entry) = self.masks.get_mut(mask_id) {
112            entry.stats.times_used += 1;
113            entry.stats.times_failed += 1;
114            entry.stats.success_rate = if entry.stats.times_used > 0 {
115                1.0 - entry.stats.times_failed as f32 / entry.stats.times_used as f32
116            } else {
117                1.0
118            };
119
120            // Auto-deactivation check
121            if entry.stats.success_rate < DEACTIVATION_THRESHOLD
122                && entry.stats.times_used > MIN_USAGES_FOR_DEACTIVATION
123            {
124                entry.stats.is_active = false;
125                self.catalog.remove_mask(mask_id);
126                warn!(
127                    "Mask '{}' deactivated: success={:.1}% ({}/{} failures)",
128                    mask_id,
129                    entry.stats.success_rate * 100.0,
130                    entry.stats.times_failed,
131                    entry.stats.times_used
132                );
133            }
134            self.save_stats_to_disk(mask_id, &entry.stats);
135        }
136    }
137
138    /// List all masks with their stats
139    pub fn list_masks(&self) -> Vec<MaskEntry> {
140        self.masks.iter().map(|e| e.value().clone()).collect()
141    }
142
143    /// Get a specific mask entry
144    pub fn get_mask(&self, mask_id: &str) -> Option<MaskEntry> {
145        self.masks.get(mask_id).map(|e| e.value().clone())
146    }
147
148    /// Delete a mask
149    pub fn delete_mask(&self, mask_id: &str) {
150        self.masks.remove(mask_id);
151        self.catalog.remove_mask(mask_id);
152        // Remove disk files
153        let json_path = self.storage_dir.join(format!("{}.json", mask_id));
154        let stats_path = self.storage_dir.join(format!("{}.stats", mask_id));
155        let _ = std::fs::remove_file(&json_path);
156        let _ = std::fs::remove_file(&stats_path);
157        info!("Deleted mask '{}'", mask_id);
158    }
159
160    /// Broadcast mask update to all connected clients (placeholder)
161    pub async fn broadcast_mask_update(&self, mask_id: &str) -> Result<()> {
162        if let Some(entry) = self.masks.get(mask_id) {
163            // Serialize mask profile for distribution
164            let _profile_data = rmp_serde::to_vec(&entry.value().profile)
165                .map_err(|e| aivpn_common::error::Error::Serialization(e.to_string()))?;
166            // TODO: broadcast to all active sessions via ControlPayload::MaskUpdate
167            info!("Broadcast mask '{}' to all clients", mask_id);
168        }
169        Ok(())
170    }
171
172    fn save_stats_to_disk(&self, mask_id: &str, stats: &MaskStats) {
173        let _ = std::fs::create_dir_all(&self.storage_dir);
174        let stats_path = self.storage_dir.join(format!("{}.stats", mask_id));
175        match serde_json::to_string_pretty(stats) {
176            Ok(json) => {
177                if let Err(e) = std::fs::write(&stats_path, json) {
178                    error!("Failed to save mask stats {}: {}", mask_id, e);
179                }
180            }
181            Err(e) => error!("Failed to serialize mask stats {}: {}", mask_id, e),
182        }
183    }
184
185    /// Save mask entry to disk
186    fn save_to_disk(&self, mask_id: &str, entry: &MaskEntry) {
187        let _ = std::fs::create_dir_all(&self.storage_dir);
188
189        let json_path = self.storage_dir.join(format!("{}.json", mask_id));
190        match serde_json::to_string_pretty(&entry.profile) {
191            Ok(json) => {
192                if let Err(e) = std::fs::write(&json_path, json) {
193                    error!("Failed to save mask profile {}: {}", mask_id, e);
194                }
195            }
196            Err(e) => error!("Failed to serialize mask profile {}: {}", mask_id, e),
197        }
198
199        self.save_stats_to_disk(mask_id, &entry.stats);
200    }
201
202    /// Load masks from disk on startup
203    fn load_from_disk(&self) {
204        let dir = &self.storage_dir;
205        if !dir.exists() {
206            return;
207        }
208
209        let entries = match std::fs::read_dir(dir) {
210            Ok(e) => e,
211            Err(_) => return,
212        };
213
214        for entry in entries.flatten() {
215            let path = entry.path();
216            if path.extension().and_then(|e| e.to_str()) == Some("json") {
217                let mask_id = path
218                    .file_stem()
219                    .and_then(|s| s.to_str())
220                    .unwrap_or("")
221                    .to_string();
222
223                if mask_id.is_empty() {
224                    continue;
225                }
226
227                // Load profile
228                let profile: MaskProfile = match std::fs::read_to_string(&path)
229                    .ok()
230                    .and_then(|json| serde_json::from_str(&json).ok())
231                {
232                    Some(p) => p,
233                    None => continue,
234                };
235
236                // Load stats
237                let stats_path = dir.join(format!("{}.stats", mask_id));
238                let stats: MaskStats = std::fs::read_to_string(&stats_path)
239                    .ok()
240                    .and_then(|json| serde_json::from_str(&json).ok())
241                    .unwrap_or(MaskStats {
242                        mask_id: mask_id.clone(),
243                        times_used: 0,
244                        times_failed: 0,
245                        success_rate: 1.0,
246                        confidence: 0.0,
247                        is_active: true,
248                        created_by: "loaded".into(),
249                        created_at: 0,
250                        last_used: None,
251                    });
252
253                info!(
254                    "Loaded mask '{}' from disk (success: {:.1}%)",
255                    mask_id,
256                    stats.success_rate * 100.0
257                );
258
259                // Register only active masks in the live catalog
260                if stats.is_active {
261                    self.catalog.register_mask(profile.clone());
262                }
263
264                self.masks.insert(mask_id, MaskEntry { profile, stats });
265            }
266        }
267    }
268}
269
270/// Get current Unix timestamp in seconds
271fn current_unix_secs() -> u64 {
272    std::time::SystemTime::now()
273        .duration_since(std::time::UNIX_EPOCH)
274        .unwrap_or_default()
275        .as_secs()
276}