nex_socket/tcp/
async_impl.rs1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 let socket = Socket::new(
18 config.socket_family.to_domain(),
19 config.socket_type.to_sock_type(),
20 Some(Protocol::TCP),
21 )?;
22
23 socket.set_nonblocking(true)?;
24
25 if let Some(flag) = config.reuseaddr {
27 socket.set_reuse_address(flag)?;
28 }
29 if let Some(flag) = config.nodelay {
30 socket.set_nodelay(flag)?;
31 }
32 if let Some(ttl) = config.ttl {
33 socket.set_ttl(ttl)?;
34 }
35 if let Some(hoplimit) = config.hoplimit {
36 socket.set_unicast_hops_v6(hoplimit)?;
37 }
38 if let Some(keepalive) = config.keepalive {
39 socket.set_keepalive(keepalive)?;
40 }
41 if let Some(timeout) = config.read_timeout {
42 socket.set_read_timeout(Some(timeout))?;
43 }
44 if let Some(timeout) = config.write_timeout {
45 socket.set_write_timeout(Some(timeout))?;
46 }
47
48 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
50 if let Some(iface) = &config.bind_device {
51 socket.bind_device(Some(iface.as_bytes()))?;
52 }
53
54 if let Some(addr) = config.bind_addr {
56 socket.bind(&addr.into())?;
57 }
58
59 Ok(Self { socket })
60 }
61
62 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
64 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
65 socket.set_nonblocking(true)?;
66 Ok(Self { socket })
67 }
68
69 pub fn v4_stream() -> io::Result<Self> {
71 Self::new(Domain::IPV4, SockType::STREAM)
72 }
73
74 pub fn v6_stream() -> io::Result<Self> {
76 Self::new(Domain::IPV6, SockType::STREAM)
77 }
78
79 pub fn raw_v4() -> io::Result<Self> {
81 Self::new(Domain::IPV4, SockType::RAW)
82 }
83
84 pub fn raw_v6() -> io::Result<Self> {
86 Self::new(Domain::IPV6, SockType::RAW)
87 }
88
89 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
91 match self.socket.connect(&target.into()) {
93 Ok(_) => {
94 let std_stream: StdTcpStream = self.socket.into();
96 return TcpStream::from_std(std_stream);
97 }
98 Err(e)
99 if e.kind() == io::ErrorKind::WouldBlock
100 || e.raw_os_error() == Some(libc::EINPROGRESS) =>
101 {
102 let std_stream: StdTcpStream = self.socket.into();
104 let stream = TcpStream::from_std(std_stream)?;
105 stream.writable().await?;
106
107 if let Some(err) = stream.take_error()? {
109 return Err(err);
110 }
111
112 return Ok(stream);
113 }
114 Err(e) => {
115 println!("Failed to connect: {}", e);
116 return Err(e);
117 }
118 }
119 }
120
121 pub async fn connect_timeout(
123 self,
124 target: SocketAddr,
125 timeout: Duration,
126 ) -> io::Result<TcpStream> {
127 match tokio::time::timeout(timeout, self.connect(target)).await {
128 Ok(result) => result,
129 Err(_) => Err(io::Error::new(
130 io::ErrorKind::TimedOut,
131 "connection timed out",
132 )),
133 }
134 }
135
136 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
138 self.socket.listen(backlog)?;
139
140 let std_listener: StdTcpListener = self.socket.into();
141 TcpListener::from_std(std_listener)
142 }
143
144 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
146 self.socket.send_to(buf, &target.into())
147 }
148
149 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
151 let buf_maybe = unsafe {
153 std::slice::from_raw_parts_mut(
154 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
155 buf.len(),
156 )
157 };
158
159 let (n, addr) = self.socket.recv_from(buf_maybe)?;
160 let addr = addr
161 .as_socket()
162 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
163
164 Ok((n, addr))
165 }
166
167 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
169 self.socket.shutdown(how)
170 }
171
172 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
174 self.socket.set_reuse_address(on)
175 }
176
177 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
179 self.socket.set_nodelay(on)
180 }
181
182 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
184 self.socket.set_linger(dur)
185 }
186
187 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
189 self.socket.set_ttl(ttl)
190 }
191
192 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
194 self.socket.set_unicast_hops_v6(hops)
195 }
196
197 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
199 self.socket.set_keepalive(on)
200 }
201
202 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
204 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
205 return self.socket.bind_device(Some(iface.as_bytes()));
206
207 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
208 {
209 let _ = iface;
210 Err(io::Error::new(
211 io::ErrorKind::Unsupported,
212 "bind_device not supported on this OS",
213 ))
214 }
215 }
216
217 pub fn local_addr(&self) -> io::Result<SocketAddr> {
219 self.socket
220 .local_addr()?
221 .as_socket()
222 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
223 }
224
225 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
227 let std_stream: StdTcpStream = self.socket.into();
228 TcpStream::from_std(std_stream)
229 }
230
231 #[cfg(unix)]
233 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
234 use std::os::fd::AsRawFd;
235 self.socket.as_raw_fd()
236 }
237
238 #[cfg(windows)]
240 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
241 use std::os::windows::io::AsRawSocket;
242 self.socket.as_raw_socket()
243 }
244}