Skip to main content

heliosdb_proxy/
batch.rs

1//! INSERT Batching for HeliosProxy
2//!
3//! Batches multiple INSERT statements into combined bulk operations for
4//! improved throughput. Reduces round-trips and enables lock-free bulk ingestion.
5
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::oneshot;
12
13/// Table identifier
14pub type TableId = String;
15
16/// Batch ticket ID
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct BatchTicketId(u64);
19
20/// Batch configuration
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BatchConfig {
23    /// Enable INSERT batching
24    pub enabled: bool,
25    /// Maximum batch size (number of rows)
26    pub max_batch_size: usize,
27    /// Maximum batch wait time (ms) before flushing
28    pub max_wait_ms: u64,
29    /// Maximum memory per batch (bytes)
30    pub max_batch_bytes: usize,
31    /// Enable automatic flushing
32    pub auto_flush: bool,
33    /// Tables to batch (empty = all tables)
34    pub batch_tables: Vec<String>,
35}
36
37impl Default for BatchConfig {
38    fn default() -> Self {
39        Self {
40            enabled: true,
41            max_batch_size: 1000,
42            max_wait_ms: 10,
43            max_batch_bytes: 16 * 1024 * 1024, // 16MB
44            auto_flush: true,
45            batch_tables: Vec::new(), // Batch all tables by default
46        }
47    }
48}
49
50/// An individual INSERT request
51#[derive(Debug)]
52pub struct InsertRequest {
53    /// Table name
54    pub table: String,
55    /// Column names
56    pub columns: Vec<String>,
57    /// Row values (each inner vec is a row)
58    pub values: Vec<Vec<String>>,
59    /// Original SQL (for fallback)
60    pub original_sql: String,
61    /// Request timestamp
62    pub submitted_at: Instant,
63    /// Response channel
64    response_tx: Option<oneshot::Sender<BatchResult>>,
65}
66
67/// Result of a batch operation
68#[derive(Debug, Clone)]
69pub struct BatchResult {
70    /// Ticket ID
71    pub ticket_id: BatchTicketId,
72    /// Number of rows inserted
73    pub rows_inserted: u64,
74    /// Whether the batch succeeded
75    pub success: bool,
76    /// Error message if failed
77    pub error: Option<String>,
78    /// Time spent waiting in batch
79    pub wait_time: Duration,
80    /// Execution time
81    pub execution_time: Duration,
82}
83
84/// Ticket for awaiting batch completion
85pub struct BatchTicket {
86    id: BatchTicketId,
87    rx: oneshot::Receiver<BatchResult>,
88}
89
90impl BatchTicket {
91    /// Wait for the batch to complete
92    pub async fn wait(self) -> Result<BatchResult, BatchError> {
93        self.rx.await.map_err(|_| BatchError::ChannelClosed)
94    }
95
96    /// Wait with timeout
97    pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
98        tokio::time::timeout(timeout, self.rx)
99            .await
100            .map_err(|_| BatchError::Timeout)?
101            .map_err(|_| BatchError::ChannelClosed)
102    }
103
104    /// Get the ticket ID
105    pub fn id(&self) -> BatchTicketId {
106        self.id
107    }
108}
109
110/// Batch error types
111#[derive(Debug, Clone)]
112pub enum BatchError {
113    /// Batching is disabled
114    Disabled,
115    /// Batch is full
116    BatchFull,
117    /// Timeout waiting for batch
118    Timeout,
119    /// Channel closed
120    ChannelClosed,
121    /// Execution failed
122    ExecutionFailed(String),
123}
124
125impl std::fmt::Display for BatchError {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        match self {
128            Self::Disabled => write!(f, "Batching is disabled"),
129            Self::BatchFull => write!(f, "Batch is full"),
130            Self::Timeout => write!(f, "Batch timeout"),
131            Self::ChannelClosed => write!(f, "Channel closed"),
132            Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
133        }
134    }
135}
136
137impl std::error::Error for BatchError {}
138
139/// Statistics for batch operations
140#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct BatchStats {
142    /// Total inserts received
143    pub inserts_received: u64,
144    /// Total rows received
145    pub rows_received: u64,
146    /// Total batches flushed
147    pub batches_flushed: u64,
148    /// Total rows inserted
149    pub rows_inserted: u64,
150    /// Average batch size
151    pub avg_batch_size: f64,
152    /// Average wait time (ms)
153    pub avg_wait_time_ms: f64,
154    /// Average execution time (ms)
155    pub avg_execution_time_ms: f64,
156    /// Batches flushed due to size limit
157    pub size_triggered_flushes: u64,
158    /// Batches flushed due to time limit
159    pub time_triggered_flushes: u64,
160    /// Batches whose combined INSERT failed to execute (no backend,
161    /// connect error, or SQL error).
162    pub flush_failures: u64,
163}
164
165/// Pending batch for a table
166struct PendingBatch {
167    /// INSERT requests in this batch
168    requests: Vec<InsertRequest>,
169    /// Total rows in batch
170    row_count: usize,
171    /// Total bytes in batch (estimated)
172    byte_count: usize,
173    /// First request timestamp
174    first_submitted: Instant,
175}
176
177impl PendingBatch {
178    fn new() -> Self {
179        Self {
180            requests: Vec::with_capacity(100),
181            row_count: 0,
182            byte_count: 0,
183            first_submitted: Instant::now(),
184        }
185    }
186
187    fn add(&mut self, request: InsertRequest) {
188        let row_count = request.values.len();
189        let byte_estimate = request.original_sql.len();
190
191        if self.requests.is_empty() {
192            self.first_submitted = request.submitted_at;
193        }
194
195        self.row_count += row_count;
196        self.byte_count += byte_estimate;
197        self.requests.push(request);
198    }
199
200    fn is_empty(&self) -> bool {
201        self.requests.is_empty()
202    }
203
204    fn should_flush(&self, config: &BatchConfig) -> bool {
205        self.row_count >= config.max_batch_size
206            || self.byte_count >= config.max_batch_bytes
207            || self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
208    }
209
210    fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
211        let row_count = self.row_count;
212        self.row_count = 0;
213        self.byte_count = 0;
214        (std::mem::take(&mut self.requests), row_count)
215    }
216}
217
218/// INSERT Batcher
219///
220/// Batches INSERT statements for improved throughput.
221pub struct InsertBatcher {
222    /// Configuration
223    config: BatchConfig,
224    /// Pending batches per table
225    pending: DashMap<TableId, PendingBatch>,
226    /// Next ticket ID
227    next_ticket_id: AtomicU64,
228    /// Statistics
229    stats: Arc<parking_lot::RwLock<BatchStats>>,
230    /// Shutdown flag
231    shutdown: AtomicBool,
232    /// Backend the combined INSERTs actually execute against. When
233    /// `None` the batcher still batches/seals but cannot execute — it
234    /// reports an honest failure to waiters rather than claiming a
235    /// silent success.
236    backend: Option<crate::backend::BackendConfig>,
237}
238
239/// A batch sealed out of `pending` and ready to execute: the drained
240/// requests, the row count, and the combined bulk-INSERT SQL.
241struct SealedBatch {
242    requests: Vec<InsertRequest>,
243    row_count: usize,
244    sql: String,
245}
246
247impl InsertBatcher {
248    /// Create a new INSERT batcher
249    pub fn new(config: BatchConfig) -> Self {
250        Self {
251            config,
252            pending: DashMap::new(),
253            next_ticket_id: AtomicU64::new(1),
254            stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
255            shutdown: AtomicBool::new(false),
256            backend: None,
257        }
258    }
259
260    /// Attach the backend the combined INSERTs execute against. Without
261    /// it the batcher seals and combines but cannot run the SQL.
262    pub fn with_backend(mut self, backend: crate::backend::BackendConfig) -> Self {
263        self.backend = Some(backend);
264        self
265    }
266
267    /// Add an INSERT to the batch.
268    ///
269    /// Takes `&Arc<Self>` so a size-triggered flush can seal the batch
270    /// synchronously (callers observing `batch_size` see it drop
271    /// immediately) and then execute the combined INSERT for real on a
272    /// spawned task. Must be called from within a Tokio runtime.
273    pub fn add(
274        self: &Arc<Self>,
275        table: String,
276        columns: Vec<String>,
277        values: Vec<Vec<String>>,
278        original_sql: String,
279    ) -> Result<BatchTicket, BatchError> {
280        if !self.config.enabled {
281            return Err(BatchError::Disabled);
282        }
283
284        if self.shutdown.load(Ordering::Relaxed) {
285            return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
286        }
287
288        // Check if table should be batched
289        if !self.config.batch_tables.is_empty() && !self.config.batch_tables.contains(&table) {
290            return Err(BatchError::Disabled);
291        }
292
293        let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
294        let (tx, rx) = oneshot::channel();
295
296        let row_count = values.len();
297
298        let request = InsertRequest {
299            table: table.clone(),
300            columns,
301            values,
302            original_sql,
303            submitted_at: Instant::now(),
304            response_tx: Some(tx),
305        };
306
307        // Update statistics
308        {
309            let mut stats = self.stats.write();
310            stats.inserts_received += 1;
311            stats.rows_received += row_count as u64;
312        }
313
314        // Add to pending batch
315        let should_flush = {
316            let mut batch = self
317                .pending
318                .entry(table.clone())
319                .or_insert_with(PendingBatch::new);
320            batch.add(request);
321            batch.should_flush(&self.config)
322        };
323
324        // Trigger flush if needed. Seal synchronously so `batch_size`
325        // reflects the flush immediately, then execute the combined
326        // INSERT for real on a spawned task.
327        if should_flush {
328            if let Some(sealed) = self.seal(&table) {
329                let me = Arc::clone(self);
330                tokio::spawn(async move {
331                    me.execute_sealed(sealed).await;
332                });
333            }
334        }
335
336        Ok(BatchTicket { id: ticket_id, rx })
337    }
338
339    /// Seal a table's pending batch: remove it from `pending`, drain it,
340    /// and combine the rows into one bulk INSERT. Synchronous, so
341    /// observers of `batch_size` see the flush immediately. Returns
342    /// `None` if there was nothing pending.
343    fn seal(&self, table: &str) -> Option<SealedBatch> {
344        let (_, mut batch) = self.pending.remove(table)?;
345        if batch.is_empty() {
346            return None;
347        }
348        let (requests, row_count) = batch.drain();
349        let sql = self.combine_inserts(&requests);
350        Some(SealedBatch {
351            requests,
352            row_count,
353            sql,
354        })
355    }
356
357    /// Execute a sealed batch's combined INSERT against the backend and
358    /// notify every waiting request with the real outcome. When no
359    /// backend is configured (or the connection/execution fails) the
360    /// waiters receive `success = false` with the reason — never a
361    /// fabricated success.
362    async fn execute_sealed(&self, sealed: SealedBatch) {
363        let SealedBatch {
364            requests,
365            row_count,
366            sql,
367        } = sealed;
368        let execution_start = Instant::now();
369
370        // Actually run the combined INSERT.
371        let (success, error) = match &self.backend {
372            Some(cfg) => match crate::backend::BackendClient::connect(cfg).await {
373                Ok(mut client) => {
374                    let outcome = client.execute(&sql).await;
375                    client.close().await;
376                    match outcome {
377                        Ok(_tag) => (true, None),
378                        Err(e) => (false, Some(format!("execute: {}", e))),
379                    }
380                }
381                Err(e) => (false, Some(format!("connect: {}", e))),
382            },
383            None => (false, Some("no backend configured".to_string())),
384        };
385
386        let execution_time = execution_start.elapsed();
387
388        // Update statistics (only count rows that actually landed).
389        {
390            let mut stats = self.stats.write();
391            stats.batches_flushed += 1;
392            if success {
393                stats.rows_inserted += row_count as u64;
394            } else {
395                stats.flush_failures += 1;
396            }
397
398            if stats.batches_flushed == 1 {
399                stats.avg_batch_size = row_count as f64;
400            } else {
401                stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
402            }
403
404            let exec_ms = execution_time.as_millis() as f64;
405            if stats.batches_flushed == 1 {
406                stats.avg_execution_time_ms = exec_ms;
407            } else {
408                stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
409            }
410        }
411
412        // Send responses to all waiting requests.
413        for mut req in requests {
414            let wait_time = req
415                .submitted_at
416                .elapsed()
417                .checked_sub(execution_time)
418                .unwrap_or_default();
419
420            if let Some(tx) = req.response_tx.take() {
421                let _ = tx.send(BatchResult {
422                    ticket_id: BatchTicketId(0), // Individual tickets not tracked
423                    rows_inserted: if success { req.values.len() as u64 } else { 0 },
424                    success,
425                    error: error.clone(),
426                    wait_time,
427                    execution_time,
428                });
429            }
430        }
431    }
432
433    /// Flush a single table's batch: seal it and execute the combined
434    /// INSERT against the backend.
435    pub async fn flush_batch(&self, table: &str) {
436        if let Some(sealed) = self.seal(table) {
437            self.execute_sealed(sealed).await;
438        }
439    }
440
441    /// Combine multiple INSERT requests into a single SQL statement
442    fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
443        if requests.is_empty() {
444            return String::new();
445        }
446
447        let first = &requests[0];
448        let table = &first.table;
449        let columns = &first.columns;
450
451        let mut sql = format!("INSERT INTO {} ({}) VALUES ", table, columns.join(", "));
452
453        let mut value_parts: Vec<String> = Vec::new();
454
455        for req in requests {
456            for row in &req.values {
457                value_parts.push(format!("({})", row.join(", ")));
458            }
459        }
460
461        sql.push_str(&value_parts.join(", "));
462
463        sql
464    }
465
466    /// Flush all pending batches, executing each combined INSERT.
467    pub async fn flush_all(&self) {
468        let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
469        for table in tables {
470            self.flush_batch(&table).await;
471        }
472    }
473
474    /// Get the current batch size for a table
475    pub fn batch_size(&self, table: &str) -> usize {
476        self.pending.get(table).map(|b| b.row_count).unwrap_or(0)
477    }
478
479    /// Get statistics snapshot
480    pub fn stats(&self) -> BatchStats {
481        self.stats.read().clone()
482    }
483
484    /// Shutdown the batcher: stop accepting work and flush whatever is
485    /// pending (executing each combined INSERT).
486    pub async fn shutdown(&self) {
487        self.shutdown.store(true, Ordering::Release);
488        self.flush_all().await;
489    }
490
491    /// Start auto-flush background task
492    pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
493        let interval = Duration::from_millis(self.config.max_wait_ms);
494
495        tokio::spawn(async move {
496            let mut interval_timer = tokio::time::interval(interval);
497
498            loop {
499                interval_timer.tick().await;
500
501                if self.shutdown.load(Ordering::Relaxed) {
502                    break;
503                }
504
505                // Check each batch for timeout
506                let tables: Vec<TableId> = self
507                    .pending
508                    .iter()
509                    .filter(|r| {
510                        r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
511                    })
512                    .map(|r| r.key().clone())
513                    .collect();
514
515                for table in tables {
516                    self.flush_batch(&table).await;
517                    self.stats.write().time_triggered_flushes += 1;
518                }
519            }
520        })
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[tokio::test]
529    async fn test_batch_add() {
530        let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
531
532        let _ticket = batcher
533            .add(
534                "users".to_string(),
535                vec!["id".to_string(), "name".to_string()],
536                vec![vec!["1".to_string(), "'Alice'".to_string()]],
537                "INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
538            )
539            .unwrap();
540
541        assert_eq!(batcher.batch_size("users"), 1);
542    }
543
544    #[tokio::test]
545    async fn test_batch_flush_on_size() {
546        let config = BatchConfig {
547            max_batch_size: 2,
548            ..Default::default()
549        };
550        let batcher = Arc::new(InsertBatcher::new(config));
551
552        // Add first INSERT
553        batcher
554            .add(
555                "users".to_string(),
556                vec!["id".to_string()],
557                vec![vec!["1".to_string()]],
558                "INSERT INTO users VALUES (1)".to_string(),
559            )
560            .unwrap();
561
562        assert_eq!(batcher.batch_size("users"), 1);
563
564        // Add second INSERT - should trigger flush
565        batcher
566            .add(
567                "users".to_string(),
568                vec!["id".to_string()],
569                vec![vec!["2".to_string()]],
570                "INSERT INTO users VALUES (2)".to_string(),
571            )
572            .unwrap();
573
574        // Batch should be flushed
575        assert_eq!(batcher.batch_size("users"), 0);
576    }
577
578    #[test]
579    fn test_combine_inserts() {
580        let batcher = InsertBatcher::new(BatchConfig::default());
581
582        let requests = vec![
583            InsertRequest {
584                table: "users".to_string(),
585                columns: vec!["id".to_string(), "name".to_string()],
586                values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
587                original_sql: String::new(),
588                submitted_at: Instant::now(),
589                response_tx: None,
590            },
591            InsertRequest {
592                table: "users".to_string(),
593                columns: vec!["id".to_string(), "name".to_string()],
594                values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
595                original_sql: String::new(),
596                submitted_at: Instant::now(),
597                response_tx: None,
598            },
599        ];
600
601        let combined = batcher.combine_inserts(&requests);
602        assert!(combined.contains("INSERT INTO users"));
603        assert!(combined.contains("(1, 'Alice')"));
604        assert!(combined.contains("(2, 'Bob')"));
605    }
606
607    #[test]
608    fn test_batch_stats() {
609        // Default config won't trigger a size/time flush for 2 rows, so no
610        // task is spawned and this stays a plain (non-async) test.
611        let batcher = Arc::new(InsertBatcher::new(BatchConfig::default()));
612
613        batcher
614            .add(
615                "users".to_string(),
616                vec!["id".to_string()],
617                vec![vec!["1".to_string()], vec!["2".to_string()]],
618                "INSERT INTO users VALUES (1), (2)".to_string(),
619            )
620            .unwrap();
621
622        let stats = batcher.stats();
623        assert_eq!(stats.inserts_received, 1);
624        assert_eq!(stats.rows_received, 2);
625    }
626
627    /// Live proof that a flushed batch's combined INSERT actually
628    /// executes against PostgreSQL and the rows land. Gated on
629    /// `HELIOS_LIVE_PG` (`host:port`, e.g. `127.0.0.1:25433`); skips when
630    /// unset so CI without a backend stays green. Before this change the
631    /// flush discarded the SQL and faked success — this test would then
632    /// find zero rows.
633    #[tokio::test]
634    async fn flush_executes_against_live_backend() {
635        use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, TlsMode};
636
637        let addr = match std::env::var("HELIOS_LIVE_PG") {
638            Ok(a) if !a.is_empty() => a,
639            _ => {
640                eprintln!("skipping flush_executes_against_live_backend: set HELIOS_LIVE_PG");
641                return;
642            }
643        };
644        let (host, port_s) = addr.rsplit_once(':').unwrap();
645        let port: u16 = port_s.parse().unwrap();
646        let user = std::env::var("HELIOS_LIVE_USER").unwrap_or_else(|_| "bench".into());
647        let pass = std::env::var("HELIOS_LIVE_PASS").unwrap_or_else(|_| "benchpass".into());
648        let db = std::env::var("HELIOS_LIVE_DB").unwrap_or_else(|_| "benchdb".into());
649
650        let cfg = BackendConfig {
651            host: host.to_string(),
652            port,
653            user,
654            password: Some(pass),
655            database: Some(db),
656            application_name: Some("helios-batch-test".into()),
657            tls_mode: TlsMode::Disable,
658            connect_timeout: Duration::from_secs(5),
659            query_timeout: Duration::from_secs(5),
660            tls_config: default_client_config(),
661        };
662
663        // Seed a clean probe table.
664        let mut seed = BackendClient::connect(&cfg).await.expect("connect seed");
665        seed.execute("DROP TABLE IF EXISTS batch_probe").await.unwrap();
666        seed.execute("CREATE TABLE batch_probe(id int, total numeric)")
667            .await
668            .unwrap();
669        seed.close().await;
670
671        // Large batch size so nothing auto-flushes; we flush explicitly
672        // and await the real execution.
673        let batcher = Arc::new(
674            InsertBatcher::new(BatchConfig {
675                max_batch_size: 1000,
676                max_wait_ms: 60_000,
677                ..Default::default()
678            })
679            .with_backend(cfg.clone()),
680        );
681        batcher
682            .add(
683                "batch_probe".to_string(),
684                vec!["id".to_string(), "total".to_string()],
685                vec![
686                    vec!["1".to_string(), "99.99".to_string()],
687                    vec!["2".to_string(), "12.50".to_string()],
688                ],
689                String::new(),
690            )
691            .unwrap();
692        batcher.flush_batch("batch_probe").await;
693
694        // The two rows must actually be in PostgreSQL now.
695        let mut verify = BackendClient::connect(&cfg).await.expect("connect verify");
696        let n = verify
697            .query_scalar("SELECT count(*) AS n FROM batch_probe")
698            .await
699            .unwrap()
700            .as_i64("n")
701            .unwrap()
702            .unwrap_or(0);
703        let _ = verify.execute("DROP TABLE IF EXISTS batch_probe").await;
704        verify.close().await;
705
706        assert_eq!(n, 2, "expected 2 batched rows to land, found {}", n);
707        assert_eq!(batcher.stats().rows_inserted, 2);
708        assert_eq!(batcher.stats().flush_failures, 0);
709    }
710}