nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        let socket = Socket::new(
18            config.socket_family.to_domain(),
19            config.socket_type.to_sock_type(),
20            Some(Protocol::TCP),
21        )?;
22
23        socket.set_nonblocking(true)?;
24
25        // Set socket options based on configuration
26        if let Some(flag) = config.reuseaddr {
27            socket.set_reuse_address(flag)?;
28        }
29        if let Some(flag) = config.nodelay {
30            socket.set_nodelay(flag)?;
31        }
32        if let Some(ttl) = config.ttl {
33            socket.set_ttl(ttl)?;
34        }
35        if let Some(hoplimit) = config.hoplimit {
36            socket.set_unicast_hops_v6(hoplimit)?;
37        }
38        if let Some(keepalive) = config.keepalive {
39            socket.set_keepalive(keepalive)?;
40        }
41        if let Some(timeout) = config.read_timeout {
42            socket.set_read_timeout(Some(timeout))?;
43        }
44        if let Some(timeout) = config.write_timeout {
45            socket.set_write_timeout(Some(timeout))?;
46        }
47
48        // Linux: optional interface name
49        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
50        if let Some(iface) = &config.bind_device {
51            socket.bind_device(Some(iface.as_bytes()))?;
52        }
53
54        // bind to the specified address if provided
55        if let Some(addr) = config.bind_addr {
56            socket.bind(&addr.into())?;
57        }
58
59        Ok(Self { socket })
60    }
61
62    /// Create a socket of arbitrary type (STREAM or RAW).
63    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
64        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
65        socket.set_nonblocking(true)?;
66        Ok(Self { socket })
67    }
68
69    /// Convenience constructor for an IPv4 STREAM socket.
70    pub fn v4_stream() -> io::Result<Self> {
71        Self::new(Domain::IPV4, SockType::STREAM)
72    }
73
74    /// Convenience constructor for an IPv6 STREAM socket.
75    pub fn v6_stream() -> io::Result<Self> {
76        Self::new(Domain::IPV6, SockType::STREAM)
77    }
78
79    /// IPv4 RAW TCP. Requires administrator privileges.
80    pub fn raw_v4() -> io::Result<Self> {
81        Self::new(Domain::IPV4, SockType::RAW)
82    }
83
84    /// IPv6 RAW TCP. Requires administrator privileges.
85    pub fn raw_v6() -> io::Result<Self> {
86        Self::new(Domain::IPV6, SockType::RAW)
87    }
88
89    /// Connect to the target asynchronously.
90    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
91        // call connect
92        match self.socket.connect(&target.into()) {
93            Ok(_) => {
94                // connection completed immediately (rare case)
95                let std_stream: StdTcpStream = self.socket.into();
96                return TcpStream::from_std(std_stream);
97            }
98            Err(e)
99                if e.kind() == io::ErrorKind::WouldBlock
100                    || e.raw_os_error() == Some(libc::EINPROGRESS) =>
101            {
102                // wait until writable
103                let std_stream: StdTcpStream = self.socket.into();
104                let stream = TcpStream::from_std(std_stream)?;
105                stream.writable().await?;
106
107                // check the final connection state with SO_ERROR
108                if let Some(err) = stream.take_error()? {
109                    return Err(err);
110                }
111
112                return Ok(stream);
113            }
114            Err(e) => {
115                println!("Failed to connect: {}", e);
116                return Err(e);
117            }
118        }
119    }
120
121    /// Connect with a timeout to the target address.
122    pub async fn connect_timeout(
123        self,
124        target: SocketAddr,
125        timeout: Duration,
126    ) -> io::Result<TcpStream> {
127        match tokio::time::timeout(timeout, self.connect(target)).await {
128            Ok(result) => result,
129            Err(_) => Err(io::Error::new(
130                io::ErrorKind::TimedOut,
131                "connection timed out",
132            )),
133        }
134    }
135
136    /// Start listening for incoming connections.
137    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
138        self.socket.listen(backlog)?;
139
140        let std_listener: StdTcpListener = self.socket.into();
141        TcpListener::from_std(std_listener)
142    }
143
144    /// Send a raw TCP packet. Requires `SockType::RAW`.
145    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
146        self.socket.send_to(buf, &target.into())
147    }
148
149    /// Receive a raw TCP packet. Requires `SockType::RAW`.
150    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
151        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
152        let buf_maybe = unsafe {
153            std::slice::from_raw_parts_mut(
154                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
155                buf.len(),
156            )
157        };
158
159        let (n, addr) = self.socket.recv_from(buf_maybe)?;
160        let addr = addr
161            .as_socket()
162            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
163
164        Ok((n, addr))
165    }
166
167    /// Shutdown the socket.
168    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
169        self.socket.shutdown(how)
170    }
171
172    /// Set reuse address option.
173    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
174        self.socket.set_reuse_address(on)
175    }
176
177    /// Set no delay option for TCP.
178    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
179        self.socket.set_nodelay(on)
180    }
181
182    /// Set linger option for the socket.
183    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
184        self.socket.set_linger(dur)
185    }
186
187    /// Set the time-to-live for IPv4 packets.
188    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
189        self.socket.set_ttl(ttl)
190    }
191
192    /// Set the hop limit for IPv6 packets.
193    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
194        self.socket.set_unicast_hops_v6(hops)
195    }
196
197    /// Set the keepalive option for the socket.
198    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
199        self.socket.set_keepalive(on)
200    }
201
202    /// Set the bind device for the socket (Linux specific).
203    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
204        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
205        return self.socket.bind_device(Some(iface.as_bytes()));
206
207        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
208        {
209            let _ = iface;
210            Err(io::Error::new(
211                io::ErrorKind::Unsupported,
212                "bind_device not supported on this OS",
213            ))
214        }
215    }
216
217    /// Retrieve the local address of the socket.
218    pub fn local_addr(&self) -> io::Result<SocketAddr> {
219        self.socket
220            .local_addr()?
221            .as_socket()
222            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
223    }
224
225    /// Convert the internal socket into a Tokio `TcpStream`.
226    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
227        let std_stream: StdTcpStream = self.socket.into();
228        TcpStream::from_std(std_stream)
229    }
230
231    /// Extract the RAW file descriptor for Unix.
232    #[cfg(unix)]
233    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
234        use std::os::fd::AsRawFd;
235        self.socket.as_raw_fd()
236    }
237
238    /// Extract the RAW socket handle for Windows.
239    #[cfg(windows)]
240    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
241        use std::os::windows::io::AsRawSocket;
242        self.socket.as_raw_socket()
243    }
244}