1use std::{
2 io::{Error, ErrorKind, Read, Write},
3 net::Shutdown,
4 ops::Deref,
5 sync::RwLock,
6 task::Poll,
7};
8
9use mio::{event::Source, Interest, Token};
10use rasi::net::register_network_driver;
11
12use crate::{reactor::global_reactor, token::TokenSequence, utils::would_block};
13
14#[derive(Debug)]
16pub(crate) struct MioSocket<S: Source> {
17 pub(crate) token: Token,
19 pub(crate) socket: S,
21}
22
23impl<S: Source> From<(Token, S)> for MioSocket<S> {
24 fn from(value: (Token, S)) -> Self {
25 Self {
26 token: value.0,
27 socket: value.1,
28 }
29 }
30}
31
32impl<S: Source> Deref for MioSocket<S> {
33 type Target = S;
34 fn deref(&self) -> &Self::Target {
35 &self.socket
36 }
37}
38
39impl<S: Source> Drop for MioSocket<S> {
40 fn drop(&mut self) {
41 if global_reactor().deregister(&mut self.socket).is_err() {}
42 }
43}
44
45type MioTcpListener = MioSocket<mio::net::TcpListener>;
46
47impl rasi::net::syscall::DriverTcpListener for MioTcpListener {
48 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
49 self.socket.local_addr()
50 }
51
52 fn ttl(&self) -> std::io::Result<u32> {
53 self.socket.ttl()
54 }
55
56 fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
57 self.socket.set_ttl(ttl)
58 }
59
60 fn poll_next(
61 &self,
62 cx: &mut std::task::Context<'_>,
63 ) -> std::task::Poll<std::io::Result<(rasi::net::TcpStream, std::net::SocketAddr)>> {
64 would_block(
65 self.token,
66 cx.waker().clone(),
67 Interest::READABLE,
68 || match self.socket.accept() {
69 Ok((mut stream, raddr)) => {
70 let token = Token::next();
71
72 global_reactor().register(
73 &mut stream,
74 token,
75 Interest::READABLE.add(Interest::WRITABLE),
76 )?;
77
78 Ok((
79 MioTcpStream {
80 token,
81 socket: stream,
82 }
83 .into(),
84 raddr,
85 ))
86 }
87 Err(err) => Err(err),
88 },
89 )
90 }
91}
92
93type MioTcpStream = MioSocket<mio::net::TcpStream>;
94
95impl rasi::net::syscall::DriverTcpStream for MioTcpStream {
96 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
97 self.socket.local_addr()
98 }
99
100 fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
101 self.socket.peer_addr()
102 }
103
104 fn ttl(&self) -> std::io::Result<u32> {
105 self.socket.ttl()
106 }
107
108 fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
109 self.socket.set_ttl(ttl)
110 }
111
112 fn nodelay(&self) -> std::io::Result<bool> {
113 self.socket.nodelay()
114 }
115
116 fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
117 self.socket.set_nodelay(nodelay)
118 }
119
120 fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
121 self.socket.shutdown(how)
122 }
123
124 fn poll_read(
125 &self,
126 cx: &mut std::task::Context<'_>,
127 buf: &mut [u8],
128 ) -> std::task::Poll<std::io::Result<usize>> {
129 would_block(self.token, cx.waker().clone(), Interest::READABLE, || {
130 self.deref().read(buf)
131 })
132 }
133
134 fn poll_write(
135 &self,
136 cx: &mut std::task::Context<'_>,
137 buf: &[u8],
138 ) -> std::task::Poll<std::io::Result<usize>> {
139 would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
140 self.deref().write(buf)
141 })
142 }
143
144 fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<()>> {
145 would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
146 log::trace!("tcp_connect, poll_ready {:?}", self.token);
147
148 if let Err(err) = self.deref().take_error() {
149 return Err(err);
150 }
151
152 match self.deref().peer_addr() {
153 Ok(_) => {
154 return Ok(());
155 }
156 Err(err)
157 if err.kind() == ErrorKind::NotConnected
158 || err.raw_os_error() == Some(libc::EINPROGRESS) =>
159 {
160 return Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, ""));
161 }
162 Err(err) => {
163 return Err(err);
164 }
165 }
166 })
167 }
168}
169
170struct MioUdpSocket {
171 mio_socket: MioSocket<mio::net::UdpSocket>,
172 shutdown: RwLock<(bool, bool)>,
173}
174
175impl rasi::net::syscall::DriverUdpSocket for MioUdpSocket {
176 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
177 self.mio_socket.socket.local_addr()
178 }
179
180 fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
181 self.mio_socket.socket.peer_addr()
182 }
183
184 fn ttl(&self) -> std::io::Result<u32> {
185 self.mio_socket.socket.ttl()
186 }
187
188 fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
189 self.mio_socket.socket.set_ttl(ttl)
190 }
191
192 fn join_multicast_v4(
193 &self,
194 multiaddr: &std::net::Ipv4Addr,
195 interface: &std::net::Ipv4Addr,
196 ) -> std::io::Result<()> {
197 self.mio_socket
198 .socket
199 .join_multicast_v4(multiaddr, interface)
200 }
201
202 fn join_multicast_v6(
203 &self,
204 multiaddr: &std::net::Ipv6Addr,
205 interface: u32,
206 ) -> std::io::Result<()> {
207 self.mio_socket
208 .socket
209 .join_multicast_v6(multiaddr, interface)
210 }
211
212 fn leave_multicast_v4(
213 &self,
214 multiaddr: &std::net::Ipv4Addr,
215 interface: &std::net::Ipv4Addr,
216 ) -> std::io::Result<()> {
217 self.mio_socket
218 .socket
219 .leave_multicast_v4(multiaddr, interface)
220 }
221
222 fn leave_multicast_v6(
223 &self,
224 multiaddr: &std::net::Ipv6Addr,
225 interface: u32,
226 ) -> std::io::Result<()> {
227 self.mio_socket
228 .socket
229 .leave_multicast_v6(multiaddr, interface)
230 }
231
232 fn set_broadcast(&self, on: bool) -> std::io::Result<()> {
233 self.mio_socket.socket.set_broadcast(on)
234 }
235
236 fn broadcast(&self) -> std::io::Result<bool> {
237 self.mio_socket.socket.broadcast()
238 }
239
240 fn set_multicast_loop_v4(&self, on: bool) -> std::io::Result<()> {
244 self.mio_socket.socket.set_multicast_loop_v4(on)
245 }
246
247 fn set_multicast_loop_v6(&self, on: bool) -> std::io::Result<()> {
251 self.mio_socket.socket.set_multicast_loop_v6(on)
252 }
253
254 fn multicast_loop_v4(&self) -> std::io::Result<bool> {
256 self.mio_socket.socket.multicast_loop_v4()
257 }
258
259 fn multicast_loop_v6(&self) -> std::io::Result<bool> {
261 self.mio_socket.socket.multicast_loop_v6()
262 }
263
264 fn poll_recv_from(
265 &self,
266 cx: &mut std::task::Context<'_>,
267 buf: &mut [u8],
268 ) -> Poll<std::io::Result<(usize, std::net::SocketAddr)>> {
269 let shutdown = self.shutdown.read().unwrap();
270
271 if shutdown.0 {
272 return Poll::Ready(Err(Error::new(
273 ErrorKind::BrokenPipe,
274 "UdpSocket read shutdown.",
275 )));
276 }
277
278 would_block(
279 self.mio_socket.token,
280 cx.waker().clone(),
281 Interest::READABLE,
282 || self.mio_socket.socket.recv_from(buf),
283 )
284 }
285
286 fn poll_send_to(
287 &self,
288 cx: &mut std::task::Context<'_>,
289 buf: &[u8],
290 peer: std::net::SocketAddr,
291 ) -> Poll<std::io::Result<usize>> {
292 let shutdown = self.shutdown.read().unwrap();
293 if shutdown.1 {
294 return Poll::Ready(Err(Error::new(
295 ErrorKind::BrokenPipe,
296 "UdpSocket write shutdown.",
297 )));
298 }
299
300 would_block(
301 self.mio_socket.token,
302 cx.waker().clone(),
303 Interest::WRITABLE,
304 || self.mio_socket.socket.send_to(buf, peer),
305 )
306 }
307
308 fn shutdown(&self, how: Shutdown) -> std::io::Result<()> {
315 let mut locker = self.shutdown.write().unwrap();
316
317 match how {
318 Shutdown::Read => {
319 locker.0 = true;
320
321 global_reactor().notify(self.mio_socket.token, Interest::READABLE);
322 }
323 Shutdown::Write => {
324 locker.1 = true;
325 global_reactor().notify(self.mio_socket.token, Interest::WRITABLE);
326 }
327 Shutdown::Both => {
328 locker.0 = true;
329 locker.1 = true;
330 global_reactor().notify(
331 self.mio_socket.token,
332 Interest::WRITABLE.add(Interest::READABLE),
333 );
334 }
335 }
336
337 Ok(())
338 }
339}
340
341#[cfg(unix)]
342type MioUnixListener = MioSocket<mio::net::UnixListener>;
343
344#[cfg(unix)]
345impl rasi::net::syscall::unix::DriverUnixListener for MioUnixListener {
346 fn local_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
347 self.socket.local_addr()
348 }
349
350 fn poll_next(
351 &self,
352 cx: &mut std::task::Context<'_>,
353 ) -> Poll<std::io::Result<(rasi::net::unix::UnixStream, std::os::unix::net::SocketAddr)>> {
354 would_block(
355 self.token,
356 cx.waker().clone(),
357 Interest::READABLE,
358 || match self.socket.accept() {
359 Ok((mut stream, raddr)) => {
360 let token = Token::next();
361
362 global_reactor().register(
363 &mut stream,
364 token,
365 Interest::READABLE.add(Interest::WRITABLE),
366 )?;
367
368 Ok((
369 MioUnixStream {
370 token,
371 socket: stream,
372 }
373 .into(),
374 raddr,
375 ))
376 }
377 Err(err) => Err(err),
378 },
379 )
380 }
381}
382
383#[cfg(unix)]
384type MioUnixStream = MioSocket<mio::net::UnixStream>;
385
386#[cfg(unix)]
387impl rasi::net::syscall::unix::DriverUnixStream for MioUnixStream {
388 fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
389 self.socket.shutdown(how)
390 }
391
392 fn poll_read(
393 &self,
394 cx: &mut std::task::Context<'_>,
395 buf: &mut [u8],
396 ) -> Poll<std::io::Result<usize>> {
397 would_block(self.token, cx.waker().clone(), Interest::READABLE, || {
398 self.deref().read(buf)
399 })
400 }
401
402 fn poll_write(
403 &self,
404 cx: &mut std::task::Context<'_>,
405 buf: &[u8],
406 ) -> Poll<std::io::Result<usize>> {
407 would_block(self.token, cx.waker().clone(), Interest::WRITABLE, || {
408 self.deref().write(buf)
409 })
410 }
411
412 fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<std::io::Result<()>> {
413 Poll::Ready(Ok(()))
414 }
415
416 fn local_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
417 self.socket.local_addr()
418 }
419
420 fn peer_addr(&self) -> std::io::Result<std::os::unix::net::SocketAddr> {
421 self.socket.peer_addr()
422 }
423}
424
425struct MioNetworkDriver;
426
427impl MioNetworkDriver {
428 fn tcp_listener_from_std_socket(
429 &self,
430 std_socket: std::net::TcpListener,
431 ) -> std::io::Result<rasi::net::TcpListener> {
432 let mut socket = mio::net::TcpListener::from_std(std_socket);
433
434 let token = Token::next();
435
436 global_reactor().register(
437 &mut socket,
438 token,
439 Interest::READABLE.add(Interest::WRITABLE),
440 )?;
441
442 Ok(MioTcpListener { token, socket }.into())
443 }
444
445 fn tcp_stream_from_std_socket(
446 &self,
447 std_socket: std::net::TcpStream,
448 ) -> std::io::Result<rasi::net::TcpStream> {
449 let mut socket = mio::net::TcpStream::from_std(std_socket);
450
451 let token = Token::next();
452
453 global_reactor().register(
454 &mut socket,
455 token,
456 Interest::READABLE.add(Interest::WRITABLE),
457 )?;
458
459 return Ok(MioTcpStream { token, socket }.into());
460 }
461
462 fn udp_socket_from_std_socket(
463 &self,
464 std_socket: std::net::UdpSocket,
465 ) -> std::io::Result<rasi::net::UdpSocket> {
466 let mut socket = mio::net::UdpSocket::from_std(std_socket);
467 let token = Token::next();
468
469 global_reactor().register(
470 &mut socket,
471 token,
472 Interest::READABLE.add(Interest::WRITABLE),
473 )?;
474
475 Ok(MioUdpSocket {
476 mio_socket: MioSocket { socket, token },
477 shutdown: RwLock::new((false, false)),
478 }
479 .into())
480 }
481}
482
483impl rasi::net::syscall::Driver for MioNetworkDriver {
484 fn tcp_listen(
485 &self,
486 laddrs: &[std::net::SocketAddr],
487 ) -> std::io::Result<rasi::net::TcpListener> {
488 let std_socket = std::net::TcpListener::bind(laddrs)?;
489
490 std_socket.set_nonblocking(true)?;
491
492 self.tcp_listener_from_std_socket(std_socket)
493 }
494
495 #[cfg(unix)]
496 unsafe fn tcp_listener_from_raw_fd(
497 &self,
498 fd: std::os::fd::RawFd,
499 ) -> std::io::Result<rasi::net::TcpListener> {
500 use std::os::fd::FromRawFd;
501
502 let std_socket = std::net::TcpListener::from_raw_fd(fd);
503
504 std_socket.set_nonblocking(true)?;
505
506 self.tcp_listener_from_std_socket(std_socket)
507 }
508
509 #[cfg(windows)]
510 unsafe fn tcp_listener_from_raw_socket(
511 &self,
512 socket: std::os::windows::io::RawSocket,
513 ) -> std::io::Result<rasi::net::TcpListener> {
514 use std::os::windows::io::FromRawSocket;
515
516 let std_socket = std::net::TcpListener::from_raw_socket(socket);
517
518 std_socket.set_nonblocking(true)?;
519
520 self.tcp_listener_from_std_socket(std_socket)
521 }
522
523 fn tcp_connect(&self, raddr: &std::net::SocketAddr) -> std::io::Result<rasi::net::TcpStream> {
524 log::trace!("tcp_connect, raddr={}", raddr);
525
526 let mut socket = mio::net::TcpStream::connect(raddr.clone())?;
527
528 let token = Token::next();
529
530 global_reactor().register(
531 &mut socket,
532 token,
533 Interest::READABLE.add(Interest::WRITABLE),
534 )?;
535
536 return Ok(MioTcpStream { token, socket }.into());
537 }
538
539 #[cfg(unix)]
540 unsafe fn tcp_stream_from_raw_fd(
541 &self,
542 fd: std::os::fd::RawFd,
543 ) -> std::io::Result<rasi::net::TcpStream> {
544 use std::os::fd::FromRawFd;
545
546 let std_socket = std::net::TcpStream::from_raw_fd(fd);
547
548 std_socket.set_nonblocking(true)?;
549
550 self.tcp_stream_from_std_socket(std_socket)
551 }
552
553 #[cfg(windows)]
554 unsafe fn tcp_stream_from_raw_socket(
555 &self,
556 socket: std::os::windows::io::RawSocket,
557 ) -> std::io::Result<rasi::net::TcpStream> {
558 use std::os::windows::io::FromRawSocket;
559
560 let std_socket = std::net::TcpStream::from_raw_socket(socket);
561
562 std_socket.set_nonblocking(true)?;
563
564 self.tcp_stream_from_std_socket(std_socket)
565 }
566
567 fn udp_bind(&self, laddrs: &[std::net::SocketAddr]) -> std::io::Result<rasi::net::UdpSocket> {
568 let std_socket = std::net::UdpSocket::bind(laddrs)?;
569
570 std_socket.set_nonblocking(true)?;
571
572 self.udp_socket_from_std_socket(std_socket)
573 }
574
575 #[cfg(unix)]
576 unsafe fn udp_from_raw_fd(
577 &self,
578 fd: std::os::fd::RawFd,
579 ) -> std::io::Result<rasi::net::UdpSocket> {
580 use std::os::fd::FromRawFd;
581
582 let std_socket = std::net::UdpSocket::from_raw_fd(fd);
583
584 std_socket.set_nonblocking(true)?;
585
586 self.udp_socket_from_std_socket(std_socket)
587 }
588
589 #[cfg(windows)]
590 unsafe fn udp_from_raw_socket(
591 &self,
592 socket: std::os::windows::io::RawSocket,
593 ) -> std::io::Result<rasi::net::UdpSocket> {
594 use std::os::windows::io::FromRawSocket;
595
596 let std_socket = std::net::UdpSocket::from_raw_socket(socket);
597
598 std_socket.set_nonblocking(true)?;
599
600 self.udp_socket_from_std_socket(std_socket)
601 }
602
603 #[cfg(unix)]
604 fn unix_listen(
605 &self,
606 path: &std::path::Path,
607 ) -> std::io::Result<rasi::net::unix::UnixListener> {
608 let mut socket = mio::net::UnixListener::bind(path)?;
609
610 let token = Token::next();
611
612 global_reactor().register(
613 &mut socket,
614 token,
615 Interest::READABLE.add(Interest::WRITABLE),
616 )?;
617
618 Ok(MioUnixListener { token, socket }.into())
619 }
620
621 #[cfg(unix)]
622 fn unix_connect(&self, path: &std::path::Path) -> std::io::Result<rasi::net::unix::UnixStream> {
623 let mut socket = mio::net::UnixStream::connect(path)?;
624
625 let token = Token::next();
626
627 global_reactor().register(
628 &mut socket,
629 token,
630 Interest::READABLE.add(Interest::WRITABLE),
631 )?;
632
633 Ok(MioUnixStream { token, socket }.into())
634 }
635}
636
637pub fn register_mio_network() {
641 register_network_driver(MioNetworkDriver)
642}
643
644#[cfg(test)]
645mod tests {
646
647 use rasi_spec::network::run_network_spec;
648
649 use super::*;
650
651 #[futures_test::test]
652 async fn test_network() {
653 static DRIVER: MioNetworkDriver = MioNetworkDriver;
654
655 run_network_spec(&DRIVER).await;
656
657 #[cfg(unix)]
658 rasi_spec::ipc::run_ipc_spec(&DRIVER).await;
659 }
660}