1use bytes::Buf;
2use futures::SinkExt;
3use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
5use tokio::net::TcpStream;
6use tracing::instrument;
7
8#[derive(Debug)]
12pub struct SendStream<W = OwnedWriteHalf> {
13 framed: tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec>,
14}
15
16impl<W: AsyncWrite + Unpin> SendStream<W> {
17 pub fn new(stream: W) -> Self {
18 let framed = tokio_util::codec::FramedWrite::new(
19 stream,
20 tokio_util::codec::LengthDelimitedCodec::new(),
21 );
22 Self { framed }
23 }
24
25 pub async fn send_batch_message<T: serde::Serialize>(&mut self, obj: &T) -> anyhow::Result<()> {
26 let bytes = bitcode::serialize(obj).map_err(anyhow::Error::from)?;
27 self.framed.send(bytes::Bytes::from(bytes)).await?;
28 Ok(())
29 }
30
31 pub async fn send_control_message<T: serde::Serialize>(
32 &mut self,
33 obj: &T,
34 ) -> anyhow::Result<()> {
35 self.send_batch_message(obj).await?;
36 self.framed.flush().await?;
37 Ok(())
38 }
39
40 #[instrument(level = "trace", skip(self, obj, reader))]
46 pub async fn send_message_with_data_buffered<T: serde::Serialize, R: AsyncBufRead + Unpin>(
47 &mut self,
48 obj: &T,
49 reader: &mut R,
50 ) -> anyhow::Result<u64> {
51 self.send_batch_message(obj).await?;
52 let data_stream = self.framed.get_mut();
53 let bytes_copied = tokio::io::copy_buf(reader, data_stream).await?;
54 Ok(bytes_copied)
55 }
56
57 pub async fn close(&mut self) -> anyhow::Result<()> {
58 self.framed.close().await?;
59 Ok(())
60 }
61}
62
63pub type SharedSendStream<W = OwnedWriteHalf> = std::sync::Arc<tokio::sync::Mutex<SendStream<W>>>;
64
65pub type BoxedWrite = Box<dyn AsyncWrite + Unpin + Send>;
67pub type BoxedRead = Box<dyn AsyncRead + Unpin + Send>;
69pub type BoxedSendStream = SendStream<BoxedWrite>;
71pub type BoxedRecvStream = RecvStream<BoxedRead>;
73pub type BoxedSharedSendStream = SharedSendStream<BoxedWrite>;
75
76#[derive(Debug)]
80pub struct RecvStream<R = OwnedReadHalf> {
81 framed: tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec>,
82}
83
84impl<R: AsyncRead + Unpin> RecvStream<R> {
85 pub fn new(stream: R) -> Self {
86 let framed = tokio_util::codec::FramedRead::new(
87 stream,
88 tokio_util::codec::LengthDelimitedCodec::new(),
89 );
90 Self { framed }
91 }
92
93 pub async fn recv_object<T: serde::de::DeserializeOwned>(
94 &mut self,
95 ) -> anyhow::Result<Option<T>> {
96 if let Some(frame) = futures::StreamExt::next(&mut self.framed).await {
97 let bytes = frame?;
98 let obj = bitcode::deserialize(&bytes).map_err(anyhow::Error::from)?;
99 Ok(Some(obj))
100 } else {
101 Ok(None)
102 }
103 }
104
105 #[instrument(level = "trace", skip(self, writer))]
109 pub async fn copy_to<W: tokio::io::AsyncWrite + Unpin>(
110 &mut self,
111 writer: &mut W,
112 ) -> anyhow::Result<u64> {
113 let read_buffer = self.framed.read_buffer();
114 let buffer_size = read_buffer.len() as u64;
115 writer.write_all(read_buffer).await?;
116 let data_stream = self.framed.get_mut();
117 let stream_bytes = tokio::io::copy(data_stream, writer).await?;
118 Ok(buffer_size + stream_bytes)
119 }
120
121 #[instrument(level = "trace", skip(self, writer))]
127 pub async fn copy_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
128 &mut self,
129 writer: &mut W,
130 buffer_size: usize,
131 ) -> anyhow::Result<u64> {
132 let read_buffer = self.framed.read_buffer();
133 let buffered_bytes = read_buffer.len() as u64;
134 writer.write_all(read_buffer).await?;
135 let data_stream = self.framed.get_mut();
136 let mut buffered_stream = tokio::io::BufReader::with_capacity(buffer_size, data_stream);
138 let stream_bytes = tokio::io::copy_buf(&mut buffered_stream, writer).await?;
139 Ok(buffered_bytes + stream_bytes)
140 }
141
142 #[instrument(level = "trace", skip(self, writer))]
148 pub async fn copy_exact_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
149 &mut self,
150 writer: &mut W,
151 size: u64,
152 buffer_size: usize,
153 ) -> anyhow::Result<u64> {
154 if size == 0 {
155 return Ok(0);
156 }
157 let read_buffer = self.framed.read_buffer_mut();
159 let buffered = (read_buffer.len() as u64).min(size);
160 if buffered > 0 {
161 writer.write_all(&read_buffer[..buffered as usize]).await?;
162 read_buffer.advance(buffered as usize);
163 }
164 let remaining = size - buffered;
165 if remaining == 0 {
166 return Ok(size);
167 }
168 let data_stream = self.framed.get_mut();
170 let mut limited = data_stream.take(remaining);
171 let mut buf = vec![0u8; buffer_size.min(remaining as usize)];
172 let mut total_copied = buffered;
173 loop {
174 let bytes_to_read = buf.len().min((size - total_copied) as usize);
175 if bytes_to_read == 0 {
176 break;
177 }
178 let n = limited.read(&mut buf[..bytes_to_read]).await?;
179 if n == 0 {
180 break;
181 }
182 writer.write_all(&buf[..n]).await?;
183 total_copied += n as u64;
184 }
185 if total_copied != size {
186 anyhow::bail!(
187 "unexpected EOF: expected {} bytes, got {}",
188 size,
189 total_copied
190 );
191 }
192 Ok(size)
193 }
194
195 pub async fn close(&mut self) {
196 }
198}
199
200#[derive(Debug)]
202pub struct ControlConnection {
203 send: SendStream,
204 recv: RecvStream,
205}
206
207impl ControlConnection {
208 pub fn new(stream: TcpStream) -> Self {
210 let (read_half, write_half) = stream.into_split();
211 Self {
212 send: SendStream::new(write_half),
213 recv: RecvStream::new(read_half),
214 }
215 }
216
217 pub fn into_split(self) -> (SharedSendStream, RecvStream) {
219 (
220 std::sync::Arc::new(tokio::sync::Mutex::new(self.send)),
221 self.recv,
222 )
223 }
224
225 pub fn send_mut(&mut self) -> &mut SendStream {
227 &mut self.send
228 }
229
230 pub fn recv_mut(&mut self) -> &mut RecvStream {
232 &mut self.recv
233 }
234}