use crate::Result;
use bytes::{buf::Reader, Buf, Bytes};
use futures::{Stream, StreamExt};
use std::{io::Read, pin::Pin};
use tokio::io::AsyncWrite;
pub(crate) fn stream_as_reader<S>(stream: S) -> impl Read
where
S: Stream<Item = Result<Bytes>> + Send + Sync + 'static,
{
let handle = tokio::runtime::Handle::current();
TryStreamReader {
buffer: None,
stream: Box::pin(stream),
handle,
}
}
pub(crate) fn async_write_as_writer<W>(writer: W) -> tokio_util::io::SyncIoBridge<W>
where
W: AsyncWrite + Unpin + 'static,
{
tokio_util::io::SyncIoBridge::new(writer)
}
struct TryStreamReader {
buffer: Option<Reader<Bytes>>,
stream: Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + Sync>>,
handle: tokio::runtime::Handle,
}
impl Read for TryStreamReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if let Some(mut buffer) = self.buffer.take() {
if buffer.get_ref().remaining() > 0 {
let bytes_read = buffer.read(buf)?;
if buffer.get_ref().remaining() > 0 {
self.buffer = Some(buffer);
}
return Ok(bytes_read);
}
}
match self.handle.block_on(async { self.stream.next().await }) {
None => {
Ok(0)
}
Some(Err(e)) => {
Err(std::io::Error::new(std::io::ErrorKind::Other, e))
}
Some(Ok(bytes)) => {
let mut buffer = bytes.reader();
let bytes_read = buffer.read(buf)?;
if buffer.get_ref().remaining() > 0 {
self.buffer = Some(buffer);
}
Ok(bytes_read)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
use std::io::Cursor;
#[tokio::test]
async fn read_from_stream() {
const TEST_DATA_SIZE: usize = 10_000_000;
const MAX_READ_SIZE: usize = TEST_DATA_SIZE / 10;
const MIN_READ_SIZE: usize = 1;
let mut test_data = vec![0u8; TEST_DATA_SIZE];
let mut rand = rand::thread_rng();
rand.fill(&mut test_data[..]);
let mut chunks = Vec::new();
let mut cursor = Cursor::new(test_data.clone());
while cursor.position() < TEST_DATA_SIZE as u64 {
let mut chunk = vec![0u8; rand.gen_range(MIN_READ_SIZE..MAX_READ_SIZE)];
let bytes_read = cursor.read(&mut chunk[..]).unwrap();
chunk.truncate(bytes_read);
chunks.push(Bytes::from(chunk));
}
let stream = futures::stream::iter(chunks.into_iter().map(Result::<_>::Ok));
let mut reader = stream_as_reader(stream);
let read_data = tokio::task::spawn_blocking(move || {
let mut rand = rand::thread_rng();
let mut read_data = Vec::with_capacity(TEST_DATA_SIZE);
while read_data.len() < TEST_DATA_SIZE {
let mut chunk = vec![0u8; rand.gen_range(MIN_READ_SIZE..MAX_READ_SIZE)];
let bytes_read = reader.read(&mut chunk[..]).unwrap();
assert!(bytes_read > 0);
read_data.extend_from_slice(&chunk[0..bytes_read]);
}
let mut dummy = vec![0u8; 100];
assert_eq!(0, reader.read(&mut dummy[..]).unwrap());
read_data
})
.await
.unwrap();
assert_eq!(test_data, read_data);
}
}