message_io/adapters/
tcp.rs

1pub use socket2::{TcpKeepalive};
2
3use crate::network::adapter::{
4    Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
5    ListeningInfo, PendingStatus,
6};
7use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
8
9use mio::net::{TcpListener, TcpStream};
10use mio::event::{Source};
11
12use socket2::{Socket, Domain, Type, Protocol};
13
14use std::net::{SocketAddr};
15#[cfg(unix)]
16use std::ffi::{CString};
17use std::io::{self, ErrorKind, Read, Write};
18#[cfg(target_os = "macos")]
19use std::num::NonZeroU32;
20use std::mem::{forget, MaybeUninit};
21use std::os::raw::c_int;
22#[cfg(target_os = "windows")]
23use std::os::windows::io::{FromRawSocket, AsRawSocket};
24#[cfg(not(target_os = "windows"))]
25use std::os::{fd::AsRawFd, unix::io::FromRawFd};
26
27/// Size of the internal reading buffer.
28/// It implies that at most the generated [`crate::network::NetEvent::Message`]
29/// will contains a chunk of data of this value.
30pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1
31
32/// The maximum length of the pending (unaccepted) connection queue of a listener.
33pub const LISTENER_BACKLOG: c_int = 1024;
34
35#[derive(Clone, Debug, Default)]
36pub struct TcpConnectConfig {
37    bind_device: Option<String>,
38    source_address: Option<SocketAddr>,
39    keepalive: Option<TcpKeepalive>,
40}
41
42impl TcpConnectConfig {
43    /// Bind the TCP connection to a specific interface, identified by its name. This option works
44    /// in Unix, on other systems, it will be ignored.
45    pub fn with_bind_device(mut self, device: String) -> Self {
46        self.bind_device = Some(device);
47        self
48    }
49
50    /// Enables TCP keepalive settings on the socket.
51    pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
52        self.keepalive = Some(keepalive);
53        self
54    }
55
56    /// Specify the source address and port.
57    pub fn with_source_address(mut self, source_address: SocketAddr) -> Self {
58        self.source_address = Some(source_address);
59        self
60    }
61}
62
63#[derive(Clone, Debug, Default)]
64pub struct TcpListenConfig {
65    bind_device: Option<String>,
66    keepalive: Option<TcpKeepalive>,
67}
68
69impl TcpListenConfig {
70    /// Bind the TCP listener to a specific interface, identified by its name. This option works in
71    /// Unix, on other systems, it will be ignored.
72    pub fn with_bind_device(mut self, device: String) -> Self {
73        self.bind_device = Some(device);
74        self
75    }
76
77    /// Enables TCP keepalive settings on client connection sockets.
78    pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
79        self.keepalive = Some(keepalive);
80        self
81    }
82}
83
84pub(crate) struct TcpAdapter;
85impl Adapter for TcpAdapter {
86    type Remote = RemoteResource;
87    type Local = LocalResource;
88}
89
90pub(crate) struct RemoteResource {
91    stream: TcpStream,
92    keepalive: Option<TcpKeepalive>,
93}
94
95impl Resource for RemoteResource {
96    fn source(&mut self) -> &mut dyn Source {
97        &mut self.stream
98    }
99}
100
101impl Remote for RemoteResource {
102    fn connect_with(
103        config: TransportConnect,
104        remote_addr: RemoteAddr,
105    ) -> io::Result<ConnectionInfo<Self>> {
106        let config = match config {
107            TransportConnect::Tcp(config) => config,
108            _ => panic!("Internal error: Got wrong config"),
109        };
110        let peer_addr = *remote_addr.socket_addr();
111
112        let socket = Socket::new(
113            match peer_addr {
114                SocketAddr::V4 { .. } => Domain::IPV4,
115                SocketAddr::V6 { .. } => Domain::IPV6,
116            },
117            Type::STREAM,
118            Some(Protocol::TCP),
119        )?;
120        socket.set_nonblocking(true)?;
121
122        if let Some(source_address) = config.source_address {
123            socket.bind(&source_address.into())?;
124        }
125
126        #[cfg(unix)]
127        if let Some(bind_device) = config.bind_device {
128            let device = CString::new(bind_device)?;
129
130            #[cfg(not(target_os = "macos"))]
131            socket.bind_device(Some(device.as_bytes()))?;
132
133            #[cfg(target_os = "macos")]
134            match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
135                Some(index) => socket.bind_device_by_index_v4(Some(index))?,
136                None => {
137                    return Err(io::Error::new(
138                        ErrorKind::NotFound,
139                        "Bind device interface not found",
140                    ))
141                }
142            }
143        }
144
145        match socket.connect(&peer_addr.into()) {
146            #[cfg(unix)]
147            Err(e) if e.raw_os_error() != Some(libc::EINPROGRESS) => return Err(e),
148            #[cfg(windows)]
149            Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
150            _ => {}
151        }
152
153        let stream = TcpStream::from_std(socket.into());
154        let local_addr = stream.local_addr()?;
155        Ok(ConnectionInfo {
156            remote: Self { stream, keepalive: config.keepalive },
157            local_addr,
158            peer_addr,
159        })
160    }
161
162    fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
163        let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
164        let mut input_buffer = unsafe { buffer.assume_init() }; // Avoid to initialize the array
165
166        loop {
167            let mut stream = &self.stream;
168            match stream.read(&mut input_buffer) {
169                Ok(0) => break ReadStatus::Disconnected,
170                Ok(size) => process_data(&input_buffer[..size]),
171                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
172                Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
173                    break ReadStatus::WaitNextEvent
174                }
175                Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
176                    break ReadStatus::Disconnected
177                }
178                Err(err) => {
179                    log::error!("TCP receive error: {}", err);
180                    break ReadStatus::Disconnected; // should not happen
181                }
182            }
183        }
184    }
185
186    fn send(&self, data: &[u8]) -> SendStatus {
187        // TODO: The current implementation implies an active waiting,
188        // improve it using POLLIN instead to avoid active waiting.
189        // Note: Despite the fear that an active waiting could generate,
190        // this only occurs in the case when the receiver is full because reads slower that it sends.
191        let mut total_bytes_sent = 0;
192        loop {
193            let mut stream = &self.stream;
194            match stream.write(&data[total_bytes_sent..]) {
195                Ok(bytes_sent) => {
196                    total_bytes_sent += bytes_sent;
197                    if total_bytes_sent == data.len() {
198                        break SendStatus::Sent;
199                    }
200                }
201                Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
202
203                // Others errors are considered fatal for the connection.
204                // a Event::Disconnection will be generated later.
205                Err(err) => {
206                    log::error!("TCP receive error: {}", err);
207                    break SendStatus::ResourceNotFound; // should not happen
208                }
209            }
210        }
211    }
212
213    fn pending(&self, _readiness: Readiness) -> PendingStatus {
214        let status = check_stream_ready(&self.stream);
215
216        if status == PendingStatus::Ready {
217            if let Some(keepalive) = &self.keepalive {
218                #[cfg(target_os = "windows")]
219                let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
220                #[cfg(not(target_os = "windows"))]
221                let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
222
223                if let Err(e) = socket.set_tcp_keepalive(keepalive) {
224                    log::warn!("TCP set keepalive error: {}", e);
225                }
226
227                // Don't drop so the underlying socket is not closed.
228                forget(socket);
229            }
230        }
231
232        status
233    }
234}
235
236/// Check if a TcpStream can be considered connected.
237pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {
238    // A multiplatform non-blocking way to determine if the TCP stream is connected:
239    // Extracted from: https://github.com/tokio-rs/mio/issues/1486
240    if let Ok(Some(_)) = stream.take_error() {
241        return PendingStatus::Disconnected;
242    }
243    match stream.peer_addr() {
244        Ok(_) => PendingStatus::Ready,
245        Err(err) if err.kind() == io::ErrorKind::NotConnected => PendingStatus::Incomplete,
246        Err(err) if err.kind() == io::ErrorKind::InvalidInput => PendingStatus::Incomplete,
247        Err(_) => PendingStatus::Disconnected,
248    }
249}
250
251pub(crate) struct LocalResource {
252    listener: TcpListener,
253    keepalive: Option<TcpKeepalive>,
254}
255
256impl Resource for LocalResource {
257    fn source(&mut self) -> &mut dyn Source {
258        &mut self.listener
259    }
260}
261
262impl Local for LocalResource {
263    type Remote = RemoteResource;
264
265    fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
266        let config = match config {
267            TransportListen::Tcp(config) => config,
268            _ => panic!("Internal error: Got wrong config"),
269        };
270
271        let socket = Socket::new(
272            match addr {
273                SocketAddr::V4 { .. } => Domain::IPV4,
274                SocketAddr::V6 { .. } => Domain::IPV6,
275            },
276            Type::STREAM,
277            Some(Protocol::TCP),
278        )?;
279        socket.set_nonblocking(true)?;
280        socket.set_reuse_address(true)?;
281
282        #[cfg(unix)]
283        if let Some(bind_device) = config.bind_device {
284            let device = CString::new(bind_device)?;
285
286            #[cfg(not(target_os = "macos"))]
287            socket.bind_device(Some(device.as_bytes()))?;
288
289            #[cfg(target_os = "macos")]
290            match NonZeroU32::new(unsafe { libc::if_nametoindex(device.as_ptr()) }) {
291                Some(index) => socket.bind_device_by_index_v4(Some(index))?,
292                None => {
293                    return Err(io::Error::new(
294                        ErrorKind::NotFound,
295                        "Bind device interface not found",
296                    ))
297                }
298            }
299        }
300
301        socket.bind(&addr.into())?;
302        socket.listen(LISTENER_BACKLOG)?;
303
304        let listener = TcpListener::from_std(socket.into());
305
306        let local_addr = listener.local_addr().unwrap();
307        Ok(ListeningInfo {
308            local: { LocalResource { listener, keepalive: config.keepalive } },
309            local_addr,
310        })
311    }
312
313    fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
314        loop {
315            match self.listener.accept() {
316                Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
317                    addr,
318                    RemoteResource { stream, keepalive: self.keepalive.clone() },
319                )),
320                Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
321                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
322                Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
323            }
324        }
325    }
326}