use core::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
task::{Context, Poll},
};
use net2::unix::UnixUdpBuilderExt;
use rustls::{ClientConfig, CommonState, ServerConfig, pki_types::ServerName};
use tokio_rustls::{TlsAcceptor, TlsConnector};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::sync::Arc;
use crate::{
device::Device,
io::{
IoImpl, KnownFunctionName, Result, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl,
UdpSocketImpl,
},
};
#[derive(Debug)]
pub struct TokioIoImpl;
impl
IoImpl<
tokio::net::UdpSocket,
tokio::net::TcpStream,
tokio::net::TcpListener,
tokio_rustls::TlsStream<tokio::net::TcpStream>,
> for TokioIoImpl
{
async fn bind_udp(&self, addr: SocketAddr) -> Result<tokio::net::UdpSocket> {
tokio::net::UdpSocket::bind(addr).await.map_err(Into::into)
}
async fn bind_udp_reuse_v6(&self, addr: SocketAddr) -> Result<tokio::net::UdpSocket> {
let raw_socket = net2::UdpBuilder::new_v6()?
.reuse_address(true)?
.reuse_port(true)?
.bind(addr)?;
raw_socket.set_nonblocking(true)?;
tokio::net::UdpSocket::from_std(raw_socket).map_err(Into::into)
}
async fn bind_udp_reuse_multicast_v4(
&self,
addr: SocketAddr,
multicast_addr: (Ipv4Addr, Ipv4Addr),
) -> Result<tokio::net::UdpSocket> {
let raw_socket = net2::UdpBuilder::new_v4()?
.reuse_address(true)?
.reuse_port(true)?
.bind(addr)?;
raw_socket.join_multicast_v4(&multicast_addr.0, &multicast_addr.1)?;
raw_socket.set_nonblocking(true)?;
tokio::net::UdpSocket::from_std(raw_socket).map_err(Into::into)
}
async fn listen_tcp(&self, addr: SocketAddr) -> Result<tokio::net::TcpListener> {
tokio::net::TcpListener::bind(addr)
.await
.map_err(Into::into)
}
async fn connect_tcp(&self, addr: SocketAddr) -> Result<tokio::net::TcpStream> {
tokio::net::TcpStream::connect(addr)
.await
.map_err(Into::into)
}
async fn accept_server_tls(
&self,
config: ServerConfig,
stream: tokio::net::TcpStream,
) -> Result<tokio_rustls::TlsStream<tokio::net::TcpStream>> {
match TlsAcceptor::from(Arc::new(config)).accept(stream).await {
Ok(r) => Ok(tokio_rustls::TlsStream::Server(r)),
Err(e) => Err(e.into()),
}
}
async fn connect_client_tls(
&self,
config: ClientConfig,
server_name: ServerName<'static>,
stream: tokio::net::TcpStream,
) -> Result<tokio_rustls::TlsStream<tokio::net::TcpStream>> {
match TlsConnector::from(Arc::new(config))
.connect(server_name, stream)
.await
{
Ok(r) => Ok(tokio_rustls::TlsStream::Client(r)),
Err(e) => Err(e.into()),
}
}
async fn get_host_addresses(&self) -> (Option<Ipv4Addr>, Option<Ipv6Addr>) {
let addrs = if_addrs::get_if_addrs().unwrap();
let mut ipv4_addr = None;
let mut ipv6_addr = None;
for addr in addrs {
if !addr.is_loopback() && addr.is_oper_up() {
match addr.ip() {
IpAddr::V4(addr) => {
if ipv6_addr.is_some() {
return (Some(addr), ipv6_addr);
}
ipv4_addr = Some(addr);
}
IpAddr::V6(addr) => {
if ipv4_addr.is_some() {
return (ipv4_addr, Some(addr));
}
ipv6_addr = Some(addr);
}
}
}
}
(ipv4_addr, ipv6_addr)
}
async fn sleep(&self, duration: std::time::Duration) {
tokio::time::sleep(duration).await;
}
fn spawn(
&self,
name: KnownFunctionName<tokio::net::TcpStream>,
device: Arc<
Device<
Self,
tokio::net::UdpSocket,
tokio::net::TcpStream,
tokio::net::TcpListener,
tokio_rustls::TlsStream<tokio::net::TcpStream>,
>,
>,
) {
match name {
KnownFunctionName::SetupUdp => {
tokio::task::spawn_local(crate::transport::udp::setup_udp(device))
}
KnownFunctionName::SetupMdns => {
tokio::task::spawn_local(crate::transport::mdns::setup_mdns(device))
}
KnownFunctionName::PerTcpStream(stream) => {
tokio::task::spawn_local(crate::transport::tcp::per_tcp_stream(stream, device))
}
};
}
fn start(
&self,
device: Arc<
Device<
Self,
tokio::net::UdpSocket,
tokio::net::TcpStream,
tokio::net::TcpListener,
tokio_rustls::TlsStream<tokio::net::TcpStream>,
>,
>,
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let set = tokio::task::LocalSet::new();
set.enter();
set.spawn_local(async { crate::transport::tcp::setup_tcp(device).await });
rt.block_on(set);
}
async fn get_current_timestamp(&self) -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}
}
impl UdpSocketImpl for tokio::net::UdpSocket {
fn set_broadcast(&self, on: bool) -> Result<()> {
self.set_broadcast(on).map_err(Into::into)
}
fn poll_recv(&self, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<()>> {
self.poll_recv(cx, &mut tokio::io::ReadBuf::new(buf))
.map_err(Into::into)
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
self.recv_from(buf).await.map_err(Into::into)
}
async fn send_to(&mut self, buf: &[u8], addr: SocketAddr) -> Result<usize> {
tokio::net::UdpSocket::send_to(self, buf, addr)
.await
.map_err(Into::into)
}
}
impl TcpStreamImpl for tokio::net::TcpStream {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
<Self as tokio::io::AsyncReadExt>::read(self, buf)
.await
.map_err(Into::into)
}
async fn writable(&self) -> Result<()> {
self.writable().await.map_err(Into::into)
}
async fn write_all(&mut self, src: &[u8]) -> Result<()> {
<Self as tokio::io::AsyncWriteExt>::write_all(self, src)
.await
.map_err(Into::into)
}
}
impl TlsStreamImpl for tokio_rustls::TlsStream<tokio::net::TcpStream> {
fn get_common_state(&self) -> &CommonState {
self.get_ref().1
}
async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
<Self as tokio::io::AsyncReadExt>::read(self, buf)
.await
.map_err(Into::into)
}
async fn write_all(&mut self, src: &[u8]) -> Result<()> {
<Self as tokio::io::AsyncWriteExt>::write_all(self, src)
.await
.map_err(Into::into)
}
}
impl TcpListenerImpl<tokio::net::TcpStream> for tokio::net::TcpListener {
async fn accept(&self) -> Result<tokio::net::TcpStream> {
self.accept().await.map(|(r, _)| r).map_err(Into::into)
}
}