use crate::{
packet::{IpHeader, NetworkPacket, TransportHeader},
PacketReceiver, PacketSender, TTL,
};
use async_io::Timer;
use bytes::BufMut as _;
use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader};
use futures_lite::{AsyncRead, AsyncWrite, StreamExt};
use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration};
#[derive(Debug)]
pub struct IpStackUdpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: PacketSender,
stream_receiver: Pin<Box<PacketReceiver>>,
pkt_sender: PacketSender,
first_payload: Option<Vec<u8>>,
timeout: Pin<Box<Timer>>,
udp_timeout: Duration,
mtu: u16,
}
impl IpStackUdpStream {
pub fn new(
src_addr: SocketAddr,
dst_addr: SocketAddr,
payload: Vec<u8>,
pkt_sender: PacketSender,
mtu: u16,
udp_timeout: Duration,
) -> Self {
let (stream_sender, stream_receiver) = async_channel::unbounded::<NetworkPacket>();
let deadline = std::time::Instant::now() + udp_timeout;
IpStackUdpStream {
src_addr,
dst_addr,
stream_sender,
stream_receiver: Box::pin(stream_receiver),
pkt_sender,
first_payload: Some(payload),
timeout: Box::pin(Timer::at(deadline)),
udp_timeout,
mtu,
}
}
pub(crate) fn stream_sender(&self) -> PacketSender {
self.stream_sender.clone()
}
fn create_rev_packet(&self, ttl: u8, mut payload: Vec<u8>) -> anyhow::Result<NetworkPacket> {
const UHS: usize = 8; match (self.dst_addr.ip(), self.src_addr.ip()) {
(std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => {
let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets())?;
let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
payload.truncate(line_buffer as usize);
ip_h.set_payload_len(payload.len() + UHS)?;
let udp_header = UdpHeader::with_ipv4_checksum(
self.dst_addr.port(),
self.src_addr.port(),
&ip_h,
&payload,
)?;
Ok(NetworkPacket {
ip: IpHeader::Ipv4(ip_h),
transport: TransportHeader::Udp(udp_header),
payload,
})
}
(std::net::IpAddr::V6(dst), std::net::IpAddr::V6(src)) => {
let mut ip_h = Ipv6Header {
traffic_class: 0,
flow_label: Ipv6FlowLabel::ZERO,
payload_length: 0,
next_header: IpNumber::UDP,
hop_limit: ttl,
source: dst.octets(),
destination: src.octets(),
};
let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16);
payload.truncate(line_buffer as usize);
ip_h.payload_length = (payload.len() + UHS) as u16;
let udp_header = UdpHeader::with_ipv6_checksum(
self.dst_addr.port(),
self.src_addr.port(),
&ip_h,
&payload,
)?;
Ok(NetworkPacket {
ip: IpHeader::Ipv6(ip_h),
transport: TransportHeader::Udp(udp_header),
payload,
})
}
_ => unreachable!(),
}
}
pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}
pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
fn reset_timeout(&mut self) {
let deadline = std::time::Instant::now() + self.udp_timeout;
self.timeout.as_mut().set_at(deadline);
}
}
impl AsyncRead for IpStackUdpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
if let Some(p) = self.first_payload.take() {
buf.put_slice(&p);
return std::task::Poll::Ready(Ok(p.len()));
}
if matches!(self.timeout.as_mut().poll(cx), std::task::Poll::Ready(_)) {
return std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::TimedOut)));
}
self.reset_timeout();
match self.stream_receiver.poll_next(cx) {
std::task::Poll::Ready(Some(p)) => {
buf.put_slice(&p.payload);
std::task::Poll::Ready(Ok(p.payload.len()))
}
std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(0)),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl AsyncWrite for IpStackUdpStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.reset_timeout();
let packet = self
.create_rev_packet(TTL, buf.to_vec())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let payload_len = packet.payload.len();
self.pkt_sender
.try_send(packet)
.or(Err(std::io::ErrorKind::UnexpectedEof))?;
std::task::Poll::Ready(Ok(payload_len))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}