use rand::SeedableRng;
use rand::rngs::SmallRng;
use rand_distr::{Beta, Distribution};
use std::sync::atomic::{AtomicU64, Ordering};
use super::{BackpressureState, Scheduler};
use crate::unified_pipeline::base::{ActiveSteps, PipelineStep};
static SEED_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct ThompsonSamplingScheduler {
thread_id: usize,
num_threads: usize,
alphas: [f64; 9],
betas: [f64; 9],
rng: SmallRng,
priority_buffer: [PipelineStep; 9],
active_steps: ActiveSteps,
}
impl ThompsonSamplingScheduler {
#[must_use]
pub fn new(thread_id: usize, num_threads: usize, active_steps: ActiveSteps) -> Self {
let seed = SEED_COUNTER
.fetch_add(1, Ordering::Relaxed)
.wrapping_add(thread_id as u64)
.wrapping_mul(0x9E37_79B9_7F4A_7C15); Self {
thread_id,
num_threads,
alphas: [1.0; 9], betas: [1.0; 9],
rng: SmallRng::seed_from_u64(seed),
priority_buffer: PipelineStep::all(),
active_steps,
}
}
fn sample_beta(&mut self, alpha: f64, beta: f64) -> f64 {
let alpha = alpha.clamp(0.001, 10000.0);
let beta = beta.clamp(0.001, 10000.0);
match Beta::new(alpha, beta) {
Ok(dist) => dist.sample(&mut self.rng),
Err(_) => alpha / (alpha + beta), }
}
}
impl Scheduler for ThompsonSamplingScheduler {
fn get_priorities(&mut self, _backpressure: BackpressureState) -> &[PipelineStep] {
let mut samples: [(f64, usize); 9] = [(0.0, 0); 9];
#[allow(clippy::needless_range_loop)]
for i in 0..9 {
samples[i] = (self.sample_beta(self.alphas[i], self.betas[i]), i);
}
samples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (priority, (_, step_idx)) in samples.iter().enumerate() {
self.priority_buffer[priority] = PipelineStep::all()[*step_idx];
}
let n = self.active_steps.filter_in_place(&mut self.priority_buffer);
&self.priority_buffer[..n]
}
fn record_outcome(&mut self, step: PipelineStep, success: bool, _was_contention: bool) {
let idx = step.index();
if success {
self.alphas[idx] += 1.0;
} else {
self.betas[idx] += 1.0;
}
}
fn thread_id(&self) -> usize {
self.thread_id
}
fn num_threads(&self) -> usize {
self.num_threads
}
fn active_steps(&self) -> &ActiveSteps {
&self.active_steps
}
}
#[cfg(test)]
mod tests {
use super::*;
fn all() -> ActiveSteps {
ActiveSteps::all()
}
#[test]
fn test_initial_uniform_prior() {
let scheduler = ThompsonSamplingScheduler::new(0, 8, all());
assert!((scheduler.alphas[0] - 1.0).abs() < f64::EPSILON);
assert!((scheduler.betas[0] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_update_on_success() {
let mut scheduler = ThompsonSamplingScheduler::new(0, 8, all());
scheduler.record_outcome(PipelineStep::Read, true, false);
assert!((scheduler.alphas[0] - 2.0).abs() < f64::EPSILON);
assert!((scheduler.betas[0] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_update_on_failure() {
let mut scheduler = ThompsonSamplingScheduler::new(0, 8, all());
scheduler.record_outcome(PipelineStep::Read, false, false);
assert!((scheduler.alphas[0] - 1.0).abs() < f64::EPSILON);
assert!((scheduler.betas[0] - 2.0).abs() < f64::EPSILON);
}
#[test]
fn test_get_priorities_returns_all_steps() {
let mut scheduler = ThompsonSamplingScheduler::new(0, 8, all());
let bp = BackpressureState::default();
let priorities = scheduler.get_priorities(bp);
assert_eq!(priorities.len(), 9);
}
#[test]
fn test_learned_preference() {
let mut scheduler = ThompsonSamplingScheduler::new(0, 8, all());
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Read, true, false);
}
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Write, false, false);
}
assert!(scheduler.alphas[0] > 50.0); assert!(scheduler.betas[8] > 50.0); }
}