nex_socket/tcp/
async_impl.rs1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpStream as StdTcpStream, TcpListener as StdTcpListener};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
18
19 if let Some(flag) = config.reuseaddr {
20 socket.set_reuse_address(flag)?;
21 }
22 if let Some(flag) = config.nodelay {
23 socket.set_nodelay(flag)?;
24 }
25 if let Some(ttl) = config.ttl {
26 socket.set_ttl(ttl)?;
27 }
28
29 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
30 if let Some(iface) = &config.bind_device {
31 socket.bind_device(Some(iface.as_bytes()))?;
32 }
33
34 if let Some(addr) = config.bind_addr {
35 socket.bind(&addr.into())?;
36 }
37
38 socket.set_nonblocking(true)?;
39
40 Ok(Self { socket })
41 }
42
43 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
45 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
46 socket.set_nonblocking(true)?;
47 Ok(Self { socket })
48 }
49
50 pub fn v4_stream() -> io::Result<Self> {
52 Self::new(Domain::IPV4, SockType::STREAM)
53 }
54
55 pub fn v6_stream() -> io::Result<Self> {
57 Self::new(Domain::IPV6, SockType::STREAM)
58 }
59
60 pub fn raw_v4() -> io::Result<Self> {
62 Self::new(Domain::IPV4, SockType::RAW)
63 }
64
65 pub fn raw_v6() -> io::Result<Self> {
67 Self::new(Domain::IPV6, SockType::RAW)
68 }
69
70 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
72 match self.socket.connect(&target.into()) {
74 Ok(_) => {
75 let std_stream: StdTcpStream = self.socket.into();
77 return TcpStream::from_std(std_stream);
78 }
79 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(libc::EINPROGRESS) => {
80 let std_stream: StdTcpStream = self.socket.into();
82 let stream = TcpStream::from_std(std_stream)?;
83 stream.writable().await?;
84
85 if let Some(err) = stream.take_error()? {
87 return Err(err);
88 }
89
90 return Ok(stream);
91 }
92 Err(e) => {
93 println!("Failed to connect: {}", e);
94 return Err(e);
95 }
96 }
97 }
98
99 pub async fn connect_timeout(self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
101 match tokio::time::timeout(timeout, self.connect(target)).await {
102 Ok(result) => result,
103 Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
104 }
105 }
106
107 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
109 self.socket.listen(backlog)?;
110
111 let std_listener: StdTcpListener = self.socket.into();
112 TcpListener::from_std(std_listener)
113 }
114
115 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
117 self.socket.send_to(buf, &target.into())
118 }
119
120 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
122 let buf_maybe = unsafe {
124 std::slice::from_raw_parts_mut(
125 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
126 buf.len(),
127 )
128 };
129
130 let (n, addr) = self.socket.recv_from(buf_maybe)?;
131 let addr = addr.as_socket().ok_or_else(|| {
132 io::Error::new(io::ErrorKind::InvalidData, "invalid address format")
133 })?;
134
135 Ok((n, addr))
136 }
137
138 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
141 self.socket.set_reuse_address(on)
142 }
143
144 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
145 self.socket.set_nodelay(on)
146 }
147
148 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
149 self.socket.set_linger(dur)
150 }
151
152 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
153 self.socket.set_ttl(ttl)
154 }
155
156 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
157 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
158 return self.socket.bind_device(Some(iface.as_bytes()));
159
160 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
161 {
162 let _ = iface;
163 Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS"))
164 }
165 }
166
167 pub fn local_addr(&self) -> io::Result<SocketAddr> {
169 self.socket.local_addr()?.as_socket().ok_or_else(|| {
170 io::Error::new(io::ErrorKind::Other, "Failed to get socket address")
171 })
172 }
173
174 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
176 let std_stream: StdTcpStream = self.socket.into();
177 TcpStream::from_std(std_stream)
178 }
179
180 #[cfg(unix)]
181 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
182 use std::os::fd::AsRawFd;
183 self.socket.as_raw_fd()
184 }
185
186 #[cfg(windows)]
187 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
188 use std::os::windows::io::AsRawSocket;
189 self.socket.as_raw_socket()
190 }
191}