use crate::runtime::task::{Task, TaskId, DEFAULT_INLINE_TASKS};
use crate::scheduler::data::random::RandomDataSource;
use crate::scheduler::data::DataSource;
use crate::scheduler::{Schedule, Scheduler};
use rand::rngs::OsRng;
use rand::seq::{index::sample, SliceRandom};
use rand::{Rng, RngCore, SeedableRng};
use rand_pcg::Pcg64Mcg;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct PctScheduler {
max_iterations: usize,
max_depth: usize,
iterations: usize,
priorities: HashMap<TaskId, usize>,
next_priority: usize,
change_points: Vec<usize>,
max_steps: usize,
steps: usize,
rng: Pcg64Mcg,
data_source: RandomDataSource,
}
impl PctScheduler {
pub fn new(max_depth: usize, max_iterations: usize) -> Self {
Self::new_from_seed(OsRng.next_u64(), max_depth, max_iterations)
}
pub fn new_from_seed(seed: u64, max_depth: usize, max_iterations: usize) -> Self {
assert!(max_depth > 0);
let seed_env = std::env::var("SHUTTLE_RANDOM_SEED");
let seed = match seed_env {
Ok(s) => match s.as_str().parse::<u64>() {
Ok(seed) => {
tracing::info!(
"Initializing PctScheduler with the seed provided by SHUTTLE_RANDOM_SEED: {}",
seed
);
seed
}
Err(err) => panic!("The seed provided by SHUTTLE_RANDOM_SEED is not a valid u64: {}", err),
},
Err(_) => seed,
};
let rng = Pcg64Mcg::seed_from_u64(seed);
Self {
max_iterations,
max_depth,
iterations: 0,
priorities: (0..DEFAULT_INLINE_TASKS).map(|i| (TaskId::from(i), i)).collect(),
next_priority: DEFAULT_INLINE_TASKS,
change_points: vec![],
max_steps: 0,
steps: 0,
rng,
data_source: RandomDataSource::initialize(seed),
}
}
}
impl Scheduler for PctScheduler {
fn new_execution(&mut self) -> Option<Schedule> {
if self.iterations >= self.max_iterations {
return None;
}
self.steps = 0;
if self.iterations > 0 {
assert!(self.max_steps > 0, "test closure did not exercise any concurrency");
debug_assert_eq!(
self.priorities.iter().collect::<HashSet<_>>().len(),
self.priorities.len()
);
let mut priorities = (0..self.priorities.len()).collect::<Vec<_>>();
priorities.shuffle(&mut self.rng);
for (i, priority) in priorities.into_iter().enumerate() {
let old = self.priorities.insert(TaskId::from(i), priority);
debug_assert!(old.is_some(), "priority queue invariant");
}
self.next_priority = self.priorities.len();
let num_points = std::cmp::min(self.max_depth - 1, self.max_steps - 1);
self.change_points = sample(&mut self.rng, self.max_steps - 1, num_points)
.iter()
.map(|v| v + 1)
.collect::<Vec<_>>();
}
self.iterations += 1;
Some(Schedule::new(self.data_source.reinitialize()))
}
fn next_task(&mut self, runnable: &[&Task], current: Option<TaskId>, is_yielding: bool) -> Option<TaskId> {
let max_known_task = self.priorities.len();
let max_new_task = usize::from(runnable.iter().map(|t| t.id()).max().unwrap());
for new_task_id in max_known_task..1 + max_new_task {
let new_task_id = TaskId::from(new_task_id);
let target_task_id = TaskId::from(self.rng.gen_range(0..self.priorities.len()) + 1);
let new_task_priority = if target_task_id == new_task_id {
self.next_priority
} else {
self.priorities
.insert(target_task_id, self.next_priority)
.expect("priority queue invariant")
};
let old = self.priorities.insert(new_task_id, new_task_priority);
debug_assert!(old.is_none(), "priority queue invariant");
self.next_priority += 1;
}
if runnable.len() > 1 {
if self.change_points.contains(&self.steps) || is_yielding {
let current = current.expect("self.steps > 0 should mean a task has run");
let old = self.priorities.insert(current, self.next_priority);
debug_assert!(old.is_some(), "priority queue invariant");
self.next_priority += 1;
}
self.steps += 1;
if self.steps > self.max_steps {
self.max_steps = self.steps;
}
}
Some(
runnable
.iter()
.min_by_key(|t| self.priorities.get(&t.id()))
.expect("priority queue invariant")
.id(),
)
}
fn next_u64(&mut self) -> u64 {
self.data_source.next_u64()
}
}