1use std::collections::BTreeMap;
2use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
3
4use crate::routing::{FieldAccessor, Predicate, evaluate};
5
6#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
8pub struct SubscriberId(String);
9
10impl SubscriberId {
11 #[must_use]
13 pub fn new(id: impl Into<String>) -> Self {
14 Self(id.into())
15 }
16
17 #[must_use]
19 pub fn as_str(&self) -> &str {
20 self.0.as_str()
21 }
22}
23
24#[derive(Clone, Debug, PartialEq)]
26pub struct Subscription {
27 pub subscriber: SubscriberId,
29 pub predicate: Predicate,
31}
32
33impl Subscription {
34 #[must_use]
36 pub const fn new(subscriber: SubscriberId, predicate: Predicate) -> Self {
37 Self {
38 subscriber,
39 predicate,
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct RoutingTable {
47 inner: Arc<TableInner>,
48}
49
50impl RoutingTable {
51 #[must_use]
53 pub fn new() -> Self {
54 Self {
55 inner: Arc::new(TableInner::new()),
56 }
57 }
58
59 #[must_use]
61 pub fn is_empty(&self) -> bool {
62 read_table_state(&self.inner.state).channels.is_empty()
63 }
64
65 #[must_use]
67 pub fn register(
68 &self,
69 channel: impl Into<String>,
70 subscription: Subscription,
71 ) -> Arc<Subscription> {
72 let channel = channel.into();
73 let subscription = Arc::new(subscription);
74 let mut state = write_table_state(&self.inner.state);
75 let mut subscriptions = state
76 .channels
77 .get(channel.as_str())
78 .map_or_else(Vec::new, |snapshot| {
79 snapshot.iter().cloned().collect::<Vec<_>>()
80 });
81
82 subscriptions.push(Arc::clone(&subscription));
83 state
84 .channels
85 .insert(channel, Arc::from(subscriptions.into_boxed_slice()));
86
87 subscription
88 }
89
90 #[must_use]
92 pub fn remove(&self, channel: &str, subscriber: &SubscriberId) -> bool {
93 let mut state = write_table_state(&self.inner.state);
94 let Some(snapshot) = state.channels.get(channel).cloned() else {
95 return false;
96 };
97
98 let mut removed = false;
99 let mut subscriptions = Vec::with_capacity(snapshot.len());
100 for subscription in snapshot.iter() {
101 if subscription.subscriber == *subscriber {
102 removed = true;
103 } else {
104 subscriptions.push(Arc::clone(subscription));
105 }
106 }
107
108 if removed {
109 if subscriptions.is_empty() {
110 state.channels.remove(channel);
111 } else {
112 state.channels.insert(
113 channel.to_owned(),
114 Arc::from(subscriptions.into_boxed_slice()),
115 );
116 }
117 }
118
119 removed
120 }
121
122 #[must_use]
124 pub fn resolve(&self, channel: &str, accessor: &dyn FieldAccessor) -> Vec<Arc<Subscription>> {
125 let snapshot = {
126 let state = read_table_state(&self.inner.state);
127 state.channels.get(channel).cloned()
128 };
129
130 let Some(subscriptions) = snapshot else {
131 return Vec::new();
132 };
133
134 subscriptions
135 .iter()
136 .filter(|subscription| evaluate(&subscription.predicate, accessor))
137 .cloned()
138 .collect()
139 }
140}
141
142impl Default for RoutingTable {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[derive(Debug)]
149struct TableInner {
150 state: RwLock<TableState>,
151}
152
153impl TableInner {
154 fn new() -> Self {
155 Self {
156 state: RwLock::new(TableState::default()),
157 }
158 }
159}
160
161#[derive(Debug, Default)]
162struct TableState {
163 channels: BTreeMap<String, Arc<[Arc<Subscription>]>>,
164}
165
166fn read_table_state(lock: &RwLock<TableState>) -> RwLockReadGuard<'_, TableState> {
167 match lock.read() {
168 Ok(guard) => guard,
169 Err(poisoned) => poisoned.into_inner(),
170 }
171}
172
173fn write_table_state(lock: &RwLock<TableState>) -> RwLockWriteGuard<'_, TableState> {
174 match lock.write() {
175 Ok(guard) => guard,
176 Err(poisoned) => poisoned.into_inner(),
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use std::sync::mpsc::{Receiver, SyncSender, sync_channel};
183 use std::thread;
184 use std::time::Duration;
185
186 use super::{RoutingTable, SubscriberId, Subscription};
187 use crate::routing::{
188 ComparisonOp, FieldAccessor, FieldPath, FieldValue, FieldValueRef, Predicate,
189 };
190
191 #[derive(Debug)]
192 struct StaticAccessor {
193 field: &'static str,
194 value: FieldValueRef<'static>,
195 }
196
197 impl StaticAccessor {
198 const fn new(field: &'static str, value: FieldValueRef<'static>) -> Self {
199 Self { field, value }
200 }
201 }
202
203 impl FieldAccessor for StaticAccessor {
204 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
205 path.segments().eq([self.field]).then_some(self.value)
206 }
207 }
208
209 struct BlockingAccessor {
210 entered: SyncSender<()>,
211 release: Receiver<()>,
212 }
213
214 impl std::fmt::Debug for BlockingAccessor {
215 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 formatter
217 .debug_struct("BlockingAccessor")
218 .finish_non_exhaustive()
219 }
220 }
221
222 impl BlockingAccessor {
223 const fn new(entered: SyncSender<()>, release: Receiver<()>) -> Self {
224 Self { entered, release }
225 }
226 }
227
228 impl FieldAccessor for BlockingAccessor {
229 fn field(&self, path: &FieldPath) -> Option<FieldValueRef<'_>> {
230 if path.segments().eq(["gate"]) {
231 if self.entered.send(()).is_err() {
232 return None;
233 }
234 if self.release.recv().is_err() {
235 return None;
236 }
237
238 Some(FieldValueRef::Boolean(true))
239 } else {
240 None
241 }
242 }
243 }
244
245 fn assert_send_sync<T: Send + Sync>() {}
246
247 fn amount_greater_than(value: i64) -> Predicate {
248 Predicate::Comparison {
249 field: FieldPath::new("amount"),
250 op: ComparisonOp::Gt,
251 value: FieldValue::Integer(value),
252 }
253 }
254
255 fn gate_predicate() -> Predicate {
256 Predicate::Comparison {
257 field: FieldPath::new("gate"),
258 op: ComparisonOp::Eq,
259 value: FieldValue::Boolean(true),
260 }
261 }
262
263 fn subscription(subscriber: &str, predicate: Predicate) -> Subscription {
264 Subscription::new(SubscriberId::new(subscriber), predicate)
265 }
266
267 #[test]
268 fn new_table_is_empty() {
269 let table = RoutingTable::new();
270
271 assert!(table.is_empty());
272 }
273
274 #[test]
275 fn resolve_returns_matching_subscription() {
276 let table = RoutingTable::new();
277 let registered = table.register(
278 "orders",
279 subscription("billing", amount_greater_than(1_000)),
280 );
281 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
282
283 let matches = table.resolve("orders", &accessor);
284
285 assert_eq!(matches, vec![registered]);
286 }
287
288 #[test]
289 fn resolve_returns_empty_when_no_subscription_matches() {
290 let table = RoutingTable::new();
291 let _ = table.register(
292 "orders",
293 subscription("billing", amount_greater_than(1_000)),
294 );
295 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
296
297 assert!(table.resolve("orders", &accessor).is_empty());
298 }
299
300 #[test]
301 fn multiple_subscriptions_on_channel_are_evaluated_independently() {
302 let table = RoutingTable::new();
303 let low = table.register("orders", subscription("low", amount_greater_than(100)));
304 let high = table.register("orders", subscription("high", amount_greater_than(1_000)));
305 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(500));
306
307 let matches = table.resolve("orders", &accessor);
308
309 assert_eq!(matches, vec![low]);
310 assert_ne!(matches, vec![high]);
311 }
312
313 #[test]
314 fn routing_table_is_send_and_sync() {
315 assert_send_sync::<RoutingTable>();
316 }
317
318 #[test]
319 fn register_does_not_block_active_resolve() {
320 let table = RoutingTable::new();
321 let _ = table.register("orders", subscription("initial", gate_predicate()));
322 let resolver_table = table.clone();
323 let (entered_sender, entered_receiver) = sync_channel(0);
324 let (release_sender, release_receiver) = sync_channel(0);
325
326 let resolver = thread::spawn(move || {
327 let accessor = BlockingAccessor::new(entered_sender, release_receiver);
328 resolver_table.resolve("orders", &accessor).len()
329 });
330
331 assert!(entered_receiver.recv().is_ok());
332 let registered = table.register("orders", subscription("new", gate_predicate()));
333 assert_eq!(registered.subscriber.as_str(), "new");
334 assert!(release_sender.send(()).is_ok());
335 assert!(matches!(resolver.join(), Ok(1)));
336
337 let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
338 assert_eq!(table.resolve("orders", &accessor).len(), 2);
339 }
340
341 #[test]
342 fn remove_does_not_block_active_resolve() {
343 let table = RoutingTable::new();
344 let subscriber = SubscriberId::new("initial");
345 let _ = table.register(
346 "orders",
347 Subscription::new(subscriber.clone(), gate_predicate()),
348 );
349 let resolver_table = table.clone();
350 let (entered_sender, entered_receiver) = sync_channel(0);
351 let (release_sender, release_receiver) = sync_channel(0);
352
353 let resolver = thread::spawn(move || {
354 let accessor = BlockingAccessor::new(entered_sender, release_receiver);
355 resolver_table.resolve("orders", &accessor).len()
356 });
357
358 assert!(entered_receiver.recv().is_ok());
359 assert!(table.remove("orders", &subscriber));
360 assert!(release_sender.send(()).is_ok());
361 assert!(matches!(resolver.join(), Ok(1)));
362
363 let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
364 assert!(table.resolve("orders", &accessor).is_empty());
365 }
366
367 #[test]
368 fn register_during_active_resolve_preserves_state() {
369 let table = RoutingTable::new();
370 let _ = table.register("orders", subscription("initial", gate_predicate()));
371 let resolver_table = table.clone();
372 let (entered_sender, entered_receiver) = sync_channel(0);
373 let (release_sender, release_receiver) = sync_channel(0);
374
375 let resolver = thread::spawn(move || {
376 let accessor = BlockingAccessor::new(entered_sender, release_receiver);
377 resolver_table.resolve("orders", &accessor).len()
378 });
379
380 assert!(entered_receiver.recv().is_ok());
381 let _ = table.register("orders", subscription("new", gate_predicate()));
382 assert!(release_sender.send(()).is_ok());
383 assert!(matches!(resolver.join(), Ok(1)));
384
385 let accessor = StaticAccessor::new("gate", FieldValueRef::Boolean(true));
386 let matches = table.resolve("orders", &accessor);
387
388 assert_eq!(matches.len(), 2);
389 assert_eq!(matches[0].subscriber.as_str(), "initial");
390 assert_eq!(matches[1].subscriber.as_str(), "new");
391 }
392
393 #[test]
394 fn resolve_nonexistent_channel_returns_empty_set() {
395 let table = RoutingTable::new();
396 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
397
398 assert!(table.resolve("nonexistent-channel", &accessor).is_empty());
399 }
400
401 #[test]
402 fn removing_last_subscription_makes_channel_resolve_empty() {
403 let table = RoutingTable::new();
404 let subscriber = SubscriberId::new("billing");
405 let _ = table.register(
406 "orders",
407 Subscription::new(subscriber.clone(), amount_greater_than(1_000)),
408 );
409 let accessor = StaticAccessor::new("amount", FieldValueRef::Integer(1_500));
410
411 assert!(table.remove("orders", &subscriber));
412 assert!(table.resolve("orders", &accessor).is_empty());
413 }
414
415 #[test]
416 fn updater_completion_is_observed_before_resolve_release() {
417 let table = RoutingTable::new();
418 let _ = table.register("orders", subscription("initial", gate_predicate()));
419 let resolver_table = table.clone();
420 let updater_table = table;
421 let (entered_sender, entered_receiver) = sync_channel(0);
422 let (release_sender, release_receiver) = sync_channel(0);
423 let (updated_sender, updated_receiver) = sync_channel(0);
424
425 let resolver = thread::spawn(move || {
426 let accessor = BlockingAccessor::new(entered_sender, release_receiver);
427 resolver_table.resolve("orders", &accessor).len()
428 });
429
430 assert!(entered_receiver.recv().is_ok());
431 let updater = thread::spawn(move || {
432 let _ = updater_table.register("orders", subscription("new", gate_predicate()));
433 updated_sender.send(())
434 });
435 assert!(
436 updated_receiver
437 .recv_timeout(Duration::from_secs(1))
438 .is_ok()
439 );
440 assert!(release_sender.send(()).is_ok());
441 assert!(matches!(resolver.join(), Ok(1)));
442 assert!(matches!(updater.join(), Ok(Ok(()))));
443 }
444}