use super::{BackpressureState, Scheduler};
use crate::unified_pipeline::base::{ActiveSteps, PipelineStep};
pub struct UCBScheduler {
thread_id: usize,
num_threads: usize,
total_attempts: u64,
attempts: [u64; 9],
successes: [u64; 9],
exploration_c: f64,
priority_buffer: [PipelineStep; 9],
active_steps: ActiveSteps,
}
impl UCBScheduler {
const DEFAULT_EXPLORATION_C: f64 = 1.414;
#[must_use]
pub fn new(thread_id: usize, num_threads: usize, active_steps: ActiveSteps) -> Self {
Self {
thread_id,
num_threads,
total_attempts: 0,
attempts: [0; 9],
successes: [0; 9],
exploration_c: Self::DEFAULT_EXPLORATION_C,
priority_buffer: PipelineStep::all(),
active_steps,
}
}
#[allow(clippy::cast_precision_loss)]
fn ucb_score(&self, step_idx: usize) -> f64 {
let n_i = self.attempts[step_idx];
if n_i == 0 {
return f64::INFINITY;
}
let mean = self.successes[step_idx] as f64 / n_i as f64;
let ln_n = (self.total_attempts.max(1) as f64).ln();
let exploration = self.exploration_c * (ln_n / n_i as f64).sqrt();
mean + exploration
}
}
impl Scheduler for UCBScheduler {
fn get_priorities(&mut self, _backpressure: BackpressureState) -> &[PipelineStep] {
let mut scores: [(f64, usize); 9] = [(0.0, 0); 9];
for (i, score) in scores.iter_mut().enumerate() {
*score = (self.ucb_score(i), i);
}
scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (priority, (_, step_idx)) in scores.iter().enumerate() {
self.priority_buffer[priority] = PipelineStep::all()[*step_idx];
}
let n = self.active_steps.filter_in_place(&mut self.priority_buffer);
&self.priority_buffer[..n]
}
fn record_outcome(&mut self, step: PipelineStep, success: bool, _was_contention: bool) {
let idx = step.index();
self.total_attempts += 1;
self.attempts[idx] += 1;
if success {
self.successes[idx] += 1;
}
}
fn thread_id(&self) -> usize {
self.thread_id
}
fn num_threads(&self) -> usize {
self.num_threads
}
fn active_steps(&self) -> &ActiveSteps {
&self.active_steps
}
}
#[cfg(test)]
mod tests {
use super::*;
fn all() -> ActiveSteps {
ActiveSteps::all()
}
#[test]
fn test_unexplored_steps_prioritized() {
let mut scheduler = UCBScheduler::new(0, 8, all());
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Read, true, false);
}
let bp = BackpressureState::default();
let priorities = scheduler.get_priorities(bp);
let read_pos = priorities
.iter()
.position(|&s| s == PipelineStep::Read)
.expect("PipelineStep::Read should be present in priorities");
assert!(read_pos > 0, "Well-explored Read should not be first");
}
#[test]
fn test_high_success_rate_preferred() {
let mut scheduler = UCBScheduler::new(0, 8, all());
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Read, true, false);
}
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Write, false, false);
}
let read_score = scheduler.ucb_score(0);
let write_score = scheduler.ucb_score(8);
assert!(read_score > write_score);
}
#[test]
fn test_exploration_decreases_with_attempts() {
let mut scheduler = UCBScheduler::new(0, 8, all());
scheduler.total_attempts = 1000;
scheduler.attempts[0] = 10;
scheduler.attempts[1] = 100;
scheduler.successes[0] = 5; scheduler.successes[1] = 50;
let score_few = scheduler.ucb_score(0);
let score_many = scheduler.ucb_score(1);
assert!(score_few > score_many);
}
}