1use crate::backend::Backend;
31use crate::connection::Connection;
32use crate::copy::{
33 BulkMode, CopyFormat, IfExists, backend_needs_explicit_commit, insert_batch, quote_identifier,
34};
35use crate::error::SqlError;
36use crate::transaction::{begin_transaction, commit_transaction, rollback_transaction};
37use crate::value::{ColumnInfo, Row};
38
39pub const DEFAULT_WRITE_BATCH: usize = 1000;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
50pub enum WriteMode {
51 #[default]
55 Insert,
56 Skip,
60 Upsert,
64}
65
66impl WriteMode {
67 fn if_exists(self) -> IfExists {
69 match self {
70 WriteMode::Insert => IfExists::Append,
71 WriteMode::Skip => IfExists::Skip,
72 WriteMode::Upsert => IfExists::Upsert,
73 }
74 }
75
76 fn needs_key(self) -> bool {
78 matches!(self, WriteMode::Skip | WriteMode::Upsert)
79 }
80}
81
82pub struct WriteOptions {
84 pub mode: WriteMode,
86 pub batch_size: usize,
88 pub key_columns: Vec<String>,
91 pub bulk_mode: BulkMode,
95 pub copy_format: CopyFormat,
98 pub atomic: bool,
105 pub isolate_failures: bool,
113 pub verbose: bool,
116}
117
118impl Default for WriteOptions {
119 fn default() -> Self {
120 Self {
121 mode: WriteMode::default(),
122 batch_size: DEFAULT_WRITE_BATCH,
123 key_columns: Vec::new(),
124 bulk_mode: BulkMode::Off,
125 copy_format: CopyFormat::Text,
126 atomic: false,
127 isolate_failures: false,
128 verbose: false,
129 }
130 }
131}
132
133impl std::fmt::Debug for WriteOptions {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("WriteOptions")
136 .field("mode", &self.mode)
137 .field("batch_size", &self.batch_size)
138 .field("key_columns", &self.key_columns)
139 .field("bulk_mode", &self.bulk_mode)
140 .field("copy_format", &self.copy_format)
141 .field("atomic", &self.atomic)
142 .field("isolate_failures", &self.isolate_failures)
143 .field("verbose", &self.verbose)
144 .finish()
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
150pub enum BatchOutcome {
151 Written,
153 Rejected,
156}
157
158#[derive(Debug, Clone)]
165pub struct RejectedBatch {
166 pub batch_index: usize,
167 pub start_row: u64,
168 pub row_count: usize,
169 pub error: String,
170}
171
172#[derive(Debug, Clone)]
175pub struct RejectedRow {
176 pub row_index: u64,
177 pub error: String,
178}
179
180#[derive(Debug, Clone, Default)]
187pub struct WriteReport {
188 pub rows_attempted: u64,
190 pub rows_written: u64,
192 pub batches_committed: usize,
194 pub rejected_batches: Vec<RejectedBatch>,
196 pub rejected_rows: Vec<RejectedRow>,
199}
200
201impl WriteReport {
202 #[must_use]
204 pub fn is_complete(&self) -> bool {
205 self.rejected_batches.is_empty() && self.rejected_rows.is_empty()
206 }
207}
208
209pub fn write_rows<I>(
228 dst: &mut dyn Connection,
229 backend: Backend,
230 table: &str,
231 columns: &[ColumnInfo],
232 rows: I,
233 opts: &WriteOptions,
234) -> Result<WriteReport, SqlError>
235where
236 I: IntoIterator<Item = Row>,
237{
238 if opts.mode.needs_key() && opts.key_columns.is_empty() {
239 return Err(SqlError::QueryFailed(format!(
240 "{:?} write mode requires key_columns (conflict key); none supplied",
241 opts.mode
242 )));
243 }
244 for key in &opts.key_columns {
247 if !columns.iter().any(|c| &c.name == key) {
248 return Err(SqlError::QueryFailed(format!(
249 "key column {key:?} is not among the destination columns"
250 )));
251 }
252 }
253
254 let batch_size = if opts.batch_size == 0 {
255 DEFAULT_WRITE_BATCH
256 } else {
257 opts.batch_size
258 };
259 let if_exists = opts.mode.if_exists();
260 let quoted_table = quote_identifier(table, backend);
261 let cols_clause = columns
262 .iter()
263 .map(|c| quote_identifier(&c.name, backend))
264 .collect::<Vec<_>>()
265 .join(", ");
266
267 let mut report = WriteReport::default();
268
269 let atomic_opened = if opts.atomic {
270 #[cfg(feature = "mssql")]
273 if matches!(backend, Backend::MsSql) {
274 let _ = dst.execute("SET XACT_ABORT ON");
275 }
276 begin_transaction(dst, backend)
277 } else {
278 false
279 };
280
281 let mut iter = rows.into_iter();
282 let mut batch: Vec<Row> = Vec::with_capacity(batch_size);
283 let mut batch_index = 0usize;
284 let mut next_row: u64 = 0;
285 let mut atomic_failure: Option<SqlError> = None;
286
287 loop {
288 batch.clear();
289 for _ in 0..batch_size {
290 match iter.next() {
291 Some(row) => batch.push(row),
292 None => break,
293 }
294 }
295 if batch.is_empty() {
296 break;
297 }
298 let start_row = next_row;
299 let n = batch.len();
300 report.rows_attempted += n as u64;
301 next_row += n as u64;
302
303 match insert_batch(
304 dst,
305 table,
306 columns,
307 &opts.key_columns,
308 "ed_table,
309 &cols_clause,
310 &batch,
311 backend,
312 if_exists,
313 opts.bulk_mode,
314 opts.copy_format,
315 opts.verbose,
316 ) {
317 Ok(()) => {
318 report.rows_written += n as u64;
319 report.batches_committed += 1;
320 }
321 Err(err) => {
322 if atomic_opened {
323 record_batch_rejection(&mut report, batch_index, start_row, n, &err);
326 atomic_failure = Some(err);
327 break;
328 }
329 if opts.isolate_failures {
330 let written = probe_rows(
333 dst,
334 table,
335 columns,
336 &opts.key_columns,
337 "ed_table,
338 &cols_clause,
339 &batch,
340 backend,
341 if_exists,
342 opts.copy_format,
343 opts.verbose,
344 start_row,
345 &mut report,
346 );
347 report.rows_written += written;
348 } else {
349 record_batch_rejection(&mut report, batch_index, start_row, n, &err);
350 }
351 }
352 }
353 batch_index += 1;
354 }
355
356 if atomic_opened {
357 if let Some(err) = atomic_failure {
358 let _ = rollback_transaction(dst, backend);
359 report.rows_written = 0;
361 report.batches_committed = 0;
362 return Err(SqlError::QueryFailed(format!(
363 "atomic write rolled back after batch {} failed: {err}",
364 report.rejected_batches.last().map_or(0, |b| b.batch_index)
365 )));
366 }
367 commit_transaction(dst, backend)?;
368 let _ = backend_needs_explicit_commit(backend);
372 }
373
374 Ok(report)
375}
376
377fn record_batch_rejection(
379 report: &mut WriteReport,
380 batch_index: usize,
381 start_row: u64,
382 row_count: usize,
383 err: &SqlError,
384) {
385 report.rejected_batches.push(RejectedBatch {
386 batch_index,
387 start_row,
388 row_count,
389 error: err.to_string(),
390 });
391}
392
393#[allow(clippy::too_many_arguments)]
396fn probe_rows(
397 dst: &mut dyn Connection,
398 table: &str,
399 columns: &[ColumnInfo],
400 key_columns: &[String],
401 quoted_table: &str,
402 cols_clause: &str,
403 batch: &[Row],
404 backend: Backend,
405 if_exists: IfExists,
406 copy_format: CopyFormat,
407 verbose: bool,
408 start_row: u64,
409 report: &mut WriteReport,
410) -> u64 {
411 let mut written = 0u64;
412 for (offset, row) in batch.iter().enumerate() {
413 let single = std::slice::from_ref(row);
414 match insert_batch(
417 dst,
418 table,
419 columns,
420 key_columns,
421 quoted_table,
422 cols_clause,
423 single,
424 backend,
425 if_exists,
426 BulkMode::Off,
427 copy_format,
428 verbose,
429 ) {
430 Ok(()) => written += 1,
431 Err(err) => report.rejected_rows.push(RejectedRow {
432 row_index: start_row + offset as u64,
433 error: err.to_string(),
434 }),
435 }
436 }
437 written
438}
439
440#[cfg(all(test, feature = "sqlite"))]
444mod tests {
445 use super::*;
446 use crate::connection::ConnectOptions;
447 use crate::url::DatabaseUrl;
448 use crate::value::{TypeHint, Value};
449 use std::sync::atomic::{AtomicU64, Ordering};
450
451 static CTR: AtomicU64 = AtomicU64::new(0);
452
453 fn fresh_sqlite() -> (Box<dyn Connection>, std::path::PathBuf) {
454 let pid = std::process::id();
455 let n = CTR.fetch_add(1, Ordering::SeqCst);
456 let path = std::env::temp_dir().join(format!("ferrule-write-test-{pid}-{n}.db"));
457 let _ = std::fs::remove_file(&path);
458 let url = DatabaseUrl::parse(&format!("sqlite://{}", path.display())).unwrap();
459 let conn = crate::connect(&url, &ConnectOptions::default(), None).unwrap();
460 (conn, path)
461 }
462
463 fn col(name: &str) -> ColumnInfo {
464 ColumnInfo {
465 name: name.to_string(),
466 type_hint: TypeHint::Other,
467 nullable: true,
468 }
469 }
470
471 #[test]
475 fn write_rows_round_trip_in_bounded_batches() {
476 let (mut conn, path) = fresh_sqlite();
477 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)")
478 .unwrap();
479 let columns = vec![col("id"), col("name")];
480 let rows: Vec<Row> = (1..=2500)
481 .map(|i| vec![Value::Int64(i), Value::String(format!("n{i}"))])
482 .collect();
483 let opts = WriteOptions {
484 batch_size: 100,
485 ..Default::default()
486 };
487 let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
488 assert_eq!(report.rows_attempted, 2500);
489 assert_eq!(report.rows_written, 2500);
490 assert_eq!(report.batches_committed, 25);
492 assert!(report.is_complete());
493
494 let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
495 assert!(matches!(back.rows[0][0], Value::Int64(2500)));
496 let _ = std::fs::remove_file(&path);
497 }
498
499 #[test]
503 fn write_rows_rejects_failing_batch_structurally() {
504 let (mut conn, path) = fresh_sqlite();
505 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
506 .unwrap();
507 conn.execute("INSERT INTO t VALUES (5)").unwrap();
509 let columns = vec![col("id")];
510 let rows: Vec<Row> = (1..=8).map(|i| vec![Value::Int64(i)]).collect();
512 let opts = WriteOptions {
513 batch_size: 4,
514 ..Default::default()
515 };
516 let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
517 assert_eq!(report.rows_attempted, 8);
518 assert_eq!(report.rows_written, 4, "only the clean batch landed");
519 assert_eq!(report.batches_committed, 1);
520 assert_eq!(report.rejected_batches.len(), 1);
521 let rej = &report.rejected_batches[0];
522 assert_eq!(rej.batch_index, 1);
523 assert_eq!(rej.start_row, 4);
524 assert_eq!(rej.row_count, 4);
525 assert!(!report.is_complete());
526 let _ = std::fs::remove_file(&path);
527 }
528
529 #[test]
533 fn write_rows_isolates_offending_row() {
534 let (mut conn, path) = fresh_sqlite();
535 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
536 .unwrap();
537 conn.execute("INSERT INTO t VALUES (3)").unwrap();
538 let columns = vec![col("id")];
539 let rows: Vec<Row> = (1..=4).map(|i| vec![Value::Int64(i)]).collect();
541 let opts = WriteOptions {
542 batch_size: 10,
543 isolate_failures: true,
544 ..Default::default()
545 };
546 let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
547 assert_eq!(report.rows_written, 3, "1,2,4 landed; 3 rejected");
548 assert_eq!(report.rejected_batches.len(), 0);
549 assert_eq!(report.rejected_rows.len(), 1);
550 assert_eq!(
551 report.rejected_rows[0].row_index, 2,
552 "0-based index of id=3"
553 );
554 let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
555 assert!(matches!(back.rows[0][0], Value::Int64(4)));
556 let _ = std::fs::remove_file(&path);
557 }
558
559 #[test]
562 fn write_rows_atomic_rolls_back_on_failure() {
563 let (mut conn, path) = fresh_sqlite();
564 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
565 .unwrap();
566 conn.execute("INSERT INTO t VALUES (7)").unwrap();
567 let columns = vec![col("id")];
568 let rows: Vec<Row> = vec![1, 2, 7, 8]
570 .into_iter()
571 .map(|i| vec![Value::Int64(i)])
572 .collect();
573 let opts = WriteOptions {
574 batch_size: 2,
575 atomic: true,
576 ..Default::default()
577 };
578 let err = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts)
579 .expect_err("atomic write must surface the failure");
580 assert!(matches!(err, SqlError::QueryFailed(_)));
581 let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
583 assert!(
584 matches!(back.rows[0][0], Value::Int64(1)),
585 "atomic rollback left only the pre-existing row"
586 );
587 let _ = std::fs::remove_file(&path);
588 }
589
590 #[test]
593 fn write_rows_upsert_overwrites_by_key() {
594 let (mut conn, path) = fresh_sqlite();
595 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)")
596 .unwrap();
597 conn.execute("INSERT INTO t VALUES (1, 'old')").unwrap();
598 let columns = vec![col("id"), col("v")];
599 let rows: Vec<Row> = vec![
600 vec![Value::Int64(1), Value::String("new".into())],
601 vec![Value::Int64(2), Value::String("two".into())],
602 ];
603 let opts = WriteOptions {
604 mode: WriteMode::Upsert,
605 key_columns: vec!["id".into()],
606 ..Default::default()
607 };
608 let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
609 assert!(report.is_complete());
610 let v1 = conn.query("SELECT v FROM t WHERE id = 1").unwrap();
611 assert!(matches!(&v1.rows[0][0], Value::String(s) if s == "new"));
612 let _ = std::fs::remove_file(&path);
613 }
614
615 #[test]
617 fn write_rows_conflict_mode_requires_key() {
618 let (mut conn, path) = fresh_sqlite();
619 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
620 .unwrap();
621 let columns = vec![col("id")];
622 let opts = WriteOptions {
623 mode: WriteMode::Skip,
624 ..Default::default()
625 };
626 let err = write_rows(
627 &mut *conn,
628 Backend::Sqlite,
629 "t",
630 &columns,
631 vec![vec![Value::Int64(1)]],
632 &opts,
633 )
634 .expect_err("skip without key must fail fast");
635 assert!(matches!(err, SqlError::QueryFailed(_)));
636 let _ = std::fs::remove_file(&path);
637 }
638
639 #[test]
641 fn write_rows_unknown_key_column_fails_fast() {
642 let (mut conn, path) = fresh_sqlite();
643 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
644 .unwrap();
645 let columns = vec![col("id")];
646 let opts = WriteOptions {
647 mode: WriteMode::Upsert,
648 key_columns: vec!["nonexistent".into()],
649 ..Default::default()
650 };
651 let err = write_rows(
652 &mut *conn,
653 Backend::Sqlite,
654 "t",
655 &columns,
656 vec![vec![Value::Int64(1)]],
657 &opts,
658 )
659 .expect_err("unknown key column must fail fast");
660 assert!(matches!(err, SqlError::QueryFailed(_)));
661 let _ = std::fs::remove_file(&path);
662 }
663}