Skip to main content

http_relay/http_relay/
waiting_list.rs

1//! Coordination layer for pub-sub message delivery with persistent storage.
2//!
3//! Combines SQLite persistence (via EntryRepository) with in-memory waiter channels.
4//! Entries are persisted; oneshot channels for waiting requests cannot be persisted.
5
6use std::collections::HashMap;
7use std::time::Duration;
8
9use bytes::Bytes;
10use tokio::sync::oneshot;
11
12use super::persistence::EntryRepository;
13use super::unix_timestamp_millis;
14
15/// Bounds concurrent long-poll connections per entry to prevent memory exhaustion.
16/// With 10k entries at 10 waiters each, worst case is ~100k oneshot channels.
17const MAX_WAITERS_PER_ENTRY: usize = 10;
18
19/// Payload stored in the waiting list, cloned to multiple subscribers on delivery.
20///
21/// Uses `Bytes` for zero-copy cloning when fanning out to concurrent waiters.
22#[derive(Clone, Debug)]
23pub struct Message {
24    /// The message body bytes.
25    pub body: Bytes,
26    /// Optional MIME content type (e.g., "application/json").
27    pub content_type: Option<String>,
28}
29
30/// In-memory waiters for a single entry.
31/// These cannot be persisted (oneshot channels are runtime-specific).
32struct Waiters {
33    /// Subscribers waiting to be notified when ACK occurs.
34    ack_waiters: Vec<oneshot::Sender<()>>,
35    /// Subscribers waiting to receive a message when it arrives.
36    message_waiters: Vec<oneshot::Sender<Message>>,
37}
38
39impl Waiters {
40    fn new() -> Self {
41        Self {
42            ack_waiters: Vec::new(),
43            message_waiters: Vec::new(),
44        }
45    }
46
47    fn is_empty(&self) -> bool {
48        self.ack_waiters.is_empty() && self.message_waiters.is_empty()
49    }
50
51    /// Adds an ACK waiter. Returns false if limit reached.
52    fn add_ack_waiter(&mut self, sender: oneshot::Sender<()>) -> bool {
53        if self.ack_waiters.len() >= MAX_WAITERS_PER_ENTRY {
54            return false;
55        }
56        self.ack_waiters.push(sender);
57        true
58    }
59
60    /// Adds a message waiter. Returns false if limit reached.
61    fn add_message_waiter(&mut self, sender: oneshot::Sender<Message>) -> bool {
62        if self.message_waiters.len() >= MAX_WAITERS_PER_ENTRY {
63            return false;
64        }
65        self.message_waiters.push(sender);
66        true
67    }
68
69    /// Notifies all message waiters and clears the list.
70    fn notify_message_waiters(&mut self, message: &Message) {
71        for waiter in self.ack_waiters.drain(..) {
72            // Drop stale ack waiters when message is overwritten
73            drop(waiter);
74        }
75        for waiter in self.message_waiters.drain(..) {
76            let _ = waiter.send(message.clone());
77        }
78    }
79
80    /// Notifies all ACK waiters and clears the list.
81    fn notify_ack_waiters(&mut self) {
82        for waiter in self.ack_waiters.drain(..) {
83            let _ = waiter.send(());
84        }
85    }
86}
87
88/// Coordination layer for pub-sub message delivery.
89///
90/// Combines persistent storage (EntryRepository) with in-memory waiter channels.
91/// Supports two patterns:
92/// 1. **Inbox**: producer stores, consumer polls or long-polls for message
93/// 2. **Link**: producer stores, waits for ACK; consumer fetches and ACKs
94///
95/// Memory is bounded by repository's max_entries and per-entry waiter limits.
96pub struct WaitingList {
97    repository: EntryRepository,
98    waiters: HashMap<String, Waiters>,
99}
100
101/// Error returned when subscribing to a waiting list entry fails.
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum SubscribeError {
104    /// Maximum waiters per entry reached.
105    WaiterLimitReached,
106    /// Entry not found or expired.
107    NotFound,
108}
109
110/// Result of get_or_subscribe operation.
111pub enum GetOrSubscribeResult {
112    /// Message was immediately available.
113    Message(Message),
114    /// No message yet - receiver will fire when message arrives.
115    Waiting(oneshot::Receiver<Message>),
116}
117
118impl WaitingList {
119    /// Creates a WaitingList with the given repository.
120    pub fn new(repository: EntryRepository) -> Self {
121        Self {
122            repository,
123            waiters: HashMap::new(),
124        }
125    }
126
127    /// Stores a message, immediately notifying any long-polling subscribers.
128    ///
129    /// On overwrite: message_waiters receive the new message; stale ack_waiters
130    /// (from previous message) are dropped (their receivers get `RecvError`).
131    ///
132    /// Returns an error if persistence fails. Waiters are notified regardless
133    /// of persistence success (message is delivered but may not survive restart).
134    pub fn store(&mut self, id: String, message: Message, ttl: Duration) -> anyhow::Result<()> {
135        let expires_at = unix_timestamp_millis() + ttl.as_millis() as i64;
136
137        // Notify any waiting subscribers before persisting
138        if let Some(waiters) = self.waiters.get_mut(&id) {
139            waiters.notify_message_waiters(&message);
140            if waiters.is_empty() {
141                self.waiters.remove(&id);
142            }
143        }
144
145        // Persist to SQLite
146        self.repository.insert(
147            &id,
148            &message.body,
149            message.content_type.as_deref(),
150            expires_at,
151        )
152    }
153
154    /// Marks entry as acknowledged, clearing message payload and notifying waiters.
155    ///
156    /// Returns false if entry missing or expired (caller should return 404).
157    pub fn ack(&mut self, id: &str) -> bool {
158        // Check if entry exists and is not expired
159        let entry = match self.repository.get(id) {
160            Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => entry,
161            _ => return false,
162        };
163
164        // Already acked? Return true but don't notify again
165        if entry.acked {
166            return true;
167        }
168
169        // Persist the ACK
170        if let Err(e) = self.repository.ack(id) {
171            tracing::error!(?e, id, "Failed to ack entry in repository");
172            return false;
173        }
174
175        // Notify waiters
176        if let Some(waiters) = self.waiters.get_mut(id) {
177            waiters.notify_ack_waiters();
178            if waiters.is_empty() {
179                self.waiters.remove(id);
180            }
181        }
182
183        true
184    }
185
186    /// Returns whether an entry has been acknowledged.
187    ///
188    /// Returns `None` if entry doesn't exist or is expired, `Some(bool)` otherwise.
189    pub fn is_acked(&self, id: &str) -> Option<bool> {
190        match self.repository.get(id) {
191            Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => Some(entry.acked),
192            _ => None,
193        }
194    }
195
196    /// Subscribes to receive notification when this entry is ACKed.
197    ///
198    /// Returns `Err(SubscribeError::NotFound)` if entry missing/expired,
199    /// `Err(SubscribeError::WaiterLimitReached)` if waiter limit reached.
200    pub fn subscribe_ack(&mut self, id: &str) -> Result<oneshot::Receiver<()>, SubscribeError> {
201        // Check if entry exists and is not expired
202        let entry = match self.repository.get(id) {
203            Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => entry,
204            _ => return Err(SubscribeError::NotFound),
205        };
206
207        // If already acked, create a channel and immediately notify
208        if entry.acked {
209            let (tx, rx) = oneshot::channel();
210            let _ = tx.send(());
211            return Ok(rx);
212        }
213
214        let (tx, rx) = oneshot::channel();
215        let waiters = self
216            .waiters
217            .entry(id.to_string())
218            .or_insert_with(Waiters::new);
219        if waiters.add_ack_waiter(tx) {
220            Ok(rx)
221        } else {
222            Err(SubscribeError::WaiterLimitReached)
223        }
224    }
225
226    /// Atomically checks for message and subscribes if not present.
227    ///
228    /// Prevents TOCTOU race: without this, a separate "check then subscribe" would
229    /// miss messages arriving between the two operations.
230    ///
231    /// Creates a waiter entry if no message exists, allowing waiters to register
232    /// before any message is stored (consumer arrives before producer).
233    pub fn get_or_subscribe(&mut self, id: &str) -> Result<GetOrSubscribeResult, SubscribeError> {
234        // Check if entry exists with a message
235        if let Ok(Some(entry)) = self.repository.get(id) {
236            if !Self::is_expired(entry.expires_at) {
237                if let Some(body) = entry.message_body {
238                    return Ok(GetOrSubscribeResult::Message(Message {
239                        body: Bytes::from(body),
240                        content_type: entry.content_type,
241                    }));
242                }
243            }
244        }
245
246        // No message - subscribe for notification
247        let (tx, rx) = oneshot::channel();
248        let waiters = self
249            .waiters
250            .entry(id.to_string())
251            .or_insert_with(Waiters::new);
252
253        if waiters.add_message_waiter(tx) {
254            Ok(GetOrSubscribeResult::Waiting(rx))
255        } else {
256            Err(SubscribeError::WaiterLimitReached)
257        }
258    }
259
260    /// Removes expired entries and cleans up stale waiters.
261    pub fn cleanup_expired(&mut self) -> usize {
262        // Identify expired entries BEFORE deleting them from repository
263        // (so we can clean up their waiters)
264        let expired_keys: Vec<String> = self
265            .waiters
266            .keys()
267            .filter(|id| {
268                match self.repository.get(id) {
269                    Ok(Some(entry)) => Self::is_expired(entry.expires_at),
270                    _ => false, // No entry yet or error - keep waiters
271                }
272            })
273            .cloned()
274            .collect();
275
276        // Delete expired entries from repository
277        let count = match self.repository.cleanup_expired() {
278            Ok(c) => c,
279            Err(e) => {
280                tracing::error!(?e, "Failed to cleanup expired entries");
281                0
282            }
283        };
284
285        // Clean up closed senders (receivers dropped due to timeout)
286        for waiters in self.waiters.values_mut() {
287            waiters.ack_waiters.retain(|s| !s.is_closed());
288            waiters.message_waiters.retain(|s| !s.is_closed());
289        }
290
291        // Remove empty waiter entries
292        self.waiters.retain(|_, w| !w.is_empty());
293
294        // Remove waiters for entries that were expired
295        for key in expired_keys {
296            self.waiters.remove(&key);
297        }
298
299        count
300    }
301
302    /// Checks if an entry has expired based on its expires_at timestamp (in milliseconds).
303    fn is_expired(expires_at: i64) -> bool {
304        unix_timestamp_millis() >= expires_at
305    }
306}
307
308#[cfg(test)]
309impl WaitingList {
310    /// Creates a WaitingList with an in-memory repository for testing.
311    pub fn new_in_memory(max_entries: usize) -> Self {
312        let repository =
313            EntryRepository::new(None, max_entries).expect("Failed to create in-memory repository");
314        Self::new(repository)
315    }
316
317    /// Returns the number of entries in the repository.
318    pub fn len(&self) -> usize {
319        self.repository.count().unwrap_or(0)
320    }
321
322    /// Returns true if the repository is empty.
323    pub fn is_empty(&self) -> bool {
324        self.len() == 0
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use std::time::Duration;
332    use tokio::time::sleep;
333
334    fn make_message(body: &str) -> Message {
335        Message {
336            body: Bytes::from(body.to_string()),
337            content_type: Some("text/plain".to_string()),
338        }
339    }
340
341    fn create_test_list() -> WaitingList {
342        WaitingList::new_in_memory(100)
343    }
344
345    // ========== cleanup_expired tests ==========
346
347    #[tokio::test]
348    async fn cleanup_expired_removes_expired_entries() {
349        let mut list = create_test_list();
350        let short_ttl = Duration::from_millis(10);
351        let long_ttl = Duration::from_secs(60);
352
353        list.store("expires-soon".to_string(), make_message("a"), short_ttl)
354            .unwrap();
355        list.store("stays-alive".to_string(), make_message("b"), long_ttl)
356            .unwrap();
357
358        assert_eq!(list.len(), 2);
359
360        sleep(Duration::from_millis(20)).await;
361
362        let removed = list.cleanup_expired();
363
364        assert_eq!(removed, 1);
365        assert_eq!(list.len(), 1);
366    }
367
368    #[tokio::test]
369    async fn cleanup_expired_returns_zero_when_nothing_expired() {
370        let mut list = create_test_list();
371        let ttl = Duration::from_secs(60);
372
373        list.store("a".to_string(), make_message("a"), ttl).unwrap();
374        list.store("b".to_string(), make_message("b"), ttl).unwrap();
375
376        let removed = list.cleanup_expired();
377
378        assert_eq!(removed, 0);
379        assert_eq!(list.len(), 2);
380    }
381
382    // ========== Message overwrite tests ==========
383
384    #[tokio::test]
385    async fn store_notifies_message_waiters_on_overwrite() {
386        let mut list = create_test_list();
387        let ttl = Duration::from_secs(60);
388
389        // Subscribe for message
390        let result = list.get_or_subscribe("id1").expect("should succeed");
391        let rx = match result {
392            GetOrSubscribeResult::Waiting(rx) => rx,
393            GetOrSubscribeResult::Message(_) => panic!("expected waiting"),
394        };
395
396        // Store message - should notify the waiter
397        list.store("id1".to_string(), make_message("overwrite"), ttl)
398            .unwrap();
399
400        let received = rx.await.expect("should receive message");
401        assert_eq!(received.body, Bytes::from("overwrite"));
402    }
403
404    #[tokio::test]
405    async fn store_drops_ack_waiters_on_overwrite() {
406        let mut list = create_test_list();
407        let ttl = Duration::from_secs(60);
408
409        list.store("id1".to_string(), make_message("first"), ttl)
410            .unwrap();
411        let ack_rx = list.subscribe_ack("id1").expect("should subscribe");
412
413        // Overwrite the entry
414        list.store("id1".to_string(), make_message("second"), ttl)
415            .unwrap();
416
417        // Old ack waiter should be dropped (receive error)
418        let result = ack_rx.await;
419        assert!(result.is_err(), "old ack waiter should be dropped");
420    }
421
422    // ========== Waiter limit tests ==========
423
424    #[test]
425    fn subscribe_ack_returns_limit_error() {
426        let mut list = create_test_list();
427        let ttl = Duration::from_secs(60);
428
429        list.store("id1".to_string(), make_message("test"), ttl)
430            .unwrap();
431
432        for _ in 0..MAX_WAITERS_PER_ENTRY {
433            let result = list.subscribe_ack("id1");
434            assert!(result.is_ok());
435        }
436
437        let result = list.subscribe_ack("id1");
438        assert!(
439            matches!(result, Err(SubscribeError::WaiterLimitReached)),
440            "expected WaiterLimitReached error"
441        );
442    }
443
444    #[test]
445    fn get_or_subscribe_returns_limit_error() {
446        let mut list = create_test_list();
447
448        // Subscribe multiple times (no message stored yet)
449        for _ in 0..MAX_WAITERS_PER_ENTRY {
450            let result = list.get_or_subscribe("id1");
451            assert!(result.is_ok());
452        }
453
454        let result = list.get_or_subscribe("id1");
455        assert!(
456            matches!(result, Err(SubscribeError::WaiterLimitReached)),
457            "expected WaiterLimitReached error"
458        );
459    }
460
461    // ========== State transition tests ==========
462
463    #[test]
464    fn is_acked_false_before_ack() {
465        let mut list = create_test_list();
466        let ttl = Duration::from_secs(60);
467
468        list.store("id1".to_string(), make_message("test"), ttl)
469            .unwrap();
470
471        assert_eq!(list.is_acked("id1"), Some(false));
472    }
473
474    #[test]
475    fn is_acked_true_after_ack() {
476        let mut list = create_test_list();
477        let ttl = Duration::from_secs(60);
478
479        list.store("id1".to_string(), make_message("test"), ttl)
480            .unwrap();
481        let ack_result = list.ack("id1");
482
483        assert!(ack_result, "ack should succeed");
484        assert_eq!(list.is_acked("id1"), Some(true));
485    }
486
487    #[tokio::test]
488    async fn ack_notifies_waiters() {
489        let mut list = create_test_list();
490        let ttl = Duration::from_secs(60);
491
492        list.store("id1".to_string(), make_message("test"), ttl)
493            .unwrap();
494        let rx = list.subscribe_ack("id1").expect("should subscribe");
495
496        list.ack("id1");
497
498        let result = rx.await;
499        assert!(result.is_ok(), "ack waiter should receive notification");
500    }
501
502    #[tokio::test]
503    async fn ack_fails_for_expired_entry() {
504        let mut list = create_test_list();
505        let short_ttl = Duration::from_millis(10);
506
507        list.store("id1".to_string(), make_message("test"), short_ttl)
508            .unwrap();
509
510        sleep(Duration::from_millis(20)).await;
511
512        assert!(!list.ack("id1"), "ack should fail for expired entry");
513        assert_eq!(
514            list.is_acked("id1"),
515            None,
516            "is_acked should be None for expired"
517        );
518    }
519
520    #[test]
521    fn ack_fails_for_nonexistent_entry() {
522        let mut list = create_test_list();
523
524        assert!(!list.ack("nonexistent"));
525        assert_eq!(list.is_acked("nonexistent"), None);
526    }
527
528    // ========== get_or_subscribe atomicity tests ==========
529
530    #[test]
531    fn get_or_subscribe_returns_existing_message() {
532        let mut list = create_test_list();
533        let ttl = Duration::from_secs(60);
534
535        list.store("id1".to_string(), make_message("existing"), ttl)
536            .unwrap();
537
538        let result = list.get_or_subscribe("id1").expect("should succeed");
539
540        match result {
541            GetOrSubscribeResult::Message(msg) => {
542                assert_eq!(msg.body, Bytes::from("existing"));
543            }
544            GetOrSubscribeResult::Waiting(_) => {
545                panic!("should return message, not waiting");
546            }
547        }
548    }
549
550    #[test]
551    fn get_or_subscribe_returns_receiver_when_no_entry() {
552        let mut list = create_test_list();
553
554        let result = list.get_or_subscribe("id1").expect("should succeed");
555
556        match result {
557            GetOrSubscribeResult::Message(_) => {
558                panic!("should return waiting, not message");
559            }
560            GetOrSubscribeResult::Waiting(_) => {
561                // Correct - waiter registered
562                assert!(list.waiters.contains_key("id1"));
563            }
564        }
565    }
566
567    #[tokio::test]
568    async fn get_or_subscribe_receiver_gets_message_when_stored() {
569        let mut list = create_test_list();
570        let ttl = Duration::from_secs(60);
571
572        let result = list.get_or_subscribe("id1").expect("should succeed");
573        let rx = match result {
574            GetOrSubscribeResult::Waiting(rx) => rx,
575            GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
576        };
577
578        list.store("id1".to_string(), make_message("arrived"), ttl)
579            .unwrap();
580
581        let msg = rx.await.expect("should receive message");
582        assert_eq!(msg.body, Bytes::from("arrived"));
583    }
584
585    #[tokio::test]
586    async fn get_or_subscribe_ignores_expired_message() {
587        let mut list = create_test_list();
588        let short_ttl = Duration::from_millis(10);
589
590        list.store("id1".to_string(), make_message("expired"), short_ttl)
591            .unwrap();
592
593        sleep(Duration::from_millis(20)).await;
594
595        let result = list.get_or_subscribe("id1").expect("should succeed");
596
597        match result {
598            GetOrSubscribeResult::Message(_) => {
599                panic!("should not return expired message");
600            }
601            GetOrSubscribeResult::Waiting(_) => {
602                // Correct behavior
603            }
604        }
605    }
606
607    // ========== Cleanup behavior tests ==========
608
609    #[tokio::test]
610    async fn cleanup_does_not_remove_waiters_without_entry() {
611        // Regression test: consumer arrives before producer, cleanup must not evict waiter
612        let mut list = create_test_list();
613        let ttl = Duration::from_secs(60);
614
615        // Consumer subscribes - no entry in DB yet
616        let result = list
617            .get_or_subscribe("consumer-first")
618            .expect("should succeed");
619        let rx = match result {
620            GetOrSubscribeResult::Waiting(rx) => rx,
621            GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
622        };
623
624        // Cleanup runs - must NOT remove the waiter
625        list.cleanup_expired();
626
627        // Producer sends message - waiter should still receive it
628        list.store("consumer-first".to_string(), make_message("delayed"), ttl)
629            .unwrap();
630
631        let msg = rx
632            .await
633            .expect("waiter should not have been removed by cleanup");
634        assert_eq!(msg.body, Bytes::from("delayed"));
635    }
636
637    #[tokio::test]
638    async fn cleanup_removes_waiters_for_expired_entries() {
639        let mut list = create_test_list();
640        let short_ttl = Duration::from_millis(10);
641
642        // Store message, subscribe for ack
643        list.store("will-expire".to_string(), make_message("test"), short_ttl)
644            .unwrap();
645        let ack_rx = list.subscribe_ack("will-expire").expect("should subscribe");
646
647        // Wait for expiry
648        sleep(Duration::from_millis(20)).await;
649
650        // Cleanup should remove the expired entry AND its waiters
651        list.cleanup_expired();
652
653        // Ack waiter should be dropped (sender removed)
654        assert!(
655            ack_rx.await.is_err(),
656            "waiter should be removed for expired entry"
657        );
658    }
659
660    #[tokio::test]
661    async fn cleanup_removes_closed_senders() {
662        let mut list = create_test_list();
663
664        // Subscribe but immediately drop the receiver (simulates timeout)
665        let result = list
666            .get_or_subscribe("dropped-receiver")
667            .expect("should succeed");
668        match result {
669            GetOrSubscribeResult::Waiting(rx) => drop(rx), // Receiver dropped
670            GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
671        };
672
673        assert!(list.waiters.contains_key("dropped-receiver"));
674
675        // Cleanup should remove closed senders
676        list.cleanup_expired();
677
678        // Waiter entry should be removed (no live senders)
679        assert!(!list.waiters.contains_key("dropped-receiver"));
680    }
681}