message_io/adapters/
framed_tcp.rs1pub use socket2::{TcpKeepalive};
2
3use crate::network::adapter::{
4 Resource, Remote, Local, Adapter, SendStatus, AcceptedType, ReadStatus, ConnectionInfo,
5 ListeningInfo, PendingStatus,
6};
7use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
8use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE};
9
10use mio::net::{TcpListener, TcpStream};
11use mio::event::{Source};
12
13use socket2::{Socket};
14
15use std::net::{SocketAddr};
16use std::io::{self, ErrorKind, Read, Write};
17use std::cell::{RefCell};
18use std::mem::{forget, MaybeUninit};
19#[cfg(target_os = "windows")]
20use std::os::windows::io::{FromRawSocket, AsRawSocket};
21#[cfg(not(target_os = "windows"))]
22use std::os::{fd::AsRawFd, unix::io::FromRawFd};
23
24const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; #[derive(Clone, Debug, Default)]
27pub struct FramedTcpConnectConfig {
28 keepalive: Option<TcpKeepalive>,
29}
30
31impl FramedTcpConnectConfig {
32 pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
34 self.keepalive = Some(keepalive);
35 self
36 }
37}
38
39#[derive(Clone, Debug, Default)]
40pub struct FramedTcpListenConfig {
41 keepalive: Option<TcpKeepalive>,
42}
43
44impl FramedTcpListenConfig {
45 pub fn with_keepalive(mut self, keepalive: TcpKeepalive) -> Self {
47 self.keepalive = Some(keepalive);
48 self
49 }
50}
51
52pub(crate) struct FramedTcpAdapter;
53impl Adapter for FramedTcpAdapter {
54 type Remote = RemoteResource;
55 type Local = LocalResource;
56}
57
58pub(crate) struct RemoteResource {
59 stream: TcpStream,
60 decoder: RefCell<Decoder>,
61 keepalive: Option<TcpKeepalive>,
62}
63
64unsafe impl Sync for RemoteResource {}
68
69impl RemoteResource {
70 fn new(stream: TcpStream, keepalive: Option<TcpKeepalive>) -> Self {
71 Self { stream, decoder: RefCell::new(Decoder::default()), keepalive }
72 }
73}
74
75impl Resource for RemoteResource {
76 fn source(&mut self) -> &mut dyn Source {
77 &mut self.stream
78 }
79}
80
81impl Remote for RemoteResource {
82 fn connect_with(
83 config: TransportConnect,
84 remote_addr: RemoteAddr,
85 ) -> io::Result<ConnectionInfo<Self>> {
86 let config = match config {
87 TransportConnect::FramedTcp(config) => config,
88 _ => panic!("Internal error: Got wrong config"),
89 };
90 let peer_addr = *remote_addr.socket_addr();
91 let stream = TcpStream::connect(peer_addr)?;
92 let local_addr = stream.local_addr()?;
93 Ok(ConnectionInfo {
94 remote: RemoteResource::new(stream, config.keepalive),
95 local_addr,
96 peer_addr,
97 })
98 }
99
100 fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
101 let buffer: MaybeUninit<[u8; INPUT_BUFFER_SIZE]> = MaybeUninit::uninit();
102 let mut input_buffer = unsafe { buffer.assume_init() }; loop {
105 let mut stream = &self.stream;
106 match stream.read(&mut input_buffer) {
107 Ok(0) => break ReadStatus::Disconnected,
108 Ok(size) => {
109 let data = &input_buffer[..size];
110 log::trace!("Decoding {} bytes", data.len());
111 self.decoder.borrow_mut().decode(data, |decoded_data| {
112 process_data(decoded_data);
113 });
114 }
115 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
116 Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
117 break ReadStatus::WaitNextEvent
118 }
119 Err(ref err) if err.kind() == ErrorKind::ConnectionReset => {
120 break ReadStatus::Disconnected
121 }
122 Err(err) => {
123 log::error!("TCP receive error: {}", err);
124 break ReadStatus::Disconnected; }
126 }
127 }
128 }
129
130 fn send(&self, data: &[u8]) -> SendStatus {
131 let mut buf = [0; MAX_ENCODED_SIZE]; let encoded_size = encoding::encode_size(data, &mut buf);
133
134 let mut total_bytes_sent = 0;
135 let total_bytes = encoded_size.len() + data.len();
136 loop {
137 let data_to_send = match total_bytes_sent < encoded_size.len() {
138 true => &encoded_size[total_bytes_sent..],
139 false => &data[total_bytes_sent - encoded_size.len()..],
140 };
141
142 let mut stream = &self.stream;
143 match stream.write(data_to_send) {
144 Ok(bytes_sent) => {
145 total_bytes_sent += bytes_sent;
146 if total_bytes_sent == total_bytes {
147 break SendStatus::Sent;
148 }
149 }
150 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue,
151 Err(err) => {
152 log::error!("TCP receive error: {}", err);
153 break SendStatus::ResourceNotFound; }
155 }
156 }
157 }
158
159 fn pending(&self, _readiness: Readiness) -> PendingStatus {
160 let status = super::tcp::check_stream_ready(&self.stream);
161
162 if status == PendingStatus::Ready {
163 if let Some(keepalive) = &self.keepalive {
164 #[cfg(target_os = "windows")]
165 let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
166 #[cfg(not(target_os = "windows"))]
167 let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };
168
169 if let Err(e) = socket.set_tcp_keepalive(keepalive) {
170 log::warn!("TCP set keepalive error: {}", e);
171 }
172
173 forget(socket);
175 }
176 }
177
178 status
179 }
180}
181
182pub(crate) struct LocalResource {
183 listener: TcpListener,
184 keepalive: Option<TcpKeepalive>,
185}
186
187impl Resource for LocalResource {
188 fn source(&mut self) -> &mut dyn Source {
189 &mut self.listener
190 }
191}
192
193impl Local for LocalResource {
194 type Remote = RemoteResource;
195
196 fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
197 let config = match config {
198 TransportListen::FramedTcp(config) => config,
199 _ => panic!("Internal error: Got wrong config"),
200 };
201 let listener = TcpListener::bind(addr)?;
202 let local_addr = listener.local_addr().unwrap();
203 Ok(ListeningInfo {
204 local: { LocalResource { listener, keepalive: config.keepalive } },
205 local_addr,
206 })
207 }
208
209 fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
210 loop {
211 match self.listener.accept() {
212 Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
213 addr,
214 RemoteResource::new(stream, self.keepalive.clone()),
215 )),
216 Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
217 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
218 Err(err) => break log::error!("TCP accept error: {}", err), }
220 }
221 }
222}