peerlink/
message_stream.rs

1use std::io::{self, Read, Write};
2use std::num::NonZeroU8;
3use std::time::Instant;
4
5use crate::{DecodeError, Message};
6
7/// Stream related configuration parameters.
8#[derive(Debug, Clone, Copy)]
9pub struct StreamConfig {
10    /// Defines the minimum size of the per-connection receive buffer used for message reassembly.
11    /// Low values will cause more frequent reallocation while high values will reallocate less at
12    /// the expense of more memory usage. A buffer never shrinks to less than this value.
13    pub rx_buf_min_size: usize,
14
15    /// Defines the maximum size of the per-connection receive buffer used for message reassembly.
16    /// This defaults to [`Message::MAX_SIZE`]. This number is very important in DoS mitigation
17    /// because it prevents a malicious sender from filling up a connection's receive buffer with
18    /// endless junk data.
19    pub rx_buf_max_size: MaxMessageSizeMultiple,
20
21    /// Defines the minimum size of the per-connection send buffer. Low values will cause more
22    /// frequent reallocation while high values will reallocate less at the expense of more memory
23    /// usage. A buffer never shrinks to less than this value.
24    pub tx_buf_min_size: usize,
25
26    /// Defines the maximum size of the per-connection send buffer. Once the send buffer is full, it
27    /// is not possible to queue new messages for sending until some capacity is available. A send
28    /// buffer becomes full when sending messages faster than the remote peer is reading. This value
29    /// is important for outbound backpressure control. Defaults to 2 x [`Message::MAX_SIZE`].
30    pub tx_buf_max_size: MaxMessageSizeMultiple,
31
32    /// The duration after which a peer is disconnected if it fails to read our data.
33    pub tx_timeout: std::time::Duration,
34
35    /// The duration after which a connect attempt is abandoned. Applies only to non-blocking
36    /// connect attempts. Blocking ones performed in custom connectors ignore this value.
37    pub connect_timeout: std::time::Duration,
38}
39
40/// A deferred size expressed as a multiple of the protocol's maximum message size
41/// ([`Message::MAX_SIZE`]).
42///
43/// This type does not store a byte length. Instead, it stores a factor that is multiplied by the
44/// maximum message size defined by the protocol to obtain a concrete size at evaluation time.
45#[derive(Debug, Clone, Copy)]
46pub struct MaxMessageSizeMultiple(pub NonZeroU8);
47
48impl MaxMessageSizeMultiple {
49    pub fn compute<M: Message>(&self) -> usize {
50        (self.0.get() as usize) * M::MAX_SIZE
51    }
52}
53
54impl Default for StreamConfig {
55    fn default() -> Self {
56        Self {
57            rx_buf_min_size: 4 * 1024,
58            rx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(1).unwrap()),
59            tx_buf_min_size: 4 * 1024,
60            tx_buf_max_size: MaxMessageSizeMultiple(NonZeroU8::new(2).unwrap()),
61            tx_timeout: std::time::Duration::from_secs(30),
62            connect_timeout: std::time::Duration::from_secs(5),
63        }
64    }
65}
66
67/// Wraps read and write parts of a peer connection and does message buffering and assembly
68/// on the read side and message serialization and flushing on the write side.
69#[derive(Debug)]
70pub struct MessageStream<T: Read + Write> {
71    /// Configuration parameters for the stream.
72    config: StreamConfig,
73    /// The read+write stream underlying the connection.
74    stream: T,
75    /// Buffer used for message reconstruction.
76    rx_msg_buf: Vec<u8>,
77    /// Buffer used for sending.
78    tx_msg_buf: Vec<u8>,
79    /// The list of queue points for outgoing messages.
80    tx_queue_points: queue_points::Queue,
81    /// Cached readyness.
82    ready: bool,
83    /// Last successful write time.
84    last_write: Instant,
85}
86
87#[derive(Debug)]
88pub enum ReadError {
89    /// A malformed message was received.
90    MalformedMessage,
91    /// End of stream has been reached (closed stream).
92    EndOfStream,
93    /// The stream produced an I/O error.
94    Error(io::Error),
95}
96
97// Some notes:
98// The read strategy is as follows: every stream gets a fair chance to read. The most a stream can
99// read in one round is the size of shared receive buffer, or the amount of space remaining in the
100// receive buffer for incomplete messages, whichever is lower. This way we ensure that we never get
101// overwhelmed by an aggressive sender.
102
103impl<T: Read + Write> MessageStream<T> {
104    /// Creates a new [`Stream`] instance.
105    pub fn new(stream: T, config: StreamConfig) -> Self {
106        Self {
107            stream,
108            rx_msg_buf: Vec::new(),
109            tx_msg_buf: Vec::new(),
110            tx_queue_points: Default::default(),
111            ready: false,
112            last_write: Instant::now(),
113            config,
114        }
115    }
116
117    /// Reads data from a stream and then attempts to decode messages from the data.
118    /// Decoded messages are passed into the provided closure.
119    /// Encountering an error means that the stream must be discarded.
120    /// Attempts to read from a connection and decode messages in a fair manner.
121    ///
122    /// Returns whether there is more work available.
123    #[must_use]
124    pub fn read<M: Message, F: Fn(M, usize)>(
125        &mut self,
126        rx_buf: &mut [u8],
127        on_msg: F,
128    ) -> Result<bool, ReadError> {
129        let preexisting = !self.rx_msg_buf.is_empty();
130
131        // limit ourselves to reading some fixed number of bytes
132        let max_buf_size = self.config.rx_buf_max_size.compute::<M>();
133        let limit = (max_buf_size - self.rx_msg_buf.len()).min(rx_buf.len());
134
135        let (total_read, read_result) = {
136            let buffer = &mut rx_buf[..limit];
137            let mut total_read: usize = 0;
138
139            let result = loop {
140                match self.stream.read(&mut buffer[total_read..]) {
141                    // there is maybe more to read but our buffer was already full
142                    Ok(0) if buffer.len() == 0 => break Ok(true),
143                    // buffer was not full and we simply reached the end of stream
144                    Ok(0) => break Err(ReadError::EndOfStream),
145                    // regular nonzero read
146                    Ok(read @ 1..) => {
147                        total_read += read;
148                        if total_read == buffer.len() {
149                            // exceeded, maybe there is more but we need to move on
150                            break Ok(true);
151                        }
152                    }
153                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => break Ok(false),
154                    Err(err) => break Err(ReadError::Error(err)),
155                }
156            };
157
158            (total_read, result)
159        };
160
161        // for now we always decode as much as we can so we don't care about this result
162        let _decode_has_more = if !preexisting {
163            let (consumed, result) = decode_from_buffer(&mut rx_buf[..total_read], on_msg)?;
164            if consumed < total_read {
165                self.rx_msg_buf
166                    .extend_from_slice(&rx_buf[consumed..total_read]);
167            }
168            result
169        } else {
170            self.rx_msg_buf.extend_from_slice(&rx_buf[..total_read]);
171            let (consumed, result) = decode_from_buffer(&mut &mut self.rx_msg_buf[..], on_msg)?;
172            self.rx_msg_buf.drain(..consumed);
173            result
174        };
175
176        read_result
177    }
178
179    /// Writes out as many bytes from the send buffer as possible, until blocking would start.
180    ///
181    /// Returns whether more data remains queued for writing.
182    #[must_use]
183    pub fn write(&mut self, now: Instant) -> io::Result<bool> {
184        if !self.has_queued_data() {
185            return Ok(false);
186        }
187
188        loop {
189            match self.try_write(now) {
190                Ok(written) => {
191                    let has_more = self.has_queued_data();
192                    log::trace!("wrote out {written} bytes, has more: {}", has_more);
193
194                    if !has_more {
195                        break Ok(false);
196                    }
197                }
198
199                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
200                    log::trace!("write would block");
201                    break Ok(self.has_queued_data());
202                }
203
204                Err(err) => break Err(err),
205            }
206        }
207    }
208    /// Queues a message for sending. This method simply encodes the message and places it into
209    /// the internal buffer. [`write`] must be called when the socket is writeable in order to flush.
210    ///
211    /// Returns `true` if the write buffer contains enough space to accept the message, or `false`
212    /// if the buffer is full and the message cannot be queued at this time.
213    ///
214    /// Note that if a message has a [`Message::size_hint`] implementation, the size hint is used to
215    /// determine whether the message will be accepted into the send buffer. If a message does
216    /// not have a size hint, two scenarios exist:
217    ///   - The buffer is not full -- the message is encoded and placed into the buffer even if
218    ///        that will exceed its maximum size and push it past the configured limits.
219    ///   - The buffer is full -- the message will not be encoded and queued.
220    #[must_use]
221    pub fn queue_message<M: Message>(&mut self, message: &M) -> bool {
222        let size_hint = message.size_hint().unwrap_or_default();
223        if size_hint + self.tx_msg_buf.len() >= self.config.tx_buf_max_size.compute::<M>() {
224            false
225        } else {
226            let encoded = message.encode(&mut self.tx_msg_buf);
227            self.tx_queue_points.append(encoded);
228            true
229        }
230    }
231
232    /// Returns whether the stream is stale on the write side, i.e. the data is not leaving the
233    /// send buffer in a timely manner.
234    pub fn is_write_stale(&self, now: Instant) -> bool {
235        self.tx_queue_points.first().is_some_and(|t| {
236            let timeout = self.config.tx_timeout;
237            (now - t > self.config.tx_timeout) && (now - self.last_write > timeout)
238        })
239    }
240
241    /// Resizes and shrinks the capacity of internal send and receive buffers by 1/4, until the
242    /// floor is reached. This helps maintain memory usage at sane levels since keeping permanent
243    /// large receive buffers (e.g. after receiving a large message) would eventually exhaust
244    /// available memory on less powerful devices when managing many peers.
245    pub fn shrink_buffers(&mut self) {
246        fn shrink(v: &mut Vec<u8>, min: usize) {
247            if v.capacity() > min {
248                let shrink_to = 3 * (v.capacity() / 4);
249                v.shrink_to(min.max(shrink_to));
250            }
251        }
252        shrink(&mut self.rx_msg_buf, self.config.rx_buf_min_size);
253        shrink(&mut self.tx_msg_buf, self.config.tx_buf_min_size);
254        self.tx_queue_points.shrink();
255    }
256
257    /// Takes some bytes from the local send buffer and sends them. Removes successfully sent bytes
258    /// from the buffer. Returns the number of bytes sent.
259    fn try_write(&mut self, now: Instant) -> io::Result<usize> {
260        let written = self.stream.write(&self.tx_msg_buf)?;
261        self.tx_msg_buf.drain(..written);
262        self.stream.flush()?;
263        self.last_write = now;
264        self.tx_queue_points.mark_write(written);
265        Ok(written)
266    }
267
268    /// Returns whether the send buffer has more data to write.
269    #[inline(always)]
270    pub fn has_queued_data(&self) -> bool {
271        !self.tx_msg_buf.is_empty()
272    }
273}
274
275fn decode_from_buffer<M: Message, F: Fn(M, usize)>(
276    buffer: &mut [u8],
277    on_msg: F,
278) -> Result<(usize, bool), ReadError> {
279    let mut cursor: usize = 0;
280    loop {
281        match M::decode(&buffer[cursor..]) {
282            Ok((message, consumed)) => {
283                cursor += consumed;
284                on_msg(message, consumed);
285            }
286            Err(DecodeError::NotEnoughData) => {
287                break Ok((cursor, false)); // not ready in the next round as far as we know
288            }
289            Err(DecodeError::MalformedMessage) => {
290                break Err(ReadError::MalformedMessage);
291            }
292        }
293    }
294}
295
296impl MessageStream<mio::net::TcpStream> {
297    /// Returns `true` if the underlying stream is ready. Otherwise it tests readiness and
298    /// caches the result.
299    pub fn is_ready(&mut self) -> bool {
300        if !self.ready {
301            self.ready = self.stream.peer_addr().is_ok();
302        }
303        self.ready
304    }
305    /// Shuts down the underlying stream.
306    pub fn shutdown(self) -> io::Result<()> {
307        self.stream.shutdown(std::net::Shutdown::Both)
308    }
309
310    pub fn take_error(&self) -> Option<io::Error> {
311        self.stream.take_error().ok().flatten()
312    }
313
314    /// Returns the underlying stream as a Mio event source.
315    pub fn as_source(&mut self) -> &mut impl mio::event::Source {
316        &mut self.stream
317    }
318}
319
320/// Provides a collection that tracks points in time where a message of certain size was queued.
321/// This allows the consumer to track how long ago a message was attempted to be sent out and how
322/// many bytes are yet to be sent.
323mod queue_points {
324    use std::collections::VecDeque;
325    use std::time::Instant;
326
327    /// A single queue point given a point in time and the remaining number of bytes.
328    #[derive(Debug)]
329    struct Point {
330        time: Instant,
331        left: usize,
332    }
333
334    /// A list of queue points, from oldest to newest.
335    #[derive(Debug, Default)]
336    pub struct Queue(VecDeque<Point>);
337
338    impl Queue {
339        /// Handles a write event.
340        pub fn mark_write(&mut self, n_written: usize) {
341            let mut n_bytes_left = n_written;
342            let mut n_pop = 0;
343
344            for q in &mut self.0 {
345                let q_written = n_bytes_left.min(q.left);
346                n_bytes_left -= q_written;
347                q.left -= q_written;
348
349                if q.left == 0 {
350                    n_pop += 1;
351                }
352
353                if n_bytes_left == 0 {
354                    break;
355                }
356            }
357
358            assert_eq!(n_bytes_left, 0);
359            self.0.drain(..n_pop);
360        }
361
362        /// Appends a new queue point of a certain size.
363        pub fn append(&mut self, size: usize) {
364            self.0.push_back(Point {
365                time: Instant::now(),
366                left: size,
367            })
368        }
369
370        /// Returns the creation instant of the first queue point, if any.
371        pub fn first(&self) -> Option<Instant> {
372            self.0.front().map(|p| p.time)
373        }
374
375        /// Shrinks the capacity of the queue by 1/4, floored at 8.
376        pub fn shrink(&mut self) {
377            if self.0.capacity() > 8 {
378                self.0.shrink_to(8.max(3 * (self.0.capacity() / 4)));
379            }
380        }
381    }
382
383    #[cfg(test)]
384    #[test]
385    fn queue_behavior() {
386        let mut queue = Queue::default();
387
388        queue.append(10);
389        queue.append(20);
390        queue.append(30);
391
392        assert_eq!(queue.0[0].left, 10);
393        assert_eq!(queue.0[1].left, 20);
394        assert_eq!(queue.0[2].left, 30);
395
396        queue.mark_write(5);
397        assert_eq!(queue.0[0].left, 5);
398        assert_eq!(queue.0[1].left, 20);
399        assert_eq!(queue.0[2].left, 30);
400
401        queue.mark_write(5);
402        assert_eq!(queue.0[0].left, 20);
403        assert_eq!(queue.0[1].left, 30);
404
405        queue.mark_write(25);
406        assert_eq!(queue.0[0].left, 25);
407        assert_eq!(queue.0.len(), 1);
408
409        queue.mark_write(25);
410        assert!(queue.first().is_none());
411    }
412}
413
414#[cfg(test)]
415mod test {
416    use std::cell::RefCell;
417    use std::io::Cursor;
418
419    use super::*;
420
421    #[derive(Debug, Eq, PartialEq)]
422    struct Ping(u64);
423
424    impl Message for Ping {
425        const MAX_SIZE: usize = 8;
426
427        fn encode(&self, dest: &mut impl std::io::Write) -> usize {
428            dest.write(&self.0.to_le_bytes()).unwrap()
429        }
430
431        fn decode(buffer: &[u8]) -> Result<(Self, usize), DecodeError> {
432            if buffer.len() >= 8 {
433                Ok((Ping(u64::from_le_bytes(buffer[..8].try_into().unwrap())), 8))
434            } else {
435                Err(DecodeError::NotEnoughData)
436            }
437        }
438    }
439
440    #[test]
441    fn reassemble_message_whole_reads() {
442        let mut buf = [0; 1024];
443        let mut cursor = Cursor::new(Vec::new());
444
445        Ping(0).encode(&mut cursor);
446        Ping(1).encode(&mut cursor);
447        cursor.set_position(0);
448
449        let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
450
451        let received: RefCell<Vec<Ping>> = Default::default();
452        conn.read(&mut buf, |message, size| {
453            assert_eq!(size, 8);
454            received.borrow_mut().push(message);
455        })
456        .unwrap();
457
458        assert_eq!(received.borrow()[0], Ping(0));
459
460        conn.read(&mut buf, |message, size| {
461            assert_eq!(size, 8);
462            received.borrow_mut().push(message);
463        })
464        .unwrap();
465        assert_eq!(received.borrow()[1], Ping(1));
466
467        let err = conn.read(&mut buf, |message, size| {
468            assert_eq!(size, 8);
469            received.borrow_mut().push(message);
470        });
471        assert!(matches!(err, Err(ReadError::EndOfStream)));
472        assert_eq!(conn.stream.position(), 16);
473        assert!(conn.rx_msg_buf.is_empty());
474    }
475
476    #[test]
477    fn reassemble_message_partial_reads() {
478        let mut buf = [0; 8];
479        let mut cursor = Cursor::new(Vec::new());
480        let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
481        let mut serialized = Vec::new();
482        Ping(u64::MAX - 1).encode(&mut serialized);
483        Ping(u64::MAX).encode(&mut serialized);
484
485        let received: RefCell<Vec<Ping>> = Default::default();
486
487        conn.stream.get_mut().extend_from_slice(&serialized[..4]);
488        let _ = conn.read(&mut buf, |message, size| {
489            assert_eq!(size, 8);
490            received.borrow_mut().push(message);
491        });
492        assert!(received.borrow().is_empty());
493        assert_eq!(conn.rx_msg_buf.len(), 4);
494
495        conn.stream.get_mut().extend_from_slice(&serialized[4..]);
496        let _ = conn.read(&mut buf, |message, size| {
497            assert_eq!(size, 8);
498            received.borrow_mut().push(message);
499        });
500        assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
501
502        let _ = conn.read(&mut buf, |message, size| {
503            assert_eq!(size, 8);
504            received.borrow_mut().push(message);
505        });
506        assert_eq!(received.borrow()[1], Ping(u64::MAX));
507    }
508
509    #[test]
510    fn send_message() {
511        let mut wire = Cursor::new(Vec::<u8>::new());
512        let mut connection = MessageStream::new(
513            &mut wire,
514            StreamConfig {
515                tx_buf_max_size: MaxMessageSizeMultiple(3.try_into().unwrap()),
516                ..Default::default()
517            },
518        );
519
520        assert!(connection.queue_message(&Ping(0)));
521        assert!(connection.queue_message(&Ping(1)));
522        assert!(connection.queue_message(&Ping(2)));
523
524        let cloned_buffer = connection.tx_msg_buf.clone();
525        connection.write(Instant::now()).unwrap();
526        assert_eq!(wire.position(), 24);
527        assert_eq!(wire.into_inner(), cloned_buffer);
528    }
529
530    #[test]
531    fn send_message_buf_full() {
532        let mut wire = Cursor::new(Vec::<u8>::new());
533        let config = StreamConfig {
534            tx_buf_min_size: 1,
535            tx_buf_max_size: MaxMessageSizeMultiple(1.try_into().unwrap()),
536            ..Default::default()
537        };
538        let mut connection = MessageStream::new(&mut wire, config);
539
540        assert!(connection.queue_message(&Ping(0)));
541        assert!(!connection.queue_message(&Ping(1)));
542
543        let buffer_len = connection.tx_msg_buf.len();
544        let cloned_buffer = connection.tx_msg_buf.clone();
545        connection.write(Instant::now()).unwrap();
546        assert_eq!(wire.position(), buffer_len as u64);
547        assert_eq!(wire.into_inner(), cloned_buffer);
548    }
549}