Skip to main content

remote/
streams.rs

1use bytes::Buf;
2use futures::SinkExt;
3use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
5use tokio::net::TcpStream;
6use tracing::instrument;
7
8/// Framed send stream for length-delimited messages.
9///
10/// Generic over the underlying writer type - works with TCP, TLS, or any AsyncWrite.
11#[derive(Debug)]
12pub struct SendStream<W = OwnedWriteHalf> {
13    framed: tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec>,
14}
15
16impl<W: AsyncWrite + Unpin> SendStream<W> {
17    pub fn new(stream: W) -> Self {
18        let framed = tokio_util::codec::FramedWrite::new(
19            stream,
20            tokio_util::codec::LengthDelimitedCodec::new(),
21        );
22        Self { framed }
23    }
24
25    pub async fn send_batch_message<T: serde::Serialize>(&mut self, obj: &T) -> anyhow::Result<()> {
26        let bytes = bitcode::serialize(obj).map_err(anyhow::Error::from)?;
27        self.framed.send(bytes::Bytes::from(bytes)).await?;
28        Ok(())
29    }
30
31    pub async fn send_control_message<T: serde::Serialize>(
32        &mut self,
33        obj: &T,
34    ) -> anyhow::Result<()> {
35        self.send_batch_message(obj).await?;
36        self.framed.flush().await?;
37        Ok(())
38    }
39
40    /// Sends an object followed by data from a buffered reader.
41    ///
42    /// This method uses `copy_buf` which avoids internal buffer allocation by using
43    /// the reader's existing buffer. Wrap your reader in `BufReader::with_capacity(size, reader)`
44    /// to control the buffer size.
45    #[instrument(level = "trace", skip(self, obj, reader))]
46    pub async fn send_message_with_data_buffered<T: serde::Serialize, R: AsyncBufRead + Unpin>(
47        &mut self,
48        obj: &T,
49        reader: &mut R,
50    ) -> anyhow::Result<u64> {
51        self.send_batch_message(obj).await?;
52        let data_stream = self.framed.get_mut();
53        let bytes_copied = tokio::io::copy_buf(reader, data_stream).await?;
54        Ok(bytes_copied)
55    }
56
57    pub async fn close(&mut self) -> anyhow::Result<()> {
58        self.framed.close().await?;
59        Ok(())
60    }
61}
62
63pub type SharedSendStream<W = OwnedWriteHalf> = std::sync::Arc<tokio::sync::Mutex<SendStream<W>>>;
64
65/// Type alias for boxed write stream (supports both TLS and plain TCP)
66pub type BoxedWrite = Box<dyn AsyncWrite + Unpin + Send>;
67/// Type alias for boxed read stream (supports both TLS and plain TCP)
68pub type BoxedRead = Box<dyn AsyncRead + Unpin + Send>;
69/// Send stream over boxed writer
70pub type BoxedSendStream = SendStream<BoxedWrite>;
71/// Recv stream over boxed reader
72pub type BoxedRecvStream = RecvStream<BoxedRead>;
73/// Shared send stream over boxed writer
74pub type BoxedSharedSendStream = SharedSendStream<BoxedWrite>;
75
76/// Framed receive stream for length-delimited messages.
77///
78/// Generic over the underlying reader type - works with TCP, TLS, or any AsyncRead.
79#[derive(Debug)]
80pub struct RecvStream<R = OwnedReadHalf> {
81    framed: tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec>,
82}
83
84impl<R: AsyncRead + Unpin> RecvStream<R> {
85    pub fn new(stream: R) -> Self {
86        let framed = tokio_util::codec::FramedRead::new(
87            stream,
88            tokio_util::codec::LengthDelimitedCodec::new(),
89        );
90        Self { framed }
91    }
92
93    pub async fn recv_object<T: serde::de::DeserializeOwned>(
94        &mut self,
95    ) -> anyhow::Result<Option<T>> {
96        if let Some(frame) = futures::StreamExt::next(&mut self.framed).await {
97            let bytes = frame?;
98            let obj = bitcode::deserialize(&bytes).map_err(anyhow::Error::from)?;
99            Ok(Some(obj))
100        } else {
101            Ok(None)
102        }
103    }
104
105    /// Copies data to a writer using the default buffer size (8 KiB).
106    ///
107    /// For better performance with large files, use [`Self::copy_to_buffered`] instead.
108    #[instrument(level = "trace", skip(self, writer))]
109    pub async fn copy_to<W: tokio::io::AsyncWrite + Unpin>(
110        &mut self,
111        writer: &mut W,
112    ) -> anyhow::Result<u64> {
113        let read_buffer = self.framed.read_buffer();
114        let buffer_size = read_buffer.len() as u64;
115        writer.write_all(read_buffer).await?;
116        let data_stream = self.framed.get_mut();
117        let stream_bytes = tokio::io::copy(data_stream, writer).await?;
118        Ok(buffer_size + stream_bytes)
119    }
120
121    /// Copies data to a writer using a custom buffer size.
122    ///
123    /// Uses a buffered reader around the TCP stream with the specified capacity.
124    /// This avoids the default 8 KiB buffer in `tokio::io::copy` and can significantly
125    /// improve throughput on high-bandwidth networks.
126    #[instrument(level = "trace", skip(self, writer))]
127    pub async fn copy_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
128        &mut self,
129        writer: &mut W,
130        buffer_size: usize,
131    ) -> anyhow::Result<u64> {
132        let read_buffer = self.framed.read_buffer();
133        let buffered_bytes = read_buffer.len() as u64;
134        writer.write_all(read_buffer).await?;
135        let data_stream = self.framed.get_mut();
136        // wrap the TCP recv stream in a BufReader to control the buffer size
137        let mut buffered_stream = tokio::io::BufReader::with_capacity(buffer_size, data_stream);
138        let stream_bytes = tokio::io::copy_buf(&mut buffered_stream, writer).await?;
139        Ok(buffered_bytes + stream_bytes)
140    }
141
142    /// Copies exactly `size` bytes to a writer using a custom buffer size.
143    ///
144    /// Unlike [`Self::copy_to_buffered`], this does NOT read until EOF. It reads
145    /// exactly the specified number of bytes, leaving the stream open for
146    /// reading subsequent messages.
147    #[instrument(level = "trace", skip(self, writer))]
148    pub async fn copy_exact_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
149        &mut self,
150        writer: &mut W,
151        size: u64,
152        buffer_size: usize,
153    ) -> anyhow::Result<u64> {
154        if size == 0 {
155            return Ok(0);
156        }
157        // first drain any buffered data from the framed reader
158        let read_buffer = self.framed.read_buffer_mut();
159        let buffered = (read_buffer.len() as u64).min(size);
160        if buffered > 0 {
161            writer.write_all(&read_buffer[..buffered as usize]).await?;
162            read_buffer.advance(buffered as usize);
163        }
164        let remaining = size - buffered;
165        if remaining == 0 {
166            return Ok(size);
167        }
168        // read exactly `remaining` bytes from the underlying stream
169        let data_stream = self.framed.get_mut();
170        let mut limited = data_stream.take(remaining);
171        let mut buf = vec![0u8; buffer_size.min(remaining as usize)];
172        let mut total_copied = buffered;
173        loop {
174            let bytes_to_read = buf.len().min((size - total_copied) as usize);
175            if bytes_to_read == 0 {
176                break;
177            }
178            let n = limited.read(&mut buf[..bytes_to_read]).await?;
179            if n == 0 {
180                break;
181            }
182            writer.write_all(&buf[..n]).await?;
183            total_copied += n as u64;
184        }
185        if total_copied != size {
186            anyhow::bail!(
187                "unexpected EOF: expected {} bytes, got {}",
188                size,
189                total_copied
190            );
191        }
192        Ok(size)
193    }
194
195    pub async fn close(&mut self) {
196        // for TCP, we just let the stream drop - no special cleanup needed
197    }
198}
199
200/// Connection wrapper for control channel (bidirectional TCP connection)
201#[derive(Debug)]
202pub struct ControlConnection {
203    send: SendStream,
204    recv: RecvStream,
205}
206
207impl ControlConnection {
208    /// Create a control connection from a TCP stream
209    pub fn new(stream: TcpStream) -> Self {
210        let (read_half, write_half) = stream.into_split();
211        Self {
212            send: SendStream::new(write_half),
213            recv: RecvStream::new(read_half),
214        }
215    }
216
217    /// Split into send and recv halves for independent use
218    pub fn into_split(self) -> (SharedSendStream, RecvStream) {
219        (
220            std::sync::Arc::new(tokio::sync::Mutex::new(self.send)),
221            self.recv,
222        )
223    }
224
225    /// Get mutable access to send stream
226    pub fn send_mut(&mut self) -> &mut SendStream {
227        &mut self.send
228    }
229
230    /// Get mutable access to recv stream
231    pub fn recv_mut(&mut self) -> &mut RecvStream {
232        &mut self.recv
233    }
234}