use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
pub struct PreemptMask {
bits: [AtomicU64; 4],
}
impl std::fmt::Debug for PreemptMask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreemptMask")
.field("bits[0]", &self.bits[0].load(Ordering::Relaxed))
.field("bits[1]", &self.bits[1].load(Ordering::Relaxed))
.field("bits[2]", &self.bits[2].load(Ordering::Relaxed))
.field("bits[3]", &self.bits[3].load(Ordering::Relaxed))
.finish()
}
}
impl PreemptMask {
pub fn new() -> Self {
Self {
bits: [
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
AtomicU64::new(0),
],
}
}
#[inline]
pub fn set(&self, i: u8) {
let i = i as usize;
let word = i / 64;
let bit = i % 64;
self.bits[word].fetch_or(1 << bit, Ordering::Release);
}
#[inline]
pub fn clear(&self) {
for word in &self.bits {
word.store(0, Ordering::Release);
}
}
#[inline]
pub fn contains(&self, i: u8) -> bool {
let i = i as usize;
let word = i / 64;
let bit = i % 64;
(self.bits[word].load(Ordering::Acquire) >> bit) & 1 == 1
}
}
pub struct PreemptState {
mask: PreemptMask,
preempt: AtomicBool,
}
impl std::fmt::Debug for PreemptState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreemptState")
.field("mask", &self.mask)
.field("preempt", &self.preempt.load(Ordering::Relaxed))
.finish()
}
}
impl PreemptState {
pub fn new() -> Self {
Self {
mask: PreemptMask::new(),
preempt: AtomicBool::new(false),
}
}
#[inline]
pub fn check(&self) -> bool {
self.preempt.load(Ordering::Acquire)
}
#[inline]
pub fn request_preempt(&self) {
self.preempt.store(true, Ordering::Release);
}
#[inline]
pub fn clear_preempt(&self) {
self.preempt.store(false, Ordering::Release);
}
#[inline]
pub fn would_preempt(&self, qidx: usize) -> bool {
self.mask.contains(qidx as u8)
}
#[inline]
pub fn update_mask<I: Iterator<Item = usize>>(&self, higher_priority_queues: I) {
self.mask.clear();
for qidx in higher_priority_queues {
self.mask.set(qidx as u8);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_preemption_mask_basic() {
let mask = PreemptMask::new();
assert!(!mask.contains(0));
assert!(!mask.contains(100));
assert!(!mask.contains(255));
mask.set(0);
mask.set(63);
mask.set(64);
mask.set(127);
mask.set(128);
mask.set(255);
assert!(mask.contains(0));
assert!(mask.contains(63));
assert!(mask.contains(64));
assert!(mask.contains(127));
assert!(mask.contains(128));
assert!(mask.contains(255));
assert!(!mask.contains(1));
assert!(!mask.contains(100));
assert!(!mask.contains(200));
mask.clear();
assert!(!mask.contains(0));
assert!(!mask.contains(255));
}
#[test]
fn test_preemption_state_check_clear() {
let state = PreemptState::new();
assert!(!state.check());
state.request_preempt();
assert!(state.check());
state.clear_preempt();
assert!(!state.check());
}
}