use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[derive(Clone)]
pub struct Progress(Arc<AtomicU32>);
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum ProgressState {
NoUpdate,
Updated,
ProtectedZone(u32),
}
#[allow(clippy::from_over_into)]
impl Into<u32> for ProgressState {
fn into(self) -> u32 {
match self {
ProgressState::NoUpdate => 0,
ProgressState::Updated => 1,
ProgressState::ProtectedZone(level) => 2 + level,
}
}
}
impl From<u32> for ProgressState {
fn from(level: u32) -> Self {
match level {
0 => ProgressState::NoUpdate,
1 => ProgressState::Updated,
level => ProgressState::ProtectedZone(level - 2),
}
}
}
impl Default for Progress {
fn default() -> Progress {
Progress(Arc::new(AtomicU32::new(ProgressState::Updated.into())))
}
}
impl Progress {
pub fn record_progress(&self) {
self.0
.fetch_max(ProgressState::Updated.into(), Ordering::Relaxed);
}
pub fn protect_zone(&self) -> ProtectedZoneGuard {
loop {
let previous_state: ProgressState = self.0.load(Ordering::SeqCst).into();
let new_state = match previous_state {
ProgressState::NoUpdate | ProgressState::Updated => ProgressState::ProtectedZone(0),
ProgressState::ProtectedZone(level) => ProgressState::ProtectedZone(level + 1),
};
if self
.0
.compare_exchange(
previous_state.into(),
new_state.into(),
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
return ProtectedZoneGuard(self.0.clone());
}
}
}
pub fn registered_activity_since_last_call(&self) -> bool {
let previous_state: ProgressState = self
.0
.compare_exchange(
ProgressState::Updated.into(),
ProgressState::NoUpdate.into(),
Ordering::Relaxed,
Ordering::Relaxed,
)
.unwrap_or_else(|previous_value| previous_value)
.into();
previous_state != ProgressState::NoUpdate
}
}
pub struct ProtectedZoneGuard(Arc<AtomicU32>);
impl Drop for ProtectedZoneGuard {
fn drop(&mut self) {
let previous_state: ProgressState = self.0.fetch_sub(1, Ordering::SeqCst).into();
assert!(matches!(previous_state, ProgressState::ProtectedZone(_)));
}
}
#[cfg(test)]
mod tests {
use super::Progress;
#[test]
fn test_progress() {
let progress = Progress::default();
assert!(progress.registered_activity_since_last_call());
progress.record_progress();
assert!(progress.registered_activity_since_last_call());
assert!(!progress.registered_activity_since_last_call());
}
#[test]
fn test_progress_protect_zone() {
let progress = Progress::default();
assert!(progress.registered_activity_since_last_call());
progress.record_progress();
assert!(progress.registered_activity_since_last_call());
{
let _protect_guard = progress.protect_zone();
assert!(progress.registered_activity_since_last_call());
assert!(progress.registered_activity_since_last_call());
}
assert!(progress.registered_activity_since_last_call());
assert!(!progress.registered_activity_since_last_call());
}
#[test]
fn test_progress_several_protect_zone() {
let progress = Progress::default();
assert!(progress.registered_activity_since_last_call());
progress.record_progress();
assert!(progress.registered_activity_since_last_call());
let first_protect_guard = progress.protect_zone();
let second_protect_guard = progress.protect_zone();
assert!(progress.registered_activity_since_last_call());
assert!(progress.registered_activity_since_last_call());
std::mem::drop(first_protect_guard);
assert!(progress.registered_activity_since_last_call());
assert!(progress.registered_activity_since_last_call());
std::mem::drop(second_protect_guard);
assert!(progress.registered_activity_since_last_call());
assert!(!progress.registered_activity_since_last_call());
}
}