1use crate::parser::ParsedValue;
7use ahash::AHashMap;
8use anyhow::Result;
9use duckdb::Connection;
10
11use super::ImportStats;
12
13pub const MAX_ROWS_PER_BATCH: usize = 10_000;
15
16#[derive(Debug)]
18pub struct InsertBatch {
19 pub table: String,
21 pub columns: Option<Vec<String>>,
23 pub rows: Vec<Vec<ParsedValue>>,
25 pub statements: Vec<String>,
27 pub rows_per_statement: Vec<usize>,
29}
30
31impl InsertBatch {
32 pub fn new(table: String, columns: Option<Vec<String>>) -> Self {
34 Self {
35 table,
36 columns,
37 rows: Vec::new(),
38 statements: Vec::new(),
39 rows_per_statement: Vec::new(),
40 }
41 }
42
43 pub fn should_flush(&self) -> bool {
45 self.rows.len() >= MAX_ROWS_PER_BATCH
46 }
47
48 pub fn row_count(&self) -> usize {
50 self.rows.len()
51 }
52
53 pub fn clear(&mut self) {
55 self.rows.clear();
56 self.statements.clear();
57 self.rows_per_statement.clear();
58 }
59}
60
61type BatchKey = (String, Option<Vec<String>>);
65
66pub struct BatchManager {
68 batches: AHashMap<BatchKey, InsertBatch>,
70 max_rows_per_batch: usize,
72}
73
74impl BatchManager {
75 pub fn new(max_rows_per_batch: usize) -> Self {
77 Self {
78 batches: AHashMap::new(),
79 max_rows_per_batch,
80 }
81 }
82
83 pub fn queue_insert(
85 &mut self,
86 table: &str,
87 columns: Option<Vec<String>>,
88 rows: Vec<Vec<ParsedValue>>,
89 original_sql: String,
90 ) -> Option<InsertBatch> {
91 let row_count = rows.len();
92 let key = (table.to_string(), columns.clone());
93
94 let batch = self
95 .batches
96 .entry(key)
97 .or_insert_with(|| InsertBatch::new(table.to_string(), columns));
98
99 batch.rows.extend(rows);
100 batch.statements.push(original_sql);
101 batch.rows_per_statement.push(row_count);
102
103 if batch.rows.len() >= self.max_rows_per_batch {
105 let key = (table.to_string(), batch.columns.clone());
107 self.batches.remove(&key)
108 } else {
109 None
110 }
111 }
112
113 pub fn get_ready_batches(&mut self) -> Vec<InsertBatch> {
115 let mut ready = Vec::new();
116 let mut to_remove = Vec::new();
117
118 for (key, batch) in &self.batches {
119 if batch.rows.len() >= self.max_rows_per_batch {
120 to_remove.push(key.clone());
121 }
122 }
123
124 for key in to_remove {
125 if let Some(batch) = self.batches.remove(&key) {
126 ready.push(batch);
127 }
128 }
129
130 ready
131 }
132
133 pub fn drain_all(&mut self) -> Vec<InsertBatch> {
135 self.batches.drain().map(|(_, batch)| batch).collect()
136 }
137
138 pub fn has_pending(&self) -> bool {
140 !self.batches.is_empty()
141 }
142}
143
144fn format_value_for_sql(value: &ParsedValue) -> String {
146 match value {
147 ParsedValue::Null => "NULL".to_string(),
148 ParsedValue::Integer(n) => n.to_string(),
149 ParsedValue::BigInteger(n) => n.to_string(),
150 ParsedValue::String { value } => {
151 let escaped = value.replace('\'', "''");
153 format!("'{}'", escaped)
154 }
155 ParsedValue::Hex(bytes) => {
156 let hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
158 format!("x'{}'", hex)
159 }
160 ParsedValue::Other(raw) => {
161 let s = String::from_utf8_lossy(raw);
162 if s.parse::<f64>().is_ok() {
164 s.to_string()
165 } else {
166 let escaped = s.replace('\'', "''");
168 format!("'{}'", escaped)
169 }
170 }
171 }
172}
173
174fn generate_batch_insert(
176 table: &str,
177 columns: &Option<Vec<String>>,
178 rows: &[Vec<ParsedValue>],
179) -> String {
180 if rows.is_empty() {
181 return String::new();
182 }
183
184 let mut sql = format!("INSERT INTO \"{}\"", table);
185
186 if let Some(cols) = columns {
188 sql.push_str(" (");
189 for (i, col) in cols.iter().enumerate() {
190 if i > 0 {
191 sql.push_str(", ");
192 }
193 sql.push('"');
194 sql.push_str(col);
195 sql.push('"');
196 }
197 sql.push(')');
198 }
199
200 sql.push_str(" VALUES\n");
201
202 for (i, row) in rows.iter().enumerate() {
203 if i > 0 {
204 sql.push_str(",\n");
205 }
206 sql.push('(');
207 for (j, value) in row.iter().enumerate() {
208 if j > 0 {
209 sql.push_str(", ");
210 }
211 sql.push_str(&format_value_for_sql(value));
212 }
213 sql.push(')');
214 }
215 sql.push(';');
216
217 sql
218}
219
220pub fn flush_batch(
222 conn: &Connection,
223 batch: &mut InsertBatch,
224 stats: &mut ImportStats,
225 failed_tables: &mut std::collections::HashSet<String>,
226) -> Result<()> {
227 if batch.rows.is_empty() {
228 return Ok(());
229 }
230
231 if failed_tables.contains(&batch.table) {
233 batch.clear();
234 return Ok(());
235 }
236
237 match try_batch_insert(conn, batch, stats) {
239 Ok(true) => {
240 batch.clear();
242 Ok(())
243 }
244 Ok(false) => {
245 failed_tables.insert(batch.table.clone());
247 batch.clear();
248 Ok(())
249 }
250 Err(_) => {
251 fallback_execute(conn, batch, stats)?;
254 batch.clear();
255 Ok(())
256 }
257 }
258}
259
260fn try_batch_insert(
263 conn: &Connection,
264 batch: &InsertBatch,
265 stats: &mut ImportStats,
266) -> Result<bool> {
267 let batch_sql = generate_batch_insert(&batch.table, &batch.columns, &batch.rows);
269 if batch_sql.is_empty() {
270 return Ok(true);
271 }
272
273 match conn.execute(&batch_sql, []) {
275 Ok(_) => {
276 stats.insert_statements += batch.statements.len();
277 stats.rows_inserted += batch.rows.len() as u64;
278 Ok(true)
279 }
280 Err(e) => {
281 let err_str = e.to_string();
282 if err_str.contains("does not exist") || err_str.contains("not found") {
284 return Ok(false);
285 }
286 Err(e.into())
287 }
288 }
289}
290
291fn fallback_execute(conn: &Connection, batch: &InsertBatch, stats: &mut ImportStats) -> Result<()> {
293 for stmt in &batch.statements {
294 match conn.execute(stmt, []) {
295 Ok(_) => {
296 stats.insert_statements += 1;
297 stats.rows_inserted += count_insert_rows(stmt);
298 }
299 Err(e) => {
300 if stats.warnings.len() < 100 {
301 stats.warnings.push(format!(
302 "Failed INSERT for {} in fallback: {}",
303 batch.table, e
304 ));
305 }
306 stats.statements_skipped += 1;
307 }
308 }
309 }
310 Ok(())
311}
312
313fn count_insert_rows(sql: &str) -> u64 {
315 if let Some(values_pos) = sql.to_uppercase().find("VALUES") {
316 let after_values = &sql[values_pos + 6..];
317 let mut count = 0u64;
318 let mut depth: i32 = 0;
319 let mut in_string = false;
320 let mut prev_char = ' ';
321
322 for c in after_values.chars() {
323 if c == '\'' && prev_char != '\\' {
324 in_string = !in_string;
325 }
326 if !in_string {
327 if c == '(' {
328 if depth == 0 {
329 count += 1;
330 }
331 depth += 1;
332 } else if c == ')' {
333 depth = depth.saturating_sub(1);
334 }
335 }
336 prev_char = c;
337 }
338 count
339 } else {
340 1
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_batch_manager_queue() {
350 let mut mgr = BatchManager::new(100);
351
352 let rows = vec![vec![
353 ParsedValue::Integer(1),
354 ParsedValue::String {
355 value: "test".to_string(),
356 },
357 ]];
358
359 let result = mgr.queue_insert(
360 "users",
361 None,
362 rows,
363 "INSERT INTO users VALUES (1, 'test')".to_string(),
364 );
365 assert!(result.is_none()); assert!(mgr.has_pending());
367 }
368
369 #[test]
370 fn test_batch_manager_flush_threshold() {
371 let mut mgr = BatchManager::new(2);
372
373 let rows1 = vec![vec![ParsedValue::Integer(1)]];
374 let rows2 = vec![vec![ParsedValue::Integer(2)], vec![ParsedValue::Integer(3)]];
375
376 mgr.queue_insert("test", None, rows1, "SQL1".to_string());
377 let result = mgr.queue_insert("test", None, rows2, "SQL2".to_string());
378
379 assert!(result.is_some());
380 let batch = result.unwrap();
381 assert_eq!(batch.row_count(), 3);
382 }
383
384 #[test]
385 fn test_count_insert_rows() {
386 assert_eq!(count_insert_rows("INSERT INTO t VALUES (1)"), 1);
387 assert_eq!(count_insert_rows("INSERT INTO t VALUES (1), (2), (3)"), 3);
388 assert_eq!(
389 count_insert_rows("INSERT INTO t VALUES (1, 'a(b)'), (2, 'c')"),
390 2
391 );
392 }
393
394 #[test]
395 fn test_generate_batch_insert_with_columns() {
396 let rows = vec![
397 vec![
398 ParsedValue::String {
399 value: "alice".to_string(),
400 },
401 ParsedValue::Integer(1),
402 ],
403 vec![
404 ParsedValue::String {
405 value: "bob".to_string(),
406 },
407 ParsedValue::Integer(2),
408 ],
409 ];
410 let columns = Some(vec!["name".to_string(), "id".to_string()]);
411 let sql = generate_batch_insert("users", &columns, &rows);
412 assert!(sql.contains("INSERT INTO \"users\" (\"name\", \"id\") VALUES"));
413 assert!(sql.contains("'alice'"));
414 assert!(sql.contains("'bob'"));
415 }
416
417 #[test]
418 fn test_generate_batch_insert_without_columns() {
419 let rows = vec![vec![
420 ParsedValue::Integer(1),
421 ParsedValue::String {
422 value: "test".to_string(),
423 },
424 ]];
425 let sql = generate_batch_insert("test", &None, &rows);
426 assert_eq!(sql, "INSERT INTO \"test\" VALUES\n(1, 'test');");
427 }
428}