ombrac_client/connection/
stream.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8/// A wrapper around a stream that ensures buffered data is read first.
9///
10/// This wrapper is used to ensure that any data remaining in the Framed codec's
11/// read buffer is consumed before reading from the underlying stream. This prevents
12/// data loss when transitioning from framed message-based communication to raw
13/// stream communication.
14pub struct BufferedStream<S> {
15    stream: S,
16    buffer: Bytes,
17    buffer_pos: usize,
18}
19
20impl<S> BufferedStream<S> {
21    /// Creates a new `BufferedStream` with optional initial buffer data.
22    ///
23    /// If `buffer` is provided and non-empty, it will be read first before
24    /// any data is read from the underlying `stream`.
25    pub fn new(stream: S, buffer: Bytes) -> Self {
26        Self {
27            stream,
28            buffer,
29            buffer_pos: 0,
30        }
31    }
32
33    /// Creates a new `BufferedStream` without any initial buffer.
34    ///
35    /// This is equivalent to `BufferedStream::new(stream, Bytes::new())`.
36    pub fn without_buffer(stream: S) -> Self {
37        Self::new(stream, Bytes::new())
38    }
39}
40
41impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
42    fn poll_read(
43        mut self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut ReadBuf<'_>,
46    ) -> Poll<io::Result<()>> {
47        // First, try to read from the buffer if there's any remaining data
48        if self.buffer_pos < self.buffer.len() {
49            let remaining = &self.buffer[self.buffer_pos..];
50            let to_copy = remaining.len().min(buf.remaining());
51
52            if to_copy > 0 {
53                buf.put_slice(&remaining[..to_copy]);
54                self.buffer_pos += to_copy;
55            }
56
57            // If we've consumed all buffer data, we can drop it to free memory
58            if self.buffer_pos >= self.buffer.len() {
59                self.buffer = Bytes::new();
60                self.buffer_pos = 0;
61            }
62
63            return Poll::Ready(Ok(()));
64        }
65
66        // Buffer is exhausted, read from the underlying stream
67        Pin::new(&mut self.stream).poll_read(cx, buf)
68    }
69}
70
71impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
72    fn poll_write(
73        mut self: Pin<&mut Self>,
74        cx: &mut Context<'_>,
75        buf: &[u8],
76    ) -> Poll<io::Result<usize>> {
77        Pin::new(&mut self.stream).poll_write(cx, buf)
78    }
79
80    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        Pin::new(&mut self.stream).poll_flush(cx)
82    }
83
84    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85        Pin::new(&mut self.stream).poll_shutdown(cx)
86    }
87}