1use std::future::Future;
44use std::sync::Arc;
45use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
46use std::time::Duration;
47
48use tokio::sync::{mpsc, oneshot};
49use tokio::task::JoinHandle;
50use tokio::time::{MissedTickBehavior, interval};
51use tokio_util::sync::CancellationToken;
52use tracing::warn;
53
54use super::error::{DrainError, SinkError};
55
56#[derive(Debug, Clone)]
58pub struct BackgroundSinkConfig {
59 pub queue_capacity: usize,
62
63 pub batch_size: usize,
66
67 pub flush_interval: Duration,
70
71 pub overflow: Overflow,
74
75 pub metric_prefix: Option<&'static str>,
87}
88
89impl Default for BackgroundSinkConfig {
90 fn default() -> Self {
91 Self {
92 queue_capacity: 10_000,
93 batch_size: 256,
94 flush_interval: Duration::from_millis(100),
95 overflow: Overflow::Drop,
96 metric_prefix: None,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum Overflow {
107 Drop,
110 Block,
114}
115
116enum SinkMsg<T> {
123 Data(T),
124 Barrier(oneshot::Sender<Result<(), DrainError>>),
125}
126
127pub trait SinkDrain<T: Send>: Send + 'static {
141 fn write_batch(&mut self, batch: Vec<T>)
150 -> impl Future<Output = Result<(), DrainError>> + Send;
151
152 fn flush_durable(&mut self) -> impl Future<Output = Result<(), DrainError>> + Send {
163 std::future::ready(Ok(()))
164 }
165
166 fn close(&mut self) -> impl Future<Output = Result<(), DrainError>> + Send {
170 std::future::ready(Ok(()))
171 }
172}
173
174#[derive(Debug, Clone)]
177pub struct BackgroundSink<T: Send + 'static> {
178 tx: mpsc::Sender<SinkMsg<T>>,
179 dropped: Arc<AtomicU64>,
180 pending: Arc<AtomicUsize>,
181 overflow: Overflow,
182 metric_prefix: Option<&'static str>,
183}
184
185pub struct BackgroundSinkHandle {
191 join: JoinHandle<()>,
192}
193
194impl BackgroundSinkHandle {
195 pub async fn join(self) -> Result<(), tokio::task::JoinError> {
200 self.join.await
201 }
202}
203
204impl<T: Send + 'static> BackgroundSink<T> {
205 pub fn spawn<D: SinkDrain<T>>(
208 drain: D,
209 config: BackgroundSinkConfig,
210 shutdown: CancellationToken,
211 ) -> (Self, BackgroundSinkHandle) {
212 let (tx, rx) = mpsc::channel(config.queue_capacity);
213 let dropped = Arc::new(AtomicU64::new(0));
214 let pending = Arc::new(AtomicUsize::new(0));
215 let metric_prefix = config.metric_prefix;
216 let overflow = config.overflow;
217
218 if let Some(prefix) = metric_prefix {
219 metrics::describe_counter!(
220 format!("{prefix}_pushed_total"),
221 "Messages successfully enqueued by the background sink"
222 );
223 metrics::describe_counter!(
224 format!("{prefix}_dropped_total"),
225 "Messages dropped due to queue overflow"
226 );
227 metrics::describe_counter!(
228 format!("{prefix}_writes_total"),
229 "Batch writes attempted by the drain"
230 );
231 metrics::describe_counter!(
232 format!("{prefix}_write_errors_total"),
233 "Batch writes that returned an error"
234 );
235 metrics::describe_gauge!(
236 format!("{prefix}_pending"),
237 "Current background sink queue depth"
238 );
239 }
240
241 let actor_pending = Arc::clone(&pending);
242 let join = tokio::spawn(actor_loop(
243 rx,
244 drain,
245 config,
246 shutdown,
247 actor_pending,
248 metric_prefix,
249 ));
250
251 (
252 Self {
253 tx,
254 dropped,
255 pending,
256 overflow,
257 metric_prefix,
258 },
259 BackgroundSinkHandle { join },
260 )
261 }
262
263 pub fn try_push(&self, msg: T) -> Result<(), SinkError> {
272 match self.overflow {
273 Overflow::Drop => {
274 self.pending.fetch_add(1, Ordering::Relaxed);
281 match self.tx.try_send(SinkMsg::Data(msg)) {
282 Ok(()) => {
283 if let Some(p) = self.metric_prefix {
284 metrics::counter!(format!("{p}_pushed_total")).increment(1);
285 }
286 Ok(())
287 }
288 Err(mpsc::error::TrySendError::Full(_)) => {
289 self.pending.fetch_sub(1, Ordering::Relaxed);
291 self.dropped.fetch_add(1, Ordering::Relaxed);
292 if let Some(p) = self.metric_prefix {
293 metrics::counter!(format!("{p}_dropped_total")).increment(1);
294 }
295 Err(SinkError::Overflow)
296 }
297 Err(mpsc::error::TrySendError::Closed(_)) => {
298 self.pending.fetch_sub(1, Ordering::Relaxed);
299 Err(SinkError::Closed)
300 }
301 }
302 }
303 Overflow::Block => Err(SinkError::Overflow),
304 }
305 }
306
307 pub async fn push_blocking(&self, msg: T) -> Result<(), SinkError> {
310 self.pending.fetch_add(1, Ordering::Relaxed);
314 if self.tx.send(SinkMsg::Data(msg)).await.is_err() {
315 self.pending.fetch_sub(1, Ordering::Relaxed);
316 return Err(SinkError::Closed);
317 }
318 if let Some(p) = self.metric_prefix {
319 metrics::counter!(format!("{p}_pushed_total")).increment(1);
320 }
321 Ok(())
322 }
323
324 pub async fn flush(&self) -> Result<(), SinkError> {
333 let (ack_tx, ack_rx) = oneshot::channel();
334 self.tx
335 .send(SinkMsg::Barrier(ack_tx))
336 .await
337 .map_err(|_| SinkError::Closed)?;
338 ack_rx
341 .await
342 .map_err(|_| SinkError::Closed)?
343 .map_err(SinkError::Drain)
344 }
345
346 #[must_use]
348 pub fn dropped(&self) -> u64 {
349 self.dropped.load(Ordering::Relaxed)
350 }
351
352 #[must_use]
354 pub fn pending(&self) -> usize {
355 self.pending.load(Ordering::Relaxed)
356 }
357}
358
359async fn actor_loop<T, D>(
360 mut rx: mpsc::Receiver<SinkMsg<T>>,
361 mut drain: D,
362 config: BackgroundSinkConfig,
363 shutdown: CancellationToken,
364 pending: Arc<AtomicUsize>,
365 metric_prefix: Option<&'static str>,
366) where
367 T: Send + 'static,
368 D: SinkDrain<T>,
369{
370 let mut batch: Vec<T> = Vec::with_capacity(config.batch_size);
371 let mut tick = interval(config.flush_interval);
372 tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
373 tick.tick().await;
375
376 loop {
377 tokio::select! {
378 biased;
379
380 () = shutdown.cancelled() => {
381 rx.close();
385
386 if !batch.is_empty() {
390 let _ = write_batch_with_metrics(
391 &mut drain, std::mem::take(&mut batch),
392 &pending, metric_prefix,
393 ).await;
394 }
395 while let Ok(msg) = rx.try_recv() {
396 match msg {
397 SinkMsg::Data(t) => {
398 batch.push(t);
399 if batch.len() >= config.batch_size {
400 let _ = write_batch_with_metrics(
401 &mut drain, std::mem::take(&mut batch),
402 &pending, metric_prefix,
403 ).await;
404 }
405 }
406 SinkMsg::Barrier(ack) => {
407 let result = barrier_drain(
411 &mut drain, std::mem::take(&mut batch),
412 &pending, metric_prefix,
413 ).await;
414 let _ = ack.send(result);
415 }
416 }
417 }
418 if !batch.is_empty() {
419 let _ = write_batch_with_metrics(
420 &mut drain, std::mem::take(&mut batch),
421 &pending, metric_prefix,
422 ).await;
423 }
424 if let Err(e) = drain.close().await {
425 warn!(error = %e, "sink drain close failed");
426 }
427 return;
428 }
429
430 msg = rx.recv() => match msg {
431 Some(SinkMsg::Data(t)) => {
432 batch.push(t);
433 if batch.len() >= config.batch_size {
434 let _ = write_batch_with_metrics(
435 &mut drain, std::mem::take(&mut batch),
436 &pending, metric_prefix,
437 ).await;
438 }
439 }
440 Some(SinkMsg::Barrier(ack)) => {
441 let result = barrier_drain(
442 &mut drain, std::mem::take(&mut batch),
443 &pending, metric_prefix,
444 ).await;
445 let _ = ack.send(result);
446 }
447 None => {
448 if !batch.is_empty() {
449 let _ = write_batch_with_metrics(
450 &mut drain, std::mem::take(&mut batch),
451 &pending, metric_prefix,
452 ).await;
453 }
454 if let Err(e) = drain.close().await {
455 warn!(error = %e, "sink drain close failed");
456 }
457 return;
458 }
459 },
460
461 _ = tick.tick() => {
462 if !batch.is_empty() {
463 let _ = write_batch_with_metrics(
464 &mut drain, std::mem::take(&mut batch),
465 &pending, metric_prefix,
466 ).await;
467 }
468 }
469 }
470 }
471}
472
473async fn barrier_drain<T, D: SinkDrain<T>>(
477 drain: &mut D,
478 batch: Vec<T>,
479 pending: &AtomicUsize,
480 metric_prefix: Option<&'static str>,
481) -> Result<(), DrainError>
482where
483 T: Send,
484{
485 let write_result = if batch.is_empty() {
486 Ok(())
487 } else {
488 write_batch_with_metrics(drain, batch, pending, metric_prefix).await
489 };
490 let durable_result = drain.flush_durable().await;
491 write_result.and(durable_result)
492}
493
494async fn write_batch_with_metrics<T, D: SinkDrain<T>>(
495 drain: &mut D,
496 batch: Vec<T>,
497 pending: &AtomicUsize,
498 metric_prefix: Option<&'static str>,
499) -> Result<(), DrainError>
500where
501 T: Send,
502{
503 let count = batch.len();
504 pending.fetch_sub(count, Ordering::Relaxed);
505 if let Some(p) = metric_prefix {
506 metrics::counter!(format!("{p}_writes_total")).increment(1);
507 metrics::gauge!(format!("{p}_pending")).set(pending.load(Ordering::Relaxed) as f64);
508 }
509 match drain.write_batch(batch).await {
510 Ok(()) => Ok(()),
511 Err(e) => {
512 warn!(error = %e, count, "sink drain write_batch failed");
513 if let Some(p) = metric_prefix {
514 metrics::counter!(format!("{p}_write_errors_total")).increment(1);
515 }
516 Err(e)
517 }
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use std::sync::Arc;
525 use std::sync::atomic::{AtomicU64, Ordering};
526 use std::time::Instant;
527
528 use tokio::sync::Notify;
529 use tokio_util::sync::CancellationToken;
530
531 struct CountingDrain {
533 count: Arc<AtomicU64>,
534 }
535
536 impl SinkDrain<u32> for CountingDrain {
537 async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
538 self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
539 Ok(())
540 }
541 }
542
543 struct ThrottledDrain {
546 release: Arc<Notify>,
547 count: Arc<AtomicU64>,
548 }
549
550 impl SinkDrain<u32> for ThrottledDrain {
551 async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
552 self.release.notified().await;
553 self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
554 Ok(())
555 }
556 }
557
558 fn fast_config() -> BackgroundSinkConfig {
559 BackgroundSinkConfig {
560 queue_capacity: 1024,
561 batch_size: 16,
562 flush_interval: Duration::from_millis(20),
563 overflow: Overflow::Drop,
564 metric_prefix: None,
565 }
566 }
567
568 #[tokio::test]
569 async fn try_push_succeeds_when_queue_has_space() {
570 let count = Arc::new(AtomicU64::new(0));
571 let shutdown = CancellationToken::new();
572 let (sink, _handle) = BackgroundSink::spawn(
573 CountingDrain {
574 count: count.clone(),
575 },
576 fast_config(),
577 shutdown.clone(),
578 );
579
580 for i in 0..10 {
581 sink.try_push(i).expect("queue has space");
582 }
583 sink.flush().await.expect("flush ok");
584 assert_eq!(count.load(Ordering::SeqCst), 10);
585 shutdown.cancel();
586 }
587
588 #[tokio::test]
589 async fn try_push_returns_overflow_when_full() {
590 let count = Arc::new(AtomicU64::new(0));
591 let release = Arc::new(Notify::new());
592 let shutdown = CancellationToken::new();
593 let cfg = BackgroundSinkConfig {
594 queue_capacity: 4,
595 batch_size: 16,
596 flush_interval: Duration::from_mins(1),
597 overflow: Overflow::Drop,
598 metric_prefix: None,
599 };
600 let (sink, _handle) = BackgroundSink::spawn(
601 ThrottledDrain {
602 release: release.clone(),
603 count: count.clone(),
604 },
605 cfg,
606 shutdown.clone(),
607 );
608
609 let mut accepted: u64 = 0;
611 let mut overflowed: u64 = 0;
612 for i in 0..20 {
613 match sink.try_push(i) {
614 Ok(()) => accepted += 1,
615 Err(SinkError::Overflow) => overflowed += 1,
616 Err(e) => panic!("unexpected error: {e}"),
617 }
618 }
619 assert!(overflowed > 0, "expected at least one overflow");
620 assert_eq!(sink.dropped(), overflowed);
621 assert!(accepted >= 4, "should accept at least queue_capacity");
623 let _ = (accepted, count); shutdown.cancel();
625 release.notify_waiters();
626 }
627
628 #[tokio::test]
629 async fn try_push_in_block_mode_always_errors() {
630 let count = Arc::new(AtomicU64::new(0));
631 let shutdown = CancellationToken::new();
632 let cfg = BackgroundSinkConfig {
633 overflow: Overflow::Block,
634 ..fast_config()
635 };
636 let (sink, _handle) = BackgroundSink::spawn(
637 CountingDrain {
638 count: count.clone(),
639 },
640 cfg,
641 shutdown.clone(),
642 );
643 match sink.try_push(1) {
645 Err(SinkError::Overflow) => {}
646 other => panic!("expected Overflow, got {other:?}"),
647 }
648 sink.push_blocking(1)
650 .await
651 .expect("push_blocking ok in Block mode");
652 sink.flush().await.expect("flush ok");
653 assert_eq!(count.load(Ordering::SeqCst), 1);
654 shutdown.cancel();
655 }
656
657 #[tokio::test]
658 async fn flush_waits_for_pre_flush_messages() {
659 let count = Arc::new(AtomicU64::new(0));
660 let shutdown = CancellationToken::new();
661 let (sink, _handle) = BackgroundSink::spawn(
662 CountingDrain {
663 count: count.clone(),
664 },
665 fast_config(),
666 shutdown.clone(),
667 );
668
669 for i in 0..100 {
670 sink.try_push(i).expect("queue has space");
671 }
672 sink.flush().await.expect("flush ok");
673 assert_eq!(count.load(Ordering::SeqCst), 100);
675 shutdown.cancel();
676 }
677
678 #[tokio::test]
679 async fn shutdown_drains_remaining_queue() {
680 let count = Arc::new(AtomicU64::new(0));
681 let shutdown = CancellationToken::new();
682 let (sink, handle) = BackgroundSink::spawn(
683 CountingDrain {
684 count: count.clone(),
685 },
686 fast_config(),
687 shutdown.clone(),
688 );
689 for i in 0..50 {
690 sink.try_push(i).expect("queue has space");
691 }
692 shutdown.cancel();
693 handle.join().await.expect("clean exit");
694 assert_eq!(count.load(Ordering::SeqCst), 50);
695 }
696
697 #[tokio::test]
698 async fn dropped_counter_reflects_overflow_count() {
699 let count = Arc::new(AtomicU64::new(0));
700 let release = Arc::new(Notify::new());
701 let shutdown = CancellationToken::new();
702 let cfg = BackgroundSinkConfig {
703 queue_capacity: 2,
704 batch_size: 16,
705 flush_interval: Duration::from_mins(1),
706 overflow: Overflow::Drop,
707 metric_prefix: None,
708 };
709 let (sink, _handle) = BackgroundSink::spawn(
710 ThrottledDrain {
711 release: release.clone(),
712 count: count.clone(),
713 },
714 cfg,
715 shutdown.clone(),
716 );
717 for i in 0..100 {
718 let _ = sink.try_push(i);
719 }
720 assert!(sink.dropped() >= 90, "dropped={}", sink.dropped());
722 shutdown.cancel();
723 release.notify_waiters();
724 }
725
726 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
735 async fn try_push_stays_fast_under_load() {
736 let count = Arc::new(AtomicU64::new(0));
737 let release = Arc::new(Notify::new());
738 let shutdown = CancellationToken::new();
739 let cfg = BackgroundSinkConfig {
740 queue_capacity: 100_000,
741 batch_size: 256,
742 flush_interval: Duration::from_mins(1),
743 overflow: Overflow::Drop,
744 metric_prefix: None,
745 };
746 let (sink, _handle) = BackgroundSink::spawn(
747 ThrottledDrain {
748 release: release.clone(),
749 count: count.clone(),
750 },
751 cfg,
752 shutdown.clone(),
753 );
754
755 let start = Instant::now();
758 for i in 0..10_000_u32 {
759 sink.try_push(i).expect("queue has space");
760 }
761 let elapsed = start.elapsed();
762 let avg_us = elapsed.as_micros() as f64 / 10_000.0;
766 assert!(
767 avg_us < 50.0,
768 "try_push p_avg = {avg_us}µs (expected <50µs under load with throttled drain)",
769 );
770 shutdown.cancel();
771 release.notify_waiters();
772 }
773
774 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
775 async fn many_concurrent_producers_dont_block_each_other() {
776 let count = Arc::new(AtomicU64::new(0));
777 let shutdown = CancellationToken::new();
778 let cfg = BackgroundSinkConfig {
779 queue_capacity: 100_000,
780 batch_size: 1024,
781 flush_interval: Duration::from_millis(5),
782 overflow: Overflow::Drop,
783 metric_prefix: None,
784 };
785 let (sink, _handle) = BackgroundSink::spawn(
786 CountingDrain {
787 count: count.clone(),
788 },
789 cfg,
790 shutdown.clone(),
791 );
792
793 let mut handles = Vec::new();
795 for _ in 0..8 {
796 let s = sink.clone();
797 handles.push(tokio::spawn(async move {
798 for i in 0..1_000_u32 {
799 s.try_push(i).expect("queue has space");
800 }
801 }));
802 }
803 for h in handles {
804 h.await.expect("producer exit");
805 }
806 sink.flush().await.expect("flush ok");
807 assert_eq!(count.load(Ordering::SeqCst), 8_000);
808 shutdown.cancel();
809 }
810
811 #[tokio::test]
812 async fn flush_completes_quickly_when_queue_is_already_empty() {
813 let count = Arc::new(AtomicU64::new(0));
814 let shutdown = CancellationToken::new();
815 let (sink, _handle) = BackgroundSink::spawn(
816 CountingDrain {
817 count: count.clone(),
818 },
819 fast_config(),
820 shutdown.clone(),
821 );
822 let start = Instant::now();
824 sink.flush().await.expect("flush ok");
825 let elapsed = start.elapsed();
826 assert!(
827 elapsed < Duration::from_millis(50),
828 "empty flush took {elapsed:?} (expected <50ms)",
829 );
830 shutdown.cancel();
831 }
832
833 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
837 async fn slow_drain_doesnt_block_consumer() {
838 struct SlowDrain {
839 count: Arc<AtomicU64>,
840 }
841 impl SinkDrain<u32> for SlowDrain {
842 async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
843 tokio::time::sleep(Duration::from_millis(50)).await;
844 self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
845 Ok(())
846 }
847 }
848
849 let count = Arc::new(AtomicU64::new(0));
850 let shutdown = CancellationToken::new();
851 let cfg = BackgroundSinkConfig {
852 queue_capacity: 10_000,
853 batch_size: 16,
854 flush_interval: Duration::from_millis(5),
855 overflow: Overflow::Drop,
856 metric_prefix: None,
857 };
858 let (sink, _handle) = BackgroundSink::spawn(
859 SlowDrain {
860 count: count.clone(),
861 },
862 cfg,
863 shutdown.clone(),
864 );
865
866 let mut max_us: u128 = 0;
869 for i in 0..200_u32 {
870 let t0 = Instant::now();
871 sink.try_push(i).expect("queue has space");
872 let elapsed_us = t0.elapsed().as_micros();
873 if elapsed_us > max_us {
874 max_us = elapsed_us;
875 }
876 }
877 assert!(
878 max_us < 5_000,
879 "max try_push latency was {max_us}µs -- slow drain leaked back to consumer",
880 );
881 shutdown.cancel();
882 }
883
884 #[tokio::test]
887 async fn flush_surfaces_drain_write_failure() {
888 struct FailingDrain;
889 impl SinkDrain<u32> for FailingDrain {
890 async fn write_batch(&mut self, _batch: Vec<u32>) -> Result<(), DrainError> {
891 Err(DrainError::Io(std::io::Error::other("simulated")))
892 }
893 }
894 let shutdown = CancellationToken::new();
895 let (sink, _handle) = BackgroundSink::spawn(FailingDrain, fast_config(), shutdown.clone());
896 sink.try_push(1).unwrap();
897 let err = sink.flush().await.unwrap_err();
898 assert!(matches!(err, SinkError::Drain(_)), "got: {err:?}");
899 shutdown.cancel();
900 }
901
902 #[tokio::test]
905 async fn flush_surfaces_flush_durable_failure() {
906 struct WriteOkFlushFail;
907 impl SinkDrain<u32> for WriteOkFlushFail {
908 async fn write_batch(&mut self, _batch: Vec<u32>) -> Result<(), DrainError> {
909 Ok(())
910 }
911 async fn flush_durable(&mut self) -> Result<(), DrainError> {
912 Err(DrainError::Io(std::io::Error::other("durable-fail")))
913 }
914 }
915 let shutdown = CancellationToken::new();
916 let (sink, _handle) =
917 BackgroundSink::spawn(WriteOkFlushFail, fast_config(), shutdown.clone());
918 sink.try_push(1).unwrap();
919 let err = sink.flush().await.unwrap_err();
920 assert!(matches!(err, SinkError::Drain(_)), "got: {err:?}");
921 shutdown.cancel();
922 }
923}