hala_udp/
group.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    io,
5    net::{SocketAddr, ToSocketAddrs},
6    ptr::null_mut,
7    sync::{
8        atomic::{AtomicPtr, Ordering},
9        Arc,
10    },
11    task::Poll,
12};
13
14use hala_future::batching::FutureBatcher;
15use hala_io::{
16    context::{io_context, RawIoContext},
17    would_block, Cmd, Description, Driver, Handle, Interest, OpenFlags,
18};
19
20/// The return type of batching read.
21struct BatchRead {
22    /// ready udp socket handle.
23    handle: Handle,
24    /// the result of reading poll on `handle`
25    result: io::Result<(usize, PathInfo)>,
26}
27
28/// The return type of batching write.
29struct BatchWrite {
30    /// ready udp socket handle.
31    handle: Handle,
32    /// the result of writing poll on `handle`
33    result: io::Result<(usize, PathInfo)>,
34}
35
36/// The oatg information for transfered udp data.
37#[derive(Clone, Copy, PartialEq, Eq)]
38pub struct PathInfo {
39    /// The packet from udp endpoint.
40    pub from: SocketAddr,
41    /// The packet to udp endpoint
42    pub to: SocketAddr,
43}
44
45impl Debug for PathInfo {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "path_info, from={:?}, to={:?}", self.from, self.to)
48    }
49}
50
51/// An utility socket type which handle a group udp sockets's read/write ops.
52pub struct UdpGroup {
53    /// hala io driver instance.
54    driver: Driver,
55    /// fds of this udp group.
56    fds: HashMap<SocketAddr, Handle>,
57    /// mapping handle to address.
58    laddrs: HashMap<Handle, SocketAddr>,
59    /// The batch reader for udp socket handles.
60    batching_reader: FutureBatcher<BatchRead>,
61    /// The batch writer for udp socket handles.
62    batching_writer: FutureBatcher<BatchWrite>,
63    /// The out buf for batch read.
64    batching_read_buf: Arc<AtomicPtr<*mut [u8]>>,
65    /// The in buf for batch write.
66    batching_write_buf: Arc<AtomicPtr<(*const [u8], SocketAddr)>>,
67}
68
69impl Drop for UdpGroup {
70    fn drop(&mut self) {
71        for fd in self.fds.iter().map(|(_, fd)| *fd) {
72            self.driver.fd_close(fd).unwrap()
73        }
74    }
75}
76
77impl UdpGroup {
78    /// Bind udp group on providing addresses group.
79    pub fn bind<S: ToSocketAddrs>(laddrs: S) -> io::Result<Self> {
80        let io_context = io_context();
81
82        let mut fds = HashMap::new();
83        let mut addrs = HashMap::new();
84
85        for addr in laddrs.to_socket_addrs()? {
86            let fd = io_context
87                .driver()
88                .fd_open(Description::UdpSocket, OpenFlags::Bind(&[addr]))?;
89
90            match io_context.driver().fd_cntl(
91                io_context.poller(),
92                Cmd::Register {
93                    source: fd,
94                    interests: Interest::Readable | Interest::Writable,
95                },
96            ) {
97                Err(err) => {
98                    _ = io_context.driver().fd_close(fd);
99                    return Err(err);
100                }
101                _ => {}
102            }
103
104            let laddr = io_context
105                .driver()
106                .fd_cntl(fd, Cmd::LocalAddr)?
107                .try_into_sockaddr()?;
108
109            fds.insert(laddr, fd);
110            addrs.insert(fd, laddr);
111        }
112
113        let group = UdpGroup {
114            driver: io_context.driver().clone(),
115            fds,
116            laddrs: addrs,
117            batching_read_buf: Default::default(),
118            batching_reader: Default::default(),
119            batching_write_buf: Default::default(),
120            batching_writer: Default::default(),
121        };
122
123        group.init_push_batch_ops();
124
125        Ok(group)
126    }
127
128    fn init_push_batch_ops(&self) {
129        for fd in self.fds.iter().map(|(_, fd)| *fd) {
130            self.push_batch_read(fd);
131            self.push_batch_write(fd);
132        }
133    }
134
135    /// mapping laddr to udp socket handle. returns `None` if the mapping is not found.
136    fn laddr_to_handle(&self, laddr: SocketAddr) -> Option<Handle> {
137        self.fds.get(&laddr).map(|fd| *fd)
138    }
139
140    /// mapping udp socket handle to laddr. returns `None` if the mapping is not found.
141    fn handle_to_laddr(&self, handle: Handle) -> Option<SocketAddr> {
142        self.laddrs.get(&handle).map(|fd| *fd)
143    }
144
145    /// Create new batch op for udp reading.
146    fn push_batch_read(&self, handle: Handle) {
147        let driver = self.driver.clone();
148
149        let batching_read_buf = self.batching_read_buf.clone();
150
151        let laddr = self
152            .handle_to_laddr(handle)
153            .expect("The mapping handle -> address not found.");
154
155        self.batching_reader.push_fn(move |cx| {
156            let buf = batching_read_buf.load(Ordering::Acquire);
157
158            assert!(
159                buf != null_mut(),
160                "set batching_read_buf before calling batching_reader await."
161            );
162
163            batching_read_buf
164                .compare_exchange(buf, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
165                .expect("Only one poll read ops should be executing at a time");
166
167            let buf_ref = unsafe { &mut **buf };
168
169            let cmd_resp = driver.fd_cntl(
170                handle,
171                Cmd::RecvFrom {
172                    waker: cx.waker().clone(),
173                    buf: buf_ref,
174                },
175            );
176
177            let cmd_resp = match cmd_resp {
178                Ok(cmd_resp) => {
179                    _ = unsafe { Box::from_raw(buf) };
180                    cmd_resp
181                }
182                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
183                    batching_read_buf
184                        .compare_exchange(null_mut(), buf, Ordering::AcqRel, Ordering::Relaxed)
185                        .expect("Only one poll read ops should be executing at a time");
186
187                    return Poll::Pending;
188                }
189                Err(err) => {
190                    _ = unsafe { Box::from_raw(buf) };
191                    return Poll::Ready(BatchRead {
192                        handle,
193                        result: Err(err),
194                    });
195                }
196            };
197
198            let (read_size, raddr) = cmd_resp.try_into_recv_from().unwrap();
199
200            log::trace!("batch_read ready");
201
202            return Poll::Ready(BatchRead {
203                handle,
204                result: Ok((
205                    read_size,
206                    PathInfo {
207                        from: raddr,
208                        to: laddr,
209                    },
210                )),
211            });
212        });
213    }
214
215    /// Create new batch op for udp writing.
216    fn push_batch_write(&self, handle: Handle) {
217        let driver = self.driver.clone();
218
219        let batching_write_buf = self.batching_write_buf.clone();
220
221        let laddr = self
222            .handle_to_laddr(handle)
223            .expect("The mapping handle -> address not found.");
224
225        self.batching_writer.push_fn(move |cx| {
226            let buf = batching_write_buf.load(Ordering::Acquire);
227
228            assert!(
229                buf != null_mut(),
230                "set batching_write_buf before calling batching_writer await."
231            );
232
233            batching_write_buf
234                .compare_exchange(buf, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
235                .expect("Only one poll write ops should be executing at a time");
236
237            let (buf_ref, raddr) = unsafe { &mut *buf };
238
239            let raddr = raddr.clone();
240
241            let buf_ref = unsafe { &**buf_ref };
242
243            let cmd_resp = driver.fd_cntl(
244                handle,
245                Cmd::SendTo {
246                    waker: cx.waker().clone(),
247                    buf: buf_ref,
248                    raddr,
249                },
250            );
251
252            let cmd_resp = match cmd_resp {
253                Ok(cmd_resp) => {
254                    _ = unsafe { Box::from_raw(buf) };
255
256                    cmd_resp
257                }
258                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
259                    batching_write_buf
260                        .compare_exchange(null_mut(), buf, Ordering::AcqRel, Ordering::Relaxed)
261                        .expect("Only one poll write ops should be executing at a time");
262
263                    return Poll::Pending;
264                }
265                Err(err) => {
266                    _ = unsafe { Box::from_raw(buf) };
267                    return Poll::Ready(BatchWrite {
268                        handle,
269                        result: Err(err),
270                    });
271                }
272            };
273
274            let read_size = cmd_resp.try_into_datalen().unwrap();
275
276            return Poll::Ready(BatchWrite {
277                handle,
278                result: Ok((
279                    read_size,
280                    PathInfo {
281                        from: laddr,
282                        to: raddr,
283                    },
284                )),
285            });
286        });
287    }
288
289    /// Try recv an udp packet and write it into `buf`.
290    ///
291    /// If successful, the received packet length and data transfer [`path`](PathInfo) information is returned.
292    ///
293    /// *Restrictions*: concurrently calls to recv_from are not allowed!!!, the later calling will override earlier calling's out buf.
294    pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, PathInfo)> {
295        let ptr = Box::into_raw(Box::new(buf as *mut [u8]));
296
297        let old_ptr = self.batching_read_buf.swap(ptr, Ordering::AcqRel);
298
299        if old_ptr != null_mut() {
300            _ = unsafe { Box::from_raw(old_ptr) };
301        }
302
303        let batch_read = self
304            .batching_reader
305            .wait()
306            .await
307            .expect("No one call closed");
308
309        self.push_batch_read(batch_read.handle);
310
311        batch_read.result
312    }
313
314    /// Try send an udp packet to peer.
315    ///
316    /// If successful, the sent packet length and data transfer [`path`](PathInfo) information is returned.
317    ///
318    /// *Restrictions*: concurrently calls to send_to are not allowed!!!, the later calling will override earlier calling's in buf.
319    pub async fn send_to(&self, buf: &[u8], raddr: SocketAddr) -> io::Result<(usize, PathInfo)> {
320        let ptr = Box::into_raw(Box::new((buf as *const [u8], raddr)));
321
322        let old_ptr = self.batching_write_buf.swap(ptr, Ordering::AcqRel);
323
324        if old_ptr != null_mut() {
325            _ = unsafe { Box::from_raw(old_ptr) };
326        }
327
328        let batch_write = self
329            .batching_writer
330            .wait()
331            .await
332            .expect("No one call closed");
333
334        self.push_batch_write(batch_write.handle);
335
336        batch_write.result
337    }
338
339    /// Try send an udp packet to peer over the given [`path`](PathInfo)
340    pub async fn send_to_on_path(&self, buf: &[u8], path_info: PathInfo) -> io::Result<usize> {
341        let fd = self.laddr_to_handle(path_info.from).ok_or(io::Error::new(
342            io::ErrorKind::NotFound,
343            format!("path info not found, {:?}", path_info),
344        ))?;
345
346        let r = would_block(|cx| {
347            self.driver
348                .fd_cntl(
349                    fd,
350                    Cmd::SendTo {
351                        waker: cx.waker().clone(),
352                        buf,
353                        raddr: path_info.to,
354                    },
355                )?
356                .try_into_datalen()
357        })
358        .await;
359
360        r
361    }
362
363    /// Return the local bound socket addresses iterator.
364    pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
365        self.laddrs.values()
366    }
367}
368
369#[cfg(test)]
370mod tests {
371
372    use hala_future::executor::future_spawn;
373    use hala_io::test::io_test;
374    use rand::{seq::SliceRandom, thread_rng};
375
376    use super::*;
377
378    #[hala_test::test(io_test)]
379    async fn test_send() {
380        let laddrs = vec!["127.0.0.1:0".parse().unwrap(); 1];
381
382        let server_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
383
384        let client_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
385
386        let raddrs = server_group
387            .local_addrs()
388            .map(|addr| *addr)
389            .collect::<Vec<_>>();
390
391        let loops = 10000;
392
393        future_spawn(async move {
394            loop {
395                let mut buf = vec![0; 1024];
396
397                let (recv_size, recv_path_info) = server_group.recv_from(&mut buf).await.unwrap();
398
399                server_group
400                    .send_to_on_path(
401                        &buf[..recv_size],
402                        PathInfo {
403                            from: recv_path_info.to,
404                            to: recv_path_info.from,
405                        },
406                    )
407                    .await
408                    .unwrap();
409            }
410        });
411
412        for i in 0..loops {
413            let raddr = raddrs.choose(&mut thread_rng()).unwrap();
414
415            let data = format!("hello world {}", i);
416
417            let (send_size, send_path_info) =
418                client_group.send_to(data.as_bytes(), *raddr).await.unwrap();
419
420            let mut buf = vec![0; 1024];
421
422            let (read_size, path_info) = client_group.recv_from(&mut buf).await.unwrap();
423
424            assert_eq!(read_size, send_size);
425
426            assert_eq!(path_info.from, send_path_info.to);
427            assert_eq!(path_info.to, send_path_info.from);
428        }
429    }
430
431    #[hala_test::test(io_test)]
432    async fn test_sequence_send_recv() {
433        let laddrs = vec!["127.0.0.1:0".parse().unwrap(); 1];
434
435        let server_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
436
437        let client_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
438
439        let raddrs = server_group
440            .local_addrs()
441            .map(|addr| *addr)
442            .collect::<Vec<_>>();
443
444        let loops = 10000;
445
446        for i in 0..loops {
447            let raddr = raddrs.choose(&mut thread_rng()).unwrap();
448
449            let data = format!("hello world {}", i);
450
451            let (send_size, send_path_info) =
452                client_group.send_to(data.as_bytes(), *raddr).await.unwrap();
453
454            let mut buf = vec![0; 1024];
455
456            let (read_size, path_info) = server_group.recv_from(&mut buf).await.unwrap();
457
458            assert_eq!(read_size, send_size);
459            assert_eq!(path_info, send_path_info);
460        }
461    }
462}