use clap::ValueEnum;
use super::base::{ActiveSteps, PipelineStep};
#[doc(hidden)]
pub mod backpressure_proportional;
#[doc(hidden)]
pub mod balanced_chase;
#[doc(hidden)]
pub mod balanced_chase_drain;
#[doc(hidden)]
pub mod chase_bottleneck;
#[doc(hidden)]
pub mod epsilon_greedy;
#[doc(hidden)]
pub mod fixed_priority;
#[doc(hidden)]
pub mod hybrid_adaptive;
#[doc(hidden)]
pub mod learned_affinity;
#[doc(hidden)]
pub mod optimized_chase;
#[doc(hidden)]
pub mod sticky_work_stealing;
#[doc(hidden)]
pub mod thompson_sampling;
#[doc(hidden)]
pub mod thompson_with_priors;
#[doc(hidden)]
pub mod two_phase;
#[doc(hidden)]
pub mod ucb;
#[doc(hidden)]
pub use backpressure_proportional::BackpressureProportionalScheduler;
#[doc(hidden)]
pub use balanced_chase::BalancedChaseScheduler;
#[doc(hidden)]
pub use balanced_chase_drain::BalancedChaseDrainScheduler;
#[doc(hidden)]
pub use chase_bottleneck::ChaseBottleneckScheduler;
#[doc(hidden)]
pub use epsilon_greedy::EpsilonGreedyScheduler;
#[doc(hidden)]
pub use fixed_priority::FixedPriorityScheduler;
#[doc(hidden)]
pub use hybrid_adaptive::HybridAdaptiveScheduler;
#[doc(hidden)]
pub use learned_affinity::LearnedAffinityScheduler;
#[doc(hidden)]
pub use optimized_chase::OptimizedChaseScheduler;
#[doc(hidden)]
pub use sticky_work_stealing::StickyWorkStealingScheduler;
#[doc(hidden)]
pub use thompson_sampling::ThompsonSamplingScheduler;
#[doc(hidden)]
pub use thompson_with_priors::ThompsonWithPriorsScheduler;
#[doc(hidden)]
pub use two_phase::TwoPhaseScheduler;
#[doc(hidden)]
pub use ucb::UCBScheduler;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
pub enum SchedulerStrategy {
#[value(name = "fixed-priority")]
FixedPriority,
#[value(name = "chase-bottleneck")]
ChaseBottleneck,
#[value(name = "thompson-sampling")]
ThompsonSampling,
#[value(name = "ucb")]
UCB,
#[value(name = "epsilon-greedy")]
EpsilonGreedy,
#[value(name = "thompson-with-priors")]
ThompsonWithPriors,
#[value(name = "hybrid-adaptive")]
HybridAdaptive,
#[value(name = "backpressure-proportional")]
BackpressureProportional,
#[value(name = "two-phase")]
TwoPhase,
#[value(name = "sticky-work-stealing")]
StickyWorkStealing,
#[value(name = "learned-affinity")]
LearnedAffinity,
#[value(name = "optimized-chase")]
OptimizedChase,
#[value(name = "balanced-chase")]
BalancedChase,
#[default]
#[value(name = "balanced-chase-drain")]
BalancedChaseDrain,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Forward,
Backward,
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::struct_excessive_bools)]
pub struct BackpressureState {
pub output_high: bool,
pub input_low: bool,
pub read_done: bool,
pub memory_high: bool,
pub memory_drained: bool,
}
impl Default for BackpressureState {
fn default() -> Self {
Self {
output_high: false,
input_low: false,
read_done: false,
memory_high: false,
memory_drained: true,
}
}
}
pub trait Scheduler: Send {
fn get_priorities(&mut self, backpressure: BackpressureState) -> &[PipelineStep];
fn record_outcome(&mut self, step: PipelineStep, success: bool, was_contention: bool);
fn thread_id(&self) -> usize;
fn num_threads(&self) -> usize;
fn active_steps(&self) -> &ActiveSteps;
fn exclusive_step_owned(&self) -> Option<PipelineStep> {
let num_threads = self.num_threads();
let thread_id = self.thread_id();
if num_threads < 4 {
return None;
}
let exclusive = self.active_steps().exclusive_steps();
if exclusive.is_empty() {
return None;
}
if thread_id == 0 {
return Some(exclusive[0]);
}
if thread_id == num_threads - 1 && exclusive.len() > 1 {
return Some(
*exclusive.last().expect("exclusive is non-empty because len > 1 was checked"),
);
}
let num_interior = exclusive.len().saturating_sub(2);
if num_interior == 0 {
return None;
}
let front_count = num_interior.div_ceil(2);
if thread_id >= 1 && thread_id <= front_count {
return Some(exclusive[thread_id]);
}
let back_count = num_interior - front_count;
let back_thread_offset = num_threads - 1 - thread_id;
if back_thread_offset >= 1 && back_thread_offset <= back_count {
return Some(exclusive[exclusive.len() - 1 - back_thread_offset]);
}
None
}
fn should_attempt_step(&self, step: PipelineStep) -> bool {
if !step.is_exclusive() {
return true;
}
let num_threads = self.num_threads();
if num_threads < 4 {
return true;
}
match self.exclusive_step_owned() {
Some(owned) => owned == step,
None => false,
}
}
fn should_attempt_step_with_drain(&self, step: PipelineStep, drain_mode: bool) -> bool {
if !step.is_exclusive() {
return true;
}
if drain_mode {
return true;
}
let num_threads = self.num_threads();
if num_threads < 4 {
return true;
}
match self.exclusive_step_owned() {
Some(owned) => owned == step,
None => false,
}
}
}
fn balanced_chase_determine_role(
thread_id: usize,
num_threads: usize,
) -> (PipelineStep, Option<PipelineStep>) {
use PipelineStep::{Compress, FindBoundaries, Group, Read, Serialize, Write};
if thread_id == 0 {
(Compress, Some(Read))
} else if thread_id == num_threads - 1 && num_threads > 1 {
(Compress, Some(Write))
} else if thread_id == 1 && num_threads > 2 {
(FindBoundaries, Some(FindBoundaries))
} else if thread_id == num_threads - 2 && num_threads > 3 {
(Group, Some(Group))
} else {
let step = if thread_id.is_multiple_of(2) { Compress } else { Serialize };
(step, None)
}
}
#[must_use]
pub fn create_scheduler(
strategy: SchedulerStrategy,
thread_id: usize,
num_threads: usize,
active_steps: ActiveSteps,
) -> Box<dyn Scheduler> {
match strategy {
SchedulerStrategy::FixedPriority => {
Box::new(FixedPriorityScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::ChaseBottleneck => {
Box::new(ChaseBottleneckScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::ThompsonSampling => {
Box::new(ThompsonSamplingScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::UCB => Box::new(UCBScheduler::new(thread_id, num_threads, active_steps)),
SchedulerStrategy::EpsilonGreedy => {
Box::new(EpsilonGreedyScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::ThompsonWithPriors => {
Box::new(ThompsonWithPriorsScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::HybridAdaptive => {
Box::new(HybridAdaptiveScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::BackpressureProportional => {
Box::new(BackpressureProportionalScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::TwoPhase => {
Box::new(TwoPhaseScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::StickyWorkStealing => {
Box::new(StickyWorkStealingScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::LearnedAffinity => {
Box::new(LearnedAffinityScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::OptimizedChase => {
Box::new(OptimizedChaseScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::BalancedChase => {
Box::new(BalancedChaseScheduler::new(thread_id, num_threads, active_steps))
}
SchedulerStrategy::BalancedChaseDrain => {
Box::new(BalancedChaseDrainScheduler::new(thread_id, num_threads, active_steps))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn all() -> ActiveSteps {
ActiveSteps::all()
}
#[test]
fn test_scheduler_strategy_default() {
assert_eq!(SchedulerStrategy::default(), SchedulerStrategy::BalancedChaseDrain);
}
#[rstest]
#[case::fixed_priority(SchedulerStrategy::FixedPriority)]
#[case::chase_bottleneck(SchedulerStrategy::ChaseBottleneck)]
#[case::thompson_sampling(SchedulerStrategy::ThompsonSampling)]
#[case::ucb(SchedulerStrategy::UCB)]
#[case::epsilon_greedy(SchedulerStrategy::EpsilonGreedy)]
#[case::thompson_with_priors(SchedulerStrategy::ThompsonWithPriors)]
#[case::hybrid_adaptive(SchedulerStrategy::HybridAdaptive)]
#[case::backpressure_proportional(SchedulerStrategy::BackpressureProportional)]
#[case::two_phase(SchedulerStrategy::TwoPhase)]
#[case::sticky_work_stealing(SchedulerStrategy::StickyWorkStealing)]
#[case::learned_affinity(SchedulerStrategy::LearnedAffinity)]
#[case::optimized_chase(SchedulerStrategy::OptimizedChase)]
#[case::balanced_chase(SchedulerStrategy::BalancedChase)]
#[case::balanced_chase_drain(SchedulerStrategy::BalancedChaseDrain)]
fn test_create_scheduler(#[case] strategy: SchedulerStrategy) {
let scheduler = create_scheduler(strategy, 0, 8, all());
assert_eq!(scheduler.thread_id(), 0);
}
#[test]
fn test_exclusive_step_ownership_small_thread_counts() {
let s0 = create_scheduler(SchedulerStrategy::BalancedChase, 0, 2, all());
let s1 = create_scheduler(SchedulerStrategy::BalancedChase, 1, 2, all());
assert!(s0.exclusive_step_owned().is_none());
assert!(s1.exclusive_step_owned().is_none());
assert!(s0.should_attempt_step(PipelineStep::Group));
assert!(s1.should_attempt_step(PipelineStep::Group));
let s0 = create_scheduler(SchedulerStrategy::BalancedChase, 0, 3, all());
let s1 = create_scheduler(SchedulerStrategy::BalancedChase, 1, 3, all());
let s2 = create_scheduler(SchedulerStrategy::BalancedChase, 2, 3, all());
assert!(s0.should_attempt_step(PipelineStep::FindBoundaries));
assert!(s1.should_attempt_step(PipelineStep::Group));
assert!(s2.should_attempt_step(PipelineStep::Write));
}
#[test]
fn test_exclusive_step_ownership_eight_threads() {
for thread_id in 0..8 {
let scheduler = create_scheduler(SchedulerStrategy::BalancedChase, thread_id, 8, all());
let expected_ownership = match thread_id {
0 => Some(PipelineStep::Read),
1 => Some(PipelineStep::FindBoundaries),
6 => Some(PipelineStep::Group), 7 => Some(PipelineStep::Write), _ => None,
};
assert_eq!(
scheduler.exclusive_step_owned(),
expected_ownership,
"Thread {thread_id} ownership mismatch"
);
}
}
#[test]
fn test_should_attempt_step_parallel_always_allowed() {
let scheduler = create_scheduler(SchedulerStrategy::BalancedChase, 3, 8, all());
assert!(scheduler.should_attempt_step(PipelineStep::Decompress));
assert!(scheduler.should_attempt_step(PipelineStep::Decode));
assert!(scheduler.should_attempt_step(PipelineStep::Process));
assert!(scheduler.should_attempt_step(PipelineStep::Serialize));
assert!(scheduler.should_attempt_step(PipelineStep::Compress));
}
#[test]
fn test_should_attempt_step_exclusive_only_owner() {
let t3 = create_scheduler(SchedulerStrategy::BalancedChase, 3, 8, all());
assert!(!t3.should_attempt_step(PipelineStep::Read));
assert!(!t3.should_attempt_step(PipelineStep::FindBoundaries));
assert!(!t3.should_attempt_step(PipelineStep::Group));
assert!(!t3.should_attempt_step(PipelineStep::Write));
let t6 = create_scheduler(SchedulerStrategy::BalancedChase, 6, 8, all());
assert!(!t6.should_attempt_step(PipelineStep::Read));
assert!(!t6.should_attempt_step(PipelineStep::FindBoundaries));
assert!(t6.should_attempt_step(PipelineStep::Group)); assert!(!t6.should_attempt_step(PipelineStep::Write));
}
#[test]
fn test_exclusive_ownership_all_strategies() {
let strategies = [
SchedulerStrategy::FixedPriority,
SchedulerStrategy::ChaseBottleneck,
SchedulerStrategy::BalancedChase,
SchedulerStrategy::OptimizedChase,
SchedulerStrategy::ThompsonSampling,
SchedulerStrategy::UCB,
SchedulerStrategy::EpsilonGreedy,
SchedulerStrategy::ThompsonWithPriors,
SchedulerStrategy::HybridAdaptive,
SchedulerStrategy::BackpressureProportional,
SchedulerStrategy::TwoPhase,
SchedulerStrategy::StickyWorkStealing,
SchedulerStrategy::LearnedAffinity,
SchedulerStrategy::BalancedChaseDrain,
];
for strategy in strategies {
let t6 = create_scheduler(strategy, 6, 8, all());
assert_eq!(
t6.exclusive_step_owned(),
Some(PipelineStep::Group),
"{strategy:?} T6 should own Group"
);
let t3 = create_scheduler(strategy, 3, 8, all());
assert_eq!(t3.exclusive_step_owned(), None, "{strategy:?} T3 should own nothing");
}
}
#[test]
fn test_four_thread_edge_case() {
let t0 = create_scheduler(SchedulerStrategy::BalancedChase, 0, 4, all());
let t1 = create_scheduler(SchedulerStrategy::BalancedChase, 1, 4, all());
let t2 = create_scheduler(SchedulerStrategy::BalancedChase, 2, 4, all());
let t3 = create_scheduler(SchedulerStrategy::BalancedChase, 3, 4, all());
assert_eq!(t0.exclusive_step_owned(), Some(PipelineStep::Read));
assert_eq!(t1.exclusive_step_owned(), Some(PipelineStep::FindBoundaries));
assert_eq!(t2.exclusive_step_owned(), Some(PipelineStep::Group));
assert_eq!(t3.exclusive_step_owned(), Some(PipelineStep::Write));
assert!(t0.should_attempt_step(PipelineStep::Process));
assert!(t1.should_attempt_step(PipelineStep::Process));
assert!(t2.should_attempt_step(PipelineStep::Process));
assert!(t3.should_attempt_step(PipelineStep::Process));
}
}