Skip to main content

graft_client/runtime/storage/
changeset.rs

1use std::{
2    collections::{HashMap, HashSet},
3    hash::Hash,
4    sync::{
5        Arc,
6        atomic::{AtomicU64, Ordering},
7    },
8};
9
10use crossbeam::channel::{Receiver, Sender, TrySendError, bounded};
11use parking_lot::{Mutex, RwLock};
12
13type InnerSet<K> = Arc<RwLock<HashMap<K, AtomicU64>>>;
14
15pub struct ChangeSet<K> {
16    next_version: AtomicU64,
17    subscribers: Mutex<Vec<(Option<K>, Sender<()>)>>,
18    set: InnerSet<K>,
19}
20
21impl<K> Default for ChangeSet<K> {
22    fn default() -> Self {
23        Self {
24            next_version: AtomicU64::new(0),
25            subscribers: Default::default(),
26            set: Default::default(),
27        }
28    }
29}
30
31impl<K: Eq + Hash + Clone> ChangeSet<K> {
32    pub fn version(&self) -> u64 {
33        self.next_version.load(Ordering::SeqCst)
34    }
35
36    fn next_version(&self) -> u64 {
37        self.next_version.fetch_add(1, Ordering::SeqCst)
38    }
39
40    fn notify(&self, key: &K) {
41        let mut subscribers = self.subscribers.lock();
42        subscribers.retain(|(k, s)| {
43            if k.as_ref().is_none_or(|k| k == key) {
44                match s.try_send(()) {
45                    Ok(()) => true,
46                    Err(TrySendError::Full(())) => true,
47                    Err(TrySendError::Disconnected(())) => false,
48                }
49            } else {
50                true
51            }
52        });
53    }
54
55    /// Inserts a key into the set returning true if the set already contained the key
56    pub fn insert(&self, key: K) -> bool {
57        let version = self.next_version();
58        let existed = self
59            .set
60            .write()
61            .insert(key.clone(), AtomicU64::new(version))
62            .is_some();
63        self.notify(&key);
64        existed
65    }
66
67    /// Removes a key from the set
68    pub fn remove(&self, key: &K) {
69        self.set.write().remove(key);
70        self.notify(key);
71    }
72
73    /// Marks a key as changed
74    pub fn mark_changed(&self, key: &K) {
75        // optimistically assume the key exists
76        let found = {
77            if let Some(val) = self.set.read().get(key) {
78                val.store(self.next_version(), Ordering::SeqCst);
79                self.notify(key);
80                true
81            } else {
82                false
83            }
84        };
85
86        // fallback to inserting the key
87        if !found {
88            self.insert(key.clone());
89        }
90    }
91
92    pub fn subscribe(&self, key: K) -> Receiver<()> {
93        let (tx, rx) = bounded(1);
94        self.subscribers.lock().push((Some(key), tx));
95        rx
96    }
97
98    pub fn subscribe_all(&self) -> SetSubscriber<K> {
99        let (tx, rx) = bounded(1);
100        self.subscribers.lock().push((None, tx));
101        SetSubscriber {
102            rx,
103            version: self.version(),
104            set: self.set.clone(),
105        }
106    }
107}
108
109pub struct SetSubscriber<K> {
110    version: u64,
111    rx: Receiver<()>,
112    set: InnerSet<K>,
113}
114
115impl<K: Clone + Eq + Hash> SetSubscriber<K> {
116    /// returns a receiver that will be notified when the set changes
117    pub fn ready(&self) -> &Receiver<()> {
118        &self.rx
119    }
120
121    /// returns a set of changed keys since the last time this function returned
122    /// a non-empty set
123    pub fn changed(&mut self) -> HashSet<K> {
124        let set = self.set.read();
125        let mut max_version = self.version;
126        let set: HashSet<K> = set
127            .iter()
128            .filter_map(|(k, v)| {
129                let version = v.load(Ordering::SeqCst);
130                max_version = max_version.max(version);
131                (version >= self.version).then_some(k.clone())
132            })
133            .collect();
134
135        if !set.is_empty() {
136            self.version = max_version;
137        }
138        set
139    }
140}