rasi_mio/
net.rs

1use std::{
2    io::{Error, ErrorKind, Read, Write},
3    net::Shutdown,
4    ops::Deref,
5    sync::RwLock,
6    task::Poll,
7};
8
9use mio::{event::Source, Interest, Token};
10use rasi::net::register_network_driver;
11
12use crate::{reactor::global_reactor, token::TokenSequence, utils::would_block};
13
14/// A wrapper of mio event source.
15#[derive(Debug)]
16pub(crate) struct MioSocket<S: Source> {
17    /// Associcated token.
18    pub(crate) token: Token,
19    /// net source type.
20    pub(crate) socket: S,
21}
22
23impl<S: Source> From<(Token, S)> for MioSocket<S> {
24    fn from(value: (Token, S)) -> Self {
25        Self {
26            token: value.0,
27            socket: value.1,
28        }
29    }
30}
31
32impl<S: Source> Deref for MioSocket<S> {
33    type Target = S;
34    fn deref(&self) -> &Self::Target {
35        &self.socket
36    }
37}
38
39impl<S: Source> Drop for MioSocket<S> {
40    fn drop(&mut self) {
41        if global_reactor().deregister(&mut self.socket).is_err() {}
42    }
43}
44
45type MioTcpListener = MioSocket<mio::net::TcpListener>;
46
47impl rasi::net::syscall::DriverTcpListener for MioTcpListener {
48    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
49        self.socket.local_addr()
50    }
51
52    fn ttl(&self) -> std::io::Result<u32> {
53        self.socket.ttl()
54    }
55
56    fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
57        self.socket.set_ttl(ttl)
58    }
59
60    fn poll_next(
61        &self,
62        cx: &mut std::task::Context<'_>,
63    ) -> std::task::Poll<std::io::Result<(rasi::net::TcpStream, std::net::SocketAddr)>> {
64        would_block(
65            self.token,
66            cx.waker().clone(),
67            Interest::READABLE,
68            || match self.socket.accept() {
69                Ok((mut stream, raddr)) => {
70                    let token = Token::next();
71
72                    global_reactor().register(
73                        &mut stream,
74                        token,
75                        Interest::READABLE.add(Interest::WRITABLE),
76                    )?;
77
78                    Ok((
79                        MioTcpStream {
80                            token,
81                            socket: stream,
82                        }
83                        .into(),
84                        raddr,
85                    ))
86                }
87                Err(err) => Err(err),
88            },
89        )
90    }
91}
92
93type MioTcpStream = MioSocket<mio::net::TcpStream>;
94
95impl rasi::net::syscall::DriverTcpStream for MioTcpStream {
96    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
97        self.socket.local_addr()
98    }
99
100    fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
101        self.socket.peer_addr()
102    }
103
104    fn ttl(&self) -> std::io::Result<u32> {
105        self.socket.ttl()
106    }
107
108    fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
109        self.socket.set_ttl(ttl)
110    }
111
112    fn nodelay(&self) -> std::io::Result<bool> {
113        self.socket.nodelay()
114    }
115
116    fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
117        self.socket.set_nodelay(nodelay)
118    }
119
120    fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
121        self.socket.shutdown(how)
122    }
123
124    fn poll_read(
125        &self,
126        cx: &mut std::task::Context<'_>,
127        buf: &mut [u8],
128    ) -> std::task::Poll<std::io::Result<usize>> {
129        would_block(self.token, cx.waker().clone(), Interest::READABLE, || {
130            self.deref().read(buf)
131        })
132    }
133
134    fn poll_write(
135        &self,
136        cx: &mut std::task::Context<'_>,
137        buf: &[u8],
138    ) -> std::task::Poll<std::io::Result<usize>> {
139        would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
140            self.deref().write(buf)
141        })
142    }
143
144    fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<()>> {
145        would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
146            log::trace!("tcp_connect, poll_ready {:?}", self.token);
147
148            if let Err(err) = self.deref().take_error() {
149                return Err(err);
150            }
151
152            match self.deref().peer_addr() {
153                Ok(_) => {
154                    return Ok(());
155                }
156                Err(err)
157                    if err.kind() == ErrorKind::NotConnected
158                        || err.raw_os_error() == Some(libc::EINPROGRESS) =>
159                {
160                    return Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, ""));
161                }
162                Err(err) => {
163                    return Err(err);
164                }
165            }
166        })
167    }
168}
169
170struct MioUdpSocket {
171    mio_socket: MioSocket<mio::net::UdpSocket>,
172    shutdown: RwLock<(bool, bool)>,
173}
174
175impl rasi::net::syscall::DriverUdpSocket for MioUdpSocket {
176    fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
177        self.mio_socket.socket.local_addr()
178    }
179
180    fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
181        self.mio_socket.socket.peer_addr()
182    }
183
184    fn ttl(&self) -> std::io::Result<u32> {
185        self.mio_socket.socket.ttl()
186    }
187
188    fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
189        self.mio_socket.socket.set_ttl(ttl)
190    }
191
192    fn join_multicast_v4(
193        &self,
194        multiaddr: &std::net::Ipv4Addr,
195        interface: &std::net::Ipv4Addr,
196    ) -> std::io::Result<()> {
197        self.mio_socket
198            .socket
199            .join_multicast_v4(multiaddr, interface)
200    }
201
202    fn join_multicast_v6(
203        &self,
204        multiaddr: &std::net::Ipv6Addr,
205        interface: u32,
206    ) -> std::io::Result<()> {
207        self.mio_socket
208            .socket
209            .join_multicast_v6(multiaddr, interface)
210    }
211
212    fn leave_multicast_v4(
213        &self,
214        multiaddr: &std::net::Ipv4Addr,
215        interface: &std::net::Ipv4Addr,
216    ) -> std::io::Result<()> {
217        self.mio_socket
218            .socket
219            .leave_multicast_v4(multiaddr, interface)
220    }
221
222    fn leave_multicast_v6(
223        &self,
224        multiaddr: &std::net::Ipv6Addr,
225        interface: u32,
226    ) -> std::io::Result<()> {
227        self.mio_socket
228            .socket
229            .leave_multicast_v6(multiaddr, interface)
230    }
231
232    fn set_broadcast(&self, on: bool) -> std::io::Result<()> {
233        self.mio_socket.socket.set_broadcast(on)
234    }
235
236    fn broadcast(&self) -> std::io::Result<bool> {
237        self.mio_socket.socket.broadcast()
238    }
239
240    /// Sets the value of the IP_MULTICAST_LOOP option for this socket.
241    ///
242    /// If enabled, multicast packets will be looped back to the local socket. Note that this might not have any effect on IPv6 sockets.
243    fn set_multicast_loop_v4(&self, on: bool) -> std::io::Result<()> {
244        self.mio_socket.socket.set_multicast_loop_v4(on)
245    }
246
247    /// Sets the value of the IPV6_MULTICAST_LOOP option for this socket.
248    ///
249    /// Controls whether this socket sees the multicast packets it sends itself. Note that this might not have any affect on IPv4 sockets.
250    fn set_multicast_loop_v6(&self, on: bool) -> std::io::Result<()> {
251        self.mio_socket.socket.set_multicast_loop_v6(on)
252    }
253
254    /// Gets the value of the IP_MULTICAST_LOOP option for this socket.
255    fn multicast_loop_v4(&self) -> std::io::Result<bool> {
256        self.mio_socket.socket.multicast_loop_v4()
257    }
258
259    /// Gets the value of the IPV6_MULTICAST_LOOP option for this socket.
260    fn multicast_loop_v6(&self) -> std::io::Result<bool> {
261        self.mio_socket.socket.multicast_loop_v6()
262    }
263
264    fn poll_recv_from(
265        &self,
266        cx: &mut std::task::Context<'_>,
267        buf: &mut [u8],
268    ) -> Poll<std::io::Result<(usize, std::net::SocketAddr)>> {
269        let shutdown = self.shutdown.read().unwrap();
270
271        if shutdown.0 {
272            return Poll::Ready(Err(Error::new(
273                ErrorKind::BrokenPipe,
274                "UdpSocket read shutdown.",
275            )));
276        }
277
278        would_block(
279            self.mio_socket.token,
280            cx.waker().clone(),
281            Interest::READABLE,
282            || self.mio_socket.socket.recv_from(buf),
283        )
284    }
285
286    fn poll_send_to(
287        &self,
288        cx: &mut std::task::Context<'_>,
289        buf: &[u8],
290        peer: std::net::SocketAddr,
291    ) -> Poll<std::io::Result<usize>> {
292        let shutdown = self.shutdown.read().unwrap();
293        if shutdown.1 {
294            return Poll::Ready(Err(Error::new(
295                ErrorKind::BrokenPipe,
296                "UdpSocket write shutdown.",
297            )));
298        }
299
300        would_block(
301            self.mio_socket.token,
302            cx.waker().clone(),
303            Interest::WRITABLE,
304            || self.mio_socket.socket.send_to(buf, peer),
305        )
306    }
307
308    /// Shuts down the read, write, or both halves of this connection.
309    ///
310    /// This method will cause all pending and future I/O on the specified portions to return
311    /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
312    ///
313    /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html
314    fn shutdown(&self, how: Shutdown) -> std::io::Result<()> {
315        let mut locker = self.shutdown.write().unwrap();
316
317        match how {
318            Shutdown::Read => {
319                locker.0 = true;
320
321                global_reactor().notify(self.mio_socket.token, Interest::READABLE);
322            }
323            Shutdown::Write => {
324                locker.1 = true;
325                global_reactor().notify(self.mio_socket.token, Interest::WRITABLE);
326            }
327            Shutdown::Both => {
328                locker.0 = true;
329                locker.1 = true;
330                global_reactor().notify(
331                    self.mio_socket.token,
332                    Interest::WRITABLE.add(Interest::READABLE),
333                );
334            }
335        }
336
337        Ok(())
338    }
339}
340
341#[cfg(unix)]
342type MioUnixListener = MioSocket<mio::net::UnixListener>;
343
344#[cfg(unix)]
345impl rasi::net::syscall::unix::DriverUnixListener for MioUnixListener {
346    fn local_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
347        self.socket.local_addr()
348    }
349
350    fn poll_next(
351        &self,
352        cx: &mut std::task::Context<'_>,
353    ) -> Poll<std::io::Result<(rasi::net::unix::UnixStream, std::os::unix::net::SocketAddr)>> {
354        would_block(
355            self.token,
356            cx.waker().clone(),
357            Interest::READABLE,
358            || match self.socket.accept() {
359                Ok((mut stream, raddr)) => {
360                    let token = Token::next();
361
362                    global_reactor().register(
363                        &mut stream,
364                        token,
365                        Interest::READABLE.add(Interest::WRITABLE),
366                    )?;
367
368                    Ok((
369                        MioUnixStream {
370                            token,
371                            socket: stream,
372                        }
373                        .into(),
374                        raddr,
375                    ))
376                }
377                Err(err) => Err(err),
378            },
379        )
380    }
381}
382
383#[cfg(unix)]
384type MioUnixStream = MioSocket<mio::net::UnixStream>;
385
386#[cfg(unix)]
387impl rasi::net::syscall::unix::DriverUnixStream for MioUnixStream {
388    fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
389        self.socket.shutdown(how)
390    }
391
392    fn poll_read(
393        &self,
394        cx: &mut std::task::Context<'_>,
395        buf: &mut [u8],
396    ) -> Poll<std::io::Result<usize>> {
397        would_block(self.token, cx.waker().clone(), Interest::READABLE, || {
398            self.deref().read(buf)
399        })
400    }
401
402    fn poll_write(
403        &self,
404        cx: &mut std::task::Context<'_>,
405        buf: &[u8],
406    ) -> Poll<std::io::Result<usize>> {
407        would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
408            self.deref().write(buf)
409        })
410    }
411
412    fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
413        Poll::Ready(Ok(()))
414    }
415
416    fn local_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
417        self.socket.local_addr()
418    }
419
420    fn peer_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
421        self.socket.peer_addr()
422    }
423}
424
425struct MioNetworkDriver;
426
427impl MioNetworkDriver {
428    fn tcp_listener_from_std_socket(
429        &self,
430        std_socket: std::net::TcpListener,
431    ) -> std::io::Result<rasi::net::TcpListener> {
432        let mut socket = mio::net::TcpListener::from_std(std_socket);
433
434        let token = Token::next();
435
436        global_reactor().register(
437            &mut socket,
438            token,
439            Interest::READABLE.add(Interest::WRITABLE),
440        )?;
441
442        Ok(MioTcpListener { token, socket }.into())
443    }
444
445    fn tcp_stream_from_std_socket(
446        &self,
447        std_socket: std::net::TcpStream,
448    ) -> std::io::Result<rasi::net::TcpStream> {
449        let mut socket = mio::net::TcpStream::from_std(std_socket);
450
451        let token = Token::next();
452
453        global_reactor().register(
454            &mut socket,
455            token,
456            Interest::READABLE.add(Interest::WRITABLE),
457        )?;
458
459        return Ok(MioTcpStream { token, socket }.into());
460    }
461
462    fn udp_socket_from_std_socket(
463        &self,
464        std_socket: std::net::UdpSocket,
465    ) -> std::io::Result<rasi::net::UdpSocket> {
466        let mut socket = mio::net::UdpSocket::from_std(std_socket);
467        let token = Token::next();
468
469        global_reactor().register(
470            &mut socket,
471            token,
472            Interest::READABLE.add(Interest::WRITABLE),
473        )?;
474
475        Ok(MioUdpSocket {
476            mio_socket: MioSocket { socket, token },
477            shutdown: RwLock::new((false, false)),
478        }
479        .into())
480    }
481}
482
483impl rasi::net::syscall::Driver for MioNetworkDriver {
484    fn tcp_listen(
485        &self,
486        laddrs: &[std::net::SocketAddr],
487    ) -> std::io::Result<rasi::net::TcpListener> {
488        let std_socket = std::net::TcpListener::bind(laddrs)?;
489
490        std_socket.set_nonblocking(true)?;
491
492        self.tcp_listener_from_std_socket(std_socket)
493    }
494
495    #[cfg(unix)]
496    unsafe fn tcp_listener_from_raw_fd(
497        &self,
498        fd: std::os::fd::RawFd,
499    ) -> std::io::Result<rasi::net::TcpListener> {
500        use std::os::fd::FromRawFd;
501
502        let std_socket = std::net::TcpListener::from_raw_fd(fd);
503
504        std_socket.set_nonblocking(true)?;
505
506        self.tcp_listener_from_std_socket(std_socket)
507    }
508
509    #[cfg(windows)]
510    unsafe fn tcp_listener_from_raw_socket(
511        &self,
512        socket: std::os::windows::io::RawSocket,
513    ) -> std::io::Result<rasi::net::TcpListener> {
514        use std::os::windows::io::FromRawSocket;
515
516        let std_socket = std::net::TcpListener::from_raw_socket(socket);
517
518        std_socket.set_nonblocking(true)?;
519
520        self.tcp_listener_from_std_socket(std_socket)
521    }
522
523    fn tcp_connect(&self, raddr: &std::net::SocketAddr) -> std::io::Result<rasi::net::TcpStream> {
524        log::trace!("tcp_connect, raddr={}", raddr);
525
526        let mut socket = mio::net::TcpStream::connect(raddr.clone())?;
527
528        let token = Token::next();
529
530        global_reactor().register(
531            &mut socket,
532            token,
533            Interest::READABLE.add(Interest::WRITABLE),
534        )?;
535
536        return Ok(MioTcpStream { token, socket }.into());
537    }
538
539    #[cfg(unix)]
540    unsafe fn tcp_stream_from_raw_fd(
541        &self,
542        fd: std::os::fd::RawFd,
543    ) -> std::io::Result<rasi::net::TcpStream> {
544        use std::os::fd::FromRawFd;
545
546        let std_socket = std::net::TcpStream::from_raw_fd(fd);
547
548        std_socket.set_nonblocking(true)?;
549
550        self.tcp_stream_from_std_socket(std_socket)
551    }
552
553    #[cfg(windows)]
554    unsafe fn tcp_stream_from_raw_socket(
555        &self,
556        socket: std::os::windows::io::RawSocket,
557    ) -> std::io::Result<rasi::net::TcpStream> {
558        use std::os::windows::io::FromRawSocket;
559
560        let std_socket = std::net::TcpStream::from_raw_socket(socket);
561
562        std_socket.set_nonblocking(true)?;
563
564        self.tcp_stream_from_std_socket(std_socket)
565    }
566
567    fn udp_bind(&self, laddrs: &[std::net::SocketAddr]) -> std::io::Result<rasi::net::UdpSocket> {
568        let std_socket = std::net::UdpSocket::bind(laddrs)?;
569
570        std_socket.set_nonblocking(true)?;
571
572        self.udp_socket_from_std_socket(std_socket)
573    }
574
575    #[cfg(unix)]
576    unsafe fn udp_from_raw_fd(
577        &self,
578        fd: std::os::fd::RawFd,
579    ) -> std::io::Result<rasi::net::UdpSocket> {
580        use std::os::fd::FromRawFd;
581
582        let std_socket = std::net::UdpSocket::from_raw_fd(fd);
583
584        std_socket.set_nonblocking(true)?;
585
586        self.udp_socket_from_std_socket(std_socket)
587    }
588
589    #[cfg(windows)]
590    unsafe fn udp_from_raw_socket(
591        &self,
592        socket: std::os::windows::io::RawSocket,
593    ) -> std::io::Result<rasi::net::UdpSocket> {
594        use std::os::windows::io::FromRawSocket;
595
596        let std_socket = std::net::UdpSocket::from_raw_socket(socket);
597
598        std_socket.set_nonblocking(true)?;
599
600        self.udp_socket_from_std_socket(std_socket)
601    }
602
603    #[cfg(unix)]
604    fn unix_listen(
605        &self,
606        path: &std::path::Path,
607    ) -> std::io::Result<rasi::net::unix::UnixListener> {
608        let mut socket = mio::net::UnixListener::bind(path)?;
609
610        let token = Token::next();
611
612        global_reactor().register(
613            &mut socket,
614            token,
615            Interest::READABLE.add(Interest::WRITABLE),
616        )?;
617
618        Ok(MioUnixListener { token, socket }.into())
619    }
620
621    #[cfg(unix)]
622    fn unix_connect(&self, path: &std::path::Path) -> std::io::Result<rasi::net::unix::UnixStream> {
623        let mut socket = mio::net::UnixStream::connect(path)?;
624
625        let token = Token::next();
626
627        global_reactor().register(
628            &mut socket,
629            token,
630            Interest::READABLE.add(Interest::WRITABLE),
631        )?;
632
633        Ok(MioUnixStream { token, socket }.into())
634    }
635}
636
637/// This function using [`register_network_driver`] to register the `MioNetwork` to global registry.
638///
639/// So you may not call this function twice, otherwise will cause a panic. [`read more`](`register_network_driver`)
640pub fn register_mio_network() {
641    register_network_driver(MioNetworkDriver)
642}
643
644#[cfg(test)]
645mod tests {
646
647    use rasi_spec::network::run_network_spec;
648
649    use super::*;
650
651    #[futures_test::test]
652    async fn test_network() {
653        static DRIVER: MioNetworkDriver = MioNetworkDriver;
654
655        run_network_spec(&DRIVER).await;
656
657        #[cfg(unix)]
658        rasi_spec::ipc::run_ipc_spec(&DRIVER).await;
659    }
660}