ombrac_client/connection/
stream.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8pub struct BufferedStream<S> {
15 stream: S,
16 buffer: Bytes,
17 buffer_pos: usize,
18}
19
20impl<S> BufferedStream<S> {
21 pub fn new(stream: S, buffer: Bytes) -> Self {
26 Self {
27 stream,
28 buffer,
29 buffer_pos: 0,
30 }
31 }
32
33 pub fn without_buffer(stream: S) -> Self {
37 Self::new(stream, Bytes::new())
38 }
39}
40
41impl<S: AsyncRead + Unpin> AsyncRead for BufferedStream<S> {
42 fn poll_read(
43 mut self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut ReadBuf<'_>,
46 ) -> Poll<io::Result<()>> {
47 if self.buffer_pos < self.buffer.len() {
49 let remaining = &self.buffer[self.buffer_pos..];
50 let to_copy = remaining.len().min(buf.remaining());
51
52 if to_copy > 0 {
53 buf.put_slice(&remaining[..to_copy]);
54 self.buffer_pos += to_copy;
55 }
56
57 if self.buffer_pos >= self.buffer.len() {
59 self.buffer = Bytes::new();
60 self.buffer_pos = 0;
61 }
62
63 return Poll::Ready(Ok(()));
64 }
65
66 Pin::new(&mut self.stream).poll_read(cx, buf)
68 }
69}
70
71impl<S: AsyncWrite + Unpin> AsyncWrite for BufferedStream<S> {
72 fn poll_write(
73 mut self: Pin<&mut Self>,
74 cx: &mut Context<'_>,
75 buf: &[u8],
76 ) -> Poll<io::Result<usize>> {
77 Pin::new(&mut self.stream).poll_write(cx, buf)
78 }
79
80 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 Pin::new(&mut self.stream).poll_flush(cx)
82 }
83
84 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
85 Pin::new(&mut self.stream).poll_shutdown(cx)
86 }
87}