use std::hint::spin_loop;
use std::thread::yield_now;
use std::sync::atomic::{AtomicU32, Ordering};
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub struct SpinWait {
count: AtomicU32,
yield_threshold: u32,
}
impl SpinWait {
pub fn new() -> Self {
SpinWait {
count: AtomicU32::new(0),
yield_threshold: 10,
}
}
pub fn with_threshold(yield_threshold: u32) -> Self {
Self {
count: AtomicU32::new(0),
yield_threshold,
}
}
pub fn spin_once(&self) {
self.count.fetch_add(1, Ordering::Relaxed);
if self.next_spin_will_yield() || num_cpus::get() == 1 {
yield_now();
} else if self.count.load(Ordering::Relaxed) < 4 {
spin_loop();
} else {
thread::sleep(Duration::from_nanos(1 << self.count.load(Ordering::Relaxed)));
}
}
pub fn count(&self) -> u32 {
self.count.load(Ordering::Relaxed)
}
pub fn next_spin_will_yield(&self) -> bool {
self.count.load(Ordering::Relaxed) >= self.yield_threshold
}
pub fn reset(&self) {
self.count.store(0, Ordering::Relaxed);
}
pub fn spin_until<F>(&self, condition: F)
where
F: Fn() -> bool,
{
while !condition() {
self.spin_once();
}
}
}
impl Default for SpinWait {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
#[test]
fn test_spin_once_increments_count() {
let spinner = SpinWait::new();
assert_eq!(spinner.count(), 0);
spinner.spin_once();
assert_eq!(spinner.count(), 1);
}
#[test]
fn test_reset_clears_count() {
let spinner = SpinWait::new();
spinner.spin_once();
spinner.spin_once();
assert_eq!(spinner.count(), 2);
spinner.reset();
assert_eq!(spinner.count(), 0);
}
#[test]
fn test_next_spin_will_yield() {
let spinner = SpinWait::new();
for _ in 0..spinner.yield_threshold {
assert!(!spinner.next_spin_will_yield());
spinner.spin_once();
}
assert!(spinner.next_spin_will_yield());
}
#[test]
fn test_spin_until() {
let flag = Arc::new(AtomicBool::new(false));
let spinner = SpinWait::new();
let handle = thread::spawn({
let flag = flag.clone();
move || {
thread::sleep(std::time::Duration::from_millis(10));
flag.store(true, Ordering::Relaxed);
}
});
spinner.spin_until(|| flag.load(Ordering::Relaxed));
assert!(flag.load(Ordering::Relaxed));
handle.join().unwrap();
}
}