Skip to main content

nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        let socket = Socket::new(
18            config.socket_family.to_domain(),
19            config.socket_type.to_sock_type(),
20            Some(Protocol::TCP),
21        )?;
22
23        socket.set_nonblocking(true)?;
24
25        // Set socket options based on configuration
26        if let Some(flag) = config.reuseaddr {
27            socket.set_reuse_address(flag)?;
28        }
29        #[cfg(any(
30            target_os = "android",
31            target_os = "dragonfly",
32            target_os = "freebsd",
33            target_os = "fuchsia",
34            target_os = "ios",
35            target_os = "linux",
36            target_os = "macos",
37            target_os = "netbsd",
38            target_os = "openbsd",
39            target_os = "tvos",
40            target_os = "visionos",
41            target_os = "watchos"
42        ))]
43        if let Some(flag) = config.reuseport {
44            socket.set_reuse_port(flag)?;
45        }
46        if let Some(flag) = config.nodelay {
47            socket.set_nodelay(flag)?;
48        }
49        if let Some(ttl) = config.ttl {
50            socket.set_ttl(ttl)?;
51        }
52        if let Some(hoplimit) = config.hoplimit {
53            socket.set_unicast_hops_v6(hoplimit)?;
54        }
55        if let Some(keepalive) = config.keepalive {
56            socket.set_keepalive(keepalive)?;
57        }
58        if let Some(timeout) = config.read_timeout {
59            socket.set_read_timeout(Some(timeout))?;
60        }
61        if let Some(timeout) = config.write_timeout {
62            socket.set_write_timeout(Some(timeout))?;
63        }
64        if let Some(size) = config.recv_buffer_size {
65            socket.set_recv_buffer_size(size)?;
66        }
67        if let Some(size) = config.send_buffer_size {
68            socket.set_send_buffer_size(size)?;
69        }
70        if let Some(tos) = config.tos {
71            socket.set_tos(tos)?;
72        }
73        #[cfg(any(
74            target_os = "android",
75            target_os = "dragonfly",
76            target_os = "freebsd",
77            target_os = "fuchsia",
78            target_os = "ios",
79            target_os = "linux",
80            target_os = "macos",
81            target_os = "netbsd",
82            target_os = "openbsd",
83            target_os = "tvos",
84            target_os = "visionos",
85            target_os = "watchos"
86        ))]
87        if let Some(tclass) = config.tclass_v6 {
88            socket.set_tclass_v6(tclass)?;
89        }
90        if let Some(only_v6) = config.only_v6 {
91            socket.set_only_v6(only_v6)?;
92        }
93
94        // Linux: optional interface name
95        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
96        if let Some(iface) = &config.bind_device {
97            socket.bind_device(Some(iface.as_bytes()))?;
98        }
99
100        // bind to the specified address if provided
101        if let Some(addr) = config.bind_addr {
102            socket.bind(&addr.into())?;
103        }
104
105        Ok(Self { socket })
106    }
107
108    /// Create a socket of arbitrary type (STREAM or RAW).
109    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
110        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
111        socket.set_nonblocking(true)?;
112        Ok(Self { socket })
113    }
114
115    /// Convenience constructor for an IPv4 STREAM socket.
116    pub fn v4_stream() -> io::Result<Self> {
117        Self::new(Domain::IPV4, SockType::STREAM)
118    }
119
120    /// Convenience constructor for an IPv6 STREAM socket.
121    pub fn v6_stream() -> io::Result<Self> {
122        Self::new(Domain::IPV6, SockType::STREAM)
123    }
124
125    /// IPv4 RAW TCP. Requires administrator privileges.
126    pub fn raw_v4() -> io::Result<Self> {
127        Self::new(Domain::IPV4, SockType::RAW)
128    }
129
130    /// IPv6 RAW TCP. Requires administrator privileges.
131    pub fn raw_v6() -> io::Result<Self> {
132        Self::new(Domain::IPV6, SockType::RAW)
133    }
134
135    /// Connect to the target asynchronously.
136    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
137        // call connect
138        match self.socket.connect(&target.into()) {
139            Ok(_) => {
140                // connection completed immediately (rare case)
141                let std_stream: StdTcpStream = self.socket.into();
142                return TcpStream::from_std(std_stream);
143            }
144            Err(e)
145                if e.kind() == io::ErrorKind::WouldBlock
146                    || e.raw_os_error() == Some(libc::EINPROGRESS) =>
147            {
148                // wait until writable
149                let std_stream: StdTcpStream = self.socket.into();
150                let stream = TcpStream::from_std(std_stream)?;
151                stream.writable().await?;
152
153                // check the final connection state with SO_ERROR
154                if let Some(err) = stream.take_error()? {
155                    return Err(err);
156                }
157
158                return Ok(stream);
159            }
160            Err(e) => {
161                return Err(e);
162            }
163        }
164    }
165
166    /// Connect with a timeout to the target address.
167    pub async fn connect_timeout(
168        self,
169        target: SocketAddr,
170        timeout: Duration,
171    ) -> io::Result<TcpStream> {
172        match tokio::time::timeout(timeout, self.connect(target)).await {
173            Ok(result) => result,
174            Err(_) => Err(io::Error::new(
175                io::ErrorKind::TimedOut,
176                "connection timed out",
177            )),
178        }
179    }
180
181    /// Start listening for incoming connections.
182    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
183        self.socket.listen(backlog)?;
184
185        let std_listener: StdTcpListener = self.socket.into();
186        TcpListener::from_std(std_listener)
187    }
188
189    /// Send a raw TCP packet. Requires `SockType::RAW`.
190    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
191        self.socket.send_to(buf, &target.into())
192    }
193
194    /// Receive a raw TCP packet. Requires `SockType::RAW`.
195    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
196        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
197        let buf_maybe = unsafe {
198            std::slice::from_raw_parts_mut(
199                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
200                buf.len(),
201            )
202        };
203
204        let (n, addr) = self.socket.recv_from(buf_maybe)?;
205        let addr = addr
206            .as_socket()
207            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
208
209        Ok((n, addr))
210    }
211
212    /// Shutdown the socket.
213    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
214        self.socket.shutdown(how)
215    }
216
217    /// Set reuse address option.
218    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
219        self.socket.set_reuse_address(on)
220    }
221
222    /// Get reuse address option.
223    pub fn reuseaddr(&self) -> io::Result<bool> {
224        self.socket.reuse_address()
225    }
226
227    /// Set port reuse option where supported.
228    #[cfg(any(
229        target_os = "android",
230        target_os = "dragonfly",
231        target_os = "freebsd",
232        target_os = "fuchsia",
233        target_os = "ios",
234        target_os = "linux",
235        target_os = "macos",
236        target_os = "netbsd",
237        target_os = "openbsd",
238        target_os = "tvos",
239        target_os = "visionos",
240        target_os = "watchos"
241    ))]
242    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
243        self.socket.set_reuse_port(on)
244    }
245
246    /// Get port reuse option where supported.
247    #[cfg(any(
248        target_os = "android",
249        target_os = "dragonfly",
250        target_os = "freebsd",
251        target_os = "fuchsia",
252        target_os = "ios",
253        target_os = "linux",
254        target_os = "macos",
255        target_os = "netbsd",
256        target_os = "openbsd",
257        target_os = "tvos",
258        target_os = "visionos",
259        target_os = "watchos"
260    ))]
261    pub fn reuseport(&self) -> io::Result<bool> {
262        self.socket.reuse_port()
263    }
264
265    /// Set no delay option for TCP.
266    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
267        self.socket.set_nodelay(on)
268    }
269
270    /// Get no delay option for TCP.
271    pub fn nodelay(&self) -> io::Result<bool> {
272        self.socket.nodelay()
273    }
274
275    /// Set linger option for the socket.
276    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
277        self.socket.set_linger(dur)
278    }
279
280    /// Set the time-to-live for IPv4 packets.
281    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
282        self.socket.set_ttl(ttl)
283    }
284
285    /// Get the time-to-live for IPv4 packets.
286    pub fn ttl(&self) -> io::Result<u32> {
287        self.socket.ttl()
288    }
289
290    /// Set the hop limit for IPv6 packets.
291    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
292        self.socket.set_unicast_hops_v6(hops)
293    }
294
295    /// Get the hop limit for IPv6 packets.
296    pub fn hoplimit(&self) -> io::Result<u32> {
297        self.socket.unicast_hops_v6()
298    }
299
300    /// Set the keepalive option for the socket.
301    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
302        self.socket.set_keepalive(on)
303    }
304
305    /// Get the keepalive option for the socket.
306    pub fn keepalive(&self) -> io::Result<bool> {
307        self.socket.keepalive()
308    }
309
310    /// Set the receive buffer size.
311    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
312        self.socket.set_recv_buffer_size(size)
313    }
314
315    /// Get the receive buffer size.
316    pub fn recv_buffer_size(&self) -> io::Result<usize> {
317        self.socket.recv_buffer_size()
318    }
319
320    /// Set the send buffer size.
321    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
322        self.socket.set_send_buffer_size(size)
323    }
324
325    /// Get the send buffer size.
326    pub fn send_buffer_size(&self) -> io::Result<usize> {
327        self.socket.send_buffer_size()
328    }
329
330    /// Set IPv4 TOS / DSCP.
331    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
332        self.socket.set_tos(tos)
333    }
334
335    /// Get IPv4 TOS / DSCP.
336    pub fn tos(&self) -> io::Result<u32> {
337        self.socket.tos()
338    }
339
340    /// Set IPv6 traffic class where supported.
341    #[cfg(any(
342        target_os = "android",
343        target_os = "dragonfly",
344        target_os = "freebsd",
345        target_os = "fuchsia",
346        target_os = "ios",
347        target_os = "linux",
348        target_os = "macos",
349        target_os = "netbsd",
350        target_os = "openbsd",
351        target_os = "tvos",
352        target_os = "visionos",
353        target_os = "watchos"
354    ))]
355    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
356        self.socket.set_tclass_v6(tclass)
357    }
358
359    /// Get IPv6 traffic class where supported.
360    #[cfg(any(
361        target_os = "android",
362        target_os = "dragonfly",
363        target_os = "freebsd",
364        target_os = "fuchsia",
365        target_os = "ios",
366        target_os = "linux",
367        target_os = "macos",
368        target_os = "netbsd",
369        target_os = "openbsd",
370        target_os = "tvos",
371        target_os = "visionos",
372        target_os = "watchos"
373    ))]
374    pub fn tclass_v6(&self) -> io::Result<u32> {
375        self.socket.tclass_v6()
376    }
377
378    /// Set whether this socket is IPv6 only.
379    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
380        self.socket.set_only_v6(only_v6)
381    }
382
383    /// Get whether this socket is IPv6 only.
384    pub fn only_v6(&self) -> io::Result<bool> {
385        self.socket.only_v6()
386    }
387
388    /// Set the bind device for the socket (Linux specific).
389    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
390        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
391        return self.socket.bind_device(Some(iface.as_bytes()));
392
393        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
394        {
395            let _ = iface;
396            Err(io::Error::new(
397                io::ErrorKind::Unsupported,
398                "bind_device is not supported on this platform",
399            ))
400        }
401    }
402
403    /// Retrieve the local address of the socket.
404    pub fn local_addr(&self) -> io::Result<SocketAddr> {
405        self.socket
406            .local_addr()?
407            .as_socket()
408            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
409    }
410
411    /// Convert the internal socket into a Tokio `TcpStream`.
412    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
413        let std_stream: StdTcpStream = self.socket.into();
414        TcpStream::from_std(std_stream)
415    }
416
417    /// Construct from a raw `socket2::Socket`.
418    pub fn from_socket(socket: Socket) -> Self {
419        Self { socket }
420    }
421
422    /// Borrow the inner `socket2::Socket`.
423    pub fn socket(&self) -> &Socket {
424        &self.socket
425    }
426
427    /// Consume and return the inner `socket2::Socket`.
428    pub fn into_socket(self) -> Socket {
429        self.socket
430    }
431
432    /// Extract the RAW file descriptor for Unix.
433    #[cfg(unix)]
434    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
435        use std::os::fd::AsRawFd;
436        self.socket.as_raw_fd()
437    }
438
439    /// Extract the RAW socket handle for Windows.
440    #[cfg(windows)]
441    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
442        use std::os::windows::io::AsRawSocket;
443        self.socket.as_raw_socket()
444    }
445}