1use std::fmt::{Debug, Display};
23use std::hash::Hash;
24use std::io;
25use std::net::{Shutdown, TcpStream, ToSocketAddrs};
26use std::os::fd::IntoRawFd;
27use std::os::unix::io::AsRawFd;
28use std::time::Duration;
29
30use cyphernet::addr::{Addr, InetHost, NetAddr};
31
32pub trait Address: Addr + Send + Clone + Eq + Hash + Debug + Display {}
33impl<T> Address for T where T: Addr + Send + Clone + Eq + Hash + Debug + Display {}
34
35pub trait NetStream: Send + io::Read + io::Write {}
36
37pub trait AsConnection {
38 type Connection: NetConnection;
39 fn as_connection(&self) -> &Self::Connection;
40}
41
42pub trait NetConnection: NetStream + AsRawFd + Debug {
44 type Addr: Address;
45
46 fn connect_blocking(addr: Self::Addr, timeout: Duration) -> io::Result<Self>
47 where Self: Sized;
48
49 #[cfg(feature = "nonblocking")]
50 fn connect_nonblocking(addr: Self::Addr, timeout: Duration) -> io::Result<Self>
51 where Self: Sized;
52
53 #[cfg(feature = "nonblocking")]
54 fn connect_reusable_nonblocking(
55 local_addr: Self::Addr,
56 remote_addr: Self::Addr,
57 ) -> io::Result<Self>
58 where
59 Self: Sized;
60
61 fn shutdown(&mut self, how: Shutdown) -> io::Result<()>;
62
63 fn remote_addr(&self) -> io::Result<Self::Addr>;
64 fn local_addr(&self) -> io::Result<Self::Addr>;
65
66 #[cfg(feature = "nonblocking")]
67 fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()>;
68 fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()>;
69 fn set_write_timeout(&mut self, dur: Option<Duration>) -> io::Result<()>;
70 fn read_timeout(&self) -> io::Result<Option<Duration>>;
71 fn write_timeout(&self) -> io::Result<Option<Duration>>;
72
73 fn peek(&self, buf: &mut [u8]) -> io::Result<usize>;
74
75 fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()>;
76 fn nodelay(&self) -> io::Result<bool>;
77 fn set_ttl(&mut self, ttl: u32) -> io::Result<()>;
78 fn ttl(&self) -> io::Result<u32>;
79 fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>;
80
81 fn try_clone(&self) -> io::Result<Self>
82 where Self: Sized;
83 fn take_error(&self) -> io::Result<Option<io::Error>>;
84}
85
86impl NetStream for TcpStream {}
87impl NetConnection for TcpStream {
88 type Addr = NetAddr<InetHost>;
89
90 fn connect_blocking(addr: Self::Addr, timeout: Duration) -> io::Result<Self> {
91 let socket_addr = addr.to_socket_addrs()?.next().ok_or(io::ErrorKind::AddrNotAvailable)?;
92 TcpStream::connect_timeout(&socket_addr, timeout)
93 }
94
95 #[cfg(feature = "nonblocking")]
96 fn connect_nonblocking(addr: Self::Addr, timeout: Duration) -> io::Result<Self> {
97 Ok(socket2::Socket::connect_nonblocking(addr, timeout)?.into())
98 }
99
100 #[cfg(feature = "nonblocking")]
101 fn connect_reusable_nonblocking(
102 local_addr: Self::Addr,
103 remote_addr: Self::Addr,
104 ) -> io::Result<Self> {
105 Ok(socket2::Socket::connect_reusable_nonblocking(local_addr, remote_addr)?.into())
106 }
107
108 fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { TcpStream::shutdown(self, how) }
109
110 fn remote_addr(&self) -> io::Result<Self::Addr> { Ok(TcpStream::peer_addr(self)?.into()) }
111
112 fn local_addr(&self) -> io::Result<Self::Addr> { Ok(TcpStream::local_addr(self)?.into()) }
113
114 #[cfg(feature = "nonblocking")]
115 fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()> {
116 use std::os::fd::FromRawFd;
117 let socket = unsafe { socket2::Socket::from_raw_fd(self.as_raw_fd()) };
118 socket.set_tcp_keepalive(keepalive)?;
119 let _ = socket.into_raw_fd(); Ok(())
121 }
122 fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
123 TcpStream::set_read_timeout(self, dur)
124 }
125 fn set_write_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
126 TcpStream::set_write_timeout(self, dur)
127 }
128 fn read_timeout(&self) -> io::Result<Option<Duration>> { TcpStream::read_timeout(self) }
129 fn write_timeout(&self) -> io::Result<Option<Duration>> { TcpStream::write_timeout(self) }
130
131 fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { TcpStream::peek(self, buf) }
132
133 fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
134 TcpStream::set_nodelay(self, nodelay)
135 }
136 fn nodelay(&self) -> io::Result<bool> { TcpStream::nodelay(self) }
137 fn set_ttl(&mut self, ttl: u32) -> io::Result<()> { TcpStream::set_ttl(self, ttl) }
138 fn ttl(&self) -> io::Result<u32> { TcpStream::ttl(self) }
139 fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
140 TcpStream::set_nonblocking(self, nonblocking)
141 }
142
143 fn try_clone(&self) -> io::Result<TcpStream> { TcpStream::try_clone(self) }
144 fn take_error(&self) -> io::Result<Option<io::Error>> { TcpStream::take_error(self) }
145}
146
147#[cfg(feature = "socket2")]
148impl NetStream for socket2::Socket {}
149#[cfg(feature = "socket2")]
150impl NetConnection for socket2::Socket {
151 type Addr = NetAddr<InetHost>;
152
153 fn connect_blocking(addr: Self::Addr, timeout: Duration) -> io::Result<Self> {
154 TcpStream::connect_blocking(addr, timeout).map(socket2::Socket::from)
155 }
156
157 #[cfg(feature = "nonblocking")]
158 fn connect_nonblocking(addr: Self::Addr, _timeout: Duration) -> io::Result<Self> {
159 let addr = addr.to_socket_addrs()?.next().ok_or(io::ErrorKind::AddrNotAvailable)?;
160 let socket =
161 socket2::Socket::new(socket2::Domain::for_address(addr), socket2::Type::STREAM, None)?;
162 socket.set_nonblocking(true)?;
163
164 match socket2::Socket::connect(&socket, &addr.into()) {
165 Ok(()) => {
166 #[cfg(feature = "log")]
167 log::debug!(target: "netservices", "Connected to {}", addr);
168 }
169 Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {
170 #[cfg(feature = "log")]
171 log::debug!(target: "netservices", "Connecting to {} in a non-blocking way", addr);
172 }
173 Err(e) if e.raw_os_error() == Some(libc::EALREADY) => {
174 #[cfg(feature = "log")]
175 log::error!(target: "netservices", "Can't connect to {}: address already in use", addr);
176 return Err(io::Error::from(io::ErrorKind::AlreadyExists));
177 }
178 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
179 #[cfg(feature = "log")]
180 log::error!(target: "netservices", "Can't connect to {} in a non-blocking way", addr);
181 }
182 Err(e) => {
183 #[cfg(feature = "log")]
184 log::debug!(target: "netservices", "Error connecting to {}: {}", addr, e);
185 return Err(e);
186 }
187 }
188 Ok(socket)
189 }
190
191 fn connect_reusable_nonblocking(
192 local_addr: Self::Addr,
193 remote_addr: Self::Addr,
194 ) -> io::Result<Self> {
195 let local_addr = local_addr.to_socket_addrs()?.next().ok_or(io::ErrorKind::InvalidInput)?;
196 let remote_addr =
197 remote_addr.to_socket_addrs()?.next().ok_or(io::ErrorKind::AddrNotAvailable)?;
198 let socket = socket2::Socket::new(
199 socket2::Domain::for_address(local_addr),
200 socket2::Type::STREAM,
201 None,
202 )?;
203 socket.set_nonblocking(true)?;
204 socket.set_reuse_address(true)?;
205 #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
206 {
207 socket.set_reuse_port(true)?;
208 }
209 socket2::Socket::bind(&socket, &local_addr.into())?;
210
211 match socket2::Socket::connect(&socket, &remote_addr.into()) {
212 Ok(()) => {
213 #[cfg(feature = "log")]
214 log::debug!(target: "netservices", "Connected to {}", remote_addr);
215 }
216 Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {
217 #[cfg(feature = "log")]
218 log::debug!(target: "netservices", "Connecting to {} in a non-blocking way", remote_addr);
219 }
220 Err(e) if e.raw_os_error() == Some(libc::EALREADY) => {
221 #[cfg(feature = "log")]
222 log::error!(target: "netservices", "Can't connect to {}: address already in use", remote_addr);
223 return Err(io::Error::from(io::ErrorKind::AlreadyExists));
224 }
225 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
226 #[cfg(feature = "log")]
227 log::error!(target: "netservices", "Can't connect to {} in a non-blocking way", remote_addr);
228 }
229 Err(e) => {
230 #[cfg(feature = "log")]
231 log::debug!(target: "netservices", "Error connecting to {}: {}", remote_addr, e);
232 return Err(e);
233 }
234 }
235 Ok(socket)
236 }
237
238 fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { socket2::Socket::shutdown(self, how) }
239
240 fn remote_addr(&self) -> io::Result<Self::Addr> {
241 Ok(socket2::Socket::peer_addr(self)?
242 .as_socket()
243 .ok_or::<io::Error>(io::ErrorKind::NotFound.into())?
244 .into())
245 }
246
247 fn local_addr(&self) -> io::Result<Self::Addr> {
248 Ok(socket2::Socket::local_addr(self)?
249 .as_socket()
250 .ok_or::<io::Error>(io::ErrorKind::NotFound.into())?
251 .into())
252 }
253
254 #[cfg(feature = "nonblocking")]
255 fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()> {
256 socket2::Socket::set_tcp_keepalive(self, keepalive)
257 }
258
259 fn set_read_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
260 socket2::Socket::set_read_timeout(self, dur)
261 }
262
263 fn set_write_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
264 socket2::Socket::set_write_timeout(self, dur)
265 }
266
267 fn read_timeout(&self) -> io::Result<Option<Duration>> { socket2::Socket::read_timeout(self) }
268
269 fn write_timeout(&self) -> io::Result<Option<Duration>> { socket2::Socket::write_timeout(self) }
270
271 fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
272 use std::mem::MaybeUninit;
273
274 let mut buf2 = vec![MaybeUninit::<u8>::uninit(); buf.len()];
275 let len = socket2::Socket::peek(self, &mut buf2)?;
276 for i in 0..len {
277 buf[i] = unsafe { buf2[i].assume_init() };
278 }
279 Ok(len)
280 }
281
282 fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
283 socket2::Socket::set_nodelay(self, nodelay)
284 }
285
286 fn nodelay(&self) -> io::Result<bool> { socket2::Socket::nodelay(self) }
287
288 fn set_ttl(&mut self, ttl: u32) -> io::Result<()> { socket2::Socket::set_ttl(self, ttl) }
289
290 fn ttl(&self) -> io::Result<u32> { socket2::Socket::ttl(self) }
291
292 fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
293 socket2::Socket::set_nonblocking(self, nonblocking)
294 }
295
296 fn try_clone(&self) -> io::Result<Self> { socket2::Socket::try_clone(self) }
297
298 fn take_error(&self) -> io::Result<Option<io::Error>> { socket2::Socket::take_error(self) }
299}