Skip to main content

cognis_trace/
batch.rs

1//! `Batcher<T>` — bounded queue + background flush task. Each exporter
2//! gets its own batcher so a slow backend doesn't block the bridge.
3
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10use tokio::time::Instant;
11
12/// Shared statistics counted across the batcher's lifetime.
13#[derive(Debug, Default)]
14pub struct BatcherStats {
15    /// Items successfully sent to the flush callback.
16    pub sent: AtomicUsize,
17    /// Items dropped because the queue was full.
18    pub dropped: AtomicUsize,
19    /// Flush callbacks that returned an error.
20    pub failed: AtomicUsize,
21}
22
23impl BatcherStats {
24    /// Returns `(sent, dropped, failed)`.
25    pub fn snapshot(&self) -> (usize, usize, usize) {
26        (
27            self.sent.load(Ordering::Relaxed),
28            self.dropped.load(Ordering::Relaxed),
29            self.failed.load(Ordering::Relaxed),
30        )
31    }
32}
33
34/// Configuration for a `Batcher`.
35#[derive(Debug, Clone, Copy)]
36pub struct BatcherConfig {
37    /// Max items per flush.
38    pub max_batch: usize,
39    /// Max wait before flushing a partial batch.
40    pub flush_interval: Duration,
41    /// Max items the queue holds before drops.
42    pub queue_capacity: usize,
43}
44
45impl Default for BatcherConfig {
46    fn default() -> Self {
47        Self {
48            max_batch: 100,
49            flush_interval: Duration::from_secs(1),
50            queue_capacity: 10_000,
51        }
52    }
53}
54
55/// Bounded queue + background flush task. Generic over the item type.
56pub struct Batcher<T: Send + 'static> {
57    tx: mpsc::Sender<T>,
58    stats: Arc<BatcherStats>,
59    handle: Option<JoinHandle<()>>,
60}
61
62impl<T: Send + 'static> Batcher<T> {
63    /// Spawn a batcher. `flush` receives a non-empty `Vec<T>` whenever a
64    /// batch is ready (size or time). Errors from `flush` increment
65    /// `stats.failed`.
66    pub fn spawn<F, Fut>(cfg: BatcherConfig, flush: F) -> Self
67    where
68        F: Fn(Vec<T>) -> Fut + Send + Sync + 'static,
69        Fut: std::future::Future<Output = Result<(), crate::TraceError>> + Send,
70    {
71        let (tx, mut rx) = mpsc::channel::<T>(cfg.queue_capacity);
72        let stats = Arc::new(BatcherStats::default());
73        let stats_for_task = stats.clone();
74        let handle = tokio::spawn(async move {
75            let mut buf: Vec<T> = Vec::with_capacity(cfg.max_batch);
76            let mut deadline = Instant::now() + cfg.flush_interval;
77            loop {
78                tokio::select! {
79                    biased;
80                    item = rx.recv() => match item {
81                        Some(x) => {
82                            buf.push(x);
83                            if buf.len() >= cfg.max_batch {
84                                Self::do_flush(&mut buf, &flush, &stats_for_task).await;
85                                deadline = Instant::now() + cfg.flush_interval;
86                            }
87                        }
88                        None => {
89                            if !buf.is_empty() {
90                                Self::do_flush(&mut buf, &flush, &stats_for_task).await;
91                            }
92                            break;
93                        }
94                    },
95                    _ = tokio::time::sleep_until(deadline) => {
96                        if !buf.is_empty() {
97                            Self::do_flush(&mut buf, &flush, &stats_for_task).await;
98                        }
99                        deadline = Instant::now() + cfg.flush_interval;
100                    }
101                }
102            }
103        });
104        Self {
105            tx,
106            stats,
107            handle: Some(handle),
108        }
109    }
110
111    async fn do_flush<F, Fut>(buf: &mut Vec<T>, flush: &F, stats: &Arc<BatcherStats>)
112    where
113        F: Fn(Vec<T>) -> Fut,
114        Fut: std::future::Future<Output = Result<(), crate::TraceError>>,
115    {
116        let count = buf.len();
117        let batch = std::mem::take(buf);
118        match flush(batch).await {
119            Ok(()) => {
120                stats.sent.fetch_add(count, Ordering::Relaxed);
121            }
122            Err(e) => {
123                stats.failed.fetch_add(count, Ordering::Relaxed);
124                tracing::warn!(error = %e, dropped = count, "trace batcher flush failed");
125            }
126        }
127    }
128
129    /// Non-blocking enqueue. Drops on overflow and increments the dropped
130    /// counter.
131    pub fn send(&self, item: T) {
132        if let Err(_e) = self.tx.try_send(item) {
133            self.stats.dropped.fetch_add(1, Ordering::Relaxed);
134        }
135    }
136
137    /// Stats handle (clone-cheap; counters are atomic).
138    pub fn stats(&self) -> Arc<BatcherStats> {
139        self.stats.clone()
140    }
141
142    /// Drop the sender and wait for the background task to drain remaining
143    /// items. Use on graceful shutdown.
144    pub async fn shutdown(mut self) {
145        drop(self.tx);
146        if let Some(h) = self.handle.take() {
147            let _ = h.await;
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::sync::Mutex;
156    use tokio::time::sleep;
157
158    #[tokio::test(flavor = "current_thread", start_paused = true)]
159    async fn flushes_when_batch_full() {
160        let collected = Arc::new(Mutex::new(Vec::<u32>::new()));
161        let c2 = collected.clone();
162        let cfg = BatcherConfig {
163            max_batch: 3,
164            flush_interval: Duration::from_secs(60),
165            queue_capacity: 100,
166        };
167        let b = Batcher::spawn(cfg, move |batch: Vec<u32>| {
168            let c = c2.clone();
169            async move {
170                c.lock().unwrap().extend(batch);
171                Ok(())
172            }
173        });
174        b.send(1);
175        b.send(2);
176        b.send(3); // triggers flush
177                   // Yield once so the task can run.
178        tokio::task::yield_now().await;
179        sleep(Duration::from_millis(1)).await;
180        assert_eq!(*collected.lock().unwrap(), vec![1, 2, 3]);
181        b.shutdown().await;
182    }
183
184    #[tokio::test(flavor = "current_thread", start_paused = true)]
185    async fn flushes_on_interval_with_partial_batch() {
186        let collected = Arc::new(Mutex::new(Vec::<u32>::new()));
187        let c2 = collected.clone();
188        let cfg = BatcherConfig {
189            max_batch: 100,
190            flush_interval: Duration::from_millis(100),
191            queue_capacity: 100,
192        };
193        let b = Batcher::spawn(cfg, move |batch: Vec<u32>| {
194            let c = c2.clone();
195            async move {
196                c.lock().unwrap().extend(batch);
197                Ok(())
198            }
199        });
200        b.send(7);
201        sleep(Duration::from_millis(150)).await;
202        assert_eq!(*collected.lock().unwrap(), vec![7]);
203        b.shutdown().await;
204    }
205
206    #[tokio::test(flavor = "current_thread")]
207    async fn drops_count_when_queue_full() {
208        let cfg = BatcherConfig {
209            max_batch: 1,
210            flush_interval: Duration::from_secs(60),
211            queue_capacity: 1,
212        };
213        let b = Batcher::spawn(cfg, |_batch: Vec<u32>| async move {
214            tokio::time::sleep(Duration::from_secs(60)).await;
215            Ok(())
216        });
217        // First send fills the channel; the batcher pulls and gets stuck
218        // in the slow flush. Subsequent sends fill again then drop.
219        for i in 0..50 {
220            b.send(i);
221        }
222        let (_, dropped, _) = b.stats().snapshot();
223        assert!(dropped > 0, "expected some drops, got {dropped}");
224    }
225}