Skip to main content

aivpn_server/
client_db.rs

1//! Client Database
2//!
3//! Manages registered VPN clients with pre-shared keys, static IPs,
4//! and per-client statistics. Persisted to JSON file.
5
6use std::net::Ipv4Addr;
7use std::path::{Path, PathBuf};
8
9use chrono::{DateTime, Utc};
10use parking_lot::{Mutex, RwLock};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13use tracing::{info, warn};
14
15use aivpn_common::error::{Error, Result};
16use aivpn_common::network_config::VpnNetworkConfig;
17
18/// Client configuration and credentials
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ClientConfig {
21    /// Unique client ID (UUID-like hex string)
22    pub id: String,
23    /// Human-readable name
24    pub name: String,
25    /// Pre-shared key (32 bytes, base64-encoded in JSON).
26    /// SECURITY: never return `ClientConfig` directly from API handlers — use `ClientResponse`
27    /// instead, which explicitly excludes this field.
28    #[serde(with = "base64_bytes")]
29    pub psk: [u8; 32],
30    /// Assigned static VPN IP
31    pub vpn_ip: Ipv4Addr,
32    /// Whether client is enabled
33    pub enabled: bool,
34    /// Creation timestamp
35    pub created_at: DateTime<Utc>,
36    /// Traffic and connection statistics
37    pub stats: ClientStats,
38}
39
40/// Per-client traffic statistics
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct ClientStats {
43    pub bytes_in: u64,
44    pub bytes_out: u64,
45    pub last_connected: Option<DateTime<Utc>>,
46    pub total_connections: u64,
47    pub last_handshake: Option<DateTime<Utc>>,
48}
49
50/// Persistent client database
51#[derive(Debug, Clone, Serialize, Deserialize)]
52struct ClientDbFile {
53    clients: Vec<ClientConfig>,
54    /// Next host offset within the configured VPN subnet to assign.
55    #[serde(default = "default_next_host_offset", alias = "next_octet")]
56    next_host_offset: u32,
57}
58
59fn default_next_host_offset() -> u32 {
60    2
61}
62
63impl Default for ClientDbFile {
64    fn default() -> Self {
65        Self {
66            clients: Vec::new(),
67            next_host_offset: default_next_host_offset(),
68        }
69    }
70}
71
72/// Thread-safe client database with file persistence
73pub struct ClientDatabase {
74    data: RwLock<ClientDbFile>,
75    file_path: PathBuf,
76    network_config: VpnNetworkConfig,
77    last_mtime: Mutex<Option<std::time::SystemTime>>,
78}
79
80impl ClientDatabase {
81    /// Load or create client database from file
82    pub fn load(file_path: &Path, network_config: VpnNetworkConfig) -> Result<Self> {
83        network_config.validate()?;
84        let data = if file_path.exists() {
85            let content = std::fs::read_to_string(file_path)
86                .map_err(|e| Error::Session(format!("Failed to read client DB: {}", e)))?;
87            serde_json::from_str(&content)
88                .map_err(|e| Error::Session(format!("Failed to parse client DB: {}", e)))?
89        } else {
90            ClientDbFile::default()
91        };
92
93        let last_mtime = Mutex::new(std::fs::metadata(file_path).and_then(|m| m.modified()).ok());
94
95        Ok(Self {
96            data: RwLock::new(data),
97            file_path: file_path.to_path_buf(),
98            network_config,
99            last_mtime,
100        })
101    }
102
103    /// Save database to file
104    pub fn save(&self) -> Result<()> {
105        let data = self.data.read();
106        let content = serde_json::to_string_pretty(&*data)
107            .map_err(|e| Error::Session(format!("Failed to serialize client DB: {}", e)))?;
108
109        // Write atomically via temp file
110        let tmp_path = self.file_path.with_extension("tmp");
111        std::fs::write(&tmp_path, &content)
112            .map_err(|e| Error::Session(format!("Failed to write client DB: {}", e)))?;
113        std::fs::rename(&tmp_path, &self.file_path)
114            .map_err(|e| Error::Session(format!("Failed to rename client DB: {}", e)))?;
115
116        // Refresh cached mtime so reload_if_changed ignores our own write
117        if let Ok(mtime) = std::fs::metadata(&self.file_path).and_then(|m| m.modified()) {
118            *self.last_mtime.lock() = Some(mtime);
119        }
120
121        Ok(())
122    }
123
124    /// Add a new client, returns the generated config
125    pub fn add_client(&self, name: &str) -> Result<ClientConfig> {
126        let mut data = self.data.write();
127
128        // Check name uniqueness
129        if data.clients.iter().any(|c| c.name == name) {
130            return Err(Error::Session(format!("Client '{}' already exists", name)));
131        }
132
133        // Allocate VPN IP
134        let vpn_ip = self.allocate_vpn_ip(&mut data)?;
135
136        // Generate random ID and PSK
137        let mut id_bytes = [0u8; 8];
138        let mut psk = [0u8; 32];
139        chacha20poly1305::aead::OsRng.fill_bytes(&mut id_bytes);
140        chacha20poly1305::aead::OsRng.fill_bytes(&mut psk);
141
142        let id = id_bytes
143            .iter()
144            .map(|b| format!("{:02x}", b))
145            .collect::<String>();
146
147        let client = ClientConfig {
148            id,
149            name: name.to_string(),
150            psk,
151            vpn_ip,
152            enabled: true,
153            created_at: Utc::now(),
154            stats: ClientStats::default(),
155        };
156
157        data.clients.push(client.clone());
158        drop(data);
159
160        self.save()?;
161        Ok(client)
162    }
163
164    pub fn network_config(&self) -> VpnNetworkConfig {
165        self.network_config
166    }
167
168    /// Remove a client by ID
169    pub fn remove_client(&self, client_id: &str) -> Result<()> {
170        let mut data = self.data.write();
171        let before = data.clients.len();
172        data.clients.retain(|c| c.id != client_id);
173        if data.clients.len() == before {
174            return Err(Error::Session(format!("Client '{}' not found", client_id)));
175        }
176        drop(data);
177        self.save()?;
178        Ok(())
179    }
180
181    /// Get all clients
182    pub fn list_clients(&self) -> Vec<ClientConfig> {
183        self.data.read().clients.clone()
184    }
185
186    /// Find client by PSK (used during handshake to identify the connecting client)
187    pub fn find_by_psk(&self, psk: &[u8; 32]) -> Option<ClientConfig> {
188        let data = self.data.read();
189        data.clients
190            .iter()
191            .find(|c| c.enabled && subtle::ConstantTimeEq::ct_eq(&c.psk[..], &psk[..]).into())
192            .cloned()
193    }
194
195    /// Find client by VPN IP
196    pub fn find_by_vpn_ip(&self, ip: &Ipv4Addr) -> Option<ClientConfig> {
197        let data = self.data.read();
198        data.clients.iter().find(|c| c.vpn_ip == *ip).cloned()
199    }
200
201    /// Find client by ID
202    pub fn find_by_id(&self, id: &str) -> Option<ClientConfig> {
203        let data = self.data.read();
204        data.clients.iter().find(|c| c.id == id).cloned()
205    }
206
207    /// Update client stats (called from gateway on traffic)
208    pub fn record_handshake(&self, client_id: &str) {
209        let mut data = self.data.write();
210        if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
211            client.stats.total_connections += 1;
212            client.stats.last_handshake = Some(Utc::now());
213            client.stats.last_connected = Some(Utc::now());
214        }
215    }
216
217    /// Update traffic counters
218    pub fn record_traffic(&self, client_id: &str, bytes_in: u64, bytes_out: u64) {
219        let mut data = self.data.write();
220        if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
221            client.stats.bytes_in += bytes_in;
222            client.stats.bytes_out += bytes_out;
223            client.stats.last_connected = Some(Utc::now());
224        }
225    }
226
227    /// Persist stats periodically (called from a background task)
228    pub fn flush_stats(&self) {
229        if let Err(e) = self.save() {
230            warn!("Failed to flush client stats: {}", e);
231        }
232    }
233
234    /// Reload client database from disk if the file has changed.
235    /// Preserves in-memory traffic stats for existing clients.
236    /// Returns true if the client configuration changed.
237    pub fn reload_if_changed(&self) -> bool {
238        let metadata = match std::fs::metadata(&self.file_path) {
239            Ok(m) => m,
240            Err(_) => return false,
241        };
242
243        let current_mtime = metadata.modified().ok();
244        {
245            let last = self.last_mtime.lock();
246            if *last == current_mtime {
247                return false;
248            }
249        }
250
251        match self.reload_from_disk() {
252            Ok(changed) => {
253                *self.last_mtime.lock() = current_mtime;
254                if changed {
255                    info!(
256                        "Client database reloaded from disk ({} clients)",
257                        self.list_clients().len()
258                    );
259                }
260                changed
261            }
262            Err(e) => {
263                warn!("Failed to reload client DB: {}", e);
264                false
265            }
266        }
267    }
268
269    /// Internal: reload from disk, merging with in-memory stats.
270    /// Returns Ok(true) if data changed, Ok(false) if unchanged.
271    fn reload_from_disk(&self) -> Result<bool> {
272        let content = std::fs::read_to_string(&self.file_path)
273            .map_err(|e| Error::Session(format!("Failed to read client DB for reload: {}", e)))?;
274        let new_data: ClientDbFile = serde_json::from_str(&content)
275            .map_err(|e| Error::Session(format!("Failed to parse client DB for reload: {}", e)))?;
276
277        let mut data = self.data.write();
278
279        // Check if anything actually changed in the client configuration.
280        // PSK must be part of the signature so secret rotation takes effect
281        // without requiring a full server restart.
282        let old_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> = data
283            .clients
284            .iter()
285            .map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
286            .collect();
287        let new_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> =
288            new_data
289                .clients
290                .iter()
291                .map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
292                .collect();
293        let changed = old_sig != new_sig;
294
295        if !changed {
296            return Ok(false);
297        }
298
299        // Build a map of existing stats by client ID
300        let mut stats_map: std::collections::HashMap<String, ClientStats> =
301            std::collections::HashMap::new();
302        for client in &data.clients {
303            stats_map.insert(client.id.clone(), client.stats.clone());
304        }
305
306        // Replace clients list, preserving stats for existing clients
307        data.clients = new_data
308            .clients
309            .into_iter()
310            .map(|mut c| {
311                if let Some(saved_stats) = stats_map.get(&c.id) {
312                    c.stats = saved_stats.clone();
313                }
314                c
315            })
316            .collect();
317        data.next_host_offset = new_data.next_host_offset;
318
319        Ok(true)
320    }
321
322    fn allocate_vpn_ip(&self, data: &mut ClientDbFile) -> Result<Ipv4Addr> {
323        let max_host_offset = self.network_config.max_host_offset();
324        if max_host_offset < 1 {
325            return Err(Error::Session(
326                "Configured VPN subnet has no usable host addresses".into(),
327            ));
328        }
329
330        let mut candidate_offset = if data.next_host_offset == 0 {
331            default_next_host_offset()
332        } else {
333            data.next_host_offset
334        };
335
336        for _ in 0..max_host_offset {
337            if let Some(candidate_ip) = self.network_config.ip_for_host_offset(candidate_offset) {
338                let already_used = data
339                    .clients
340                    .iter()
341                    .any(|client| client.vpn_ip == candidate_ip);
342                if candidate_ip != self.network_config.server_vpn_ip && !already_used {
343                    data.next_host_offset = if candidate_offset >= max_host_offset {
344                        1
345                    } else {
346                        candidate_offset + 1
347                    };
348                    return Ok(candidate_ip);
349                }
350            }
351
352            candidate_offset = if candidate_offset >= max_host_offset {
353                1
354            } else {
355                candidate_offset + 1
356            };
357        }
358
359        Err(Error::Session(
360            "No more VPN IPs available in configured subnet".into(),
361        ))
362    }
363}
364
365/// Custom serde module for [u8; 32] as base64
366mod base64_bytes {
367    use serde::{Deserialize, Deserializer, Serialize, Serializer};
368
369    pub fn serialize<S: Serializer>(
370        bytes: &[u8; 32],
371        serializer: S,
372    ) -> std::result::Result<S::Ok, S::Error> {
373        use base64::Engine;
374        let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
375        b64.serialize(serializer)
376    }
377
378    pub fn deserialize<'de, D: Deserializer<'de>>(
379        deserializer: D,
380    ) -> std::result::Result<[u8; 32], D::Error> {
381        use base64::Engine;
382        let s = String::deserialize(deserializer)?;
383        let bytes = base64::engine::general_purpose::STANDARD
384            .decode(&s)
385            .map_err(serde::de::Error::custom)?;
386        if bytes.len() != 32 {
387            return Err(serde::de::Error::custom(format!(
388                "PSK must be 32 bytes, got {}",
389                bytes.len()
390            )));
391        }
392        let mut arr = [0u8; 32];
393        arr.copy_from_slice(&bytes);
394        Ok(arr)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use std::time::Duration;
402
403    fn test_network_config() -> VpnNetworkConfig {
404        VpnNetworkConfig {
405            server_vpn_ip: Ipv4Addr::new(10, 99, 0, 1),
406            prefix_len: 24,
407            mtu: 1400,
408        }
409    }
410
411    #[test]
412    fn reload_if_changed_applies_psk_rotation() {
413        let dir = tempfile::tempdir().unwrap();
414        let db_path = dir.path().join("clients.json");
415        let db = ClientDatabase::load(&db_path, test_network_config()).unwrap();
416
417        let client = db.add_client("alice").unwrap();
418        let old_psk = client.psk;
419
420        db.record_traffic(&client.id, 111, 222);
421
422        let mut on_disk: ClientDbFile =
423            serde_json::from_str(&std::fs::read_to_string(&db_path).unwrap()).unwrap();
424        let new_psk = [0xAB; 32];
425        on_disk.clients[0].psk = new_psk;
426
427        let original_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
428        let updated_json = serde_json::to_string_pretty(&on_disk).unwrap();
429        let mut mtime_changed = false;
430        for _ in 0..20 {
431            std::fs::write(&db_path, &updated_json).unwrap();
432            let new_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
433            if new_mtime != original_mtime {
434                mtime_changed = true;
435                break;
436            }
437            std::thread::sleep(Duration::from_millis(60));
438        }
439        assert!(
440            mtime_changed,
441            "test setup failed to advance client DB mtime"
442        );
443
444        assert!(db.reload_if_changed(), "PSK rotation must trigger reload");
445        assert!(
446            db.find_by_psk(&old_psk).is_none(),
447            "old PSK must stop authenticating after reload"
448        );
449
450        let reloaded = db
451            .find_by_psk(&new_psk)
452            .expect("new PSK must authenticate after reload");
453        assert_eq!(reloaded.id, client.id);
454        assert_eq!(reloaded.stats.bytes_in, 111);
455        assert_eq!(reloaded.stats.bytes_out, 222);
456    }
457}