use crate::error::Result;
use crate::io::sums::{ReaderStream, SharedReader};
use async_stream::stream;
use futures_util::Stream;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
use tokio::sync::mpsc;
#[derive(Debug)]
pub struct ChannelReader<R> {
inner: BufReader<R>,
txs: Vec<mpsc::Sender<Arc<[u8]>>>,
capacity: usize,
}
impl<R> ChannelReader<R>
where
R: AsyncRead + Unpin,
{
pub fn new(inner: R, capacity: usize) -> Self {
Self {
inner: BufReader::new(inner),
txs: vec![],
capacity,
}
}
pub fn into_inner(self) -> BufReader<R> {
self.inner
}
pub fn subscribe_stream(&mut self) -> impl Stream<Item = Result<Arc<[u8]>>> + 'static {
let (tx, mut rx) = mpsc::channel(self.capacity);
self.txs.push(tx);
stream! {
let mut msg = rx.recv().await;
while let Some(buf) = msg {
yield Ok(buf);
msg = rx.recv().await;
}
}
}
pub async fn send_to_end(&mut self) -> Result<u64> {
let txs = self.txs.drain(..);
let mut size = 0;
loop {
let mut buf = vec![0; 1000];
let n = self.inner.read(&mut buf).await?;
if n == 0 {
break;
}
size += n;
let buf: Arc<[u8]> = Arc::from(&buf[0..n]);
for tx in txs.as_ref() {
tx.send(buf.clone()).await?;
}
}
Ok(u64::try_from(size)?)
}
}
#[async_trait::async_trait]
impl<R> SharedReader for ChannelReader<R>
where
R: AsyncRead + Unpin + Send,
{
async fn read_chunks(&mut self) -> Result<u64> {
self.send_to_end().await
}
fn as_stream(&mut self) -> ReaderStream {
Box::pin(self.subscribe_stream())
}
}
#[cfg(test)]
pub(crate) mod test {
use super::*;
use crate::test::TestFileBuilder;
use anyhow::Result;
use futures_util::StreamExt;
use rand::Rng;
use std::io::Cursor;
#[tokio::test]
async fn test_stream() -> Result<()> {
let mut rng = TestFileBuilder::new()?.with_constant_seed().into_rng();
let mut data = vec![0; 100000];
rng.fill_bytes(&mut data);
let mut reader = channel_reader(Cursor::new(data.clone())).await;
let stream = reader.as_stream();
reader.read_chunks().await?;
let result: Vec<_> = stream
.map(|value| Ok(value?.to_vec()))
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
assert_eq!(result, data);
Ok(())
}
pub(crate) async fn channel_reader<R>(inner: R) -> ChannelReader<R>
where
R: AsyncRead + Unpin,
{
ChannelReader::new(inner, 1073741825)
}
}