peerlink/
message_stream.rs

1use std::io::{self, Read, Write};
2use std::time::Instant;
3
4use crate::{DecodeError, Message};
5
6/// Stream related configuration parameters.
7#[derive(Debug, Clone, Copy)]
8pub struct StreamConfig {
9    /// Defines the minimum size of the buffer used for message reassembly. Low values will cause
10    /// more frequent reallocation while high values will reallocate less at the expense of more
11    /// memory usage.
12    pub rx_buf_min_size: usize,
13    /// Defines the minimum size of the buffer used for outbound data. Low values will cause
14    /// more frequent reallocation while high values will reallocate less at the expense of more
15    /// memory usage.
16    pub tx_buf_min_size: usize,
17    /// Defines the maximum capacity of the send buffer. Once the send buffer is full,
18    /// it is not possible to queue new messages for sending until some capacity is available.
19    /// A send buffer becomes full when sending messages faster than the remote peer is reading.
20    pub tx_buf_max_size: usize,
21    /// The duration after which a peer is disconnected if it fails to read incoming data.
22    pub stream_write_timeout: std::time::Duration,
23    /// The duration after which a connect attempt is abandoned. Applies only to non-blocking
24    /// connect attempts. Blocking ones performed in custom connectors ignore this value.
25    pub stream_connect_timeout: std::time::Duration,
26}
27
28impl Default for StreamConfig {
29    fn default() -> Self {
30        Self {
31            rx_buf_min_size: 32 * 1024,
32            tx_buf_min_size: 32 * 1024,
33            tx_buf_max_size: 1024 * 1024,
34            stream_write_timeout: std::time::Duration::from_secs(30),
35            stream_connect_timeout: std::time::Duration::from_secs(5),
36        }
37    }
38}
39
40/// Wraps read and write parts of a peer connection and does message buffering and assembly
41/// on the read side and message serialization and flushing on the write side.
42#[derive(Debug)]
43pub struct MessageStream<T: Read + Write> {
44    /// Configuration parameters for the stream.
45    config: StreamConfig,
46    /// The read+write stream underlying the connection.
47    stream: T,
48    /// Buffer used for message reconstruction.
49    rx_msg_buf: Vec<u8>,
50    /// Buffer used for sending.
51    tx_msg_buf: Vec<u8>,
52    /// The list of queue points for outgoing messages.
53    tx_queue_points: queue_points::Queue,
54    /// Cached readyness.
55    ready: bool,
56    /// Last successful write time.
57    last_write: Instant,
58}
59
60#[derive(Debug)]
61pub enum ReadError {
62    /// A malformed message was received.
63    MalformedMessage,
64    /// End of stream has been reached (closed stream).
65    EndOfStream,
66    /// The stream produced an I/O error.
67    Error(io::Error),
68}
69
70impl<T: Read + Write> MessageStream<T> {
71    pub fn new(stream: T, config: StreamConfig) -> Self {
72        Self {
73            stream,
74            rx_msg_buf: Vec::new(),
75            tx_msg_buf: Vec::with_capacity(config.tx_buf_min_size),
76            tx_queue_points: Default::default(),
77            ready: false,
78            last_write: Instant::now(),
79            config,
80        }
81    }
82
83    /// Receives as many messages as possible, either until reads would start blocking, or until
84    /// an error is encountered. Encountering an error means that the stream should be discarded.
85    pub fn read<M: Message, F: Fn(M)>(
86        &mut self,
87        rx_buf: &mut [u8],
88        on_msg: F,
89    ) -> Result<(), ReadError> {
90        'read: loop {
91            match self.stream.read(rx_buf).map(|read| &rx_buf[..read]) {
92                Ok(&[]) => break 'read Err(ReadError::EndOfStream),
93
94                Ok(received) => {
95                    if !self.rx_msg_buf.is_empty() {
96                        self.rx_msg_buf.extend_from_slice(received);
97                        'decode: loop {
98                            if !self.rx_msg_buf.is_empty() {
99                                match M::decode(&self.rx_msg_buf) {
100                                    Ok((message, consumed)) => {
101                                        self.rx_msg_buf.drain(..consumed);
102                                        on_msg(message);
103                                    }
104                                    Err(DecodeError::NotEnoughData) => break 'decode,
105                                    Err(DecodeError::MalformedMessage) => {
106                                        break 'read Err(ReadError::MalformedMessage)
107                                    }
108                                }
109                            } else {
110                                break 'decode;
111                            }
112                        }
113                    } else {
114                        let mut next_from = 0;
115                        'decode: loop {
116                            let next = &received[next_from..];
117                            if !next.is_empty() {
118                                match M::decode(next) {
119                                    Ok((message, consumed)) => {
120                                        on_msg(message);
121                                        next_from += consumed;
122                                    }
123                                    Err(DecodeError::NotEnoughData) => {
124                                        if self.rx_msg_buf.capacity() == 0 {
125                                            self.rx_msg_buf
126                                                .reserve_exact(self.config.rx_buf_min_size);
127                                        }
128                                        self.rx_msg_buf.extend_from_slice(next);
129                                        break 'decode;
130                                    }
131                                    Err(DecodeError::MalformedMessage) => {
132                                        break 'read Err(ReadError::MalformedMessage);
133                                    }
134                                }
135                            } else {
136                                break 'decode;
137                            }
138                        }
139                    }
140                }
141
142                Err(err) if err.kind() == io::ErrorKind::WouldBlock => break 'read Ok(()),
143
144                Err(err) => break 'read Err(ReadError::Error(err)),
145            }
146        }
147    }
148
149    /// Writes out as many bytes from the send buffer as possible, until blocking would start.
150    pub fn write(&mut self, now: Instant) -> io::Result<()> {
151        if !self.has_queued_data() {
152            return Ok(());
153        }
154
155        loop {
156            match self.attempt_write(now) {
157                Ok(written) => {
158                    let has_more = self.has_queued_data();
159                    log::trace!("wrote out {written} bytes, has more: {}", has_more);
160
161                    if !has_more {
162                        break Ok(());
163                    }
164                }
165
166                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
167                    log::trace!("write would block");
168                    break Ok(());
169                }
170
171                Err(err) => break Err(err),
172            }
173        }
174    }
175    /// Queues a message for sending. This method simply serializes the message and places it into
176    /// the internal buffer. `write` must be called when the socket is writeable in order to flush.
177    ///
178    /// Returns `true` if the write buffer contains enough space to accept the message, or `false`
179    /// if the buffer is full and the message cannot be queued at this time.
180    ///
181    /// Note: this will fail only if the buffer is full prior to even attempting to queue. A buffer
182    /// that is close to full will not reject a message, even if queueing might exceed the
183    /// configured limits.
184    #[must_use]
185    pub fn queue_message<M: Message>(&mut self, message: &M) -> bool {
186        if self.tx_msg_buf.len() <= self.config.tx_buf_max_size {
187            let encoded = message.encode(&mut self.tx_msg_buf);
188            self.tx_queue_points.append(encoded);
189            true
190        } else {
191            false
192        }
193    }
194
195    /// Returns whether the stream is stale on the write side, i.e. the data is not leaving the
196    /// send buffer in a timely manner.
197    pub fn is_write_stale(&self, now: Instant) -> bool {
198        match self.tx_queue_points.first() {
199            Some(t) => {
200                let timeout = self.config.stream_write_timeout;
201                (now - t > timeout) && (now - self.last_write > timeout)
202            }
203            None => false,
204        }
205    }
206
207    /// Resizes and shrinks the capacity of internal send and receive buffers by 1/3, until the
208    /// floor is reached. This helps maintain memory usage at sane levels since keeping permanent
209    /// large receive buffers (e.g. after receiving a large message) would eventually exhaust
210    /// available memory on less powerful devices when managing many peers.
211    pub fn shrink_buffers(&mut self) {
212        fn shrink(v: &mut Vec<u8>, min: usize) {
213            if v.capacity() > min {
214                let shrink_to = (2 * v.capacity()) / 3;
215                v.shrink_to(min.max(shrink_to));
216            }
217        }
218        shrink(&mut self.rx_msg_buf, self.config.rx_buf_min_size);
219        shrink(&mut self.tx_msg_buf, self.config.tx_buf_min_size);
220        self.tx_queue_points.shrink();
221    }
222
223    /// Determines the interest set wanted by the connection.
224    pub fn interest(&self) -> mio::Interest {
225        if self.has_queued_data() {
226            mio::Interest::READABLE | mio::Interest::WRITABLE
227        } else {
228            mio::Interest::READABLE
229        }
230    }
231
232    /// Takes some bytes from the local send buffer and sends them. Removes successfully sent bytes
233    /// from the buffer. Returns the number of bytes sent.
234    fn attempt_write(&mut self, now: Instant) -> io::Result<usize> {
235        let written = self.stream.write(&self.tx_msg_buf)?;
236        self.tx_msg_buf.drain(..written);
237        self.stream.flush()?;
238        self.last_write = now;
239        self.tx_queue_points.handle_write(written);
240        Ok(written)
241    }
242
243    /// Returns whether the send buffer has more data to write.
244    #[inline(always)]
245    fn has_queued_data(&self) -> bool {
246        !self.tx_msg_buf.is_empty()
247    }
248}
249
250impl MessageStream<mio::net::TcpStream> {
251    /// Returns `true` if the underlying stream is ready. Otherwise it tests readyness and
252    /// caches the result.
253    pub fn is_ready(&mut self) -> bool {
254        if !self.ready {
255            self.ready = self.stream.peer_addr().is_ok();
256        }
257        self.ready
258    }
259    /// Shuts down the underlying stream.
260    pub fn shutdown(self) -> io::Result<()> {
261        self.stream.shutdown(std::net::Shutdown::Both)
262    }
263
264    /// Returns the underlying stream as a Mio event source.
265    pub fn as_source(&mut self) -> &mut impl mio::event::Source {
266        &mut self.stream
267    }
268}
269
270/// Provides a collection that tracks points in time where a message of certain size was queued.
271/// This allows the consumer to track how long ago a message was attempted to be sent out and how
272/// many bytes are yet to be sent.
273mod queue_points {
274    use std::time::Instant;
275
276    /// A single queue point given a point in time and the remaining number of bytes.
277    #[derive(Debug)]
278    struct Point {
279        time: Instant,
280        left: usize,
281    }
282
283    /// A list of queue points, from oldest to newest.
284    #[derive(Debug, Default)]
285    pub struct Queue(Vec<Point>);
286
287    impl Queue {
288        /// Signals to the current queue point that a number of bytes were written out.
289        pub fn handle_write(&mut self, n_written: usize) {
290            let mut n_bytes_left = n_written;
291            let mut n_pop = 0;
292
293            for q in &mut self.0 {
294                let q_written = n_bytes_left.min(q.left);
295                n_bytes_left -= q_written;
296                q.left -= q_written;
297
298                if q.left == 0 {
299                    n_pop += 1;
300                }
301
302                if n_bytes_left == 0 {
303                    break;
304                }
305            }
306
307            assert_eq!(n_bytes_left, 0);
308            self.0.drain(..n_pop);
309        }
310
311        /// Appends a new queue point of a certain size.
312        pub fn append(&mut self, size: usize) {
313            self.0.push(Point {
314                time: Instant::now(),
315                left: size,
316            })
317        }
318
319        /// Returns the creation instant of the first queue point, if any.
320        pub fn first(&self) -> Option<Instant> {
321            self.0.first().map(|p| p.time)
322        }
323
324        /// Shrinks the capacity of the queue by 1/3, floored at 8 or current size.
325        pub fn shrink(&mut self) {
326            self.0.shrink_to(8.max((2 * self.0.capacity()) / 3));
327        }
328    }
329
330    #[cfg(test)]
331    #[test]
332    fn queue_behavior() {
333        let mut queue = Queue::default();
334
335        queue.append(10);
336        queue.append(20);
337        queue.append(30);
338
339        assert_eq!(queue.0[0].left, 10);
340        assert_eq!(queue.0[1].left, 20);
341        assert_eq!(queue.0[2].left, 30);
342
343        queue.handle_write(5);
344        assert_eq!(queue.0[0].left, 5);
345        assert_eq!(queue.0[1].left, 20);
346        assert_eq!(queue.0[2].left, 30);
347
348        queue.handle_write(5);
349        assert_eq!(queue.0[0].left, 20);
350        assert_eq!(queue.0[1].left, 30);
351
352        queue.handle_write(25);
353        assert_eq!(queue.0[0].left, 25);
354        assert_eq!(queue.0.len(), 1);
355
356        queue.handle_write(25);
357        assert!(queue.first().is_none());
358    }
359}
360
361#[cfg(test)]
362mod test {
363    use std::cell::RefCell;
364    use std::io::Cursor;
365
366    use super::*;
367
368    #[derive(Debug, Eq, PartialEq)]
369    struct Ping(u64);
370
371    impl Message for Ping {
372        fn encode(&self, dest: &mut impl std::io::Write) -> usize {
373            dest.write(&self.0.to_le_bytes()).unwrap()
374        }
375
376        fn decode(buffer: &[u8]) -> Result<(Self, usize), DecodeError> {
377            if buffer.len() >= 8 {
378                Ok((Ping(u64::from_le_bytes(buffer[..8].try_into().unwrap())), 8))
379            } else {
380                Err(DecodeError::NotEnoughData)
381            }
382        }
383    }
384
385    #[test]
386    fn reassemble_message_whole_reads() {
387        let mut buf = [0; 1024];
388        let mut cursor = Cursor::new(Vec::new());
389
390        Ping(0).encode(&mut cursor);
391        Ping(1).encode(&mut cursor);
392        cursor.set_position(0);
393
394        let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
395
396        let received: RefCell<Vec<Ping>> = Default::default();
397        let err = conn.read(&mut buf, |message| {
398            received.borrow_mut().push(message);
399        });
400
401        assert_eq!(received.borrow()[0], Ping(0));
402        assert_eq!(received.borrow()[1], Ping(1));
403        assert!(matches!(err, Err(ReadError::EndOfStream)));
404        assert_eq!(conn.stream.position(), 16);
405        assert!(conn.rx_msg_buf.is_empty());
406    }
407
408    #[test]
409    fn reassemble_message_partial_reads() {
410        let mut buf = [0; 1024];
411        let mut cursor = Cursor::new(Vec::new());
412        let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
413        let mut serialized = Vec::new();
414        Ping(u64::MAX - 1).encode(&mut serialized);
415        Ping(u64::MAX).encode(&mut serialized);
416
417        let received: RefCell<Vec<Ping>> = Default::default();
418
419        conn.stream.get_mut().extend_from_slice(&serialized[..4]);
420        let _ = conn.read(&mut buf, |message| {
421            received.borrow_mut().push(message);
422        });
423        assert!(received.borrow().is_empty());
424        assert_eq!(conn.rx_msg_buf.len(), 4);
425
426        conn.stream.get_mut().extend_from_slice(&serialized[4..]);
427        let _ = conn.read(&mut buf, |message| {
428            received.borrow_mut().push(message);
429        });
430        assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
431        assert_eq!(received.borrow()[1], Ping(u64::MAX));
432    }
433
434    #[test]
435    fn send_message() {
436        let mut wire = Cursor::new(Vec::<u8>::new());
437        let mut connection = MessageStream::new(&mut wire, StreamConfig::default());
438
439        assert!(connection.queue_message(&Ping(0)));
440        assert!(connection.queue_message(&Ping(1)));
441        assert!(connection.queue_message(&Ping(2)));
442
443        let cloned_buffer = connection.tx_msg_buf.clone();
444        connection.write(Instant::now()).unwrap();
445        assert_eq!(wire.position(), 24);
446        assert_eq!(wire.into_inner(), cloned_buffer);
447    }
448
449    #[test]
450    fn send_message_buf_full() {
451        let mut wire = Cursor::new(Vec::<u8>::new());
452        let config = StreamConfig {
453            tx_buf_min_size: 1,
454            tx_buf_max_size: 7,
455            ..Default::default()
456        };
457        let mut connection = MessageStream::new(&mut wire, config);
458
459        assert!(connection.queue_message(&Ping(0)));
460        assert!(!connection.queue_message(&Ping(1)));
461
462        let buffer_len = connection.tx_msg_buf.len();
463        let cloned_buffer = connection.tx_msg_buf.clone();
464        connection.write(Instant::now()).unwrap();
465        assert_eq!(wire.position(), buffer_len as u64);
466        assert_eq!(wire.into_inner(), cloned_buffer);
467    }
468}