use std::collections::{BinaryHeap, HashMap};
use crate::{
emulation::{
thread::{EmulationThread, ThreadPriority, ThreadState, WaitReason},
EmValue, HeapRef, ThreadId,
},
Result,
};
#[derive(Clone, Debug)]
pub enum SchedulerOutcome {
AllCompleted,
LimitReached {
executed: u64,
},
QuantumExhausted {
thread_id: ThreadId,
},
Deadlock {
waiting_threads: Vec<ThreadId>,
},
ThreadCompleted {
thread_id: ThreadId,
return_value: Option<EmValue>,
},
ThreadFaulted {
thread_id: ThreadId,
exception: String,
},
ThreadAborted {
thread_id: ThreadId,
},
Continue,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct ScheduleEntry {
priority: ThreadPriority,
thread_id: ThreadId,
sequence: u64,
}
impl Ord for ScheduleEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority
.cmp(&other.priority)
.then_with(|| other.sequence.cmp(&self.sequence))
}
}
impl PartialOrd for ScheduleEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug)]
pub struct ThreadScheduler {
threads: HashMap<ThreadId, EmulationThread>,
ready_queue: BinaryHeap<ScheduleEntry>,
current: Option<ThreadId>,
quantum: usize,
quantum_used: usize,
total_instructions: u64,
next_sequence: u64,
next_thread_id: u32,
}
impl ThreadScheduler {
#[must_use]
pub fn new(quantum: usize) -> Self {
Self {
threads: HashMap::new(),
ready_queue: BinaryHeap::new(),
current: None,
quantum,
quantum_used: 0,
total_instructions: 0,
next_sequence: 0,
next_thread_id: 2, }
}
#[must_use]
pub fn with_default_quantum() -> Self {
Self::new(1000)
}
#[must_use]
pub fn quantum(&self) -> usize {
self.quantum
}
pub fn set_quantum(&mut self, quantum: usize) {
self.quantum = quantum;
}
#[must_use]
pub fn total_instructions(&self) -> u64 {
self.total_instructions
}
pub fn add_main_thread(&mut self, thread: EmulationThread) {
let id = thread.id();
self.threads.insert(id, thread);
self.enqueue_ready(id, ThreadPriority::Normal);
}
pub fn spawn(&mut self, thread: EmulationThread) -> ThreadId {
let id = thread.id();
let priority = thread.priority();
self.threads.insert(id, thread);
self.enqueue_ready(id, priority);
id
}
pub fn allocate_thread_id(&mut self) -> ThreadId {
let id = ThreadId::new(self.next_thread_id);
self.next_thread_id += 1;
id
}
#[must_use]
pub fn thread_count(&self) -> usize {
self.threads.len()
}
#[must_use]
pub fn ready_count(&self) -> usize {
self.ready_queue.len() + usize::from(self.current.is_some())
}
#[must_use]
pub fn current_thread_id(&self) -> Option<ThreadId> {
self.current
}
#[must_use]
pub fn get_thread(&self, id: ThreadId) -> Option<&EmulationThread> {
self.threads.get(&id)
}
pub fn get_thread_mut(&mut self, id: ThreadId) -> Option<&mut EmulationThread> {
self.threads.get_mut(&id)
}
#[must_use]
pub fn current_thread(&self) -> Option<&EmulationThread> {
self.current.and_then(|id| self.threads.get(&id))
}
pub fn current_thread_mut(&mut self) -> Option<&mut EmulationThread> {
self.current.and_then(|id| self.threads.get_mut(&id))
}
#[must_use]
pub fn all_completed(&self) -> bool {
self.threads.values().all(EmulationThread::is_completed)
}
#[must_use]
pub fn has_ready_threads(&self) -> bool {
self.current.is_some() || !self.ready_queue.is_empty()
}
fn enqueue_ready(&mut self, id: ThreadId, priority: ThreadPriority) {
let sequence = self.next_sequence;
self.next_sequence += 1;
self.ready_queue.push(ScheduleEntry {
priority,
thread_id: id,
sequence,
});
}
pub fn select_next(&mut self) -> Option<ThreadId> {
if let Some(current_id) = self.current {
if let Some(thread) = self.threads.get(¤t_id) {
if thread.is_ready() && self.quantum_used < self.quantum {
return Some(current_id);
}
}
}
if let Some(current_id) = self.current.take() {
if let Some(thread) = self.threads.get(¤t_id) {
if thread.is_ready() {
self.enqueue_ready(current_id, thread.priority());
}
}
}
self.quantum_used = 0;
while let Some(entry) = self.ready_queue.pop() {
if let Some(thread) = self.threads.get(&entry.thread_id) {
if thread.is_ready() {
self.current = Some(entry.thread_id);
return Some(entry.thread_id);
}
}
}
None
}
pub fn record_instruction(&mut self) -> bool {
self.total_instructions += 1;
self.quantum_used += 1;
if let Some(id) = self.current {
if let Some(thread) = self.threads.get_mut(&id) {
thread.increment_instructions();
}
}
self.quantum_used >= self.quantum
}
pub fn yield_current(&mut self) {
self.quantum_used = self.quantum; }
pub fn block_current(&mut self, reason: WaitReason) -> Result<()> {
if let Some(id) = self.current {
if let Some(thread) = self.threads.get_mut(&id) {
thread.set_state(ThreadState::Waiting(reason));
}
self.current = None;
self.quantum_used = 0;
}
Ok(())
}
pub fn wake(&mut self, condition: &WakeCondition) {
let threads_to_wake: Vec<ThreadId> = self
.threads
.iter()
.filter_map(|(id, thread)| {
if let ThreadState::Waiting(reason) = thread.state() {
if condition.matches(&reason) {
return Some(*id);
}
}
None
})
.collect();
for id in threads_to_wake {
let priority = if let Some(thread) = self.threads.get_mut(&id) {
thread.set_state(ThreadState::Ready);
Some(thread.priority())
} else {
None
};
if let Some(p) = priority {
self.enqueue_ready(id, p);
}
}
}
pub fn wake_thread(&mut self, id: ThreadId) {
let priority = if let Some(thread) = self.threads.get_mut(&id) {
if matches!(thread.state(), ThreadState::Waiting(_)) {
thread.set_state(ThreadState::Ready);
Some(thread.priority())
} else {
None
}
} else {
None
};
if let Some(p) = priority {
self.enqueue_ready(id, p);
}
}
pub fn complete_current(&mut self, return_value: Option<EmValue>) {
if let Some(id) = self.current.take() {
if let Some(thread) = self.threads.get_mut(&id) {
thread.set_return_value(return_value);
}
}
self.quantum_used = 0;
}
pub fn fault_current(&mut self) {
if let Some(id) = self.current.take() {
if let Some(thread) = self.threads.get_mut(&id) {
thread.fault();
}
}
self.quantum_used = 0;
}
pub fn abort_thread(&mut self, id: ThreadId) {
if let Some(thread) = self.threads.get_mut(&id) {
thread.abort();
}
if self.current == Some(id) {
self.current = None;
self.quantum_used = 0;
}
}
pub fn collect_completed(&mut self) -> Vec<EmulationThread> {
let completed_ids: Vec<ThreadId> = self
.threads
.iter()
.filter(|(_, t)| t.is_completed())
.map(|(id, _)| *id)
.collect();
completed_ids
.into_iter()
.filter_map(|id| self.threads.remove(&id))
.collect()
}
#[must_use]
pub fn check_state(&self) -> SchedulerOutcome {
if self.all_completed() {
return SchedulerOutcome::AllCompleted;
}
if !self.has_ready_threads() {
let waiting: Vec<ThreadId> = self
.threads
.iter()
.filter(|(_, t)| matches!(t.state(), ThreadState::Waiting(_)))
.map(|(id, _)| *id)
.collect();
if !waiting.is_empty() {
return SchedulerOutcome::Deadlock {
waiting_threads: waiting,
};
}
}
SchedulerOutcome::Continue
}
pub fn threads(&self) -> impl Iterator<Item = (&ThreadId, &EmulationThread)> {
self.threads.iter()
}
pub fn threads_mut(&mut self) -> impl Iterator<Item = (&ThreadId, &mut EmulationThread)> {
self.threads.iter_mut()
}
}
#[derive(Clone, Debug)]
pub enum WakeCondition {
Monitor(HeapRef),
Event(HeapRef),
Thread(ThreadId),
SleepElapsed {
current_instruction: u64,
},
Mutex(HeapRef),
Semaphore(HeapRef),
All,
}
impl WakeCondition {
#[must_use]
pub fn matches(&self, reason: &WaitReason) -> bool {
match (self, reason) {
(WakeCondition::Monitor(wake_href), WaitReason::Monitor(wait_href))
| (WakeCondition::Event(wake_href), WaitReason::Event(wait_href))
| (WakeCondition::Mutex(wake_href), WaitReason::Mutex(wait_href))
| (WakeCondition::Semaphore(wake_href), WaitReason::Semaphore(wait_href)) => {
wake_href == wait_href
}
(WakeCondition::Thread(tid), WaitReason::Thread(waiting_for)) => tid == waiting_for,
(
WakeCondition::SleepElapsed {
current_instruction,
},
WaitReason::Sleep { until_instruction },
) => *current_instruction >= *until_instruction,
(WakeCondition::All, _) => true,
_ => false,
}
}
}
impl Default for ThreadScheduler {
fn default() -> Self {
Self::with_default_quantum()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::emulation::{
capture::CaptureContext,
thread::{
EmulationThread, SchedulerOutcome, ThreadPriority, ThreadScheduler, WaitReason,
WakeCondition,
},
AddressSpace, EmValue, HeapRef, SharedFakeObjects, ThreadId,
};
fn create_test_thread(id: u32) -> EmulationThread {
let space = Arc::new(AddressSpace::new());
let capture = Arc::new(CaptureContext::new());
let fake_objects = SharedFakeObjects::new(space.managed_heap());
EmulationThread::new(ThreadId::new(id), space, capture, None, fake_objects)
}
#[test]
fn test_scheduler_creation() {
let scheduler = ThreadScheduler::new(500);
assert_eq!(scheduler.quantum(), 500);
assert_eq!(scheduler.thread_count(), 0);
assert_eq!(scheduler.total_instructions(), 0);
}
#[test]
fn test_add_main_thread() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
scheduler.add_main_thread(thread);
assert_eq!(scheduler.thread_count(), 1);
assert!(scheduler.has_ready_threads());
}
#[test]
fn test_spawn_thread() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
let id = scheduler.spawn(thread);
assert_eq!(id, ThreadId::new(1));
assert_eq!(scheduler.thread_count(), 1);
}
#[test]
fn test_select_next() {
let mut scheduler = ThreadScheduler::new(100);
let thread1 = create_test_thread(1);
let thread2 = create_test_thread(2);
scheduler.spawn(thread1);
scheduler.spawn(thread2);
let selected = scheduler.select_next();
assert!(selected.is_some());
let same = scheduler.select_next();
assert_eq!(selected, same);
}
#[test]
fn test_quantum_exhaustion() {
let mut scheduler = ThreadScheduler::new(3);
let thread = create_test_thread(1);
scheduler.spawn(thread);
scheduler.select_next();
assert!(!scheduler.record_instruction());
assert!(!scheduler.record_instruction());
assert!(scheduler.record_instruction());
assert_eq!(scheduler.total_instructions(), 3);
}
#[test]
fn test_thread_priority() {
let mut scheduler = ThreadScheduler::new(100);
let space = Arc::new(AddressSpace::new());
let capture = Arc::new(CaptureContext::new());
let fake_objects = SharedFakeObjects::new(space.managed_heap());
let mut low_thread = EmulationThread::new(
ThreadId::new(1),
Arc::clone(&space),
Arc::clone(&capture),
None,
fake_objects.clone(),
);
low_thread.set_priority(ThreadPriority::Lowest);
let mut high_thread = EmulationThread::new(
ThreadId::new(2),
Arc::clone(&space),
Arc::clone(&capture),
None,
fake_objects,
);
high_thread.set_priority(ThreadPriority::Highest);
scheduler.spawn(low_thread);
scheduler.spawn(high_thread);
let selected = scheduler.select_next();
assert_eq!(selected, Some(ThreadId::new(2)));
}
#[test]
fn test_complete_current() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
scheduler.spawn(thread);
scheduler.select_next();
scheduler.complete_current(Some(EmValue::I32(42)));
let thread = scheduler.get_thread(ThreadId::new(1)).unwrap();
assert!(thread.is_completed());
assert_eq!(thread.return_value(), Some(&EmValue::I32(42)));
}
#[test]
fn test_all_completed() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
scheduler.spawn(thread);
assert!(!scheduler.all_completed());
scheduler.select_next();
scheduler.complete_current(None);
assert!(scheduler.all_completed());
}
#[test]
fn test_yield_current() {
let mut scheduler = ThreadScheduler::new(100);
let thread1 = create_test_thread(1);
let thread2 = create_test_thread(2);
scheduler.spawn(thread1);
scheduler.spawn(thread2);
let first = scheduler.select_next();
scheduler.yield_current();
let next = scheduler.select_next();
assert_ne!(first, next);
}
#[test]
fn test_wake_sleeping_threads() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
scheduler.spawn(thread);
scheduler.select_next();
scheduler
.block_current(WaitReason::Sleep {
until_instruction: 100,
})
.unwrap();
assert!(!scheduler.has_ready_threads());
scheduler.wake(&WakeCondition::SleepElapsed {
current_instruction: 100,
});
assert!(scheduler.has_ready_threads());
}
#[test]
fn test_deadlock_detection() {
let mut scheduler = ThreadScheduler::new(100);
let thread = create_test_thread(1);
scheduler.spawn(thread);
scheduler.select_next();
let dummy_monitor = HeapRef::new(1);
scheduler
.block_current(WaitReason::Monitor(dummy_monitor))
.unwrap();
let outcome = scheduler.check_state();
assert!(matches!(outcome, SchedulerOutcome::Deadlock { .. }));
}
}