nex_socket/tcp/
async_impl.rs1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
18
19 if let Some(flag) = config.reuseaddr {
20 socket.set_reuse_address(flag)?;
21 }
22 if let Some(flag) = config.nodelay {
23 socket.set_nodelay(flag)?;
24 }
25 if let Some(ttl) = config.ttl {
26 socket.set_ttl(ttl)?;
27 }
28
29 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
30 if let Some(iface) = &config.bind_device {
31 socket.bind_device(Some(iface.as_bytes()))?;
32 }
33
34 if let Some(addr) = config.bind_addr {
35 socket.bind(&addr.into())?;
36 }
37
38 socket.set_nonblocking(true)?;
39
40 Ok(Self { socket })
41 }
42
43 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
45 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
46 socket.set_nonblocking(true)?;
47 Ok(Self { socket })
48 }
49
50 pub fn v4_stream() -> io::Result<Self> {
52 Self::new(Domain::IPV4, SockType::STREAM)
53 }
54
55 pub fn v6_stream() -> io::Result<Self> {
57 Self::new(Domain::IPV6, SockType::STREAM)
58 }
59
60 pub fn raw_v4() -> io::Result<Self> {
62 Self::new(Domain::IPV4, SockType::RAW)
63 }
64
65 pub fn raw_v6() -> io::Result<Self> {
67 Self::new(Domain::IPV6, SockType::RAW)
68 }
69
70 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
72 match self.socket.connect(&target.into()) {
74 Ok(_) => {
75 let std_stream: StdTcpStream = self.socket.into();
77 return TcpStream::from_std(std_stream);
78 }
79 Err(e)
80 if e.kind() == io::ErrorKind::WouldBlock
81 || e.raw_os_error() == Some(libc::EINPROGRESS) =>
82 {
83 let std_stream: StdTcpStream = self.socket.into();
85 let stream = TcpStream::from_std(std_stream)?;
86 stream.writable().await?;
87
88 if let Some(err) = stream.take_error()? {
90 return Err(err);
91 }
92
93 return Ok(stream);
94 }
95 Err(e) => {
96 println!("Failed to connect: {}", e);
97 return Err(e);
98 }
99 }
100 }
101
102 pub async fn connect_timeout(
104 self,
105 target: SocketAddr,
106 timeout: Duration,
107 ) -> io::Result<TcpStream> {
108 match tokio::time::timeout(timeout, self.connect(target)).await {
109 Ok(result) => result,
110 Err(_) => Err(io::Error::new(
111 io::ErrorKind::TimedOut,
112 "connection timed out",
113 )),
114 }
115 }
116
117 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
119 self.socket.listen(backlog)?;
120
121 let std_listener: StdTcpListener = self.socket.into();
122 TcpListener::from_std(std_listener)
123 }
124
125 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
127 self.socket.send_to(buf, &target.into())
128 }
129
130 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
132 let buf_maybe = unsafe {
134 std::slice::from_raw_parts_mut(
135 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
136 buf.len(),
137 )
138 };
139
140 let (n, addr) = self.socket.recv_from(buf_maybe)?;
141 let addr = addr
142 .as_socket()
143 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
144
145 Ok((n, addr))
146 }
147
148 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
151 self.socket.set_reuse_address(on)
152 }
153
154 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
155 self.socket.set_nodelay(on)
156 }
157
158 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
159 self.socket.set_linger(dur)
160 }
161
162 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
163 self.socket.set_ttl(ttl)
164 }
165
166 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
167 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
168 return self.socket.bind_device(Some(iface.as_bytes()));
169
170 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
171 {
172 let _ = iface;
173 Err(io::Error::new(
174 io::ErrorKind::Unsupported,
175 "bind_device not supported on this OS",
176 ))
177 }
178 }
179
180 pub fn local_addr(&self) -> io::Result<SocketAddr> {
182 self.socket
183 .local_addr()?
184 .as_socket()
185 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to get socket address"))
186 }
187
188 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
190 let std_stream: StdTcpStream = self.socket.into();
191 TcpStream::from_std(std_stream)
192 }
193
194 #[cfg(unix)]
195 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
196 use std::os::fd::AsRawFd;
197 self.socket.as_raw_fd()
198 }
199
200 #[cfg(windows)]
201 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
202 use std::os::windows::io::AsRawSocket;
203 self.socket.as_raw_socket()
204 }
205}