1use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::oneshot;
12
13pub type TableId = String;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct BatchTicketId(u64);
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BatchConfig {
23 pub enabled: bool,
25 pub max_batch_size: usize,
27 pub max_wait_ms: u64,
29 pub max_batch_bytes: usize,
31 pub auto_flush: bool,
33 pub batch_tables: Vec<String>,
35}
36
37impl Default for BatchConfig {
38 fn default() -> Self {
39 Self {
40 enabled: true,
41 max_batch_size: 1000,
42 max_wait_ms: 10,
43 max_batch_bytes: 16 * 1024 * 1024, auto_flush: true,
45 batch_tables: Vec::new(), }
47 }
48}
49
50#[derive(Debug)]
52pub struct InsertRequest {
53 pub table: String,
55 pub columns: Vec<String>,
57 pub values: Vec<Vec<String>>,
59 pub original_sql: String,
61 pub submitted_at: Instant,
63 response_tx: Option<oneshot::Sender<BatchResult>>,
65}
66
67#[derive(Debug, Clone)]
69pub struct BatchResult {
70 pub ticket_id: BatchTicketId,
72 pub rows_inserted: u64,
74 pub success: bool,
76 pub error: Option<String>,
78 pub wait_time: Duration,
80 pub execution_time: Duration,
82}
83
84pub struct BatchTicket {
86 id: BatchTicketId,
87 rx: oneshot::Receiver<BatchResult>,
88}
89
90impl BatchTicket {
91 pub async fn wait(self) -> Result<BatchResult, BatchError> {
93 self.rx.await.map_err(|_| BatchError::ChannelClosed)
94 }
95
96 pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
98 tokio::time::timeout(timeout, self.rx)
99 .await
100 .map_err(|_| BatchError::Timeout)?
101 .map_err(|_| BatchError::ChannelClosed)
102 }
103
104 pub fn id(&self) -> BatchTicketId {
106 self.id
107 }
108}
109
110#[derive(Debug, Clone)]
112pub enum BatchError {
113 Disabled,
115 BatchFull,
117 Timeout,
119 ChannelClosed,
121 ExecutionFailed(String),
123}
124
125impl std::fmt::Display for BatchError {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 match self {
128 Self::Disabled => write!(f, "Batching is disabled"),
129 Self::BatchFull => write!(f, "Batch is full"),
130 Self::Timeout => write!(f, "Batch timeout"),
131 Self::ChannelClosed => write!(f, "Channel closed"),
132 Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
133 }
134 }
135}
136
137impl std::error::Error for BatchError {}
138
139#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct BatchStats {
142 pub inserts_received: u64,
144 pub rows_received: u64,
146 pub batches_flushed: u64,
148 pub rows_inserted: u64,
150 pub avg_batch_size: f64,
152 pub avg_wait_time_ms: f64,
154 pub avg_execution_time_ms: f64,
156 pub size_triggered_flushes: u64,
158 pub time_triggered_flushes: u64,
160 pub flush_failures: u64,
163}
164
165struct PendingBatch {
167 requests: Vec<InsertRequest>,
169 row_count: usize,
171 byte_count: usize,
173 first_submitted: Instant,
175}
176
177impl PendingBatch {
178 fn new() -> Self {
179 Self {
180 requests: Vec::with_capacity(100),
181 row_count: 0,
182 byte_count: 0,
183 first_submitted: Instant::now(),
184 }
185 }
186
187 fn add(&mut self, request: InsertRequest) {
188 let row_count = request.values.len();
189 let byte_estimate = request.original_sql.len();
190
191 if self.requests.is_empty() {
192 self.first_submitted = request.submitted_at;
193 }
194
195 self.row_count += row_count;
196 self.byte_count += byte_estimate;
197 self.requests.push(request);
198 }
199
200 fn is_empty(&self) -> bool {
201 self.requests.is_empty()
202 }
203
204 fn should_flush(&self, config: &BatchConfig) -> bool {
205 self.row_count >= config.max_batch_size
206 || self.byte_count >= config.max_batch_bytes
207 || self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
208 }
209
210 fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
211 let row_count = self.row_count;
212 self.row_count = 0;
213 self.byte_count = 0;
214 (std::mem::take(&mut self.requests), row_count)
215 }
216}
217
218pub struct InsertBatcher {
222 config: BatchConfig,
224 pending: DashMap<TableId, PendingBatch>,
226 next_ticket_id: AtomicU64,
228 stats: Arc<parking_lot::RwLock<BatchStats>>,
230 shutdown: AtomicBool,
232 backend: Option<crate::backend::BackendConfig>,
237}
238
239struct SealedBatch {
242 requests: Vec<InsertRequest>,
243 row_count: usize,
244 sql: String,
245}
246
247impl InsertBatcher {
248 pub fn new(config: BatchConfig) -> Self {
250 Self {
251 config,
252 pending: DashMap::new(),
253 next_ticket_id: AtomicU64::new(1),
254 stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
255 shutdown: AtomicBool::new(false),
256 backend: None,
257 }
258 }
259
260 pub fn with_backend(mut self, backend: crate::backend::BackendConfig) -> Self {
263 self.backend = Some(backend);
264 self
265 }
266
267 pub fn add(
274 self: &Arc<Self>,
275 table: String,
276 columns: Vec<String>,
277 values: Vec<Vec<String>>,
278 original_sql: String,
279 ) -> Result<BatchTicket, BatchError> {
280 if !self.config.enabled {
281 return Err(BatchError::Disabled);
282 }
283
284 if self.shutdown.load(Ordering::Relaxed) {
285 return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
286 }
287
288 if !self.config.batch_tables.is_empty() && !self.config.batch_tables.contains(&table) {
290 return Err(BatchError::Disabled);
291 }
292
293 let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
294 let (tx, rx) = oneshot::channel();
295
296 let row_count = values.len();
297
298 let request = InsertRequest {
299 table: table.clone(),
300 columns,
301 values,
302 original_sql,
303 submitted_at: Instant::now(),
304 response_tx: Some(tx),
305 };
306
307 {
309 let mut stats = self.stats.write();
310 stats.inserts_received += 1;
311 stats.rows_received += row_count as u64;
312 }
313
314 let should_flush = {
316 let mut batch = self
317 .pending
318 .entry(table.clone())
319 .or_insert_with(PendingBatch::new);
320 batch.add(request);
321 batch.should_flush(&self.config)
322 };
323
324 if should_flush {
328 if let Some(sealed) = self.seal(&table) {
329 let me = Arc::clone(self);
330 tokio::spawn(async move {
331 me.execute_sealed(sealed).await;
332 });
333 }
334 }
335
336 Ok(BatchTicket { id: ticket_id, rx })
337 }
338
339 fn seal(&self, table: &str) -> Option<SealedBatch> {
344 let (_, mut batch) = self.pending.remove(table)?;
345 if batch.is_empty() {
346 return None;
347 }
348 let (requests, row_count) = batch.drain();
349 let sql = self.combine_inserts(&requests);
350 Some(SealedBatch {
351 requests,
352 row_count,
353 sql,
354 })
355 }
356
357 async fn execute_sealed(&self, sealed: SealedBatch) {
363 let SealedBatch {
364 requests,
365 row_count,
366 sql,
367 } = sealed;
368 let execution_start = Instant::now();
369
370 let (success, error) = match &self.backend {
372 Some(cfg) => match crate::backend::BackendClient::connect(cfg).await {
373 Ok(mut client) => {
374 let outcome = client.execute(&sql).await;
375 client.close().await;
376 match outcome {
377 Ok(_tag) => (true, None),
378 Err(e) => (false, Some(format!("execute: {}", e))),
379 }
380 }
381 Err(e) => (false, Some(format!("connect: {}", e))),
382 },
383 None => (false, Some("no backend configured".to_string())),
384 };
385
386 let execution_time = execution_start.elapsed();
387
388 {
390 let mut stats = self.stats.write();
391 stats.batches_flushed += 1;
392 if success {
393 stats.rows_inserted += row_count as u64;
394 } else {
395 stats.flush_failures += 1;
396 }
397
398 if stats.batches_flushed == 1 {
399 stats.avg_batch_size = row_count as f64;
400 } else {
401 stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
402 }
403
404 let exec_ms = execution_time.as_millis() as f64;
405 if stats.batches_flushed == 1 {
406 stats.avg_execution_time_ms = exec_ms;
407 } else {
408 stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
409 }
410 }
411
412 for mut req in requests {
414 let wait_time = req
415 .submitted_at
416 .elapsed()
417 .checked_sub(execution_time)
418 .unwrap_or_default();
419
420 if let Some(tx) = req.response_tx.take() {
421 let _ = tx.send(BatchResult {
422 ticket_id: BatchTicketId(0), rows_inserted: if success { req.values.len() as u64 } else { 0 },
424 success,
425 error: error.clone(),
426 wait_time,
427 execution_time,
428 });
429 }
430 }
431 }
432
433 pub async fn flush_batch(&self, table: &str) {
436 if let Some(sealed) = self.seal(table) {
437 self.execute_sealed(sealed).await;
438 }
439 }
440
441 fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
443 if requests.is_empty() {
444 return String::new();
445 }
446
447 let first = &requests[0];
448 let table = &first.table;
449 let columns = &first.columns;
450
451 let mut sql = format!("INSERT INTO {} ({}) VALUES ", table, columns.join(", "));
452
453 let mut value_parts: Vec<String> = Vec::new();
454
455 for req in requests {
456 for row in &req.values {
457 value_parts.push(format!("({})", row.join(", ")));
458 }
459 }
460
461 sql.push_str(&value_parts.join(", "));
462
463 sql
464 }
465
466 pub async fn flush_all(&self) {
468 let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
469 for table in tables {
470 self.flush_batch(&table).await;
471 }
472 }
473
474 pub fn batch_size(&self, table: &str) -> usize {
476 self.pending.get(table).map(|b| b.row_count).unwrap_or(0)
477 }
478
479 pub fn stats(&self) -> BatchStats {
481 self.stats.read().clone()
482 }
483
484 pub async fn shutdown(&self) {
487 self.shutdown.store(true, Ordering::Release);
488 self.flush_all().await;
489 }
490
491 pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
493 let interval = Duration::from_millis(self.config.max_wait_ms);
494
495 tokio::spawn(async move {
496 let mut interval_timer = tokio::time::interval(interval);
497
498 loop {
499 interval_timer.tick().await;
500
501 if self.shutdown.load(Ordering::Relaxed) {
502 break;
503 }
504
505 let tables: Vec<TableId> = self
507 .pending
508 .iter()
509 .filter(|r| {
510 r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
511 })
512 .map(|r| r.key().clone())
513 .collect();
514
515 for table in tables {
516 self.flush_batch(&table).await;
517 self.stats.write().time_triggered_flushes += 1;
518 }
519 }
520 })
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[tokio::test]
529 async fn test_batch_add() {
530 let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
531
532 let _ticket = batcher
533 .add(
534 "users".to_string(),
535 vec!["id".to_string(), "name".to_string()],
536 vec![vec!["1".to_string(), "'Alice'".to_string()]],
537 "INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
538 )
539 .unwrap();
540
541 assert_eq!(batcher.batch_size("users"), 1);
542 }
543
544 #[tokio::test]
545 async fn test_batch_flush_on_size() {
546 let config = BatchConfig {
547 max_batch_size: 2,
548 ..Default::default()
549 };
550 let batcher = Arc::new(InsertBatcher::new(config));
551
552 batcher
554 .add(
555 "users".to_string(),
556 vec!["id".to_string()],
557 vec![vec!["1".to_string()]],
558 "INSERT INTO users VALUES (1)".to_string(),
559 )
560 .unwrap();
561
562 assert_eq!(batcher.batch_size("users"), 1);
563
564 batcher
566 .add(
567 "users".to_string(),
568 vec!["id".to_string()],
569 vec![vec!["2".to_string()]],
570 "INSERT INTO users VALUES (2)".to_string(),
571 )
572 .unwrap();
573
574 assert_eq!(batcher.batch_size("users"), 0);
576 }
577
578 #[test]
579 fn test_combine_inserts() {
580 let batcher = InsertBatcher::new(BatchConfig::default());
581
582 let requests = vec![
583 InsertRequest {
584 table: "users".to_string(),
585 columns: vec!["id".to_string(), "name".to_string()],
586 values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
587 original_sql: String::new(),
588 submitted_at: Instant::now(),
589 response_tx: None,
590 },
591 InsertRequest {
592 table: "users".to_string(),
593 columns: vec!["id".to_string(), "name".to_string()],
594 values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
595 original_sql: String::new(),
596 submitted_at: Instant::now(),
597 response_tx: None,
598 },
599 ];
600
601 let combined = batcher.combine_inserts(&requests);
602 assert!(combined.contains("INSERT INTO users"));
603 assert!(combined.contains("(1, 'Alice')"));
604 assert!(combined.contains("(2, 'Bob')"));
605 }
606
607 #[test]
608 fn test_batch_stats() {
609 let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
612
613 batcher
614 .add(
615 "users".to_string(),
616 vec!["id".to_string()],
617 vec![vec!["1".to_string()], vec!["2".to_string()]],
618 "INSERT INTO users VALUES (1), (2)".to_string(),
619 )
620 .unwrap();
621
622 let stats = batcher.stats();
623 assert_eq!(stats.inserts_received, 1);
624 assert_eq!(stats.rows_received, 2);
625 }
626
627 #[tokio::test]
634 async fn flush_executes_against_live_backend() {
635 use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
636
637 let addr = match std::env::var("HELIOS_LIVE_PG") {
638 Ok(a) if !a.is_empty() => a,
639 _ => {
640 eprintln!("skipping flush_executes_against_live_backend: set HELIOS_LIVE_PG");
641 return;
642 }
643 };
644 let (host, port_s) = addr.rsplit_once(':').unwrap();
645 let port: u16 = port_s.parse().unwrap();
646 let user = std::env::var("HELIOS_LIVE_USER").unwrap_or_else(|_| "bench".into());
647 let pass = std::env::var("HELIOS_LIVE_PASS").unwrap_or_else(|_| "benchpass".into());
648 let db = std::env::var("HELIOS_LIVE_DB").unwrap_or_else(|_| "benchdb".into());
649
650 let cfg = BackendConfig {
651 host: host.to_string(),
652 port,
653 user,
654 password: Some(pass),
655 database: Some(db),
656 application_name: Some("helios-batch-test".into()),
657 tls_mode: TlsMode::Disable,
658 connect_timeout: Duration::from_secs(5),
659 query_timeout: Duration::from_secs(5),
660 tls_config: default_client_config(),
661 };
662
663 let mut seed = BackendClient::connect(&cfg).await.expect("connect seed");
665 seed.execute("DROP TABLE IF EXISTS batch_probe").await.unwrap();
666 seed.execute("CREATE TABLE batch_probe(id int, total numeric)")
667 .await
668 .unwrap();
669 seed.close().await;
670
671 let batcher = Arc::new(
674 InsertBatcher::new(BatchConfig {
675 max_batch_size: 1000,
676 max_wait_ms: 60_000,
677 ..Default::default()
678 })
679 .with_backend(cfg.clone()),
680 );
681 batcher
682 .add(
683 "batch_probe".to_string(),
684 vec!["id".to_string(), "total".to_string()],
685 vec![
686 vec!["1".to_string(), "99.99".to_string()],
687 vec!["2".to_string(), "12.50".to_string()],
688 ],
689 String::new(),
690 )
691 .unwrap();
692 batcher.flush_batch("batch_probe").await;
693
694 let mut verify = BackendClient::connect(&cfg).await.expect("connect verify");
696 let n = verify
697 .query_scalar("SELECT count(*) AS n FROM batch_probe")
698 .await
699 .unwrap()
700 .as_i64("n")
701 .unwrap()
702 .unwrap_or(0);
703 let _ = verify.execute("DROP TABLE IF EXISTS batch_probe").await;
704 verify.close().await;
705
706 assert_eq!(n, 2, "expected 2 batched rows to land, found {}", n);
707 assert_eq!(batcher.stats().rows_inserted, 2);
708 assert_eq!(batcher.stats().flush_failures, 0);
709 }
710}