1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 io,
5 net::{SocketAddr, ToSocketAddrs},
6 ptr::null_mut,
7 sync::{
8 atomic::{AtomicPtr, Ordering},
9 Arc,
10 },
11 task::Poll,
12};
13
14use hala_future::batching::FutureBatcher;
15use hala_io::{
16 context::{io_context, RawIoContext},
17 would_block, Cmd, Description, Driver, Handle, Interest, OpenFlags,
18};
19
20struct BatchRead {
22 handle: Handle,
24 result: io::Result<(usize, PathInfo)>,
26}
27
28struct BatchWrite {
30 handle: Handle,
32 result: io::Result<(usize, PathInfo)>,
34}
35
36#[derive(Clone, Copy, PartialEq, Eq)]
38pub struct PathInfo {
39 pub from: SocketAddr,
41 pub to: SocketAddr,
43}
44
45impl Debug for PathInfo {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "path_info, from={:?}, to={:?}", self.from, self.to)
48 }
49}
50
51pub struct UdpGroup {
53 driver: Driver,
55 fds: HashMap<SocketAddr, Handle>,
57 laddrs: HashMap<Handle, SocketAddr>,
59 batching_reader: FutureBatcher<BatchRead>,
61 batching_writer: FutureBatcher<BatchWrite>,
63 batching_read_buf: Arc<AtomicPtr<*mut [u8]>>,
65 batching_write_buf: Arc<AtomicPtr<(*const [u8], SocketAddr)>>,
67}
68
69impl Drop for UdpGroup {
70 fn drop(&mut self) {
71 for fd in self.fds.iter().map(|(_, fd)| *fd) {
72 self.driver.fd_close(fd).unwrap()
73 }
74 }
75}
76
77impl UdpGroup {
78 pub fn bind<S: ToSocketAddrs>(laddrs: S) -> io::Result<Self> {
80 let io_context = io_context();
81
82 let mut fds = HashMap::new();
83 let mut addrs = HashMap::new();
84
85 for addr in laddrs.to_socket_addrs()? {
86 let fd = io_context
87 .driver()
88 .fd_open(Description::UdpSocket, OpenFlags::Bind(&[addr]))?;
89
90 match io_context.driver().fd_cntl(
91 io_context.poller(),
92 Cmd::Register {
93 source: fd,
94 interests: Interest::Readable | Interest::Writable,
95 },
96 ) {
97 Err(err) => {
98 _ = io_context.driver().fd_close(fd);
99 return Err(err);
100 }
101 _ => {}
102 }
103
104 let laddr = io_context
105 .driver()
106 .fd_cntl(fd, Cmd::LocalAddr)?
107 .try_into_sockaddr()?;
108
109 fds.insert(laddr, fd);
110 addrs.insert(fd, laddr);
111 }
112
113 let group = UdpGroup {
114 driver: io_context.driver().clone(),
115 fds,
116 laddrs: addrs,
117 batching_read_buf: Default::default(),
118 batching_reader: Default::default(),
119 batching_write_buf: Default::default(),
120 batching_writer: Default::default(),
121 };
122
123 group.init_push_batch_ops();
124
125 Ok(group)
126 }
127
128 fn init_push_batch_ops(&self) {
129 for fd in self.fds.iter().map(|(_, fd)| *fd) {
130 self.push_batch_read(fd);
131 self.push_batch_write(fd);
132 }
133 }
134
135 fn laddr_to_handle(&self, laddr: SocketAddr) -> Option<Handle> {
137 self.fds.get(&laddr).map(|fd| *fd)
138 }
139
140 fn handle_to_laddr(&self, handle: Handle) -> Option<SocketAddr> {
142 self.laddrs.get(&handle).map(|fd| *fd)
143 }
144
145 fn push_batch_read(&self, handle: Handle) {
147 let driver = self.driver.clone();
148
149 let batching_read_buf = self.batching_read_buf.clone();
150
151 let laddr = self
152 .handle_to_laddr(handle)
153 .expect("The mapping handle -> address not found.");
154
155 self.batching_reader.push_fn(move |cx| {
156 let buf = batching_read_buf.load(Ordering::Acquire);
157
158 assert!(
159 buf != null_mut(),
160 "set batching_read_buf before calling batching_reader await."
161 );
162
163 batching_read_buf
164 .compare_exchange(buf, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
165 .expect("Only one poll read ops should be executing at a time");
166
167 let buf_ref = unsafe { &mut **buf };
168
169 let cmd_resp = driver.fd_cntl(
170 handle,
171 Cmd::RecvFrom {
172 waker: cx.waker().clone(),
173 buf: buf_ref,
174 },
175 );
176
177 let cmd_resp = match cmd_resp {
178 Ok(cmd_resp) => {
179 _ = unsafe { Box::from_raw(buf) };
180 cmd_resp
181 }
182 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
183 batching_read_buf
184 .compare_exchange(null_mut(), buf, Ordering::AcqRel, Ordering::Relaxed)
185 .expect("Only one poll read ops should be executing at a time");
186
187 return Poll::Pending;
188 }
189 Err(err) => {
190 _ = unsafe { Box::from_raw(buf) };
191 return Poll::Ready(BatchRead {
192 handle,
193 result: Err(err),
194 });
195 }
196 };
197
198 let (read_size, raddr) = cmd_resp.try_into_recv_from().unwrap();
199
200 log::trace!("batch_read ready");
201
202 return Poll::Ready(BatchRead {
203 handle,
204 result: Ok((
205 read_size,
206 PathInfo {
207 from: raddr,
208 to: laddr,
209 },
210 )),
211 });
212 });
213 }
214
215 fn push_batch_write(&self, handle: Handle) {
217 let driver = self.driver.clone();
218
219 let batching_write_buf = self.batching_write_buf.clone();
220
221 let laddr = self
222 .handle_to_laddr(handle)
223 .expect("The mapping handle -> address not found.");
224
225 self.batching_writer.push_fn(move |cx| {
226 let buf = batching_write_buf.load(Ordering::Acquire);
227
228 assert!(
229 buf != null_mut(),
230 "set batching_write_buf before calling batching_writer await."
231 );
232
233 batching_write_buf
234 .compare_exchange(buf, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
235 .expect("Only one poll write ops should be executing at a time");
236
237 let (buf_ref, raddr) = unsafe { &mut *buf };
238
239 let raddr = raddr.clone();
240
241 let buf_ref = unsafe { &**buf_ref };
242
243 let cmd_resp = driver.fd_cntl(
244 handle,
245 Cmd::SendTo {
246 waker: cx.waker().clone(),
247 buf: buf_ref,
248 raddr,
249 },
250 );
251
252 let cmd_resp = match cmd_resp {
253 Ok(cmd_resp) => {
254 _ = unsafe { Box::from_raw(buf) };
255
256 cmd_resp
257 }
258 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
259 batching_write_buf
260 .compare_exchange(null_mut(), buf, Ordering::AcqRel, Ordering::Relaxed)
261 .expect("Only one poll write ops should be executing at a time");
262
263 return Poll::Pending;
264 }
265 Err(err) => {
266 _ = unsafe { Box::from_raw(buf) };
267 return Poll::Ready(BatchWrite {
268 handle,
269 result: Err(err),
270 });
271 }
272 };
273
274 let read_size = cmd_resp.try_into_datalen().unwrap();
275
276 return Poll::Ready(BatchWrite {
277 handle,
278 result: Ok((
279 read_size,
280 PathInfo {
281 from: laddr,
282 to: raddr,
283 },
284 )),
285 });
286 });
287 }
288
289 pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, PathInfo)> {
295 let ptr = Box::into_raw(Box::new(buf as *mut [u8]));
296
297 let old_ptr = self.batching_read_buf.swap(ptr, Ordering::AcqRel);
298
299 if old_ptr != null_mut() {
300 _ = unsafe { Box::from_raw(old_ptr) };
301 }
302
303 let batch_read = self
304 .batching_reader
305 .wait()
306 .await
307 .expect("No one call closed");
308
309 self.push_batch_read(batch_read.handle);
310
311 batch_read.result
312 }
313
314 pub async fn send_to(&self, buf: &[u8], raddr: SocketAddr) -> io::Result<(usize, PathInfo)> {
320 let ptr = Box::into_raw(Box::new((buf as *const [u8], raddr)));
321
322 let old_ptr = self.batching_write_buf.swap(ptr, Ordering::AcqRel);
323
324 if old_ptr != null_mut() {
325 _ = unsafe { Box::from_raw(old_ptr) };
326 }
327
328 let batch_write = self
329 .batching_writer
330 .wait()
331 .await
332 .expect("No one call closed");
333
334 self.push_batch_write(batch_write.handle);
335
336 batch_write.result
337 }
338
339 pub async fn send_to_on_path(&self, buf: &[u8], path_info: PathInfo) -> io::Result<usize> {
341 let fd = self.laddr_to_handle(path_info.from).ok_or(io::Error::new(
342 io::ErrorKind::NotFound,
343 format!("path info not found, {:?}", path_info),
344 ))?;
345
346 let r = would_block(|cx| {
347 self.driver
348 .fd_cntl(
349 fd,
350 Cmd::SendTo {
351 waker: cx.waker().clone(),
352 buf,
353 raddr: path_info.to,
354 },
355 )?
356 .try_into_datalen()
357 })
358 .await;
359
360 r
361 }
362
363 pub fn local_addrs(&self) -> impl Iterator<Item = &SocketAddr> {
365 self.laddrs.values()
366 }
367}
368
369#[cfg(test)]
370mod tests {
371
372 use hala_future::executor::future_spawn;
373 use hala_io::test::io_test;
374 use rand::{seq::SliceRandom, thread_rng};
375
376 use super::*;
377
378 #[hala_test::test(io_test)]
379 async fn test_send() {
380 let laddrs = vec!["127.0.0.1:0".parse().unwrap(); 1];
381
382 let server_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
383
384 let client_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
385
386 let raddrs = server_group
387 .local_addrs()
388 .map(|addr| *addr)
389 .collect::<Vec<_>>();
390
391 let loops = 10000;
392
393 future_spawn(async move {
394 loop {
395 let mut buf = vec![0; 1024];
396
397 let (recv_size, recv_path_info) = server_group.recv_from(&mut buf).await.unwrap();
398
399 server_group
400 .send_to_on_path(
401 &buf[..recv_size],
402 PathInfo {
403 from: recv_path_info.to,
404 to: recv_path_info.from,
405 },
406 )
407 .await
408 .unwrap();
409 }
410 });
411
412 for i in 0..loops {
413 let raddr = raddrs.choose(&mut thread_rng()).unwrap();
414
415 let data = format!("hello world {}", i);
416
417 let (send_size, send_path_info) =
418 client_group.send_to(data.as_bytes(), *raddr).await.unwrap();
419
420 let mut buf = vec![0; 1024];
421
422 let (read_size, path_info) = client_group.recv_from(&mut buf).await.unwrap();
423
424 assert_eq!(read_size, send_size);
425
426 assert_eq!(path_info.from, send_path_info.to);
427 assert_eq!(path_info.to, send_path_info.from);
428 }
429 }
430
431 #[hala_test::test(io_test)]
432 async fn test_sequence_send_recv() {
433 let laddrs = vec!["127.0.0.1:0".parse().unwrap(); 1];
434
435 let server_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
436
437 let client_group = UdpGroup::bind(laddrs.as_slice()).unwrap();
438
439 let raddrs = server_group
440 .local_addrs()
441 .map(|addr| *addr)
442 .collect::<Vec<_>>();
443
444 let loops = 10000;
445
446 for i in 0..loops {
447 let raddr = raddrs.choose(&mut thread_rng()).unwrap();
448
449 let data = format!("hello world {}", i);
450
451 let (send_size, send_path_info) =
452 client_group.send_to(data.as_bytes(), *raddr).await.unwrap();
453
454 let mut buf = vec![0; 1024];
455
456 let (read_size, path_info) = server_group.recv_from(&mut buf).await.unwrap();
457
458 assert_eq!(read_size, send_size);
459 assert_eq!(path_info, send_path_info);
460 }
461 }
462}