nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        let socket = Socket::new(
18            config.socket_family.to_domain(),
19            config.socket_type.to_sock_type(),
20            Some(Protocol::TCP),
21        )?;
22
23        socket.set_nonblocking(true)?;
24
25        // Set socket options based on configuration
26        if let Some(flag) = config.reuseaddr {
27            socket.set_reuse_address(flag)?;
28        }
29        if let Some(flag) = config.nodelay {
30            socket.set_nodelay(flag)?;
31        }
32        if let Some(ttl) = config.ttl {
33            socket.set_ttl(ttl)?;
34        }
35        if let Some(hoplimit) = config.hoplimit {
36            socket.set_unicast_hops_v6(hoplimit)?;
37        }
38        if let Some(keepalive) = config.keepalive {
39            socket.set_keepalive(keepalive)?;
40        }
41        if let Some(timeout) = config.read_timeout {
42            socket.set_read_timeout(Some(timeout))?;
43        }
44        if let Some(timeout) = config.write_timeout {
45            socket.set_write_timeout(Some(timeout))?;
46        }
47
48        // Linux: optional interface name
49        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
50        if let Some(iface) = &config.bind_device {
51            socket.bind_device(Some(iface.as_bytes()))?;
52        }
53
54        // bind to the specified address if provided
55        if let Some(addr) = config.bind_addr {
56            socket.bind(&addr.into())?;
57        }
58
59        Ok(Self { socket })
60    }
61
62    /// Create a socket of arbitrary type (STREAM or RAW).
63    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
64        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
65        socket.set_nonblocking(true)?;
66        Ok(Self { socket })
67    }
68
69    /// Convenience constructor for an IPv4 STREAM socket.
70    pub fn v4_stream() -> io::Result<Self> {
71        Self::new(Domain::IPV4, SockType::STREAM)
72    }
73
74    /// Convenience constructor for an IPv6 STREAM socket.
75    pub fn v6_stream() -> io::Result<Self> {
76        Self::new(Domain::IPV6, SockType::STREAM)
77    }
78
79    /// IPv4 RAW TCP. Requires administrator privileges.
80    pub fn raw_v4() -> io::Result<Self> {
81        Self::new(Domain::IPV4, SockType::RAW)
82    }
83
84    /// IPv6 RAW TCP. Requires administrator privileges.
85    pub fn raw_v6() -> io::Result<Self> {
86        Self::new(Domain::IPV6, SockType::RAW)
87    }
88
89    /// Connect to the target asynchronously.
90    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
91        // call connect
92        match self.socket.connect(&target.into()) {
93            Ok(_) => {
94                // connection completed immediately (rare case)
95                let std_stream: StdTcpStream = self.socket.into();
96                return TcpStream::from_std(std_stream);
97            }
98            Err(e)
99                if e.kind() == io::ErrorKind::WouldBlock
100                    || e.raw_os_error() == Some(libc::EINPROGRESS) =>
101            {
102                // wait until writable
103                let std_stream: StdTcpStream = self.socket.into();
104                let stream = TcpStream::from_std(std_stream)?;
105                stream.writable().await?;
106
107                // check the final connection state with SO_ERROR
108                if let Some(err) = stream.take_error()? {
109                    return Err(err);
110                }
111
112                return Ok(stream);
113            }
114            Err(e) => {
115                return Err(e);
116            }
117        }
118    }
119
120    /// Connect with a timeout to the target address.
121    pub async fn connect_timeout(
122        self,
123        target: SocketAddr,
124        timeout: Duration,
125    ) -> io::Result<TcpStream> {
126        match tokio::time::timeout(timeout, self.connect(target)).await {
127            Ok(result) => result,
128            Err(_) => Err(io::Error::new(
129                io::ErrorKind::TimedOut,
130                "connection timed out",
131            )),
132        }
133    }
134
135    /// Start listening for incoming connections.
136    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
137        self.socket.listen(backlog)?;
138
139        let std_listener: StdTcpListener = self.socket.into();
140        TcpListener::from_std(std_listener)
141    }
142
143    /// Send a raw TCP packet. Requires `SockType::RAW`.
144    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
145        self.socket.send_to(buf, &target.into())
146    }
147
148    /// Receive a raw TCP packet. Requires `SockType::RAW`.
149    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
150        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
151        let buf_maybe = unsafe {
152            std::slice::from_raw_parts_mut(
153                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
154                buf.len(),
155            )
156        };
157
158        let (n, addr) = self.socket.recv_from(buf_maybe)?;
159        let addr = addr
160            .as_socket()
161            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
162
163        Ok((n, addr))
164    }
165
166    /// Shutdown the socket.
167    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
168        self.socket.shutdown(how)
169    }
170
171    /// Set reuse address option.
172    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
173        self.socket.set_reuse_address(on)
174    }
175
176    /// Set no delay option for TCP.
177    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
178        self.socket.set_nodelay(on)
179    }
180
181    /// Set linger option for the socket.
182    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
183        self.socket.set_linger(dur)
184    }
185
186    /// Set the time-to-live for IPv4 packets.
187    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
188        self.socket.set_ttl(ttl)
189    }
190
191    /// Set the hop limit for IPv6 packets.
192    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
193        self.socket.set_unicast_hops_v6(hops)
194    }
195
196    /// Set the keepalive option for the socket.
197    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
198        self.socket.set_keepalive(on)
199    }
200
201    /// Set the bind device for the socket (Linux specific).
202    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
203        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
204        return self.socket.bind_device(Some(iface.as_bytes()));
205
206        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
207        {
208            let _ = iface;
209            Err(io::Error::new(
210                io::ErrorKind::Unsupported,
211                "bind_device not supported on this OS",
212            ))
213        }
214    }
215
216    /// Retrieve the local address of the socket.
217    pub fn local_addr(&self) -> io::Result<SocketAddr> {
218        self.socket
219            .local_addr()?
220            .as_socket()
221            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
222    }
223
224    /// Convert the internal socket into a Tokio `TcpStream`.
225    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
226        let std_stream: StdTcpStream = self.socket.into();
227        TcpStream::from_std(std_stream)
228    }
229
230    /// Extract the RAW file descriptor for Unix.
231    #[cfg(unix)]
232    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
233        use std::os::fd::AsRawFd;
234        self.socket.as_raw_fd()
235    }
236
237    /// Extract the RAW socket handle for Windows.
238    #[cfg(windows)]
239    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
240        use std::os::windows::io::AsRawSocket;
241        self.socket.as_raw_socket()
242    }
243}