tokens/
default.rs

1use crate::{Callback, ChangeToken, Registration};
2use std::{
3    any::Any,
4    sync::{
5        atomic::{AtomicBool, Ordering},
6        Arc, RwLock, Weak,
7    },
8};
9
10/// Represents a default [`ChangeToken`](crate::ChangeToken) that may change zero or more times.
11#[derive(Default)]
12pub struct DefaultChangeToken {
13    once: bool,
14    changed: AtomicBool,
15    callbacks: RwLock<
16        Vec<(
17            Weak<dyn Fn(Option<Arc<dyn Any>>) + Send + Sync>,
18            Option<Arc<dyn Any>>,
19        )>,
20    >,
21}
22
23impl DefaultChangeToken {
24    pub(crate) fn once() -> Self {
25        Self {
26            once: true,
27            ..Default::default()
28        }
29    }
30
31    /// Initializes a new default change token.
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Notifies any registered callbacks of a change.
37    pub fn notify(&self) {
38        let result = self
39            .changed
40            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
41
42        if let Ok(notified) = result {
43            if !notified {
44                // acquire a read-lock and capture any callbacks that are still alive.
45                // do NOT invoke the callback with the read-lock held. the callback might
46                // register a new callback on the same token which will result in a deadlock.
47                // invoking the callbacks after the read-lock is released ensures that won't happen.
48                let callbacks: Vec<_> = self
49                    .callbacks
50                    .read()
51                    .unwrap()
52                    .iter()
53                    .filter_map(|r| r.0.upgrade().map(|c| (c, r.1.clone())))
54                    .collect();
55
56                for (callback, state) in callbacks {
57                    callback(state);
58                }
59
60                self.changed
61                    .compare_exchange(true, self.once, Ordering::SeqCst, Ordering::SeqCst)
62                    .ok();
63            }
64        }
65    }
66}
67
68impl ChangeToken for DefaultChangeToken {
69    fn changed(&self) -> bool {
70        // this is uninteresting and unusable in sync contexts. the value
71        // will be true, invoke callbacks, and then likely revert to false
72        // before it can be observed. it 'might' be useful in an async context,
73        // but a callback is the most practical way a change would be observed
74        self.changed.load(Ordering::SeqCst)
75    }
76
77    fn register(&self, callback: Callback, state: Option<Arc<dyn Any>>) -> Registration {
78        let mut callbacks = self.callbacks.write().unwrap();
79
80        // writes are much infrequent and we already need to escalate
81        // to a write-lock, so do the trimming of any dead callbacks now
82        if !callbacks.is_empty() {
83            for i in (0..callbacks.len()).rev() {
84                if callbacks[i].0.upgrade().is_none() {
85                    callbacks.remove(i);
86                }
87            }
88        }
89
90        let source: Arc<dyn Fn(Option<Arc<dyn Any>>) + Send + Sync> = Arc::from(callback);
91
92        callbacks.push((Arc::downgrade(&source), state));
93        Registration::new(source)
94    }
95}
96
97unsafe impl Send for DefaultChangeToken {}
98unsafe impl Sync for DefaultChangeToken {}
99
100#[cfg(test)]
101mod tests {
102
103    use super::*;
104    use std::sync::{
105        atomic::{AtomicU8, Ordering},
106        Arc,
107    };
108
109    #[test]
110    fn default_change_token_should_be_unchanged() {
111        // arrange
112        let token = DefaultChangeToken::default();
113
114        // act
115        let changed = token.changed();
116
117        // assert
118        assert_eq!(changed, false);
119    }
120
121    #[test]
122    fn default_change_token_should_invoke_callback() {
123        // arrange
124        let counter = Arc::new(AtomicU8::default());
125        let token = DefaultChangeToken::default();
126        let _registration = token.register(
127            Box::new(|state| {
128                state
129                    .unwrap()
130                    .downcast_ref::<AtomicU8>()
131                    .unwrap()
132                    .fetch_add(1, Ordering::SeqCst);
133            }),
134            Some(counter.clone()),
135        );
136
137        // act
138        token.notify();
139
140        // assert
141        assert_eq!(counter.load(Ordering::SeqCst), 1);
142    }
143
144    #[test]
145    fn default_change_token_should_invoke_callback_multiple_times() {
146        // arrange
147        let counter = Arc::new(AtomicU8::default());
148        let token = DefaultChangeToken::default();
149        let _registration = token.register(
150            Box::new(|state| {
151                state
152                    .unwrap()
153                    .downcast_ref::<AtomicU8>()
154                    .unwrap()
155                    .fetch_add(1, Ordering::SeqCst);
156            }),
157            Some(counter.clone()),
158        );
159        token.notify();
160
161        // act
162        token.notify();
163
164        // assert
165        assert_eq!(counter.load(Ordering::SeqCst), 2);
166    }
167}