use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::{MissedTickBehavior, interval};
use tokio_util::sync::CancellationToken;
use tracing::warn;
use super::error::{DrainError, SinkError};
#[derive(Debug, Clone)]
pub struct BackgroundSinkConfig {
pub queue_capacity: usize,
pub batch_size: usize,
pub flush_interval: Duration,
pub overflow: Overflow,
pub metric_prefix: Option<&'static str>,
}
impl Default for BackgroundSinkConfig {
fn default() -> Self {
Self {
queue_capacity: 10_000,
batch_size: 256,
flush_interval: Duration::from_millis(100),
overflow: Overflow::Drop,
metric_prefix: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Overflow {
Drop,
Block,
}
enum SinkMsg<T> {
Data(T),
Barrier(oneshot::Sender<()>),
}
pub trait SinkDrain<T: Send>: Send + 'static {
fn write_batch(&mut self, batch: Vec<T>)
-> impl Future<Output = Result<(), DrainError>> + Send;
fn close(&mut self) -> impl Future<Output = Result<(), DrainError>> + Send {
std::future::ready(Ok(()))
}
}
#[derive(Debug, Clone)]
pub struct BackgroundSink<T: Send + 'static> {
tx: mpsc::Sender<SinkMsg<T>>,
dropped: Arc<AtomicU64>,
pending: Arc<AtomicUsize>,
overflow: Overflow,
metric_prefix: Option<&'static str>,
}
pub struct BackgroundSinkHandle {
join: JoinHandle<()>,
}
impl BackgroundSinkHandle {
pub async fn join(self) -> Result<(), tokio::task::JoinError> {
self.join.await
}
}
impl<T: Send + 'static> BackgroundSink<T> {
pub fn spawn<D: SinkDrain<T>>(
drain: D,
config: BackgroundSinkConfig,
shutdown: CancellationToken,
) -> (Self, BackgroundSinkHandle) {
let (tx, rx) = mpsc::channel(config.queue_capacity);
let dropped = Arc::new(AtomicU64::new(0));
let pending = Arc::new(AtomicUsize::new(0));
let metric_prefix = config.metric_prefix;
let overflow = config.overflow;
if let Some(prefix) = metric_prefix {
metrics::describe_counter!(
format!("{prefix}_pushed_total"),
"Messages successfully enqueued by the background sink"
);
metrics::describe_counter!(
format!("{prefix}_dropped_total"),
"Messages dropped due to queue overflow"
);
metrics::describe_counter!(
format!("{prefix}_writes_total"),
"Batch writes attempted by the drain"
);
metrics::describe_counter!(
format!("{prefix}_write_errors_total"),
"Batch writes that returned an error"
);
metrics::describe_gauge!(
format!("{prefix}_pending"),
"Current background sink queue depth"
);
}
let actor_pending = Arc::clone(&pending);
let join = tokio::spawn(actor_loop(
rx,
drain,
config,
shutdown,
actor_pending,
metric_prefix,
));
(
Self {
tx,
dropped,
pending,
overflow,
metric_prefix,
},
BackgroundSinkHandle { join },
)
}
pub fn try_push(&self, msg: T) -> Result<(), SinkError> {
match self.overflow {
Overflow::Drop => match self.tx.try_send(SinkMsg::Data(msg)) {
Ok(()) => {
self.pending.fetch_add(1, Ordering::Relaxed);
if let Some(p) = self.metric_prefix {
metrics::counter!(format!("{p}_pushed_total")).increment(1);
}
Ok(())
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.dropped.fetch_add(1, Ordering::Relaxed);
if let Some(p) = self.metric_prefix {
metrics::counter!(format!("{p}_dropped_total")).increment(1);
}
Err(SinkError::Overflow)
}
Err(mpsc::error::TrySendError::Closed(_)) => Err(SinkError::Closed),
},
Overflow::Block => Err(SinkError::Overflow),
}
}
pub async fn push_blocking(&self, msg: T) -> Result<(), SinkError> {
self.tx
.send(SinkMsg::Data(msg))
.await
.map_err(|_| SinkError::Closed)?;
self.pending.fetch_add(1, Ordering::Relaxed);
if let Some(p) = self.metric_prefix {
metrics::counter!(format!("{p}_pushed_total")).increment(1);
}
Ok(())
}
pub async fn flush(&self) -> Result<(), SinkError> {
let (ack_tx, ack_rx) = oneshot::channel();
self.tx
.send(SinkMsg::Barrier(ack_tx))
.await
.map_err(|_| SinkError::Closed)?;
ack_rx.await.map_err(|_| SinkError::Closed)
}
#[must_use]
pub fn dropped(&self) -> u64 {
self.dropped.load(Ordering::Relaxed)
}
#[must_use]
pub fn pending(&self) -> usize {
self.pending.load(Ordering::Relaxed)
}
}
async fn actor_loop<T, D>(
mut rx: mpsc::Receiver<SinkMsg<T>>,
mut drain: D,
config: BackgroundSinkConfig,
shutdown: CancellationToken,
pending: Arc<AtomicUsize>,
metric_prefix: Option<&'static str>,
) where
T: Send + 'static,
D: SinkDrain<T>,
{
let mut batch: Vec<T> = Vec::with_capacity(config.batch_size);
let mut tick = interval(config.flush_interval);
tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
tick.tick().await;
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
if !batch.is_empty() {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
while let Ok(msg) = rx.try_recv() {
match msg {
SinkMsg::Data(t) => batch.push(t),
SinkMsg::Barrier(ack) => { let _ = ack.send(()); }
}
if batch.len() >= config.batch_size {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
}
if !batch.is_empty() {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
if let Err(e) = drain.close().await {
warn!(error = %e, "sink drain close failed");
}
return;
}
msg = rx.recv() => match msg {
Some(SinkMsg::Data(t)) => {
batch.push(t);
if batch.len() >= config.batch_size {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
}
Some(SinkMsg::Barrier(ack)) => {
if !batch.is_empty() {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
let _ = ack.send(());
}
None => {
if !batch.is_empty() {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
if let Err(e) = drain.close().await {
warn!(error = %e, "sink drain close failed");
}
return;
}
},
_ = tick.tick() => {
if !batch.is_empty() {
write_batch_with_metrics(
&mut drain, std::mem::take(&mut batch),
&pending, metric_prefix,
).await;
}
}
}
}
}
async fn write_batch_with_metrics<T, D: SinkDrain<T>>(
drain: &mut D,
batch: Vec<T>,
pending: &AtomicUsize,
metric_prefix: Option<&'static str>,
) where
T: Send,
{
let count = batch.len();
pending.fetch_sub(count, Ordering::Relaxed);
if let Some(p) = metric_prefix {
metrics::counter!(format!("{p}_writes_total")).increment(1);
metrics::gauge!(format!("{p}_pending")).set(pending.load(Ordering::Relaxed) as f64);
}
if let Err(e) = drain.write_batch(batch).await {
warn!(error = %e, count, "sink drain write_batch failed");
if let Some(p) = metric_prefix {
metrics::counter!(format!("{p}_write_errors_total")).increment(1);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
struct CountingDrain {
count: Arc<AtomicU64>,
}
impl SinkDrain<u32> for CountingDrain {
async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
Ok(())
}
}
struct ThrottledDrain {
release: Arc<Notify>,
count: Arc<AtomicU64>,
}
impl SinkDrain<u32> for ThrottledDrain {
async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
self.release.notified().await;
self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
Ok(())
}
}
fn fast_config() -> BackgroundSinkConfig {
BackgroundSinkConfig {
queue_capacity: 1024,
batch_size: 16,
flush_interval: Duration::from_millis(20),
overflow: Overflow::Drop,
metric_prefix: None,
}
}
#[tokio::test]
async fn try_push_succeeds_when_queue_has_space() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let (sink, _handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
fast_config(),
shutdown.clone(),
);
for i in 0..10 {
sink.try_push(i).expect("queue has space");
}
sink.flush().await.expect("flush ok");
assert_eq!(count.load(Ordering::SeqCst), 10);
shutdown.cancel();
}
#[tokio::test]
async fn try_push_returns_overflow_when_full() {
let count = Arc::new(AtomicU64::new(0));
let release = Arc::new(Notify::new());
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
queue_capacity: 4,
batch_size: 16,
flush_interval: Duration::from_mins(1),
overflow: Overflow::Drop,
metric_prefix: None,
};
let (sink, _handle) = BackgroundSink::spawn(
ThrottledDrain {
release: release.clone(),
count: count.clone(),
},
cfg,
shutdown.clone(),
);
let mut accepted: u64 = 0;
let mut overflowed: u64 = 0;
for i in 0..20 {
match sink.try_push(i) {
Ok(()) => accepted += 1,
Err(SinkError::Overflow) => overflowed += 1,
Err(e) => panic!("unexpected error: {e}"),
}
}
assert!(overflowed > 0, "expected at least one overflow");
assert_eq!(sink.dropped(), overflowed);
assert!(accepted >= 4, "should accept at least queue_capacity");
let _ = (accepted, count); shutdown.cancel();
release.notify_waiters();
}
#[tokio::test]
async fn try_push_in_block_mode_always_errors() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
overflow: Overflow::Block,
..fast_config()
};
let (sink, _handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
cfg,
shutdown.clone(),
);
match sink.try_push(1) {
Err(SinkError::Overflow) => {}
other => panic!("expected Overflow, got {other:?}"),
}
sink.push_blocking(1)
.await
.expect("push_blocking ok in Block mode");
sink.flush().await.expect("flush ok");
assert_eq!(count.load(Ordering::SeqCst), 1);
shutdown.cancel();
}
#[tokio::test]
async fn flush_waits_for_pre_flush_messages() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let (sink, _handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
fast_config(),
shutdown.clone(),
);
for i in 0..100 {
sink.try_push(i).expect("queue has space");
}
sink.flush().await.expect("flush ok");
assert_eq!(count.load(Ordering::SeqCst), 100);
shutdown.cancel();
}
#[tokio::test]
async fn shutdown_drains_remaining_queue() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let (sink, handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
fast_config(),
shutdown.clone(),
);
for i in 0..50 {
sink.try_push(i).expect("queue has space");
}
shutdown.cancel();
handle.join().await.expect("clean exit");
assert_eq!(count.load(Ordering::SeqCst), 50);
}
#[tokio::test]
async fn dropped_counter_reflects_overflow_count() {
let count = Arc::new(AtomicU64::new(0));
let release = Arc::new(Notify::new());
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
queue_capacity: 2,
batch_size: 16,
flush_interval: Duration::from_mins(1),
overflow: Overflow::Drop,
metric_prefix: None,
};
let (sink, _handle) = BackgroundSink::spawn(
ThrottledDrain {
release: release.clone(),
count: count.clone(),
},
cfg,
shutdown.clone(),
);
for i in 0..100 {
let _ = sink.try_push(i);
}
assert!(sink.dropped() >= 90, "dropped={}", sink.dropped());
shutdown.cancel();
release.notify_waiters();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn try_push_stays_fast_under_load() {
let count = Arc::new(AtomicU64::new(0));
let release = Arc::new(Notify::new());
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
queue_capacity: 100_000,
batch_size: 256,
flush_interval: Duration::from_mins(1),
overflow: Overflow::Drop,
metric_prefix: None,
};
let (sink, _handle) = BackgroundSink::spawn(
ThrottledDrain {
release: release.clone(),
count: count.clone(),
},
cfg,
shutdown.clone(),
);
let start = Instant::now();
for i in 0..10_000_u32 {
sink.try_push(i).expect("queue has space");
}
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() as f64 / 10_000.0;
assert!(
avg_us < 50.0,
"try_push p_avg = {avg_us}µs (expected <50µs under load with throttled drain)",
);
shutdown.cancel();
release.notify_waiters();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn many_concurrent_producers_dont_block_each_other() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
queue_capacity: 100_000,
batch_size: 1024,
flush_interval: Duration::from_millis(5),
overflow: Overflow::Drop,
metric_prefix: None,
};
let (sink, _handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
cfg,
shutdown.clone(),
);
let mut handles = Vec::new();
for _ in 0..8 {
let s = sink.clone();
handles.push(tokio::spawn(async move {
for i in 0..1_000_u32 {
s.try_push(i).expect("queue has space");
}
}));
}
for h in handles {
h.await.expect("producer exit");
}
sink.flush().await.expect("flush ok");
assert_eq!(count.load(Ordering::SeqCst), 8_000);
shutdown.cancel();
}
#[tokio::test]
async fn flush_completes_quickly_when_queue_is_already_empty() {
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let (sink, _handle) = BackgroundSink::spawn(
CountingDrain {
count: count.clone(),
},
fast_config(),
shutdown.clone(),
);
let start = Instant::now();
sink.flush().await.expect("flush ok");
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(50),
"empty flush took {elapsed:?} (expected <50ms)",
);
shutdown.cancel();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn slow_drain_doesnt_block_consumer() {
struct SlowDrain {
count: Arc<AtomicU64>,
}
impl SinkDrain<u32> for SlowDrain {
async fn write_batch(&mut self, batch: Vec<u32>) -> Result<(), DrainError> {
tokio::time::sleep(Duration::from_millis(50)).await;
self.count.fetch_add(batch.len() as u64, Ordering::SeqCst);
Ok(())
}
}
let count = Arc::new(AtomicU64::new(0));
let shutdown = CancellationToken::new();
let cfg = BackgroundSinkConfig {
queue_capacity: 10_000,
batch_size: 16,
flush_interval: Duration::from_millis(5),
overflow: Overflow::Drop,
metric_prefix: None,
};
let (sink, _handle) = BackgroundSink::spawn(
SlowDrain {
count: count.clone(),
},
cfg,
shutdown.clone(),
);
let mut max_us: u128 = 0;
for i in 0..200_u32 {
let t0 = Instant::now();
sink.try_push(i).expect("queue has space");
let elapsed_us = t0.elapsed().as_micros();
if elapsed_us > max_us {
max_us = elapsed_us;
}
}
assert!(
max_us < 5_000,
"max try_push latency was {max_us}µs — slow drain leaked back to consumer",
);
shutdown.cancel();
}
}