1use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10use tokio::time::Instant;
11
12#[derive(Debug, Default)]
14pub struct BatcherStats {
15 pub sent: AtomicUsize,
17 pub dropped: AtomicUsize,
19 pub failed: AtomicUsize,
21}
22
23impl BatcherStats {
24 pub fn snapshot(&self) -> (usize, usize, usize) {
26 (
27 self.sent.load(Ordering::Relaxed),
28 self.dropped.load(Ordering::Relaxed),
29 self.failed.load(Ordering::Relaxed),
30 )
31 }
32}
33
34#[derive(Debug, Clone, Copy)]
36pub struct BatcherConfig {
37 pub max_batch: usize,
39 pub flush_interval: Duration,
41 pub queue_capacity: usize,
43}
44
45impl Default for BatcherConfig {
46 fn default() -> Self {
47 Self {
48 max_batch: 100,
49 flush_interval: Duration::from_secs(1),
50 queue_capacity: 10_000,
51 }
52 }
53}
54
55pub struct Batcher<T: Send + 'static> {
57 tx: mpsc::Sender<T>,
58 stats: Arc<BatcherStats>,
59 handle: Option<JoinHandle<()>>,
60}
61
62impl<T: Send + 'static> Batcher<T> {
63 pub fn spawn<F, Fut>(cfg: BatcherConfig, flush: F) -> Self
67 where
68 F: Fn(Vec<T>) -> Fut + Send + Sync + 'static,
69 Fut: std::future::Future<Output = Result<(), crate::TraceError>> + Send,
70 {
71 let (tx, mut rx) = mpsc::channel::<T>(cfg.queue_capacity);
72 let stats = Arc::new(BatcherStats::default());
73 let stats_for_task = stats.clone();
74 let handle = tokio::spawn(async move {
75 let mut buf: Vec<T> = Vec::with_capacity(cfg.max_batch);
76 let mut deadline = Instant::now() + cfg.flush_interval;
77 loop {
78 tokio::select! {
79 biased;
80 item = rx.recv() => match item {
81 Some(x) => {
82 buf.push(x);
83 if buf.len() >= cfg.max_batch {
84 Self::do_flush(&mut buf, &flush, &stats_for_task).await;
85 deadline = Instant::now() + cfg.flush_interval;
86 }
87 }
88 None => {
89 if !buf.is_empty() {
90 Self::do_flush(&mut buf, &flush, &stats_for_task).await;
91 }
92 break;
93 }
94 },
95 _ = tokio::time::sleep_until(deadline) => {
96 if !buf.is_empty() {
97 Self::do_flush(&mut buf, &flush, &stats_for_task).await;
98 }
99 deadline = Instant::now() + cfg.flush_interval;
100 }
101 }
102 }
103 });
104 Self {
105 tx,
106 stats,
107 handle: Some(handle),
108 }
109 }
110
111 async fn do_flush<F, Fut>(buf: &mut Vec<T>, flush: &F, stats: &Arc<BatcherStats>)
112 where
113 F: Fn(Vec<T>) -> Fut,
114 Fut: std::future::Future<Output = Result<(), crate::TraceError>>,
115 {
116 let count = buf.len();
117 let batch = std::mem::take(buf);
118 match flush(batch).await {
119 Ok(()) => {
120 stats.sent.fetch_add(count, Ordering::Relaxed);
121 }
122 Err(e) => {
123 stats.failed.fetch_add(count, Ordering::Relaxed);
124 tracing::warn!(error = %e, dropped = count, "trace batcher flush failed");
125 }
126 }
127 }
128
129 pub fn send(&self, item: T) {
132 if let Err(_e) = self.tx.try_send(item) {
133 self.stats.dropped.fetch_add(1, Ordering::Relaxed);
134 }
135 }
136
137 pub fn stats(&self) -> Arc<BatcherStats> {
139 self.stats.clone()
140 }
141
142 pub async fn shutdown(mut self) {
145 drop(self.tx);
146 if let Some(h) = self.handle.take() {
147 let _ = h.await;
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use std::sync::Mutex;
156 use tokio::time::sleep;
157
158 #[tokio::test(flavor = "current_thread", start_paused = true)]
159 async fn flushes_when_batch_full() {
160 let collected = Arc::new(Mutex::new(Vec::<u32>::new()));
161 let c2 = collected.clone();
162 let cfg = BatcherConfig {
163 max_batch: 3,
164 flush_interval: Duration::from_secs(60),
165 queue_capacity: 100,
166 };
167 let b = Batcher::spawn(cfg, move |batch: Vec<u32>| {
168 let c = c2.clone();
169 async move {
170 c.lock().unwrap().extend(batch);
171 Ok(())
172 }
173 });
174 b.send(1);
175 b.send(2);
176 b.send(3); tokio::task::yield_now().await;
179 sleep(Duration::from_millis(1)).await;
180 assert_eq!(*collected.lock().unwrap(), vec![1, 2, 3]);
181 b.shutdown().await;
182 }
183
184 #[tokio::test(flavor = "current_thread", start_paused = true)]
185 async fn flushes_on_interval_with_partial_batch() {
186 let collected = Arc::new(Mutex::new(Vec::<u32>::new()));
187 let c2 = collected.clone();
188 let cfg = BatcherConfig {
189 max_batch: 100,
190 flush_interval: Duration::from_millis(100),
191 queue_capacity: 100,
192 };
193 let b = Batcher::spawn(cfg, move |batch: Vec<u32>| {
194 let c = c2.clone();
195 async move {
196 c.lock().unwrap().extend(batch);
197 Ok(())
198 }
199 });
200 b.send(7);
201 sleep(Duration::from_millis(150)).await;
202 assert_eq!(*collected.lock().unwrap(), vec![7]);
203 b.shutdown().await;
204 }
205
206 #[tokio::test(flavor = "current_thread")]
207 async fn drops_count_when_queue_full() {
208 let cfg = BatcherConfig {
209 max_batch: 1,
210 flush_interval: Duration::from_secs(60),
211 queue_capacity: 1,
212 };
213 let b = Batcher::spawn(cfg, |_batch: Vec<u32>| async move {
214 tokio::time::sleep(Duration::from_secs(60)).await;
215 Ok(())
216 });
217 for i in 0..50 {
220 b.send(i);
221 }
222 let (_, dropped, _) = b.stats().snapshot();
223 assert!(dropped > 0, "expected some drops, got {dropped}");
224 }
225}