1use std::{
4 collections::HashMap,
5 io,
6 net::{SocketAddr, ToSocketAddrs},
7 sync::Arc,
8 task::Poll,
9};
10
11use bytes::BytesMut;
12use rand::seq::IteratorRandom;
13use rasi::{
14 executor::spawn,
15 net::UdpSocket,
16 syscall::{global_network, Network},
17};
18
19use futures::{SinkExt, Stream, StreamExt};
20
21use crate::utils::ReadBuf;
22
23#[derive(Debug)]
26pub struct PathInfo {
27 pub from: SocketAddr,
29 pub to: SocketAddr,
31}
32
33impl PathInfo {
34 pub fn reverse(self) -> Self {
36 Self {
37 from: self.to,
38 to: self.from,
39 }
40 }
41}
42
43pub struct UdpGroup {
45 sockets: HashMap<SocketAddr, Arc<UdpSocket>>,
47 max_recv_buf_len: u16,
49}
50
51impl UdpGroup {
52 pub async fn bind<A: ToSocketAddrs>(laddrs: A) -> io::Result<Self> {
71 Self::bind_with(laddrs, global_network()).await
72 }
73
74 pub async fn bind_with<A: ToSocketAddrs>(
77 laddrs: A,
78 syscall: &'static dyn Network,
79 ) -> io::Result<Self> {
80 let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
81
82 let mut sockets = HashMap::new();
83
84 for laddr in laddrs {
85 let socket = Arc::new(UdpSocket::bind_with([laddr].as_slice(), syscall).await?);
86
87 let laddr = socket.local_addr()?;
89
90 sockets.insert(laddr, socket);
91 }
92
93 Ok(Self {
94 sockets,
95 max_recv_buf_len: 2048,
96 })
97 }
98
99 pub fn with_max_recv_buf_len(mut self, len: u16) -> Self {
116 assert!(len > 0, "sets max_recv_buf_len to zero");
117 self.max_recv_buf_len = len;
118
119 self
120 }
121
122 pub fn split(self) -> (Sender, Receiver) {
137 let sockets = self.sockets.values().cloned().collect::<Vec<_>>();
138
139 let (sender, receiver) = futures::channel::mpsc::channel(0);
140
141 for socket in sockets {
142 spawn(Self::recv_loop(
143 socket,
144 sender.clone(),
145 self.max_recv_buf_len as usize,
146 ));
147 }
148
149 (Sender::new(self.sockets), Receiver::new(receiver))
150 }
151
152 pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
169 self.sockets.keys()
170 }
171
172 async fn recv_loop(
173 socket: Arc<UdpSocket>,
174 mut sender: futures::channel::mpsc::Sender<UdpGroupData>,
175 max_recv_buf_len: usize,
176 ) {
177 let laddr = socket.local_addr().unwrap();
178
179 loop {
180 let mut read_buf = ReadBuf::with_capacity(max_recv_buf_len);
181
182 match socket.recv_from(read_buf.chunk_mut()).await {
183 Ok((read_size, raddr)) => {
184 log::trace!(
185 "UdpGroup recv_from, raddr={:?}, read_size={}",
186 raddr,
187 read_size
188 );
189
190 let data = UdpGroupData {
191 result: Ok((read_buf.into_bytes_mut(Some(read_size)), raddr)),
192 to: laddr,
193 };
194
195 if sender.send(data).await.is_err() {
196 log::trace!("socket({:?}) in udp group, stop recv loop", laddr);
197 }
198 }
199 Err(err) => {
200 log::error!(
201 "socket({:?}) in udp group, shutdown with error: {}",
202 laddr,
203 err
204 );
205 }
206 }
207 }
208 }
209}
210
211struct UdpGroupData {
212 result: io::Result<(BytesMut, SocketAddr)>,
213 to: SocketAddr,
215}
216
217pub struct Receiver {
219 inner: futures::channel::mpsc::Receiver<UdpGroupData>,
220}
221
222impl Receiver {
223 fn new(inner: futures::channel::mpsc::Receiver<UdpGroupData>) -> Self {
224 Self { inner }
225 }
226}
227
228impl Stream for Receiver {
229 type Item = io::Result<(BytesMut, PathInfo)>;
230
231 fn poll_next(
232 mut self: std::pin::Pin<&mut Self>,
233 cx: &mut std::task::Context<'_>,
234 ) -> std::task::Poll<Option<Self::Item>> {
235 match self.inner.poll_next_unpin(cx) {
236 Poll::Ready(None) => Poll::Ready(None),
237 Poll::Ready(Some(udp_group_data)) => {
238 Poll::Ready(Some(udp_group_data.result.map(|(buf, raddr)| {
239 (
240 buf,
241 PathInfo {
242 from: raddr,
243 to: udp_group_data.to,
244 },
245 )
246 })))
247 }
248 Poll::Pending => Poll::Pending,
249 }
250 }
251}
252
253pub struct Sender {
255 sockets: Arc<HashMap<SocketAddr, Arc<UdpSocket>>>,
256}
257
258impl Clone for Sender {
259 fn clone(&self) -> Self {
260 Self {
261 sockets: self.sockets.clone(),
262 }
263 }
264}
265
266impl Sender {
267 fn new(sockets: HashMap<SocketAddr, Arc<UdpSocket>>) -> Self {
268 Self {
269 sockets: Arc::new(sockets),
270 }
271 }
272
273 pub async fn send_to(&self, buf: &[u8], raddr: SocketAddr) -> io::Result<usize> {
277 let socket = self
278 .sockets
279 .values()
280 .choose(&mut rand::thread_rng())
281 .unwrap()
282 .clone();
283
284 socket.send_to(buf, raddr).await
285 }
286
287 pub async fn send_to_on_path(&self, buf: &[u8], path_info: PathInfo) -> io::Result<usize> {
291 if let Some(socket) = self.sockets.get(&path_info.from) {
292 socket.send_to(buf, path_info.to).await
293 } else {
294 Err(io::Error::new(
295 io::ErrorKind::NotFound,
296 format!("Socket bound to {:?} is not in the group.", path_info.from),
297 ))
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304
305 use bytes::Bytes;
306 use futures::TryStreamExt;
307 use rasi_default::{executor::register_futures_executor_with_pool_size, net::MioNetwork};
308
309 use super::*;
310
311 use std::sync::OnceLock;
312
313 static INIT: OnceLock<Box<dyn rasi::syscall::Network>> = OnceLock::new();
314
315 fn get_syscall() -> &'static dyn rasi::syscall::Network {
316 INIT.get_or_init(|| {
317 register_futures_executor_with_pool_size(10).unwrap();
318 Box::new(MioNetwork::default())
319 })
320 .as_ref()
321 }
322
323 #[futures_test::test]
324 async fn test_udp_group_echo() {
325 let syscall = get_syscall();
326
327 let addrs: Vec<SocketAddr> = ["127.0.0.1:0".parse().unwrap()].repeat(4);
328 let (client_sender, mut client_receiver) = UdpGroup::bind_with(addrs.as_slice(), syscall)
329 .await
330 .unwrap()
331 .split();
332
333 let server = UdpGroup::bind_with(addrs.as_slice(), syscall)
334 .await
335 .unwrap();
336
337 let raddrs = server.local_addrs().cloned().collect::<Vec<_>>();
338
339 let (server_sender, mut server_receiver) = server.split();
340
341 let random_raddr = raddrs
342 .iter()
343 .choose(&mut rand::thread_rng())
344 .cloned()
345 .unwrap();
346
347 client_sender
348 .send_to(b"hello world", random_raddr)
349 .await
350 .unwrap();
351
352 let (buf, path_info) = server_receiver.try_next().await.unwrap().unwrap();
353
354 let buf = buf.freeze();
355
356 assert_eq!(buf, Bytes::from_static(b"hello world"));
357
358 server_sender
359 .send_to_on_path(b"hello world", path_info.reverse())
360 .await
361 .unwrap();
362
363 let (buf, _) = client_receiver.try_next().await.unwrap().unwrap();
364
365 let buf = buf.freeze();
366
367 assert_eq!(buf, Bytes::from_static(b"hello world"));
368 }
369}