async_pipes/pipeline/
mod.rs

1use std::any::{type_name, Any};
2use std::collections::HashMap;
3use std::fmt::{Display, Formatter};
4use std::future::Future;
5use std::num::NonZeroUsize;
6use std::panic;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use tokio::select;
11use tokio::sync::mpsc::Sender;
12use tokio::sync::Mutex;
13use tokio::task::{yield_now, JoinError, JoinSet};
14
15pub use builder::*;
16use io::{PipeReader, PipeWriter};
17use sync::Synchronizer;
18
19mod builder;
20mod io;
21mod sync;
22mod workers;
23
24const DEFAULT_MAX_TASK_COUNT: usize = 100;
25const DEFAULT_READER_BUFFER_SIZE: usize = 30;
26
27/// Options that can be passed to methods in the [PipelineBuilder] when defining stages.
28///
29/// This implements [Default] which makes it easier to specify options when defining stages.
30/// By default, each worker will be allowed 100 concurrent tasks maximum and the buffer of each pipe
31/// is set to 30.
32///
33/// # Examples
34///
35/// ```
36/// use async_pipes::{Pipeline, WorkerOptions};
37///
38/// #[tokio::main]
39/// async fn main() {
40///     let pipeline = Pipeline::builder()
41///         .with_inputs("Pipe", vec![()])
42///         .with_consumer("Pipe", WorkerOptions::default(), |_: ()| async move {
43///             /* ... */
44///         })
45///         .build();
46///
47///     assert!(pipeline.is_ok());
48/// }
49/// ```
50#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
51pub struct WorkerOptions {
52    /// The maximum number of items allowed per pipe before stages have to wait to write
53    /// more data to the pipe.
54    pub pipe_buffer_size: usize,
55
56    /// The maximum number of tasks that a worker can be concurrently running. Once this number
57    /// is reached in a worker, the worker will poll for tasks completions before spawning more.
58    pub max_task_count: usize,
59}
60
61impl Default for WorkerOptions {
62    fn default() -> Self {
63        Self::default_multi_task()
64    }
65}
66
67impl WorkerOptions {
68    /// Like the [Default] implementation, but specifies `1` for [WorkerOptions::max_task_count].
69    pub fn default_single_task() -> Self {
70        Self {
71            max_task_count: 1,
72            ..Default::default()
73        }
74    }
75
76    /// Specifies `100` for [WorkerOptions::max_task_count] and `30` for
77    /// [WorkerOptions::pipe_buffer_size].
78    pub fn default_multi_task() -> Self {
79        Self {
80            max_task_count: DEFAULT_MAX_TASK_COUNT,
81            pipe_buffer_size: DEFAULT_READER_BUFFER_SIZE,
82        }
83    }
84}
85
86#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
87struct ValidWorkerOptions {
88    unbounded_buffer: bool,
89    reader_buffer_size: NonZeroUsize,
90    max_task_count: NonZeroUsize,
91}
92
93impl TryFrom<WorkerOptions> for ValidWorkerOptions {
94    type Error = String;
95
96    fn try_from(value: WorkerOptions) -> Result<Self, Self::Error> {
97        Ok(Self {
98            unbounded_buffer: false,
99            reader_buffer_size: NonZeroUsize::new(value.pipe_buffer_size)
100                .ok_or("reader buffer size must not be zero")?,
101            max_task_count: NonZeroUsize::new(value.max_task_count)
102                .ok_or("max task count must not be zero")?,
103        })
104    }
105}
106
107/// A Box that can hold any value that is [Send].
108///
109/// Values sent through pipes are trait objects of this type.
110///
111/// This type is publicly exposed as it's needed when building a pipeline stage with multiple
112/// outputs. Since each output could have a different type, it's more feasible to define the
113/// outputs to use dynamic dispatching rather that static dispatching.
114///
115/// # Examples
116///
117/// Here's an example of a closure representing the task function given to the pipeline builder
118/// when creating a "branching" stage. Three outputs are returned, each of a different type.
119/// ```
120/// use async_pipes::branch;
121///
122/// #[tokio::main]
123/// async fn main() {
124///     let task = |value: String| async move {
125///         let length: usize = value.len();
126///         let excited: String = format!("{}!", value);
127///         let odd_length: bool = length % 2 == 1;
128///
129///         Some(branch![length, excited, odd_length])
130///     };
131///
132///     // E.g.:
133///     // ...
134///     // .with_branching_stage("pipe_in", vec!["pipe_len", "pipe_excited", "pipe_odd"], <task>)
135///     // ...
136///
137///     let mut results = task("hello".to_string()).await.unwrap();
138///
139///     let length = results.remove(0).unwrap().downcast::<usize>().unwrap();
140///     let excited = results.remove(0).unwrap().downcast::<String>().unwrap();
141///     let odd_length = results.remove(0).unwrap().downcast::<bool>().unwrap();
142///
143///     assert_eq!(*length, 5usize);
144///     assert_eq!(*excited, "hello!".to_string());
145///     assert_eq!(*odd_length, true);
146/// }
147/// ```
148pub type BoxedAnySend = Box<dyn Any + Send + 'static>;
149
150type ProducerFn = Box<dyn FnMut() -> TaskFuture + Send + 'static>;
151type TaskFn = Box<dyn Fn(BoxedAnySend) -> TaskFuture + Send + Sync + 'static>;
152type IterCastFn = Box<dyn Fn(BoxedAnySend) -> Vec<BoxedAnySend> + Send + Sync + 'static>;
153type TaskFuture = Pin<Box<dyn Future<Output = Option<Vec<Option<BoxedAnySend>>>> + Send + 'static>>;
154
155enum Stage {
156    Producer {
157        function: ProducerFn,
158        pipes: ProducerPipeNames,
159    },
160
161    Regular {
162        function: TaskFn,
163        pipes: TaskPipeNames,
164        options: WorkerOptions,
165    },
166
167    Iterator {
168        stage_type: IterStageType,
169        caster: IterCastFn,
170        pipes: TaskPipeNames,
171        options: WorkerOptions,
172    },
173}
174
175#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
176struct ProducerPipeNames {
177    writers: Vec<String>,
178}
179
180#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
181struct TaskPipeNames {
182    reader: String,
183    writers: Vec<String>,
184}
185
186#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
187enum IterStageType {
188    Flatten,
189}
190
191#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
192struct PipeConfig {
193    name: String,
194    options: ValidWorkerOptions,
195}
196
197#[derive(Debug)]
198struct Pipe<T> {
199    /// Use an option here to "take" it when a reader is used.
200    /// Only allow one reader per pipe.
201    reader: Option<PipeReader<T>>,
202    writer: PipeWriter<T>,
203}
204
205/// Signals sent to stage workers.
206///
207/// Useful for interrupting the natural workflow to tell it something.
208#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
209enum StageWorkerSignal {
210    /// Used to tell stage workers to finish immediately without waiting for remaining tasks to end.
211    Terminate,
212}
213
214impl Display for StageWorkerSignal {
215    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
216        let signal = match self {
217            Self::Terminate => "Terminate",
218        };
219        write!(f, "{signal}")
220    }
221}
222
223/// A pipeline provides the infrastructure for managing a set of workers that run user-defined
224/// "tasks" on data going through the pipes.
225///
226/// # Examples
227///
228/// Creating a single producer and a single consumer.
229/// ```
230/// use std::sync::Arc;
231/// use std::sync::atomic::{AtomicUsize, Ordering};
232/// use std::sync::atomic::Ordering::{Acquire, SeqCst};
233/// use tokio::sync::Mutex;
234/// use async_pipes::Pipeline;
235///
236/// #[tokio::main]
237/// async fn main() {
238///     use async_pipes::WorkerOptions;
239/// let count = Arc::new(Mutex::new(0usize));
240///
241///     let sum = Arc::new(AtomicUsize::new(0));
242///     let task_sum = sum.clone();
243///
244///     Pipeline::builder()
245///         // Produce values 1 through 10
246///         .with_producer("data", move || {
247///             let count = count.clone();
248///             async move {
249///                 let mut count = count.lock().await;
250///                 if *count < 10 {
251///                     *count += 1;
252///                     Some(*count)
253///                 } else {
254///                     None
255///                 }
256///             }
257///         })
258///         .with_consumer("data", WorkerOptions::default_single_task(), move |value: usize| {
259///             let sum = task_sum.clone();
260///             async move {
261///                 sum.fetch_add(value, SeqCst);
262///             }
263///         })
264///         .build()
265///         .expect("failed to build pipeline")
266///         .wait()
267///         .await;
268///
269///     assert_eq!(sum.load(Acquire), 55);
270/// }
271/// ```
272///
273/// Creating a branching producer and two consumers for each branch.
274/// ```
275/// use std::sync::Arc;
276/// use std::sync::atomic::{AtomicUsize, Ordering};
277/// use std::sync::atomic::Ordering::Acquire;
278/// use tokio::sync::Mutex;
279/// use async_pipes::{branch, NoOutput, Pipeline};
280///
281/// #[tokio::main]
282/// async fn main() {
283///     use async_pipes::WorkerOptions;
284/// let count = Arc::new(Mutex::new(0usize));
285///
286///     let odds_sum = Arc::new(AtomicUsize::new(0));
287///     let task_odds_sum = odds_sum.clone();
288///
289///     let evens_sum = Arc::new(AtomicUsize::new(0));
290///     let task_evens_sum = evens_sum.clone();
291///
292///     Pipeline::builder()
293///         .with_branching_producer(vec!["evens", "odds"], move || {
294///             let c = count.clone();
295///             async move {
296///                 let mut c = c.lock().await;
297///                 if *c >= 10 {
298///                     return None;
299///                 }
300///                 *c += 1;
301///
302///                 let result = if *c % 2 == 0 {
303///                     branch![*c, NoOutput]
304///                 } else {
305///                     branch![NoOutput, *c]
306///                 };
307///                 Some(result)
308///             }
309///         })
310///         .with_consumer("odds", WorkerOptions::default_single_task(), move |n: usize| {
311///             let odds_sum = task_odds_sum.clone();
312///             async move {
313///                 odds_sum.fetch_add(n, Ordering::SeqCst);
314///             }
315///         })
316///         .with_consumer("evens", WorkerOptions::default_single_task(), move |n: usize| {
317///             let evens_sum = task_evens_sum.clone();
318///             async move {
319///                 evens_sum.fetch_add(n, Ordering::SeqCst);
320///             }
321///         })
322///         .build()
323///         .expect("failed to build pipeline!")
324///         .wait()
325///         .await;
326///
327///     assert_eq!(odds_sum.load(Acquire), 25);
328///     assert_eq!(evens_sum.load(Acquire), 30);
329/// }
330/// ```
331#[derive(Debug)]
332pub struct Pipeline {
333    synchronizer: Arc<Synchronizer>,
334    producers: JoinSet<()>,
335    workers: JoinSet<()>,
336    signal_txs: Vec<Sender<StageWorkerSignal>>,
337}
338
339impl Pipeline {
340    /// Create a new pipeline builder.
341    pub fn builder() -> PipelineBuilder {
342        PipelineBuilder::default()
343    }
344
345    /// Wait for the pipeline to complete.
346    ///
347    /// Once the pipeline is complete, a termination signal is sent to to all the workers.
348    ///
349    /// A pipeline progresses to completion by doing the following:
350    ///   1. Wait for all "producers" to complete while also progressing stage workers
351    ///   2. Wait for either all the stage workers to complete, or wait for the internal
352    ///      synchronizer to notify of completion (i.e. there's no more data flowing through the
353    ///      pipeline)
354    ///
355    /// Step 1 implies that if the producers never finish, the pipeline will run forever. See
356    /// [PipelineBuilder::with_producer] for more info.
357    pub async fn wait(mut self) {
358        let workers_to_progress = Arc::new(Mutex::new(self.workers));
359        let workers_to_finish = workers_to_progress.clone();
360
361        let wait_for_producers = async {
362            while let Some(result) = self.producers.join_next().await {
363                check_join_result(result);
364            }
365        };
366        let wait_for_workers = |workers: Arc<Mutex<JoinSet<()>>>| async move {
367            while let Some(result) = workers.lock().await.join_next().await {
368                check_join_result(result);
369            }
370        };
371        let check_sync_completed = async move {
372            while !self.synchronizer.completed() {
373                yield_now().await
374            }
375
376            for tx in self.signal_txs {
377                tx.send(StageWorkerSignal::Terminate)
378                    .await
379                    .expect("failed to send done signal")
380            }
381        };
382
383        // Effectively, make progress until all producers are done.
384        // We do this to prevent the synchronizer from causing the pipeline to shut down too early.
385        select! {
386            _ = wait_for_producers => {},
387            _ = wait_for_workers(workers_to_progress),
388                if !workers_to_progress.lock().await.is_empty() => {},
389        }
390
391        // If either the synchronizer determines we're done, or all workers completed, we're done
392        select! {
393            _ = wait_for_workers(workers_to_finish) => {},
394            _ = check_sync_completed => {},
395        }
396    }
397}
398
399fn find_reader(
400    name: &str,
401    pipes: &mut HashMap<String, Pipe<BoxedAnySend>>,
402) -> Result<PipeReader<BoxedAnySend>, String> {
403    Ok(pipes
404        .get_mut(name)
405        .unwrap_or_else(|| panic!("no pipe with name '{}' found", name))
406        .reader
407        .take()
408        .ok_or("reader was already used")?)
409}
410
411fn find_writer(
412    name: &str,
413    pipes: &HashMap<String, Pipe<BoxedAnySend>>,
414) -> Result<PipeWriter<BoxedAnySend>, String> {
415    Ok(pipes
416        .get(name)
417        .ok_or(format!("pipeline has open-ended pipe: '{}'", name))?
418        .writer
419        .clone())
420}
421
422fn find_writers(
423    names: &[String],
424    pipes: &HashMap<String, Pipe<BoxedAnySend>>,
425) -> Result<Vec<PipeWriter<BoxedAnySend>>, String> {
426    let mut writers = Vec::new();
427    for name in names {
428        writers.push(find_writer(name, pipes)?);
429    }
430    Ok(writers)
431}
432
433async fn write_results<O>(writers: &[PipeWriter<O>], results: Vec<Option<O>>) {
434    if results.len() != writers.len() {
435        panic!("len(results) != len(writers)");
436    }
437
438    for (result, writer) in results.into_iter().zip(writers) {
439        if let Some(result) = result {
440            writer.write(result).await;
441        }
442    }
443}
444
445fn downcast_from_pipe<T: 'static>(value: BoxedAnySend, pipe_name: &str) -> Box<T> {
446    value.downcast::<T>().unwrap_or_else(|_| {
447        panic!(
448            "failed to downcast input value to {} from pipe '{}'",
449            type_name::<T>(),
450            pipe_name,
451        )
452    })
453}
454
455fn check_join_result<T>(result: Result<T, JoinError>) {
456    if let Err(e) = result {
457        if e.is_panic() {
458            panic::resume_unwind(e.into_panic())
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use std::collections::hash_map::RandomState;
466    use std::collections::HashSet;
467    use std::sync::Arc;
468
469    use tokio::select;
470
471    use super::*;
472
473    macro_rules! pipe_writers {
474        ($count:expr) => {{
475            pipe_writers!($count, ())
476        }};
477
478        ($count:expr, $ch:ty) => {{
479            let sync = Arc::new(Synchronizer::default());
480            let mut writers = Vec::new();
481            let mut rxs = Vec::new();
482            for _ in 0..$count {
483                let id = ulid::Ulid::new().to_string();
484                let (tx, rx) = tokio::sync::mpsc::channel::<$ch>(1);
485                rxs.push(rx);
486                writers.push(PipeWriter::new(format!("{id}"), sync.clone(), tx));
487            }
488            (writers, rxs)
489        }};
490    }
491
492    macro_rules! pipe {
493        ($id:expr, reader=$reader:literal) => {{
494            let id: String = $id.into();
495            let sync = Arc::new(Synchronizer::default());
496            let (tx, rx) = tokio::sync::mpsc::channel(1);
497            let pipe = Pipe {
498                writer: PipeWriter::new(id.clone(), sync.clone(), tx),
499                reader: $reader.then_some(PipeReader::new(id.clone(), sync, rx)),
500            };
501            (id, pipe)
502        }};
503    }
504
505    #[test]
506    fn test_find_reader() {
507        let pipe_id = "Pipe";
508        let mut pipes = HashMap::from([pipe!("Pipe", reader = true)]);
509
510        let reader = find_reader(pipe_id, &mut pipes);
511        assert!(reader.is_ok());
512        assert_eq!(reader.unwrap().get_pipe_id(), pipe_id);
513    }
514
515    #[test]
516    #[should_panic]
517    fn test_find_reader_panics_on_no_reader() {
518        let _ = find_reader("Pipe", &mut HashMap::from([]));
519    }
520
521    #[test]
522    fn test_find_reader_already_used() {
523        let mut pipes = HashMap::from([pipe!("NoReader", reader = false)]);
524
525        let reader = find_reader("NoReader", &mut pipes);
526        assert!(reader.is_err());
527        assert_eq!(reader.unwrap_err(), "reader was already used".to_string());
528    }
529
530    #[test]
531    fn test_find_writer() {
532        let pipe_id = "Pipe";
533        let pipes = HashMap::from([pipe!(pipe_id, reader = true)]);
534
535        let writer = find_writer(pipe_id, &pipes);
536        assert!(writer.is_ok());
537        assert_eq!(writer.unwrap().get_pipe_id(), pipe_id);
538    }
539
540    #[test]
541    fn test_find_writer_open_ended() {
542        let pipes = HashMap::from([]);
543
544        let writer = find_writer("Pipe", &pipes);
545        assert!(writer.is_err());
546        assert_eq!(writer.unwrap_err(), "pipeline has open-ended pipe: 'Pipe'");
547    }
548
549    #[test]
550    fn test_find_writers() {
551        let pipes = HashMap::from([
552            pipe!("One", reader = true),
553            pipe!("Two", reader = false),
554            pipe!("Three", reader = true),
555        ]);
556
557        let pipe_ids = vec!["Two".to_string(), "Three".to_string()];
558        let writers = find_writers(&pipe_ids, &pipes);
559        assert!(writers.is_ok());
560
561        let mut pipe_ids = HashSet::<String, RandomState>::from_iter(pipe_ids);
562        let writers = writers.unwrap();
563        assert_eq!(writers.len(), 2);
564
565        for writer in writers {
566            let id = writer.get_pipe_id();
567            assert!(pipe_ids.remove(id), "missing ID");
568        }
569    }
570
571    #[test]
572    fn test_find_writers_open_ended() {
573        let pipes = HashMap::from([
574            pipe!("One", reader = true),
575            pipe!("Two", reader = false),
576            pipe!("Three", reader = true),
577        ]);
578
579        let pipe_ids = vec!["Two".to_string(), "Three".to_string(), "Four".to_string()];
580        let writers = find_writers(&pipe_ids, &pipes);
581        assert!(writers.is_err());
582        assert_eq!(writers.unwrap_err(), "pipeline has open-ended pipe: 'Four'");
583    }
584
585    #[tokio::test]
586    #[cfg_attr(miri, ignore)]
587    async fn test_write_results() {
588        let (writers, mut txs) = pipe_writers!(3, usize);
589        let results = vec![Some(0), None, Some(2)];
590
591        write_results(&writers, results).await;
592
593        assert_eq!(txs.get_mut(0).unwrap().try_recv(), Ok(0));
594        assert!(txs.get_mut(1).unwrap().try_recv().is_err());
595        assert_eq!(txs.get_mut(2).unwrap().try_recv(), Ok(2));
596    }
597
598    #[tokio::test]
599    #[should_panic]
600    #[cfg_attr(miri, ignore)]
601    async fn test_write_results_panics_on_result_count_mismatch() {
602        let (writers, _txs) = pipe_writers!(5, i32);
603        let results = vec![Some(1), None, None];
604
605        write_results(&writers, results).await;
606    }
607
608    #[test]
609    fn test_downcast_from_pipe() {
610        let value = Box::new(3i8) as BoxedAnySend;
611
612        let casted = downcast_from_pipe::<i8>(value, "some_pipe");
613
614        assert_eq!(casted, Box::new(3i8));
615    }
616
617    #[test]
618    #[should_panic(expected = "failed to downcast input value to i32 from pipe 'some_pipe'")]
619    fn test_downcast_from_pipe_fails() {
620        let value = Box::new(3i8) as BoxedAnySend;
621
622        downcast_from_pipe::<i32>(value, "some_pipe");
623    }
624
625    #[test]
626    fn test_check_join_result_does_nothing_on_ok() {
627        check_join_result(Ok(3usize));
628    }
629
630    #[tokio::test]
631    #[should_panic]
632    async fn test_check_join_result_propagates_panic() {
633        let mut joins = JoinSet::new();
634        joins.spawn(async { panic!("aaaahhhhh") });
635
636        check_join_result(joins.join_next().await.unwrap())
637    }
638
639    #[tokio::test]
640    async fn test_stage_receives_signal_terminate() {
641        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
642
643        let pipeline = Pipeline::builder()
644            .with_inputs("pipe", vec![()])
645            .with_consumer(
646                "pipe",
647                WorkerOptions::default_single_task(),
648                move |_: ()| {
649                    let tx = tx.clone();
650                    async move {
651                        tx.send(()).await.unwrap();
652                        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
653                        panic!("worker did not terminate!");
654                    }
655                },
656            )
657            .build()
658            .unwrap();
659
660        let signaller = pipeline.signal_txs.first().unwrap().clone();
661        select! {
662            _ = pipeline.wait() => {},
663            _ = rx.recv() => {
664                signaller.send(StageWorkerSignal::Terminate).await.unwrap();
665            }
666        }
667    }
668}