message_io/adapters/
framed_tcp.rs

1pub use socket2::{TcpKeepalive};
2
3use crate::network::adapter::{
4    Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
5    ListeningInfo, PendingStatus,
6};
7use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
8use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE};
9
10use mio::net::{TcpListener, TcpStream};
11use mio::event::{Source};
12
13use socket2::{Socket};
14
15use std::net::{SocketAddr};
16use std::io::{self, ErrorKind, Read, Write};
17use std::cell::{RefCell};
18use std::mem::{forget, MaybeUninit};
19#[cfg(target_os = "windows")]
20use std::os::windows::io::{FromRawSocket, AsRawSocket};
21#[cfg(not(target_os = "windows"))]
22use std::os::{fd::AsRawFd, unix::io::FromRawFd};
23
24const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1
25
26#[derive(Clone, Debug, Default)]
27pub struct FramedTcpConnectConfig {
28    keepalive: Option<TcpKeepalive>,
29}
30
31impl FramedTcpConnectConfig {
32    /// Enables TCP keepalive settings on the socket.
33    pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
34        self.keepalive = Some(keepalive);
35        self
36    }
37}
38
39#[derive(Clone, Debug, Default)]
40pub struct FramedTcpListenConfig {
41    keepalive: Option<TcpKeepalive>,
42}
43
44impl FramedTcpListenConfig {
45    /// Enables TCP keepalive settings on client connection sockets.
46    pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
47        self.keepalive = Some(keepalive);
48        self
49    }
50}
51
52pub(crate) struct FramedTcpAdapter;
53impl Adapter for FramedTcpAdapter {
54    type Remote = RemoteResource;
55    type Local = LocalResource;
56}
57
58pub(crate) struct RemoteResource {
59    stream: TcpStream,
60    decoder: RefCell<Decoder>,
61    keepalive: Option<TcpKeepalive>,
62}
63
64// SAFETY:
65// That RefCell<Decoder> can be used with Sync because the decoder is only used in the read_event,
66// that will be called always from the same thread. This way, we save the cost of a Mutex.
67unsafe impl Sync for RemoteResource {}
68
69impl RemoteResource {
70    fn new(stream: TcpStream, keepalive: Option<TcpKeepalive>) -> Self {
71        Self { stream, decoder: RefCell::new(Decoder::default()), keepalive }
72    }
73}
74
75impl Resource for RemoteResource {
76    fn source(&mut self) -> &mut dyn Source {
77        &mut self.stream
78    }
79}
80
81impl Remote for RemoteResource {
82    fn connect_with(
83        config: TransportConnect,
84        remote_addr: RemoteAddr,
85    ) -> io::Result<ConnectionInfo<Self>> {
86        let config = match config {
87            TransportConnect::FramedTcp(config) => config,
88            _ => panic!("Internal error: Got wrong config"),
89        };
90        let peer_addr = *remote_addr.socket_addr();
91        let stream = TcpStream::connect(peer_addr)?;
92        let local_addr = stream.local_addr()?;
93        Ok(ConnectionInfo {
94            remote: RemoteResource::new(stream, config.keepalive),
95            local_addr,
96            peer_addr,
97        })
98    }
99
100    fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
101        let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
102        let mut input_buffer = unsafe { buffer.assume_init() }; // Avoid to initialize the array
103
104        loop {
105            let mut stream = &self.stream;
106            match stream.read(&mut input_buffer) {
107                Ok(0) => break ReadStatus::Disconnected,
108                Ok(size) => {
109                    let data = &input_buffer[..size];
110                    log::trace!("Decoding {} bytes", data.len());
111                    self.decoder.borrow_mut().decode(data, |decoded_data| {
112                        process_data(decoded_data);
113                    });
114                }
115                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
116                Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
117                    break ReadStatus::WaitNextEvent
118                }
119                Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
120                    break ReadStatus::Disconnected
121                }
122                Err(err) => {
123                    log::error!("TCP receive error: {}", err);
124                    break ReadStatus::Disconnected; // should not happen
125                }
126            }
127        }
128    }
129
130    fn send(&self, data: &[u8]) -> SendStatus {
131        let mut buf = [0; MAX_ENCODED_SIZE]; // used to avoid a heap allocation
132        let encoded_size = encoding::encode_size(data, &mut buf);
133
134        let mut total_bytes_sent = 0;
135        let total_bytes = encoded_size.len() + data.len();
136        loop {
137            let data_to_send = match total_bytes_sent < encoded_size.len() {
138                true => &encoded_size[total_bytes_sent..],
139                false => &data[total_bytes_sent - encoded_size.len()..],
140            };
141
142            let mut stream = &self.stream;
143            match stream.write(data_to_send) {
144                Ok(bytes_sent) => {
145                    total_bytes_sent += bytes_sent;
146                    if total_bytes_sent == total_bytes {
147                        break SendStatus::Sent;
148                    }
149                }
150                Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
151                Err(err) => {
152                    log::error!("TCP receive error: {}", err);
153                    break SendStatus::ResourceNotFound; // should not happen
154                }
155            }
156        }
157    }
158
159    fn pending(&self, _readiness: Readiness) -> PendingStatus {
160        let status = super::tcp::check_stream_ready(&self.stream);
161
162        if status == PendingStatus::Ready {
163            if let Some(keepalive) = &self.keepalive {
164                #[cfg(target_os = "windows")]
165                let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
166                #[cfg(not(target_os = "windows"))]
167                let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
168
169                if let Err(e) = socket.set_tcp_keepalive(keepalive) {
170                    log::warn!("TCP set keepalive error: {}", e);
171                }
172
173                // Don't drop so the underlying socket is not closed.
174                forget(socket);
175            }
176        }
177
178        status
179    }
180}
181
182pub(crate) struct LocalResource {
183    listener: TcpListener,
184    keepalive: Option<TcpKeepalive>,
185}
186
187impl Resource for LocalResource {
188    fn source(&mut self) -> &mut dyn Source {
189        &mut self.listener
190    }
191}
192
193impl Local for LocalResource {
194    type Remote = RemoteResource;
195
196    fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
197        let config = match config {
198            TransportListen::FramedTcp(config) => config,
199            _ => panic!("Internal error: Got wrong config"),
200        };
201        let listener = TcpListener::bind(addr)?;
202        let local_addr = listener.local_addr().unwrap();
203        Ok(ListeningInfo {
204            local: { LocalResource { listener, keepalive: config.keepalive } },
205            local_addr,
206        })
207    }
208
209    fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
210        loop {
211            match self.listener.accept() {
212                Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
213                    addr,
214                    RemoteResource::new(stream, self.keepalive.clone()),
215                )),
216                Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
217                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
218                Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
219            }
220        }
221    }
222}