1use std::net::Ipv4Addr;
7use std::path::{Path, PathBuf};
8
9use chrono::{DateTime, Utc};
10use parking_lot::{Mutex, RwLock};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13use tracing::{info, warn};
14
15use aivpn_common::error::{Error, Result};
16use aivpn_common::network_config::VpnNetworkConfig;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ClientConfig {
21 pub id: String,
23 pub name: String,
25 #[serde(with = "base64_bytes")]
29 pub psk: [u8; 32],
30 pub vpn_ip: Ipv4Addr,
32 pub enabled: bool,
34 pub created_at: DateTime<Utc>,
36 pub stats: ClientStats,
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
42pub struct ClientStats {
43 pub bytes_in: u64,
44 pub bytes_out: u64,
45 pub last_connected: Option<DateTime<Utc>>,
46 pub total_connections: u64,
47 pub last_handshake: Option<DateTime<Utc>>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52struct ClientDbFile {
53 clients: Vec<ClientConfig>,
54 #[serde(default = "default_next_host_offset", alias = "next_octet")]
56 next_host_offset: u32,
57}
58
59fn default_next_host_offset() -> u32 {
60 2
61}
62
63impl Default for ClientDbFile {
64 fn default() -> Self {
65 Self {
66 clients: Vec::new(),
67 next_host_offset: default_next_host_offset(),
68 }
69 }
70}
71
72pub struct ClientDatabase {
74 data: RwLock<ClientDbFile>,
75 file_path: PathBuf,
76 network_config: VpnNetworkConfig,
77 last_mtime: Mutex<Option<std::time::SystemTime>>,
78}
79
80impl ClientDatabase {
81 pub fn load(file_path: &Path, network_config: VpnNetworkConfig) -> Result<Self> {
83 network_config.validate()?;
84 let data = if file_path.exists() {
85 let content = std::fs::read_to_string(file_path)
86 .map_err(|e| Error::Session(format!("Failed to read client DB: {}", e)))?;
87 serde_json::from_str(&content)
88 .map_err(|e| Error::Session(format!("Failed to parse client DB: {}", e)))?
89 } else {
90 ClientDbFile::default()
91 };
92
93 let last_mtime = Mutex::new(std::fs::metadata(file_path).and_then(|m| m.modified()).ok());
94
95 Ok(Self {
96 data: RwLock::new(data),
97 file_path: file_path.to_path_buf(),
98 network_config,
99 last_mtime,
100 })
101 }
102
103 pub fn save(&self) -> Result<()> {
105 let data = self.data.read();
106 let content = serde_json::to_string_pretty(&*data)
107 .map_err(|e| Error::Session(format!("Failed to serialize client DB: {}", e)))?;
108
109 let tmp_path = self.file_path.with_extension("tmp");
111 std::fs::write(&tmp_path, &content)
112 .map_err(|e| Error::Session(format!("Failed to write client DB: {}", e)))?;
113 std::fs::rename(&tmp_path, &self.file_path)
114 .map_err(|e| Error::Session(format!("Failed to rename client DB: {}", e)))?;
115
116 if let Ok(mtime) = std::fs::metadata(&self.file_path).and_then(|m| m.modified()) {
118 *self.last_mtime.lock() = Some(mtime);
119 }
120
121 Ok(())
122 }
123
124 pub fn add_client(&self, name: &str) -> Result<ClientConfig> {
126 let mut data = self.data.write();
127
128 if data.clients.iter().any(|c| c.name == name) {
130 return Err(Error::Session(format!("Client '{}' already exists", name)));
131 }
132
133 let vpn_ip = self.allocate_vpn_ip(&mut data)?;
135
136 let mut id_bytes = [0u8; 8];
138 let mut psk = [0u8; 32];
139 chacha20poly1305::aead::OsRng.fill_bytes(&mut id_bytes);
140 chacha20poly1305::aead::OsRng.fill_bytes(&mut psk);
141
142 let id = id_bytes
143 .iter()
144 .map(|b| format!("{:02x}", b))
145 .collect::<String>();
146
147 let client = ClientConfig {
148 id,
149 name: name.to_string(),
150 psk,
151 vpn_ip,
152 enabled: true,
153 created_at: Utc::now(),
154 stats: ClientStats::default(),
155 };
156
157 data.clients.push(client.clone());
158 drop(data);
159
160 self.save()?;
161 Ok(client)
162 }
163
164 pub fn network_config(&self) -> VpnNetworkConfig {
165 self.network_config
166 }
167
168 pub fn remove_client(&self, client_id: &str) -> Result<()> {
170 let mut data = self.data.write();
171 let before = data.clients.len();
172 data.clients.retain(|c| c.id != client_id);
173 if data.clients.len() == before {
174 return Err(Error::Session(format!("Client '{}' not found", client_id)));
175 }
176 drop(data);
177 self.save()?;
178 Ok(())
179 }
180
181 pub fn list_clients(&self) -> Vec<ClientConfig> {
183 self.data.read().clients.clone()
184 }
185
186 pub fn find_by_psk(&self, psk: &[u8; 32]) -> Option<ClientConfig> {
188 let data = self.data.read();
189 data.clients
190 .iter()
191 .find(|c| c.enabled && subtle::ConstantTimeEq::ct_eq(&c.psk[..], &psk[..]).into())
192 .cloned()
193 }
194
195 pub fn find_by_vpn_ip(&self, ip: &Ipv4Addr) -> Option<ClientConfig> {
197 let data = self.data.read();
198 data.clients.iter().find(|c| c.vpn_ip == *ip).cloned()
199 }
200
201 pub fn find_by_id(&self, id: &str) -> Option<ClientConfig> {
203 let data = self.data.read();
204 data.clients.iter().find(|c| c.id == id).cloned()
205 }
206
207 pub fn record_handshake(&self, client_id: &str) {
209 let mut data = self.data.write();
210 if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
211 client.stats.total_connections += 1;
212 client.stats.last_handshake = Some(Utc::now());
213 client.stats.last_connected = Some(Utc::now());
214 }
215 }
216
217 pub fn record_traffic(&self, client_id: &str, bytes_in: u64, bytes_out: u64) {
219 let mut data = self.data.write();
220 if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
221 client.stats.bytes_in += bytes_in;
222 client.stats.bytes_out += bytes_out;
223 client.stats.last_connected = Some(Utc::now());
224 }
225 }
226
227 pub fn flush_stats(&self) {
229 if let Err(e) = self.save() {
230 warn!("Failed to flush client stats: {}", e);
231 }
232 }
233
234 pub fn reload_if_changed(&self) -> bool {
238 let metadata = match std::fs::metadata(&self.file_path) {
239 Ok(m) => m,
240 Err(_) => return false,
241 };
242
243 let current_mtime = metadata.modified().ok();
244 {
245 let last = self.last_mtime.lock();
246 if *last == current_mtime {
247 return false;
248 }
249 }
250
251 match self.reload_from_disk() {
252 Ok(changed) => {
253 *self.last_mtime.lock() = current_mtime;
254 if changed {
255 info!(
256 "Client database reloaded from disk ({} clients)",
257 self.list_clients().len()
258 );
259 }
260 changed
261 }
262 Err(e) => {
263 warn!("Failed to reload client DB: {}", e);
264 false
265 }
266 }
267 }
268
269 fn reload_from_disk(&self) -> Result<bool> {
272 let content = std::fs::read_to_string(&self.file_path)
273 .map_err(|e| Error::Session(format!("Failed to read client DB for reload: {}", e)))?;
274 let new_data: ClientDbFile = serde_json::from_str(&content)
275 .map_err(|e| Error::Session(format!("Failed to parse client DB for reload: {}", e)))?;
276
277 let mut data = self.data.write();
278
279 let old_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> = data
283 .clients
284 .iter()
285 .map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
286 .collect();
287 let new_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> =
288 new_data
289 .clients
290 .iter()
291 .map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
292 .collect();
293 let changed = old_sig != new_sig;
294
295 if !changed {
296 return Ok(false);
297 }
298
299 let mut stats_map: std::collections::HashMap<String, ClientStats> =
301 std::collections::HashMap::new();
302 for client in &data.clients {
303 stats_map.insert(client.id.clone(), client.stats.clone());
304 }
305
306 data.clients = new_data
308 .clients
309 .into_iter()
310 .map(|mut c| {
311 if let Some(saved_stats) = stats_map.get(&c.id) {
312 c.stats = saved_stats.clone();
313 }
314 c
315 })
316 .collect();
317 data.next_host_offset = new_data.next_host_offset;
318
319 Ok(true)
320 }
321
322 fn allocate_vpn_ip(&self, data: &mut ClientDbFile) -> Result<Ipv4Addr> {
323 let max_host_offset = self.network_config.max_host_offset();
324 if max_host_offset < 1 {
325 return Err(Error::Session(
326 "Configured VPN subnet has no usable host addresses".into(),
327 ));
328 }
329
330 let mut candidate_offset = if data.next_host_offset == 0 {
331 default_next_host_offset()
332 } else {
333 data.next_host_offset
334 };
335
336 for _ in 0..max_host_offset {
337 if let Some(candidate_ip) = self.network_config.ip_for_host_offset(candidate_offset) {
338 let already_used = data
339 .clients
340 .iter()
341 .any(|client| client.vpn_ip == candidate_ip);
342 if candidate_ip != self.network_config.server_vpn_ip && !already_used {
343 data.next_host_offset = if candidate_offset >= max_host_offset {
344 1
345 } else {
346 candidate_offset + 1
347 };
348 return Ok(candidate_ip);
349 }
350 }
351
352 candidate_offset = if candidate_offset >= max_host_offset {
353 1
354 } else {
355 candidate_offset + 1
356 };
357 }
358
359 Err(Error::Session(
360 "No more VPN IPs available in configured subnet".into(),
361 ))
362 }
363}
364
365mod base64_bytes {
367 use serde::{Deserialize, Deserializer, Serialize, Serializer};
368
369 pub fn serialize<S: Serializer>(
370 bytes: &[u8; 32],
371 serializer: S,
372 ) -> std::result::Result<S::Ok, S::Error> {
373 use base64::Engine;
374 let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
375 b64.serialize(serializer)
376 }
377
378 pub fn deserialize<'de, D: Deserializer<'de>>(
379 deserializer: D,
380 ) -> std::result::Result<[u8; 32], D::Error> {
381 use base64::Engine;
382 let s = String::deserialize(deserializer)?;
383 let bytes = base64::engine::general_purpose::STANDARD
384 .decode(&s)
385 .map_err(serde::de::Error::custom)?;
386 if bytes.len() != 32 {
387 return Err(serde::de::Error::custom(format!(
388 "PSK must be 32 bytes, got {}",
389 bytes.len()
390 )));
391 }
392 let mut arr = [0u8; 32];
393 arr.copy_from_slice(&bytes);
394 Ok(arr)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use std::time::Duration;
402
403 fn test_network_config() -> VpnNetworkConfig {
404 VpnNetworkConfig {
405 server_vpn_ip: Ipv4Addr::new(10, 99, 0, 1),
406 prefix_len: 24,
407 mtu: 1400,
408 }
409 }
410
411 #[test]
412 fn reload_if_changed_applies_psk_rotation() {
413 let dir = tempfile::tempdir().unwrap();
414 let db_path = dir.path().join("clients.json");
415 let db = ClientDatabase::load(&db_path, test_network_config()).unwrap();
416
417 let client = db.add_client("alice").unwrap();
418 let old_psk = client.psk;
419
420 db.record_traffic(&client.id, 111, 222);
421
422 let mut on_disk: ClientDbFile =
423 serde_json::from_str(&std::fs::read_to_string(&db_path).unwrap()).unwrap();
424 let new_psk = [0xAB; 32];
425 on_disk.clients[0].psk = new_psk;
426
427 let original_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
428 let updated_json = serde_json::to_string_pretty(&on_disk).unwrap();
429 let mut mtime_changed = false;
430 for _ in 0..20 {
431 std::fs::write(&db_path, &updated_json).unwrap();
432 let new_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
433 if new_mtime != original_mtime {
434 mtime_changed = true;
435 break;
436 }
437 std::thread::sleep(Duration::from_millis(60));
438 }
439 assert!(
440 mtime_changed,
441 "test setup failed to advance client DB mtime"
442 );
443
444 assert!(db.reload_if_changed(), "PSK rotation must trigger reload");
445 assert!(
446 db.find_by_psk(&old_psk).is_none(),
447 "old PSK must stop authenticating after reload"
448 );
449
450 let reloaded = db
451 .find_by_psk(&new_psk)
452 .expect("new PSK must authenticate after reload");
453 assert_eq!(reloaded.id, client.id);
454 assert_eq!(reloaded.stats.bytes_in, 111);
455 assert_eq!(reloaded.stats.bytes_out, 222);
456 }
457}