nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpStream as StdTcpStream, TcpListener as StdTcpListener};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
18
19        if let Some(flag) = config.reuseaddr {
20            socket.set_reuse_address(flag)?;
21        }
22        if let Some(flag) = config.nodelay {
23            socket.set_nodelay(flag)?;
24        }
25        if let Some(ttl) = config.ttl {
26            socket.set_ttl(ttl)?;
27        }
28
29        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
30        if let Some(iface) = &config.bind_device {
31            socket.bind_device(Some(iface.as_bytes()))?;
32        }
33
34        if let Some(addr) = config.bind_addr {
35            socket.bind(&addr.into())?;
36        }
37
38        socket.set_nonblocking(true)?;
39
40        Ok(Self { socket })
41    }
42
43    /// Create a socket of arbitrary type (STREAM or RAW).
44    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
45        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
46        socket.set_nonblocking(true)?;
47        Ok(Self { socket })
48    }
49
50    /// Convenience constructor for an IPv4 STREAM socket.
51    pub fn v4_stream() -> io::Result<Self> {
52        Self::new(Domain::IPV4, SockType::STREAM)
53    }
54
55    /// Convenience constructor for an IPv6 STREAM socket.
56    pub fn v6_stream() -> io::Result<Self> {
57        Self::new(Domain::IPV6, SockType::STREAM)
58    }
59
60    /// IPv4 RAW TCP. Requires administrator privileges.
61    pub fn raw_v4() -> io::Result<Self> {
62        Self::new(Domain::IPV4, SockType::RAW)
63    }
64
65    /// IPv6 RAW TCP. Requires administrator privileges.
66    pub fn raw_v6() -> io::Result<Self> {
67        Self::new(Domain::IPV6, SockType::RAW)
68    }
69
70    /// Connect to the target asynchronously.
71    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
72        // call connect
73        match self.socket.connect(&target.into()) {
74            Ok(_) => {
75                // connection completed immediately (rare case)
76                let std_stream: StdTcpStream = self.socket.into();
77                return TcpStream::from_std(std_stream);
78            }
79            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(libc::EINPROGRESS) => {
80                // wait until writable
81                let std_stream: StdTcpStream = self.socket.into();
82                let stream = TcpStream::from_std(std_stream)?;
83                stream.writable().await?;
84
85                // check the final connection state with SO_ERROR
86                if let Some(err) = stream.take_error()? {
87                    return Err(err);
88                }
89
90                return Ok(stream);
91            }
92            Err(e) => {
93                println!("Failed to connect: {}", e);
94                return Err(e);
95            }
96        }
97    }
98
99    /// Connect with a timeout to the target address.
100    pub async fn connect_timeout(self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
101        match tokio::time::timeout(timeout, self.connect(target)).await {
102            Ok(result) => result,
103            Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
104        }
105    }
106
107    /// Start listening for incoming connections.
108    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
109        self.socket.listen(backlog)?;
110
111        let std_listener: StdTcpListener = self.socket.into();
112        TcpListener::from_std(std_listener)
113    }
114
115    /// Send a raw TCP packet. Requires `SockType::RAW`.
116    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
117        self.socket.send_to(buf, &target.into())
118    }
119
120    /// Receive a raw TCP packet. Requires `SockType::RAW`.
121    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
122        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
123        let buf_maybe = unsafe {
124            std::slice::from_raw_parts_mut(
125                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
126                buf.len(),
127            )
128        };
129
130        let (n, addr) = self.socket.recv_from(buf_maybe)?;
131        let addr = addr.as_socket().ok_or_else(|| {
132            io::Error::new(io::ErrorKind::InvalidData, "invalid address format")
133        })?;
134
135        Ok((n, addr))
136    }
137
138    // --- option helpers ---
139
140    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
141        self.socket.set_reuse_address(on)
142    }
143
144    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
145        self.socket.set_nodelay(on)
146    }
147
148    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
149        self.socket.set_linger(dur)
150    }
151
152    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
153        self.socket.set_ttl(ttl)
154    }
155
156    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
157        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
158        return self.socket.bind_device(Some(iface.as_bytes()));
159
160        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
161        {
162            let _ = iface;
163            Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS"))
164        }
165    }
166
167    /// Retrieve the local address of the socket.
168    pub fn local_addr(&self) -> io::Result<SocketAddr> {
169        self.socket.local_addr()?.as_socket().ok_or_else(|| {
170            io::Error::new(io::ErrorKind::Other, "Failed to get socket address")
171        })
172    }
173
174    /// Convert the internal socket into a Tokio `TcpStream`.
175    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
176        let std_stream: StdTcpStream = self.socket.into();
177        TcpStream::from_std(std_stream)
178    }
179
180    #[cfg(unix)]
181    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
182        use std::os::fd::AsRawFd;
183        self.socket.as_raw_fd()
184    }
185
186    #[cfg(windows)]
187    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
188        use std::os::windows::io::AsRawSocket;
189        self.socket.as_raw_socket()
190    }
191}