use tokio::{sync::watch, time::sleep};
use tun::AbstractDevice;
use crate::{
packet::{Ip, Packet, PacketBufPool},
task::Task,
tun::{IpRecv, IpSend, MtuWatcher},
};
use std::{convert::Infallible, io, iter, sync::Arc, time::Duration};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("Failed to open TUN device: {0}")]
OpenTun(#[source] tun::Error),
#[error("Failed to get TUN device name: {0}")]
GetTunName(#[source] tun::Error),
#[error("Unsupported TUN feature: {0}")]
UnsupportedFeature(String),
#[error("Failed to get TUN device MTU: {0}")]
GetMtu(#[source] tun::Error),
}
#[derive(Clone)]
pub struct TunDevice {
tun: Arc<tun::AsyncDevice>,
state: Arc<TunDeviceState>,
}
struct TunDeviceState {
mtu: MtuWatcher,
_mtu_monitor: Task,
}
impl TunDevice {
pub fn from_name(name: &str) -> Result<Self, Error> {
let mut tun_config = tun::Configuration::default();
if cfg!(not(target_os = "macos")) || name != "utun" {
tun_config.tun_name(name);
}
#[cfg(target_os = "macos")]
tun_config.platform_config(|p| {
p.enable_routing(false);
});
let tun = tun::create_as_async(&tun_config).map_err(Error::OpenTun)?;
let tun = TunDevice::from_tun_device(tun)?;
Ok(tun)
}
pub fn from_tun_device(tun: tun::AsyncDevice) -> Result<Self, Error> {
#[cfg(target_os = "linux")]
if tun.packet_information() {
return Err(Error::UnsupportedFeature("packet_information".to_string()));
}
let mtu = tun.mtu().map_err(Error::GetMtu)?;
let (tx, rx) = watch::channel(mtu);
let tun = Arc::new(tun);
let tun_weak = Arc::downgrade(&tun);
let watch_task = async move || -> Option<Infallible> {
let mut mtu = mtu;
loop {
sleep(Duration::from_secs(3)).await;
let tun = tun_weak.upgrade()?;
let new = tun.mtu().ok()?;
if new != mtu {
mtu = new;
tx.send(mtu).ok()?;
}
}
};
let mtu_monitor = Task::spawn("tun_mtu_monitor", async move {
watch_task().await;
});
Ok(Self {
tun,
state: Arc::new(TunDeviceState {
mtu: rx.into(),
_mtu_monitor: mtu_monitor,
}),
})
}
pub fn name(&self) -> Result<String, Error> {
self.tun.tun_name().map_err(Error::GetTunName)
}
}
impl IpSend for TunDevice {
async fn send(&mut self, packet: Packet<Ip>) -> io::Result<()> {
self.tun.send(&packet.into_bytes()).await?;
Ok(())
}
}
impl IpRecv for TunDevice {
async fn recv<'a>(
&'a mut self,
pool: &mut PacketBufPool,
) -> io::Result<impl Iterator<Item = Packet<Ip>> + 'a> {
let mut packet = pool.get();
let n = self.tun.recv(&mut packet).await?;
packet.truncate(n);
match packet.try_into_ip() {
Ok(packet) => Ok(iter::once(packet)),
Err(e) => Err(io::Error::other(e.to_string())),
}
}
fn mtu(&self) -> MtuWatcher {
self.state.mtu.clone()
}
}