use alloc::sync::Arc;
use core::{
pin::Pin,
task::{Context, Poll},
};
use bytes::BytesMut;
use netcore::{
smoltcp,
smoltcp::{phy::DeviceCapabilities, time::Instant},
};
pub struct TunRsDevice {
tun: Arc<tun_rs::SyncDevice>,
mtu: u16,
}
pub struct TunRsDeviceAsync {
tun: Arc<tun_rs::AsyncDevice>,
mtu: u16,
}
impl TunRsDevice {
pub fn new(mtu: u16) -> Result<TunRsDevice, std::io::Error> {
let dev = tun_rs::DeviceBuilder::new().mtu(mtu as _).build_sync()?;
#[cfg(unix)]
dev.set_nonblocking(true)?;
Ok(TunRsDevice {
tun: Arc::new(dev),
mtu,
})
}
pub fn inner(&self) -> &tun_rs::SyncDevice {
&self.tun
}
}
impl From<tun_rs::SyncDevice> for TunRsDevice {
fn from(value: tun_rs::SyncDevice) -> Self {
let mtu = value.mtu().unwrap();
Self {
tun: Arc::new(value),
mtu,
}
}
}
impl TunRsDeviceAsync {
pub fn new(mtu: u16) -> Result<TunRsDeviceAsync, std::io::Error> {
let dev = tun_rs::DeviceBuilder::new().mtu(mtu as _).build_async()?;
Ok(TunRsDeviceAsync {
tun: Arc::new(dev),
mtu,
})
}
pub fn inner(&self) -> &tun_rs::AsyncDevice {
&self.tun
}
}
impl From<tun_rs::AsyncDevice> for TunRsDeviceAsync {
fn from(value: tun_rs::AsyncDevice) -> Self {
let mtu = value.mtu().unwrap();
Self {
tun: Arc::new(value),
mtu,
}
}
}
impl netcore::AsyncWakeDevice for TunRsDeviceAsync {
fn poll_rx<'cx>(self: Pin<&mut Self>, cx: &mut Context<'cx>) -> Poll<()> {
let fut = self.tun.readable();
let fut = core::pin::pin![fut];
fut.poll(cx).map(|_| ())
}
fn poll_tx<'cx>(self: Pin<&mut Self>, _cx: &mut Context<'cx>) -> Poll<()> {
Poll::Ready(())
}
}
impl smoltcp::phy::Device for TunRsDevice {
type RxToken<'a>
= RxToken
where
Self: 'a;
type TxToken<'a>
= TxToken
where
Self: 'a;
fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let tx = self.transmit(timestamp)?;
let mut buf = BytesMut::zeroed(self.mtu as _);
#[cfg(windows)]
let n = self.tun.try_recv(&mut buf).ok()?;
#[cfg(unix)]
let n = self.tun.recv(&mut buf).ok()?;
buf.truncate(n);
Some((RxToken(buf.freeze()), tx))
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
Some(TxToken(self.tun.clone()))
}
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.max_transmission_unit = self.mtu as _;
caps
}
}
impl smoltcp::phy::Device for TunRsDeviceAsync {
type RxToken<'a>
= RxToken
where
Self: 'a;
type TxToken<'a>
= TxTokenAsync
where
Self: 'a;
fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let tx = self.transmit(timestamp)?;
let mut buf = BytesMut::zeroed(self.mtu as _);
let n = self.tun.try_recv(&mut buf).ok()?;
buf.truncate(n);
Some((RxToken(buf.freeze()), tx))
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
Some(TxTokenAsync(self.tun.clone()))
}
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.max_transmission_unit = self.mtu as _;
caps
}
}
pub struct RxToken(bytes::Bytes);
impl smoltcp::phy::RxToken for RxToken {
fn consume<R, F>(self, f: F) -> R
where
F: FnOnce(&[u8]) -> R,
{
f(&self.0)
}
}
pub struct TxToken(Arc<tun_rs::SyncDevice>);
impl smoltcp::phy::TxToken for TxToken {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut b = BytesMut::zeroed(len);
let ret = f(&mut b);
if let Err(e) = self.0.send(&b) {
tracing::error!(error = %e, "writing to tun");
}
ret
}
}
pub struct TxTokenAsync(Arc<tun_rs::AsyncDevice>);
impl smoltcp::phy::TxToken for TxTokenAsync {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut b = BytesMut::zeroed(len);
let ret = f(&mut b);
if let Err(e) = self.0.try_send(&b) {
tracing::error!(error = %e, "writing to tun");
}
ret
}
}