1use std::any::Any;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use parking_lot::RwLock;
14
15use crate::traits::CrdtMerge;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum WriteConsistency {
24 Local,
25 All { timeout: Duration },
26 Majority { timeout: Duration },
27 From { n: usize, timeout: Duration },
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31#[non_exhaustive]
32pub enum ReadConsistency {
33 Local,
34 All { timeout: Duration },
35 Majority { timeout: Duration },
36 From { n: usize, timeout: Duration },
37}
38
39impl WriteConsistency {
40 pub fn required_acks(self, cluster_size: usize) -> usize {
43 match self {
44 Self::Local => 1,
45 Self::All { .. } => cluster_size.max(1),
46 Self::Majority { .. } => (cluster_size / 2) + 1,
47 Self::From { n, .. } => n.min(cluster_size.max(1)),
48 }
49 }
50
51 pub fn timeout(self) -> Option<Duration> {
52 match self {
53 Self::Local => None,
54 Self::All { timeout } | Self::Majority { timeout } | Self::From { timeout, .. } => Some(timeout),
55 }
56 }
57}
58
59impl ReadConsistency {
60 pub fn required_replies(self, cluster_size: usize) -> usize {
61 match self {
62 Self::Local => 1,
63 Self::All { .. } => cluster_size.max(1),
64 Self::Majority { .. } => (cluster_size / 2) + 1,
65 Self::From { n, .. } => n.min(cluster_size.max(1)),
66 }
67 }
68
69 pub fn timeout(self) -> Option<Duration> {
70 match self {
71 Self::Local => None,
72 Self::All { timeout } | Self::Majority { timeout } | Self::From { timeout, .. } => Some(timeout),
73 }
74 }
75}
76
77type Entry = Box<dyn Any + Send + Sync>;
78type SubscriberId = u64;
79type Notifier = Box<dyn Fn(&str) + Send + Sync + 'static>;
80
81pub struct Replicator {
82 store: RwLock<HashMap<String, Entry>>,
83 subscribers: RwLock<HashMap<String, Vec<(SubscriberId, Notifier)>>>,
84 next_sub_id: AtomicU64,
85}
86
87impl Default for Replicator {
88 fn default() -> Self {
89 Self {
90 store: RwLock::new(HashMap::new()),
91 subscribers: RwLock::new(HashMap::new()),
92 next_sub_id: AtomicU64::new(0),
93 }
94 }
95}
96
97impl Replicator {
98 pub fn new() -> Arc<Self> {
99 Arc::new(Self::default())
100 }
101
102 pub fn update<T>(&self, key: &str, value: T)
103 where
104 T: CrdtMerge + Send + Sync + 'static,
105 {
106 {
107 let mut map = self.store.write();
108 match map.get_mut(key) {
109 Some(existing) => {
110 if let Some(current) = existing.downcast_mut::<T>() {
111 current.merge(&value);
112 } else {
113 map.insert(key.to_string(), Box::new(value));
114 }
115 }
116 None => {
117 map.insert(key.to_string(), Box::new(value));
118 }
119 }
120 }
121 self.notify(key);
122 }
123
124 pub fn subscribe<F>(self: &Arc<Self>, key: impl Into<String>, notifier: F) -> SubscriptionToken
129 where
130 F: Fn(&str) + Send + Sync + 'static,
131 {
132 let key = key.into();
133 let id = self.next_sub_id.fetch_add(1, Ordering::Relaxed);
134 self.subscribers.write().entry(key.clone()).or_default().push((id, Box::new(notifier)));
135 SubscriptionToken { id, key, replicator: Arc::downgrade(self) }
136 }
137
138 pub fn notify(&self, key: &str) {
141 let subs = self.subscribers.read();
142 if let Some(list) = subs.get(key) {
143 for (_, cb) in list {
144 cb(key);
145 }
146 }
147 }
148
149 pub(crate) fn unsubscribe_by_id(&self, key: &str, id: SubscriberId) {
152 let mut g = self.subscribers.write();
153 if let Some(list) = g.get_mut(key) {
154 list.retain(|(i, _)| *i != id);
155 if list.is_empty() {
156 g.remove(key);
157 }
158 }
159 }
160
161 pub fn subscriber_count(&self, key: &str) -> usize {
163 self.subscribers.read().get(key).map(|v| v.len()).unwrap_or(0)
164 }
165
166 pub fn get<T>(&self, key: &str) -> Option<T>
167 where
168 T: CrdtMerge + Clone + Send + Sync + 'static,
169 {
170 self.store.read().get(key).and_then(|e| e.downcast_ref::<T>().cloned())
171 }
172
173 pub fn delete(&self, key: &str) {
174 self.store.write().remove(key);
175 self.notify(key);
176 }
177
178 pub fn keys(&self) -> Vec<String> {
181 let mut ks: Vec<String> = self.store.read().keys().cloned().collect();
182 ks.sort();
183 ks
184 }
185}
186
187pub struct SubscriptionToken {
189 id: SubscriberId,
190 key: String,
191 replicator: std::sync::Weak<Replicator>,
192}
193
194impl Drop for SubscriptionToken {
195 fn drop(&mut self) {
196 if let Some(r) = self.replicator.upgrade() {
197 r.unsubscribe_by_id(&self.key, self.id);
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::GCounter;
206 use std::sync::atomic::{AtomicU32, Ordering};
207
208 #[test]
209 fn update_merges_into_existing_value() {
210 let r = Replicator::new();
211 let mut c1 = GCounter::new();
212 c1.increment("n1", 1);
213 r.update("count", c1);
214 let mut c2 = GCounter::new();
215 c2.increment("n2", 5);
216 r.update("count", c2);
217 let got: GCounter = r.get("count").unwrap();
218 assert_eq!(got.value(), 6);
219 }
220
221 #[test]
222 fn subscribe_fires_on_update() {
223 let r = Replicator::new();
224 let n = Arc::new(AtomicU32::new(0));
225 let n2 = n.clone();
226 let _t = r.subscribe("k", move |_| {
227 n2.fetch_add(1, Ordering::SeqCst);
228 });
229 let mut c = GCounter::new();
230 c.increment("a", 1);
231 r.update("k", c.clone());
232 r.update("k", c.clone());
233 assert_eq!(n.load(Ordering::SeqCst), 2);
234 }
235
236 #[test]
237 fn subscribe_fires_on_delete() {
238 let r = Replicator::new();
239 let n = Arc::new(AtomicU32::new(0));
240 let n2 = n.clone();
241 let _t = r.subscribe("k", move |_| {
242 n2.fetch_add(1, Ordering::SeqCst);
243 });
244 r.update("k", GCounter::new());
245 r.delete("k");
246 assert_eq!(n.load(Ordering::SeqCst), 2);
247 }
248
249 #[test]
250 fn drop_token_unsubscribes() {
251 let r = Replicator::new();
252 let n = Arc::new(AtomicU32::new(0));
253 let n2 = n.clone();
254 let t = r.subscribe("k", move |_| {
255 n2.fetch_add(1, Ordering::SeqCst);
256 });
257 assert_eq!(r.subscriber_count("k"), 1);
258 drop(t);
259 assert_eq!(r.subscriber_count("k"), 0);
260 r.update("k", GCounter::new());
261 assert_eq!(n.load(Ordering::SeqCst), 0);
262 }
263
264 #[test]
265 fn write_consistency_majority_math() {
266 let w = WriteConsistency::Majority { timeout: Duration::from_secs(1) };
267 assert_eq!(w.required_acks(1), 1);
268 assert_eq!(w.required_acks(3), 2);
269 assert_eq!(w.required_acks(5), 3);
270 assert_eq!(w.required_acks(6), 4);
271 }
272
273 #[test]
274 fn write_consistency_all_uses_cluster_size() {
275 let w = WriteConsistency::All { timeout: Duration::from_secs(1) };
276 assert_eq!(w.required_acks(7), 7);
277 assert_eq!(w.required_acks(0), 1); }
279
280 #[test]
281 fn read_consistency_from_clamps_to_cluster_size() {
282 let r = ReadConsistency::From { n: 99, timeout: Duration::from_secs(1) };
283 assert_eq!(r.required_replies(3), 3);
284 }
285
286 #[test]
287 fn local_consistency_has_no_timeout() {
288 assert!(WriteConsistency::Local.timeout().is_none());
289 assert!(ReadConsistency::Local.timeout().is_none());
290 }
291
292 #[test]
293 fn subscribe_only_fires_for_matching_key() {
294 let r = Replicator::new();
295 let n = Arc::new(AtomicU32::new(0));
296 let n2 = n.clone();
297 let _t = r.subscribe("a", move |_| {
298 n2.fetch_add(1, Ordering::SeqCst);
299 });
300 r.update("a", GCounter::new());
301 r.update("b", GCounter::new());
302 assert_eq!(n.load(Ordering::SeqCst), 1);
303 }
304}