1use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12use futures::stream::StreamExt;
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15
16use crate::source::Source;
17
18pub struct Sink;
19
20impl Sink {
21 pub async fn fold<T, Acc, F>(source: Source<T>, init: Acc, mut f: F) -> Acc
23 where
24 T: Send + 'static,
25 Acc: Send + 'static,
26 F: FnMut(Acc, T) -> Acc + Send + 'static,
27 {
28 source.into_boxed().fold(init, move |acc, x| futures::future::ready(f(acc, x))).await
29 }
30
31 pub async fn fold_async<T, Acc, F, Fut>(source: Source<T>, init: Acc, f: F) -> Acc
33 where
34 T: Send + 'static,
35 Acc: Send + 'static,
36 F: FnMut(Acc, T) -> Fut + Send + 'static,
37 Fut: Future<Output = Acc> + Send + 'static,
38 {
39 source.into_boxed().fold(init, f).await
40 }
41
42 pub async fn collect<T>(source: Source<T>) -> Vec<T>
44 where
45 T: Send + 'static,
46 {
47 source.into_boxed().collect().await
48 }
49
50 pub async fn first<T>(source: Source<T>) -> Option<T>
51 where
52 T: Send + 'static,
53 {
54 source.into_boxed().next().await
55 }
56
57 pub async fn last<T>(source: Source<T>) -> Option<T>
58 where
59 T: Send + 'static,
60 {
61 source.into_boxed().fold(None, |_, x| async move { Some(x) }).await
62 }
63
64 pub async fn sum<T>(source: Source<T>) -> T
65 where
66 T: Send + Default + std::ops::Add<Output = T> + 'static,
67 {
68 let init: T = T::default();
69 Self::fold(source, init, |acc, x| acc + x).await
70 }
71
72 pub async fn count<T>(source: Source<T>) -> u64
73 where
74 T: Send + 'static,
75 {
76 Self::fold(source, 0u64, |acc, _| acc + 1).await
77 }
78
79 pub async fn for_each<T, F>(source: Source<T>, mut f: F)
80 where
81 T: Send + 'static,
82 F: FnMut(T) + Send + 'static,
83 {
84 source
85 .into_boxed()
86 .for_each(move |x| {
87 f(x);
88 futures::future::ready(())
89 })
90 .await
91 }
92
93 pub async fn for_each_async<T, F, Fut>(source: Source<T>, parallelism: usize, f: F)
94 where
95 T: Send + 'static,
96 F: FnMut(T) -> Fut + Send + 'static,
97 Fut: Future<Output = ()> + Send + 'static,
98 {
99 let p = parallelism.max(1);
100 source.into_boxed().for_each_concurrent(p, f).await
101 }
102
103 pub async fn ignore<T: Send + 'static>(source: Source<T>) {
104 source.into_boxed().for_each(|_| futures::future::ready(())).await
105 }
106
107 pub async fn to_sender<T>(source: Source<T>, tx: tokio::sync::mpsc::UnboundedSender<T>)
110 where
111 T: Send + 'static,
112 {
113 let mut stream = source.into_boxed();
114 while let Some(v) = stream.next().await {
115 if tx.send(v).is_err() {
116 break;
117 }
118 }
119 }
120
121 pub fn queue<T>(source: Source<T>) -> SinkQueue<T>
125 where
126 T: Send + 'static,
127 {
128 let buf: Arc<Mutex<SinkQueueState<T>>> = Arc::new(Mutex::new(SinkQueueState::default()));
129 let notify = Arc::new(Notify::new());
130 let buf_t = Arc::clone(&buf);
131 let notify_t = Arc::clone(¬ify);
132 let handle = tokio::spawn(async move {
133 let mut stream = source.into_boxed();
134 while let Some(v) = stream.next().await {
135 buf_t.lock().items.push_back(v);
136 notify_t.notify_one();
137 }
138 buf_t.lock().complete = true;
139 notify_t.notify_waiters();
140 });
141 SinkQueue { buf, notify, _handle: handle }
142 }
143
144 pub async fn pull_with_timeout<T: Send + 'static>(q: &SinkQueue<T>, t: Duration) -> Option<T> {
146 tokio::time::timeout(t, q.pull()).await.ok().flatten()
147 }
148}
149
150struct SinkQueueState<T> {
151 items: std::collections::VecDeque<T>,
152 complete: bool,
153}
154
155impl<T> Default for SinkQueueState<T> {
156 fn default() -> Self {
157 Self { items: std::collections::VecDeque::new(), complete: false }
158 }
159}
160
161pub struct SinkQueue<T> {
162 buf: Arc<Mutex<SinkQueueState<T>>>,
163 notify: Arc<Notify>,
164 _handle: tokio::task::JoinHandle<()>,
165}
166
167impl<T: Send + 'static> SinkQueue<T> {
168 pub async fn pull(&self) -> Option<T> {
171 loop {
172 {
173 let mut guard = self.buf.lock();
174 if let Some(v) = guard.items.pop_front() {
175 return Some(v);
176 }
177 if guard.complete {
178 return None;
179 }
180 }
181 self.notify.notified().await;
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[tokio::test]
191 async fn first_last_sum_count() {
192 assert_eq!(Sink::first(Source::from_iter(vec![1, 2, 3])).await, Some(1));
193 assert_eq!(Sink::last(Source::from_iter(vec![1, 2, 3])).await, Some(3));
194 assert_eq!(Sink::sum(Source::from_iter(1..=10_i32)).await, 55);
195 assert_eq!(Sink::count(Source::from_iter(0..42_u64)).await, 42);
196 }
197
198 #[tokio::test]
199 async fn for_each_async_runs_all_tasks() {
200 let sum = std::sync::Arc::new(std::sync::Mutex::new(0i32));
201 let sum_c = sum.clone();
202 Sink::for_each_async(Source::from_iter(1..=5), 2, move |v| {
203 let sum_c = sum_c.clone();
204 async move {
205 *sum_c.lock().unwrap() += v;
206 }
207 })
208 .await;
209 assert_eq!(*sum.lock().unwrap(), 15);
210 }
211
212 #[tokio::test]
213 async fn sink_queue_pulls_until_complete() {
214 let q = Sink::queue(Source::from_iter(vec![10, 20, 30]));
215 assert_eq!(q.pull().await, Some(10));
216 assert_eq!(q.pull().await, Some(20));
217 assert_eq!(q.pull().await, Some(30));
218 assert_eq!(q.pull().await, None);
219 }
220}