bevy_slinet/protocols/
udp.rs

1//! UDP protocol implementation based on [`tokio::net`]. You can enable it by adding `protocol_udp` feature.
2
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use async_trait::async_trait;
12use dashmap::DashMap;
13use futures::task::AtomicWaker;
14use tokio::net::UdpSocket;
15use tokio::sync::Mutex;
16
17use crate::protocol::{
18    ClientStream, Listener, NetworkStream, ReadStream, ServerStream, WriteStream,
19};
20use crate::Protocol;
21
22const BUFFER_SIZE: usize = u16::MAX as usize;
23
24/// UDP protocol.
25pub struct UdpProtocol;
26
27#[async_trait]
28impl Protocol for UdpProtocol {
29    type Listener = UdpNetworkListener;
30    type ServerStream = UdpServerStream;
31    type ClientStream = UdpClientStream;
32
33    async fn bind(addr: SocketAddr) -> std::io::Result<Self::Listener> {
34        Ok(UdpNetworkListener {
35            socket: Arc::new(UdpSocket::bind(addr).await?),
36            tasks: DashMap::new(),
37        })
38    }
39}
40
41struct Inner {
42    waker: AtomicWaker,
43    bytes: Mutex<Vec<u8>>,
44}
45
46#[derive(Clone)]
47struct UdpRead(Arc<Inner>);
48
49/// A UDP listener.
50pub struct UdpNetworkListener {
51    socket: Arc<UdpSocket>,
52    tasks: DashMap<SocketAddr, UdpRead>,
53}
54
55#[async_trait]
56impl Listener for UdpNetworkListener {
57    type Stream = UdpServerStream;
58
59    async fn accept(&self) -> std::io::Result<UdpServerStream> {
60        let mut buf = [0; BUFFER_SIZE];
61        loop {
62            let (bytes, address) = self.socket.recv_from(&mut buf).await?;
63            let bytes = &buf[..bytes];
64            if let Some(task) = self.tasks.get(&address) {
65                {
66                    let mut task_bytes = task.0.bytes.lock().await;
67                    task_bytes.extend(bytes);
68                }
69                task.0.waker.wake();
70            } else {
71                let new_task = UdpRead(Arc::new(Inner {
72                    waker: AtomicWaker::new(),
73                    bytes: Mutex::new(Vec::new()),
74                }));
75                self.tasks.insert(address, new_task.clone());
76                return Ok(UdpServerStream {
77                    task: new_task,
78                    peer_addr: address,
79                    socket: Arc::clone(&self.socket),
80                });
81            }
82        }
83    }
84
85    fn address(&self) -> SocketAddr {
86        self.socket.local_addr().unwrap()
87    }
88
89    fn handle_disconnection(&self, peer_addr: SocketAddr) {
90        self.tasks.remove(&peer_addr);
91    }
92}
93
94/// A UDP server stream that contains cached bytes and a task waker.
95pub struct UdpServerStream {
96    task: UdpRead,
97    peer_addr: SocketAddr,
98    socket: Arc<UdpSocket>,
99}
100
101#[async_trait]
102impl NetworkStream for UdpServerStream {
103    type ReadHalf = UdpServerReadHalf;
104    type WriteHalf = UdpServerWriteHalf;
105
106    async fn into_split(self) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {
107        let peer_addr = self.peer_addr();
108        Ok((
109            UdpServerReadHalf(self.task.clone()),
110            UdpServerWriteHalf {
111                peer_addr,
112                socket: self.socket,
113            },
114        ))
115    }
116
117    fn peer_addr(&self) -> SocketAddr {
118        self.peer_addr
119    }
120
121    fn local_addr(&self) -> SocketAddr {
122        self.socket.local_addr().unwrap()
123    }
124}
125
126/// The read half of [`UdpServerStream`].
127pub struct UdpServerReadHalf(UdpRead);
128
129#[async_trait]
130impl ReadStream for UdpServerReadHalf {
131    fn read_exact<'life0, 'life1, 'async_trait>(
132        &'life0 mut self,
133        buffer: &'life1 mut [u8],
134    ) -> Pin<Box<dyn Future<Output = Result<(), std::io::Error>> + std::marker::Send + 'async_trait>>
135    where
136        'life0: 'async_trait,
137        'life1: 'async_trait,
138        Self: 'async_trait,
139    {
140        Box::pin(UdpReadTask {
141            read: self.0.clone(),
142            buffer,
143        })
144    }
145}
146
147/// A future that tries to read bytes from cache, and receives additional bytes if needed.
148/// [`UdpSocket::recv`] discards bytes that are not needed and there's no way to save them without buffering.
149pub struct UdpReadTask<'a> {
150    read: UdpRead,
151    buffer: &'a mut [u8],
152}
153
154impl Future for UdpReadTask<'_> {
155    type Output = io::Result<()>;
156
157    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        let UdpReadTask { read, buffer } = &mut *self;
159
160        if let Ok(mut bytes) = read.0.bytes.try_lock() {
161            if bytes.len() >= buffer.len() {
162                buffer.copy_from_slice(&bytes[..buffer.len()]);
163                *bytes = bytes[buffer.len()..].to_vec();
164                Poll::Ready(Ok(()))
165            } else {
166                read.0.waker.register(cx.waker());
167                Poll::Pending
168            }
169        } else {
170            read.0.waker.register(cx.waker());
171            Poll::Pending
172        }
173    }
174}
175
176/// A write half of [`UdpServerStream`];
177pub struct UdpServerWriteHalf {
178    peer_addr: SocketAddr,
179    socket: Arc<UdpSocket>,
180}
181
182#[async_trait]
183impl WriteStream for UdpServerWriteHalf {
184    async fn write_all(&mut self, buffer: &[u8]) -> std::io::Result<()> {
185        self.socket
186            .send_to(buffer, self.peer_addr)
187            .await
188            .and_then(|i| assert_all(i, buffer))
189    }
190}
191
192impl ServerStream for UdpServerStream {}
193
194/// A UDP client stream.
195pub struct UdpClientStream {
196    socket: UdpSocket,
197    peer_addr: SocketAddr,
198}
199
200#[async_trait]
201impl NetworkStream for UdpClientStream {
202    type ReadHalf = UdpClientReadHalf;
203    type WriteHalf = UdpClientWriteHalf;
204
205    async fn into_split(mut self) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {
206        let std_socket = self.socket.into_std()?;
207        let std_socket2 = std_socket.try_clone()?;
208        let read_socket = UdpSocket::from_std(std_socket)?;
209        let write_socket = UdpSocket::from_std(std_socket2)?;
210        let write = UdpClientWriteHalf {
211            socket: write_socket,
212        };
213        let read = UdpClientReadHalf {
214            socket: read_socket,
215            buffer: Vec::new(),
216        };
217        Ok((read, write))
218    }
219
220    fn peer_addr(&self) -> SocketAddr {
221        self.peer_addr // self.0.peer_addr().unwrap(). Tokio added it in https://github.com/tokio-rs/tokio/pull/4362 and then reverted in https://github.com/tokio-rs/tokio/pull/4392
222    }
223
224    fn local_addr(&self) -> SocketAddr {
225        self.socket.local_addr().unwrap()
226    }
227}
228
229#[async_trait]
230impl ClientStream for UdpClientStream {
231    async fn connect(addr: SocketAddr) -> std::io::Result<Self>
232    where
233        Self: Sized,
234    {
235        let socket = UdpSocket::bind("127.0.0.1:0").await?;
236        socket.connect(addr).await?;
237
238        // TODO remove this
239        let std_socket = socket.into_std().unwrap();
240        let peer_addr = std_socket.peer_addr().unwrap();
241        let socket = UdpSocket::from_std(std_socket).unwrap();
242
243        // socket.connect and socket.send is not enough to handle ConnectionRefused, but 2 sends is
244        socket.send(&[]).await?;
245        socket.send(&[]).await?;
246        Ok(UdpClientStream { socket, peer_addr })
247    }
248}
249
250/// A read half of [`UdpClientStream`].
251pub struct UdpClientReadHalf {
252    socket: UdpSocket,
253    buffer: Vec<u8>,
254}
255
256#[async_trait]
257impl ReadStream for UdpClientReadHalf {
258    async fn read_exact(&mut self, buffer: &mut [u8]) -> std::io::Result<()> {
259        loop {
260            if self.buffer.len() >= buffer.len() {
261                buffer.copy_from_slice(&self.buffer[..buffer.len()]);
262                self.buffer = self.buffer[buffer.len()..].to_vec();
263                return Ok(());
264            }
265            let mut buf = [0; BUFFER_SIZE];
266            let read = self.socket.recv(&mut buf).await?;
267            self.buffer.extend(&buf[..read]);
268        }
269    }
270}
271
272/// A write half of [`UdpClientStream`].
273pub struct UdpClientWriteHalf {
274    socket: UdpSocket,
275}
276
277#[async_trait]
278impl WriteStream for UdpClientWriteHalf {
279    async fn write_all(&mut self, buffer: &[u8]) -> std::io::Result<()> {
280        self.socket
281            .send(buffer)
282            .await
283            .and_then(|i| assert_all(i, buffer))
284    }
285}
286
287fn assert_all(i: usize, buf: &[u8]) -> io::Result<()> {
288    if i == buf.len() {
289        Ok(())
290    } else {
291        Err(io::Error::from(ErrorKind::BrokenPipe))
292    }
293}