use std::{net::SocketAddr, sync::Arc, time::Duration};
use ipnetwork::IpNetwork;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::device::Error;
use crate::device::{Connection, Device, DeviceState, DeviceTransports, Peer, Reconfigure};
pub struct DeviceRead<'a, T: DeviceTransports> {
device: &'a DeviceState<T>,
}
#[derive(Debug)]
pub struct Stats {
pub last_handshake: Option<Duration>,
pub rx_bytes: usize,
pub tx_bytes: usize,
#[cfg(feature = "daita")]
pub daita: Option<DaitaStats>,
}
#[cfg(feature = "daita")]
#[derive(Debug)]
pub struct DaitaStats {
pub tx_padding_bytes: usize,
pub rx_padding_bytes: usize,
pub tx_decoy_packet_bytes: usize,
pub rx_decoy_packet_bytes: usize,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct PeerStats {
pub peer: Peer,
pub stats: Stats,
}
pub struct DeviceWrite<'a, T: DeviceTransports> {
device: &'a mut DeviceState<T>,
reconfigure: Reconfigure,
set_private_key: Option<StaticSecret>,
}
#[derive(Default)]
enum Update<T> {
#[default]
Ignore,
Set(Option<T>),
}
impl<T> From<Option<T>> for Update<T> {
fn from(value: Option<T>) -> Self {
Update::Set(value)
}
}
#[derive(Default)]
#[non_exhaustive]
pub struct PeerMut {
preshared_key: Update<[u8; 32]>,
endpoint: Update<SocketAddr>,
keepalive: Update<u16>,
clear_allowed_ips: bool,
add_allowed_ips: Vec<IpNetwork>,
}
impl PeerMut {
pub fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) {
self.preshared_key = preshared_key.into();
}
pub fn set_endpoint(&mut self, addr: Option<SocketAddr>) {
self.endpoint = addr.into();
}
pub fn set_keepalive(&mut self, keepalive: Option<u16>) {
self.keepalive = keepalive.into();
}
pub fn clear_allowed_ips(&mut self) {
self.clear_allowed_ips = true;
}
pub fn add_allowed_ip(&mut self, allowed_ip: impl Into<IpNetwork>) {
self.add_allowed_ips.push(allowed_ip.into());
}
pub fn add_allowed_ips(&mut self, allowed_ips: impl IntoIterator<Item = impl Into<IpNetwork>>) {
self.add_allowed_ips
.extend(allowed_ips.into_iter().map(Into::into));
}
}
impl<T: DeviceTransports> DeviceRead<'_, T> {
pub fn private_key(&self) -> Option<&StaticSecret> {
self.device.key_pair.as_ref().map(|kp| &kp.0)
}
pub fn listen_port(&self) -> u16 {
self.device.port
}
#[cfg(target_os = "linux")]
pub fn fwmark(&self) -> Option<u32> {
self.device.fwmark
}
pub async fn peers(&self) -> Vec<PeerStats> {
let mut peers = vec![];
for (pubkey, peer) in self.device.peers.iter() {
let p = peer.lock().await;
#[cfg(feature = "daita")]
let daita = p.daita_settings().cloned();
#[cfg(feature = "daita")]
let daita_stats = p.daita().map(|daita| {
let overhead = daita.daita_overhead();
DaitaStats {
tx_padding_bytes: overhead.tx_padding_bytes,
tx_decoy_packet_bytes: overhead
.tx_decoy_packet_bytes
.load(std::sync::atomic::Ordering::SeqCst),
rx_padding_bytes: overhead.rx_padding_bytes,
rx_decoy_packet_bytes: overhead.rx_decoy_packet_bytes,
}
});
let (_, tx_bytes, rx_bytes, ..) = p.tunnel.stats();
let last_handshake = p.time_since_last_handshake();
let stats = Stats {
tx_bytes,
rx_bytes,
last_handshake,
#[cfg(feature = "daita")]
daita: daita_stats,
};
peers.push(PeerStats {
peer: Peer {
public_key: *pubkey,
preshared_key: p.preshared_key,
allowed_ips: p.allowed_ips.iter().map(|(_, net)| net).collect(),
endpoint: p.endpoint.addr,
keepalive: p.tunnel.persistent_keepalive(),
#[cfg(feature = "daita")]
daita_settings: daita,
},
stats,
});
}
peers
}
}
impl<T: DeviceTransports> DeviceWrite<'_, T> {
pub async fn set_private_key(&mut self, private_key: StaticSecret) {
self.reconfigure |= self.device.set_key(private_key).await;
}
pub fn clear_peers(&mut self) -> usize {
self.device.clear_peers()
}
pub fn add_peer(&mut self, peer: Peer) -> bool {
if self.device.peers.contains_key(&peer.public_key) {
return false;
}
self.device.add_peer(peer);
true
}
pub fn add_peers(&mut self, peers: impl IntoIterator<Item = Peer>) -> bool {
let peers: Vec<_> = peers.into_iter().collect();
if peers
.iter()
.any(|peer| self.device.peers.contains_key(&peer.public_key))
{
return false;
}
for peer in peers {
self.device.add_peer(peer);
}
true
}
pub async fn add_or_update_peer(&mut self, peer: Peer) {
if self.device.peers.contains_key(&peer.public_key) {
self.update_peer(peer).await;
} else {
self.add_peer(peer);
}
}
pub async fn add_or_update_peers(&mut self, peers: impl IntoIterator<Item = Peer>) {
for peer in peers {
self.add_or_update_peer(peer).await;
}
}
pub async fn update_peer(&mut self, peer: Peer) -> bool {
self.modify_peer(&peer.public_key, |peer_mut| {
peer_mut.clear_allowed_ips();
peer_mut.add_allowed_ips(peer.allowed_ips);
peer_mut.set_endpoint(peer.endpoint);
peer_mut.set_keepalive(peer.keepalive);
peer_mut.set_preshared_key(peer.preshared_key);
})
.await
}
pub async fn modify_peer(
&mut self,
public_key: &PublicKey,
f: impl for<'a> FnOnce(&mut PeerMut),
) -> bool {
let Some(existing_peer) = self.device.peers.get(public_key) else {
return false;
};
let existing_peer_arc = Arc::clone(existing_peer);
let mut existing_peer = existing_peer_arc.lock().await;
let mut peer_mut = PeerMut::default();
f(&mut peer_mut);
let PeerMut {
preshared_key,
clear_allowed_ips,
add_allowed_ips,
endpoint,
keepalive,
} = peer_mut;
if let Update::Set(preshared_key) = preshared_key {
existing_peer.preshared_key = preshared_key;
}
if let Update::Set(keepalive) = keepalive {
existing_peer.tunnel.set_persistent_keepalive(keepalive);
}
if let Update::Set(addr) = endpoint {
existing_peer.endpoint.addr = addr;
}
if clear_allowed_ips {
existing_peer.allowed_ips.clear();
}
for allowed_ip in add_allowed_ips {
existing_peer
.allowed_ips
.insert(allowed_ip.network(), allowed_ip.prefix(), ());
}
let mut remove_list = vec![];
for (peer, ip_network) in self.device.peers_by_ip.iter() {
if Arc::ptr_eq(&existing_peer_arc, peer) {
remove_list.push(ip_network);
}
}
for network in remove_list {
self.device.peers_by_ip.remove_network(network);
}
for (_, allowed_ip) in existing_peer.allowed_ips.iter() {
self.device.peers_by_ip.insert(
allowed_ip.network(),
allowed_ip.prefix(),
Arc::clone(&existing_peer_arc),
);
}
true
}
pub async fn remove_peer(&mut self, public_key: &PublicKey) -> bool {
self.device.remove_peer(public_key).await.is_some()
}
pub fn set_listen_port(&mut self, port: u16) {
self.reconfigure |= self.device.set_port(port);
}
#[cfg(target_os = "linux")]
pub fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> {
self.device.set_fwmark(mark)
}
pub fn private_key(&self) -> Option<&StaticSecret> {
self.device.key_pair.as_ref().map(|kp| &kp.0)
}
pub fn listen_port(&self) -> u16 {
self.as_configurator().listen_port()
}
#[cfg(target_os = "linux")]
pub fn fwmark(&self) -> Option<u32> {
self.as_configurator().fwmark()
}
pub async fn peers(&self) -> Vec<PeerStats> {
self.as_configurator().peers().await
}
fn as_configurator(&self) -> DeviceRead<'_, T> {
DeviceRead {
device: self.device,
}
}
}
impl<T: DeviceTransports> Device<T> {
pub async fn read<X>(&self, f: impl AsyncFnOnce(&DeviceRead<T>) -> X) -> X {
let state = self.inner.read().await;
let configurator = DeviceRead { device: &state };
f(&configurator).await
}
pub async fn write<X>(
&self,
f: impl AsyncFnOnce(&mut DeviceWrite<T>) -> X,
) -> Result<X, Error> {
let mut state = self.inner.write().await;
let mut configurator = DeviceWrite {
device: &mut state,
reconfigure: Reconfigure::No,
set_private_key: None,
};
let t = f(&mut configurator).await;
if let Some(private_key) = configurator.set_private_key {
configurator.reconfigure |= configurator.device.set_key(private_key).await;
}
if let Reconfigure::Yes = configurator.reconfigure {
drop(state);
let con = Connection::set_up(self.inner.clone()).await?;
let mut state = self.inner.write().await;
state.connection = Some(con);
}
Ok(t)
}
pub async fn set_private_key(&self, private_key: StaticSecret) -> Result<(), Error> {
self.write(async |device| {
device.set_private_key(private_key).await;
})
.await
}
pub async fn clear_peers(&self) -> Result<usize, Error> {
self.write(async |device| device.clear_peers()).await
}
pub async fn add_peer(&self, peer: Peer) -> Result<bool, Error> {
self.write(async |device| device.add_peer(peer)).await
}
pub async fn peers(&self) -> Vec<PeerStats> {
self.read(async |device| device.peers().await).await
}
pub async fn add_peers(&self, peers: impl IntoIterator<Item = Peer>) -> Result<bool, Error> {
self.write(async |device| device.add_peers(peers)).await
}
pub async fn add_or_update_peer(&self, peer: Peer) -> Result<(), Error> {
self.write(async |device| device.add_or_update_peer(peer).await)
.await
}
pub async fn add_or_update_peers(
&mut self,
peers: impl IntoIterator<Item = Peer>,
) -> Result<(), Error> {
self.write(async |device| device.add_or_update_peers(peers).await)
.await
}
pub async fn update_peer(&self, peer: Peer) -> Result<bool, Error> {
self.write(async |device| device.update_peer(peer).await)
.await
}
pub async fn modify_peer(
&mut self,
public_key: &PublicKey,
f: impl for<'a> FnOnce(&mut PeerMut),
) -> Result<bool, Error> {
self.write(async |device| device.modify_peer(public_key, f).await)
.await
}
pub async fn remove_peer(&self, public_key: &PublicKey) -> Result<bool, Error> {
self.write(async |device| device.remove_peer(public_key).await)
.await
}
pub async fn set_listen_port(&self, port: u16) -> Result<(), Error> {
self.write(async |device| device.set_listen_port(port))
.await
}
#[cfg(target_os = "linux")]
pub async fn set_fwmark(&self, mark: u32) -> Result<(), Error> {
self.write(async |device| device.set_fwmark(mark)).await?
}
}