Skip to main content

olympipe_rs/
pipeline.rs

1use crossbeam_channel::{bounded, unbounded, Receiver, RecvTimeoutError, Sender};
2use std::sync::{Arc, Mutex};
3use std::thread::{self, JoinHandle};
4use std::time::Duration;
5
6const DEFAULT_BUFFER: usize = 4;
7
8/// A parallel data pipeline backed by bounded crossbeam channels and worker threads.
9///
10/// Each stage method consumes the pipeline and returns a new one whose type
11/// reflects the output of that stage.  All worker threads are tracked in a
12/// shared `Arc<Mutex<Vec<JoinHandle>>>` so that calling one of the terminal
13/// methods (`collect`, `for_each`, `wait_for_completion`) joins every thread
14/// spawned by the whole graph.
15pub struct Pipeline<T: Send + 'static> {
16    pub(crate) receiver: Receiver<T>,
17    pub(crate) handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
18    pub(crate) buffer_size: usize,
19}
20
21// ─── constructor ──────────────────────────────────────────────────────────────
22
23impl<T: Send + 'static> Pipeline<T> {
24    /// Create a pipeline from any iterator.  A feeder thread is spawned
25    /// immediately to push items into the first bounded channel.
26    pub fn new(iter: impl IntoIterator<Item = T> + Send + 'static) -> Self {
27        let (tx, rx) = bounded(DEFAULT_BUFFER);
28        let handles: Arc<Mutex<Vec<JoinHandle<()>>>> = Arc::new(Mutex::new(Vec::new()));
29        let h = thread::spawn(move || {
30            for item in iter {
31                if tx.send(item).is_err() {
32                    break;
33                }
34            }
35        });
36        handles.lock().unwrap().push(h);
37        Pipeline {
38            receiver: rx,
39            handles,
40            buffer_size: DEFAULT_BUFFER,
41        }
42    }
43
44    /// Override the inter-stage channel capacity (default: 4, matching olympipe).
45    pub fn with_buffer(mut self, size: usize) -> Self {
46        self.buffer_size = size;
47        self
48    }
49}
50
51// ─── internal helpers ─────────────────────────────────────────────────────────
52
53impl<T: Send + 'static> Pipeline<T> {
54    /// Spawn `count` worker threads that each read from `self.receiver` and
55    /// write to a fresh bounded channel.  `worker` must be `Clone` so it can
56    /// be duplicated per thread.  Use `spawn_single` for stages that run
57    /// exactly one thread and whose closure cannot be cloned.
58    fn spawn_stage<U, W>(self, count: usize, worker: W) -> Pipeline<U>
59    where
60        U: Send + 'static,
61        W: Fn(Receiver<T>, Sender<U>) + Send + Clone + 'static,
62    {
63        let (tx, rx) = bounded::<U>(self.buffer_size);
64        let handles = Arc::clone(&self.handles);
65
66        for _ in 0..count {
67            let w = worker.clone();
68            let in_rx = self.receiver.clone();
69            let out_tx = tx.clone();
70            let h = thread::spawn(move || {
71                let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
72                    w(in_rx, out_tx);
73                }));
74                if let Err(payload) = result {
75                    let msg = payload
76                        .downcast_ref::<&str>()
77                        .copied()
78                        .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
79                        .unwrap_or("(unknown panic)");
80                    eprintln!("[olympipe-rs] worker panic: {msg}");
81                }
82            });
83            handles.lock().unwrap().push(h);
84        }
85
86        Pipeline {
87            receiver: rx,
88            handles,
89            buffer_size: self.buffer_size,
90        }
91    }
92
93    /// Like `spawn_stage` but for a single worker thread whose closure does
94    /// not need to implement `Clone`.
95    fn spawn_single<U, W>(self, worker: W) -> Pipeline<U>
96    where
97        U: Send + 'static,
98        W: FnOnce(Receiver<T>, Sender<U>) + Send + 'static,
99    {
100        let (tx, rx) = bounded::<U>(self.buffer_size);
101        let handles = Arc::clone(&self.handles);
102        let in_rx = self.receiver;
103
104        let h = thread::spawn(move || {
105            let result =
106                std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| worker(in_rx, tx)));
107            if let Err(payload) = result {
108                let msg = payload
109                    .downcast_ref::<&str>()
110                    .copied()
111                    .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
112                    .unwrap_or("(unknown panic)");
113                eprintln!("[olympipe-rs] worker panic: {msg}");
114            }
115        });
116        handles.lock().unwrap().push(h);
117
118        Pipeline {
119            receiver: rx,
120            handles,
121            buffer_size: self.buffer_size,
122        }
123    }
124}
125
126// ─── stage methods ────────────────────────────────────────────────────────────
127
128impl<T: Send + 'static> Pipeline<T> {
129    /// Map each item through `f`.  `count` worker threads run concurrently,
130    /// sharing the input channel (MPMC) and writing to a shared output channel.
131    pub fn task<U, F>(self, f: F, count: usize) -> Pipeline<U>
132    where
133        U: Send + 'static,
134        F: Fn(T) -> U + Send + Clone + 'static,
135    {
136        let count = count.max(1);
137        self.spawn_stage(count, move |rx, tx| {
138            for item in rx {
139                let out = f(item);
140                if tx.send(out).is_err() {
141                    break;
142                }
143            }
144        })
145    }
146
147    /// Map each item through `f` which may fail.  On error, `on_error` is
148    /// called with a clone of the original item and the error; if it returns
149    /// `Some(v)`, `v` is forwarded downstream; `None` skips the item.
150    pub fn task_or<U, E, F, H>(self, f: F, on_error: H) -> Pipeline<U>
151    where
152        T: Clone,
153        U: Send + 'static,
154        E: 'static,
155        F: Fn(T) -> Result<U, E> + Send + 'static,
156        H: Fn(T, E) -> Option<U> + Send + 'static,
157    {
158        self.spawn_single(move |rx, tx| {
159            for item in rx {
160                let cloned = item.clone();
161                match f(item) {
162                    Ok(out) => {
163                        if tx.send(out).is_err() {
164                            break;
165                        }
166                    }
167                    Err(e) => {
168                        if on_error(cloned, e)
169                            .map(|fallback| tx.send(fallback).is_err())
170                            .unwrap_or(false)
171                        {
172                            break;
173                        }
174                    }
175                }
176            }
177        })
178    }
179
180    /// Keep only items for which `f` returns `true`.
181    pub fn filter<F>(self, f: F) -> Pipeline<T>
182    where
183        F: Fn(&T) -> bool + Send + 'static,
184    {
185        self.spawn_single(move |rx, tx| {
186            for item in rx {
187                if f(&item) && tx.send(item).is_err() {
188                    break;
189                }
190            }
191        })
192    }
193
194    /// Group items into `Vec<T>` of at most `size` elements.
195    /// The last (potentially incomplete) batch is always emitted.
196    pub fn batch(self, size: usize) -> Pipeline<Vec<T>> {
197        assert!(size >= 1, "batch size must be >= 1");
198        self.spawn_single(move |rx, tx| {
199            let mut buf: Vec<T> = Vec::with_capacity(size);
200            for item in rx {
201                buf.push(item);
202                if buf.len() >= size {
203                    let full = std::mem::replace(&mut buf, Vec::with_capacity(size));
204                    if tx.send(full).is_err() {
205                        return;
206                    }
207                }
208            }
209            if !buf.is_empty() {
210                let _ = tx.send(buf);
211            }
212        })
213    }
214
215    /// Collect items arriving within `window` of each other into a `Vec<T>`.
216    /// A new batch begins whenever the inter-item gap exceeds `window`, or when
217    /// the upstream channel disconnects.
218    pub fn temporal_batch(self, window: Duration) -> Pipeline<Vec<T>> {
219        self.spawn_single(move |rx, tx| {
220            loop {
221                // Wait for the first item of a new batch (no timeout — block).
222                let first = match rx.recv() {
223                    Ok(item) => item,
224                    Err(_) => break, // upstream done, no pending batch
225                };
226                let mut batch = vec![first];
227
228                // Drain additional items within `window`.
229                loop {
230                    match rx.recv_timeout(window) {
231                        Ok(item) => batch.push(item),
232                        Err(RecvTimeoutError::Timeout) => break,
233                        Err(RecvTimeoutError::Disconnected) => {
234                            let _ = tx.send(batch);
235                            return;
236                        }
237                    }
238                }
239
240                if tx.send(batch).is_err() {
241                    return;
242                }
243            }
244        })
245    }
246
247    /// Apply `f` to each item and forward every element yielded by the
248    /// returned iterator (flatMap).
249    pub fn explode<U, I, F>(self, f: F) -> Pipeline<U>
250    where
251        U: Send + 'static,
252        I: IntoIterator<Item = U>,
253        F: Fn(T) -> I + Send + 'static,
254    {
255        self.spawn_single(move |rx, tx| {
256            for item in rx {
257                for out in f(item) {
258                    if tx.send(out).is_err() {
259                        return;
260                    }
261                }
262            }
263        })
264    }
265
266    /// Route each item to one or both of two output pipelines according to `f`.
267    ///
268    /// Both returned pipelines share the same handle registry, so a single
269    /// terminal call on either one will join all threads.
270    ///
271    /// The two output channels are **unbounded** so that the router thread can
272    /// fully drain its input even when the two branches are consumed
273    /// sequentially (calling `collect()` on one branch at a time). If you add
274    /// further bounded stages after `split`, back-pressure is restored there.
275    pub fn split<A, B, F>(self, f: F) -> (Pipeline<A>, Pipeline<B>)
276    where
277        A: Send + 'static,
278        B: Send + 'static,
279        F: Fn(T) -> (Option<A>, Option<B>) + Send + 'static,
280    {
281        let buf = self.buffer_size;
282        let (tx_a, rx_a) = unbounded::<A>();
283        let (tx_b, rx_b) = unbounded::<B>();
284        let handles = Arc::clone(&self.handles);
285
286        let h = thread::spawn(move || {
287            for item in self.receiver {
288                let (a, b) = f(item);
289                if let Some(v) = a {
290                    // If the A-branch is closed, keep feeding B.
291                    let _ = tx_a.send(v);
292                }
293                if let Some(v) = b {
294                    let _ = tx_b.send(v);
295                }
296            }
297        });
298        handles.lock().unwrap().push(h);
299
300        (
301            Pipeline {
302                receiver: rx_a,
303                handles: Arc::clone(&handles),
304                buffer_size: buf,
305            },
306            Pipeline {
307                receiver: rx_b,
308                handles,
309                buffer_size: buf,
310            },
311        )
312    }
313
314    /// Merge `self` and all pipelines in `others` into a single output stream.
315    /// One forwarder thread is spawned per source; all handles are merged into
316    /// the returned pipeline's handle registry.
317    pub fn gather(self, others: Vec<Pipeline<T>>) -> Pipeline<T> {
318        let buf = self.buffer_size;
319        let (tx, rx) = bounded::<T>(buf);
320        let handles = Arc::clone(&self.handles);
321
322        // Forward self.
323        let tx0 = tx.clone();
324        let self_rx = self.receiver;
325        let h = thread::spawn(move || {
326            for item in self_rx {
327                if tx0.send(item).is_err() {
328                    break;
329                }
330            }
331        });
332        handles.lock().unwrap().push(h);
333
334        // Forward each other pipeline, merging their handles first.
335        for other in others {
336            let mut other_handles = other.handles.lock().unwrap();
337            let drained: Vec<_> = other_handles.drain(..).collect();
338            drop(other_handles);
339            handles.lock().unwrap().extend(drained);
340
341            let tx_n = tx.clone();
342            let other_rx = other.receiver;
343            let h = thread::spawn(move || {
344                for item in other_rx {
345                    if tx_n.send(item).is_err() {
346                        break;
347                    }
348                }
349            });
350            handles.lock().unwrap().push(h);
351        }
352
353        Pipeline {
354            receiver: rx,
355            handles,
356            buffer_size: buf,
357        }
358    }
359
360    /// Fold all items into a single accumulated value and emit it as the sole
361    /// output item.
362    pub fn reduce<U, F>(self, init: U, f: F) -> Pipeline<U>
363    where
364        U: Send + 'static,
365        F: Fn(U, T) -> U + Send + 'static,
366    {
367        self.spawn_single(move |rx, tx| {
368            let mut acc = init;
369            for item in rx {
370                acc = f(acc, item);
371            }
372            let _ = tx.send(acc);
373        })
374    }
375
376    /// Pass at most `n` items downstream, then stop.
377    pub fn limit(self, n: usize) -> Pipeline<T> {
378        self.spawn_single(move |rx, tx| {
379            for (i, item) in rx.iter().enumerate() {
380                if tx.send(item).is_err() {
381                    break;
382                }
383                if i + 1 >= n {
384                    break;
385                }
386            }
387        })
388    }
389
390    /// Error (drop the sender, propagating disconnection downstream) if no
391    /// item arrives within `duration` between two consecutive items.
392    pub fn timeout(self, duration: Duration) -> Pipeline<T> {
393        self.spawn_single(move |rx, tx| {
394            // First item: block indefinitely.
395            match rx.recv() {
396                Err(_) => return,
397                Ok(first) => {
398                    if tx.send(first).is_err() {
399                        return;
400                    }
401                }
402            }
403            loop {
404                match rx.recv_timeout(duration) {
405                    Ok(item) => {
406                        if tx.send(item).is_err() {
407                            break;
408                        }
409                    }
410                    Err(RecvTimeoutError::Timeout) => {
411                        eprintln!(
412                            "[olympipe-rs] timeout: no item received within {:?}",
413                            duration
414                        );
415                        break;
416                    }
417                    Err(RecvTimeoutError::Disconnected) => break,
418                }
419            }
420        })
421    }
422
423    /// Print each item to stdout and forward it unchanged.
424    pub fn debug(self) -> Pipeline<T>
425    where
426        T: std::fmt::Debug,
427    {
428        self.spawn_single(move |rx, tx| {
429            for item in rx {
430                println!("[olympipe-rs] {:?}", item);
431                if tx.send(item).is_err() {
432                    break;
433                }
434            }
435        })
436    }
437}
438
439// ─── terminal methods ─────────────────────────────────────────────────────────
440
441impl<T: Send + 'static> Pipeline<T> {
442    /// Collect all output items into a `Vec`, then join every worker thread.
443    pub fn collect(self) -> Vec<T> {
444        let items: Vec<T> = self.receiver.into_iter().collect();
445        let mut handles = self.handles.lock().unwrap();
446        for h in handles.drain(..) {
447            let _ = h.join();
448        }
449        items
450    }
451
452    /// Apply `f` to each output item, then join every worker thread.
453    pub fn for_each<F>(self, mut f: F)
454    where
455        F: FnMut(T),
456    {
457        for item in self.receiver {
458            f(item);
459        }
460        let mut handles = self.handles.lock().unwrap();
461        for h in handles.drain(..) {
462            let _ = h.join();
463        }
464    }
465
466    /// Drain and discard all output, then join every worker thread.
467    pub fn wait_for_completion(self) {
468        for _ in self.receiver {}
469        let mut handles = self.handles.lock().unwrap();
470        for h in handles.drain(..) {
471            let _ = h.join();
472        }
473    }
474}