use crate::AtomicU128;
use core::{sync::atomic::Ordering, time::Duration};
#[repr(transparent)]
pub struct AtomicDuration(AtomicU128);
impl core::fmt::Debug for AtomicDuration {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("AtomicDuration")
.field(&self.load(Ordering::SeqCst))
.finish()
}
}
impl Default for AtomicDuration {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new(Duration::ZERO)
}
}
impl From<Duration> for AtomicDuration {
#[cfg_attr(not(tarpaulin), inline(always))]
fn from(duration: Duration) -> Self {
Self::new(duration)
}
}
impl AtomicDuration {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new(duration: Duration) -> Self {
Self(AtomicU128::new(encode_duration(duration)))
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn load(&self, ordering: Ordering) -> Duration {
decode_duration(self.0.load(ordering))
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn store(&self, val: Duration, ordering: Ordering) {
self.0.store(encode_duration(val), ordering)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn swap(&self, val: Duration, ordering: Ordering) -> Duration {
decode_duration(self.0.swap(encode_duration(val), ordering))
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn compare_exchange_weak(
&self,
current: Duration,
new: Duration,
success: Ordering,
failure: Ordering,
) -> Result<Duration, Duration> {
self
.0
.compare_exchange_weak(
encode_duration(current),
encode_duration(new),
success,
failure,
)
.map(decode_duration)
.map_err(decode_duration)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn compare_exchange(
&self,
current: Duration,
new: Duration,
success: Ordering,
failure: Ordering,
) -> Result<Duration, Duration> {
self
.0
.compare_exchange(
encode_duration(current),
encode_duration(new),
success,
failure,
)
.map(decode_duration)
.map_err(decode_duration)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn fetch_update<F>(
&self,
set_order: Ordering,
fetch_order: Ordering,
mut f: F,
) -> Result<Duration, Duration>
where
F: FnMut(Duration) -> Option<Duration>,
{
self
.0
.fetch_update(set_order, fetch_order, |d| {
f(decode_duration(d)).map(encode_duration)
})
.map(decode_duration)
.map_err(decode_duration)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn into_inner(self) -> Duration {
decode_duration(self.0.into_inner())
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn is_lock_free() -> bool {
AtomicU128::is_lock_free()
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn encode_duration(duration: Duration) -> u128 {
let seconds = duration.as_secs() as u128;
let nanos = duration.subsec_nanos() as u128;
(seconds << 32) + nanos
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn decode_duration(encoded: u128) -> Duration {
let seconds = (encoded >> 32) as u64;
let raw_nanos = (encoded & 0xFFFFFFFF) as u32;
let extra_secs = (raw_nanos / 1_000_000_000) as u64;
let nanos = raw_nanos % 1_000_000_000;
match seconds.checked_add(extra_secs) {
Some(secs) => Duration::new(secs, nanos),
None => Duration::new(u64::MAX, 999_999_999),
}
}
#[cfg(feature = "serde")]
const _: () = {
use serde::{Deserialize, Serialize};
impl Serialize for AtomicDuration {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.load(Ordering::SeqCst).serialize(serializer)
}
}
impl<'de> Deserialize<'de> for AtomicDuration {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
Ok(Self::new(Duration::deserialize(deserializer)?))
}
}
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_atomic_duration_new() {
let duration = Duration::from_secs(5);
let atomic_duration = AtomicDuration::new(duration);
assert_eq!(atomic_duration.load(Ordering::SeqCst), duration);
}
#[test]
fn test_atomic_duration_load() {
let duration = Duration::new(10, 10);
let atomic_duration = AtomicDuration::new(duration);
assert_eq!(atomic_duration.load(Ordering::SeqCst), duration);
}
#[test]
fn test_atomic_duration_store() {
let initial_duration = Duration::from_secs(3);
let new_duration = Duration::from_secs(7);
let atomic_duration = AtomicDuration::new(initial_duration);
atomic_duration.store(new_duration, Ordering::SeqCst);
assert_eq!(atomic_duration.load(Ordering::SeqCst), new_duration);
}
#[test]
fn test_atomic_duration_swap() {
let initial_duration = Duration::from_secs(2);
let new_duration = Duration::from_secs(8);
let atomic_duration = AtomicDuration::new(initial_duration);
let prev_duration = atomic_duration.swap(new_duration, Ordering::SeqCst);
assert_eq!(prev_duration, initial_duration);
assert_eq!(atomic_duration.load(Ordering::SeqCst), new_duration);
}
#[test]
fn test_atomic_duration_compare_exchange_weak() {
let initial_duration = Duration::from_secs(4);
let atomic_duration = AtomicDuration::new(initial_duration);
let mut result;
loop {
result = atomic_duration.compare_exchange_weak(
initial_duration,
Duration::from_secs(6),
Ordering::SeqCst,
Ordering::SeqCst,
);
if result.is_ok() || result.unwrap_err() != initial_duration {
break;
}
}
assert!(result.is_ok());
assert_eq!(result.unwrap(), initial_duration);
assert_eq!(
atomic_duration.load(Ordering::SeqCst),
Duration::from_secs(6)
);
let result = atomic_duration.compare_exchange_weak(
initial_duration,
Duration::from_secs(7),
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), Duration::from_secs(6));
}
#[test]
fn test_atomic_duration_compare_exchange() {
let initial_duration = Duration::from_secs(1);
let atomic_duration = AtomicDuration::new(initial_duration);
let result = atomic_duration.compare_exchange(
initial_duration,
Duration::from_secs(5),
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), initial_duration);
assert_eq!(
atomic_duration.load(Ordering::SeqCst),
Duration::from_secs(5)
);
let result = atomic_duration.compare_exchange(
initial_duration,
Duration::from_secs(6),
Ordering::SeqCst,
Ordering::SeqCst,
);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), Duration::from_secs(5));
}
#[test]
fn test_atomic_duration_fetch_update() {
let initial_duration = Duration::from_secs(4);
let atomic_duration = AtomicDuration::new(initial_duration);
let result = atomic_duration.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |d| {
Some(d + Duration::from_secs(2))
});
assert_eq!(result, Ok(initial_duration));
assert_eq!(
atomic_duration.load(Ordering::SeqCst),
Duration::from_secs(6)
);
}
#[test]
fn test_atomic_duration_into_inner() {
let duration = Duration::from_secs(3);
let atomic_duration = AtomicDuration::new(duration);
assert_eq!(atomic_duration.into_inner(), duration);
}
#[test]
#[cfg(feature = "std")]
fn test_atomic_duration_thread_safety() {
use std::sync::Arc;
use std::thread;
let atomic_duration = Arc::new(AtomicDuration::new(Duration::from_secs(0)));
let mut handles = vec![];
for _ in 0..10 {
let atomic_clone = Arc::clone(&atomic_duration);
let handle = thread::spawn(move || {
for _ in 0..100 {
loop {
let current = atomic_clone.load(Ordering::SeqCst);
let new = current + Duration::from_millis(1);
match atomic_clone.compare_exchange_weak(
current,
new,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break, Err(_) => continue, }
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let expected_duration = Duration::from_millis(10 * 100);
assert_eq!(atomic_duration.load(Ordering::SeqCst), expected_duration);
}
#[cfg(feature = "std")]
#[test]
fn test_atomic_duration_debug() {
let duration = Duration::new(1, 500_000_000);
let atomic_duration = AtomicDuration::new(duration);
let debug_str = format!("{:?}", atomic_duration);
assert!(debug_str.contains("AtomicDuration"));
}
#[test]
fn test_atomic_duration_default() {
let atomic_duration = AtomicDuration::default();
assert_eq!(atomic_duration.load(Ordering::SeqCst), Duration::ZERO);
}
#[test]
fn test_atomic_duration_from() {
let duration = Duration::from_secs(42);
let atomic_duration = AtomicDuration::from(duration);
assert_eq!(atomic_duration.load(Ordering::SeqCst), duration);
}
#[cfg(feature = "serde")]
#[test]
fn test_atomic_duration_serde() {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Test {
duration: AtomicDuration,
}
let test = Test {
duration: AtomicDuration::new(Duration::from_secs(5)),
};
let serialized = serde_json::to_string(&test).unwrap();
let deserialized: Test = serde_json::from_str(&serialized).unwrap();
assert_eq!(
deserialized.duration.load(Ordering::SeqCst),
Duration::from_secs(5)
);
}
#[test]
fn decode_duration_roundtrip() {
let cases = [
Duration::ZERO,
Duration::from_secs(1),
Duration::new(123_456_789, 999_999_999),
Duration::new(u64::MAX, 999_999_999),
];
for d in cases {
assert_eq!(decode_duration(encode_duration(d)), d);
}
}
#[test]
fn decode_duration_saturates_on_non_canonical_input() {
let max = decode_duration(u128::MAX);
assert_eq!(max, Duration::new(u64::MAX, 999_999_999));
let d = decode_duration(2_000_000_000u128);
assert_eq!(d, Duration::new(2, 0));
}
}