use std::net::Ipv4Addr;
use std::path::{Path, PathBuf};
use chrono::{DateTime, Utc};
use parking_lot::{Mutex, RwLock};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use aivpn_common::error::{Error, Result};
use aivpn_common::network_config::VpnNetworkConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub id: String,
pub name: String,
#[serde(with = "base64_bytes")]
pub psk: [u8; 32],
pub vpn_ip: Ipv4Addr,
pub enabled: bool,
pub created_at: DateTime<Utc>,
pub stats: ClientStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ClientStats {
pub bytes_in: u64,
pub bytes_out: u64,
pub last_connected: Option<DateTime<Utc>>,
pub total_connections: u64,
pub last_handshake: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ClientDbFile {
clients: Vec<ClientConfig>,
#[serde(default = "default_next_host_offset", alias = "next_octet")]
next_host_offset: u32,
}
fn default_next_host_offset() -> u32 {
2
}
impl Default for ClientDbFile {
fn default() -> Self {
Self {
clients: Vec::new(),
next_host_offset: default_next_host_offset(),
}
}
}
pub struct ClientDatabase {
data: RwLock<ClientDbFile>,
file_path: PathBuf,
network_config: VpnNetworkConfig,
last_mtime: Mutex<Option<std::time::SystemTime>>,
}
impl ClientDatabase {
pub fn load(file_path: &Path, network_config: VpnNetworkConfig) -> Result<Self> {
network_config.validate()?;
let data = if file_path.exists() {
let content = std::fs::read_to_string(file_path)
.map_err(|e| Error::Session(format!("Failed to read client DB: {}", e)))?;
serde_json::from_str(&content)
.map_err(|e| Error::Session(format!("Failed to parse client DB: {}", e)))?
} else {
ClientDbFile::default()
};
let last_mtime = Mutex::new(std::fs::metadata(file_path).and_then(|m| m.modified()).ok());
Ok(Self {
data: RwLock::new(data),
file_path: file_path.to_path_buf(),
network_config,
last_mtime,
})
}
pub fn save(&self) -> Result<()> {
let data = self.data.read();
let content = serde_json::to_string_pretty(&*data)
.map_err(|e| Error::Session(format!("Failed to serialize client DB: {}", e)))?;
let tmp_path = self.file_path.with_extension("tmp");
std::fs::write(&tmp_path, &content)
.map_err(|e| Error::Session(format!("Failed to write client DB: {}", e)))?;
std::fs::rename(&tmp_path, &self.file_path)
.map_err(|e| Error::Session(format!("Failed to rename client DB: {}", e)))?;
if let Ok(mtime) = std::fs::metadata(&self.file_path).and_then(|m| m.modified()) {
*self.last_mtime.lock() = Some(mtime);
}
Ok(())
}
pub fn add_client(&self, name: &str) -> Result<ClientConfig> {
let mut data = self.data.write();
if data.clients.iter().any(|c| c.name == name) {
return Err(Error::Session(format!("Client '{}' already exists", name)));
}
let vpn_ip = self.allocate_vpn_ip(&mut data)?;
let mut id_bytes = [0u8; 8];
let mut psk = [0u8; 32];
chacha20poly1305::aead::OsRng.fill_bytes(&mut id_bytes);
chacha20poly1305::aead::OsRng.fill_bytes(&mut psk);
let id = id_bytes
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>();
let client = ClientConfig {
id,
name: name.to_string(),
psk,
vpn_ip,
enabled: true,
created_at: Utc::now(),
stats: ClientStats::default(),
};
data.clients.push(client.clone());
drop(data);
self.save()?;
Ok(client)
}
pub fn network_config(&self) -> VpnNetworkConfig {
self.network_config
}
pub fn remove_client(&self, client_id: &str) -> Result<()> {
let mut data = self.data.write();
let before = data.clients.len();
data.clients.retain(|c| c.id != client_id);
if data.clients.len() == before {
return Err(Error::Session(format!("Client '{}' not found", client_id)));
}
drop(data);
self.save()?;
Ok(())
}
pub fn list_clients(&self) -> Vec<ClientConfig> {
self.data.read().clients.clone()
}
pub fn find_by_psk(&self, psk: &[u8; 32]) -> Option<ClientConfig> {
let data = self.data.read();
data.clients
.iter()
.find(|c| c.enabled && subtle::ConstantTimeEq::ct_eq(&c.psk[..], &psk[..]).into())
.cloned()
}
pub fn find_by_vpn_ip(&self, ip: &Ipv4Addr) -> Option<ClientConfig> {
let data = self.data.read();
data.clients.iter().find(|c| c.vpn_ip == *ip).cloned()
}
pub fn find_by_id(&self, id: &str) -> Option<ClientConfig> {
let data = self.data.read();
data.clients.iter().find(|c| c.id == id).cloned()
}
pub fn record_handshake(&self, client_id: &str) {
let mut data = self.data.write();
if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
client.stats.total_connections += 1;
client.stats.last_handshake = Some(Utc::now());
client.stats.last_connected = Some(Utc::now());
}
}
pub fn record_traffic(&self, client_id: &str, bytes_in: u64, bytes_out: u64) {
let mut data = self.data.write();
if let Some(client) = data.clients.iter_mut().find(|c| c.id == client_id) {
client.stats.bytes_in += bytes_in;
client.stats.bytes_out += bytes_out;
client.stats.last_connected = Some(Utc::now());
}
}
pub fn flush_stats(&self) {
if let Err(e) = self.save() {
warn!("Failed to flush client stats: {}", e);
}
}
pub fn reload_if_changed(&self) -> bool {
let metadata = match std::fs::metadata(&self.file_path) {
Ok(m) => m,
Err(_) => return false,
};
let current_mtime = metadata.modified().ok();
{
let last = self.last_mtime.lock();
if *last == current_mtime {
return false;
}
}
match self.reload_from_disk() {
Ok(changed) => {
*self.last_mtime.lock() = current_mtime;
if changed {
info!(
"Client database reloaded from disk ({} clients)",
self.list_clients().len()
);
}
changed
}
Err(e) => {
warn!("Failed to reload client DB: {}", e);
false
}
}
}
fn reload_from_disk(&self) -> Result<bool> {
let content = std::fs::read_to_string(&self.file_path)
.map_err(|e| Error::Session(format!("Failed to read client DB for reload: {}", e)))?;
let new_data: ClientDbFile = serde_json::from_str(&content)
.map_err(|e| Error::Session(format!("Failed to parse client DB for reload: {}", e)))?;
let mut data = self.data.write();
let old_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> = data
.clients
.iter()
.map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
.collect();
let new_sig: std::collections::HashSet<(String, String, [u8; 32], Ipv4Addr, bool)> =
new_data
.clients
.iter()
.map(|c| (c.id.clone(), c.name.clone(), c.psk, c.vpn_ip, c.enabled))
.collect();
let changed = old_sig != new_sig;
if !changed {
return Ok(false);
}
let mut stats_map: std::collections::HashMap<String, ClientStats> =
std::collections::HashMap::new();
for client in &data.clients {
stats_map.insert(client.id.clone(), client.stats.clone());
}
data.clients = new_data
.clients
.into_iter()
.map(|mut c| {
if let Some(saved_stats) = stats_map.get(&c.id) {
c.stats = saved_stats.clone();
}
c
})
.collect();
data.next_host_offset = new_data.next_host_offset;
Ok(true)
}
fn allocate_vpn_ip(&self, data: &mut ClientDbFile) -> Result<Ipv4Addr> {
let max_host_offset = self.network_config.max_host_offset();
if max_host_offset < 1 {
return Err(Error::Session(
"Configured VPN subnet has no usable host addresses".into(),
));
}
let mut candidate_offset = if data.next_host_offset == 0 {
default_next_host_offset()
} else {
data.next_host_offset
};
for _ in 0..max_host_offset {
if let Some(candidate_ip) = self.network_config.ip_for_host_offset(candidate_offset) {
let already_used = data
.clients
.iter()
.any(|client| client.vpn_ip == candidate_ip);
if candidate_ip != self.network_config.server_vpn_ip && !already_used {
data.next_host_offset = if candidate_offset >= max_host_offset {
1
} else {
candidate_offset + 1
};
return Ok(candidate_ip);
}
}
candidate_offset = if candidate_offset >= max_host_offset {
1
} else {
candidate_offset + 1
};
}
Err(Error::Session(
"No more VPN IPs available in configured subnet".into(),
))
}
}
mod base64_bytes {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(
bytes: &[u8; 32],
serializer: S,
) -> std::result::Result<S::Ok, S::Error> {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
b64.serialize(serializer)
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<[u8; 32], D::Error> {
use base64::Engine;
let s = String::deserialize(deserializer)?;
let bytes = base64::engine::general_purpose::STANDARD
.decode(&s)
.map_err(serde::de::Error::custom)?;
if bytes.len() != 32 {
return Err(serde::de::Error::custom(format!(
"PSK must be 32 bytes, got {}",
bytes.len()
)));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(arr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn test_network_config() -> VpnNetworkConfig {
VpnNetworkConfig {
server_vpn_ip: Ipv4Addr::new(10, 99, 0, 1),
prefix_len: 24,
mtu: 1400,
}
}
#[test]
fn reload_if_changed_applies_psk_rotation() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("clients.json");
let db = ClientDatabase::load(&db_path, test_network_config()).unwrap();
let client = db.add_client("alice").unwrap();
let old_psk = client.psk;
db.record_traffic(&client.id, 111, 222);
let mut on_disk: ClientDbFile =
serde_json::from_str(&std::fs::read_to_string(&db_path).unwrap()).unwrap();
let new_psk = [0xAB; 32];
on_disk.clients[0].psk = new_psk;
let original_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
let updated_json = serde_json::to_string_pretty(&on_disk).unwrap();
let mut mtime_changed = false;
for _ in 0..20 {
std::fs::write(&db_path, &updated_json).unwrap();
let new_mtime = std::fs::metadata(&db_path).unwrap().modified().unwrap();
if new_mtime != original_mtime {
mtime_changed = true;
break;
}
std::thread::sleep(Duration::from_millis(60));
}
assert!(
mtime_changed,
"test setup failed to advance client DB mtime"
);
assert!(db.reload_if_changed(), "PSK rotation must trigger reload");
assert!(
db.find_by_psk(&old_psk).is_none(),
"old PSK must stop authenticating after reload"
);
let reloaded = db
.find_by_psk(&new_psk)
.expect("new PSK must authenticate after reload");
assert_eq!(reloaded.id, client.id);
assert_eq!(reloaded.stats.bytes_in, 111);
assert_eq!(reloaded.stats.bytes_out, 222);
}
}