Skip to main content

liminal/routing/
table.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::routing::{FieldAccessor, Predicate, evaluate};
5
6/// Subscriber identity stored with a routing subscription.
7#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
8pub struct SubscriberId(String);
9
10impl SubscriberId {
11    /// Creates a subscriber identity from an owned or borrowed string.
12    #[must_use]
13    pub fn new(id: impl Into<String>) -> Self {
14        Self(id.into())
15    }
16
17    /// Returns the subscriber identity as a borrowed string.
18    #[must_use]
19    pub fn as_str(&self) -> &str {
20        self.0.as_str()
21    }
22}
23
24/// Active predicate subscription for a channel.
25#[derive(Clone, Debug, PartialEq)]
26pub struct Subscription {
27    /// Subscriber identity that owns this subscription.
28    pub subscriber: SubscriberId,
29    /// Predicate that must match for this subscription to receive a message.
30    pub predicate: Predicate,
31}
32
33impl Subscription {
34    /// Creates a subscription from a subscriber identity and predicate.
35    #[must_use]
36    pub const fn new(subscriber: SubscriberId, predicate: Predicate) -> Self {
37        Self {
38            subscriber,
39            predicate,
40        }
41    }
42}
43
44/// Concurrent routing table mapping channels to active predicate subscriptions.
45#[derive(Clone, Debug)]
46pub struct RoutingTable {
47    inner: Arc<TableInner>,
48}
49
50impl RoutingTable {
51    /// Creates an empty routing table.
52    #[must_use]
53    pub fn new() -> Self {
54        Self {
55            inner: Arc::new(TableInner::new()),
56        }
57    }
58
59    /// Returns true when no channel has any active subscriptions.
60    #[must_use]
61    pub fn is_empty(&self) -> bool {
62        read_table_state(&self.inner.state).channels.is_empty()
63    }
64
65    /// Registers `subscription` on `channel` and returns its shared table handle.
66    #[must_use]
67    pub fn register(
68        &self,
69        channel: impl Into<String>,
70        subscription: Subscription,
71    ) -> Arc<Subscription> {
72        let channel = channel.into();
73        let subscription = Arc::new(subscription);
74        let mut state = write_table_state(&self.inner.state);
75        let mut subscriptions = state
76            .channels
77            .get(channel.as_str())
78            .map_or_else(Vec::new, |snapshot| {
79                snapshot.iter().cloned().collect::<Vec<_>>()
80            });
81
82        subscriptions.push(Arc::clone(&subscription));
83        state
84            .channels
85            .insert(channel, Arc::from(subscriptions.into_boxed_slice()));
86
87        subscription
88    }
89
90    /// Removes subscriptions for `subscriber` on `channel`.
91    #[must_use]
92    pub fn remove(&self, channel: &str, subscriber: &SubscriberId) -> bool {
93        let mut state = write_table_state(&self.inner.state);
94        let Some(snapshot) = state.channels.get(channel).cloned() else {
95            return false;
96        };
97
98        let mut removed = false;
99        let mut subscriptions = Vec::with_capacity(snapshot.len());
100        for subscription in snapshot.iter() {
101            if subscription.subscriber == *subscriber {
102                removed = true;
103            } else {
104                subscriptions.push(Arc::clone(subscription));
105            }
106        }
107
108        if removed {
109            if subscriptions.is_empty() {
110                state.channels.remove(channel);
111            } else {
112                state.channels.insert(
113                    channel.to_owned(),
114                    Arc::from(subscriptions.into_boxed_slice()),
115                );
116            }
117        }
118
119        removed
120    }
121
122    /// Resolves all subscriptions on `channel` whose predicates match `accessor`.
123    #[must_use]
124    pub fn resolve(&self, channel: &str, accessor: &dyn FieldAccessor) -> Vec<Arc<Subscription>> {
125        let snapshot = {
126            let state = read_table_state(&self.inner.state);
127            state.channels.get(channel).cloned()
128        };
129
130        let Some(subscriptions) = snapshot else {
131            return Vec::new();
132        };
133
134        subscriptions
135            .iter()
136            .filter(|subscription| evaluate(&subscription.predicate, accessor))
137            .cloned()
138            .collect()
139    }
140}
141
142impl Default for RoutingTable {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[derive(Debug)]
149struct TableInner {
150    state: RwLock<TableState>,
151}
152
153impl TableInner {
154    fn new() -> Self {
155        Self {
156            state: RwLock::new(TableState::default()),
157        }
158    }
159}
160
161#[derive(Debug, Default)]
162struct TableState {
163    channels: BTreeMap<String, Arc<[Arc<Subscription>]>>,
164}
165
166fn read_table_state(lock: &RwLock<TableState>) -> RwLockReadGuard<'_, TableState> {
167    match lock.read() {
168        Ok(guard) => guard,
169        Err(poisoned) => poisoned.into_inner(),
170    }
171}
172
173fn write_table_state(lock: &RwLock<TableState>) -> RwLockWriteGuard<'_, TableState> {
174    match lock.write() {
175        Ok(guard) => guard,
176        Err(poisoned) => poisoned.into_inner(),
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use std::sync::mpsc::{Receiver, SyncSender, sync_channel};
183    use std::thread;
184    use std::time::Duration;
185
186    use super::{RoutingTable, SubscriberId, Subscription};
187    use crate::routing::{
188        ComparisonOp, FieldAccessor, FieldPath, FieldValue, FieldValueRef, Predicate,
189    };
190
191    #[derive(Debug)]
192    struct StaticAccessor {
193        field: &'static str,
194        value: FieldValueRef<'static>,
195    }
196
197    impl StaticAccessor {
198        const fn new(field: &'static str, value: FieldValueRef<'static>) -> Self {
199            Self { field, value }
200        }
201    }
202
203    impl FieldAccessor for StaticAccessor {
204        fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
205            path.segments().eq([self.field]).then_some(self.value)
206        }
207    }
208
209    struct BlockingAccessor {
210        entered: SyncSender<()>,
211        release: Receiver<()>,
212    }
213
214    impl std::fmt::Debug for BlockingAccessor {
215        fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216            formatter
217                .debug_struct("BlockingAccessor")
218                .finish_non_exhaustive()
219        }
220    }
221
222    impl BlockingAccessor {
223        const fn new(entered: SyncSender<()>, release: Receiver<()>) -> Self {
224            Self { entered, release }
225        }
226    }
227
228    impl FieldAccessor for BlockingAccessor {
229        fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
230            if path.segments().eq(["gate"]) {
231                if self.entered.send(()).is_err() {
232                    return None;
233                }
234                if self.release.recv().is_err() {
235                    return None;
236                }
237
238                Some(FieldValueRef::Boolean(true))
239            } else {
240                None
241            }
242        }
243    }
244
245    fn assert_send_sync<T: Send + Sync>() {}
246
247    fn amount_greater_than(value: i64) -> Predicate {
248        Predicate::Comparison {
249            field: FieldPath::new("amount"),
250            op: ComparisonOp::Gt,
251            value: FieldValue::Integer(value),
252        }
253    }
254
255    fn gate_predicate() -> Predicate {
256        Predicate::Comparison {
257            field: FieldPath::new("gate"),
258            op: ComparisonOp::Eq,
259            value: FieldValue::Boolean(true),
260        }
261    }
262
263    fn subscription(subscriber: &str, predicate: Predicate) -> Subscription {
264        Subscription::new(SubscriberId::new(subscriber), predicate)
265    }
266
267    #[test]
268    fn new_table_is_empty() {
269        let table = RoutingTable::new();
270
271        assert!(table.is_empty());
272    }
273
274    #[test]
275    fn resolve_returns_matching_subscription() {
276        let table = RoutingTable::new();
277        let registered = table.register(
278            "orders",
279            subscription("billing", amount_greater_than(1_000)),
280        );
281        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
282
283        let matches = table.resolve("orders", &accessor);
284
285        assert_eq!(matches, vec![registered]);
286    }
287
288    #[test]
289    fn resolve_returns_empty_when_no_subscription_matches() {
290        let table = RoutingTable::new();
291        let _ = table.register(
292            "orders",
293            subscription("billing", amount_greater_than(1_000)),
294        );
295        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
296
297        assert!(table.resolve("orders", &accessor).is_empty());
298    }
299
300    #[test]
301    fn multiple_subscriptions_on_channel_are_evaluated_independently() {
302        let table = RoutingTable::new();
303        let low = table.register("orders", subscription("low", amount_greater_than(100)));
304        let high = table.register("orders", subscription("high", amount_greater_than(1_000)));
305        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
306
307        let matches = table.resolve("orders", &accessor);
308
309        assert_eq!(matches, vec![low]);
310        assert_ne!(matches, vec![high]);
311    }
312
313    #[test]
314    fn routing_table_is_send_and_sync() {
315        assert_send_sync::<RoutingTable>();
316    }
317
318    #[test]
319    fn register_does_not_block_active_resolve() {
320        let table = RoutingTable::new();
321        let _ = table.register("orders", subscription("initial", gate_predicate()));
322        let resolver_table = table.clone();
323        let (entered_sender, entered_receiver) = sync_channel(0);
324        let (release_sender, release_receiver) = sync_channel(0);
325
326        let resolver = thread::spawn(move || {
327            let accessor = BlockingAccessor::new(entered_sender, release_receiver);
328            resolver_table.resolve("orders", &accessor).len()
329        });
330
331        assert!(entered_receiver.recv().is_ok());
332        let registered = table.register("orders", subscription("new", gate_predicate()));
333        assert_eq!(registered.subscriber.as_str(), "new");
334        assert!(release_sender.send(()).is_ok());
335        assert!(matches!(resolver.join(), Ok(1)));
336
337        let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
338        assert_eq!(table.resolve("orders", &accessor).len(), 2);
339    }
340
341    #[test]
342    fn remove_does_not_block_active_resolve() {
343        let table = RoutingTable::new();
344        let subscriber = SubscriberId::new("initial");
345        let _ = table.register(
346            "orders",
347            Subscription::new(subscriber.clone(), gate_predicate()),
348        );
349        let resolver_table = table.clone();
350        let (entered_sender, entered_receiver) = sync_channel(0);
351        let (release_sender, release_receiver) = sync_channel(0);
352
353        let resolver = thread::spawn(move || {
354            let accessor = BlockingAccessor::new(entered_sender, release_receiver);
355            resolver_table.resolve("orders", &accessor).len()
356        });
357
358        assert!(entered_receiver.recv().is_ok());
359        assert!(table.remove("orders", &subscriber));
360        assert!(release_sender.send(()).is_ok());
361        assert!(matches!(resolver.join(), Ok(1)));
362
363        let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
364        assert!(table.resolve("orders", &accessor).is_empty());
365    }
366
367    #[test]
368    fn register_during_active_resolve_preserves_state() {
369        let table = RoutingTable::new();
370        let _ = table.register("orders", subscription("initial", gate_predicate()));
371        let resolver_table = table.clone();
372        let (entered_sender, entered_receiver) = sync_channel(0);
373        let (release_sender, release_receiver) = sync_channel(0);
374
375        let resolver = thread::spawn(move || {
376            let accessor = BlockingAccessor::new(entered_sender, release_receiver);
377            resolver_table.resolve("orders", &accessor).len()
378        });
379
380        assert!(entered_receiver.recv().is_ok());
381        let _ = table.register("orders", subscription("new", gate_predicate()));
382        assert!(release_sender.send(()).is_ok());
383        assert!(matches!(resolver.join(), Ok(1)));
384
385        let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
386        let matches = table.resolve("orders", &accessor);
387
388        assert_eq!(matches.len(), 2);
389        assert_eq!(matches[0].subscriber.as_str(), "initial");
390        assert_eq!(matches[1].subscriber.as_str(), "new");
391    }
392
393    #[test]
394    fn resolve_nonexistent_channel_returns_empty_set() {
395        let table = RoutingTable::new();
396        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
397
398        assert!(table.resolve("nonexistent-channel", &accessor).is_empty());
399    }
400
401    #[test]
402    fn removing_last_subscription_makes_channel_resolve_empty() {
403        let table = RoutingTable::new();
404        let subscriber = SubscriberId::new("billing");
405        let _ = table.register(
406            "orders",
407            Subscription::new(subscriber.clone(), amount_greater_than(1_000)),
408        );
409        let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
410
411        assert!(table.remove("orders", &subscriber));
412        assert!(table.resolve("orders", &accessor).is_empty());
413    }
414
415    #[test]
416    fn updater_completion_is_observed_before_resolve_release() {
417        let table = RoutingTable::new();
418        let _ = table.register("orders", subscription("initial", gate_predicate()));
419        let resolver_table = table.clone();
420        let updater_table = table;
421        let (entered_sender, entered_receiver) = sync_channel(0);
422        let (release_sender, release_receiver) = sync_channel(0);
423        let (updated_sender, updated_receiver) = sync_channel(0);
424
425        let resolver = thread::spawn(move || {
426            let accessor = BlockingAccessor::new(entered_sender, release_receiver);
427            resolver_table.resolve("orders", &accessor).len()
428        });
429
430        assert!(entered_receiver.recv().is_ok());
431        let updater = thread::spawn(move || {
432            let _ = updater_table.register("orders", subscription("new", gate_predicate()));
433            updated_sender.send(())
434        });
435        assert!(
436            updated_receiver
437                .recv_timeout(Duration::from_secs(1))
438                .is_ok()
439        );
440        assert!(release_sender.send(()).is_ok());
441        assert!(matches!(resolver.join(), Ok(1)));
442        assert!(matches!(updater.join(), Ok(Ok(()))));
443    }
444}