rasi_ext/net/
udp_group.rs

1//! Utility to batch poll a set of [`udp sockets`](rasi::net::UdpSocket)
2//!
3use std::{
4    collections::HashMap,
5    io,
6    net::{SocketAddr, ToSocketAddrs},
7    sync::Arc,
8    task::Poll,
9};
10
11use bytes::BytesMut;
12use rand::seq::IteratorRandom;
13use rasi::{
14    executor::spawn,
15    net::UdpSocket,
16    syscall::{global_network, Network},
17};
18
19use futures::{SinkExt, Stream, StreamExt};
20
21use crate::utils::ReadBuf;
22
23/// Udp data transfer metadata for the data sent to the peers
24/// from the specified local address `from`` to the destination address `to`.
25#[derive(Debug)]
26pub struct PathInfo {
27    /// The specified local address.
28    pub from: SocketAddr,
29    /// The destination address for the data sent to the peer.
30    pub to: SocketAddr,
31}
32
33impl PathInfo {
34    /// swap ***from*** field and ***to*** field of this object.
35    pub fn reverse(self) -> Self {
36        Self {
37            from: self.to,
38            to: self.from,
39        }
40    }
41}
42
43/// A configuration for batch poll a set of [`udp socket`](rasi::net::UdpSocket)s
44pub struct UdpGroup {
45    /// Inner socket map that mapping local_addr => socket.
46    sockets: HashMap<SocketAddr, Arc<UdpSocket>>,
47    /// The max buf length for batch reading.
48    max_recv_buf_len: u16,
49}
50
51impl UdpGroup {
52    /// Use global registered syscall interface [`Network`] to create a UDP socket group from the given address.
53    ///
54    /// Binding with a port number of 0 will request that the OS assigns a port to this socket. The
55    /// port allocated can be queried via the [`local_addrs`](Self::local_addrs) method.
56    ///
57    /// [`local_addr`]: #method.local_addr
58    ///
59    /// # Examples
60    ///
61    /// ```no_run
62    /// # fn main() -> std::io::Result<()> { futures::executor::block_on(async {
63    /// #
64    /// use rasi_ext::net::udp_group::UdpGroup;
65    ///
66    /// let socket = UdpGroup::bind("127.0.0.1:0").await?;
67    /// #
68    /// # Ok(()) }) }
69    /// ```
70    pub async fn bind<A: ToSocketAddrs>(laddrs: A) -> io::Result<Self> {
71        Self::bind_with(laddrs, global_network()).await
72    }
73
74    /// Use custom syscall interface [`Network`] to create a UDP socket group from the given address.
75    /// [*Read more*](Self::bind)
76    pub async fn bind_with<A: ToSocketAddrs>(
77        laddrs: A,
78        syscall: &'static dyn Network,
79    ) -> io::Result<Self> {
80        let laddrs = laddrs.to_socket_addrs()?.collect::<Vec<_>>();
81
82        let mut sockets = HashMap::new();
83
84        for laddr in laddrs {
85            let socket = Arc::new(UdpSocket::bind_with([laddr].as_slice(), syscall).await?);
86
87            // get the port allocated by OS.
88            let laddr = socket.local_addr()?;
89
90            sockets.insert(laddr, socket);
91        }
92
93        Ok(Self {
94            sockets,
95            max_recv_buf_len: 2048,
96        })
97    }
98
99    /// Sets the capacity of batch reading buf to specific `len`.
100    ///
101    /// ***assert***: the specific `len` must > 0 && < 65535.
102    ///
103    /// # Examples
104    ///
105    /// ```no_run
106    /// # fn main() -> std::io::Result<()> { futures::executor::block_on(async {
107    /// #
108    /// use rasi_ext::net::udp_group::UdpGroup;
109    ///
110    /// let config = UdpGroup::bind("127.0.0.1:0").await?;
111    /// let config = config.with_max_recv_buf_len(1024);
112    /// #
113    /// # Ok(()) }) }
114    /// ```
115    pub fn with_max_recv_buf_len(mut self, len: u16) -> Self {
116        assert!(len > 0, "sets max_recv_buf_len to zero");
117        self.max_recv_buf_len = len;
118
119        self
120    }
121
122    /// Helper method for splitting `UdpGroup` object into two halves.
123    ///
124    /// The two halves returned implement the [Sink](futures::Sink) and [Stream] traits, respectively.
125    ///
126    /// # Examples
127    ///
128    /// ```no_run
129    /// # fn main() -> std::io::Result<()> { futures::executor::block_on(async {
130    /// #
131    /// use rasi_ext::net::udp_group::UdpGroup;
132    ///
133    /// let (sx,rx) = UdpGroup::bind("127.0.0.1:0").await?.split();
134    /// #
135    /// # Ok(()) }) }
136    pub fn split(self) -> (Sender, Receiver) {
137        let sockets = self.sockets.values().cloned().collect::<Vec<_>>();
138
139        let (sender, receiver) = futures::channel::mpsc::channel(0);
140
141        for socket in sockets {
142            spawn(Self::recv_loop(
143                socket,
144                sender.clone(),
145                self.max_recv_buf_len as usize,
146            ));
147        }
148
149        (Sender::new(self.sockets), Receiver::new(receiver))
150    }
151
152    /// Returns the local addresses iterator that this udp group are bound to.
153    ///
154    /// This can be useful, for example, to identify when binding to port 0 which port was assigned
155    /// by the OS.
156    ///
157    /// # Examples
158    ///
159    /// ```no_run
160    /// # fn main() -> std::io::Result<()> { futures::executor::block_on(async {
161    /// #
162    /// use rasi_ext::net::udp_group::UdpGroup;
163    ///
164    /// let udp_group = UdpGroup::bind("127.0.0.1:0").await?;
165    /// let laddrs = udp_group.local_addrs().collect::<Vec<_>>();
166    /// #
167    /// # Ok(()) }) }
168    pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
169        self.sockets.keys()
170    }
171
172    async fn recv_loop(
173        socket: Arc<UdpSocket>,
174        mut sender: futures::channel::mpsc::Sender<UdpGroupData>,
175        max_recv_buf_len: usize,
176    ) {
177        let laddr = socket.local_addr().unwrap();
178
179        loop {
180            let mut read_buf = ReadBuf::with_capacity(max_recv_buf_len);
181
182            match socket.recv_from(read_buf.chunk_mut()).await {
183                Ok((read_size, raddr)) => {
184                    log::trace!(
185                        "UdpGroup recv_from, raddr={:?}, read_size={}",
186                        raddr,
187                        read_size
188                    );
189
190                    let data = UdpGroupData {
191                        result: Ok((read_buf.into_bytes_mut(Some(read_size)), raddr)),
192                        to: laddr,
193                    };
194
195                    if sender.send(data).await.is_err() {
196                        log::trace!("socket({:?}) in udp group, stop recv loop", laddr);
197                    }
198                }
199                Err(err) => {
200                    log::error!(
201                        "socket({:?}) in udp group, shutdown with error: {}",
202                        laddr,
203                        err
204                    );
205                }
206            }
207        }
208    }
209}
210
211struct UdpGroupData {
212    result: io::Result<(BytesMut, SocketAddr)>,
213    /// the receiver socket's local address.
214    to: SocketAddr,
215}
216
217/// Data is received from the peers via this [`UdpGroup`] receiver [`stream`](Stream).
218pub struct Receiver {
219    inner: futures::channel::mpsc::Receiver<UdpGroupData>,
220}
221
222impl Receiver {
223    fn new(inner: futures::channel::mpsc::Receiver<UdpGroupData>) -> Self {
224        Self { inner }
225    }
226}
227
228impl Stream for Receiver {
229    type Item = io::Result<(BytesMut, PathInfo)>;
230
231    fn poll_next(
232        mut self: std::pin::Pin<&mut Self>,
233        cx: &mut std::task::Context<'_>,
234    ) -> std::task::Poll<Option<Self::Item>> {
235        match self.inner.poll_next_unpin(cx) {
236            Poll::Ready(None) => Poll::Ready(None),
237            Poll::Ready(Some(udp_group_data)) => {
238                Poll::Ready(Some(udp_group_data.result.map(|(buf, raddr)| {
239                    (
240                        buf,
241                        PathInfo {
242                            from: raddr,
243                            to: udp_group_data.to,
244                        },
245                    )
246                })))
247            }
248            Poll::Pending => Poll::Pending,
249        }
250    }
251}
252
253/// Data is sent to the peers via this [`UdpGroup`] sender
254pub struct Sender {
255    sockets: Arc<HashMap<SocketAddr, Arc<UdpSocket>>>,
256}
257
258impl Clone for Sender {
259    fn clone(&self) -> Self {
260        Self {
261            sockets: self.sockets.clone(),
262        }
263    }
264}
265
266impl Sender {
267    fn new(sockets: HashMap<SocketAddr, Arc<UdpSocket>>) -> Self {
268        Self {
269            sockets: Arc::new(sockets),
270        }
271    }
272
273    /// Sends data on the random socket in this group to the given address.
274    ///
275    /// On success, returns the number of bytes written.
276    pub async fn send_to(&self, buf: &[u8], raddr: SocketAddr) -> io::Result<usize> {
277        let socket = self
278            .sockets
279            .values()
280            .choose(&mut rand::thread_rng())
281            .unwrap()
282            .clone();
283
284        socket.send_to(buf, raddr).await
285    }
286
287    /// Sends data on the [`PathInfo`] to the given address.
288    ///
289    /// On success, returns the number of bytes written.
290    pub async fn send_to_on_path(&self, buf: &[u8], path_info: PathInfo) -> io::Result<usize> {
291        if let Some(socket) = self.sockets.get(&path_info.from) {
292            socket.send_to(buf, path_info.to).await
293        } else {
294            Err(io::Error::new(
295                io::ErrorKind::NotFound,
296                format!("Socket bound to {:?} is not in the group.", path_info.from),
297            ))
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304
305    use bytes::Bytes;
306    use futures::TryStreamExt;
307    use rasi_default::{executor::register_futures_executor_with_pool_size, net::MioNetwork};
308
309    use super::*;
310
311    use std::sync::OnceLock;
312
313    static INIT: OnceLock<Box<dyn rasi::syscall::Network>> = OnceLock::new();
314
315    fn get_syscall() -> &'static dyn rasi::syscall::Network {
316        INIT.get_or_init(|| {
317            register_futures_executor_with_pool_size(10).unwrap();
318            Box::new(MioNetwork::default())
319        })
320        .as_ref()
321    }
322
323    #[futures_test::test]
324    async fn test_udp_group_echo() {
325        let syscall = get_syscall();
326
327        let addrs: Vec<SocketAddr> = ["127.0.0.1:0".parse().unwrap()].repeat(4);
328        let (client_sender, mut client_receiver) = UdpGroup::bind_with(addrs.as_slice(), syscall)
329            .await
330            .unwrap()
331            .split();
332
333        let server = UdpGroup::bind_with(addrs.as_slice(), syscall)
334            .await
335            .unwrap();
336
337        let raddrs = server.local_addrs().cloned().collect::<Vec<_>>();
338
339        let (server_sender, mut server_receiver) = server.split();
340
341        let random_raddr = raddrs
342            .iter()
343            .choose(&mut rand::thread_rng())
344            .cloned()
345            .unwrap();
346
347        client_sender
348            .send_to(b"hello world", random_raddr)
349            .await
350            .unwrap();
351
352        let (buf, path_info) = server_receiver.try_next().await.unwrap().unwrap();
353
354        let buf = buf.freeze();
355
356        assert_eq!(buf, Bytes::from_static(b"hello world"));
357
358        server_sender
359            .send_to_on_path(b"hello world", path_info.reverse())
360            .await
361            .unwrap();
362
363        let (buf, _) = client_receiver.try_next().await.unwrap().unwrap();
364
365        let buf = buf.freeze();
366
367        assert_eq!(buf, Bytes::from_static(b"hello world"));
368    }
369}