use super::{BackpressureState, Direction, Scheduler};
use crate::unified_pipeline::base::{ActiveSteps, PipelineStep};
pub struct BalancedChaseScheduler {
thread_id: usize,
#[allow(dead_code)]
num_threads: usize,
current_step: PipelineStep,
direction: Direction,
priority_buffer: [PipelineStep; 9],
exclusive_role: Option<PipelineStep>,
active_steps: ActiveSteps,
}
impl BalancedChaseScheduler {
#[must_use]
pub fn new(thread_id: usize, num_threads: usize, active_steps: ActiveSteps) -> Self {
assert!(num_threads > 0, "num_threads must be > 0");
let (current_step, exclusive_role) = Self::determine_role(thread_id, num_threads);
Self {
thread_id,
num_threads,
current_step,
direction: Direction::Forward,
priority_buffer: PipelineStep::all(),
exclusive_role,
active_steps,
}
}
fn determine_role(
thread_id: usize,
num_threads: usize,
) -> (PipelineStep, Option<PipelineStep>) {
super::balanced_chase_determine_role(thread_id, num_threads)
}
fn build_priorities(&mut self, bp: BackpressureState) {
use PipelineStep::{
Compress, Decode, Decompress, FindBoundaries, Group, Process, Read, Serialize, Write,
};
let mut priorities = Vec::with_capacity(9);
if let Some(role) = self.exclusive_role {
priorities.push(role);
}
if bp.output_high {
if !priorities.contains(&Compress) {
priorities.push(Compress);
}
if !priorities.contains(&Serialize) {
priorities.push(Serialize);
}
} else {
if !priorities.contains(&Serialize) {
priorities.push(Serialize);
}
if !priorities.contains(&Compress) {
priorities.push(Compress);
}
}
if !priorities.contains(&self.current_step) {
priorities.push(self.current_step);
}
let parallel_order: &[PipelineStep] = match self.direction {
Direction::Forward => &[Process, Decode, Decompress],
Direction::Backward => &[Decompress, Decode, Process],
};
for &step in parallel_order {
if !priorities.contains(&step) {
priorities.push(step);
}
}
let exclusive_steps = [Read, FindBoundaries, Group, Write];
for &step in &exclusive_steps {
if !priorities.contains(&step) {
priorities.push(step);
}
}
for (i, &step) in priorities.iter().take(9).enumerate() {
self.priority_buffer[i] = step;
}
}
}
impl Scheduler for BalancedChaseScheduler {
fn get_priorities(&mut self, bp: BackpressureState) -> &[PipelineStep] {
if bp.output_high {
self.direction = Direction::Forward;
} else if bp.input_low && !bp.read_done {
self.direction = Direction::Backward;
}
self.build_priorities(bp);
let n = self.active_steps.filter_in_place(&mut self.priority_buffer);
&self.priority_buffer[..n]
}
fn record_outcome(&mut self, step: PipelineStep, success: bool, _was_contention: bool) {
if success {
if self.exclusive_role == Some(step) {
self.current_step = PipelineStep::Compress;
} else {
self.current_step = step;
}
} else {
let idx = self.current_step.index();
self.current_step = match self.direction {
Direction::Forward => {
if idx < 7 {
PipelineStep::all()[idx + 1]
} else {
PipelineStep::Compress
}
}
Direction::Backward => {
if idx > 1 && idx != 7 && idx != 6 {
PipelineStep::all()[idx - 1]
} else {
PipelineStep::Serialize
}
}
};
}
}
fn thread_id(&self) -> usize {
self.thread_id
}
fn num_threads(&self) -> usize {
self.num_threads
}
fn active_steps(&self) -> &ActiveSteps {
&self.active_steps
}
}
#[cfg(test)]
mod tests {
use super::*;
fn all() -> ActiveSteps {
ActiveSteps::all()
}
#[test]
fn test_reader_starts_on_compress() {
let scheduler = BalancedChaseScheduler::new(0, 8, all());
assert_eq!(scheduler.current_step, PipelineStep::Compress);
assert_eq!(scheduler.exclusive_role, Some(PipelineStep::Read));
}
#[test]
fn test_writer_starts_on_compress() {
let scheduler = BalancedChaseScheduler::new(7, 8, all());
assert_eq!(scheduler.current_step, PipelineStep::Compress);
assert_eq!(scheduler.exclusive_role, Some(PipelineStep::Write));
}
#[test]
fn test_exclusive_role_first_in_priorities() {
let mut scheduler = BalancedChaseScheduler::new(0, 8, all());
let bp = BackpressureState::default();
let priorities = scheduler.get_priorities(bp);
assert_eq!(priorities[0], PipelineStep::Read);
assert!(
priorities[1] == PipelineStep::Serialize || priorities[1] == PipelineStep::Compress
);
}
#[test]
fn test_pivot_to_compress_after_exclusive() {
let mut scheduler = BalancedChaseScheduler::new(0, 8, all());
scheduler.record_outcome(PipelineStep::Read, true, false);
assert_eq!(scheduler.current_step, PipelineStep::Compress);
}
#[test]
fn test_middle_thread_no_exclusive_role() {
let scheduler = BalancedChaseScheduler::new(3, 8, all());
assert!(scheduler.exclusive_role.is_none());
}
#[test]
fn test_bottleneck_always_in_top_priorities() {
let mut scheduler = BalancedChaseScheduler::new(3, 8, all());
let bp = BackpressureState::default();
let priorities = scheduler.get_priorities(bp);
let compress_pos = priorities.iter().position(|&s| s == PipelineStep::Compress);
let serialize_pos = priorities.iter().position(|&s| s == PipelineStep::Serialize);
assert!(compress_pos.expect("compress position should be Some") < 3);
assert!(serialize_pos.expect("serialize position should be Some") < 3);
}
#[test]
#[should_panic(expected = "num_threads must be > 0")]
fn test_zero_threads_panics() {
let _ = BalancedChaseScheduler::new(0, 0, all());
}
}