1use crossbeam_channel::{bounded, unbounded, Receiver, RecvTimeoutError, Sender};
2use std::sync::{Arc, Mutex};
3use std::thread::{self, JoinHandle};
4use std::time::Duration;
5
6const DEFAULT_BUFFER: usize = 4;
7
8pub struct Pipeline<T: Send + 'static> {
16 pub(crate) receiver: Receiver<T>,
17 pub(crate) handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
18 pub(crate) buffer_size: usize,
19}
20
21impl<T: Send + 'static> Pipeline<T> {
24 pub fn new(iter: impl IntoIterator<Item = T> + Send + 'static) -> Self {
27 let (tx, rx) = bounded(DEFAULT_BUFFER);
28 let handles: Arc<Mutex<Vec<JoinHandle<()>>>> = Arc::new(Mutex::new(Vec::new()));
29 let h = thread::spawn(move || {
30 for item in iter {
31 if tx.send(item).is_err() {
32 break;
33 }
34 }
35 });
36 handles.lock().unwrap().push(h);
37 Pipeline {
38 receiver: rx,
39 handles,
40 buffer_size: DEFAULT_BUFFER,
41 }
42 }
43
44 pub fn with_buffer(mut self, size: usize) -> Self {
46 self.buffer_size = size;
47 self
48 }
49}
50
51impl<T: Send + 'static> Pipeline<T> {
54 fn spawn_stage<U, W>(self, count: usize, worker: W) -> Pipeline<U>
59 where
60 U: Send + 'static,
61 W: Fn(Receiver<T>, Sender<U>) + Send + Clone + 'static,
62 {
63 let (tx, rx) = bounded::<U>(self.buffer_size);
64 let handles = Arc::clone(&self.handles);
65
66 for _ in 0..count {
67 let w = worker.clone();
68 let in_rx = self.receiver.clone();
69 let out_tx = tx.clone();
70 let h = thread::spawn(move || {
71 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72 w(in_rx, out_tx);
73 }));
74 if let Err(payload) = result {
75 let msg = payload
76 .downcast_ref::<&str>()
77 .copied()
78 .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
79 .unwrap_or("(unknown panic)");
80 eprintln!("[olympipe-rs] worker panic: {msg}");
81 }
82 });
83 handles.lock().unwrap().push(h);
84 }
85
86 Pipeline {
87 receiver: rx,
88 handles,
89 buffer_size: self.buffer_size,
90 }
91 }
92
93 fn spawn_single<U, W>(self, worker: W) -> Pipeline<U>
96 where
97 U: Send + 'static,
98 W: FnOnce(Receiver<T>, Sender<U>) + Send + 'static,
99 {
100 let (tx, rx) = bounded::<U>(self.buffer_size);
101 let handles = Arc::clone(&self.handles);
102 let in_rx = self.receiver;
103
104 let h = thread::spawn(move || {
105 let result =
106 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| worker(in_rx, tx)));
107 if let Err(payload) = result {
108 let msg = payload
109 .downcast_ref::<&str>()
110 .copied()
111 .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
112 .unwrap_or("(unknown panic)");
113 eprintln!("[olympipe-rs] worker panic: {msg}");
114 }
115 });
116 handles.lock().unwrap().push(h);
117
118 Pipeline {
119 receiver: rx,
120 handles,
121 buffer_size: self.buffer_size,
122 }
123 }
124}
125
126impl<T: Send + 'static> Pipeline<T> {
129 pub fn task<U, F>(self, f: F, count: usize) -> Pipeline<U>
132 where
133 U: Send + 'static,
134 F: Fn(T) -> U + Send + Clone + 'static,
135 {
136 let count = count.max(1);
137 self.spawn_stage(count, move |rx, tx| {
138 for item in rx {
139 let out = f(item);
140 if tx.send(out).is_err() {
141 break;
142 }
143 }
144 })
145 }
146
147 pub fn task_or<U, E, F, H>(self, f: F, on_error: H) -> Pipeline<U>
151 where
152 T: Clone,
153 U: Send + 'static,
154 E: 'static,
155 F: Fn(T) -> Result<U, E> + Send + 'static,
156 H: Fn(T, E) -> Option<U> + Send + 'static,
157 {
158 self.spawn_single(move |rx, tx| {
159 for item in rx {
160 let cloned = item.clone();
161 match f(item) {
162 Ok(out) => {
163 if tx.send(out).is_err() {
164 break;
165 }
166 }
167 Err(e) => {
168 if on_error(cloned, e)
169 .map(|fallback| tx.send(fallback).is_err())
170 .unwrap_or(false)
171 {
172 break;
173 }
174 }
175 }
176 }
177 })
178 }
179
180 pub fn filter<F>(self, f: F) -> Pipeline<T>
182 where
183 F: Fn(&T) -> bool + Send + 'static,
184 {
185 self.spawn_single(move |rx, tx| {
186 for item in rx {
187 if f(&item) && tx.send(item).is_err() {
188 break;
189 }
190 }
191 })
192 }
193
194 pub fn batch(self, size: usize) -> Pipeline<Vec<T>> {
197 assert!(size >= 1, "batch size must be >= 1");
198 self.spawn_single(move |rx, tx| {
199 let mut buf: Vec<T> = Vec::with_capacity(size);
200 for item in rx {
201 buf.push(item);
202 if buf.len() >= size {
203 let full = std::mem::replace(&mut buf, Vec::with_capacity(size));
204 if tx.send(full).is_err() {
205 return;
206 }
207 }
208 }
209 if !buf.is_empty() {
210 let _ = tx.send(buf);
211 }
212 })
213 }
214
215 pub fn temporal_batch(self, window: Duration) -> Pipeline<Vec<T>> {
219 self.spawn_single(move |rx, tx| {
220 loop {
221 let first = match rx.recv() {
223 Ok(item) => item,
224 Err(_) => break, };
226 let mut batch = vec![first];
227
228 loop {
230 match rx.recv_timeout(window) {
231 Ok(item) => batch.push(item),
232 Err(RecvTimeoutError::Timeout) => break,
233 Err(RecvTimeoutError::Disconnected) => {
234 let _ = tx.send(batch);
235 return;
236 }
237 }
238 }
239
240 if tx.send(batch).is_err() {
241 return;
242 }
243 }
244 })
245 }
246
247 pub fn explode<U, I, F>(self, f: F) -> Pipeline<U>
250 where
251 U: Send + 'static,
252 I: IntoIterator<Item = U>,
253 F: Fn(T) -> I + Send + 'static,
254 {
255 self.spawn_single(move |rx, tx| {
256 for item in rx {
257 for out in f(item) {
258 if tx.send(out).is_err() {
259 return;
260 }
261 }
262 }
263 })
264 }
265
266 pub fn split<A, B, F>(self, f: F) -> (Pipeline<A>, Pipeline<B>)
276 where
277 A: Send + 'static,
278 B: Send + 'static,
279 F: Fn(T) -> (Option<A>, Option<B>) + Send + 'static,
280 {
281 let buf = self.buffer_size;
282 let (tx_a, rx_a) = unbounded::<A>();
283 let (tx_b, rx_b) = unbounded::<B>();
284 let handles = Arc::clone(&self.handles);
285
286 let h = thread::spawn(move || {
287 for item in self.receiver {
288 let (a, b) = f(item);
289 if let Some(v) = a {
290 let _ = tx_a.send(v);
292 }
293 if let Some(v) = b {
294 let _ = tx_b.send(v);
295 }
296 }
297 });
298 handles.lock().unwrap().push(h);
299
300 (
301 Pipeline {
302 receiver: rx_a,
303 handles: Arc::clone(&handles),
304 buffer_size: buf,
305 },
306 Pipeline {
307 receiver: rx_b,
308 handles,
309 buffer_size: buf,
310 },
311 )
312 }
313
314 pub fn gather(self, others: Vec<Pipeline<T>>) -> Pipeline<T> {
318 let buf = self.buffer_size;
319 let (tx, rx) = bounded::<T>(buf);
320 let handles = Arc::clone(&self.handles);
321
322 let tx0 = tx.clone();
324 let self_rx = self.receiver;
325 let h = thread::spawn(move || {
326 for item in self_rx {
327 if tx0.send(item).is_err() {
328 break;
329 }
330 }
331 });
332 handles.lock().unwrap().push(h);
333
334 for other in others {
336 let mut other_handles = other.handles.lock().unwrap();
337 let drained: Vec<_> = other_handles.drain(..).collect();
338 drop(other_handles);
339 handles.lock().unwrap().extend(drained);
340
341 let tx_n = tx.clone();
342 let other_rx = other.receiver;
343 let h = thread::spawn(move || {
344 for item in other_rx {
345 if tx_n.send(item).is_err() {
346 break;
347 }
348 }
349 });
350 handles.lock().unwrap().push(h);
351 }
352
353 Pipeline {
354 receiver: rx,
355 handles,
356 buffer_size: buf,
357 }
358 }
359
360 pub fn reduce<U, F>(self, init: U, f: F) -> Pipeline<U>
363 where
364 U: Send + 'static,
365 F: Fn(U, T) -> U + Send + 'static,
366 {
367 self.spawn_single(move |rx, tx| {
368 let mut acc = init;
369 for item in rx {
370 acc = f(acc, item);
371 }
372 let _ = tx.send(acc);
373 })
374 }
375
376 pub fn limit(self, n: usize) -> Pipeline<T> {
378 self.spawn_single(move |rx, tx| {
379 for (i, item) in rx.iter().enumerate() {
380 if tx.send(item).is_err() {
381 break;
382 }
383 if i + 1 >= n {
384 break;
385 }
386 }
387 })
388 }
389
390 pub fn timeout(self, duration: Duration) -> Pipeline<T> {
393 self.spawn_single(move |rx, tx| {
394 match rx.recv() {
396 Err(_) => return,
397 Ok(first) => {
398 if tx.send(first).is_err() {
399 return;
400 }
401 }
402 }
403 loop {
404 match rx.recv_timeout(duration) {
405 Ok(item) => {
406 if tx.send(item).is_err() {
407 break;
408 }
409 }
410 Err(RecvTimeoutError::Timeout) => {
411 eprintln!(
412 "[olympipe-rs] timeout: no item received within {:?}",
413 duration
414 );
415 break;
416 }
417 Err(RecvTimeoutError::Disconnected) => break,
418 }
419 }
420 })
421 }
422
423 pub fn debug(self) -> Pipeline<T>
425 where
426 T: std::fmt::Debug,
427 {
428 self.spawn_single(move |rx, tx| {
429 for item in rx {
430 println!("[olympipe-rs] {:?}", item);
431 if tx.send(item).is_err() {
432 break;
433 }
434 }
435 })
436 }
437}
438
439impl<T: Send + 'static> Pipeline<T> {
442 pub fn collect(self) -> Vec<T> {
444 let items: Vec<T> = self.receiver.into_iter().collect();
445 let mut handles = self.handles.lock().unwrap();
446 for h in handles.drain(..) {
447 let _ = h.join();
448 }
449 items
450 }
451
452 pub fn for_each<F>(self, mut f: F)
454 where
455 F: FnMut(T),
456 {
457 for item in self.receiver {
458 f(item);
459 }
460 let mut handles = self.handles.lock().unwrap();
461 for h in handles.drain(..) {
462 let _ = h.join();
463 }
464 }
465
466 pub fn wait_for_completion(self) {
468 for _ in self.receiver {}
469 let mut handles = self.handles.lock().unwrap();
470 for h in handles.drain(..) {
471 let _ = h.join();
472 }
473 }
474}