use std::collections::{HashMap, VecDeque};
use crate::emulation::{HeapRef, ThreadId};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct EventState {
pub signaled: bool,
pub manual_reset: bool,
}
impl EventState {
#[must_use]
pub fn manual_reset(signaled: bool) -> Self {
Self {
signaled,
manual_reset: true,
}
}
#[must_use]
pub fn auto_reset(signaled: bool) -> Self {
Self {
signaled,
manual_reset: false,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct MonitorState {
pub owner: Option<ThreadId>,
pub recursion_count: u32,
}
impl MonitorState {
#[must_use]
pub fn is_free(&self) -> bool {
self.owner.is_none()
}
#[must_use]
pub fn is_owner(&self, thread_id: ThreadId) -> bool {
self.owner == Some(thread_id)
}
pub fn try_enter(&mut self, thread_id: ThreadId) -> bool {
match self.owner {
None => {
self.owner = Some(thread_id);
self.recursion_count = 1;
true
}
Some(owner) if owner == thread_id => {
self.recursion_count += 1;
true
}
_ => false,
}
}
pub fn exit(&mut self, thread_id: ThreadId) -> Result<bool, SyncError> {
match self.owner {
Some(owner) if owner == thread_id => {
self.recursion_count -= 1;
if self.recursion_count == 0 {
self.owner = None;
Ok(true)
} else {
Ok(false)
}
}
Some(_) => Err(SyncError::NotOwner),
None => Err(SyncError::NotLocked),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct MutexState {
pub owner: Option<ThreadId>,
pub recursion_count: u32,
pub abandoned: bool,
}
impl MutexState {
pub fn try_acquire(&mut self, thread_id: ThreadId) -> bool {
match self.owner {
None => {
self.owner = Some(thread_id);
self.recursion_count = 1;
self.abandoned = false;
true
}
Some(owner) if owner == thread_id => {
self.recursion_count += 1;
true
}
_ => false,
}
}
pub fn release(&mut self, thread_id: ThreadId) -> Result<bool, SyncError> {
match self.owner {
Some(owner) if owner == thread_id => {
self.recursion_count -= 1;
if self.recursion_count == 0 {
self.owner = None;
Ok(true)
} else {
Ok(false)
}
}
Some(_) => Err(SyncError::NotOwner),
None => Err(SyncError::NotLocked),
}
}
pub fn abandon(&mut self) {
self.abandoned = true;
self.owner = None;
self.recursion_count = 0;
}
}
#[derive(Clone, Debug)]
pub struct SemaphoreState {
pub count: u32,
pub max_count: u32,
}
impl SemaphoreState {
#[must_use]
pub fn new(initial_count: u32, max_count: u32) -> Self {
Self {
count: initial_count.min(max_count),
max_count,
}
}
pub fn try_acquire(&mut self) -> bool {
if self.count > 0 {
self.count -= 1;
true
} else {
false
}
}
pub fn release(&mut self, count: u32) -> Result<u32, SyncError> {
let new_count = self.count.saturating_add(count);
if new_count > self.max_count {
Err(SyncError::SemaphoreOverflow)
} else {
let previous = self.count;
self.count = new_count;
Ok(previous)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SyncError {
NotOwner,
NotLocked,
SemaphoreOverflow,
ObjectNotFound,
Deadlock,
Timeout,
}
impl std::fmt::Display for SyncError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SyncError::NotOwner => write!(f, "thread does not own the lock"),
SyncError::NotLocked => write!(f, "lock is not held"),
SyncError::SemaphoreOverflow => write!(f, "semaphore count would overflow"),
SyncError::ObjectNotFound => write!(f, "synchronization object not found"),
SyncError::Deadlock => write!(f, "deadlock detected"),
SyncError::Timeout => write!(f, "wait operation timed out"),
}
}
}
impl std::error::Error for SyncError {}
#[derive(Debug, Default)]
pub struct SyncState {
monitors: HashMap<HeapRef, MonitorState>,
monitor_wait_queues: HashMap<HeapRef, VecDeque<ThreadId>>,
monitor_pulse_queues: HashMap<HeapRef, VecDeque<ThreadId>>,
events: HashMap<HeapRef, EventState>,
event_wait_queues: HashMap<HeapRef, VecDeque<ThreadId>>,
mutexes: HashMap<HeapRef, MutexState>,
mutex_wait_queues: HashMap<HeapRef, VecDeque<ThreadId>>,
semaphores: HashMap<HeapRef, SemaphoreState>,
semaphore_wait_queues: HashMap<HeapRef, VecDeque<ThreadId>>,
}
impl SyncState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn monitor_try_enter(&mut self, obj: HeapRef, thread_id: ThreadId) -> bool {
let state = self.monitors.entry(obj).or_default();
state.try_enter(thread_id)
}
pub fn monitor_exit(&mut self, obj: HeapRef, thread_id: ThreadId) -> Result<bool, SyncError> {
let state = self.monitors.get_mut(&obj).ok_or(SyncError::NotLocked)?;
state.exit(thread_id)
}
pub fn monitor_wait(&mut self, obj: HeapRef, thread_id: ThreadId) -> Result<u32, SyncError> {
let state = self.monitors.get_mut(&obj).ok_or(SyncError::NotLocked)?;
if !state.is_owner(thread_id) {
return Err(SyncError::NotOwner);
}
let saved_count = state.recursion_count;
state.owner = None;
state.recursion_count = 0;
self.monitor_wait_queues
.entry(obj)
.or_default()
.push_back(thread_id);
Ok(saved_count)
}
pub fn monitor_pulse(&mut self, obj: HeapRef, thread_id: ThreadId) -> Result<(), SyncError> {
let state = self.monitors.get(&obj).ok_or(SyncError::NotLocked)?;
if !state.is_owner(thread_id) {
return Err(SyncError::NotOwner);
}
if let Some(wait_queue) = self.monitor_wait_queues.get_mut(&obj) {
if let Some(waiting_thread) = wait_queue.pop_front() {
self.monitor_pulse_queues
.entry(obj)
.or_default()
.push_back(waiting_thread);
}
}
Ok(())
}
pub fn monitor_pulse_all(
&mut self,
obj: HeapRef,
thread_id: ThreadId,
) -> Result<(), SyncError> {
let state = self.monitors.get(&obj).ok_or(SyncError::NotLocked)?;
if !state.is_owner(thread_id) {
return Err(SyncError::NotOwner);
}
if let Some(mut wait_queue) = self.monitor_wait_queues.remove(&obj) {
let pulse_queue = self.monitor_pulse_queues.entry(obj).or_default();
pulse_queue.append(&mut wait_queue);
}
Ok(())
}
pub fn monitor_get_pulsed(&mut self, obj: HeapRef) -> Vec<ThreadId> {
self.monitor_pulse_queues
.remove(&obj)
.map(|q| q.into_iter().collect())
.unwrap_or_default()
}
#[must_use]
pub fn monitor_is_owner(&self, obj: HeapRef, thread_id: ThreadId) -> bool {
self.monitors
.get(&obj)
.is_some_and(|s| s.is_owner(thread_id))
}
pub fn event_create(&mut self, obj: HeapRef, manual_reset: bool, initial_state: bool) {
self.events.insert(
obj,
EventState {
signaled: initial_state,
manual_reset,
},
);
}
pub fn event_try_wait(&mut self, obj: HeapRef) -> Option<bool> {
let state = self.events.get_mut(&obj)?;
if state.signaled {
if !state.manual_reset {
state.signaled = false;
}
Some(true)
} else {
Some(false)
}
}
pub fn event_add_waiter(&mut self, obj: HeapRef, thread_id: ThreadId) {
self.event_wait_queues
.entry(obj)
.or_default()
.push_back(thread_id);
}
pub fn event_set(&mut self, obj: HeapRef) -> Vec<ThreadId> {
if let Some(state) = self.events.get_mut(&obj) {
state.signaled = true;
if state.manual_reset {
self.event_wait_queues
.remove(&obj)
.map(|q| q.into_iter().collect())
.unwrap_or_default()
} else {
self.event_wait_queues
.get_mut(&obj)
.and_then(VecDeque::pop_front)
.into_iter()
.collect()
}
} else {
Vec::new()
}
}
pub fn event_reset(&mut self, obj: HeapRef) -> bool {
if let Some(state) = self.events.get_mut(&obj) {
state.signaled = false;
true
} else {
false
}
}
pub fn mutex_try_acquire(&mut self, obj: HeapRef, thread_id: ThreadId) -> bool {
let state = self.mutexes.entry(obj).or_default();
state.try_acquire(thread_id)
}
pub fn mutex_release(&mut self, obj: HeapRef, thread_id: ThreadId) -> Result<bool, SyncError> {
let state = self.mutexes.get_mut(&obj).ok_or(SyncError::NotLocked)?;
state.release(thread_id)
}
pub fn mutex_add_waiter(&mut self, obj: HeapRef, thread_id: ThreadId) {
self.mutex_wait_queues
.entry(obj)
.or_default()
.push_back(thread_id);
}
pub fn mutex_next_waiter(&mut self, obj: HeapRef) -> Option<ThreadId> {
self.mutex_wait_queues
.get_mut(&obj)
.and_then(VecDeque::pop_front)
}
pub fn mutex_abandon_for_thread(&mut self, thread_id: ThreadId) -> Vec<HeapRef> {
let mut abandoned = Vec::new();
for (obj, state) in &mut self.mutexes {
if state.owner == Some(thread_id) {
state.abandon();
abandoned.push(*obj);
}
}
abandoned
}
pub fn semaphore_create(&mut self, obj: HeapRef, initial_count: u32, max_count: u32) {
self.semaphores
.insert(obj, SemaphoreState::new(initial_count, max_count));
}
#[allow(clippy::redundant_closure_for_method_calls)]
pub fn semaphore_try_acquire(&mut self, obj: HeapRef) -> Option<bool> {
self.semaphores.get_mut(&obj).map(|s| s.try_acquire())
}
pub fn semaphore_release(
&mut self,
obj: HeapRef,
count: u32,
) -> Result<(u32, Vec<ThreadId>), SyncError> {
let state = self
.semaphores
.get_mut(&obj)
.ok_or(SyncError::ObjectNotFound)?;
let previous = state.release(count)?;
let mut woken = Vec::new();
if let Some(queue) = self.semaphore_wait_queues.get_mut(&obj) {
for _ in 0..count.min(u32::try_from(queue.len()).unwrap_or(u32::MAX)) {
if let Some(thread_id) = queue.pop_front() {
woken.push(thread_id);
}
}
}
Ok((previous, woken))
}
pub fn semaphore_add_waiter(&mut self, obj: HeapRef, thread_id: ThreadId) {
self.semaphore_wait_queues
.entry(obj)
.or_default()
.push_back(thread_id);
}
pub fn cleanup_thread(&mut self, thread_id: ThreadId) {
self.mutex_abandon_for_thread(thread_id);
for queue in self.monitor_wait_queues.values_mut() {
queue.retain(|&id| id != thread_id);
}
for queue in self.monitor_pulse_queues.values_mut() {
queue.retain(|&id| id != thread_id);
}
for queue in self.event_wait_queues.values_mut() {
queue.retain(|&id| id != thread_id);
}
for queue in self.mutex_wait_queues.values_mut() {
queue.retain(|&id| id != thread_id);
}
for queue in self.semaphore_wait_queues.values_mut() {
queue.retain(|&id| id != thread_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_monitor_enter_exit() {
let mut state = MonitorState::default();
let thread1 = ThreadId::new(1);
let thread2 = ThreadId::new(2);
assert!(state.try_enter(thread1));
assert!(state.is_owner(thread1));
assert!(!state.try_enter(thread2));
assert!(state.try_enter(thread1));
assert_eq!(state.recursion_count, 2);
assert!(!state.exit(thread1).unwrap());
assert!(state.is_owner(thread1));
assert!(state.exit(thread1).unwrap());
assert!(state.is_free());
assert!(state.try_enter(thread2));
}
#[test]
fn test_monitor_exit_errors() {
let mut state = MonitorState::default();
let thread1 = ThreadId::new(1);
let thread2 = ThreadId::new(2);
assert_eq!(state.exit(thread1), Err(SyncError::NotLocked));
state.try_enter(thread1);
assert_eq!(state.exit(thread2), Err(SyncError::NotOwner));
}
#[test]
fn test_semaphore() {
let mut state = SemaphoreState::new(2, 5);
assert!(state.try_acquire());
assert!(state.try_acquire());
assert!(!state.try_acquire());
assert_eq!(state.release(1).unwrap(), 0);
assert!(state.try_acquire());
assert_eq!(state.release(10), Err(SyncError::SemaphoreOverflow));
}
#[test]
fn test_event_manual_reset() {
let mut sync = SyncState::new();
let obj = HeapRef::new(1);
sync.event_create(obj, true, false);
assert_eq!(sync.event_try_wait(obj), Some(false));
sync.event_set(obj);
assert_eq!(sync.event_try_wait(obj), Some(true));
assert_eq!(sync.event_try_wait(obj), Some(true));
sync.event_reset(obj);
assert_eq!(sync.event_try_wait(obj), Some(false));
}
#[test]
fn test_event_auto_reset() {
let mut sync = SyncState::new();
let obj = HeapRef::new(1);
sync.event_create(obj, false, true);
assert_eq!(sync.event_try_wait(obj), Some(true));
assert_eq!(sync.event_try_wait(obj), Some(false));
}
#[test]
fn test_sync_state_monitor() {
let mut sync = SyncState::new();
let obj = HeapRef::new(1);
let thread1 = ThreadId::new(1);
let thread2 = ThreadId::new(2);
assert!(sync.monitor_try_enter(obj, thread1));
assert!(!sync.monitor_try_enter(obj, thread2));
assert!(sync.monitor_is_owner(obj, thread1));
sync.monitor_exit(obj, thread1).unwrap();
assert!(sync.monitor_try_enter(obj, thread2));
}
#[test]
fn test_sync_state_monitor_wait_pulse() {
let mut sync = SyncState::new();
let obj = HeapRef::new(1);
let thread1 = ThreadId::new(1);
let thread2 = ThreadId::new(2);
sync.monitor_try_enter(obj, thread1);
sync.monitor_wait(obj, thread1).unwrap();
sync.monitor_try_enter(obj, thread2);
sync.monitor_pulse(obj, thread2).unwrap();
let pulsed = sync.monitor_get_pulsed(obj);
assert_eq!(pulsed, vec![thread1]);
}
#[test]
fn test_cleanup_thread() {
let mut sync = SyncState::new();
let obj = HeapRef::new(1);
let thread1 = ThreadId::new(1);
sync.mutex_try_acquire(obj, thread1);
sync.monitor_wait_queues
.entry(HeapRef::new(2))
.or_default()
.push_back(thread1);
sync.cleanup_thread(thread1);
assert!(sync.mutexes.get(&obj).unwrap().abandoned);
assert!(sync
.monitor_wait_queues
.get(&HeapRef::new(2))
.map(|q| q.is_empty())
.unwrap_or(true));
}
}