use crate::{Error, Result};
use fluid_let::fluid_let;
use std::sync::atomic::{AtomicIsize, Ordering};
static LOGGING_STATE: AtomicIsize = AtomicIsize::new(0);
fluid_let!(
static SAFE_LOGGING_SUPPRESSED_IN_THREAD: bool
);
pub(crate) fn unsafe_logging_enabled() -> bool {
LOGGING_STATE.load(Ordering::Relaxed) < 0
|| SAFE_LOGGING_SUPPRESSED_IN_THREAD.get(|v| v == Some(&true))
}
pub fn with_safe_logging_suppressed<F, V>(func: F) -> V
where
F: FnOnce() -> V,
{
SAFE_LOGGING_SUPPRESSED_IN_THREAD.set(true, func)
}
#[derive(Debug, Copy, Clone)]
enum GuardKind {
Safe,
Unsafe,
}
#[derive(Debug)]
#[must_use = "If you drop the guard immediately, it won't do anything."]
pub struct Guard {
kind: GuardKind,
}
impl GuardKind {
fn check(&self, val: isize) -> Result<()> {
match self {
GuardKind::Safe => {
if val < 0 {
return Err(Error::AlreadyUnsafe);
}
}
GuardKind::Unsafe => {
if val > 0 {
return Err(Error::AlreadySafe);
}
}
}
Ok(())
}
fn increment(&self) -> isize {
match self {
GuardKind::Safe => 1,
GuardKind::Unsafe => -1,
}
}
}
impl Guard {
fn new(kind: GuardKind) -> Result<Self> {
let inc = kind.increment();
loop {
let old_val = LOGGING_STATE.load(Ordering::SeqCst);
kind.check(old_val)?;
let new_val = match old_val.checked_add(inc) {
Some(v) => v,
None => return Err(Error::Overflow),
};
if let Ok(v) =
LOGGING_STATE.compare_exchange(old_val, new_val, Ordering::SeqCst, Ordering::SeqCst)
{
debug_assert_eq!(v, old_val);
return Ok(Self { kind });
}
}
}
}
impl Drop for Guard {
fn drop(&mut self) {
let inc = self.kind.increment();
LOGGING_STATE.fetch_sub(inc, Ordering::SeqCst);
}
}
pub fn enforce_safe_logging() -> Result<Guard> {
Guard::new(GuardKind::Safe)
}
pub fn disable_safe_logging() -> Result<Guard> {
Guard::new(GuardKind::Unsafe)
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn guards() {
assert!(!unsafe_logging_enabled());
let g1 = enforce_safe_logging().unwrap();
let g2 = enforce_safe_logging().unwrap();
assert!(!unsafe_logging_enabled());
let e = disable_safe_logging();
assert!(matches!(e, Err(Error::AlreadySafe)));
assert!(!unsafe_logging_enabled());
drop(g1);
drop(g2);
let _g3 = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
let e = enforce_safe_logging();
assert!(matches!(e, Err(Error::AlreadyUnsafe)));
assert!(unsafe_logging_enabled());
let _g4 = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
}
#[test]
#[serial]
fn suppress() {
{
let _g = enforce_safe_logging().unwrap();
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
assert!(!unsafe_logging_enabled());
}
{
assert!(!unsafe_logging_enabled());
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
assert!(!unsafe_logging_enabled());
}
{
let _g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
}
}
#[test]
#[serial]
fn interfere_1() {
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
if let Ok(_g) = enforce_safe_logging() {
assert!(!unsafe_logging_enabled());
yield_now();
assert!(disable_safe_logging().is_err());
}
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
if let Ok(_g) = disable_safe_logging() {
assert!(unsafe_logging_enabled());
yield_now();
assert!(enforce_safe_logging().is_err());
}
yield_now();
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
#[test]
#[serial]
fn interfere_2() {
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
let g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
yield_now();
drop(g);
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
let g = disable_safe_logging().unwrap();
assert!(unsafe_logging_enabled());
yield_now();
drop(g);
yield_now();
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
#[test]
#[serial]
fn interfere_3() {
use std::thread::{spawn, yield_now};
let thread1 = spawn(|| {
for _ in 0..10_000 {
assert!(!unsafe_logging_enabled());
yield_now();
}
});
let thread2 = spawn(|| {
for _ in 0..10_000 {
assert!(!unsafe_logging_enabled());
with_safe_logging_suppressed(|| {
assert!(unsafe_logging_enabled());
yield_now();
});
}
});
thread1.join().unwrap();
thread2.join().unwrap();
}
}