netstack_smoltcp/
udp.rs

1use std::{
2    net::SocketAddr,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use etherparse::PacketBuilder;
8use futures::{ready, Sink, SinkExt, Stream};
9use smoltcp::wire::UdpPacket;
10use tokio::sync::mpsc::{Receiver, Sender};
11use tokio_util::sync::PollSender;
12use tracing::{error, trace};
13
14use crate::packet::{AnyIpPktFrame, IpPacket};
15
16pub type UdpMsg = (
17    Vec<u8>,    /* payload */
18    SocketAddr, /* local */
19    SocketAddr, /* remote */
20);
21
22pub struct UdpSocket {
23    udp_rx: Receiver<AnyIpPktFrame>,
24    stack_tx: PollSender<AnyIpPktFrame>,
25}
26
27impl UdpSocket {
28    pub(super) fn new(udp_rx: Receiver<AnyIpPktFrame>, stack_tx: Sender<AnyIpPktFrame>) -> Self {
29        Self {
30            udp_rx,
31            stack_tx: PollSender::new(stack_tx),
32        }
33    }
34
35    pub fn split(self) -> (ReadHalf, WriteHalf) {
36        (
37            ReadHalf {
38                udp_rx: self.udp_rx,
39            },
40            WriteHalf {
41                stack_tx: self.stack_tx,
42            },
43        )
44    }
45}
46
47pub struct ReadHalf {
48    udp_rx: Receiver<AnyIpPktFrame>,
49}
50
51pub struct WriteHalf {
52    stack_tx: PollSender<AnyIpPktFrame>,
53}
54
55impl Stream for ReadHalf {
56    type Item = UdpMsg;
57
58    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
59        self.udp_rx.poll_recv(cx).map(|item| {
60            item.and_then(|frame| {
61                let packet = match IpPacket::new_checked(frame.as_slice()) {
62                    Ok(p) => p,
63                    Err(err) => {
64                        error!("invalid IP packet: {}", err);
65                        return None;
66                    }
67                };
68
69                let src_ip = packet.src_addr();
70                let dst_ip = packet.dst_addr();
71                let payload = packet.payload();
72
73                let packet = match UdpPacket::new_checked(payload) {
74                    Ok(p) => p,
75                    Err(err) => {
76                        error!("invalid err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
77                        return None;
78                    }
79                };
80                let src_port = packet.src_port();
81                let dst_port = packet.dst_port();
82
83                let src_addr = SocketAddr::new(src_ip, src_port);
84                let dst_addr = SocketAddr::new(dst_ip, dst_port);
85
86                trace!("created UDP socket for {} <-> {}", src_addr, dst_addr);
87
88                Some((packet.payload().to_vec(), src_addr, dst_addr))
89            })
90        })
91    }
92}
93
94impl Sink<UdpMsg> for WriteHalf {
95    type Error = std::io::Error;
96
97    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        match ready!(self.stack_tx.poll_ready_unpin(cx)) {
99            Ok(()) => Poll::Ready(Ok(())),
100            Err(err) => Poll::Ready(Err(std::io::Error::other(err))),
101        }
102    }
103
104    fn start_send(mut self: Pin<&mut Self>, item: UdpMsg) -> Result<(), Self::Error> {
105        use std::io::{Error, ErrorKind::InvalidData};
106        let (data, src_addr, dst_addr) = item;
107
108        if data.is_empty() {
109            return Ok(());
110        }
111
112        let builder = match (src_addr, dst_addr) {
113            (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
114                PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), 20)
115                    .udp(src_addr.port(), dst_addr.port())
116            }
117            (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
118                PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), 20)
119                    .udp(src_addr.port(), dst_addr.port())
120            }
121            _ => {
122                return Err(Error::new(InvalidData, "src or destination type unmatch"));
123            }
124        };
125
126        let mut ip_packet_writer = Vec::with_capacity(builder.size(data.len()));
127        builder
128            .write(&mut ip_packet_writer, &data)
129            .map_err(|err| Error::other(format!("PacketBuilder::write: {err}")))?;
130
131        match self.stack_tx.start_send_unpin(ip_packet_writer.clone()) {
132            Ok(()) => Ok(()),
133            Err(err) => Err(Error::other(format!("send error: {err}"))),
134        }
135    }
136
137    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        use std::io::Error;
139        match ready!(self.stack_tx.poll_flush_unpin(cx)) {
140            Ok(()) => Poll::Ready(Ok(())),
141            Err(err) => Poll::Ready(Err(Error::other(format!("flush error: {err}")))),
142        }
143    }
144
145    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        use std::io::Error;
147        match ready!(self.stack_tx.poll_close_unpin(cx)) {
148            Ok(()) => Poll::Ready(Ok(())),
149            Err(err) => Poll::Ready(Err(Error::other(format!("close error: {err}")))),
150        }
151    }
152}