1use std::collections::HashSet;
57use std::future::Future;
58use std::iter::FromIterator;
59use std::mem;
60use std::pin::Pin;
61use std::rc::Rc;
62
63use downcast_rs::{impl_downcast, Downcast};
64use futures::stream::{FuturesUnordered, StreamExt};
65
66#[non_exhaustive]
68pub enum Outcome {
69 Success,
71}
72
73#[derive(Debug, PartialEq, Eq)]
75pub enum Error<E> {
76 Cycle,
78
79 Failed(Vec<E>),
81
82 Plan(E),
84}
85
86pub type Job<C, E> = Box<dyn FnOnce(C) -> Pin<Box<dyn Future<Output = Result<Outcome, E>>>>>;
103
104pub trait IntoJob<C, E>: Downcast {
118 fn into_job(&self) -> Job<C, E>;
120
121 fn plan(&self, plan: &mut PlanBuilder<C, E>) -> Result<(), Error<E>> {
123 #![allow(unused_variables)]
126
127 Ok(())
128 }
129}
130
131impl_downcast!(IntoJob<C, E>);
132
133struct PlanBuilderEntry<C, E> {
135 job: Rc<dyn IntoJob<C, E>>,
136 dependencies: HashSet<usize>,
137 dependents: HashSet<usize>,
138}
139
140pub struct PlanBuilder<C, E> {
149 jobs: Vec<PlanBuilderEntry<C, E>>,
150 ancestors: HashSet<usize>,
151 current_parent: usize,
152 ready: Vec<usize>,
153}
154
155impl<C: 'static, E: 'static> PlanBuilder<C, E> {
156 fn index_of<J: IntoJob<C, E> + PartialEq>(&self, job: &J) -> Option<usize> {
158 for (idx, entry) in self.jobs.iter().enumerate() {
159 if let Some(existing_job) = entry.job.downcast_ref::<J>() {
160 if job == existing_job {
161 return Some(idx);
162 }
163 }
164 }
165
166 None
167 }
168
169 pub fn add_dependency<J: IntoJob<C, E> + PartialEq>(&mut self, job: J) -> Result<(), Error<E>> {
171 if let Some(idx) = self.index_of(&job) {
184 if self.ancestors.contains(&idx) {
185 return Err(Error::Cycle);
186 }
187
188 self.jobs[idx].dependents.insert(self.current_parent);
189 self.jobs[self.current_parent].dependencies.insert(idx);
190 return Ok(());
191 }
192
193 let idx = self.jobs.len();
197 let job = Rc::new(job);
198 self.jobs.push(PlanBuilderEntry {
199 job: job.clone(),
200 dependencies: HashSet::new(),
201 dependents: HashSet::from_iter(vec![self.current_parent]),
202 });
203 self.jobs[self.current_parent].dependencies.insert(idx);
204
205 self.ancestors.insert(idx);
206 let prev_parent = mem::replace(&mut self.current_parent, idx);
207 job.plan(self)?;
208 self.current_parent = prev_parent;
209 self.ancestors.remove(&idx);
210
211 if self.jobs[idx].dependencies.is_empty() {
212 self.ready.push(idx);
213 }
214
215 Ok(())
216 }
217}
218
219enum State<C, E> {
221 Pending(Job<C, E>),
222 Running,
223 Success(Outcome),
224 Failed(E),
225}
226
227impl<C, E> State<C, E> {
228 fn success(&self) -> bool {
230 match self {
231 State::Success(_) => true,
232 _ => false,
233 }
234 }
235}
236
237struct PlanEntry<C, E> {
239 state: State<C, E>,
240 dependencies: HashSet<usize>,
241 dependents: HashSet<usize>,
242}
243
244struct Plan<C, E> {
246 jobs: Vec<PlanEntry<C, E>>,
247 ready: Vec<usize>,
248}
249
250impl<C, E> Plan<C, E> {
251 fn new<J: IntoJob<C, E>>(job: J) -> Result<Self, Error<E>> {
253 let job = Rc::new(job);
254
255 let mut builder = PlanBuilder {
256 jobs: vec![PlanBuilderEntry {
257 job: job.clone(),
258 dependencies: HashSet::new(),
259 dependents: HashSet::new(),
260 }],
261 ancestors: HashSet::from_iter(vec![0]),
262 current_parent: 0,
263 ready: vec![],
264 };
265
266 job.plan(&mut builder)?;
267 if builder.jobs[0].dependencies.is_empty() {
268 builder.ready.push(0);
269 }
270
271 Ok(Self {
272 jobs: builder
273 .jobs
274 .drain(..)
275 .map(|e| PlanEntry {
276 state: State::Pending(e.job.into_job()),
277 dependencies: e.dependencies,
278 dependents: e.dependents,
279 })
280 .collect(),
281 ready: builder.ready,
282 })
283 }
284
285 fn next_job(&mut self) -> Option<(Job<C, E>, usize)> {
287 if self.ready.len() == 0 {
288 return None;
289 }
290
291 let idx = self.ready.remove(0);
292 let state = mem::replace(&mut self.jobs[idx].state, State::Running);
293
294 if let State::Pending(job) = state {
295 Some((job, idx))
296 } else {
297 panic!("unexpected job status")
298 }
299 }
300
301 fn mark_complete(&mut self, job_idx: usize, res: Result<Outcome, E>) {
304 self.jobs[job_idx].state = match res {
305 Ok(outcome) => State::Success(outcome),
306 Err(err) => State::Failed(err),
307 };
308
309 for dep_idx in &self.jobs[job_idx].dependents {
310 let is_ready = self.jobs[*dep_idx]
311 .dependencies
312 .iter()
313 .all(|i| self.jobs[*i].state.success());
314 if is_ready {
315 self.ready.push(*dep_idx);
316 }
317 }
318 }
319}
320
321pub struct Scheduler<'a, C> {
325 max_jobs: usize,
326 ctx_factory: Box<dyn FnMut() -> C + 'a>,
327}
328
329impl Scheduler<'static, ()> {
330 pub fn new() -> Self {
332 let max_jobs = num_cpus::get();
333 let ctx_factory = Box::new(|| ());
334 Self { max_jobs, ctx_factory }
335 }
336}
337
338impl<'a, C> Scheduler<'a, C> {
339 pub fn with_factory<F>(factory: F) -> Self
342 where
343 F: FnMut() -> C + 'a
344 {
345 let max_jobs = num_cpus::get();
346 let ctx_factory = Box::new(factory);
347 Self { max_jobs, ctx_factory }
348 }
349
350 pub fn max_jobs(&mut self, jobs: usize) -> &mut Self {
357 if jobs == 0 {
358 panic!("max_jobs must be greater than zero")
359 }
360 self.max_jobs = jobs;
361 self
362 }
363
364 pub async fn run<E, J: IntoJob<C, E>>(&mut self, job: J) -> Result<(), Error<E>> {
366 let mut plan = Plan::new(job)?;
367 let mut pool = FuturesUnordered::new();
368
369 loop {
370 while pool.len() < self.max_jobs {
373 if let Some((job, idx)) = plan.next_job() {
374 let ctx = (self.ctx_factory)();
375 pool.push(async move {
376 let res = job(ctx).await;
377 (idx, res)
378 })
379 } else {
380 break;
381 }
382 }
383
384 if pool.len() == 0 {
385 break;
388 }
389
390 if let Some((idx, res)) = pool.next().await {
391 plan.mark_complete(idx, res);
392 } else {
393 panic!("job pool unexpectedly empty");
394 }
395 }
396
397 let mut errs = vec![];
398 for job in plan.jobs {
399 if let State::Failed(err) = job.state {
400 errs.push(err);
401 }
402 }
403
404 if errs.len() > 0 {
405 Err(Error::Failed(errs))
406 } else {
407 Ok(())
408 }
409 }
410}
411
412#[cfg(test)]
413mod tests {
414
415 use std::time::{Duration, Instant};
416
417 use async_std::sync::Mutex;
418 use async_std::task;
419
420 use super::*;
421
422 type JobGraph = Rc<Vec<(bool, Vec<usize>)>>;
423
424 type JobTrace = Rc<Mutex<Vec<usize>>>;
425
426 struct TestPlan {
427 graph: Vec<(bool, Vec<usize>)>,
428 max_jobs: Option<usize>,
429 }
430
431 struct TestJob {
432 index: usize,
433 graph: JobGraph,
434 success: bool,
435 }
436
437 impl IntoJob<JobTrace, usize> for TestJob {
438 fn plan(&self, plan: &mut PlanBuilder<JobTrace, usize>) -> Result<(), Error<usize>> {
439 for index in &self.graph[self.index].1 {
440 plan.add_dependency(TestJob {
441 index: *index,
442 graph: self.graph.clone(),
443 success: self.graph[*index].0,
444 })?;
445 }
446
447 Ok(())
448 }
449
450 fn into_job(&self) -> Job<JobTrace, usize> {
451 let success = self.success;
452 let index = self.index;
453 Box::new(move |trace| {
454 Box::pin(async move {
455 trace.lock().await.push(index);
456 if success {
457 Ok(Outcome::Success)
458 } else {
459 Err(index)
460 }
461 })
462 })
463 }
464 }
465
466 impl PartialEq for TestJob {
467 fn eq(&self, other: &Self) -> bool {
468 self.index == other.index
469 }
470 }
471
472 impl TestPlan {
473 fn new(graph: Vec<(bool, Vec<usize>)>) -> Self {
474 Self {
475 graph,
476 max_jobs: None,
477 }
478 }
479
480 async fn trace(self) -> (Vec<Option<usize>>, Option<Error<usize>>) {
481 let graph = Rc::new(self.graph);
482 let job = TestJob {
483 index: 0,
484 graph: graph.clone(),
485 success: graph[0].0,
486 };
487
488 let trace = Rc::new(Mutex::new(vec![]));
489 let mut sched = Scheduler::with_factory(|| trace.clone());
490 if let Some(max_jobs) = self.max_jobs {
491 sched.max_jobs(max_jobs);
492 }
493
494 let err = sched.run(job).await.err();
495
496 let mut results = vec![None; graph.len()];
497
498 for (finished_idx, job_idx) in trace.lock().await.iter().enumerate() {
499 assert!(results[*job_idx].is_none());
501
502 results[*job_idx] = Some(finished_idx);
503 }
504
505 (results, err)
506 }
507 }
508
509 #[async_std::test]
510 async fn single_job() {
511 let (trace, err) = TestPlan::new(vec![(true, vec![])]).trace().await;
512
513 assert!(err.is_none());
514 assert_eq!(trace[0], Some(0));
515 }
516
517 #[async_std::test]
518 async fn single_job_fails() {
519 let (trace, err) = TestPlan::new(vec![(false, vec![])]).trace().await;
520
521 assert_eq!(err, Some(Error::Failed(vec![0])));
522 assert_eq!(trace[0], Some(0));
523 }
524
525 #[async_std::test]
526 async fn single_dep() {
527 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (true, vec![])])
528 .trace()
529 .await;
530
531 assert!(err.is_none());
532 assert_eq!(trace[0], Some(1));
533 assert_eq!(trace[1], Some(0));
534 }
535
536 #[async_std::test]
537 async fn single_dep_fails() {
538 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (false, vec![])])
539 .trace()
540 .await;
541
542 assert_eq!(err, Some(Error::Failed(vec![1])));
543 assert_eq!(trace[0], None);
544 assert_eq!(trace[1], Some(0));
545 }
546
547 #[async_std::test]
548 async fn single_dep_root_fails() {
549 let (trace, err) = TestPlan::new(vec![(false, vec![1]), (true, vec![])])
550 .trace()
551 .await;
552
553 assert_eq!(err, Some(Error::Failed(vec![0])));
554 assert_eq!(trace[0], Some(1));
555 assert_eq!(trace[1], Some(0));
556 }
557
558 #[async_std::test]
559 async fn two_deps() {
560 let (trace, err) = TestPlan::new(vec![(true, vec![1, 2]), (true, vec![]), (true, vec![])])
561 .trace()
562 .await;
563
564 assert!(err.is_none());
565 assert_eq!(trace[0], Some(2));
566 assert!(matches!(trace[1], Some(x) if x < 2));
567 assert!(matches!(trace[2], Some(x) if x < 2));
568 }
569
570 #[async_std::test]
571 async fn two_deps_one_fails() {
572 let (trace, err) = TestPlan::new(vec![(true, vec![1, 2]), (true, vec![]), (false, vec![])])
573 .trace()
574 .await;
575
576 assert_eq!(err, Some(Error::Failed(vec![2])));
577 assert_eq!(trace[0], None);
578 assert!(trace[2].is_some());
580 }
581
582 #[async_std::test]
583 async fn single_trans_dep() {
584 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (true, vec![2]), (true, vec![])])
585 .trace()
586 .await;
587
588 assert!(err.is_none());
589 assert_eq!(trace[0], Some(2));
590 assert_eq!(trace[1], Some(1));
591 assert_eq!(trace[2], Some(0));
592 }
593
594 #[async_std::test]
595 async fn single_trans_dep_fails() {
596 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (true, vec![2]), (false, vec![])])
597 .trace()
598 .await;
599
600 assert_eq!(err, Some(Error::Failed(vec![2])));
601 assert_eq!(trace[0], None);
602 assert_eq!(trace[1], None);
603 assert_eq!(trace[2], Some(0));
604 }
605
606 #[async_std::test]
607 async fn single_trans_dep_direct_dep_fails() {
608 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (false, vec![2]), (true, vec![])])
609 .trace()
610 .await;
611
612 assert_eq!(err, Some(Error::Failed(vec![1])));
613 assert_eq!(trace[0], None);
614 assert_eq!(trace[1], Some(1));
615 assert_eq!(trace[2], Some(0));
616 }
617
618 #[async_std::test]
619 async fn two_deps_single_trans_dep() {
620 let (trace, err) = TestPlan::new(vec![
621 (true, vec![1, 3]),
622 (true, vec![2]),
623 (true, vec![]),
624 (true, vec![]),
625 ])
626 .trace()
627 .await;
628
629 assert!(err.is_none());
630 assert_eq!(trace[0], Some(3));
631 assert!(matches!(trace[3], Some(x) if x < 3));
632
633 let order_of_1 = trace[1].unwrap();
634 let order_of_2 = trace[2].unwrap();
635 assert!(order_of_1 > order_of_2);
636 assert!(order_of_1 < 3);
637 }
638
639 #[async_std::test]
640 async fn two_deps_each_with_trans_dep() {
641 let (trace, err) = TestPlan::new(vec![
642 (true, vec![1, 3]),
643 (true, vec![2]),
644 (true, vec![]),
645 (true, vec![4]),
646 (true, vec![]),
647 ])
648 .trace()
649 .await;
650
651 assert!(err.is_none());
652 assert_eq!(trace[0], Some(4));
653
654 let order_of_1 = trace[1].unwrap();
655 let order_of_2 = trace[2].unwrap();
656 assert!(order_of_1 < 4);
657 assert!(order_of_2 < 4);
658 assert!(order_of_1 > order_of_2);
659
660 let order_of_3 = trace[3].unwrap();
661 let order_of_4 = trace[4].unwrap();
662 assert!(order_of_3 < 4);
663 assert!(order_of_4 < 4);
664 assert!(order_of_3 > order_of_4);
665 }
666
667 #[async_std::test]
668 async fn three_deps() {
669 let (trace, err) = TestPlan::new(vec![
670 (true, vec![1, 2, 3]),
671 (true, vec![]),
672 (true, vec![]),
673 (true, vec![]),
674 ])
675 .trace()
676 .await;
677
678 assert!(err.is_none());
679 assert_eq!(trace[0], Some(3));
680 assert!(matches!(trace[1], Some(x) if x < 3));
681 assert!(matches!(trace[2], Some(x) if x < 3));
682 assert!(matches!(trace[3], Some(x) if x < 3));
683 }
684
685 #[async_std::test]
686 async fn diamond() {
687 let (trace, err) = TestPlan::new(vec![
688 (true, vec![2, 3]),
689 (true, vec![]),
690 (true, vec![1]),
691 (true, vec![1]),
692 ])
693 .trace()
694 .await;
695
696 assert!(err.is_none());
697 assert_eq!(trace[0], Some(3));
698 assert_eq!(trace[1], Some(0));
699
700 let order_of_2 = trace[2].unwrap();
701 let order_of_3 = trace[3].unwrap();
702 assert!(order_of_2 > 0);
703 assert!(order_of_2 < 3);
704 assert!(order_of_3 > 0);
705 assert!(order_of_3 < 3);
706 }
707
708 #[async_std::test]
709 async fn diamond_with_extra_trans_deps() {
710 let (trace, err) = TestPlan::new(vec![
711 (true, vec![2, 3]),
712 (true, vec![4]),
713 (true, vec![1, 5]),
714 (true, vec![1, 6]),
715 (true, vec![]),
716 (true, vec![]),
717 (true, vec![]),
718 ])
719 .trace()
720 .await;
721
722 assert!(err.is_none());
723 assert_eq!(trace[0], Some(6));
724
725 let order_of_2 = trace[2].unwrap();
726 assert!(order_of_2 < 6);
727
728 let order_of_3 = trace[3].unwrap();
729 assert!(order_of_3 < 6);
730
731 let order_of_1 = trace[1].unwrap();
732 assert!(order_of_1 < order_of_2);
733 assert!(order_of_1 < order_of_3);
734
735 let order_of_4 = trace[4].unwrap();
736 assert!(order_of_4 < order_of_1);
737
738 let order_of_5 = trace[5].unwrap();
739 assert!(order_of_5 < order_of_2);
740
741 let order_of_6 = trace[6].unwrap();
742 assert!(order_of_6 < order_of_3);
743 }
744
745 #[async_std::test]
746 async fn simple_cycle() {
747 let (trace, err) = TestPlan::new(vec![(true, vec![1]), (true, vec![0])])
748 .trace()
749 .await;
750
751 assert_eq!(err, Some(Error::Cycle));
752 for job in trace {
753 assert_eq!(job, None);
754 }
755 }
756
757 #[async_std::test]
758 async fn complex_cycle() {
759 let (trace, err) = TestPlan::new(vec![
760 (true, vec![1, 2]),
761 (true, vec![3]),
762 (true, vec![1]),
763 (true, vec![2]),
764 ])
765 .trace()
766 .await;
767
768 assert_eq!(err, Some(Error::Cycle));
769 for job in trace {
770 assert_eq!(job, None);
771 }
772 }
773
774 #[async_std::test]
775 async fn concurrent_execution() {
776 #[derive(PartialEq)]
787 struct SleepJob(Duration);
788 impl IntoJob<(), ()> for SleepJob {
789 fn into_job(&self) -> Job<(), ()> {
790 let dur = self.0;
791 Box::new(move |_| Box::pin(async move {
792 task::sleep(dur).await;
793 Ok(Outcome::Success)
794 }))
795 }
796 }
797
798 struct PseudoJob;
799 impl IntoJob<(), ()> for PseudoJob {
800 fn plan(&self, plan: &mut PlanBuilder<(), ()>) -> Result<(), Error<()>> {
801 plan.add_dependency(SleepJob(Duration::from_millis(60)))?;
802 plan.add_dependency(SleepJob(Duration::from_millis(80)))?;
803 Ok(())
804 }
805 fn into_job(&self) -> Job<(), ()> {
806 Box::new(|_| Box::pin(async {
807 Ok(Outcome::Success)
808 }))
809 }
810 }
811
812 let mut sched = Scheduler::new();
813 sched.max_jobs(2);
814
815 let start = Instant::now();
816 let res = sched.run(PseudoJob).await;
817 let end = Instant::now();
818
819 assert!(res.is_ok());
820 assert!(end - start < Duration::from_millis(140));
821 }
822}