Skip to main content

dag_executor/dag/
scheduler.rs

1//! Priority-aware scheduling and task-state bookkeeping.
2
3use crate::state::{TaskRecord, TaskState};
4use std::cmp::Ordering;
5use std::collections::{BinaryHeap, HashMap};
6
7/// A ready task waiting in the priority queue.
8struct Prioritized {
9    priority: u8,
10    /// Insertion order, used to break ties FIFO.
11    seq: u64,
12    id: String,
13}
14
15impl PartialEq for Prioritized {
16    fn eq(&self, other: &Self) -> bool {
17        self.priority == other.priority && self.seq == other.seq
18    }
19}
20impl Eq for Prioritized {}
21
22impl Ord for Prioritized {
23    fn cmp(&self, other: &Self) -> Ordering {
24        // Higher priority first; among equal priorities, lower seq (enqueued
25        // earlier) wins — so reverse the seq comparison for the max-heap.
26        self.priority
27            .cmp(&other.priority)
28            .then_with(|| other.seq.cmp(&self.seq))
29    }
30}
31impl PartialOrd for Prioritized {
32    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33        Some(self.cmp(other))
34    }
35}
36
37/// Tracks the state of every task and serves the next ready task by priority.
38///
39/// The scheduler owns the authoritative [`TaskRecord`] map. The executor pushes
40/// tasks in as they become runnable and pulls the highest-priority one out.
41#[derive(Default)]
42pub struct Scheduler {
43    records: HashMap<String, TaskRecord>,
44    ready: BinaryHeap<Prioritized>,
45    seq: u64,
46}
47
48impl Scheduler {
49    /// Create an empty scheduler.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Seed the scheduler with previously persisted records (for recovery).
55    pub fn with_records(records: HashMap<String, TaskRecord>) -> Self {
56        Scheduler {
57            records,
58            ready: BinaryHeap::new(),
59            seq: 0,
60        }
61    }
62
63    /// Ensure a record exists for `id`, creating a `Pending` one if not.
64    pub fn ensure_record(&mut self, id: &str) -> &mut TaskRecord {
65        self.records
66            .entry(id.to_string())
67            .or_insert_with(|| TaskRecord::new(id))
68    }
69
70    /// Immutable access to a record.
71    pub fn record(&self, id: &str) -> Option<&TaskRecord> {
72        self.records.get(id)
73    }
74
75    /// The full record map.
76    pub fn records(&self) -> &HashMap<String, TaskRecord> {
77        &self.records
78    }
79
80    /// Mutable access to the full record map.
81    pub fn records_mut(&mut self) -> &mut HashMap<String, TaskRecord> {
82        &mut self.records
83    }
84
85    /// Current state of `id`, if known.
86    pub fn state(&self, id: &str) -> Option<TaskState> {
87        self.records.get(id).map(|r| r.state)
88    }
89
90    /// Apply a state transition, returning whether it was legal/applied.
91    pub fn transition(&mut self, id: &str, state: TaskState) -> bool {
92        match self.records.get_mut(id) {
93            Some(r) => r.transition(state),
94            None => false,
95        }
96    }
97
98    /// Mark `id` ready and enqueue it at `priority`.
99    pub fn mark_ready(&mut self, id: &str, priority: u8) {
100        let applied = {
101            let record = self.ensure_record(id);
102            record.transition(TaskState::Ready)
103        };
104        if applied {
105            let seq = self.seq;
106            self.seq += 1;
107            self.ready.push(Prioritized {
108                priority,
109                seq,
110                id: id.to_string(),
111            });
112        }
113    }
114
115    /// Pop the highest-priority task that is still in the `Ready` state.
116    ///
117    /// Stale heap entries (whose record has since moved on) are discarded.
118    pub fn next_ready(&mut self) -> Option<String> {
119        while let Some(entry) = self.ready.pop() {
120            if self.state(&entry.id) == Some(TaskState::Ready) {
121                return Some(entry.id);
122            }
123        }
124        None
125    }
126
127    /// Whether any task is queued ready to run.
128    pub fn has_ready(&self) -> bool {
129        // May overcount stale entries, but is only used as a cheap hint.
130        !self.ready.is_empty()
131    }
132
133    /// Whether every known task has reached a terminal state.
134    pub fn all_terminal(&self) -> bool {
135        self.records.values().all(|r| r.state.is_terminal())
136    }
137
138    /// Count records currently in `state`.
139    pub fn count_in(&self, state: TaskState) -> usize {
140        self.records.values().filter(|r| r.state == state).count()
141    }
142}