Skip to main content

sql_splitter/duckdb/
batch.rs

1//! Batch manager for DuckDB Appender-based bulk loading.
2//!
3//! This module provides efficient batched insertion of rows into DuckDB
4//! using the Appender API instead of individual INSERT statement execution.
5
6use crate::parser::ParsedValue;
7use ahash::AHashMap;
8use anyhow::Result;
9use duckdb::Connection;
10
11use super::ImportStats;
12
13/// Maximum rows to accumulate per batch before flushing
14pub const MAX_ROWS_PER_BATCH: usize = 10_000;
15
16/// A batch of rows for a single table
17#[derive(Debug)]
18pub struct InsertBatch {
19    /// Target table name
20    pub table: String,
21    /// Column list if explicitly specified
22    pub columns: Option<Vec<String>>,
23    /// Accumulated rows (each row is a Vec of ParsedValue)
24    pub rows: Vec<Vec<ParsedValue>>,
25    /// Original SQL statements for fallback execution
26    pub statements: Vec<String>,
27    /// Number of rows contributed by each statement
28    pub rows_per_statement: Vec<usize>,
29}
30
31impl InsertBatch {
32    /// Create a new batch for a table
33    pub fn new(table: String, columns: Option<Vec<String>>) -> Self {
34        Self {
35            table,
36            columns,
37            rows: Vec::new(),
38            statements: Vec::new(),
39            rows_per_statement: Vec::new(),
40        }
41    }
42
43    /// Check if batch is ready to flush
44    pub fn should_flush(&self) -> bool {
45        self.rows.len() >= MAX_ROWS_PER_BATCH
46    }
47
48    /// Total number of rows in batch
49    pub fn row_count(&self) -> usize {
50        self.rows.len()
51    }
52
53    /// Clear the batch
54    pub fn clear(&mut self) {
55        self.rows.clear();
56        self.statements.clear();
57        self.rows_per_statement.clear();
58    }
59}
60
61/// Batch key: (table_name, column_layout)
62/// Using Option<Vec<String>> for columns allows distinguishing between
63/// different column orderings for the same table.
64type BatchKey = (String, Option<Vec<String>>);
65
66/// Manages batched INSERT operations for multiple tables
67pub struct BatchManager {
68    /// Active batches keyed by (table, columns)
69    batches: AHashMap<BatchKey, InsertBatch>,
70    /// Maximum rows per batch
71    max_rows_per_batch: usize,
72}
73
74impl BatchManager {
75    /// Create a new batch manager
76    pub fn new(max_rows_per_batch: usize) -> Self {
77        Self {
78            batches: AHashMap::new(),
79            max_rows_per_batch,
80        }
81    }
82
83    /// Queue rows for insertion, returning a batch if it's ready to flush
84    pub fn queue_insert(
85        &mut self,
86        table: &str,
87        columns: Option<Vec<String>>,
88        rows: Vec<Vec<ParsedValue>>,
89        original_sql: String,
90    ) -> Option<InsertBatch> {
91        let row_count = rows.len();
92        let key = (table.to_string(), columns.clone());
93
94        let batch = self
95            .batches
96            .entry(key)
97            .or_insert_with(|| InsertBatch::new(table.to_string(), columns));
98
99        batch.rows.extend(rows);
100        batch.statements.push(original_sql);
101        batch.rows_per_statement.push(row_count);
102
103        // Check if we need to flush
104        if batch.rows.len() >= self.max_rows_per_batch {
105            // Take the batch out and return it
106            let key = (table.to_string(), batch.columns.clone());
107            self.batches.remove(&key)
108        } else {
109            None
110        }
111    }
112
113    /// Get any batches that are ready to flush
114    pub fn get_ready_batches(&mut self) -> Vec<InsertBatch> {
115        let mut ready = Vec::new();
116        let mut to_remove = Vec::new();
117
118        for (key, batch) in &self.batches {
119            if batch.rows.len() >= self.max_rows_per_batch {
120                to_remove.push(key.clone());
121            }
122        }
123
124        for key in to_remove {
125            if let Some(batch) = self.batches.remove(&key) {
126                ready.push(batch);
127            }
128        }
129
130        ready
131    }
132
133    /// Flush all remaining batches
134    pub fn drain_all(&mut self) -> Vec<InsertBatch> {
135        self.batches.drain().map(|(_, batch)| batch).collect()
136    }
137
138    /// Check if there are any pending batches
139    pub fn has_pending(&self) -> bool {
140        !self.batches.is_empty()
141    }
142}
143
144/// Format a ParsedValue for SQL insertion
145fn format_value_for_sql(value: &ParsedValue) -> String {
146    match value {
147        ParsedValue::Null => "NULL".to_string(),
148        ParsedValue::Integer(n) => n.to_string(),
149        ParsedValue::BigInteger(n) => n.to_string(),
150        ParsedValue::String { value } => {
151            // Escape single quotes by doubling them (SQL standard)
152            let escaped = value.replace('\'', "''");
153            format!("'{}'", escaped)
154        }
155        ParsedValue::Hex(bytes) => {
156            // Convert to hex string for DuckDB
157            let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
158            format!("x'{}'", hex)
159        }
160        ParsedValue::Other(raw) => {
161            let s = String::from_utf8_lossy(raw);
162            // Try to parse as float
163            if s.parse::<f64>().is_ok() {
164                s.to_string()
165            } else {
166                // Treat as text
167                let escaped = s.replace('\'', "''");
168                format!("'{}'", escaped)
169            }
170        }
171    }
172}
173
174/// Generate a batched INSERT statement from parsed values
175fn generate_batch_insert(
176    table: &str,
177    columns: &Option<Vec<String>>,
178    rows: &[Vec<ParsedValue>],
179) -> String {
180    if rows.is_empty() {
181        return String::new();
182    }
183
184    let mut sql = format!("INSERT INTO \"{}\"", table);
185
186    // Add column list if specified
187    if let Some(cols) = columns {
188        sql.push_str(" (");
189        for (i, col) in cols.iter().enumerate() {
190            if i > 0 {
191                sql.push_str(", ");
192            }
193            sql.push('"');
194            sql.push_str(col);
195            sql.push('"');
196        }
197        sql.push(')');
198    }
199
200    sql.push_str(" VALUES\n");
201
202    for (i, row) in rows.iter().enumerate() {
203        if i > 0 {
204            sql.push_str(",\n");
205        }
206        sql.push('(');
207        for (j, value) in row.iter().enumerate() {
208            if j > 0 {
209                sql.push_str(", ");
210            }
211            sql.push_str(&format_value_for_sql(value));
212        }
213        sql.push(')');
214    }
215    sql.push(';');
216
217    sql
218}
219
220/// Flush a batch using DuckDB's Appender API with transactional fallback
221pub fn flush_batch(
222    conn: &Connection,
223    batch: &mut InsertBatch,
224    stats: &mut ImportStats,
225    failed_tables: &mut std::collections::HashSet<String>,
226) -> Result<()> {
227    if batch.rows.is_empty() {
228        return Ok(());
229    }
230
231    // Skip tables we know don't exist
232    if failed_tables.contains(&batch.table) {
233        batch.clear();
234        return Ok(());
235    }
236
237    // Try the fast path with batched INSERT
238    match try_batch_insert(conn, batch, stats) {
239        Ok(true) => {
240            // Success via batched INSERT
241            batch.clear();
242            Ok(())
243        }
244        Ok(false) => {
245            // Table doesn't exist or other non-recoverable error
246            failed_tables.insert(batch.table.clone());
247            batch.clear();
248            Ok(())
249        }
250        Err(_) => {
251            // Batched INSERT failed (constraint violation, type mismatch, etc.)
252            // Fall back to per-statement execution
253            fallback_execute(conn, batch, stats)?;
254            batch.clear();
255            Ok(())
256        }
257    }
258}
259
260/// Try to insert using batched SQL execution, returns Ok(true) on success,
261/// Ok(false) if table doesn't exist, Err on constraint/type errors
262fn try_batch_insert(
263    conn: &Connection,
264    batch: &InsertBatch,
265    stats: &mut ImportStats,
266) -> Result<bool> {
267    // Generate a single batched INSERT statement
268    let batch_sql = generate_batch_insert(&batch.table, &batch.columns, &batch.rows);
269    if batch_sql.is_empty() {
270        return Ok(true);
271    }
272
273    // Execute the batched INSERT (within the loader's transaction context)
274    match conn.execute(&batch_sql, []) {
275        Ok(_) => {
276            stats.insert_statements += batch.statements.len();
277            stats.rows_inserted += batch.rows.len() as u64;
278            Ok(true)
279        }
280        Err(e) => {
281            let err_str = e.to_string();
282            // Check if it's a "table not found" error
283            if err_str.contains("does not exist") || err_str.contains("not found") {
284                return Ok(false);
285            }
286            Err(e.into())
287        }
288    }
289}
290
291/// Fallback: execute original SQL statements one by one
292fn fallback_execute(conn: &Connection, batch: &InsertBatch, stats: &mut ImportStats) -> Result<()> {
293    for stmt in &batch.statements {
294        match conn.execute(stmt, []) {
295            Ok(_) => {
296                stats.insert_statements += 1;
297                stats.rows_inserted += count_insert_rows(stmt);
298            }
299            Err(e) => {
300                if stats.warnings.len() < 100 {
301                    stats.warnings.push(format!(
302                        "Failed INSERT for {} in fallback: {}",
303                        batch.table, e
304                    ));
305                }
306                stats.statements_skipped += 1;
307            }
308        }
309    }
310    Ok(())
311}
312
313/// Count rows in an INSERT statement (simple heuristic)
314fn count_insert_rows(sql: &str) -> u64 {
315    if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
316        let after_values = &sql[values_pos + 6..];
317        let mut count = 0u64;
318        let mut depth: i32 = 0;
319        let mut in_string = false;
320        let mut prev_char = ' ';
321
322        for c in after_values.chars() {
323            if c == '\'' && prev_char != '\\' {
324                in_string = !in_string;
325            }
326            if !in_string {
327                if c == '(' {
328                    if depth == 0 {
329                        count += 1;
330                    }
331                    depth += 1;
332                } else if c == ')' {
333                    depth = depth.saturating_sub(1);
334                }
335            }
336            prev_char = c;
337        }
338        count
339    } else {
340        1
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_batch_manager_queue() {
350        let mut mgr = BatchManager::new(100);
351
352        let rows = vec![vec![
353            ParsedValue::Integer(1),
354            ParsedValue::String {
355                value: "test".to_string(),
356            },
357        ]];
358
359        let result = mgr.queue_insert(
360            "users",
361            None,
362            rows,
363            "INSERT INTO users VALUES (1, 'test')".to_string(),
364        );
365        assert!(result.is_none()); // Not ready yet
366        assert!(mgr.has_pending());
367    }
368
369    #[test]
370    fn test_batch_manager_flush_threshold() {
371        let mut mgr = BatchManager::new(2);
372
373        let rows1 = vec![vec![ParsedValue::Integer(1)]];
374        let rows2 = vec![vec![ParsedValue::Integer(2)], vec![ParsedValue::Integer(3)]];
375
376        mgr.queue_insert("test", None, rows1, "SQL1".to_string());
377        let result = mgr.queue_insert("test", None, rows2, "SQL2".to_string());
378
379        assert!(result.is_some());
380        let batch = result.unwrap();
381        assert_eq!(batch.row_count(), 3);
382    }
383
384    #[test]
385    fn test_count_insert_rows() {
386        assert_eq!(count_insert_rows("INSERT INTO t VALUES (1)"), 1);
387        assert_eq!(count_insert_rows("INSERT INTO t VALUES (1), (2), (3)"), 3);
388        assert_eq!(
389            count_insert_rows("INSERT INTO t VALUES (1, 'a(b)'), (2, 'c')"),
390            2
391        );
392    }
393
394    #[test]
395    fn test_generate_batch_insert_with_columns() {
396        let rows = vec![
397            vec![
398                ParsedValue::String {
399                    value: "alice".to_string(),
400                },
401                ParsedValue::Integer(1),
402            ],
403            vec![
404                ParsedValue::String {
405                    value: "bob".to_string(),
406                },
407                ParsedValue::Integer(2),
408            ],
409        ];
410        let columns = Some(vec!["name".to_string(), "id".to_string()]);
411        let sql = generate_batch_insert("users", &columns, &rows);
412        assert!(sql.contains("INSERT INTO \"users\" (\"name\", \"id\") VALUES"));
413        assert!(sql.contains("'alice'"));
414        assert!(sql.contains("'bob'"));
415    }
416
417    #[test]
418    fn test_generate_batch_insert_without_columns() {
419        let rows = vec![vec![
420            ParsedValue::Integer(1),
421            ParsedValue::String {
422                value: "test".to_string(),
423            },
424        ]];
425        let sql = generate_batch_insert("test", &None, &rows);
426        assert_eq!(sql, "INSERT INTO \"test\" VALUES\n(1, 'test');");
427    }
428}