1use std::io::{self, Read, Write};
2use std::time::Instant;
3
4use crate::{DecodeError, Message};
5
6#[derive(Debug, Clone, Copy)]
8pub struct StreamConfig {
9 pub rx_buf_min_size: usize,
13 pub tx_buf_min_size: usize,
17 pub tx_buf_max_size: usize,
21 pub stream_write_timeout: std::time::Duration,
23 pub stream_connect_timeout: std::time::Duration,
26}
27
28impl Default for StreamConfig {
29 fn default() -> Self {
30 Self {
31 rx_buf_min_size: 32 * 1024,
32 tx_buf_min_size: 32 * 1024,
33 tx_buf_max_size: 1024 * 1024,
34 stream_write_timeout: std::time::Duration::from_secs(30),
35 stream_connect_timeout: std::time::Duration::from_secs(5),
36 }
37 }
38}
39
40#[derive(Debug)]
43pub struct MessageStream<T: Read + Write> {
44 config: StreamConfig,
46 stream: T,
48 rx_msg_buf: Vec<u8>,
50 tx_msg_buf: Vec<u8>,
52 tx_queue_points: queue_points::Queue,
54 ready: bool,
56 last_write: Instant,
58}
59
60#[derive(Debug)]
61pub enum ReadError {
62 MalformedMessage,
64 EndOfStream,
66 Error(io::Error),
68}
69
70impl<T: Read + Write> MessageStream<T> {
71 pub fn new(stream: T, config: StreamConfig) -> Self {
72 Self {
73 stream,
74 rx_msg_buf: Vec::new(),
75 tx_msg_buf: Vec::with_capacity(config.tx_buf_min_size),
76 tx_queue_points: Default::default(),
77 ready: false,
78 last_write: Instant::now(),
79 config,
80 }
81 }
82
83 pub fn read<M: Message, F: Fn(M)>(
86 &mut self,
87 rx_buf: &mut [u8],
88 on_msg: F,
89 ) -> Result<(), ReadError> {
90 'read: loop {
91 match self.stream.read(rx_buf).map(|read| &rx_buf[..read]) {
92 Ok(&[]) => break 'read Err(ReadError::EndOfStream),
93
94 Ok(received) => {
95 if !self.rx_msg_buf.is_empty() {
96 self.rx_msg_buf.extend_from_slice(received);
97 'decode: loop {
98 if !self.rx_msg_buf.is_empty() {
99 match M::decode(&self.rx_msg_buf) {
100 Ok((message, consumed)) => {
101 self.rx_msg_buf.drain(..consumed);
102 on_msg(message);
103 }
104 Err(DecodeError::NotEnoughData) => break 'decode,
105 Err(DecodeError::MalformedMessage) => {
106 break 'read Err(ReadError::MalformedMessage)
107 }
108 }
109 } else {
110 break 'decode;
111 }
112 }
113 } else {
114 let mut next_from = 0;
115 'decode: loop {
116 let next = &received[next_from..];
117 if !next.is_empty() {
118 match M::decode(next) {
119 Ok((message, consumed)) => {
120 on_msg(message);
121 next_from += consumed;
122 }
123 Err(DecodeError::NotEnoughData) => {
124 if self.rx_msg_buf.capacity() == 0 {
125 self.rx_msg_buf
126 .reserve_exact(self.config.rx_buf_min_size);
127 }
128 self.rx_msg_buf.extend_from_slice(next);
129 break 'decode;
130 }
131 Err(DecodeError::MalformedMessage) => {
132 break 'read Err(ReadError::MalformedMessage);
133 }
134 }
135 } else {
136 break 'decode;
137 }
138 }
139 }
140 }
141
142 Err(err) if err.kind() == io::ErrorKind::WouldBlock => break 'read Ok(()),
143
144 Err(err) => break 'read Err(ReadError::Error(err)),
145 }
146 }
147 }
148
149 pub fn write(&mut self, now: Instant) -> io::Result<()> {
151 if !self.has_queued_data() {
152 return Ok(());
153 }
154
155 loop {
156 match self.attempt_write(now) {
157 Ok(written) => {
158 let has_more = self.has_queued_data();
159 log::trace!("wrote out {written} bytes, has more: {}", has_more);
160
161 if !has_more {
162 break Ok(());
163 }
164 }
165
166 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
167 log::trace!("write would block");
168 break Ok(());
169 }
170
171 Err(err) => break Err(err),
172 }
173 }
174 }
175 #[must_use]
185 pub fn queue_message<M: Message>(&mut self, message: &M) -> bool {
186 if self.tx_msg_buf.len() <= self.config.tx_buf_max_size {
187 let encoded = message.encode(&mut self.tx_msg_buf);
188 self.tx_queue_points.append(encoded);
189 true
190 } else {
191 false
192 }
193 }
194
195 pub fn is_write_stale(&self, now: Instant) -> bool {
198 match self.tx_queue_points.first() {
199 Some(t) => {
200 let timeout = self.config.stream_write_timeout;
201 (now - t > timeout) && (now - self.last_write > timeout)
202 }
203 None => false,
204 }
205 }
206
207 pub fn shrink_buffers(&mut self) {
212 fn shrink(v: &mut Vec<u8>, min: usize) {
213 if v.capacity() > min {
214 let shrink_to = (2 * v.capacity()) / 3;
215 v.shrink_to(min.max(shrink_to));
216 }
217 }
218 shrink(&mut self.rx_msg_buf, self.config.rx_buf_min_size);
219 shrink(&mut self.tx_msg_buf, self.config.tx_buf_min_size);
220 self.tx_queue_points.shrink();
221 }
222
223 pub fn interest(&self) -> mio::Interest {
225 if self.has_queued_data() {
226 mio::Interest::READABLE | mio::Interest::WRITABLE
227 } else {
228 mio::Interest::READABLE
229 }
230 }
231
232 fn attempt_write(&mut self, now: Instant) -> io::Result<usize> {
235 let written = self.stream.write(&self.tx_msg_buf)?;
236 self.tx_msg_buf.drain(..written);
237 self.stream.flush()?;
238 self.last_write = now;
239 self.tx_queue_points.handle_write(written);
240 Ok(written)
241 }
242
243 #[inline(always)]
245 fn has_queued_data(&self) -> bool {
246 !self.tx_msg_buf.is_empty()
247 }
248}
249
250impl MessageStream<mio::net::TcpStream> {
251 pub fn is_ready(&mut self) -> bool {
254 if !self.ready {
255 self.ready = self.stream.peer_addr().is_ok();
256 }
257 self.ready
258 }
259 pub fn shutdown(self) -> io::Result<()> {
261 self.stream.shutdown(std::net::Shutdown::Both)
262 }
263
264 pub fn as_source(&mut self) -> &mut impl mio::event::Source {
266 &mut self.stream
267 }
268}
269
270mod queue_points {
274 use std::time::Instant;
275
276 #[derive(Debug)]
278 struct Point {
279 time: Instant,
280 left: usize,
281 }
282
283 #[derive(Debug, Default)]
285 pub struct Queue(Vec<Point>);
286
287 impl Queue {
288 pub fn handle_write(&mut self, n_written: usize) {
290 let mut n_bytes_left = n_written;
291 let mut n_pop = 0;
292
293 for q in &mut self.0 {
294 let q_written = n_bytes_left.min(q.left);
295 n_bytes_left -= q_written;
296 q.left -= q_written;
297
298 if q.left == 0 {
299 n_pop += 1;
300 }
301
302 if n_bytes_left == 0 {
303 break;
304 }
305 }
306
307 assert_eq!(n_bytes_left, 0);
308 self.0.drain(..n_pop);
309 }
310
311 pub fn append(&mut self, size: usize) {
313 self.0.push(Point {
314 time: Instant::now(),
315 left: size,
316 })
317 }
318
319 pub fn first(&self) -> Option<Instant> {
321 self.0.first().map(|p| p.time)
322 }
323
324 pub fn shrink(&mut self) {
326 self.0.shrink_to(8.max((2 * self.0.capacity()) / 3));
327 }
328 }
329
330 #[cfg(test)]
331 #[test]
332 fn queue_behavior() {
333 let mut queue = Queue::default();
334
335 queue.append(10);
336 queue.append(20);
337 queue.append(30);
338
339 assert_eq!(queue.0[0].left, 10);
340 assert_eq!(queue.0[1].left, 20);
341 assert_eq!(queue.0[2].left, 30);
342
343 queue.handle_write(5);
344 assert_eq!(queue.0[0].left, 5);
345 assert_eq!(queue.0[1].left, 20);
346 assert_eq!(queue.0[2].left, 30);
347
348 queue.handle_write(5);
349 assert_eq!(queue.0[0].left, 20);
350 assert_eq!(queue.0[1].left, 30);
351
352 queue.handle_write(25);
353 assert_eq!(queue.0[0].left, 25);
354 assert_eq!(queue.0.len(), 1);
355
356 queue.handle_write(25);
357 assert!(queue.first().is_none());
358 }
359}
360
361#[cfg(test)]
362mod test {
363 use std::cell::RefCell;
364 use std::io::Cursor;
365
366 use super::*;
367
368 #[derive(Debug, Eq, PartialEq)]
369 struct Ping(u64);
370
371 impl Message for Ping {
372 fn encode(&self, dest: &mut impl std::io::Write) -> usize {
373 dest.write(&self.0.to_le_bytes()).unwrap()
374 }
375
376 fn decode(buffer: &[u8]) -> Result<(Self, usize), DecodeError> {
377 if buffer.len() >= 8 {
378 Ok((Ping(u64::from_le_bytes(buffer[..8].try_into().unwrap())), 8))
379 } else {
380 Err(DecodeError::NotEnoughData)
381 }
382 }
383 }
384
385 #[test]
386 fn reassemble_message_whole_reads() {
387 let mut buf = [0; 1024];
388 let mut cursor = Cursor::new(Vec::new());
389
390 Ping(0).encode(&mut cursor);
391 Ping(1).encode(&mut cursor);
392 cursor.set_position(0);
393
394 let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
395
396 let received: RefCell<Vec<Ping>> = Default::default();
397 let err = conn.read(&mut buf, |message| {
398 received.borrow_mut().push(message);
399 });
400
401 assert_eq!(received.borrow()[0], Ping(0));
402 assert_eq!(received.borrow()[1], Ping(1));
403 assert!(matches!(err, Err(ReadError::EndOfStream)));
404 assert_eq!(conn.stream.position(), 16);
405 assert!(conn.rx_msg_buf.is_empty());
406 }
407
408 #[test]
409 fn reassemble_message_partial_reads() {
410 let mut buf = [0; 1024];
411 let mut cursor = Cursor::new(Vec::new());
412 let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
413 let mut serialized = Vec::new();
414 Ping(u64::MAX - 1).encode(&mut serialized);
415 Ping(u64::MAX).encode(&mut serialized);
416
417 let received: RefCell<Vec<Ping>> = Default::default();
418
419 conn.stream.get_mut().extend_from_slice(&serialized[..4]);
420 let _ = conn.read(&mut buf, |message| {
421 received.borrow_mut().push(message);
422 });
423 assert!(received.borrow().is_empty());
424 assert_eq!(conn.rx_msg_buf.len(), 4);
425
426 conn.stream.get_mut().extend_from_slice(&serialized[4..]);
427 let _ = conn.read(&mut buf, |message| {
428 received.borrow_mut().push(message);
429 });
430 assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
431 assert_eq!(received.borrow()[1], Ping(u64::MAX));
432 }
433
434 #[test]
435 fn send_message() {
436 let mut wire = Cursor::new(Vec::<u8>::new());
437 let mut connection = MessageStream::new(&mut wire, StreamConfig::default());
438
439 assert!(connection.queue_message(&Ping(0)));
440 assert!(connection.queue_message(&Ping(1)));
441 assert!(connection.queue_message(&Ping(2)));
442
443 let cloned_buffer = connection.tx_msg_buf.clone();
444 connection.write(Instant::now()).unwrap();
445 assert_eq!(wire.position(), 24);
446 assert_eq!(wire.into_inner(), cloned_buffer);
447 }
448
449 #[test]
450 fn send_message_buf_full() {
451 let mut wire = Cursor::new(Vec::<u8>::new());
452 let config = StreamConfig {
453 tx_buf_min_size: 1,
454 tx_buf_max_size: 7,
455 ..Default::default()
456 };
457 let mut connection = MessageStream::new(&mut wire, config);
458
459 assert!(connection.queue_message(&Ping(0)));
460 assert!(!connection.queue_message(&Ping(1)));
461
462 let buffer_len = connection.tx_msg_buf.len();
463 let cloned_buffer = connection.tx_msg_buf.clone();
464 connection.write(Instant::now()).unwrap();
465 assert_eq!(wire.position(), buffer_len as u64);
466 assert_eq!(wire.into_inner(), cloned_buffer);
467 }
468}