bevy_slinet/protocols/
udp.rs1use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::net::SocketAddr;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use async_trait::async_trait;
12use dashmap::DashMap;
13use futures::task::AtomicWaker;
14use tokio::net::UdpSocket;
15use tokio::sync::Mutex;
16
17use crate::protocol::{
18 ClientStream, Listener, NetworkStream, ReadStream, ServerStream, WriteStream,
19};
20use crate::Protocol;
21
22const BUFFER_SIZE: usize = u16::MAX as usize;
23
24pub struct UdpProtocol;
26
27#[async_trait]
28impl Protocol for UdpProtocol {
29 type Listener = UdpNetworkListener;
30 type ServerStream = UdpServerStream;
31 type ClientStream = UdpClientStream;
32
33 async fn bind(addr: SocketAddr) -> std::io::Result<Self::Listener> {
34 Ok(UdpNetworkListener {
35 socket: Arc::new(UdpSocket::bind(addr).await?),
36 tasks: DashMap::new(),
37 })
38 }
39}
40
41struct Inner {
42 waker: AtomicWaker,
43 bytes: Mutex<Vec<u8>>,
44}
45
46#[derive(Clone)]
47struct UdpRead(Arc<Inner>);
48
49pub struct UdpNetworkListener {
51 socket: Arc<UdpSocket>,
52 tasks: DashMap<SocketAddr, UdpRead>,
53}
54
55#[async_trait]
56impl Listener for UdpNetworkListener {
57 type Stream = UdpServerStream;
58
59 async fn accept(&self) -> std::io::Result<UdpServerStream> {
60 let mut buf = [0; BUFFER_SIZE];
61 loop {
62 let (bytes, address) = self.socket.recv_from(&mut buf).await?;
63 let bytes = &buf[..bytes];
64 if let Some(task) = self.tasks.get(&address) {
65 {
66 let mut task_bytes = task.0.bytes.lock().await;
67 task_bytes.extend(bytes);
68 }
69 task.0.waker.wake();
70 } else {
71 let new_task = UdpRead(Arc::new(Inner {
72 waker: AtomicWaker::new(),
73 bytes: Mutex::new(Vec::new()),
74 }));
75 self.tasks.insert(address, new_task.clone());
76 return Ok(UdpServerStream {
77 task: new_task,
78 peer_addr: address,
79 socket: Arc::clone(&self.socket),
80 });
81 }
82 }
83 }
84
85 fn address(&self) -> SocketAddr {
86 self.socket.local_addr().unwrap()
87 }
88
89 fn handle_disconnection(&self, peer_addr: SocketAddr) {
90 self.tasks.remove(&peer_addr);
91 }
92}
93
94pub struct UdpServerStream {
96 task: UdpRead,
97 peer_addr: SocketAddr,
98 socket: Arc<UdpSocket>,
99}
100
101#[async_trait]
102impl NetworkStream for UdpServerStream {
103 type ReadHalf = UdpServerReadHalf;
104 type WriteHalf = UdpServerWriteHalf;
105
106 async fn into_split(self) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {
107 let peer_addr = self.peer_addr();
108 Ok((
109 UdpServerReadHalf(self.task.clone()),
110 UdpServerWriteHalf {
111 peer_addr,
112 socket: self.socket,
113 },
114 ))
115 }
116
117 fn peer_addr(&self) -> SocketAddr {
118 self.peer_addr
119 }
120
121 fn local_addr(&self) -> SocketAddr {
122 self.socket.local_addr().unwrap()
123 }
124}
125
126pub struct UdpServerReadHalf(UdpRead);
128
129#[async_trait]
130impl ReadStream for UdpServerReadHalf {
131 fn read_exact<'life0, 'life1, 'async_trait>(
132 &'life0 mut self,
133 buffer: &'life1 mut [u8],
134 ) -> Pin<Box<dyn Future<Output = Result<(), std::io::Error>> + std::marker::Send + 'async_trait>>
135 where
136 'life0: 'async_trait,
137 'life1: 'async_trait,
138 Self: 'async_trait,
139 {
140 Box::pin(UdpReadTask {
141 read: self.0.clone(),
142 buffer,
143 })
144 }
145}
146
147pub struct UdpReadTask<'a> {
150 read: UdpRead,
151 buffer: &'a mut [u8],
152}
153
154impl Future for UdpReadTask<'_> {
155 type Output = io::Result<()>;
156
157 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 let UdpReadTask { read, buffer } = &mut *self;
159
160 if let Ok(mut bytes) = read.0.bytes.try_lock() {
161 if bytes.len() >= buffer.len() {
162 buffer.copy_from_slice(&bytes[..buffer.len()]);
163 *bytes = bytes[buffer.len()..].to_vec();
164 Poll::Ready(Ok(()))
165 } else {
166 read.0.waker.register(cx.waker());
167 Poll::Pending
168 }
169 } else {
170 read.0.waker.register(cx.waker());
171 Poll::Pending
172 }
173 }
174}
175
176pub struct UdpServerWriteHalf {
178 peer_addr: SocketAddr,
179 socket: Arc<UdpSocket>,
180}
181
182#[async_trait]
183impl WriteStream for UdpServerWriteHalf {
184 async fn write_all(&mut self, buffer: &[u8]) -> std::io::Result<()> {
185 self.socket
186 .send_to(buffer, self.peer_addr)
187 .await
188 .and_then(|i| assert_all(i, buffer))
189 }
190}
191
192impl ServerStream for UdpServerStream {}
193
194pub struct UdpClientStream {
196 socket: UdpSocket,
197 peer_addr: SocketAddr,
198}
199
200#[async_trait]
201impl NetworkStream for UdpClientStream {
202 type ReadHalf = UdpClientReadHalf;
203 type WriteHalf = UdpClientWriteHalf;
204
205 async fn into_split(mut self) -> io::Result<(Self::ReadHalf, Self::WriteHalf)> {
206 let std_socket = self.socket.into_std()?;
207 let std_socket2 = std_socket.try_clone()?;
208 let read_socket = UdpSocket::from_std(std_socket)?;
209 let write_socket = UdpSocket::from_std(std_socket2)?;
210 let write = UdpClientWriteHalf {
211 socket: write_socket,
212 };
213 let read = UdpClientReadHalf {
214 socket: read_socket,
215 buffer: Vec::new(),
216 };
217 Ok((read, write))
218 }
219
220 fn peer_addr(&self) -> SocketAddr {
221 self.peer_addr }
223
224 fn local_addr(&self) -> SocketAddr {
225 self.socket.local_addr().unwrap()
226 }
227}
228
229#[async_trait]
230impl ClientStream for UdpClientStream {
231 async fn connect(addr: SocketAddr) -> std::io::Result<Self>
232 where
233 Self: Sized,
234 {
235 let socket = UdpSocket::bind("127.0.0.1:0").await?;
236 socket.connect(addr).await?;
237
238 let std_socket = socket.into_std().unwrap();
240 let peer_addr = std_socket.peer_addr().unwrap();
241 let socket = UdpSocket::from_std(std_socket).unwrap();
242
243 socket.send(&[]).await?;
245 socket.send(&[]).await?;
246 Ok(UdpClientStream { socket, peer_addr })
247 }
248}
249
250pub struct UdpClientReadHalf {
252 socket: UdpSocket,
253 buffer: Vec<u8>,
254}
255
256#[async_trait]
257impl ReadStream for UdpClientReadHalf {
258 async fn read_exact(&mut self, buffer: &mut [u8]) -> std::io::Result<()> {
259 loop {
260 if self.buffer.len() >= buffer.len() {
261 buffer.copy_from_slice(&self.buffer[..buffer.len()]);
262 self.buffer = self.buffer[buffer.len()..].to_vec();
263 return Ok(());
264 }
265 let mut buf = [0; BUFFER_SIZE];
266 let read = self.socket.recv(&mut buf).await?;
267 self.buffer.extend(&buf[..read]);
268 }
269 }
270}
271
272pub struct UdpClientWriteHalf {
274 socket: UdpSocket,
275}
276
277#[async_trait]
278impl WriteStream for UdpClientWriteHalf {
279 async fn write_all(&mut self, buffer: &[u8]) -> std::io::Result<()> {
280 self.socket
281 .send(buffer)
282 .await
283 .and_then(|i| assert_all(i, buffer))
284 }
285}
286
287fn assert_all(i: usize, buf: &[u8]) -> io::Result<()> {
288 if i == buf.len() {
289 Ok(())
290 } else {
291 Err(io::Error::from(ErrorKind::BrokenPipe))
292 }
293}