1use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use dashmap::DashMap;
10use tokio::sync::oneshot;
11use serde::{Deserialize, Serialize};
12
13pub type TableId = String;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct BatchTicketId(u64);
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BatchConfig {
23 pub enabled: bool,
25 pub max_batch_size: usize,
27 pub max_wait_ms: u64,
29 pub max_batch_bytes: usize,
31 pub auto_flush: bool,
33 pub batch_tables: Vec<String>,
35}
36
37impl Default for BatchConfig {
38 fn default() -> Self {
39 Self {
40 enabled: true,
41 max_batch_size: 1000,
42 max_wait_ms: 10,
43 max_batch_bytes: 16 * 1024 * 1024, auto_flush: true,
45 batch_tables: Vec::new(), }
47 }
48}
49
50#[derive(Debug)]
52pub struct InsertRequest {
53 pub table: String,
55 pub columns: Vec<String>,
57 pub values: Vec<Vec<String>>,
59 pub original_sql: String,
61 pub submitted_at: Instant,
63 response_tx: Option<oneshot::Sender<BatchResult>>,
65}
66
67#[derive(Debug, Clone)]
69pub struct BatchResult {
70 pub ticket_id: BatchTicketId,
72 pub rows_inserted: u64,
74 pub success: bool,
76 pub error: Option<String>,
78 pub wait_time: Duration,
80 pub execution_time: Duration,
82}
83
84pub struct BatchTicket {
86 id: BatchTicketId,
87 rx: oneshot::Receiver<BatchResult>,
88}
89
90impl BatchTicket {
91 pub async fn wait(self) -> Result<BatchResult, BatchError> {
93 self.rx.await.map_err(|_| BatchError::ChannelClosed)
94 }
95
96 pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
98 tokio::time::timeout(timeout, self.rx)
99 .await
100 .map_err(|_| BatchError::Timeout)?
101 .map_err(|_| BatchError::ChannelClosed)
102 }
103
104 pub fn id(&self) -> BatchTicketId {
106 self.id
107 }
108}
109
110#[derive(Debug, Clone)]
112pub enum BatchError {
113 Disabled,
115 BatchFull,
117 Timeout,
119 ChannelClosed,
121 ExecutionFailed(String),
123}
124
125impl std::fmt::Display for BatchError {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 match self {
128 Self::Disabled => write!(f, "Batching is disabled"),
129 Self::BatchFull => write!(f, "Batch is full"),
130 Self::Timeout => write!(f, "Batch timeout"),
131 Self::ChannelClosed => write!(f, "Channel closed"),
132 Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
133 }
134 }
135}
136
137impl std::error::Error for BatchError {}
138
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct BatchStats {
142 pub inserts_received: u64,
144 pub rows_received: u64,
146 pub batches_flushed: u64,
148 pub rows_inserted: u64,
150 pub avg_batch_size: f64,
152 pub avg_wait_time_ms: f64,
154 pub avg_execution_time_ms: f64,
156 pub size_triggered_flushes: u64,
158 pub time_triggered_flushes: u64,
160}
161
162struct PendingBatch {
164 requests: Vec<InsertRequest>,
166 row_count: usize,
168 byte_count: usize,
170 first_submitted: Instant,
172}
173
174impl PendingBatch {
175 fn new() -> Self {
176 Self {
177 requests: Vec::with_capacity(100),
178 row_count: 0,
179 byte_count: 0,
180 first_submitted: Instant::now(),
181 }
182 }
183
184 fn add(&mut self, request: InsertRequest) {
185 let row_count = request.values.len();
186 let byte_estimate = request.original_sql.len();
187
188 if self.requests.is_empty() {
189 self.first_submitted = request.submitted_at;
190 }
191
192 self.row_count += row_count;
193 self.byte_count += byte_estimate;
194 self.requests.push(request);
195 }
196
197 fn is_empty(&self) -> bool {
198 self.requests.is_empty()
199 }
200
201 fn should_flush(&self, config: &BatchConfig) -> bool {
202 self.row_count >= config.max_batch_size ||
203 self.byte_count >= config.max_batch_bytes ||
204 self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
205 }
206
207 fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
208 let row_count = self.row_count;
209 self.row_count = 0;
210 self.byte_count = 0;
211 (std::mem::take(&mut self.requests), row_count)
212 }
213}
214
215pub struct InsertBatcher {
219 config: BatchConfig,
221 pending: DashMap<TableId, PendingBatch>,
223 next_ticket_id: AtomicU64,
225 stats: Arc<parking_lot::RwLock<BatchStats>>,
227 shutdown: AtomicBool,
229}
230
231impl InsertBatcher {
232 pub fn new(config: BatchConfig) -> Self {
234 Self {
235 config,
236 pending: DashMap::new(),
237 next_ticket_id: AtomicU64::new(1),
238 stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
239 shutdown: AtomicBool::new(false),
240 }
241 }
242
243 pub fn add(
245 &self,
246 table: String,
247 columns: Vec<String>,
248 values: Vec<Vec<String>>,
249 original_sql: String,
250 ) -> Result<BatchTicket, BatchError> {
251 if !self.config.enabled {
252 return Err(BatchError::Disabled);
253 }
254
255 if self.shutdown.load(Ordering::Relaxed) {
256 return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
257 }
258
259 if !self.config.batch_tables.is_empty() &&
261 !self.config.batch_tables.contains(&table)
262 {
263 return Err(BatchError::Disabled);
264 }
265
266 let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
267 let (tx, rx) = oneshot::channel();
268
269 let row_count = values.len();
270
271 let request = InsertRequest {
272 table: table.clone(),
273 columns,
274 values,
275 original_sql,
276 submitted_at: Instant::now(),
277 response_tx: Some(tx),
278 };
279
280 {
282 let mut stats = self.stats.write();
283 stats.inserts_received += 1;
284 stats.rows_received += row_count as u64;
285 }
286
287 let should_flush = {
289 let mut batch = self.pending.entry(table.clone()).or_insert_with(PendingBatch::new);
290 batch.add(request);
291 batch.should_flush(&self.config)
292 };
293
294 if should_flush {
296 self.flush_batch(&table);
297 }
298
299 Ok(BatchTicket { id: ticket_id, rx })
300 }
301
302 pub fn flush_batch(&self, table: &str) {
304 if let Some((_, mut batch)) = self.pending.remove(table) {
305 if batch.is_empty() {
306 return;
307 }
308
309 let (requests, row_count) = batch.drain();
310 let execution_start = Instant::now();
311
312 let _combined_sql = self.combine_inserts(&requests);
314
315 let success = true; let error: Option<String> = None;
319
320 let execution_time = execution_start.elapsed();
321
322 {
324 let mut stats = self.stats.write();
325 stats.batches_flushed += 1;
326 stats.rows_inserted += row_count as u64;
327
328 if stats.batches_flushed == 1 {
330 stats.avg_batch_size = row_count as f64;
331 } else {
332 stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
333 }
334
335 let exec_ms = execution_time.as_millis() as f64;
337 if stats.batches_flushed == 1 {
338 stats.avg_execution_time_ms = exec_ms;
339 } else {
340 stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
341 }
342 }
343
344 for mut req in requests {
346 let wait_time = req.submitted_at.elapsed() - execution_time;
347
348 if let Some(tx) = req.response_tx.take() {
349 let _ = tx.send(BatchResult {
350 ticket_id: BatchTicketId(0), rows_inserted: req.values.len() as u64,
352 success,
353 error: error.clone(),
354 wait_time,
355 execution_time,
356 });
357 }
358 }
359 }
360 }
361
362 fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
364 if requests.is_empty() {
365 return String::new();
366 }
367
368 let first = &requests[0];
369 let table = &first.table;
370 let columns = &first.columns;
371
372 let mut sql = format!(
373 "INSERT INTO {} ({}) VALUES ",
374 table,
375 columns.join(", ")
376 );
377
378 let mut value_parts: Vec<String> = Vec::new();
379
380 for req in requests {
381 for row in &req.values {
382 value_parts.push(format!("({})", row.join(", ")));
383 }
384 }
385
386 sql.push_str(&value_parts.join(", "));
387
388 sql
389 }
390
391 pub fn flush_all(&self) {
393 let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
394 for table in tables {
395 self.flush_batch(&table);
396 }
397 }
398
399 pub fn batch_size(&self, table: &str) -> usize {
401 self.pending
402 .get(table)
403 .map(|b| b.row_count)
404 .unwrap_or(0)
405 }
406
407 pub fn stats(&self) -> BatchStats {
409 self.stats.read().clone()
410 }
411
412 pub fn shutdown(&self) {
414 self.shutdown.store(true, Ordering::Release);
415 self.flush_all();
416 }
417
418 pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
420 let interval = Duration::from_millis(self.config.max_wait_ms);
421
422 tokio::spawn(async move {
423 let mut interval_timer = tokio::time::interval(interval);
424
425 loop {
426 interval_timer.tick().await;
427
428 if self.shutdown.load(Ordering::Relaxed) {
429 break;
430 }
431
432 let tables: Vec<TableId> = self.pending
434 .iter()
435 .filter(|r| {
436 r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
437 })
438 .map(|r| r.key().clone())
439 .collect();
440
441 for table in tables {
442 self.flush_batch(&table);
443 self.stats.write().time_triggered_flushes += 1;
444 }
445 }
446 })
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[tokio::test]
455 async fn test_batch_add() {
456 let batcher = InsertBatcher::new(BatchConfig::default());
457
458 let ticket = batcher.add(
459 "users".to_string(),
460 vec!["id".to_string(), "name".to_string()],
461 vec![vec!["1".to_string(), "'Alice'".to_string()]],
462 "INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
463 ).unwrap();
464
465 assert_eq!(batcher.batch_size("users"), 1);
466 }
467
468 #[tokio::test]
469 async fn test_batch_flush_on_size() {
470 let config = BatchConfig {
471 max_batch_size: 2,
472 ..Default::default()
473 };
474 let batcher = InsertBatcher::new(config);
475
476 batcher.add(
478 "users".to_string(),
479 vec!["id".to_string()],
480 vec![vec!["1".to_string()]],
481 "INSERT INTO users VALUES (1)".to_string(),
482 ).unwrap();
483
484 assert_eq!(batcher.batch_size("users"), 1);
485
486 batcher.add(
488 "users".to_string(),
489 vec!["id".to_string()],
490 vec![vec!["2".to_string()]],
491 "INSERT INTO users VALUES (2)".to_string(),
492 ).unwrap();
493
494 assert_eq!(batcher.batch_size("users"), 0);
496 }
497
498 #[test]
499 fn test_combine_inserts() {
500 let batcher = InsertBatcher::new(BatchConfig::default());
501
502 let requests = vec![
503 InsertRequest {
504 table: "users".to_string(),
505 columns: vec!["id".to_string(), "name".to_string()],
506 values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
507 original_sql: String::new(),
508 submitted_at: Instant::now(),
509 response_tx: None,
510 },
511 InsertRequest {
512 table: "users".to_string(),
513 columns: vec!["id".to_string(), "name".to_string()],
514 values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
515 original_sql: String::new(),
516 submitted_at: Instant::now(),
517 response_tx: None,
518 },
519 ];
520
521 let combined = batcher.combine_inserts(&requests);
522 assert!(combined.contains("INSERT INTO users"));
523 assert!(combined.contains("(1, 'Alice')"));
524 assert!(combined.contains("(2, 'Bob')"));
525 }
526
527 #[test]
528 fn test_batch_stats() {
529 let batcher = InsertBatcher::new(BatchConfig::default());
530
531 batcher.add(
532 "users".to_string(),
533 vec!["id".to_string()],
534 vec![vec!["1".to_string()], vec!["2".to_string()]],
535 "INSERT INTO users VALUES (1), (2)".to_string(),
536 ).unwrap();
537
538 let stats = batcher.stats();
539 assert_eq!(stats.inserts_received, 1);
540 assert_eq!(stats.rows_received, 2);
541 }
542}