use crate::backend::Backend;
use crate::connection::Connection;
use crate::copy::{
BulkMode, CopyFormat, IfExists, backend_needs_explicit_commit, insert_batch, quote_identifier,
};
use crate::error::SqlError;
use crate::transaction::{begin_transaction, commit_transaction, rollback_transaction};
use crate::value::{ColumnInfo, Row};
pub const DEFAULT_WRITE_BATCH: usize = 1000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WriteMode {
#[default]
Insert,
Skip,
Upsert,
}
impl WriteMode {
fn if_exists(self) -> IfExists {
match self {
WriteMode::Insert => IfExists::Append,
WriteMode::Skip => IfExists::Skip,
WriteMode::Upsert => IfExists::Upsert,
}
}
fn needs_key(self) -> bool {
matches!(self, WriteMode::Skip | WriteMode::Upsert)
}
}
pub struct WriteOptions {
pub mode: WriteMode,
pub batch_size: usize,
pub key_columns: Vec<String>,
pub bulk_mode: BulkMode,
pub copy_format: CopyFormat,
pub atomic: bool,
pub isolate_failures: bool,
pub verbose: bool,
}
impl Default for WriteOptions {
fn default() -> Self {
Self {
mode: WriteMode::default(),
batch_size: DEFAULT_WRITE_BATCH,
key_columns: Vec::new(),
bulk_mode: BulkMode::Off,
copy_format: CopyFormat::Text,
atomic: false,
isolate_failures: false,
verbose: false,
}
}
}
impl std::fmt::Debug for WriteOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteOptions")
.field("mode", &self.mode)
.field("batch_size", &self.batch_size)
.field("key_columns", &self.key_columns)
.field("bulk_mode", &self.bulk_mode)
.field("copy_format", &self.copy_format)
.field("atomic", &self.atomic)
.field("isolate_failures", &self.isolate_failures)
.field("verbose", &self.verbose)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BatchOutcome {
Written,
Rejected,
}
#[derive(Debug, Clone)]
pub struct RejectedBatch {
pub batch_index: usize,
pub start_row: u64,
pub row_count: usize,
pub error: String,
}
#[derive(Debug, Clone)]
pub struct RejectedRow {
pub row_index: u64,
pub error: String,
}
#[derive(Debug, Clone, Default)]
pub struct WriteReport {
pub rows_attempted: u64,
pub rows_written: u64,
pub batches_committed: usize,
pub rejected_batches: Vec<RejectedBatch>,
pub rejected_rows: Vec<RejectedRow>,
}
impl WriteReport {
#[must_use]
pub fn is_complete(&self) -> bool {
self.rejected_batches.is_empty() && self.rejected_rows.is_empty()
}
}
pub fn write_rows<I>(
dst: &mut dyn Connection,
backend: Backend,
table: &str,
columns: &[ColumnInfo],
rows: I,
opts: &WriteOptions,
) -> Result<WriteReport, SqlError>
where
I: IntoIterator<Item = Row>,
{
if opts.mode.needs_key() && opts.key_columns.is_empty() {
return Err(SqlError::QueryFailed(format!(
"{:?} write mode requires key_columns (conflict key); none supplied",
opts.mode
)));
}
for key in &opts.key_columns {
if !columns.iter().any(|c| &c.name == key) {
return Err(SqlError::QueryFailed(format!(
"key column {key:?} is not among the destination columns"
)));
}
}
let batch_size = if opts.batch_size == 0 {
DEFAULT_WRITE_BATCH
} else {
opts.batch_size
};
let if_exists = opts.mode.if_exists();
let quoted_table = quote_identifier(table, backend);
let cols_clause = columns
.iter()
.map(|c| quote_identifier(&c.name, backend))
.collect::<Vec<_>>()
.join(", ");
let mut report = WriteReport::default();
let atomic_opened = if opts.atomic {
#[cfg(feature = "mssql")]
if matches!(backend, Backend::MsSql) {
let _ = dst.execute("SET XACT_ABORT ON");
}
begin_transaction(dst, backend)
} else {
false
};
let mut iter = rows.into_iter();
let mut batch: Vec<Row> = Vec::with_capacity(batch_size);
let mut batch_index = 0usize;
let mut next_row: u64 = 0;
let mut atomic_failure: Option<SqlError> = None;
loop {
batch.clear();
for _ in 0..batch_size {
match iter.next() {
Some(row) => batch.push(row),
None => break,
}
}
if batch.is_empty() {
break;
}
let start_row = next_row;
let n = batch.len();
report.rows_attempted += n as u64;
next_row += n as u64;
match insert_batch(
dst,
table,
columns,
&opts.key_columns,
"ed_table,
&cols_clause,
&batch,
backend,
if_exists,
opts.bulk_mode,
opts.copy_format,
opts.verbose,
) {
Ok(()) => {
report.rows_written += n as u64;
report.batches_committed += 1;
}
Err(err) => {
if atomic_opened {
record_batch_rejection(&mut report, batch_index, start_row, n, &err);
atomic_failure = Some(err);
break;
}
if opts.isolate_failures {
let written = probe_rows(
dst,
table,
columns,
&opts.key_columns,
"ed_table,
&cols_clause,
&batch,
backend,
if_exists,
opts.copy_format,
opts.verbose,
start_row,
&mut report,
);
report.rows_written += written;
} else {
record_batch_rejection(&mut report, batch_index, start_row, n, &err);
}
}
}
batch_index += 1;
}
if atomic_opened {
if let Some(err) = atomic_failure {
let _ = rollback_transaction(dst, backend);
report.rows_written = 0;
report.batches_committed = 0;
return Err(SqlError::QueryFailed(format!(
"atomic write rolled back after batch {} failed: {err}",
report.rejected_batches.last().map_or(0, |b| b.batch_index)
)));
}
commit_transaction(dst, backend)?;
let _ = backend_needs_explicit_commit(backend);
}
Ok(report)
}
fn record_batch_rejection(
report: &mut WriteReport,
batch_index: usize,
start_row: u64,
row_count: usize,
err: &SqlError,
) {
report.rejected_batches.push(RejectedBatch {
batch_index,
start_row,
row_count,
error: err.to_string(),
});
}
#[allow(clippy::too_many_arguments)]
fn probe_rows(
dst: &mut dyn Connection,
table: &str,
columns: &[ColumnInfo],
key_columns: &[String],
quoted_table: &str,
cols_clause: &str,
batch: &[Row],
backend: Backend,
if_exists: IfExists,
copy_format: CopyFormat,
verbose: bool,
start_row: u64,
report: &mut WriteReport,
) -> u64 {
let mut written = 0u64;
for (offset, row) in batch.iter().enumerate() {
let single = std::slice::from_ref(row);
match insert_batch(
dst,
table,
columns,
key_columns,
quoted_table,
cols_clause,
single,
backend,
if_exists,
BulkMode::Off,
copy_format,
verbose,
) {
Ok(()) => written += 1,
Err(err) => report.rejected_rows.push(RejectedRow {
row_index: start_row + offset as u64,
error: err.to_string(),
}),
}
}
written
}
#[cfg(all(test, feature = "sqlite"))]
mod tests {
use super::*;
use crate::connection::ConnectOptions;
use crate::url::DatabaseUrl;
use crate::value::{TypeHint, Value};
use std::sync::atomic::{AtomicU64, Ordering};
static CTR: AtomicU64 = AtomicU64::new(0);
fn fresh_sqlite() -> (Box<dyn Connection>, std::path::PathBuf) {
let pid = std::process::id();
let n = CTR.fetch_add(1, Ordering::SeqCst);
let path = std::env::temp_dir().join(format!("ferrule-write-test-{pid}-{n}.db"));
let _ = std::fs::remove_file(&path);
let url = DatabaseUrl::parse(&format!("sqlite://{}", path.display())).unwrap();
let conn = crate::connect(&url, &ConnectOptions::default(), None).unwrap();
(conn, path)
}
fn col(name: &str) -> ColumnInfo {
ColumnInfo {
name: name.to_string(),
type_hint: TypeHint::Other,
nullable: true,
}
}
#[test]
fn write_rows_round_trip_in_bounded_batches() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)")
.unwrap();
let columns = vec![col("id"), col("name")];
let rows: Vec<Row> = (1..=2500)
.map(|i| vec![Value::Int64(i), Value::String(format!("n{i}"))])
.collect();
let opts = WriteOptions {
batch_size: 100,
..Default::default()
};
let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
assert_eq!(report.rows_attempted, 2500);
assert_eq!(report.rows_written, 2500);
assert_eq!(report.batches_committed, 25);
assert!(report.is_complete());
let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
assert!(matches!(back.rows[0][0], Value::Int64(2500)));
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_rejects_failing_batch_structurally() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
.unwrap();
conn.execute("INSERT INTO t VALUES (5)").unwrap();
let columns = vec![col("id")];
let rows: Vec<Row> = (1..=8).map(|i| vec![Value::Int64(i)]).collect();
let opts = WriteOptions {
batch_size: 4,
..Default::default()
};
let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
assert_eq!(report.rows_attempted, 8);
assert_eq!(report.rows_written, 4, "only the clean batch landed");
assert_eq!(report.batches_committed, 1);
assert_eq!(report.rejected_batches.len(), 1);
let rej = &report.rejected_batches[0];
assert_eq!(rej.batch_index, 1);
assert_eq!(rej.start_row, 4);
assert_eq!(rej.row_count, 4);
assert!(!report.is_complete());
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_isolates_offending_row() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
.unwrap();
conn.execute("INSERT INTO t VALUES (3)").unwrap();
let columns = vec![col("id")];
let rows: Vec<Row> = (1..=4).map(|i| vec![Value::Int64(i)]).collect();
let opts = WriteOptions {
batch_size: 10,
isolate_failures: true,
..Default::default()
};
let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
assert_eq!(report.rows_written, 3, "1,2,4 landed; 3 rejected");
assert_eq!(report.rejected_batches.len(), 0);
assert_eq!(report.rejected_rows.len(), 1);
assert_eq!(
report.rejected_rows[0].row_index, 2,
"0-based index of id=3"
);
let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
assert!(matches!(back.rows[0][0], Value::Int64(4)));
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_atomic_rolls_back_on_failure() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
.unwrap();
conn.execute("INSERT INTO t VALUES (7)").unwrap();
let columns = vec![col("id")];
let rows: Vec<Row> = vec![1, 2, 7, 8]
.into_iter()
.map(|i| vec![Value::Int64(i)])
.collect();
let opts = WriteOptions {
batch_size: 2,
atomic: true,
..Default::default()
};
let err = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts)
.expect_err("atomic write must surface the failure");
assert!(matches!(err, SqlError::QueryFailed(_)));
let back = conn.query("SELECT COUNT(*) FROM t").unwrap();
assert!(
matches!(back.rows[0][0], Value::Int64(1)),
"atomic rollback left only the pre-existing row"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_upsert_overwrites_by_key() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)")
.unwrap();
conn.execute("INSERT INTO t VALUES (1, 'old')").unwrap();
let columns = vec![col("id"), col("v")];
let rows: Vec<Row> = vec![
vec![Value::Int64(1), Value::String("new".into())],
vec![Value::Int64(2), Value::String("two".into())],
];
let opts = WriteOptions {
mode: WriteMode::Upsert,
key_columns: vec!["id".into()],
..Default::default()
};
let report = write_rows(&mut *conn, Backend::Sqlite, "t", &columns, rows, &opts).unwrap();
assert!(report.is_complete());
let v1 = conn.query("SELECT v FROM t WHERE id = 1").unwrap();
assert!(matches!(&v1.rows[0][0], Value::String(s) if s == "new"));
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_conflict_mode_requires_key() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
.unwrap();
let columns = vec![col("id")];
let opts = WriteOptions {
mode: WriteMode::Skip,
..Default::default()
};
let err = write_rows(
&mut *conn,
Backend::Sqlite,
"t",
&columns,
vec![vec![Value::Int64(1)]],
&opts,
)
.expect_err("skip without key must fail fast");
assert!(matches!(err, SqlError::QueryFailed(_)));
let _ = std::fs::remove_file(&path);
}
#[test]
fn write_rows_unknown_key_column_fails_fast() {
let (mut conn, path) = fresh_sqlite();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
.unwrap();
let columns = vec![col("id")];
let opts = WriteOptions {
mode: WriteMode::Upsert,
key_columns: vec!["nonexistent".into()],
..Default::default()
};
let err = write_rows(
&mut *conn,
Backend::Sqlite,
"t",
&columns,
vec![vec![Value::Int64(1)]],
&opts,
)
.expect_err("unknown key column must fail fast");
assert!(matches!(err, SqlError::QueryFailed(_)));
let _ = std::fs::remove_file(&path);
}
}