use std::collections::HashMap;
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InsertAllConfig {
pub batch_size: usize,
pub id_field: String,
}
impl Default for InsertAllConfig {
fn default() -> Self {
Self {
batch_size: 1000,
id_field: "id".to_owned(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InsertResult {
pub inserted_count: usize,
pub inserted_ids: Vec<i64>,
pub batches: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UpsertResult {
pub inserted_count: usize,
pub updated_count: usize,
pub inserted_ids: Vec<i64>,
pub updated_ids: Vec<i64>,
pub batches: usize,
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum InsertAllError {
#[error("batch size must be greater than zero")]
InvalidBatchSize,
#[error("unique_by must not be empty")]
EmptyUniqueBy,
}
pub fn insert_all(records: &[HashMap<String, Value>]) -> Result<InsertResult, InsertAllError> {
insert_all_with_config(records, &InsertAllConfig::default())
}
pub fn insert_all_with_config(
records: &[HashMap<String, Value>],
config: &InsertAllConfig,
) -> Result<InsertResult, InsertAllError> {
if config.batch_size == 0 {
return Err(InsertAllError::InvalidBatchSize);
}
let inserted_ids = collect_ids(records, &config.id_field);
Ok(InsertResult {
inserted_count: records.len(),
batches: batch_count(records.len(), config.batch_size),
inserted_ids,
})
}
pub fn upsert_all(
records: &[HashMap<String, Value>],
unique_by: &str,
) -> Result<UpsertResult, InsertAllError> {
upsert_all_with_config(records, unique_by, &InsertAllConfig::default())
}
pub fn upsert_all_with_config(
records: &[HashMap<String, Value>],
unique_by: &str,
config: &InsertAllConfig,
) -> Result<UpsertResult, InsertAllError> {
if config.batch_size == 0 {
return Err(InsertAllError::InvalidBatchSize);
}
if unique_by.is_empty() {
return Err(InsertAllError::EmptyUniqueBy);
}
let mut ids_by_unique_value = HashMap::<String, i64>::new();
let mut inserted_ids = Vec::new();
let mut updated_ids = Vec::new();
let mut next_id = 1_i64;
for record in records {
let unique_key = record
.get(unique_by)
.map(unique_value_key)
.unwrap_or_else(|| format!("__missing__:{next_id}"));
let explicit_id = record.get(&config.id_field).and_then(Value::as_i64);
if let Some(existing_id) = ids_by_unique_value.get(&unique_key).copied() {
updated_ids.push(existing_id);
continue;
}
let id = explicit_id.unwrap_or_else(|| {
let assigned = next_id;
next_id += 1;
assigned
});
ids_by_unique_value.insert(unique_key, id);
inserted_ids.push(id);
}
Ok(UpsertResult {
inserted_count: inserted_ids.len(),
updated_count: updated_ids.len(),
inserted_ids,
updated_ids,
batches: batch_count(records.len(), config.batch_size),
})
}
fn collect_ids(records: &[HashMap<String, Value>], id_field: &str) -> Vec<i64> {
let mut next_id = 1_i64;
records
.iter()
.map(|record| {
let id = record
.get(id_field)
.and_then(Value::as_i64)
.unwrap_or(next_id);
next_id = next_id.max(id.saturating_add(1));
id
})
.collect()
}
fn batch_count(total: usize, batch_size: usize) -> usize {
if total == 0 {
0
} else {
((total - 1) / batch_size) + 1
}
}
fn unique_value_key(value: &Value) -> String {
match serde_json::to_string(value) {
Ok(serialized) => serialized,
Err(_) => "null".to_owned(),
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{InsertAllConfig, InsertAllError, insert_all, insert_all_with_config, upsert_all};
fn record(
id: Option<i64>,
email: &str,
) -> std::collections::HashMap<String, serde_json::Value> {
let mut record = std::collections::HashMap::from([("email".to_owned(), json!(email))]);
if let Some(id) = id {
record.insert("id".to_owned(), json!(id));
}
record
}
#[test]
fn insert_all_returns_inserted_count_and_ids() {
let result = insert_all(&[
record(Some(10), "a@example.com"),
record(None, "b@example.com"),
])
.expect("insert should succeed");
assert_eq!(result.inserted_count, 2);
assert_eq!(result.inserted_ids, vec![10, 11]);
assert_eq!(result.batches, 1);
}
#[test]
fn insert_all_respects_batch_size() {
let config = InsertAllConfig {
batch_size: 2,
..InsertAllConfig::default()
};
let result = insert_all_with_config(
&[record(None, "a"), record(None, "b"), record(None, "c")],
&config,
)
.expect("insert should succeed");
assert_eq!(result.batches, 2);
}
#[test]
fn insert_all_rejects_zero_batch_size() {
let config = InsertAllConfig {
batch_size: 0,
..InsertAllConfig::default()
};
assert_eq!(
insert_all_with_config(&[record(None, "a")], &config),
Err(InsertAllError::InvalidBatchSize)
);
}
#[test]
fn upsert_all_inserts_unique_rows_and_updates_duplicates() {
let result = upsert_all(
&[
record(Some(1), "a@example.com"),
record(Some(2), "b@example.com"),
record(Some(3), "a@example.com"),
],
"email",
)
.expect("upsert should succeed");
assert_eq!(result.inserted_count, 2);
assert_eq!(result.updated_count, 1);
assert_eq!(result.inserted_ids, vec![1, 2]);
assert_eq!(result.updated_ids, vec![1]);
}
#[test]
fn upsert_all_requires_unique_by() {
assert_eq!(
upsert_all(&[record(None, "a")], ""),
Err(InsertAllError::EmptyUniqueBy)
);
}
#[test]
fn upsert_all_handles_missing_unique_field_as_distinct_rows() {
let records = [
std::collections::HashMap::new(),
std::collections::HashMap::new(),
];
let result = upsert_all(&records, "email").expect("upsert should succeed");
assert_eq!(result.inserted_count, 2);
assert_eq!(result.updated_count, 0);
}
}