Skip to main content

nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        config.validate()?;
18
19        let socket = Socket::new(
20            config.socket_family.to_domain(),
21            config.socket_type.to_sock_type(),
22            Some(Protocol::TCP),
23        )?;
24
25        socket.set_nonblocking(true)?;
26
27        // Set socket options based on configuration
28        if let Some(flag) = config.reuseaddr {
29            socket.set_reuse_address(flag)?;
30        }
31        #[cfg(any(
32            target_os = "android",
33            target_os = "dragonfly",
34            target_os = "freebsd",
35            target_os = "fuchsia",
36            target_os = "ios",
37            target_os = "linux",
38            target_os = "macos",
39            target_os = "netbsd",
40            target_os = "openbsd",
41            target_os = "tvos",
42            target_os = "visionos",
43            target_os = "watchos"
44        ))]
45        if let Some(flag) = config.reuseport {
46            socket.set_reuse_port(flag)?;
47        }
48        if let Some(flag) = config.nodelay {
49            socket.set_nodelay(flag)?;
50        }
51        if let Some(ttl) = config.ttl {
52            socket.set_ttl(ttl)?;
53        }
54        if let Some(hoplimit) = config.hoplimit {
55            socket.set_unicast_hops_v6(hoplimit)?;
56        }
57        if let Some(keepalive) = config.keepalive {
58            socket.set_keepalive(keepalive)?;
59        }
60        if let Some(timeout) = config.read_timeout {
61            socket.set_read_timeout(Some(timeout))?;
62        }
63        if let Some(timeout) = config.write_timeout {
64            socket.set_write_timeout(Some(timeout))?;
65        }
66        if let Some(size) = config.recv_buffer_size {
67            socket.set_recv_buffer_size(size)?;
68        }
69        if let Some(size) = config.send_buffer_size {
70            socket.set_send_buffer_size(size)?;
71        }
72        if let Some(tos) = config.tos {
73            socket.set_tos(tos)?;
74        }
75        #[cfg(any(
76            target_os = "android",
77            target_os = "dragonfly",
78            target_os = "freebsd",
79            target_os = "fuchsia",
80            target_os = "ios",
81            target_os = "linux",
82            target_os = "macos",
83            target_os = "netbsd",
84            target_os = "openbsd",
85            target_os = "tvos",
86            target_os = "visionos",
87            target_os = "watchos"
88        ))]
89        if let Some(tclass) = config.tclass_v6 {
90            socket.set_tclass_v6(tclass)?;
91        }
92        if let Some(only_v6) = config.only_v6 {
93            socket.set_only_v6(only_v6)?;
94        }
95
96        // Linux: optional interface name
97        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
98        if let Some(iface) = &config.bind_device {
99            socket.bind_device(Some(iface.as_bytes()))?;
100        }
101
102        // bind to the specified address if provided
103        if let Some(addr) = config.bind_addr {
104            socket.bind(&addr.into())?;
105        }
106
107        Ok(Self { socket })
108    }
109
110    /// Create a socket of arbitrary type (STREAM or RAW).
111    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
112        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
113        socket.set_nonblocking(true)?;
114        Ok(Self { socket })
115    }
116
117    /// Convenience constructor for an IPv4 STREAM socket.
118    pub fn v4_stream() -> io::Result<Self> {
119        Self::new(Domain::IPV4, SockType::STREAM)
120    }
121
122    /// Convenience constructor for an IPv6 STREAM socket.
123    pub fn v6_stream() -> io::Result<Self> {
124        Self::new(Domain::IPV6, SockType::STREAM)
125    }
126
127    /// IPv4 RAW TCP. Requires administrator privileges.
128    pub fn raw_v4() -> io::Result<Self> {
129        Self::new(Domain::IPV4, SockType::RAW)
130    }
131
132    /// IPv6 RAW TCP. Requires administrator privileges.
133    pub fn raw_v6() -> io::Result<Self> {
134        Self::new(Domain::IPV6, SockType::RAW)
135    }
136
137    /// Connect to the target asynchronously.
138    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
139        // call connect
140        match self.socket.connect(&target.into()) {
141            Ok(_) => {
142                // connection completed immediately (rare case)
143                let std_stream: StdTcpStream = self.socket.into();
144                return TcpStream::from_std(std_stream);
145            }
146            Err(e)
147                if e.kind() == io::ErrorKind::WouldBlock
148                    || e.raw_os_error() == Some(libc::EINPROGRESS) =>
149            {
150                // wait until writable
151                let std_stream: StdTcpStream = self.socket.into();
152                let stream = TcpStream::from_std(std_stream)?;
153                stream.writable().await?;
154
155                // check the final connection state with SO_ERROR
156                if let Some(err) = stream.take_error()? {
157                    return Err(err);
158                }
159
160                return Ok(stream);
161            }
162            Err(e) => {
163                return Err(e);
164            }
165        }
166    }
167
168    /// Connect with a timeout to the target address.
169    pub async fn connect_timeout(
170        self,
171        target: SocketAddr,
172        timeout: Duration,
173    ) -> io::Result<TcpStream> {
174        match tokio::time::timeout(timeout, self.connect(target)).await {
175            Ok(result) => result,
176            Err(_) => Err(io::Error::new(
177                io::ErrorKind::TimedOut,
178                "connection timed out",
179            )),
180        }
181    }
182
183    /// Start listening for incoming connections.
184    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
185        self.socket.listen(backlog)?;
186
187        let std_listener: StdTcpListener = self.socket.into();
188        TcpListener::from_std(std_listener)
189    }
190
191    /// Send a raw TCP packet. Requires `SockType::RAW`.
192    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
193        self.socket.send_to(buf, &target.into())
194    }
195
196    /// Receive a raw TCP packet. Requires `SockType::RAW`.
197    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
198        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
199        let buf_maybe = unsafe {
200            std::slice::from_raw_parts_mut(
201                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
202                buf.len(),
203            )
204        };
205
206        let (n, addr) = self.socket.recv_from(buf_maybe)?;
207        let addr = addr
208            .as_socket()
209            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
210
211        Ok((n, addr))
212    }
213
214    /// Shutdown the socket.
215    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
216        self.socket.shutdown(how)
217    }
218
219    /// Set reuse address option.
220    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
221        self.socket.set_reuse_address(on)
222    }
223
224    /// Get reuse address option.
225    pub fn reuseaddr(&self) -> io::Result<bool> {
226        self.socket.reuse_address()
227    }
228
229    /// Set port reuse option where supported.
230    #[cfg(any(
231        target_os = "android",
232        target_os = "dragonfly",
233        target_os = "freebsd",
234        target_os = "fuchsia",
235        target_os = "ios",
236        target_os = "linux",
237        target_os = "macos",
238        target_os = "netbsd",
239        target_os = "openbsd",
240        target_os = "tvos",
241        target_os = "visionos",
242        target_os = "watchos"
243    ))]
244    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
245        self.socket.set_reuse_port(on)
246    }
247
248    /// Get port reuse option where supported.
249    #[cfg(any(
250        target_os = "android",
251        target_os = "dragonfly",
252        target_os = "freebsd",
253        target_os = "fuchsia",
254        target_os = "ios",
255        target_os = "linux",
256        target_os = "macos",
257        target_os = "netbsd",
258        target_os = "openbsd",
259        target_os = "tvos",
260        target_os = "visionos",
261        target_os = "watchos"
262    ))]
263    pub fn reuseport(&self) -> io::Result<bool> {
264        self.socket.reuse_port()
265    }
266
267    /// Set no delay option for TCP.
268    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
269        self.socket.set_nodelay(on)
270    }
271
272    /// Get no delay option for TCP.
273    pub fn nodelay(&self) -> io::Result<bool> {
274        self.socket.nodelay()
275    }
276
277    /// Set linger option for the socket.
278    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
279        self.socket.set_linger(dur)
280    }
281
282    /// Set the time-to-live for IPv4 packets.
283    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
284        self.socket.set_ttl(ttl)
285    }
286
287    /// Get the time-to-live for IPv4 packets.
288    pub fn ttl(&self) -> io::Result<u32> {
289        self.socket.ttl()
290    }
291
292    /// Set the hop limit for IPv6 packets.
293    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
294        self.socket.set_unicast_hops_v6(hops)
295    }
296
297    /// Get the hop limit for IPv6 packets.
298    pub fn hoplimit(&self) -> io::Result<u32> {
299        self.socket.unicast_hops_v6()
300    }
301
302    /// Set the keepalive option for the socket.
303    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
304        self.socket.set_keepalive(on)
305    }
306
307    /// Get the keepalive option for the socket.
308    pub fn keepalive(&self) -> io::Result<bool> {
309        self.socket.keepalive()
310    }
311
312    /// Set the receive buffer size.
313    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
314        self.socket.set_recv_buffer_size(size)
315    }
316
317    /// Get the receive buffer size.
318    pub fn recv_buffer_size(&self) -> io::Result<usize> {
319        self.socket.recv_buffer_size()
320    }
321
322    /// Set the send buffer size.
323    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
324        self.socket.set_send_buffer_size(size)
325    }
326
327    /// Get the send buffer size.
328    pub fn send_buffer_size(&self) -> io::Result<usize> {
329        self.socket.send_buffer_size()
330    }
331
332    /// Set IPv4 TOS / DSCP.
333    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
334        self.socket.set_tos(tos)
335    }
336
337    /// Get IPv4 TOS / DSCP.
338    pub fn tos(&self) -> io::Result<u32> {
339        self.socket.tos()
340    }
341
342    /// Set IPv6 traffic class where supported.
343    #[cfg(any(
344        target_os = "android",
345        target_os = "dragonfly",
346        target_os = "freebsd",
347        target_os = "fuchsia",
348        target_os = "ios",
349        target_os = "linux",
350        target_os = "macos",
351        target_os = "netbsd",
352        target_os = "openbsd",
353        target_os = "tvos",
354        target_os = "visionos",
355        target_os = "watchos"
356    ))]
357    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
358        self.socket.set_tclass_v6(tclass)
359    }
360
361    /// Get IPv6 traffic class where supported.
362    #[cfg(any(
363        target_os = "android",
364        target_os = "dragonfly",
365        target_os = "freebsd",
366        target_os = "fuchsia",
367        target_os = "ios",
368        target_os = "linux",
369        target_os = "macos",
370        target_os = "netbsd",
371        target_os = "openbsd",
372        target_os = "tvos",
373        target_os = "visionos",
374        target_os = "watchos"
375    ))]
376    pub fn tclass_v6(&self) -> io::Result<u32> {
377        self.socket.tclass_v6()
378    }
379
380    /// Set whether this socket is IPv6 only.
381    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
382        self.socket.set_only_v6(only_v6)
383    }
384
385    /// Get whether this socket is IPv6 only.
386    pub fn only_v6(&self) -> io::Result<bool> {
387        self.socket.only_v6()
388    }
389
390    /// Set the bind device for the socket (Linux specific).
391    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
392        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
393        return self.socket.bind_device(Some(iface.as_bytes()));
394
395        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
396        {
397            let _ = iface;
398            Err(io::Error::new(
399                io::ErrorKind::Unsupported,
400                "bind_device is not supported on this platform",
401            ))
402        }
403    }
404
405    /// Retrieve the local address of the socket.
406    pub fn local_addr(&self) -> io::Result<SocketAddr> {
407        self.socket
408            .local_addr()?
409            .as_socket()
410            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
411    }
412
413    /// Convert the internal socket into a Tokio `TcpStream`.
414    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
415        let std_stream: StdTcpStream = self.socket.into();
416        TcpStream::from_std(std_stream)
417    }
418
419    /// Construct from a raw `socket2::Socket`.
420    pub fn from_socket(socket: Socket) -> Self {
421        Self { socket }
422    }
423
424    /// Borrow the inner `socket2::Socket`.
425    pub fn socket(&self) -> &Socket {
426        &self.socket
427    }
428
429    /// Consume and return the inner `socket2::Socket`.
430    pub fn into_socket(self) -> Socket {
431        self.socket
432    }
433
434    /// Extract the RAW file descriptor for Unix.
435    #[cfg(unix)]
436    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
437        use std::os::fd::AsRawFd;
438        self.socket.as_raw_fd()
439    }
440
441    /// Extract the RAW socket handle for Windows.
442    #[cfg(windows)]
443    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
444        use std::os::windows::io::AsRawSocket;
445        self.socket.as_raw_socket()
446    }
447}