1use std::collections::HashMap;
7use std::time::Duration;
8
9use bytes::Bytes;
10use tokio::sync::oneshot;
11
12use super::persistence::EntryRepository;
13use super::unix_timestamp_millis;
14
15const MAX_WAITERS_PER_ENTRY: usize = 10;
18
19#[derive(Clone, Debug)]
23pub struct Message {
24 pub body: Bytes,
26 pub content_type: Option<String>,
28}
29
30struct Waiters {
33 ack_waiters: Vec<oneshot::Sender<()>>,
35 message_waiters: Vec<oneshot::Sender<Message>>,
37}
38
39impl Waiters {
40 fn new() -> Self {
41 Self {
42 ack_waiters: Vec::new(),
43 message_waiters: Vec::new(),
44 }
45 }
46
47 fn is_empty(&self) -> bool {
48 self.ack_waiters.is_empty() && self.message_waiters.is_empty()
49 }
50
51 fn add_ack_waiter(&mut self, sender: oneshot::Sender<()>) -> bool {
53 if self.ack_waiters.len() >= MAX_WAITERS_PER_ENTRY {
54 return false;
55 }
56 self.ack_waiters.push(sender);
57 true
58 }
59
60 fn add_message_waiter(&mut self, sender: oneshot::Sender<Message>) -> bool {
62 if self.message_waiters.len() >= MAX_WAITERS_PER_ENTRY {
63 return false;
64 }
65 self.message_waiters.push(sender);
66 true
67 }
68
69 fn notify_message_waiters(&mut self, message: &Message) {
71 for waiter in self.ack_waiters.drain(..) {
72 drop(waiter);
74 }
75 for waiter in self.message_waiters.drain(..) {
76 let _ = waiter.send(message.clone());
77 }
78 }
79
80 fn notify_ack_waiters(&mut self) {
82 for waiter in self.ack_waiters.drain(..) {
83 let _ = waiter.send(());
84 }
85 }
86}
87
88pub struct WaitingList {
97 repository: EntryRepository,
98 waiters: HashMap<String, Waiters>,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum SubscribeError {
104 WaiterLimitReached,
106 NotFound,
108}
109
110pub enum GetOrSubscribeResult {
112 Message(Message),
114 Waiting(oneshot::Receiver<Message>),
116}
117
118impl WaitingList {
119 pub fn new(repository: EntryRepository) -> Self {
121 Self {
122 repository,
123 waiters: HashMap::new(),
124 }
125 }
126
127 pub fn store(&mut self, id: String, message: Message, ttl: Duration) -> anyhow::Result<()> {
135 let expires_at = unix_timestamp_millis() + ttl.as_millis() as i64;
136
137 if let Some(waiters) = self.waiters.get_mut(&id) {
139 waiters.notify_message_waiters(&message);
140 if waiters.is_empty() {
141 self.waiters.remove(&id);
142 }
143 }
144
145 self.repository.insert(
147 &id,
148 &message.body,
149 message.content_type.as_deref(),
150 expires_at,
151 )
152 }
153
154 pub fn ack(&mut self, id: &str) -> bool {
158 let entry = match self.repository.get(id) {
160 Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => entry,
161 _ => return false,
162 };
163
164 if entry.acked {
166 return true;
167 }
168
169 if let Err(e) = self.repository.ack(id) {
171 tracing::error!(?e, id, "Failed to ack entry in repository");
172 return false;
173 }
174
175 if let Some(waiters) = self.waiters.get_mut(id) {
177 waiters.notify_ack_waiters();
178 if waiters.is_empty() {
179 self.waiters.remove(id);
180 }
181 }
182
183 true
184 }
185
186 pub fn is_acked(&self, id: &str) -> Option<bool> {
190 match self.repository.get(id) {
191 Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => Some(entry.acked),
192 _ => None,
193 }
194 }
195
196 pub fn subscribe_ack(&mut self, id: &str) -> Result<oneshot::Receiver<()>, SubscribeError> {
201 let entry = match self.repository.get(id) {
203 Ok(Some(entry)) if !Self::is_expired(entry.expires_at) => entry,
204 _ => return Err(SubscribeError::NotFound),
205 };
206
207 if entry.acked {
209 let (tx, rx) = oneshot::channel();
210 let _ = tx.send(());
211 return Ok(rx);
212 }
213
214 let (tx, rx) = oneshot::channel();
215 let waiters = self
216 .waiters
217 .entry(id.to_string())
218 .or_insert_with(Waiters::new);
219 if waiters.add_ack_waiter(tx) {
220 Ok(rx)
221 } else {
222 Err(SubscribeError::WaiterLimitReached)
223 }
224 }
225
226 pub fn get_or_subscribe(&mut self, id: &str) -> Result<GetOrSubscribeResult, SubscribeError> {
234 if let Ok(Some(entry)) = self.repository.get(id) {
236 if !Self::is_expired(entry.expires_at) {
237 if let Some(body) = entry.message_body {
238 return Ok(GetOrSubscribeResult::Message(Message {
239 body: Bytes::from(body),
240 content_type: entry.content_type,
241 }));
242 }
243 }
244 }
245
246 let (tx, rx) = oneshot::channel();
248 let waiters = self
249 .waiters
250 .entry(id.to_string())
251 .or_insert_with(Waiters::new);
252
253 if waiters.add_message_waiter(tx) {
254 Ok(GetOrSubscribeResult::Waiting(rx))
255 } else {
256 Err(SubscribeError::WaiterLimitReached)
257 }
258 }
259
260 pub fn cleanup_expired(&mut self) -> usize {
262 let expired_keys: Vec<String> = self
265 .waiters
266 .keys()
267 .filter(|id| {
268 match self.repository.get(id) {
269 Ok(Some(entry)) => Self::is_expired(entry.expires_at),
270 _ => false, }
272 })
273 .cloned()
274 .collect();
275
276 let count = match self.repository.cleanup_expired() {
278 Ok(c) => c,
279 Err(e) => {
280 tracing::error!(?e, "Failed to cleanup expired entries");
281 0
282 }
283 };
284
285 for waiters in self.waiters.values_mut() {
287 waiters.ack_waiters.retain(|s| !s.is_closed());
288 waiters.message_waiters.retain(|s| !s.is_closed());
289 }
290
291 self.waiters.retain(|_, w| !w.is_empty());
293
294 for key in expired_keys {
296 self.waiters.remove(&key);
297 }
298
299 count
300 }
301
302 fn is_expired(expires_at: i64) -> bool {
304 unix_timestamp_millis() >= expires_at
305 }
306}
307
308#[cfg(test)]
309impl WaitingList {
310 pub fn new_in_memory(max_entries: usize) -> Self {
312 let repository =
313 EntryRepository::new(None, max_entries).expect("Failed to create in-memory repository");
314 Self::new(repository)
315 }
316
317 pub fn len(&self) -> usize {
319 self.repository.count().unwrap_or(0)
320 }
321
322 pub fn is_empty(&self) -> bool {
324 self.len() == 0
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use std::time::Duration;
332 use tokio::time::sleep;
333
334 fn make_message(body: &str) -> Message {
335 Message {
336 body: Bytes::from(body.to_string()),
337 content_type: Some("text/plain".to_string()),
338 }
339 }
340
341 fn create_test_list() -> WaitingList {
342 WaitingList::new_in_memory(100)
343 }
344
345 #[tokio::test]
348 async fn cleanup_expired_removes_expired_entries() {
349 let mut list = create_test_list();
350 let short_ttl = Duration::from_millis(10);
351 let long_ttl = Duration::from_secs(60);
352
353 list.store("expires-soon".to_string(), make_message("a"), short_ttl)
354 .unwrap();
355 list.store("stays-alive".to_string(), make_message("b"), long_ttl)
356 .unwrap();
357
358 assert_eq!(list.len(), 2);
359
360 sleep(Duration::from_millis(20)).await;
361
362 let removed = list.cleanup_expired();
363
364 assert_eq!(removed, 1);
365 assert_eq!(list.len(), 1);
366 }
367
368 #[tokio::test]
369 async fn cleanup_expired_returns_zero_when_nothing_expired() {
370 let mut list = create_test_list();
371 let ttl = Duration::from_secs(60);
372
373 list.store("a".to_string(), make_message("a"), ttl).unwrap();
374 list.store("b".to_string(), make_message("b"), ttl).unwrap();
375
376 let removed = list.cleanup_expired();
377
378 assert_eq!(removed, 0);
379 assert_eq!(list.len(), 2);
380 }
381
382 #[tokio::test]
385 async fn store_notifies_message_waiters_on_overwrite() {
386 let mut list = create_test_list();
387 let ttl = Duration::from_secs(60);
388
389 let result = list.get_or_subscribe("id1").expect("should succeed");
391 let rx = match result {
392 GetOrSubscribeResult::Waiting(rx) => rx,
393 GetOrSubscribeResult::Message(_) => panic!("expected waiting"),
394 };
395
396 list.store("id1".to_string(), make_message("overwrite"), ttl)
398 .unwrap();
399
400 let received = rx.await.expect("should receive message");
401 assert_eq!(received.body, Bytes::from("overwrite"));
402 }
403
404 #[tokio::test]
405 async fn store_drops_ack_waiters_on_overwrite() {
406 let mut list = create_test_list();
407 let ttl = Duration::from_secs(60);
408
409 list.store("id1".to_string(), make_message("first"), ttl)
410 .unwrap();
411 let ack_rx = list.subscribe_ack("id1").expect("should subscribe");
412
413 list.store("id1".to_string(), make_message("second"), ttl)
415 .unwrap();
416
417 let result = ack_rx.await;
419 assert!(result.is_err(), "old ack waiter should be dropped");
420 }
421
422 #[test]
425 fn subscribe_ack_returns_limit_error() {
426 let mut list = create_test_list();
427 let ttl = Duration::from_secs(60);
428
429 list.store("id1".to_string(), make_message("test"), ttl)
430 .unwrap();
431
432 for _ in 0..MAX_WAITERS_PER_ENTRY {
433 let result = list.subscribe_ack("id1");
434 assert!(result.is_ok());
435 }
436
437 let result = list.subscribe_ack("id1");
438 assert!(
439 matches!(result, Err(SubscribeError::WaiterLimitReached)),
440 "expected WaiterLimitReached error"
441 );
442 }
443
444 #[test]
445 fn get_or_subscribe_returns_limit_error() {
446 let mut list = create_test_list();
447
448 for _ in 0..MAX_WAITERS_PER_ENTRY {
450 let result = list.get_or_subscribe("id1");
451 assert!(result.is_ok());
452 }
453
454 let result = list.get_or_subscribe("id1");
455 assert!(
456 matches!(result, Err(SubscribeError::WaiterLimitReached)),
457 "expected WaiterLimitReached error"
458 );
459 }
460
461 #[test]
464 fn is_acked_false_before_ack() {
465 let mut list = create_test_list();
466 let ttl = Duration::from_secs(60);
467
468 list.store("id1".to_string(), make_message("test"), ttl)
469 .unwrap();
470
471 assert_eq!(list.is_acked("id1"), Some(false));
472 }
473
474 #[test]
475 fn is_acked_true_after_ack() {
476 let mut list = create_test_list();
477 let ttl = Duration::from_secs(60);
478
479 list.store("id1".to_string(), make_message("test"), ttl)
480 .unwrap();
481 let ack_result = list.ack("id1");
482
483 assert!(ack_result, "ack should succeed");
484 assert_eq!(list.is_acked("id1"), Some(true));
485 }
486
487 #[tokio::test]
488 async fn ack_notifies_waiters() {
489 let mut list = create_test_list();
490 let ttl = Duration::from_secs(60);
491
492 list.store("id1".to_string(), make_message("test"), ttl)
493 .unwrap();
494 let rx = list.subscribe_ack("id1").expect("should subscribe");
495
496 list.ack("id1");
497
498 let result = rx.await;
499 assert!(result.is_ok(), "ack waiter should receive notification");
500 }
501
502 #[tokio::test]
503 async fn ack_fails_for_expired_entry() {
504 let mut list = create_test_list();
505 let short_ttl = Duration::from_millis(10);
506
507 list.store("id1".to_string(), make_message("test"), short_ttl)
508 .unwrap();
509
510 sleep(Duration::from_millis(20)).await;
511
512 assert!(!list.ack("id1"), "ack should fail for expired entry");
513 assert_eq!(
514 list.is_acked("id1"),
515 None,
516 "is_acked should be None for expired"
517 );
518 }
519
520 #[test]
521 fn ack_fails_for_nonexistent_entry() {
522 let mut list = create_test_list();
523
524 assert!(!list.ack("nonexistent"));
525 assert_eq!(list.is_acked("nonexistent"), None);
526 }
527
528 #[test]
531 fn get_or_subscribe_returns_existing_message() {
532 let mut list = create_test_list();
533 let ttl = Duration::from_secs(60);
534
535 list.store("id1".to_string(), make_message("existing"), ttl)
536 .unwrap();
537
538 let result = list.get_or_subscribe("id1").expect("should succeed");
539
540 match result {
541 GetOrSubscribeResult::Message(msg) => {
542 assert_eq!(msg.body, Bytes::from("existing"));
543 }
544 GetOrSubscribeResult::Waiting(_) => {
545 panic!("should return message, not waiting");
546 }
547 }
548 }
549
550 #[test]
551 fn get_or_subscribe_returns_receiver_when_no_entry() {
552 let mut list = create_test_list();
553
554 let result = list.get_or_subscribe("id1").expect("should succeed");
555
556 match result {
557 GetOrSubscribeResult::Message(_) => {
558 panic!("should return waiting, not message");
559 }
560 GetOrSubscribeResult::Waiting(_) => {
561 assert!(list.waiters.contains_key("id1"));
563 }
564 }
565 }
566
567 #[tokio::test]
568 async fn get_or_subscribe_receiver_gets_message_when_stored() {
569 let mut list = create_test_list();
570 let ttl = Duration::from_secs(60);
571
572 let result = list.get_or_subscribe("id1").expect("should succeed");
573 let rx = match result {
574 GetOrSubscribeResult::Waiting(rx) => rx,
575 GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
576 };
577
578 list.store("id1".to_string(), make_message("arrived"), ttl)
579 .unwrap();
580
581 let msg = rx.await.expect("should receive message");
582 assert_eq!(msg.body, Bytes::from("arrived"));
583 }
584
585 #[tokio::test]
586 async fn get_or_subscribe_ignores_expired_message() {
587 let mut list = create_test_list();
588 let short_ttl = Duration::from_millis(10);
589
590 list.store("id1".to_string(), make_message("expired"), short_ttl)
591 .unwrap();
592
593 sleep(Duration::from_millis(20)).await;
594
595 let result = list.get_or_subscribe("id1").expect("should succeed");
596
597 match result {
598 GetOrSubscribeResult::Message(_) => {
599 panic!("should not return expired message");
600 }
601 GetOrSubscribeResult::Waiting(_) => {
602 }
604 }
605 }
606
607 #[tokio::test]
610 async fn cleanup_does_not_remove_waiters_without_entry() {
611 let mut list = create_test_list();
613 let ttl = Duration::from_secs(60);
614
615 let result = list
617 .get_or_subscribe("consumer-first")
618 .expect("should succeed");
619 let rx = match result {
620 GetOrSubscribeResult::Waiting(rx) => rx,
621 GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
622 };
623
624 list.cleanup_expired();
626
627 list.store("consumer-first".to_string(), make_message("delayed"), ttl)
629 .unwrap();
630
631 let msg = rx
632 .await
633 .expect("waiter should not have been removed by cleanup");
634 assert_eq!(msg.body, Bytes::from("delayed"));
635 }
636
637 #[tokio::test]
638 async fn cleanup_removes_waiters_for_expired_entries() {
639 let mut list = create_test_list();
640 let short_ttl = Duration::from_millis(10);
641
642 list.store("will-expire".to_string(), make_message("test"), short_ttl)
644 .unwrap();
645 let ack_rx = list.subscribe_ack("will-expire").expect("should subscribe");
646
647 sleep(Duration::from_millis(20)).await;
649
650 list.cleanup_expired();
652
653 assert!(
655 ack_rx.await.is_err(),
656 "waiter should be removed for expired entry"
657 );
658 }
659
660 #[tokio::test]
661 async fn cleanup_removes_closed_senders() {
662 let mut list = create_test_list();
663
664 let result = list
666 .get_or_subscribe("dropped-receiver")
667 .expect("should succeed");
668 match result {
669 GetOrSubscribeResult::Waiting(rx) => drop(rx), GetOrSubscribeResult::Message(_) => panic!("should be waiting"),
671 };
672
673 assert!(list.waiters.contains_key("dropped-receiver"));
674
675 list.cleanup_expired();
677
678 assert!(!list.waiters.contains_key("dropped-receiver"));
680 }
681}