use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use futures::stream::StreamExt;
use parking_lot::Mutex;
use tokio::sync::Notify;
use crate::source::Source;
pub struct Sink;
impl Sink {
pub async fn fold<T, Acc, F>(source: Source<T>, init: Acc, mut f: F) -> Acc
where
T: Send + 'static,
Acc: Send + 'static,
F: FnMut(Acc, T) -> Acc + Send + 'static,
{
source.into_boxed().fold(init, move |acc, x| futures::future::ready(f(acc, x))).await
}
pub async fn fold_async<T, Acc, F, Fut>(source: Source<T>, init: Acc, f: F) -> Acc
where
T: Send + 'static,
Acc: Send + 'static,
F: FnMut(Acc, T) -> Fut + Send + 'static,
Fut: Future<Output = Acc> + Send + 'static,
{
source.into_boxed().fold(init, f).await
}
pub async fn collect<T>(source: Source<T>) -> Vec<T>
where
T: Send + 'static,
{
source.into_boxed().collect().await
}
pub async fn first<T>(source: Source<T>) -> Option<T>
where
T: Send + 'static,
{
source.into_boxed().next().await
}
pub async fn last<T>(source: Source<T>) -> Option<T>
where
T: Send + 'static,
{
source.into_boxed().fold(None, |_, x| async move { Some(x) }).await
}
pub async fn sum<T>(source: Source<T>) -> T
where
T: Send + Default + std::ops::Add<Output = T> + 'static,
{
let init: T = T::default();
Self::fold(source, init, |acc, x| acc + x).await
}
pub async fn count<T>(source: Source<T>) -> u64
where
T: Send + 'static,
{
Self::fold(source, 0u64, |acc, _| acc + 1).await
}
pub async fn for_each<T, F>(source: Source<T>, mut f: F)
where
T: Send + 'static,
F: FnMut(T) + Send + 'static,
{
source
.into_boxed()
.for_each(move |x| {
f(x);
futures::future::ready(())
})
.await
}
pub async fn for_each_async<T, F, Fut>(source: Source<T>, parallelism: usize, f: F)
where
T: Send + 'static,
F: FnMut(T) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let p = parallelism.max(1);
source.into_boxed().for_each_concurrent(p, f).await
}
pub async fn ignore<T: Send + 'static>(source: Source<T>) {
source.into_boxed().for_each(|_| futures::future::ready(())).await
}
pub async fn to_sender<T>(source: Source<T>, tx: tokio::sync::mpsc::UnboundedSender<T>)
where
T: Send + 'static,
{
let mut stream = source.into_boxed();
while let Some(v) = stream.next().await {
if tx.send(v).is_err() {
break;
}
}
}
pub fn queue<T>(source: Source<T>) -> SinkQueue<T>
where
T: Send + 'static,
{
let buf: Arc<Mutex<SinkQueueState<T>>> = Arc::new(Mutex::new(SinkQueueState::default()));
let notify = Arc::new(Notify::new());
let buf_t = Arc::clone(&buf);
let notify_t = Arc::clone(¬ify);
let handle = tokio::spawn(async move {
let mut stream = source.into_boxed();
while let Some(v) = stream.next().await {
buf_t.lock().items.push_back(v);
notify_t.notify_one();
}
buf_t.lock().complete = true;
notify_t.notify_waiters();
});
SinkQueue { buf, notify, _handle: handle }
}
pub async fn pull_with_timeout<T: Send + 'static>(q: &SinkQueue<T>, t: Duration) -> Option<T> {
tokio::time::timeout(t, q.pull()).await.ok().flatten()
}
}
struct SinkQueueState<T> {
items: std::collections::VecDeque<T>,
complete: bool,
}
impl<T> Default for SinkQueueState<T> {
fn default() -> Self {
Self { items: std::collections::VecDeque::new(), complete: false }
}
}
pub struct SinkQueue<T> {
buf: Arc<Mutex<SinkQueueState<T>>>,
notify: Arc<Notify>,
_handle: tokio::task::JoinHandle<()>,
}
impl<T: Send + 'static> SinkQueue<T> {
pub async fn pull(&self) -> Option<T> {
loop {
{
let mut guard = self.buf.lock();
if let Some(v) = guard.items.pop_front() {
return Some(v);
}
if guard.complete {
return None;
}
}
self.notify.notified().await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn first_last_sum_count() {
assert_eq!(Sink::first(Source::from_iter(vec![1, 2, 3])).await, Some(1));
assert_eq!(Sink::last(Source::from_iter(vec![1, 2, 3])).await, Some(3));
assert_eq!(Sink::sum(Source::from_iter(1..=10_i32)).await, 55);
assert_eq!(Sink::count(Source::from_iter(0..42_u64)).await, 42);
}
#[tokio::test]
async fn for_each_async_runs_all_tasks() {
let sum = std::sync::Arc::new(std::sync::Mutex::new(0i32));
let sum_c = sum.clone();
Sink::for_each_async(Source::from_iter(1..=5), 2, move |v| {
let sum_c = sum_c.clone();
async move {
*sum_c.lock().unwrap() += v;
}
})
.await;
assert_eq!(*sum.lock().unwrap(), 15);
}
#[tokio::test]
async fn sink_queue_pulls_until_complete() {
let q = Sink::queue(Source::from_iter(vec![10, 20, 30]));
assert_eq!(q.pull().await, Some(10));
assert_eq!(q.pull().await, Some(20));
assert_eq!(q.pull().await, Some(30));
assert_eq!(q.pull().await, None);
}
}