use std::{io, time::Duration};
use bytes::Bytes;
use futures::SinkExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use super::{ClientError, runtime::ClientStream};
use crate::serializer::Serializer;
#[derive(Clone, Copy, Debug, Default)]
pub struct SendStreamingConfig {
chunk_size: Option<usize>,
timeout: Option<Duration>,
}
impl SendStreamingConfig {
#[must_use]
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.chunk_size = Some(size);
self
}
#[must_use]
pub fn with_timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self
}
#[must_use]
pub const fn chunk_size(&self) -> Option<usize> { self.chunk_size }
#[must_use]
pub const fn timeout(&self) -> Option<Duration> { self.timeout }
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SendStreamingOutcome {
frames_sent: u64,
}
impl SendStreamingOutcome {
#[must_use]
pub const fn new(frames_sent: u64) -> Self { Self { frames_sent } }
#[must_use]
pub const fn frames_sent(&self) -> u64 { self.frames_sent }
}
impl<S, T, C> super::WireframeClient<S, T, C>
where
S: Serializer + Send + Sync,
T: ClientStream,
{
pub async fn send_streaming<R: AsyncRead + Unpin>(
&mut self,
frame_header: &[u8],
body_reader: R,
config: SendStreamingConfig,
) -> Result<SendStreamingOutcome, ClientError> {
let effective_chunk_size = match effective_chunk_size(
frame_header.len(),
self.codec_config.max_frame_length_value(),
&config,
) {
Ok(size) => size,
Err(err) => {
self.invoke_error_hook(&err).await;
return Err(err);
}
};
match config.timeout {
Some(duration) => match tokio::time::timeout(
duration,
self.send_streaming_inner(frame_header, body_reader, effective_chunk_size),
)
.await
{
Ok(result) => result,
Err(_elapsed) => {
let _ = self.framed.get_mut().shutdown().await;
let err = ClientError::from(io::Error::new(
io::ErrorKind::TimedOut,
"streaming send timed out",
));
self.invoke_error_hook(&err).await;
Err(err)
}
},
None => {
self.send_streaming_inner(frame_header, body_reader, effective_chunk_size)
.await
}
}
}
async fn send_streaming_inner<R: AsyncRead + Unpin>(
&mut self,
frame_header: &[u8],
mut body_reader: R,
chunk_size: usize,
) -> Result<SendStreamingOutcome, ClientError> {
let header_len = frame_header.len();
let mut buf = vec![0u8; chunk_size];
let mut frames_sent: u64 = 0;
loop {
let n = match read_chunk(&mut body_reader, &mut buf).await {
ReadChunk::Eof => break,
ReadChunk::Bytes(n) => n,
ReadChunk::Err(e) => {
self.invoke_error_hook(&e).await;
return Err(e);
}
};
let chunk = buf.get(..n).ok_or_else(|| {
ClientError::from(io::Error::new(
io::ErrorKind::InvalidData,
"read returned more bytes than buffer length",
))
})?;
let mut frame = Vec::with_capacity(header_len + n);
frame.extend_from_slice(frame_header);
frame.extend_from_slice(chunk);
if let Err(e) = self.framed.send(Bytes::from(frame)).await {
let err = ClientError::from(e);
self.invoke_error_hook(&err).await;
return Err(err);
}
frames_sent += 1;
}
Ok(SendStreamingOutcome { frames_sent })
}
}
enum ReadChunk {
Eof,
Bytes(usize),
Err(ClientError),
}
async fn read_chunk<R: AsyncRead + Unpin>(reader: &mut R, buf: &mut [u8]) -> ReadChunk {
match reader.read(buf).await {
Ok(0) => ReadChunk::Eof,
Ok(n) => ReadChunk::Bytes(n),
Err(e) => ReadChunk::Err(ClientError::from(e)),
}
}
fn effective_chunk_size(
header_len: usize,
max_frame_length: usize,
config: &SendStreamingConfig,
) -> Result<usize, ClientError> {
if header_len >= max_frame_length {
return Err(ClientError::from(io::Error::new(
io::ErrorKind::InvalidInput,
concat!(
"frame header length meets or exceeds max frame length; ",
"no room for body data",
),
)));
}
let available = max_frame_length - header_len;
let size = match config.chunk_size {
Some(requested) => requested.min(available),
None => available,
};
if size == 0 {
return Err(ClientError::from(io::Error::new(
io::ErrorKind::InvalidInput,
"chunk size must be greater than zero",
)));
}
Ok(size)
}