#![cfg_attr(all(test, feature = "bench"), feature(test))]
#[cfg(all(test, feature = "bench"))]
extern crate test;
#[cfg(test)]
mod tests;
#[cfg(target_os = "linux")]
pub use linux::Once;
#[cfg(not(target_os = "linux"))]
pub use std::sync::Once;
#[cfg(target_os = "linux")]
mod linux {
use linux_futex::{Futex, Private};
use core::sync::atomic::Ordering;
pub struct Once(Futex<Private>);
const INCOMPLETE: i32 = 0;
const POISONED: i32 = 2;
const COMPLETE: i32 = 1;
const RUNNING_NO_WAIT: i32 = 3;
const RUNNING_WAITING: i32 = 4;
impl Once {
pub const fn new() -> Self {
Once(Futex::new(INCOMPLETE))
}
pub fn call_once<F: FnOnce()>(&self, f: F) {
let state = self.0.value.load(Ordering::Acquire);
if state == COMPLETE {
return;
}
let mut f = Some(f);
self.internal_call_once(state, &mut || f.take().expect("closure called more than once")())
}
#[cold]
fn internal_call_once(&self, mut state: i32, f: &mut dyn FnMut()) {
struct PanicChecker<'a> {
futex: &'a Futex<Private>,
value_to_write: i32,
}
impl<'a> Drop for PanicChecker<'a> {
fn drop(&mut self) {
if self.futex.value.swap(self.value_to_write, Ordering::AcqRel) == RUNNING_WAITING {
self.futex.wake(i32::max_value());
}
}
}
loop {
match state {
INCOMPLETE => {
if let Err(old) = self.0.value.compare_exchange_weak(INCOMPLETE, RUNNING_NO_WAIT, Ordering::Acquire, Ordering::Acquire) {
state = old;
continue;
}
{
let mut panic_checker = PanicChecker { futex: &self.0, value_to_write: POISONED, };
f();
panic_checker.value_to_write = COMPLETE;
}
break;
},
COMPLETE => break,
POISONED => panic!("Once instance has previously been poisoned"),
_running => {
if let Err(old) = self.0.value.compare_exchange(RUNNING_NO_WAIT, RUNNING_WAITING, Ordering::AcqRel, Ordering::Acquire) {
state = old;
} else {
state = RUNNING_WAITING;
}
while state >= RUNNING_NO_WAIT {
let _ = self.0.wait(state);
state = self.0.value.load(Ordering::Acquire);
}
break;
},
}
}
}
pub fn is_completed(&self) -> bool {
self.0.value.load(Ordering::Acquire) == COMPLETE
}
}
}
#[cfg(test)]
mod our_tests {
use super::Once;
use std::sync::{Arc, atomic::{AtomicUsize, Ordering::Relaxed}};
#[cfg(feature = "bench")]
use test::Bencher;
#[cfg(feature = "bench")]
const CONTENDED_THREADS: usize = 5;
#[cfg(feature = "bench")]
const CONTENDED_WAIT: u64 = 1_000_000;
#[test]
fn basic() {
let mut ran = false;
let once = Once::new();
once.call_once(|| ran = true);
assert!(ran);
ran = false;
once.call_once(|| ran = true);
assert!(!ran);
}
#[test]
fn multithreaded() {
let once = Arc::new((Once::new(), AtomicUsize::new(0)));
let once_cloned = Arc::clone(&once);
let handle = std::thread::spawn(move || once_cloned.0.call_once(|| { once_cloned.1.fetch_add(1, Relaxed); }));
once.0.call_once(|| { once.1.fetch_add(1, Relaxed); });
handle.join().expect("failed to join thread");
assert_eq!(once.1.load(Relaxed), 1);
}
#[bench]
#[cfg(feature = "bench")]
#[cfg_attr(miri, ignore)]
fn measure_std_trivial(bencher: &mut Bencher) {
bencher.iter(|| {
let mut ran = false;
let once = std::sync::Once::new();
once.call_once(|| ran = true);
assert!(ran);
})
}
#[bench]
#[cfg(feature = "bench")]
#[cfg_attr(miri, ignore)]
fn measure_linux_trivial(bencher: &mut Bencher) {
bencher.iter(|| {
let mut ran = false;
let once = Once::new();
once.call_once(|| ran = true);
assert!(ran);
})
}
#[bench]
#[cfg(feature = "bench")]
#[cfg_attr(miri, ignore)]
fn measure_std_contended(bencher: &mut Bencher) {
let barrier = Arc::new(std::sync::Barrier::new(CONTENDED_THREADS));
bencher.iter(|| {
let once = Arc::new(std::sync::Once::new());
let threads = (0..CONTENDED_THREADS)
.into_iter()
.map(|_| {
let cloned_once = Arc::clone(&once);
let cloned_barrier = Arc::clone(&barrier);
std::thread::spawn(move || {
cloned_barrier.wait();
cloned_once.call_once(|| std::thread::sleep(std::time::Duration::from_nanos(CONTENDED_WAIT)))
})
})
.collect::<Vec<_>>();
threads
.into_iter()
.map(|thread| thread.join().map(drop))
.collect::<Result<(), _>>()
.expect("Failed to join");
})
}
#[bench]
#[cfg(feature = "bench")]
#[cfg_attr(miri, ignore)]
fn measure_linux_contended(bencher: &mut Bencher) {
let barrier = Arc::new(std::sync::Barrier::new(CONTENDED_THREADS));
bencher.iter(|| {
let once = Arc::new(Once::new());
let threads = (0..CONTENDED_THREADS)
.into_iter()
.map(|_| {
let cloned = Arc::clone(&once);
let cloned_barrier = Arc::clone(&barrier);
std::thread::spawn(move || {
cloned_barrier.wait();
cloned.call_once(|| std::thread::sleep(std::time::Duration::from_nanos(CONTENDED_WAIT)))
})
})
.collect::<Vec<_>>();
threads
.into_iter()
.map(|thread| thread.join().map(drop))
.collect::<Result<(), _>>()
.expect("Failed to join");
})
}
}