Skip to main content

heliosdb_proxy/
batch.rs

1//! INSERT Batching for HeliosProxy
2//!
3//! Batches multiple INSERT statements into combined bulk operations for
4//! improved throughput. Reduces round-trips and enables lock-free bulk ingestion.
5
6use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use dashmap::DashMap;
10use tokio::sync::oneshot;
11use serde::{Deserialize, Serialize};
12
13/// Table identifier
14pub type TableId = String;
15
16/// Batch ticket ID
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub struct BatchTicketId(u64);
19
20/// Batch configuration
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BatchConfig {
23    /// Enable INSERT batching
24    pub enabled: bool,
25    /// Maximum batch size (number of rows)
26    pub max_batch_size: usize,
27    /// Maximum batch wait time (ms) before flushing
28    pub max_wait_ms: u64,
29    /// Maximum memory per batch (bytes)
30    pub max_batch_bytes: usize,
31    /// Enable automatic flushing
32    pub auto_flush: bool,
33    /// Tables to batch (empty = all tables)
34    pub batch_tables: Vec<String>,
35}
36
37impl Default for BatchConfig {
38    fn default() -> Self {
39        Self {
40            enabled: true,
41            max_batch_size: 1000,
42            max_wait_ms: 10,
43            max_batch_bytes: 16 * 1024 * 1024, // 16MB
44            auto_flush: true,
45            batch_tables: Vec::new(), // Batch all tables by default
46        }
47    }
48}
49
50/// An individual INSERT request
51#[derive(Debug)]
52pub struct InsertRequest {
53    /// Table name
54    pub table: String,
55    /// Column names
56    pub columns: Vec<String>,
57    /// Row values (each inner vec is a row)
58    pub values: Vec<Vec<String>>,
59    /// Original SQL (for fallback)
60    pub original_sql: String,
61    /// Request timestamp
62    pub submitted_at: Instant,
63    /// Response channel
64    response_tx: Option<oneshot::Sender<BatchResult>>,
65}
66
67/// Result of a batch operation
68#[derive(Debug, Clone)]
69pub struct BatchResult {
70    /// Ticket ID
71    pub ticket_id: BatchTicketId,
72    /// Number of rows inserted
73    pub rows_inserted: u64,
74    /// Whether the batch succeeded
75    pub success: bool,
76    /// Error message if failed
77    pub error: Option<String>,
78    /// Time spent waiting in batch
79    pub wait_time: Duration,
80    /// Execution time
81    pub execution_time: Duration,
82}
83
84/// Ticket for awaiting batch completion
85pub struct BatchTicket {
86    id: BatchTicketId,
87    rx: oneshot::Receiver<BatchResult>,
88}
89
90impl BatchTicket {
91    /// Wait for the batch to complete
92    pub async fn wait(self) -> Result<BatchResult, BatchError> {
93        self.rx.await.map_err(|_| BatchError::ChannelClosed)
94    }
95
96    /// Wait with timeout
97    pub async fn wait_timeout(self, timeout: Duration) -> Result<BatchResult, BatchError> {
98        tokio::time::timeout(timeout, self.rx)
99            .await
100            .map_err(|_| BatchError::Timeout)?
101            .map_err(|_| BatchError::ChannelClosed)
102    }
103
104    /// Get the ticket ID
105    pub fn id(&self) -> BatchTicketId {
106        self.id
107    }
108}
109
110/// Batch error types
111#[derive(Debug, Clone)]
112pub enum BatchError {
113    /// Batching is disabled
114    Disabled,
115    /// Batch is full
116    BatchFull,
117    /// Timeout waiting for batch
118    Timeout,
119    /// Channel closed
120    ChannelClosed,
121    /// Execution failed
122    ExecutionFailed(String),
123}
124
125impl std::fmt::Display for BatchError {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        match self {
128            Self::Disabled => write!(f, "Batching is disabled"),
129            Self::BatchFull => write!(f, "Batch is full"),
130            Self::Timeout => write!(f, "Batch timeout"),
131            Self::ChannelClosed => write!(f, "Channel closed"),
132            Self::ExecutionFailed(e) => write!(f, "Execution failed: {}", e),
133        }
134    }
135}
136
137impl std::error::Error for BatchError {}
138
139/// Statistics for batch operations
140#[derive(Debug, Clone, Default, Serialize, Deserialize)]
141pub struct BatchStats {
142    /// Total inserts received
143    pub inserts_received: u64,
144    /// Total rows received
145    pub rows_received: u64,
146    /// Total batches flushed
147    pub batches_flushed: u64,
148    /// Total rows inserted
149    pub rows_inserted: u64,
150    /// Average batch size
151    pub avg_batch_size: f64,
152    /// Average wait time (ms)
153    pub avg_wait_time_ms: f64,
154    /// Average execution time (ms)
155    pub avg_execution_time_ms: f64,
156    /// Batches flushed due to size limit
157    pub size_triggered_flushes: u64,
158    /// Batches flushed due to time limit
159    pub time_triggered_flushes: u64,
160}
161
162/// Pending batch for a table
163struct PendingBatch {
164    /// INSERT requests in this batch
165    requests: Vec<InsertRequest>,
166    /// Total rows in batch
167    row_count: usize,
168    /// Total bytes in batch (estimated)
169    byte_count: usize,
170    /// First request timestamp
171    first_submitted: Instant,
172}
173
174impl PendingBatch {
175    fn new() -> Self {
176        Self {
177            requests: Vec::with_capacity(100),
178            row_count: 0,
179            byte_count: 0,
180            first_submitted: Instant::now(),
181        }
182    }
183
184    fn add(&mut self, request: InsertRequest) {
185        let row_count = request.values.len();
186        let byte_estimate = request.original_sql.len();
187
188        if self.requests.is_empty() {
189            self.first_submitted = request.submitted_at;
190        }
191
192        self.row_count += row_count;
193        self.byte_count += byte_estimate;
194        self.requests.push(request);
195    }
196
197    fn is_empty(&self) -> bool {
198        self.requests.is_empty()
199    }
200
201    fn should_flush(&self, config: &BatchConfig) -> bool {
202        self.row_count >= config.max_batch_size ||
203        self.byte_count >= config.max_batch_bytes ||
204        self.first_submitted.elapsed().as_millis() as u64 >= config.max_wait_ms
205    }
206
207    fn drain(&mut self) -> (Vec<InsertRequest>, usize) {
208        let row_count = self.row_count;
209        self.row_count = 0;
210        self.byte_count = 0;
211        (std::mem::take(&mut self.requests), row_count)
212    }
213}
214
215/// INSERT Batcher
216///
217/// Batches INSERT statements for improved throughput.
218pub struct InsertBatcher {
219    /// Configuration
220    config: BatchConfig,
221    /// Pending batches per table
222    pending: DashMap<TableId, PendingBatch>,
223    /// Next ticket ID
224    next_ticket_id: AtomicU64,
225    /// Statistics
226    stats: Arc<parking_lot::RwLock<BatchStats>>,
227    /// Shutdown flag
228    shutdown: AtomicBool,
229}
230
231impl InsertBatcher {
232    /// Create a new INSERT batcher
233    pub fn new(config: BatchConfig) -> Self {
234        Self {
235            config,
236            pending: DashMap::new(),
237            next_ticket_id: AtomicU64::new(1),
238            stats: Arc::new(parking_lot::RwLock::new(BatchStats::default())),
239            shutdown: AtomicBool::new(false),
240        }
241    }
242
243    /// Add an INSERT to the batch
244    pub fn add(
245        &self,
246        table: String,
247        columns: Vec<String>,
248        values: Vec<Vec<String>>,
249        original_sql: String,
250    ) -> Result<BatchTicket, BatchError> {
251        if !self.config.enabled {
252            return Err(BatchError::Disabled);
253        }
254
255        if self.shutdown.load(Ordering::Relaxed) {
256            return Err(BatchError::ExecutionFailed("Batcher shutdown".to_string()));
257        }
258
259        // Check if table should be batched
260        if !self.config.batch_tables.is_empty() &&
261           !self.config.batch_tables.contains(&table)
262        {
263            return Err(BatchError::Disabled);
264        }
265
266        let ticket_id = BatchTicketId(self.next_ticket_id.fetch_add(1, Ordering::Relaxed));
267        let (tx, rx) = oneshot::channel();
268
269        let row_count = values.len();
270
271        let request = InsertRequest {
272            table: table.clone(),
273            columns,
274            values,
275            original_sql,
276            submitted_at: Instant::now(),
277            response_tx: Some(tx),
278        };
279
280        // Update statistics
281        {
282            let mut stats = self.stats.write();
283            stats.inserts_received += 1;
284            stats.rows_received += row_count as u64;
285        }
286
287        // Add to pending batch
288        let should_flush = {
289            let mut batch = self.pending.entry(table.clone()).or_insert_with(PendingBatch::new);
290            batch.add(request);
291            batch.should_flush(&self.config)
292        };
293
294        // Trigger flush if needed
295        if should_flush {
296            self.flush_batch(&table);
297        }
298
299        Ok(BatchTicket { id: ticket_id, rx })
300    }
301
302    /// Flush a batch for a table
303    pub fn flush_batch(&self, table: &str) {
304        if let Some((_, mut batch)) = self.pending.remove(table) {
305            if batch.is_empty() {
306                return;
307            }
308
309            let (requests, row_count) = batch.drain();
310            let execution_start = Instant::now();
311
312            // Combine into a single bulk INSERT
313            let _combined_sql = self.combine_inserts(&requests);
314
315            // Execute the combined INSERT
316            // In production, this would call the backend
317            let success = true; // Placeholder
318            let error: Option<String> = None;
319
320            let execution_time = execution_start.elapsed();
321
322            // Update statistics
323            {
324                let mut stats = self.stats.write();
325                stats.batches_flushed += 1;
326                stats.rows_inserted += row_count as u64;
327
328                // Update average batch size
329                if stats.batches_flushed == 1 {
330                    stats.avg_batch_size = row_count as f64;
331                } else {
332                    stats.avg_batch_size = stats.avg_batch_size * 0.9 + row_count as f64 * 0.1;
333                }
334
335                // Update average execution time
336                let exec_ms = execution_time.as_millis() as f64;
337                if stats.batches_flushed == 1 {
338                    stats.avg_execution_time_ms = exec_ms;
339                } else {
340                    stats.avg_execution_time_ms = stats.avg_execution_time_ms * 0.9 + exec_ms * 0.1;
341                }
342            }
343
344            // Send responses to all waiting requests
345            for mut req in requests {
346                let wait_time = req.submitted_at.elapsed() - execution_time;
347
348                if let Some(tx) = req.response_tx.take() {
349                    let _ = tx.send(BatchResult {
350                        ticket_id: BatchTicketId(0), // Individual tickets not tracked
351                        rows_inserted: req.values.len() as u64,
352                        success,
353                        error: error.clone(),
354                        wait_time,
355                        execution_time,
356                    });
357                }
358            }
359        }
360    }
361
362    /// Combine multiple INSERT requests into a single SQL statement
363    fn combine_inserts(&self, requests: &[InsertRequest]) -> String {
364        if requests.is_empty() {
365            return String::new();
366        }
367
368        let first = &requests[0];
369        let table = &first.table;
370        let columns = &first.columns;
371
372        let mut sql = format!(
373            "INSERT INTO {} ({}) VALUES ",
374            table,
375            columns.join(", ")
376        );
377
378        let mut value_parts: Vec<String> = Vec::new();
379
380        for req in requests {
381            for row in &req.values {
382                value_parts.push(format!("({})", row.join(", ")));
383            }
384        }
385
386        sql.push_str(&value_parts.join(", "));
387
388        sql
389    }
390
391    /// Flush all pending batches
392    pub fn flush_all(&self) {
393        let tables: Vec<TableId> = self.pending.iter().map(|r| r.key().clone()).collect();
394        for table in tables {
395            self.flush_batch(&table);
396        }
397    }
398
399    /// Get the current batch size for a table
400    pub fn batch_size(&self, table: &str) -> usize {
401        self.pending
402            .get(table)
403            .map(|b| b.row_count)
404            .unwrap_or(0)
405    }
406
407    /// Get statistics snapshot
408    pub fn stats(&self) -> BatchStats {
409        self.stats.read().clone()
410    }
411
412    /// Shutdown the batcher
413    pub fn shutdown(&self) {
414        self.shutdown.store(true, Ordering::Release);
415        self.flush_all();
416    }
417
418    /// Start auto-flush background task
419    pub fn start_auto_flush(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
420        let interval = Duration::from_millis(self.config.max_wait_ms);
421
422        tokio::spawn(async move {
423            let mut interval_timer = tokio::time::interval(interval);
424
425            loop {
426                interval_timer.tick().await;
427
428                if self.shutdown.load(Ordering::Relaxed) {
429                    break;
430                }
431
432                // Check each batch for timeout
433                let tables: Vec<TableId> = self.pending
434                    .iter()
435                    .filter(|r| {
436                        r.first_submitted.elapsed().as_millis() as u64 >= self.config.max_wait_ms
437                    })
438                    .map(|r| r.key().clone())
439                    .collect();
440
441                for table in tables {
442                    self.flush_batch(&table);
443                    self.stats.write().time_triggered_flushes += 1;
444                }
445            }
446        })
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[tokio::test]
455    async fn test_batch_add() {
456        let batcher = InsertBatcher::new(BatchConfig::default());
457
458        let ticket = batcher.add(
459            "users".to_string(),
460            vec!["id".to_string(), "name".to_string()],
461            vec![vec!["1".to_string(), "'Alice'".to_string()]],
462            "INSERT INTO users (id, name) VALUES (1, 'Alice')".to_string(),
463        ).unwrap();
464
465        assert_eq!(batcher.batch_size("users"), 1);
466    }
467
468    #[tokio::test]
469    async fn test_batch_flush_on_size() {
470        let config = BatchConfig {
471            max_batch_size: 2,
472            ..Default::default()
473        };
474        let batcher = InsertBatcher::new(config);
475
476        // Add first INSERT
477        batcher.add(
478            "users".to_string(),
479            vec!["id".to_string()],
480            vec![vec!["1".to_string()]],
481            "INSERT INTO users VALUES (1)".to_string(),
482        ).unwrap();
483
484        assert_eq!(batcher.batch_size("users"), 1);
485
486        // Add second INSERT - should trigger flush
487        batcher.add(
488            "users".to_string(),
489            vec!["id".to_string()],
490            vec![vec!["2".to_string()]],
491            "INSERT INTO users VALUES (2)".to_string(),
492        ).unwrap();
493
494        // Batch should be flushed
495        assert_eq!(batcher.batch_size("users"), 0);
496    }
497
498    #[test]
499    fn test_combine_inserts() {
500        let batcher = InsertBatcher::new(BatchConfig::default());
501
502        let requests = vec![
503            InsertRequest {
504                table: "users".to_string(),
505                columns: vec!["id".to_string(), "name".to_string()],
506                values: vec![vec!["1".to_string(), "'Alice'".to_string()]],
507                original_sql: String::new(),
508                submitted_at: Instant::now(),
509                response_tx: None,
510            },
511            InsertRequest {
512                table: "users".to_string(),
513                columns: vec!["id".to_string(), "name".to_string()],
514                values: vec![vec!["2".to_string(), "'Bob'".to_string()]],
515                original_sql: String::new(),
516                submitted_at: Instant::now(),
517                response_tx: None,
518            },
519        ];
520
521        let combined = batcher.combine_inserts(&requests);
522        assert!(combined.contains("INSERT INTO users"));
523        assert!(combined.contains("(1, 'Alice')"));
524        assert!(combined.contains("(2, 'Bob')"));
525    }
526
527    #[test]
528    fn test_batch_stats() {
529        let batcher = InsertBatcher::new(BatchConfig::default());
530
531        batcher.add(
532            "users".to_string(),
533            vec!["id".to_string()],
534            vec![vec!["1".to_string()], vec!["2".to_string()]],
535            "INSERT INTO users VALUES (1), (2)".to_string(),
536        ).unwrap();
537
538        let stats = batcher.stats();
539        assert_eq!(stats.inserts_received, 1);
540        assert_eq!(stats.rows_received, 2);
541    }
542}