Skip to main content

atomic_timer/
lib.rs

1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#![deny(missing_docs)]
3use std::{
4    sync::atomic::{AtomicBool, AtomicI64, Ordering},
5    time::Duration,
6};
7
8use bma_ts::Monotonic;
9
10/// Atomic timer
11pub struct AtomicTimer {
12    duration: AtomicI64,
13    start: AtomicI64,
14    permit_handle_expiration: AtomicBool,
15    monotonic_fn: fn() -> i64,
16}
17
18impl std::fmt::Debug for AtomicTimer {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("AtomicTimer")
21            .field("duration", &self.duration())
22            .field("elapsed", &self.elapsed())
23            .field("remaining", &self.remaining())
24            .field("expired", &self.expired())
25            .finish()
26    }
27}
28
29fn monotonic_ns() -> i64 {
30    i64::try_from(Monotonic::now().as_nanos()).expect("Monotonic time is too large")
31}
32
33impl AtomicTimer {
34    #[allow(dead_code)]
35    fn construct(duration: i64, elapsed: i64, phe: bool, monotonic_fn: fn() -> i64) -> Self {
36        AtomicTimer {
37            duration: AtomicI64::new(duration),
38            start: AtomicI64::new(monotonic_fn() - elapsed),
39            monotonic_fn,
40            permit_handle_expiration: AtomicBool::new(phe),
41        }
42    }
43    /// Create a new atomic timer
44    ///
45    /// # Panics
46    ///
47    /// Panics if the duration is too large (in nanos greater than `i64::MAX`)
48    pub fn new(duration: Duration) -> Self {
49        Self::construct(
50            duration
51                .as_nanos()
52                .try_into()
53                .expect("Duration is too large"),
54            0,
55            true,
56            monotonic_ns,
57        )
58    }
59    /// Create a new atomic timer expired
60    ///
61    /// # Panics
62    ///
63    /// Panics if the duration is too large (in nanos greater than `i64::MAX`)
64    pub fn new_expired(duration: Duration) -> Self {
65        Self::construct(
66            duration
67                .as_nanos()
68                .try_into()
69                .expect("Duration is too large"),
70            duration
71                .as_nanos()
72                .try_into()
73                .expect("Duration is too large"),
74            true,
75            monotonic_ns,
76        )
77    }
78    /// Get the duration of the timer
79    ///
80    /// # Panics
81    ///
82    /// Panics if the duration is negative
83    #[inline]
84    pub fn duration(&self) -> Duration {
85        Duration::from_nanos(self.duration.load(Ordering::SeqCst).try_into().unwrap())
86    }
87    /// Similar to reset if expired but does not reset the timer. As the timer is checked for
88    /// expiration, a tiny datarace may occur despite it passes the tests well. As soon as the
89    /// timer is reset with any method, the flag is reset as well. If used in multi-threaded
90    /// environment, "true" is returned to a single worker only. After, the flag is reset.
91    #[inline]
92    pub fn permit_handle_expiration(&self) -> bool {
93        self.permit_handle_expiration
94            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
95                (v && self.expired()).then_some(false)
96            })
97            .is_ok()
98    }
99    /// Reset the timer to a new duration
100    pub fn reset_to_duration(&self, duration: Duration) {
101        self.set_duration(duration);
102        self.reset();
103    }
104    /// Change the duration of the timer
105    ///
106    /// # Panics
107    ///
108    /// Panics if the duration in nanos is larger than `i64::MAX`
109    pub fn set_duration(&self, duration: Duration) {
110        self.duration
111            .store(duration.as_nanos().try_into().unwrap(), Ordering::SeqCst);
112    }
113    /// Reset the timer
114    #[inline]
115    pub fn reset(&self) {
116        self.permit_handle_expiration.store(true, Ordering::SeqCst);
117        self.start.store((self.monotonic_fn)(), Ordering::SeqCst);
118    }
119    /// Focibly expire the timer
120    #[inline]
121    pub fn expire_now(&self) {
122        self.start.store(
123            (self.monotonic_fn)() - self.duration.load(Ordering::SeqCst),
124            Ordering::SeqCst,
125        );
126    }
127    /// Reset the timer if it has expired, returns true if reset. If used in multi-threaded
128    /// environment, "true" is returned to a single worker only.
129    #[inline]
130    pub fn reset_if_expired(&self) -> bool {
131        let now = (self.monotonic_fn)();
132        self.start
133            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |start| {
134                self.permit_handle_expiration.store(true, Ordering::SeqCst);
135                (now.saturating_sub(start) >= self.duration.load(Ordering::SeqCst)).then_some(now)
136            })
137            .is_ok()
138    }
139    /// Get the elapsed time
140    ///
141    /// In case if negative elapsed, returns `Duration::ZERO`
142    #[inline]
143    pub fn elapsed(&self) -> Duration {
144        Duration::from_nanos(
145            (self.monotonic_fn)()
146                .saturating_sub(self.start.load(Ordering::SeqCst))
147                .try_into()
148                .unwrap_or_default(),
149        )
150    }
151    /// Get the remaining time
152    ///
153    /// In case if negative remaining, returns `Duration::ZERO`
154    #[inline]
155    pub fn remaining(&self) -> Duration {
156        let elapsed = self.elapsed_ns();
157        if elapsed >= self.duration.load(Ordering::SeqCst) {
158            Duration::ZERO
159        } else {
160            Duration::from_nanos(
161                (self.duration.load(Ordering::SeqCst) - elapsed)
162                    .try_into()
163                    .unwrap_or_default(),
164            )
165        }
166    }
167    #[inline]
168    fn elapsed_ns(&self) -> i64 {
169        (self.monotonic_fn)().saturating_sub(self.start.load(Ordering::SeqCst))
170    }
171    /// Check if the timer has expired
172    #[inline]
173    pub fn expired(&self) -> bool {
174        self.elapsed_ns() >= self.duration.load(Ordering::SeqCst)
175    }
176}
177
178#[cfg(feature = "serde")]
179mod ser {
180    use super::{monotonic_ns, AtomicTimer};
181    use serde::{Deserialize, Deserializer, Serialize, Serializer};
182    use std::sync::atomic::Ordering;
183
184    #[derive(Serialize, Deserialize)]
185    struct SerializedTimer {
186        duration: i64,
187        elapsed: i64,
188        phe: bool,
189    }
190
191    impl Serialize for AtomicTimer {
192        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193        where
194            S: Serializer,
195        {
196            let s = SerializedTimer {
197                duration: self.duration.load(Ordering::SeqCst),
198                elapsed: self.elapsed_ns(),
199                phe: self.permit_handle_expiration.load(Ordering::SeqCst),
200            };
201            s.serialize(serializer)
202        }
203    }
204
205    impl<'de> Deserialize<'de> for AtomicTimer {
206        fn deserialize<D>(deserializer: D) -> Result<AtomicTimer, D::Error>
207        where
208            D: Deserializer<'de>,
209        {
210            let s = SerializedTimer::deserialize(deserializer)?;
211            Ok(AtomicTimer::construct(
212                s.duration,
213                s.elapsed,
214                s.phe,
215                monotonic_ns,
216            ))
217        }
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use super::AtomicTimer;
224    use std::{
225        sync::{Arc, Barrier},
226        thread,
227        time::Duration,
228    };
229
230    pub(crate) fn in_time_window(a: Duration, b: Duration, window: Duration) -> bool {
231        let diff = window / 2;
232        let min = b - diff;
233        let max = b + diff;
234        a >= min && a <= max
235    }
236
237    #[test]
238    fn test_reset() {
239        let timer = AtomicTimer::new(Duration::from_secs(5));
240        thread::sleep(Duration::from_secs(1));
241        timer.reset();
242        assert!(timer.elapsed() < Duration::from_millis(100));
243    }
244
245    #[test]
246    fn test_expire_now() {
247        let timer = AtomicTimer::new(Duration::from_secs(5));
248        assert!(!timer.expired());
249        assert!(in_time_window(
250            timer.remaining(),
251            Duration::from_secs(5),
252            Duration::from_millis(100)
253        ));
254        timer.expire_now();
255        assert!(timer.expired());
256    }
257
258    #[test]
259    fn test_reset_if_expired() {
260        let timer = AtomicTimer::new(Duration::from_secs(1));
261        assert!(!timer.reset_if_expired());
262        thread::sleep(Duration::from_millis(1100));
263        assert!(timer.expired());
264        assert!(timer.reset_if_expired());
265    }
266
267    #[test]
268    fn test_reset_if_expired_no_datarace() {
269        let n = 1000;
270        let timer = Arc::new(AtomicTimer::new(Duration::from_millis(100)));
271        thread::sleep(Duration::from_millis(200));
272        assert!(timer.expired());
273        let barrier = Arc::new(Barrier::new(n));
274        let (tx, rx) = std::sync::mpsc::channel::<bool>();
275        let mut result = Vec::with_capacity(n);
276        for _ in 0..n {
277            let timer = timer.clone();
278            let barrier = barrier.clone();
279            let tx = tx.clone();
280            thread::spawn(move || {
281                barrier.wait();
282                tx.send(timer.reset_if_expired()).unwrap();
283            });
284        }
285        drop(tx);
286        while let Ok(v) = rx.recv() {
287            result.push(v);
288        }
289        assert_eq!(result.len(), n);
290        assert_eq!(result.into_iter().filter(|&v| v).count(), 1);
291    }
292
293    #[test]
294    fn test_permit_handle_expiration() {
295        let timer = AtomicTimer::new(Duration::from_secs(1));
296        assert!(!timer.permit_handle_expiration());
297        thread::sleep(Duration::from_millis(1100));
298        assert!(timer.expired());
299        assert!(timer.permit_handle_expiration());
300        assert!(!timer.permit_handle_expiration());
301        timer.reset();
302        thread::sleep(Duration::from_millis(1100));
303        timer.reset();
304        assert!(!timer.permit_handle_expiration());
305    }
306
307    #[test]
308    fn test_permit_handle_expiration_no_datarace() {
309        let n = 1000;
310        let timer = Arc::new(AtomicTimer::new(Duration::from_millis(100)));
311        thread::sleep(Duration::from_millis(200));
312        assert!(timer.expired());
313        let barrier = Arc::new(Barrier::new(n));
314        let (tx, rx) = std::sync::mpsc::channel::<bool>();
315        let mut result = Vec::with_capacity(n);
316        for _ in 0..n {
317            let timer = timer.clone();
318            let barrier = barrier.clone();
319            let tx = tx.clone();
320            thread::spawn(move || {
321                barrier.wait();
322                tx.send(timer.permit_handle_expiration()).unwrap();
323            });
324        }
325        drop(tx);
326        while let Ok(v) = rx.recv() {
327            result.push(v);
328        }
329        assert_eq!(result.len(), n);
330        assert_eq!(result.into_iter().filter(|&v| v).count(), 1);
331    }
332}
333
334#[cfg(feature = "serde")]
335#[cfg(test)]
336mod test_serialization {
337    use super::test::in_time_window;
338    use super::AtomicTimer;
339    use std::{sync::atomic::Ordering, thread, time::Duration};
340
341    #[test]
342    fn test_serialize_deserialize() {
343        let timer = AtomicTimer::new(Duration::from_secs(5));
344        thread::sleep(Duration::from_secs(1));
345        let serialized = serde_json::to_string(&timer).unwrap();
346        let deserialized: AtomicTimer = serde_json::from_str(&serialized).unwrap();
347        assert!(in_time_window(
348            deserialized.elapsed(),
349            Duration::from_secs(1),
350            Duration::from_millis(100)
351        ));
352    }
353
354    #[test]
355    fn test_serialize_deserialize_monotonic_goes_forward() {
356        fn monotonic_ns_forwarded() -> i64 {
357            super::monotonic_ns() + 10_000 * 1_000_000_000
358        }
359        let timer = AtomicTimer::new(Duration::from_secs(5));
360        thread::sleep(Duration::from_secs(1));
361        let serialized = serde_json::to_string(&timer).unwrap();
362        let deserialized: AtomicTimer = serde_json::from_str(&serialized).unwrap();
363        let deserialized_rewinded = AtomicTimer::construct(
364            deserialized.duration().as_nanos().try_into().unwrap(),
365            deserialized.elapsed_ns(),
366            deserialized.permit_handle_expiration.load(Ordering::SeqCst),
367            monotonic_ns_forwarded,
368        );
369        assert!(in_time_window(
370            deserialized_rewinded.elapsed(),
371            Duration::from_secs(1),
372            Duration::from_millis(100)
373        ));
374    }
375
376    #[test]
377    fn test_serialize_deserialize_monotonic_goes_backward() {
378        fn monotonic_ns_forwarded() -> i64 {
379            super::monotonic_ns() - 10_000 * 1_000_000_000
380        }
381        let timer = AtomicTimer::new(Duration::from_secs(5));
382        thread::sleep(Duration::from_secs(1));
383        let serialized = serde_json::to_string(&timer).unwrap();
384        let deserialized: AtomicTimer = serde_json::from_str(&serialized).unwrap();
385        let deserialized_rewinded = AtomicTimer::construct(
386            deserialized.duration().as_nanos().try_into().unwrap(),
387            deserialized.elapsed_ns(),
388            deserialized.permit_handle_expiration.load(Ordering::SeqCst),
389            monotonic_ns_forwarded,
390        );
391        assert!(in_time_window(
392            deserialized_rewinded.elapsed(),
393            Duration::from_secs(1),
394            Duration::from_millis(100)
395        ));
396    }
397
398    #[test]
399    fn test_serialize_deserialize_monotonic_goes_zero() {
400        fn monotonic_ns_forwarded() -> i64 {
401            0
402        }
403        let timer = AtomicTimer::new(Duration::from_secs(5));
404        thread::sleep(Duration::from_secs(1));
405        let serialized = serde_json::to_string(&timer).unwrap();
406        let deserialized: AtomicTimer = serde_json::from_str(&serialized).unwrap();
407        let deserialized_rewinded = AtomicTimer::construct(
408            deserialized.duration().as_nanos().try_into().unwrap(),
409            deserialized.elapsed_ns(),
410            deserialized.permit_handle_expiration.load(Ordering::SeqCst),
411            monotonic_ns_forwarded,
412        );
413        assert!(in_time_window(
414            deserialized_rewinded.elapsed(),
415            Duration::from_secs(1),
416            Duration::from_millis(100)
417        ));
418    }
419}