use crate::runtime::task::{Task, TaskId};
use crate::scheduler::data::random::RandomDataSource;
use crate::scheduler::data::DataSource;
use crate::scheduler::{Schedule, Scheduler};
use crate::seed_from_env;
use rand::rngs::OsRng;
use rand::seq::SliceRandom;
use rand::{RngCore, SeedableRng};
use rand_pcg::Pcg64Mcg;
use std::collections::{HashMap, HashSet};
use tracing::{trace, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct SignatureHash(u64);
impl From<u64> for SignatureHash {
fn from(value: u64) -> Self {
SignatureHash(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum UrwSchedulerState {
PreEstimation,
Estimating,
Initialized,
}
#[derive(Debug)]
pub struct UrwRandomScheduler {
max_iterations: usize,
rng: Pcg64Mcg,
iterations: usize,
data_source: RandomDataSource,
task_event_counts: Option<Vec<usize>>,
signature_event_counts: HashMap<SignatureHash, usize>,
min_event_count: usize,
signature_parents: Vec<(SignatureHash, SignatureHash)>,
state: UrwSchedulerState,
}
impl UrwRandomScheduler {
pub fn new(max_iterations: usize) -> Self {
Self::new_from_seed(OsRng.next_u64(), max_iterations)
}
pub fn new_from_seed(seed: u64, max_iterations: usize) -> Self {
let seed = seed_from_env(seed);
let rng = Pcg64Mcg::seed_from_u64(seed);
Self {
max_iterations,
rng,
iterations: 0,
data_source: RandomDataSource::initialize(seed),
task_event_counts: None,
signature_event_counts: HashMap::new(),
min_event_count: usize::MAX,
signature_parents: Vec::new(),
state: UrwSchedulerState::PreEstimation,
}
}
fn initialize_estimates_from_observed_counts(&mut self) {
assert_eq!(self.state, UrwSchedulerState::Estimating);
trace!("Finished estimation of event counts for URW");
trace!(
"Estimated event counts for URW (pre-parent subsumption): {:?}",
self.signature_event_counts
);
debug_assert!(
self.signature_event_counts
.keys()
.cloned()
.collect::<HashSet<_>>()
.len()
== self.signature_event_counts.len()
);
for (parent_sig, child_sig) in self.signature_parents.iter().rev() {
let child_ct = *self.signature_event_counts.get(child_sig).unwrap();
self.signature_event_counts
.entry(*parent_sig)
.and_modify(|parent_ct| *parent_ct += child_ct);
}
self.min_event_count = *self.signature_event_counts.values().min().unwrap();
trace!(
"Estimated event counts for URW (post-parent subsumption): {:?}",
self.signature_event_counts
);
self.task_event_counts = Some(Vec::new());
self.state = UrwSchedulerState::Initialized;
}
fn next_task_urw(&mut self, runnable: &[&Task]) -> Option<TaskId> {
let task_event_counts = self.task_event_counts.as_mut().unwrap();
for t in runnable {
let tid: usize = get_tid(t);
if tid == task_event_counts.len() {
let child_events = *self
.signature_event_counts
.get(&t.signature.signature_hash().into())
.unwrap_or_else(|| {
warn!(
"No event count for spawn of task with signature {}",
t.signature.signature_hash()
);
&self.min_event_count
});
task_event_counts.push(child_events);
if let Some(ptid) = t.parent_task_id() {
let ptid: usize = ptid.into();
task_event_counts[ptid] = task_event_counts[ptid].saturating_sub(child_events).max(1);
}
} else if tid > task_event_counts.len() {
panic!("TID's expected to be spawned in ascending order in increments of 1");
}
assert!(task_event_counts[tid] >= 1);
}
let next_tid = runnable
.choose_weighted(&mut self.rng, |t| task_event_counts[get_tid(t)])
.unwrap()
.id();
let next_tid_usize: usize = next_tid.into();
task_event_counts[next_tid_usize] = task_event_counts[next_tid_usize].saturating_sub(1).max(1);
trace!("URW remaining event counts: {:?}", task_event_counts);
Some(next_tid)
}
}
impl Scheduler for UrwRandomScheduler {
fn new_execution(&mut self) -> Option<Schedule> {
if self.iterations >= self.max_iterations {
self.signature_event_counts.clear();
self.signature_parents.clear();
self.task_event_counts = None;
return None;
}
match self.state {
UrwSchedulerState::PreEstimation => self.state = UrwSchedulerState::Estimating,
UrwSchedulerState::Estimating => self.initialize_estimates_from_observed_counts(),
UrwSchedulerState::Initialized => self.task_event_counts.as_mut().unwrap().clear(),
}
self.iterations += 1;
let seed = self.data_source.reinitialize();
self.rng = Pcg64Mcg::seed_from_u64(seed);
Some(Schedule::new(seed))
}
fn next_task(&mut self, runnable: &[&Task], _current: Option<TaskId>, _is_yielding: bool) -> Option<TaskId> {
match self.state {
UrwSchedulerState::PreEstimation => unreachable!(),
UrwSchedulerState::Estimating => {
let t = runnable.choose(&mut self.rng).unwrap();
self.signature_event_counts
.entry(t.signature.signature_hash().into())
.and_modify(|c| *c += 1)
.or_insert_with(|| {
self.signature_parents.push((
t.signature.parent_signature_hash().into(),
t.signature.signature_hash().into(),
));
1
});
Some(t.id())
}
UrwSchedulerState::Initialized => self.next_task_urw(runnable),
}
}
fn next_u64(&mut self) -> u64 {
self.data_source.next_u64()
}
}
#[inline]
fn get_tid(task: &Task) -> usize {
task.id().into()
}