use bytes::Buf;
use futures::SinkExt;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tracing::instrument;
#[derive(Debug)]
pub struct SendStream<W = OwnedWriteHalf> {
framed: tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec>,
}
impl<W: AsyncWrite + Unpin> SendStream<W> {
pub fn new(stream: W) -> Self {
let framed = tokio_util::codec::FramedWrite::new(
stream,
tokio_util::codec::LengthDelimitedCodec::new(),
);
Self { framed }
}
pub async fn send_batch_message<T: serde::Serialize>(&mut self, obj: &T) -> anyhow::Result<()> {
let bytes = bitcode::serialize(obj).map_err(anyhow::Error::from)?;
self.framed.send(bytes::Bytes::from(bytes)).await?;
Ok(())
}
pub async fn send_control_message<T: serde::Serialize>(
&mut self,
obj: &T,
) -> anyhow::Result<()> {
self.send_batch_message(obj).await?;
self.framed.flush().await?;
Ok(())
}
#[instrument(level = "trace", skip(self, obj, reader))]
pub async fn send_message_with_data_buffered<T: serde::Serialize, R: AsyncBufRead + Unpin>(
&mut self,
obj: &T,
reader: &mut R,
) -> anyhow::Result<u64> {
self.send_batch_message(obj).await?;
let data_stream = self.framed.get_mut();
let bytes_copied = tokio::io::copy_buf(reader, data_stream).await?;
Ok(bytes_copied)
}
pub async fn close(&mut self) -> anyhow::Result<()> {
self.framed.close().await?;
Ok(())
}
}
pub type SharedSendStream<W = OwnedWriteHalf> = std::sync::Arc<tokio::sync::Mutex<SendStream<W>>>;
pub type BoxedWrite = Box<dyn AsyncWrite + Unpin + Send>;
pub type BoxedRead = Box<dyn AsyncRead + Unpin + Send>;
pub type BoxedSendStream = SendStream<BoxedWrite>;
pub type BoxedRecvStream = RecvStream<BoxedRead>;
pub type BoxedSharedSendStream = SharedSendStream<BoxedWrite>;
#[derive(Debug)]
pub struct RecvStream<R = OwnedReadHalf> {
framed: tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec>,
}
impl<R: AsyncRead + Unpin> RecvStream<R> {
pub fn new(stream: R) -> Self {
let framed = tokio_util::codec::FramedRead::new(
stream,
tokio_util::codec::LengthDelimitedCodec::new(),
);
Self { framed }
}
pub async fn recv_object<T: serde::de::DeserializeOwned>(
&mut self,
) -> anyhow::Result<Option<T>> {
if let Some(frame) = futures::StreamExt::next(&mut self.framed).await {
let bytes = frame?;
let obj = bitcode::deserialize(&bytes).map_err(anyhow::Error::from)?;
Ok(Some(obj))
} else {
Ok(None)
}
}
#[instrument(level = "trace", skip(self, writer))]
pub async fn copy_to<W: tokio::io::AsyncWrite + Unpin>(
&mut self,
writer: &mut W,
) -> anyhow::Result<u64> {
let read_buffer = self.framed.read_buffer();
let buffer_size = read_buffer.len() as u64;
writer.write_all(read_buffer).await?;
let data_stream = self.framed.get_mut();
let stream_bytes = tokio::io::copy(data_stream, writer).await?;
Ok(buffer_size + stream_bytes)
}
#[instrument(level = "trace", skip(self, writer))]
pub async fn copy_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
&mut self,
writer: &mut W,
buffer_size: usize,
) -> anyhow::Result<u64> {
let read_buffer = self.framed.read_buffer();
let buffered_bytes = read_buffer.len() as u64;
writer.write_all(read_buffer).await?;
let data_stream = self.framed.get_mut();
let mut buffered_stream = tokio::io::BufReader::with_capacity(buffer_size, data_stream);
let stream_bytes = tokio::io::copy_buf(&mut buffered_stream, writer).await?;
Ok(buffered_bytes + stream_bytes)
}
#[instrument(level = "trace", skip(self, writer))]
pub async fn copy_exact_to_buffered<W: tokio::io::AsyncWrite + Unpin>(
&mut self,
writer: &mut W,
size: u64,
buffer_size: usize,
) -> anyhow::Result<u64> {
if size == 0 {
return Ok(0);
}
let read_buffer = self.framed.read_buffer_mut();
let buffered = (read_buffer.len() as u64).min(size);
if buffered > 0 {
writer.write_all(&read_buffer[..buffered as usize]).await?;
read_buffer.advance(buffered as usize);
}
let remaining = size - buffered;
if remaining == 0 {
return Ok(size);
}
let data_stream = self.framed.get_mut();
let mut limited = data_stream.take(remaining);
let mut buf = vec![0u8; buffer_size.min(remaining as usize)];
let mut total_copied = buffered;
loop {
let bytes_to_read = buf.len().min((size - total_copied) as usize);
if bytes_to_read == 0 {
break;
}
let n = limited.read(&mut buf[..bytes_to_read]).await?;
if n == 0 {
break;
}
writer.write_all(&buf[..n]).await?;
total_copied += n as u64;
}
if total_copied != size {
anyhow::bail!(
"unexpected EOF: expected {} bytes, got {}",
size,
total_copied
);
}
Ok(size)
}
pub async fn close(&mut self) {
}
}
#[derive(Debug)]
pub struct ControlConnection {
send: SendStream,
recv: RecvStream,
}
impl ControlConnection {
pub fn new(stream: TcpStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
send: SendStream::new(write_half),
recv: RecvStream::new(read_half),
}
}
pub fn into_split(self) -> (SharedSendStream, RecvStream) {
(
std::sync::Arc::new(tokio::sync::Mutex::new(self.send)),
self.recv,
)
}
pub fn send_mut(&mut self) -> &mut SendStream {
&mut self.send
}
pub fn recv_mut(&mut self) -> &mut RecvStream {
&mut self.recv
}
}