1use std::io::{self, Read, Write};
2use std::num::NonZeroU8;
3use std::time::Instant;
4
5use crate::{DecodeError, Message};
6
7#[derive(Debug, Clone, Copy)]
9pub struct StreamConfig {
10 pub rx_buf_min_size: usize,
14
15 pub rx_buf_max_size: MaxMessageSizeMultiple,
20
21 pub tx_buf_min_size: usize,
25
26 pub tx_buf_max_size: MaxMessageSizeMultiple,
31
32 pub tx_timeout: std::time::Duration,
34
35 pub connect_timeout: std::time::Duration,
38}
39
40#[derive(Debug, Clone, Copy)]
46pub struct MaxMessageSizeMultiple(pub NonZeroU8);
47
48impl MaxMessageSizeMultiple {
49 pub fn compute<M: Message>(&self) -> usize {
50 (self.0.get() as usize) * M::MAX_SIZE
51 }
52}
53
54impl Default for StreamConfig {
55 fn default() -> Self {
56 Self {
57 rx_buf_min_size: 4 * 1024,
58 rx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(1).unwrap()),
59 tx_buf_min_size: 4 * 1024,
60 tx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(2).unwrap()),
61 tx_timeout: std::time::Duration::from_secs(30),
62 connect_timeout: std::time::Duration::from_secs(5),
63 }
64 }
65}
66
67#[derive(Debug)]
70pub struct MessageStream<T: Read + Write> {
71 config: StreamConfig,
73 stream: T,
75 rx_msg_buf: Vec<u8>,
77 tx_msg_buf: Vec<u8>,
79 tx_queue_points: queue_points::Queue,
81 ready: bool,
83 last_write: Instant,
85}
86
87#[derive(Debug)]
88pub enum ReadError {
89 MalformedMessage,
91 EndOfStream,
93 Error(io::Error),
95}
96
97impl<T: Read + Write> MessageStream<T> {
104 pub fn new(stream: T, config: StreamConfig) -> Self {
106 Self {
107 stream,
108 rx_msg_buf: Vec::new(),
109 tx_msg_buf: Vec::new(),
110 tx_queue_points: Default::default(),
111 ready: false,
112 last_write: Instant::now(),
113 config,
114 }
115 }
116
117 #[must_use]
124 pub fn read<M: Message, F: Fn(M, usize)>(
125 &mut self,
126 rx_buf: &mut [u8],
127 on_msg: F,
128 ) -> Result<bool, ReadError> {
129 let preexisting = !self.rx_msg_buf.is_empty();
130
131 let max_buf_size = self.config.rx_buf_max_size.compute::<M>();
133 let limit = (max_buf_size - self.rx_msg_buf.len()).min(rx_buf.len());
134
135 let (total_read, read_result) = {
136 let buffer = &mut rx_buf[..limit];
137 let mut total_read: usize = 0;
138
139 let result = loop {
140 match self.stream.read(&mut buffer[total_read..]) {
141 Ok(0) if buffer.len() == 0 => break Ok(true),
143 Ok(0) => break Err(ReadError::EndOfStream),
145 Ok(read @ 1..) => {
147 total_read += read;
148 if total_read == buffer.len() {
149 break Ok(true);
151 }
152 }
153 Err(err) if err.kind() == io::ErrorKind::WouldBlock => break Ok(false),
154 Err(err) => break Err(ReadError::Error(err)),
155 }
156 };
157
158 (total_read, result)
159 };
160
161 let _decode_has_more = if !preexisting {
163 let (consumed, result) = decode_from_buffer(&mut rx_buf[..total_read], on_msg)?;
164 if consumed < total_read {
165 self.rx_msg_buf
166 .extend_from_slice(&rx_buf[consumed..total_read]);
167 }
168 result
169 } else {
170 self.rx_msg_buf.extend_from_slice(&rx_buf[..total_read]);
171 let (consumed, result) = decode_from_buffer(&mut &mut self.rx_msg_buf[..], on_msg)?;
172 self.rx_msg_buf.drain(..consumed);
173 result
174 };
175
176 read_result
177 }
178
179 #[must_use]
183 pub fn write(&mut self, now: Instant) -> io::Result<bool> {
184 if !self.has_queued_data() {
185 return Ok(false);
186 }
187
188 loop {
189 match self.try_write(now) {
190 Ok(written) => {
191 let has_more = self.has_queued_data();
192 log::trace!("wrote out {written} bytes, has more: {}", has_more);
193
194 if !has_more {
195 break Ok(false);
196 }
197 }
198
199 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
200 log::trace!("write would block");
201 break Ok(self.has_queued_data());
202 }
203
204 Err(err) => break Err(err),
205 }
206 }
207 }
208 #[must_use]
221 pub fn queue_message<M: Message>(&mut self, message: &M) -> bool {
222 let size_hint = message.size_hint().unwrap_or_default();
223 if size_hint + self.tx_msg_buf.len() >= self.config.tx_buf_max_size.compute::<M>() {
224 false
225 } else {
226 let encoded = message.encode(&mut self.tx_msg_buf);
227 self.tx_queue_points.append(encoded);
228 true
229 }
230 }
231
232 pub fn is_write_stale(&self, now: Instant) -> bool {
235 self.tx_queue_points.first().is_some_and(|t| {
236 let timeout = self.config.tx_timeout;
237 (now - t > self.config.tx_timeout) && (now - self.last_write > timeout)
238 })
239 }
240
241 pub fn shrink_buffers(&mut self) {
246 fn shrink(v: &mut Vec<u8>, min: usize) {
247 if v.capacity() > min {
248 let shrink_to = 3 * (v.capacity() / 4);
249 v.shrink_to(min.max(shrink_to));
250 }
251 }
252 shrink(&mut self.rx_msg_buf, self.config.rx_buf_min_size);
253 shrink(&mut self.tx_msg_buf, self.config.tx_buf_min_size);
254 self.tx_queue_points.shrink();
255 }
256
257 fn try_write(&mut self, now: Instant) -> io::Result<usize> {
260 let written = self.stream.write(&self.tx_msg_buf)?;
261 self.tx_msg_buf.drain(..written);
262 self.stream.flush()?;
263 self.last_write = now;
264 self.tx_queue_points.mark_write(written);
265 Ok(written)
266 }
267
268 #[inline(always)]
270 pub fn has_queued_data(&self) -> bool {
271 !self.tx_msg_buf.is_empty()
272 }
273}
274
275fn decode_from_buffer<M: Message, F: Fn(M, usize)>(
276 buffer: &mut [u8],
277 on_msg: F,
278) -> Result<(usize, bool), ReadError> {
279 let mut cursor: usize = 0;
280 loop {
281 match M::decode(&buffer[cursor..]) {
282 Ok((message, consumed)) => {
283 cursor += consumed;
284 on_msg(message, consumed);
285 }
286 Err(DecodeError::NotEnoughData) => {
287 break Ok((cursor, false)); }
289 Err(DecodeError::MalformedMessage) => {
290 break Err(ReadError::MalformedMessage);
291 }
292 }
293 }
294}
295
296impl MessageStream<mio::net::TcpStream> {
297 pub fn is_ready(&mut self) -> bool {
300 if !self.ready {
301 self.ready = self.stream.peer_addr().is_ok();
302 }
303 self.ready
304 }
305 pub fn shutdown(self) -> io::Result<()> {
307 self.stream.shutdown(std::net::Shutdown::Both)
308 }
309
310 pub fn take_error(&self) -> Option<io::Error> {
311 self.stream.take_error().ok().flatten()
312 }
313
314 pub fn as_source(&mut self) -> &mut impl mio::event::Source {
316 &mut self.stream
317 }
318}
319
320mod queue_points {
324 use std::collections::VecDeque;
325 use std::time::Instant;
326
327 #[derive(Debug)]
329 struct Point {
330 time: Instant,
331 left: usize,
332 }
333
334 #[derive(Debug, Default)]
336 pub struct Queue(VecDeque<Point>);
337
338 impl Queue {
339 pub fn mark_write(&mut self, n_written: usize) {
341 let mut n_bytes_left = n_written;
342 let mut n_pop = 0;
343
344 for q in &mut self.0 {
345 let q_written = n_bytes_left.min(q.left);
346 n_bytes_left -= q_written;
347 q.left -= q_written;
348
349 if q.left == 0 {
350 n_pop += 1;
351 }
352
353 if n_bytes_left == 0 {
354 break;
355 }
356 }
357
358 assert_eq!(n_bytes_left, 0);
359 self.0.drain(..n_pop);
360 }
361
362 pub fn append(&mut self, size: usize) {
364 self.0.push_back(Point {
365 time: Instant::now(),
366 left: size,
367 })
368 }
369
370 pub fn first(&self) -> Option<Instant> {
372 self.0.front().map(|p| p.time)
373 }
374
375 pub fn shrink(&mut self) {
377 if self.0.capacity() > 8 {
378 self.0.shrink_to(8.max(3 * (self.0.capacity() / 4)));
379 }
380 }
381 }
382
383 #[cfg(test)]
384 #[test]
385 fn queue_behavior() {
386 let mut queue = Queue::default();
387
388 queue.append(10);
389 queue.append(20);
390 queue.append(30);
391
392 assert_eq!(queue.0[0].left, 10);
393 assert_eq!(queue.0[1].left, 20);
394 assert_eq!(queue.0[2].left, 30);
395
396 queue.mark_write(5);
397 assert_eq!(queue.0[0].left, 5);
398 assert_eq!(queue.0[1].left, 20);
399 assert_eq!(queue.0[2].left, 30);
400
401 queue.mark_write(5);
402 assert_eq!(queue.0[0].left, 20);
403 assert_eq!(queue.0[1].left, 30);
404
405 queue.mark_write(25);
406 assert_eq!(queue.0[0].left, 25);
407 assert_eq!(queue.0.len(), 1);
408
409 queue.mark_write(25);
410 assert!(queue.first().is_none());
411 }
412}
413
414#[cfg(test)]
415mod test {
416 use std::cell::RefCell;
417 use std::io::Cursor;
418
419 use super::*;
420
421 #[derive(Debug, Eq, PartialEq)]
422 struct Ping(u64);
423
424 impl Message for Ping {
425 const MAX_SIZE: usize = 8;
426
427 fn encode(&self, dest: &mut impl std::io::Write) -> usize {
428 dest.write(&self.0.to_le_bytes()).unwrap()
429 }
430
431 fn decode(buffer: &[u8]) -> Result<(Self, usize), DecodeError> {
432 if buffer.len() >= 8 {
433 Ok((Ping(u64::from_le_bytes(buffer[..8].try_into().unwrap())), 8))
434 } else {
435 Err(DecodeError::NotEnoughData)
436 }
437 }
438 }
439
440 #[test]
441 fn reassemble_message_whole_reads() {
442 let mut buf = [0; 1024];
443 let mut cursor = Cursor::new(Vec::new());
444
445 Ping(0).encode(&mut cursor);
446 Ping(1).encode(&mut cursor);
447 cursor.set_position(0);
448
449 let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
450
451 let received: RefCell<Vec<Ping>> = Default::default();
452 conn.read(&mut buf, |message, size| {
453 assert_eq!(size, 8);
454 received.borrow_mut().push(message);
455 })
456 .unwrap();
457
458 assert_eq!(received.borrow()[0], Ping(0));
459
460 conn.read(&mut buf, |message, size| {
461 assert_eq!(size, 8);
462 received.borrow_mut().push(message);
463 })
464 .unwrap();
465 assert_eq!(received.borrow()[1], Ping(1));
466
467 let err = conn.read(&mut buf, |message, size| {
468 assert_eq!(size, 8);
469 received.borrow_mut().push(message);
470 });
471 assert!(matches!(err, Err(ReadError::EndOfStream)));
472 assert_eq!(conn.stream.position(), 16);
473 assert!(conn.rx_msg_buf.is_empty());
474 }
475
476 #[test]
477 fn reassemble_message_partial_reads() {
478 let mut buf = [0; 8];
479 let mut cursor = Cursor::new(Vec::new());
480 let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
481 let mut serialized = Vec::new();
482 Ping(u64::MAX - 1).encode(&mut serialized);
483 Ping(u64::MAX).encode(&mut serialized);
484
485 let received: RefCell<Vec<Ping>> = Default::default();
486
487 conn.stream.get_mut().extend_from_slice(&serialized[..4]);
488 let _ = conn.read(&mut buf, |message, size| {
489 assert_eq!(size, 8);
490 received.borrow_mut().push(message);
491 });
492 assert!(received.borrow().is_empty());
493 assert_eq!(conn.rx_msg_buf.len(), 4);
494
495 conn.stream.get_mut().extend_from_slice(&serialized[4..]);
496 let _ = conn.read(&mut buf, |message, size| {
497 assert_eq!(size, 8);
498 received.borrow_mut().push(message);
499 });
500 assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
501
502 let _ = conn.read(&mut buf, |message, size| {
503 assert_eq!(size, 8);
504 received.borrow_mut().push(message);
505 });
506 assert_eq!(received.borrow()[1], Ping(u64::MAX));
507 }
508
509 #[test]
510 fn send_message() {
511 let mut wire = Cursor::new(Vec::<u8>::new());
512 let mut connection = MessageStream::new(
513 &mut wire,
514 StreamConfig {
515 tx_buf_max_size: MaxMessageSizeMultiple(3.try_into().unwrap()),
516 ..Default::default()
517 },
518 );
519
520 assert!(connection.queue_message(&Ping(0)));
521 assert!(connection.queue_message(&Ping(1)));
522 assert!(connection.queue_message(&Ping(2)));
523
524 let cloned_buffer = connection.tx_msg_buf.clone();
525 connection.write(Instant::now()).unwrap();
526 assert_eq!(wire.position(), 24);
527 assert_eq!(wire.into_inner(), cloned_buffer);
528 }
529
530 #[test]
531 fn send_message_buf_full() {
532 let mut wire = Cursor::new(Vec::<u8>::new());
533 let config = StreamConfig {
534 tx_buf_min_size: 1,
535 tx_buf_max_size: MaxMessageSizeMultiple(1.try_into().unwrap()),
536 ..Default::default()
537 };
538 let mut connection = MessageStream::new(&mut wire, config);
539
540 assert!(connection.queue_message(&Ping(0)));
541 assert!(!connection.queue_message(&Ping(1)));
542
543 let buffer_len = connection.tx_msg_buf.len();
544 let cloned_buffer = connection.tx_msg_buf.clone();
545 connection.write(Instant::now()).unwrap();
546 assert_eq!(wire.position(), buffer_len as u64);
547 assert_eq!(wire.into_inner(), cloned_buffer);
548 }
549}