nex_socket/tcp/
async_impl.rs

1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8/// Asynchronous TCP socket built on top of Tokio.
9#[derive(Debug)]
10pub struct AsyncTcpSocket {
11    socket: Socket,
12}
13
14impl AsyncTcpSocket {
15    /// Create a socket from the given configuration without connecting.
16    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17        let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
18
19        if let Some(flag) = config.reuseaddr {
20            socket.set_reuse_address(flag)?;
21        }
22        if let Some(flag) = config.nodelay {
23            socket.set_nodelay(flag)?;
24        }
25        if let Some(ttl) = config.ttl {
26            socket.set_ttl(ttl)?;
27        }
28
29        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
30        if let Some(iface) = &config.bind_device {
31            socket.bind_device(Some(iface.as_bytes()))?;
32        }
33
34        if let Some(addr) = config.bind_addr {
35            socket.bind(&addr.into())?;
36        }
37
38        socket.set_nonblocking(true)?;
39
40        Ok(Self { socket })
41    }
42
43    /// Create a socket of arbitrary type (STREAM or RAW).
44    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
45        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
46        socket.set_nonblocking(true)?;
47        Ok(Self { socket })
48    }
49
50    /// Convenience constructor for an IPv4 STREAM socket.
51    pub fn v4_stream() -> io::Result<Self> {
52        Self::new(Domain::IPV4, SockType::STREAM)
53    }
54
55    /// Convenience constructor for an IPv6 STREAM socket.
56    pub fn v6_stream() -> io::Result<Self> {
57        Self::new(Domain::IPV6, SockType::STREAM)
58    }
59
60    /// IPv4 RAW TCP. Requires administrator privileges.
61    pub fn raw_v4() -> io::Result<Self> {
62        Self::new(Domain::IPV4, SockType::RAW)
63    }
64
65    /// IPv6 RAW TCP. Requires administrator privileges.
66    pub fn raw_v6() -> io::Result<Self> {
67        Self::new(Domain::IPV6, SockType::RAW)
68    }
69
70    /// Connect to the target asynchronously.
71    pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
72        // call connect
73        match self.socket.connect(&target.into()) {
74            Ok(_) => {
75                // connection completed immediately (rare case)
76                let std_stream: StdTcpStream = self.socket.into();
77                return TcpStream::from_std(std_stream);
78            }
79            Err(e)
80                if e.kind() == io::ErrorKind::WouldBlock
81                    || e.raw_os_error() == Some(libc::EINPROGRESS) =>
82            {
83                // wait until writable
84                let std_stream: StdTcpStream = self.socket.into();
85                let stream = TcpStream::from_std(std_stream)?;
86                stream.writable().await?;
87
88                // check the final connection state with SO_ERROR
89                if let Some(err) = stream.take_error()? {
90                    return Err(err);
91                }
92
93                return Ok(stream);
94            }
95            Err(e) => {
96                println!("Failed to connect: {}", e);
97                return Err(e);
98            }
99        }
100    }
101
102    /// Connect with a timeout to the target address.
103    pub async fn connect_timeout(
104        self,
105        target: SocketAddr,
106        timeout: Duration,
107    ) -> io::Result<TcpStream> {
108        match tokio::time::timeout(timeout, self.connect(target)).await {
109            Ok(result) => result,
110            Err(_) => Err(io::Error::new(
111                io::ErrorKind::TimedOut,
112                "connection timed out",
113            )),
114        }
115    }
116
117    /// Start listening for incoming connections.
118    pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
119        self.socket.listen(backlog)?;
120
121        let std_listener: StdTcpListener = self.socket.into();
122        TcpListener::from_std(std_listener)
123    }
124
125    /// Send a raw TCP packet. Requires `SockType::RAW`.
126    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
127        self.socket.send_to(buf, &target.into())
128    }
129
130    /// Receive a raw TCP packet. Requires `SockType::RAW`.
131    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
132        // Safety: `MaybeUninit<u8>` has the same memory layout as `u8`.
133        let buf_maybe = unsafe {
134            std::slice::from_raw_parts_mut(
135                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
136                buf.len(),
137            )
138        };
139
140        let (n, addr) = self.socket.recv_from(buf_maybe)?;
141        let addr = addr
142            .as_socket()
143            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
144
145        Ok((n, addr))
146    }
147
148    // --- option helpers ---
149
150    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
151        self.socket.set_reuse_address(on)
152    }
153
154    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
155        self.socket.set_nodelay(on)
156    }
157
158    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
159        self.socket.set_linger(dur)
160    }
161
162    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
163        self.socket.set_ttl(ttl)
164    }
165
166    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
167        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
168        return self.socket.bind_device(Some(iface.as_bytes()));
169
170        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
171        {
172            let _ = iface;
173            Err(io::Error::new(
174                io::ErrorKind::Unsupported,
175                "bind_device not supported on this OS",
176            ))
177        }
178    }
179
180    /// Retrieve the local address of the socket.
181    pub fn local_addr(&self) -> io::Result<SocketAddr> {
182        self.socket
183            .local_addr()?
184            .as_socket()
185            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
186    }
187
188    /// Convert the internal socket into a Tokio `TcpStream`.
189    pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
190        let std_stream: StdTcpStream = self.socket.into();
191        TcpStream::from_std(std_stream)
192    }
193
194    #[cfg(unix)]
195    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
196        use std::os::fd::AsRawFd;
197        self.socket.as_raw_fd()
198    }
199
200    #[cfg(windows)]
201    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
202        use std::os::windows::io::AsRawSocket;
203        self.socket.as_raw_socket()
204    }
205}