1pub use socket2::{TcpKeepalive};
2
3use crate::network::adapter::{
4 Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
5 ListeningInfo, PendingStatus,
6};
7use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
8
9use mio::net::{TcpListener, TcpStream};
10use mio::event::{Source};
11
12use socket2::{Socket, Domain, Type, Protocol};
13
14use std::net::{SocketAddr};
15#[cfg(unix)]
16use std::ffi::{CString};
17use std::io::{self, ErrorKind, Read, Write};
18#[cfg(target_os = "macos")]
19use std::num::NonZeroU32;
20use std::mem::{forget, MaybeUninit};
21use std::os::raw::c_int;
22#[cfg(target_os = "windows")]
23use std::os::windows::io::{FromRawSocket, AsRawSocket};
24#[cfg(not(target_os = "windows"))]
25use std::os::{fd::AsRawFd, unix::io::FromRawFd};
26
27pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; pub const LISTENER_BACKLOG: c_int = 1024;
34
35#[derive(Clone, Debug, Default)]
36pub struct TcpConnectConfig {
37 bind_device: Option<String>,
38 source_address: Option<SocketAddr>,
39 keepalive: Option<TcpKeepalive>,
40}
41
42impl TcpConnectConfig {
43 pub fn with_bind_device(mut self, device: String) -> Self {
46 self.bind_device = Some(device);
47 self
48 }
49
50 pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
52 self.keepalive = Some(keepalive);
53 self
54 }
55
56 pub fn with_source_address(mut self, source_address: SocketAddr) -> Self {
58 self.source_address = Some(source_address);
59 self
60 }
61}
62
63#[derive(Clone, Debug, Default)]
64pub struct TcpListenConfig {
65 bind_device: Option<String>,
66 keepalive: Option<TcpKeepalive>,
67}
68
69impl TcpListenConfig {
70 pub fn with_bind_device(mut self, device: String) -> Self {
73 self.bind_device = Some(device);
74 self
75 }
76
77 pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
79 self.keepalive = Some(keepalive);
80 self
81 }
82}
83
84pub(crate) struct TcpAdapter;
85impl Adapter for TcpAdapter {
86 type Remote = RemoteResource;
87 type Local = LocalResource;
88}
89
90pub(crate) struct RemoteResource {
91 stream: TcpStream,
92 keepalive: Option<TcpKeepalive>,
93}
94
95impl Resource for RemoteResource {
96 fn source(&mut self) -> &mut dyn Source {
97 &mut self.stream
98 }
99}
100
101impl Remote for RemoteResource {
102 fn connect_with(
103 config: TransportConnect,
104 remote_addr: RemoteAddr,
105 ) -> io::Result<ConnectionInfo<Self>> {
106 let config = match config {
107 TransportConnect::Tcp(config) => config,
108 _ => panic!("Internal error: Got wrong config"),
109 };
110 let peer_addr = *remote_addr.socket_addr();
111
112 let socket = Socket::new(
113 match peer_addr {
114 SocketAddr::V4 { .. } => Domain::IPV4,
115 SocketAddr::V6 { .. } => Domain::IPV6,
116 },
117 Type::STREAM,
118 Some(Protocol::TCP),
119 )?;
120 socket.set_nonblocking(true)?;
121
122 if let Some(source_address) = config.source_address {
123 socket.bind(&source_address.into())?;
124 }
125
126 #[cfg(unix)]
127 if let Some(bind_device) = config.bind_device {
128 let device = CString::new(bind_device)?;
129
130 #[cfg(not(target_os = "macos"))]
131 socket.bind_device(Some(device.as_bytes()))?;
132
133 #[cfg(target_os = "macos")]
134 match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
135 Some(index) => socket.bind_device_by_index_v4(Some(index))?,
136 None => {
137 return Err(io::Error::new(
138 ErrorKind::NotFound,
139 "Bind device interface not found",
140 ))
141 }
142 }
143 }
144
145 match socket.connect(&peer_addr.into()) {
146 #[cfg(unix)]
147 Err(e) if e.raw_os_error() != Some(libc::EINPROGRESS) => return Err(e),
148 #[cfg(windows)]
149 Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
150 _ => {}
151 }
152
153 let stream = TcpStream::from_std(socket.into());
154 let local_addr = stream.local_addr()?;
155 Ok(ConnectionInfo {
156 remote: Self { stream, keepalive: config.keepalive },
157 local_addr,
158 peer_addr,
159 })
160 }
161
162 fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
163 let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
164 let mut input_buffer = unsafe { buffer.assume_init() }; loop {
167 let mut stream = &self.stream;
168 match stream.read(&mut input_buffer) {
169 Ok(0) => break ReadStatus::Disconnected,
170 Ok(size) => process_data(&input_buffer[..size]),
171 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
172 Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
173 break ReadStatus::WaitNextEvent
174 }
175 Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
176 break ReadStatus::Disconnected
177 }
178 Err(err) => {
179 log::error!("TCP receive error: {}", err);
180 break ReadStatus::Disconnected; }
182 }
183 }
184 }
185
186 fn send(&self, data: &[u8]) -> SendStatus {
187 let mut total_bytes_sent = 0;
192 loop {
193 let mut stream = &self.stream;
194 match stream.write(&data[total_bytes_sent..]) {
195 Ok(bytes_sent) => {
196 total_bytes_sent += bytes_sent;
197 if total_bytes_sent == data.len() {
198 break SendStatus::Sent;
199 }
200 }
201 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
202
203 Err(err) => {
206 log::error!("TCP receive error: {}", err);
207 break SendStatus::ResourceNotFound; }
209 }
210 }
211 }
212
213 fn pending(&self, _readiness: Readiness) -> PendingStatus {
214 let status = check_stream_ready(&self.stream);
215
216 if status == PendingStatus::Ready {
217 if let Some(keepalive) = &self.keepalive {
218 #[cfg(target_os = "windows")]
219 let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
220 #[cfg(not(target_os = "windows"))]
221 let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
222
223 if let Err(e) = socket.set_tcp_keepalive(keepalive) {
224 log::warn!("TCP set keepalive error: {}", e);
225 }
226
227 forget(socket);
229 }
230 }
231
232 status
233 }
234}
235
236pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {
238 if let Ok(Some(_)) = stream.take_error() {
241 return PendingStatus::Disconnected;
242 }
243 match stream.peer_addr() {
244 Ok(_) => PendingStatus::Ready,
245 Err(err) if err.kind() == io::ErrorKind::NotConnected => PendingStatus::Incomplete,
246 Err(err) if err.kind() == io::ErrorKind::InvalidInput => PendingStatus::Incomplete,
247 Err(_) => PendingStatus::Disconnected,
248 }
249}
250
251pub(crate) struct LocalResource {
252 listener: TcpListener,
253 keepalive: Option<TcpKeepalive>,
254}
255
256impl Resource for LocalResource {
257 fn source(&mut self) -> &mut dyn Source {
258 &mut self.listener
259 }
260}
261
262impl Local for LocalResource {
263 type Remote = RemoteResource;
264
265 fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
266 let config = match config {
267 TransportListen::Tcp(config) => config,
268 _ => panic!("Internal error: Got wrong config"),
269 };
270
271 let socket = Socket::new(
272 match addr {
273 SocketAddr::V4 { .. } => Domain::IPV4,
274 SocketAddr::V6 { .. } => Domain::IPV6,
275 },
276 Type::STREAM,
277 Some(Protocol::TCP),
278 )?;
279 socket.set_nonblocking(true)?;
280 socket.set_reuse_address(true)?;
281
282 #[cfg(unix)]
283 if let Some(bind_device) = config.bind_device {
284 let device = CString::new(bind_device)?;
285
286 #[cfg(not(target_os = "macos"))]
287 socket.bind_device(Some(device.as_bytes()))?;
288
289 #[cfg(target_os = "macos")]
290 match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
291 Some(index) => socket.bind_device_by_index_v4(Some(index))?,
292 None => {
293 return Err(io::Error::new(
294 ErrorKind::NotFound,
295 "Bind device interface not found",
296 ))
297 }
298 }
299 }
300
301 socket.bind(&addr.into())?;
302 socket.listen(LISTENER_BACKLOG)?;
303
304 let listener = TcpListener::from_std(socket.into());
305
306 let local_addr = listener.local_addr().unwrap();
307 Ok(ListeningInfo {
308 local: { LocalResource { listener, keepalive: config.keepalive } },
309 local_addr,
310 })
311 }
312
313 fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
314 loop {
315 match self.listener.accept() {
316 Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
317 addr,
318 RemoteResource { stream, keepalive: self.keepalive.clone() },
319 )),
320 Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
321 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
322 Err(err) => break log::error!("TCP accept error: {}", err), }
324 }
325 }
326}