use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use x25519_dalek::StaticSecret;
use crate::device::Error;
use crate::noise::index_table::IndexTable;
#[cfg(feature = "tun")]
use crate::tun::tun_async_device::TunDevice;
use crate::{
device::{Device, DeviceState, allowed_ips::AllowedIps, peer::Peer, uapi::UapiServer},
task::Task,
tun::{IpRecv, IpSend},
udp::{UdpTransportFactory, socket::UdpSocketFactory},
};
use super::Connection;
pub struct Nul;
pub struct DeviceBuilder<Udp, TunTx, TunRx> {
udp: Udp,
tun_tx: TunTx,
tun_rx: TunRx,
private_key: Option<StaticSecret>,
port: u16,
uapi: Option<UapiServer>,
peers: Vec<Peer>,
index_table: Option<IndexTable>,
#[cfg(target_os = "linux")]
fwmark: Option<u32>,
}
impl Default for DeviceBuilder<Nul, Nul, Nul> {
fn default() -> Self {
Self::new()
}
}
impl DeviceBuilder<Nul, Nul, Nul> {
pub const fn new() -> Self {
Self {
udp: Nul,
tun_tx: Nul,
tun_rx: Nul,
private_key: None,
uapi: None,
port: 0,
peers: Vec::new(),
index_table: None,
#[cfg(target_os = "linux")]
fwmark: None,
}
}
}
impl<X, Y> DeviceBuilder<Nul, X, Y> {
pub fn with_default_udp(self) -> DeviceBuilder<UdpSocketFactory, X, Y> {
self.with_udp(UdpSocketFactory)
}
pub fn with_udp<Udp: UdpTransportFactory>(self, udp: Udp) -> DeviceBuilder<Udp, X, Y> {
DeviceBuilder {
udp,
tun_tx: self.tun_tx,
tun_rx: self.tun_rx,
private_key: self.private_key,
uapi: self.uapi,
port: self.port,
peers: self.peers,
index_table: self.index_table,
#[cfg(target_os = "linux")]
fwmark: self.fwmark,
}
}
}
impl<X> DeviceBuilder<X, Nul, Nul> {
#[cfg(feature = "tun")]
pub fn create_tun(
self,
tun_name: &str,
) -> Result<DeviceBuilder<X, TunDevice, TunDevice>, Error> {
let tun = TunDevice::from_name(tun_name)?;
Ok(self.with_ip(tun))
}
#[cfg_attr(feature = "tun", doc = "This is normally a [`TunDevice`], ")]
#[cfg_attr(
not(feature = "tun"),
doc = "This is normally a `TunDevice` (requires feature `tun`), "
)]
pub fn with_ip<Ip: IpSend + IpRecv + Clone>(self, ip: Ip) -> DeviceBuilder<X, Ip, Ip> {
self.with_ip_pair(ip.clone(), ip)
}
pub fn with_ip_pair<IpTx: IpSend, IpRx: IpRecv>(
self,
ip_tx: IpTx,
ip_rx: IpRx,
) -> DeviceBuilder<X, IpTx, IpRx> {
DeviceBuilder {
udp: self.udp,
tun_tx: ip_tx,
tun_rx: ip_rx,
private_key: self.private_key,
uapi: self.uapi,
port: self.port,
peers: self.peers,
index_table: self.index_table,
#[cfg(target_os = "linux")]
fwmark: self.fwmark,
}
}
}
impl<X, Y, Z> DeviceBuilder<X, Y, Z> {
pub fn with_private_key(mut self, private_key: StaticSecret) -> Self {
self.private_key = Some(private_key);
self
}
pub fn with_uapi(mut self, uapi: UapiServer) -> Self {
self.uapi = Some(uapi);
self
}
pub fn with_peer(mut self, peer: Peer) -> Self {
self.peers.push(peer);
self
}
pub fn with_peers(mut self, peers: impl IntoIterator<Item = Peer>) -> Self {
self.peers.extend(peers);
self
}
pub const fn with_listen_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_index_table(mut self, index_table: IndexTable) -> Self {
self.index_table = Some(index_table);
self
}
#[cfg(target_os = "linux")]
pub const fn with_fwmark(mut self, fwmark: u32) -> Self {
self.fwmark = Some(fwmark);
self
}
}
impl<Udp: UdpTransportFactory, TunTx: IpSend, TunRx: IpRecv> DeviceBuilder<Udp, TunTx, TunRx> {
#[cfg_attr(feature = "daita", doc = "Errors if DAITA initialization fails.")]
pub async fn build(self) -> Result<Device<(Udp, TunTx, TunRx)>, Error> {
#[cfg(target_os = "linux")]
let fwmark = self.fwmark;
#[cfg(not(target_os = "linux"))]
let fwmark = None;
let mut state = DeviceState {
api: None,
udp_factory: self.udp,
tun_tx: Arc::new(Mutex::new(self.tun_tx)),
tun_rx_mtu: self.tun_rx.mtu(),
tun_rx: Arc::new(Mutex::new(self.tun_rx)),
fwmark,
key_pair: None,
index_table: self.index_table.unwrap_or_else(IndexTable::from_os_rng),
peers: Default::default(),
peers_by_idx: parking_lot::Mutex::new(Default::default()),
peers_by_ip: AllowedIps::new(),
rate_limiter: None,
port: self.port,
connection: None,
};
if let Some(private_key) = self.private_key {
let _ = state.set_key(private_key).await;
}
let has_peers = !self.peers.is_empty();
for peer in self.peers {
state.add_peer(peer);
}
let inner = Arc::new(RwLock::new(state));
if let Some(uapi) = self.uapi {
inner.try_write().expect("lock is not taken").api = Some(Task::spawn(
"uapi",
DeviceState::handle_api(Arc::downgrade(&inner), uapi),
))
}
if has_peers {
let con = Connection::set_up(inner.clone()).await?;
let mut state = inner.write().await;
state.connection = Some(con);
}
Ok(Device { inner })
}
}