use crate::state::{TaskRecord, TaskState};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
struct Prioritized {
priority: u8,
seq: u64,
id: String,
}
impl PartialEq for Prioritized {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority && self.seq == other.seq
}
}
impl Eq for Prioritized {}
impl Ord for Prioritized {
fn cmp(&self, other: &Self) -> Ordering {
self.priority
.cmp(&other.priority)
.then_with(|| other.seq.cmp(&self.seq))
}
}
impl PartialOrd for Prioritized {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Default)]
pub struct Scheduler {
records: HashMap<String, TaskRecord>,
ready: BinaryHeap<Prioritized>,
seq: u64,
}
impl Scheduler {
pub fn new() -> Self {
Self::default()
}
pub fn with_records(records: HashMap<String, TaskRecord>) -> Self {
Scheduler {
records,
ready: BinaryHeap::new(),
seq: 0,
}
}
pub fn ensure_record(&mut self, id: &str) -> &mut TaskRecord {
self.records
.entry(id.to_string())
.or_insert_with(|| TaskRecord::new(id))
}
pub fn record(&self, id: &str) -> Option<&TaskRecord> {
self.records.get(id)
}
pub fn records(&self) -> &HashMap<String, TaskRecord> {
&self.records
}
pub fn records_mut(&mut self) -> &mut HashMap<String, TaskRecord> {
&mut self.records
}
pub fn state(&self, id: &str) -> Option<TaskState> {
self.records.get(id).map(|r| r.state)
}
pub fn transition(&mut self, id: &str, state: TaskState) -> bool {
match self.records.get_mut(id) {
Some(r) => r.transition(state),
None => false,
}
}
pub fn mark_ready(&mut self, id: &str, priority: u8) {
let applied = {
let record = self.ensure_record(id);
record.transition(TaskState::Ready)
};
if applied {
let seq = self.seq;
self.seq += 1;
self.ready.push(Prioritized {
priority,
seq,
id: id.to_string(),
});
}
}
pub fn next_ready(&mut self) -> Option<String> {
while let Some(entry) = self.ready.pop() {
if self.state(&entry.id) == Some(TaskState::Ready) {
return Some(entry.id);
}
}
None
}
pub fn has_ready(&self) -> bool {
!self.ready.is_empty()
}
pub fn all_terminal(&self) -> bool {
self.records.values().all(|r| r.state.is_terminal())
}
pub fn count_in(&self, state: TaskState) -> usize {
self.records.values().filter(|r| r.state == state).count()
}
}