use atomr_streams::{Sink, Source};
pub fn source_from_unbounded<T: Send + 'static>(
rx: tokio::sync::mpsc::UnboundedReceiver<T>,
) -> Source<T> {
Source::from_receiver(rx)
}
pub fn gpu_stage<I, O, F, Fut>(source: Source<I>, parallelism: usize, f: F) -> Source<O>
where
I: Send + 'static,
O: Send + 'static,
F: FnMut(I) -> Fut + Send + 'static,
Fut: std::future::Future<Output = O> + Send + 'static,
{
source.map_async(parallelism.max(1), f)
}
pub async fn run_collect<I, O, F, Fut>(
rx: tokio::sync::mpsc::UnboundedReceiver<I>,
parallelism: usize,
stage: F,
) -> Vec<O>
where
I: Send + 'static,
O: Send + 'static,
F: FnMut(I) -> Fut + Send + 'static,
Fut: std::future::Future<Output = O> + Send + 'static,
{
let s = gpu_stage(source_from_unbounded(rx), parallelism, stage);
Sink::collect(s).await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unbounded_round_trips_through_async_stage() {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<u32>();
for i in 1..=5 {
tx.send(i).unwrap();
}
drop(tx);
let mut got = run_collect::<u32, u32, _, _>(rx, 4, |x| async move { x * 10 }).await;
got.sort();
assert_eq!(got, vec![10, 20, 30, 40, 50]);
}
}