1use std::future::Future;
9use std::sync::Arc;
10use std::time::Duration;
11
12use futures::stream::StreamExt;
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15
16use crate::source::Source;
17
18pub struct Sink;
19
20impl Sink {
21 pub async fn fold<T, Acc, F>(source: Source<T>, init: Acc, mut f: F) -> Acc
23 where
24 T: Send + 'static,
25 Acc: Send + 'static,
26 F: FnMut(Acc, T) -> Acc + Send + 'static,
27 {
28 source.into_boxed().fold(init, move |acc, x| futures::future::ready(f(acc, x))).await
29 }
30
31 pub async fn fold_async<T, Acc, F, Fut>(source: Source<T>, init: Acc, f: F) -> Acc
33 where
34 T: Send + 'static,
35 Acc: Send + 'static,
36 F: FnMut(Acc, T) -> Fut + Send + 'static,
37 Fut: Future<Output = Acc> + Send + 'static,
38 {
39 source.into_boxed().fold(init, f).await
40 }
41
42 pub async fn collect<T>(source: Source<T>) -> Vec<T>
44 where
45 T: Send + 'static,
46 {
47 source.into_boxed().collect().await
48 }
49
50 pub async fn first<T>(source: Source<T>) -> Option<T>
52 where
53 T: Send + 'static,
54 {
55 source.into_boxed().next().await
56 }
57
58 pub async fn last<T>(source: Source<T>) -> Option<T>
60 where
61 T: Send + 'static,
62 {
63 source.into_boxed().fold(None, |_, x| async move { Some(x) }).await
64 }
65
66 pub async fn sum<T>(source: Source<T>) -> T
68 where
69 T: Send + Default + std::ops::Add<Output = T> + 'static,
70 {
71 let init: T = T::default();
72 Self::fold(source, init, |acc, x| acc + x).await
73 }
74
75 pub async fn count<T>(source: Source<T>) -> u64
77 where
78 T: Send + 'static,
79 {
80 Self::fold(source, 0u64, |acc, _| acc + 1).await
81 }
82
83 pub async fn for_each<T, F>(source: Source<T>, mut f: F)
84 where
85 T: Send + 'static,
86 F: FnMut(T) + Send + 'static,
87 {
88 source
89 .into_boxed()
90 .for_each(move |x| {
91 f(x);
92 futures::future::ready(())
93 })
94 .await
95 }
96
97 pub async fn for_each_async<T, F, Fut>(source: Source<T>, parallelism: usize, f: F)
99 where
100 T: Send + 'static,
101 F: FnMut(T) -> Fut + Send + 'static,
102 Fut: Future<Output = ()> + Send + 'static,
103 {
104 let p = parallelism.max(1);
105 source.into_boxed().for_each_concurrent(p, f).await
106 }
107
108 pub async fn ignore<T: Send + 'static>(source: Source<T>) {
110 source.into_boxed().for_each(|_| futures::future::ready(())).await
111 }
112
113 pub async fn to_sender<T>(source: Source<T>, tx: tokio::sync::mpsc::UnboundedSender<T>)
116 where
117 T: Send + 'static,
118 {
119 let mut stream = source.into_boxed();
120 while let Some(v) = stream.next().await {
121 if tx.send(v).is_err() {
122 break;
123 }
124 }
125 }
126
127 pub fn queue<T>(source: Source<T>) -> SinkQueue<T>
131 where
132 T: Send + 'static,
133 {
134 let buf: Arc<Mutex<SinkQueueState<T>>> = Arc::new(Mutex::new(SinkQueueState::default()));
135 let notify = Arc::new(Notify::new());
136 let buf_t = Arc::clone(&buf);
137 let notify_t = Arc::clone(¬ify);
138 let handle = tokio::spawn(async move {
139 let mut stream = source.into_boxed();
140 while let Some(v) = stream.next().await {
141 buf_t.lock().items.push_back(v);
142 notify_t.notify_one();
143 }
144 buf_t.lock().complete = true;
145 notify_t.notify_waiters();
146 });
147 SinkQueue { buf, notify, _handle: handle }
148 }
149
150 pub async fn pull_with_timeout<T: Send + 'static>(q: &SinkQueue<T>, t: Duration) -> Option<T> {
152 tokio::time::timeout(t, q.pull()).await.ok().flatten()
153 }
154}
155
156struct SinkQueueState<T> {
157 items: std::collections::VecDeque<T>,
158 complete: bool,
159}
160
161impl<T> Default for SinkQueueState<T> {
162 fn default() -> Self {
163 Self { items: std::collections::VecDeque::new(), complete: false }
164 }
165}
166
167pub struct SinkQueue<T> {
168 buf: Arc<Mutex<SinkQueueState<T>>>,
169 notify: Arc<Notify>,
170 _handle: tokio::task::JoinHandle<()>,
171}
172
173impl<T: Send + 'static> SinkQueue<T> {
174 pub async fn pull(&self) -> Option<T> {
177 loop {
178 {
179 let mut guard = self.buf.lock();
180 if let Some(v) = guard.items.pop_front() {
181 return Some(v);
182 }
183 if guard.complete {
184 return None;
185 }
186 }
187 self.notify.notified().await;
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[tokio::test]
197 async fn first_last_sum_count() {
198 assert_eq!(Sink::first(Source::from_iter(vec![1, 2, 3])).await, Some(1));
199 assert_eq!(Sink::last(Source::from_iter(vec![1, 2, 3])).await, Some(3));
200 assert_eq!(Sink::sum(Source::from_iter(1..=10_i32)).await, 55);
201 assert_eq!(Sink::count(Source::from_iter(0..42_u64)).await, 42);
202 }
203
204 #[tokio::test]
205 async fn for_each_async_runs_all_tasks() {
206 let sum = std::sync::Arc::new(std::sync::Mutex::new(0i32));
207 let sum_c = sum.clone();
208 Sink::for_each_async(Source::from_iter(1..=5), 2, move |v| {
209 let sum_c = sum_c.clone();
210 async move {
211 *sum_c.lock().unwrap() += v;
212 }
213 })
214 .await;
215 assert_eq!(*sum.lock().unwrap(), 15);
216 }
217
218 #[tokio::test]
219 async fn sink_queue_pulls_until_complete() {
220 let q = Sink::queue(Source::from_iter(vec![10, 20, 30]));
221 assert_eq!(q.pull().await, Some(10));
222 assert_eq!(q.pull().await, Some(20));
223 assert_eq!(q.pull().await, Some(30));
224 assert_eq!(q.pull().await, None);
225 }
226}