nex_socket/tcp/
async_impl.rs1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 let socket = Socket::new(
18 config.socket_family.to_domain(),
19 config.socket_type.to_sock_type(),
20 Some(Protocol::TCP),
21 )?;
22
23 socket.set_nonblocking(true)?;
24
25 if let Some(flag) = config.reuseaddr {
27 socket.set_reuse_address(flag)?;
28 }
29 if let Some(flag) = config.nodelay {
30 socket.set_nodelay(flag)?;
31 }
32 if let Some(ttl) = config.ttl {
33 socket.set_ttl(ttl)?;
34 }
35 if let Some(hoplimit) = config.hoplimit {
36 socket.set_unicast_hops_v6(hoplimit)?;
37 }
38 if let Some(keepalive) = config.keepalive {
39 socket.set_keepalive(keepalive)?;
40 }
41 if let Some(timeout) = config.read_timeout {
42 socket.set_read_timeout(Some(timeout))?;
43 }
44 if let Some(timeout) = config.write_timeout {
45 socket.set_write_timeout(Some(timeout))?;
46 }
47
48 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
50 if let Some(iface) = &config.bind_device {
51 socket.bind_device(Some(iface.as_bytes()))?;
52 }
53
54 if let Some(addr) = config.bind_addr {
56 socket.bind(&addr.into())?;
57 }
58
59 Ok(Self { socket })
60 }
61
62 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
64 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
65 socket.set_nonblocking(true)?;
66 Ok(Self { socket })
67 }
68
69 pub fn v4_stream() -> io::Result<Self> {
71 Self::new(Domain::IPV4, SockType::STREAM)
72 }
73
74 pub fn v6_stream() -> io::Result<Self> {
76 Self::new(Domain::IPV6, SockType::STREAM)
77 }
78
79 pub fn raw_v4() -> io::Result<Self> {
81 Self::new(Domain::IPV4, SockType::RAW)
82 }
83
84 pub fn raw_v6() -> io::Result<Self> {
86 Self::new(Domain::IPV6, SockType::RAW)
87 }
88
89 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
91 match self.socket.connect(&target.into()) {
93 Ok(_) => {
94 let std_stream: StdTcpStream = self.socket.into();
96 return TcpStream::from_std(std_stream);
97 }
98 Err(e)
99 if e.kind() == io::ErrorKind::WouldBlock
100 || e.raw_os_error() == Some(libc::EINPROGRESS) =>
101 {
102 let std_stream: StdTcpStream = self.socket.into();
104 let stream = TcpStream::from_std(std_stream)?;
105 stream.writable().await?;
106
107 if let Some(err) = stream.take_error()? {
109 return Err(err);
110 }
111
112 return Ok(stream);
113 }
114 Err(e) => {
115 return Err(e);
116 }
117 }
118 }
119
120 pub async fn connect_timeout(
122 self,
123 target: SocketAddr,
124 timeout: Duration,
125 ) -> io::Result<TcpStream> {
126 match tokio::time::timeout(timeout, self.connect(target)).await {
127 Ok(result) => result,
128 Err(_) => Err(io::Error::new(
129 io::ErrorKind::TimedOut,
130 "connection timed out",
131 )),
132 }
133 }
134
135 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
137 self.socket.listen(backlog)?;
138
139 let std_listener: StdTcpListener = self.socket.into();
140 TcpListener::from_std(std_listener)
141 }
142
143 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
145 self.socket.send_to(buf, &target.into())
146 }
147
148 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
150 let buf_maybe = unsafe {
152 std::slice::from_raw_parts_mut(
153 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
154 buf.len(),
155 )
156 };
157
158 let (n, addr) = self.socket.recv_from(buf_maybe)?;
159 let addr = addr
160 .as_socket()
161 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
162
163 Ok((n, addr))
164 }
165
166 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
168 self.socket.shutdown(how)
169 }
170
171 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
173 self.socket.set_reuse_address(on)
174 }
175
176 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
178 self.socket.set_nodelay(on)
179 }
180
181 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
183 self.socket.set_linger(dur)
184 }
185
186 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
188 self.socket.set_ttl(ttl)
189 }
190
191 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
193 self.socket.set_unicast_hops_v6(hops)
194 }
195
196 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
198 self.socket.set_keepalive(on)
199 }
200
201 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
203 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
204 return self.socket.bind_device(Some(iface.as_bytes()));
205
206 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
207 {
208 let _ = iface;
209 Err(io::Error::new(
210 io::ErrorKind::Unsupported,
211 "bind_device not supported on this OS",
212 ))
213 }
214 }
215
216 pub fn local_addr(&self) -> io::Result<SocketAddr> {
218 self.socket
219 .local_addr()?
220 .as_socket()
221 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
222 }
223
224 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
226 let std_stream: StdTcpStream = self.socket.into();
227 TcpStream::from_std(std_stream)
228 }
229
230 #[cfg(unix)]
232 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
233 use std::os::fd::AsRawFd;
234 self.socket.as_raw_fd()
235 }
236
237 #[cfg(windows)]
239 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
240 use std::os::windows::io::AsRawSocket;
241 self.socket.as_raw_socket()
242 }
243}