use std::{future::Future, sync::Arc};
use tokio::sync::mpsc::unbounded_channel;
#[async_trait::async_trait]
pub trait FutureConsumer {
type Item;
fn fut_consume(self, func: impl FnMut(Self::Item) + Send) -> impl Future<Output = ()>;
}
#[async_trait::async_trait]
impl<I, Fut> FutureConsumer for I
where
I: Iterator<Item = Fut>,
Fut: Future + Send + 'static,
Fut::Output: Send,
{
type Item = Fut::Output;
fn fut_consume(self, mut func: impl FnMut(Self::Item) + Send) -> impl Future<Output = ()> {
let mut rx = {
let (tx, rx) = unbounded_channel::<Self::Item>();
let tx = Arc::new(tx);
self.for_each(|fut| {
let tx = tx.clone();
tokio::spawn(async move {
let data = fut.await;
tx.send(data).expect("should send success");
});
});
rx
};
async move {
while let Some(data) = rx.recv().await {
func(data);
}
}
}
}
#[cfg(test)]
mod test {
use std::time::SystemTime;
use futures::future::join_all;
use tokio::time::{Duration, sleep};
use super::FutureConsumer;
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn available() {
(0..10)
.map(|item| async move { item * 2 })
.fut_consume(|item| assert_eq!(item % 2, 0))
.await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn time_check() {
let start = SystemTime::now();
vec![100, 200]
.into_iter()
.map(|item| async move {
sleep(Duration::from_millis(item)).await;
item
})
.fut_consume(|_| {
std::thread::sleep(std::time::Duration::from_millis(20));
})
.await;
let time1 = SystemTime::now().duration_since(start).unwrap();
let start = SystemTime::now();
let data = join_all(vec![100, 200].into_iter().map(|item| async move {
sleep(Duration::from_millis(item)).await;
item
}))
.await;
for _ in data.iter() {
sleep(Duration::from_millis(20)).await;
}
let time2 = SystemTime::now().duration_since(start).unwrap();
assert!(time1 < time2);
}
}