1use std::any::{type_name, Any};
2use std::collections::HashMap;
3use std::fmt::{Display, Formatter};
4use std::future::Future;
5use std::num::NonZeroUsize;
6use std::panic;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use tokio::select;
11use tokio::sync::mpsc::Sender;
12use tokio::sync::Mutex;
13use tokio::task::{yield_now, JoinError, JoinSet};
14
15pub use builder::*;
16use io::{PipeReader, PipeWriter};
17use sync::Synchronizer;
18
19mod builder;
20mod io;
21mod sync;
22mod workers;
23
24const DEFAULT_MAX_TASK_COUNT: usize = 100;
25const DEFAULT_READER_BUFFER_SIZE: usize = 30;
26
27#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
51pub struct WorkerOptions {
52 pub pipe_buffer_size: usize,
55
56 pub max_task_count: usize,
59}
60
61impl Default for WorkerOptions {
62 fn default() -> Self {
63 Self::default_multi_task()
64 }
65}
66
67impl WorkerOptions {
68 pub fn default_single_task() -> Self {
70 Self {
71 max_task_count: 1,
72 ..Default::default()
73 }
74 }
75
76 pub fn default_multi_task() -> Self {
79 Self {
80 max_task_count: DEFAULT_MAX_TASK_COUNT,
81 pipe_buffer_size: DEFAULT_READER_BUFFER_SIZE,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
87struct ValidWorkerOptions {
88 unbounded_buffer: bool,
89 reader_buffer_size: NonZeroUsize,
90 max_task_count: NonZeroUsize,
91}
92
93impl TryFrom<WorkerOptions> for ValidWorkerOptions {
94 type Error = String;
95
96 fn try_from(value: WorkerOptions) -> Result<Self, Self::Error> {
97 Ok(Self {
98 unbounded_buffer: false,
99 reader_buffer_size: NonZeroUsize::new(value.pipe_buffer_size)
100 .ok_or("reader buffer size must not be zero")?,
101 max_task_count: NonZeroUsize::new(value.max_task_count)
102 .ok_or("max task count must not be zero")?,
103 })
104 }
105}
106
107pub type BoxedAnySend = Box<dyn Any + Send + 'static>;
149
150type ProducerFn = Box<dyn FnMut() -> TaskFuture + Send + 'static>;
151type TaskFn = Box<dyn Fn(BoxedAnySend) -> TaskFuture + Send + Sync + 'static>;
152type IterCastFn = Box<dyn Fn(BoxedAnySend) -> Vec<BoxedAnySend> + Send + Sync + 'static>;
153type TaskFuture = Pin<Box<dyn Future<Output = Option<Vec<Option<BoxedAnySend>>>> + Send + 'static>>;
154
155enum Stage {
156 Producer {
157 function: ProducerFn,
158 pipes: ProducerPipeNames,
159 },
160
161 Regular {
162 function: TaskFn,
163 pipes: TaskPipeNames,
164 options: WorkerOptions,
165 },
166
167 Iterator {
168 stage_type: IterStageType,
169 caster: IterCastFn,
170 pipes: TaskPipeNames,
171 options: WorkerOptions,
172 },
173}
174
175#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
176struct ProducerPipeNames {
177 writers: Vec<String>,
178}
179
180#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
181struct TaskPipeNames {
182 reader: String,
183 writers: Vec<String>,
184}
185
186#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
187enum IterStageType {
188 Flatten,
189}
190
191#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
192struct PipeConfig {
193 name: String,
194 options: ValidWorkerOptions,
195}
196
197#[derive(Debug)]
198struct Pipe<T> {
199 reader: Option<PipeReader<T>>,
202 writer: PipeWriter<T>,
203}
204
205#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
209enum StageWorkerSignal {
210 Terminate,
212}
213
214impl Display for StageWorkerSignal {
215 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
216 let signal = match self {
217 Self::Terminate => "Terminate",
218 };
219 write!(f, "{signal}")
220 }
221}
222
223#[derive(Debug)]
332pub struct Pipeline {
333 synchronizer: Arc<Synchronizer>,
334 producers: JoinSet<()>,
335 workers: JoinSet<()>,
336 signal_txs: Vec<Sender<StageWorkerSignal>>,
337}
338
339impl Pipeline {
340 pub fn builder() -> PipelineBuilder {
342 PipelineBuilder::default()
343 }
344
345 pub async fn wait(mut self) {
358 let workers_to_progress = Arc::new(Mutex::new(self.workers));
359 let workers_to_finish = workers_to_progress.clone();
360
361 let wait_for_producers = async {
362 while let Some(result) = self.producers.join_next().await {
363 check_join_result(result);
364 }
365 };
366 let wait_for_workers = |workers: Arc<Mutex<JoinSet<()>>>| async move {
367 while let Some(result) = workers.lock().await.join_next().await {
368 check_join_result(result);
369 }
370 };
371 let check_sync_completed = async move {
372 while !self.synchronizer.completed() {
373 yield_now().await
374 }
375
376 for tx in self.signal_txs {
377 tx.send(StageWorkerSignal::Terminate)
378 .await
379 .expect("failed to send done signal")
380 }
381 };
382
383 select! {
386 _ = wait_for_producers => {},
387 _ = wait_for_workers(workers_to_progress),
388 if !workers_to_progress.lock().await.is_empty() => {},
389 }
390
391 select! {
393 _ = wait_for_workers(workers_to_finish) => {},
394 _ = check_sync_completed => {},
395 }
396 }
397}
398
399fn find_reader(
400 name: &str,
401 pipes: &mut HashMap<String, Pipe<BoxedAnySend>>,
402) -> Result<PipeReader<BoxedAnySend>, String> {
403 Ok(pipes
404 .get_mut(name)
405 .unwrap_or_else(|| panic!("no pipe with name '{}' found", name))
406 .reader
407 .take()
408 .ok_or("reader was already used")?)
409}
410
411fn find_writer(
412 name: &str,
413 pipes: &HashMap<String, Pipe<BoxedAnySend>>,
414) -> Result<PipeWriter<BoxedAnySend>, String> {
415 Ok(pipes
416 .get(name)
417 .ok_or(format!("pipeline has open-ended pipe: '{}'", name))?
418 .writer
419 .clone())
420}
421
422fn find_writers(
423 names: &[String],
424 pipes: &HashMap<String, Pipe<BoxedAnySend>>,
425) -> Result<Vec<PipeWriter<BoxedAnySend>>, String> {
426 let mut writers = Vec::new();
427 for name in names {
428 writers.push(find_writer(name, pipes)?);
429 }
430 Ok(writers)
431}
432
433async fn write_results<O>(writers: &[PipeWriter<O>], results: Vec<Option<O>>) {
434 if results.len() != writers.len() {
435 panic!("len(results) != len(writers)");
436 }
437
438 for (result, writer) in results.into_iter().zip(writers) {
439 if let Some(result) = result {
440 writer.write(result).await;
441 }
442 }
443}
444
445fn downcast_from_pipe<T: 'static>(value: BoxedAnySend, pipe_name: &str) -> Box<T> {
446 value.downcast::<T>().unwrap_or_else(|_| {
447 panic!(
448 "failed to downcast input value to {} from pipe '{}'",
449 type_name::<T>(),
450 pipe_name,
451 )
452 })
453}
454
455fn check_join_result<T>(result: Result<T, JoinError>) {
456 if let Err(e) = result {
457 if e.is_panic() {
458 panic::resume_unwind(e.into_panic())
459 }
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use std::collections::hash_map::RandomState;
466 use std::collections::HashSet;
467 use std::sync::Arc;
468
469 use tokio::select;
470
471 use super::*;
472
473 macro_rules! pipe_writers {
474 ($count:expr) => {{
475 pipe_writers!($count, ())
476 }};
477
478 ($count:expr, $ch:ty) => {{
479 let sync = Arc::new(Synchronizer::default());
480 let mut writers = Vec::new();
481 let mut rxs = Vec::new();
482 for _ in 0..$count {
483 let id = ulid::Ulid::new().to_string();
484 let (tx, rx) = tokio::sync::mpsc::channel::<$ch>(1);
485 rxs.push(rx);
486 writers.push(PipeWriter::new(format!("{id}"), sync.clone(), tx));
487 }
488 (writers, rxs)
489 }};
490 }
491
492 macro_rules! pipe {
493 ($id:expr, reader=$reader:literal) => {{
494 let id: String = $id.into();
495 let sync = Arc::new(Synchronizer::default());
496 let (tx, rx) = tokio::sync::mpsc::channel(1);
497 let pipe = Pipe {
498 writer: PipeWriter::new(id.clone(), sync.clone(), tx),
499 reader: $reader.then_some(PipeReader::new(id.clone(), sync, rx)),
500 };
501 (id, pipe)
502 }};
503 }
504
505 #[test]
506 fn test_find_reader() {
507 let pipe_id = "Pipe";
508 let mut pipes = HashMap::from([pipe!("Pipe", reader = true)]);
509
510 let reader = find_reader(pipe_id, &mut pipes);
511 assert!(reader.is_ok());
512 assert_eq!(reader.unwrap().get_pipe_id(), pipe_id);
513 }
514
515 #[test]
516 #[should_panic]
517 fn test_find_reader_panics_on_no_reader() {
518 let _ = find_reader("Pipe", &mut HashMap::from([]));
519 }
520
521 #[test]
522 fn test_find_reader_already_used() {
523 let mut pipes = HashMap::from([pipe!("NoReader", reader = false)]);
524
525 let reader = find_reader("NoReader", &mut pipes);
526 assert!(reader.is_err());
527 assert_eq!(reader.unwrap_err(), "reader was already used".to_string());
528 }
529
530 #[test]
531 fn test_find_writer() {
532 let pipe_id = "Pipe";
533 let pipes = HashMap::from([pipe!(pipe_id, reader = true)]);
534
535 let writer = find_writer(pipe_id, &pipes);
536 assert!(writer.is_ok());
537 assert_eq!(writer.unwrap().get_pipe_id(), pipe_id);
538 }
539
540 #[test]
541 fn test_find_writer_open_ended() {
542 let pipes = HashMap::from([]);
543
544 let writer = find_writer("Pipe", &pipes);
545 assert!(writer.is_err());
546 assert_eq!(writer.unwrap_err(), "pipeline has open-ended pipe: 'Pipe'");
547 }
548
549 #[test]
550 fn test_find_writers() {
551 let pipes = HashMap::from([
552 pipe!("One", reader = true),
553 pipe!("Two", reader = false),
554 pipe!("Three", reader = true),
555 ]);
556
557 let pipe_ids = vec!["Two".to_string(), "Three".to_string()];
558 let writers = find_writers(&pipe_ids, &pipes);
559 assert!(writers.is_ok());
560
561 let mut pipe_ids = HashSet::<String, RandomState>::from_iter(pipe_ids);
562 let writers = writers.unwrap();
563 assert_eq!(writers.len(), 2);
564
565 for writer in writers {
566 let id = writer.get_pipe_id();
567 assert!(pipe_ids.remove(id), "missing ID");
568 }
569 }
570
571 #[test]
572 fn test_find_writers_open_ended() {
573 let pipes = HashMap::from([
574 pipe!("One", reader = true),
575 pipe!("Two", reader = false),
576 pipe!("Three", reader = true),
577 ]);
578
579 let pipe_ids = vec!["Two".to_string(), "Three".to_string(), "Four".to_string()];
580 let writers = find_writers(&pipe_ids, &pipes);
581 assert!(writers.is_err());
582 assert_eq!(writers.unwrap_err(), "pipeline has open-ended pipe: 'Four'");
583 }
584
585 #[tokio::test]
586 #[cfg_attr(miri, ignore)]
587 async fn test_write_results() {
588 let (writers, mut txs) = pipe_writers!(3, usize);
589 let results = vec![Some(0), None, Some(2)];
590
591 write_results(&writers, results).await;
592
593 assert_eq!(txs.get_mut(0).unwrap().try_recv(), Ok(0));
594 assert!(txs.get_mut(1).unwrap().try_recv().is_err());
595 assert_eq!(txs.get_mut(2).unwrap().try_recv(), Ok(2));
596 }
597
598 #[tokio::test]
599 #[should_panic]
600 #[cfg_attr(miri, ignore)]
601 async fn test_write_results_panics_on_result_count_mismatch() {
602 let (writers, _txs) = pipe_writers!(5, i32);
603 let results = vec![Some(1), None, None];
604
605 write_results(&writers, results).await;
606 }
607
608 #[test]
609 fn test_downcast_from_pipe() {
610 let value = Box::new(3i8) as BoxedAnySend;
611
612 let casted = downcast_from_pipe::<i8>(value, "some_pipe");
613
614 assert_eq!(casted, Box::new(3i8));
615 }
616
617 #[test]
618 #[should_panic(expected = "failed to downcast input value to i32 from pipe 'some_pipe'")]
619 fn test_downcast_from_pipe_fails() {
620 let value = Box::new(3i8) as BoxedAnySend;
621
622 downcast_from_pipe::<i32>(value, "some_pipe");
623 }
624
625 #[test]
626 fn test_check_join_result_does_nothing_on_ok() {
627 check_join_result(Ok(3usize));
628 }
629
630 #[tokio::test]
631 #[should_panic]
632 async fn test_check_join_result_propagates_panic() {
633 let mut joins = JoinSet::new();
634 joins.spawn(async { panic!("aaaahhhhh") });
635
636 check_join_result(joins.join_next().await.unwrap())
637 }
638
639 #[tokio::test]
640 async fn test_stage_receives_signal_terminate() {
641 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
642
643 let pipeline = Pipeline::builder()
644 .with_inputs("pipe", vec![()])
645 .with_consumer(
646 "pipe",
647 WorkerOptions::default_single_task(),
648 move |_: ()| {
649 let tx = tx.clone();
650 async move {
651 tx.send(()).await.unwrap();
652 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
653 panic!("worker did not terminate!");
654 }
655 },
656 )
657 .build()
658 .unwrap();
659
660 let signaller = pipeline.signal_txs.first().unwrap().clone();
661 select! {
662 _ = pipeline.wait() => {},
663 _ = rx.recv() => {
664 signaller.send(StageWorkerSignal::Terminate).await.unwrap();
665 }
666 }
667 }
668}