use std::sync::atomic::{AtomicU64, Ordering};
use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use rand::seq::SliceRandom;
use super::{BackpressureState, Scheduler};
use crate::unified_pipeline::base::{ActiveSteps, PipelineStep};
static SEED_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct LearnedAffinityScheduler {
thread_id: usize,
num_threads: usize,
successes: [u64; 9],
attempts: [u64; 9],
exploration_rate: f64,
exploration_decay: f64,
min_exploration_rate: f64,
total_attempts: u64,
rng: SmallRng,
priority_buffer: [PipelineStep; 9],
active_steps: ActiveSteps,
}
impl LearnedAffinityScheduler {
const DEFAULT_EXPLORATION_RATE: f64 = 0.3;
const DEFAULT_EXPLORATION_DECAY: f64 = 0.95;
const DEFAULT_MIN_EXPLORATION: f64 = 0.05;
#[must_use]
pub fn new(thread_id: usize, num_threads: usize, active_steps: ActiveSteps) -> Self {
let seed = SEED_COUNTER
.fetch_add(1, Ordering::Relaxed)
.wrapping_add(thread_id as u64)
.wrapping_mul(0x9E37_79B9_7F4A_7C15);
Self {
thread_id,
num_threads,
successes: [0; 9],
attempts: [0; 9],
exploration_rate: Self::DEFAULT_EXPLORATION_RATE,
exploration_decay: Self::DEFAULT_EXPLORATION_DECAY,
min_exploration_rate: Self::DEFAULT_MIN_EXPLORATION,
total_attempts: 0,
rng: SmallRng::seed_from_u64(seed),
priority_buffer: PipelineStep::all(),
active_steps,
}
}
#[expect(
clippy::cast_precision_loss,
reason = "affinity ratio doesn't need full u64 precision"
)]
fn affinity(&self, step_idx: usize) -> f64 {
if self.attempts[step_idx] == 0 {
0.5 } else {
self.successes[step_idx] as f64 / self.attempts[step_idx] as f64
}
}
#[expect(
clippy::cast_possible_truncation,
reason = "decay_periods won't exceed i32 range in any practical run"
)]
fn current_exploration_rate(&self) -> f64 {
let decay_periods = self.total_attempts / 1000;
let decayed = self.exploration_rate * self.exploration_decay.powi(decay_periods as i32);
decayed.max(self.min_exploration_rate)
}
}
impl Scheduler for LearnedAffinityScheduler {
fn get_priorities(&mut self, _backpressure: BackpressureState) -> &[PipelineStep] {
let explore_rate = self.current_exploration_rate();
if self.rng.random::<f64>() < explore_rate {
self.priority_buffer = PipelineStep::all();
self.priority_buffer.shuffle(&mut self.rng);
} else {
let mut affinities: [(f64, usize); 9] = [(0.0, 0); 9];
for (i, affinity) in affinities.iter_mut().enumerate() {
*affinity = (self.affinity(i), i);
}
affinities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (priority, (_, step_idx)) in affinities.iter().enumerate() {
self.priority_buffer[priority] = PipelineStep::all()[*step_idx];
}
}
let n = self.active_steps.filter_in_place(&mut self.priority_buffer);
&self.priority_buffer[..n]
}
fn record_outcome(&mut self, step: PipelineStep, success: bool, _was_contention: bool) {
let idx = step.index();
self.total_attempts += 1;
self.attempts[idx] += 1;
if success {
self.successes[idx] += 1;
}
}
fn thread_id(&self) -> usize {
self.thread_id
}
fn num_threads(&self) -> usize {
self.num_threads
}
fn active_steps(&self) -> &ActiveSteps {
&self.active_steps
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_exploration_rate() {
let scheduler = LearnedAffinityScheduler::new(0, 8, ActiveSteps::all());
assert!((scheduler.exploration_rate - 0.3).abs() < 0.001);
}
#[test]
fn test_exploration_decay() {
let mut scheduler = LearnedAffinityScheduler::new(0, 8, ActiveSteps::all());
let initial_rate = scheduler.current_exploration_rate();
scheduler.total_attempts = 2000;
let decayed_rate = scheduler.current_exploration_rate();
assert!(decayed_rate < initial_rate);
assert!(decayed_rate >= scheduler.min_exploration_rate);
}
#[test]
fn test_affinity_learning() {
let mut scheduler = LearnedAffinityScheduler::new(0, 8, ActiveSteps::all());
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Read, true, false);
}
for _ in 0..100 {
scheduler.record_outcome(PipelineStep::Write, false, false);
}
assert!(scheduler.affinity(0) > 0.9); assert!(scheduler.affinity(8) < 0.1); }
#[test]
fn test_minimum_exploration_rate() {
let mut scheduler = LearnedAffinityScheduler::new(0, 8, ActiveSteps::all());
scheduler.total_attempts = 1_000_000;
let rate = scheduler.current_exploration_rate();
assert!((rate - scheduler.min_exploration_rate).abs() < 0.001);
}
}