use tokio::sync::mpsc;
pub enum PipeError {
InboundGone,
OutboundGone,
}
#[derive(Debug)]
pub struct Pipe<T>
where
T: std::fmt::Debug + Send + Sync + 'static,
{
rx: mpsc::Receiver<T>,
tx: mpsc::Sender<T>,
contents: Option<T>,
}
impl<T> Pipe<T>
where
T: std::fmt::Debug + Send + Sync + 'static,
{
pub fn new(rx: mpsc::Receiver<T>, tx: mpsc::Sender<T>, contents: Option<T>) -> Self {
Self { rx, tx, contents }
}
pub fn unterminated_pipeline(
length: usize,
total_capacity: Option<usize>,
) -> (mpsc::Sender<T>, Vec<Pipe<T>>, mpsc::Receiver<T>) {
let total_capacity = total_capacity.unwrap_or(length * 20);
let buffer = std::cmp::max(total_capacity / length, 1);
let (tx, mut rx) = mpsc::channel::<T>(buffer);
let mut pipeline = Vec::with_capacity(length);
(0..length).for_each(|_| {
let (next_tx, mut next_rx) = mpsc::channel::<T>(buffer);
std::mem::swap(&mut rx, &mut next_rx);
pipeline.push(Pipe::new(next_rx, next_tx, None));
});
(tx, pipeline, rx)
}
pub fn pipeline(
length: usize,
total_capacity: Option<usize>,
) -> (mpsc::Sender<T>, Vec<Pipe<T>>) {
let (tx, pipeline, mut rx) = Self::unterminated_pipeline(length, total_capacity);
tokio::spawn(async move {
loop {
if rx.recv().await.is_none() {
break;
}
}
});
(tx, pipeline)
}
pub fn take(&mut self) -> Option<T> {
self.contents.take()
}
pub fn read(&self) -> Option<&T> {
self.contents.as_ref()
}
pub fn to_owned(&self) -> Option<<T as ToOwned>::Owned>
where
T: ToOwned,
{
self.read().map(|c| c.to_owned())
}
async fn send(&mut self) -> Result<(), PipeError> {
if self.contents.is_some() {
let permit = self
.tx
.reserve()
.await
.map_err(|_| PipeError::OutboundGone)?;
permit.send(self.contents.take().unwrap());
}
Ok(())
}
pub async fn next(&mut self) -> Result<&T, PipeError> {
self.send().await?;
let next = self.rx.recv().await.ok_or(PipeError::InboundGone)?;
self.contents = Some(next);
Ok(self.read().expect("checked"))
}
pub fn nop(self) {
self.for_each(|_| {});
}
pub fn for_each<Func>(mut self, f: Func)
where
Func: Fn(&T) + Send + 'static,
{
tokio::spawn(async move {
while let Ok(contents) = self.next().await {
f(contents);
}
});
}
pub fn for_each_async<Func, Fut, Out>(mut self, f: Func)
where
Func: Fn(&T) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Out> + Send + 'static,
{
tokio::spawn(async move {
while let Ok(contents) = self.next().await {
f(contents).await;
}
});
}
}
impl<T> Drop for Pipe<T>
where
T: std::fmt::Debug + Send + Sync + 'static,
{
fn drop(&mut self) {
if let Some(contents) = self.contents.take() {
let _ = self.tx.try_send(contents);
}
}
}