1use std::{
2 net::SocketAddr,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use etherparse::PacketBuilder;
8use futures::{ready, Sink, SinkExt, Stream};
9use smoltcp::wire::UdpPacket;
10use tokio::sync::mpsc::{Receiver, Sender};
11use tokio_util::sync::PollSender;
12use tracing::{error, trace};
13
14use crate::packet::{AnyIpPktFrame, IpPacket};
15
16pub type UdpMsg = (
17 Vec<u8>, SocketAddr, SocketAddr, );
21
22pub struct UdpSocket {
23 udp_rx: Receiver<AnyIpPktFrame>,
24 stack_tx: PollSender<AnyIpPktFrame>,
25}
26
27impl UdpSocket {
28 pub(super) fn new(udp_rx: Receiver<AnyIpPktFrame>, stack_tx: Sender<AnyIpPktFrame>) -> Self {
29 Self {
30 udp_rx,
31 stack_tx: PollSender::new(stack_tx),
32 }
33 }
34
35 pub fn split(self) -> (ReadHalf, WriteHalf) {
36 (
37 ReadHalf {
38 udp_rx: self.udp_rx,
39 },
40 WriteHalf {
41 stack_tx: self.stack_tx,
42 },
43 )
44 }
45}
46
47pub struct ReadHalf {
48 udp_rx: Receiver<AnyIpPktFrame>,
49}
50
51pub struct WriteHalf {
52 stack_tx: PollSender<AnyIpPktFrame>,
53}
54
55impl Stream for ReadHalf {
56 type Item = UdpMsg;
57
58 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
59 self.udp_rx.poll_recv(cx).map(|item| {
60 item.and_then(|frame| {
61 let packet = match IpPacket::new_checked(frame.as_slice()) {
62 Ok(p) => p,
63 Err(err) => {
64 error!("invalid IP packet: {}", err);
65 return None;
66 }
67 };
68
69 let src_ip = packet.src_addr();
70 let dst_ip = packet.dst_addr();
71 let payload = packet.payload();
72
73 let packet = match UdpPacket::new_checked(payload) {
74 Ok(p) => p,
75 Err(err) => {
76 error!("invalid err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
77 return None;
78 }
79 };
80 let src_port = packet.src_port();
81 let dst_port = packet.dst_port();
82
83 let src_addr = SocketAddr::new(src_ip, src_port);
84 let dst_addr = SocketAddr::new(dst_ip, dst_port);
85
86 trace!("created UDP socket for {} <-> {}", src_addr, dst_addr);
87
88 Some((packet.payload().to_vec(), src_addr, dst_addr))
89 })
90 })
91 }
92}
93
94impl Sink<UdpMsg> for WriteHalf {
95 type Error = std::io::Error;
96
97 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 match ready!(self.stack_tx.poll_ready_unpin(cx)) {
99 Ok(()) => Poll::Ready(Ok(())),
100 Err(err) => Poll::Ready(Err(std::io::Error::other(err))),
101 }
102 }
103
104 fn start_send(mut self: Pin<&mut Self>, item: UdpMsg) -> Result<(), Self::Error> {
105 use std::io::{Error, ErrorKind::InvalidData};
106 let (data, src_addr, dst_addr) = item;
107
108 if data.is_empty() {
109 return Ok(());
110 }
111
112 let builder = match (src_addr, dst_addr) {
113 (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
114 PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), 20)
115 .udp(src_addr.port(), dst_addr.port())
116 }
117 (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
118 PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), 20)
119 .udp(src_addr.port(), dst_addr.port())
120 }
121 _ => {
122 return Err(Error::new(InvalidData, "src or destination type unmatch"));
123 }
124 };
125
126 let mut ip_packet_writer = Vec::with_capacity(builder.size(data.len()));
127 builder
128 .write(&mut ip_packet_writer, &data)
129 .map_err(|err| Error::other(format!("PacketBuilder::write: {err}")))?;
130
131 match self.stack_tx.start_send_unpin(ip_packet_writer.clone()) {
132 Ok(()) => Ok(()),
133 Err(err) => Err(Error::other(format!("send error: {err}"))),
134 }
135 }
136
137 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 use std::io::Error;
139 match ready!(self.stack_tx.poll_flush_unpin(cx)) {
140 Ok(()) => Poll::Ready(Ok(())),
141 Err(err) => Poll::Ready(Err(Error::other(format!("flush error: {err}")))),
142 }
143 }
144
145 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 use std::io::Error;
147 match ready!(self.stack_tx.poll_close_unpin(cx)) {
148 Ok(()) => Poll::Ready(Ok(())),
149 Err(err) => Poll::Ready(Err(Error::other(format!("close error: {err}")))),
150 }
151 }
152}