use core::sync::atomic::{AtomicU8, Ordering};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
#[repr(u8)]
pub enum State {
Uninitialized = 0,
Initialized = 1,
Running = 2,
}
impl State {
#[inline]
#[must_use]
pub const fn allows_io(self) -> bool {
matches!(self, State::Running)
}
#[inline]
const fn from_u8(value: u8) -> Self {
match value {
0 => State::Uninitialized,
1 => State::Initialized,
2 => State::Running,
_ => panic!("StateCell stored an invalid State byte"),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct TransitionError {
pub expected: State,
pub attempted: State,
pub actual: State,
}
#[derive(Debug)]
pub struct StateCell {
state: AtomicU8,
}
impl StateCell {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
state: AtomicU8::new(State::Uninitialized as u8),
}
}
#[inline]
#[must_use]
pub fn load(&self) -> State {
State::from_u8(self.state.load(Ordering::Acquire))
}
#[inline]
#[must_use]
pub fn is_running(&self) -> bool {
self.load() == State::Running
}
#[inline]
pub fn initialize(&self) -> Result<(), TransitionError> {
self.transition(State::Uninitialized, State::Initialized)
}
#[inline]
pub fn start(&self) -> Result<(), TransitionError> {
self.transition(State::Initialized, State::Running)
}
#[inline]
pub fn stop(&self) -> Result<(), TransitionError> {
self.transition(State::Running, State::Initialized)
}
#[inline]
pub fn reset(&self) -> State {
State::from_u8(
self.state
.swap(State::Uninitialized as u8, Ordering::AcqRel),
)
}
fn transition(&self, from: State, to: State) -> Result<(), TransitionError> {
match self
.state
.compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => Ok(()),
Err(actual_byte) => Err(TransitionError {
expected: from,
attempted: to,
actual: State::from_u8(actual_byte),
}),
}
}
}
impl Default for StateCell {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use static_assertions::assert_impl_all;
assert_impl_all!(StateCell: Send, Sync);
#[test]
fn new_cell_starts_uninitialized() {
let s = StateCell::new();
assert_eq!(s.load(), State::Uninitialized);
assert!(!s.is_running());
}
#[test]
fn default_matches_new() {
assert_eq!(StateCell::default().load(), State::Uninitialized);
}
#[test]
fn initialize_uninitialized_to_initialized() {
let s = StateCell::new();
assert!(s.initialize().is_ok());
assert_eq!(s.load(), State::Initialized);
}
#[test]
fn initialize_when_already_initialized_errors() {
let s = StateCell::new();
s.initialize().unwrap();
let err = s.initialize().unwrap_err();
assert_eq!(err.expected, State::Uninitialized);
assert_eq!(err.attempted, State::Initialized);
assert_eq!(err.actual, State::Initialized);
}
#[test]
fn start_requires_initialized() {
let s = StateCell::new();
let err = s.start().unwrap_err();
assert_eq!(err.expected, State::Initialized);
assert_eq!(err.attempted, State::Running);
assert_eq!(err.actual, State::Uninitialized);
s.initialize().unwrap();
assert!(s.start().is_ok());
assert!(s.is_running());
assert!(s.load().allows_io());
}
#[test]
fn start_when_already_running_errors() {
let s = StateCell::new();
s.initialize().unwrap();
s.start().unwrap();
let err = s.start().unwrap_err();
assert_eq!(err.actual, State::Running);
}
#[test]
fn stop_requires_running() {
let s = StateCell::new();
let err = s.stop().unwrap_err();
assert_eq!(err.expected, State::Running);
assert_eq!(err.attempted, State::Initialized);
assert_eq!(err.actual, State::Uninitialized);
}
#[test]
fn stop_returns_to_initialized() {
let s = StateCell::new();
s.initialize().unwrap();
s.start().unwrap();
s.stop().unwrap();
assert_eq!(s.load(), State::Initialized);
}
#[test]
fn reset_from_any_state() {
let s = StateCell::new();
assert_eq!(s.reset(), State::Uninitialized);
s.initialize().unwrap();
assert_eq!(s.reset(), State::Initialized);
assert_eq!(s.load(), State::Uninitialized);
s.initialize().unwrap();
s.start().unwrap();
assert_eq!(s.reset(), State::Running);
assert_eq!(s.load(), State::Uninitialized);
}
#[test]
fn full_lifecycle_round_trip() {
let s = StateCell::new();
s.initialize().unwrap();
s.start().unwrap();
s.stop().unwrap();
s.start().unwrap();
s.stop().unwrap();
let prior = s.reset();
assert_eq!(prior, State::Initialized);
assert_eq!(s.load(), State::Uninitialized);
}
#[test]
fn allows_io_only_in_running() {
assert!(!State::Uninitialized.allows_io());
assert!(!State::Initialized.allows_io());
assert!(State::Running.allows_io());
}
#[test]
fn concurrent_initialize_has_exactly_one_winner() {
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::thread;
const THREADS: usize = 8;
let s = Arc::new(StateCell::new());
let wins = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let s = Arc::clone(&s);
let wins = Arc::clone(&wins);
thread::spawn(move || {
if s.initialize().is_ok() {
wins.fetch_add(1, Ordering::Relaxed);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(wins.load(Ordering::Relaxed), 1);
assert_eq!(s.load(), State::Initialized);
}
}