use crate::filter::FilterValue;
use crate::sql::{DatabaseType, FastSqlBuilder, QueryCapacity};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Batch {
operations: Vec<BatchOperation>,
}
impl Batch {
pub fn new() -> Self {
Self {
operations: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
operations: Vec::with_capacity(capacity),
}
}
pub fn add(&mut self, op: BatchOperation) {
self.operations.push(op);
}
pub fn operations(&self) -> &[BatchOperation] {
&self.operations
}
pub fn len(&self) -> usize {
self.operations.len()
}
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
pub fn to_combined_sql(&self, db_type: DatabaseType) -> Option<(String, Vec<FilterValue>)> {
if self.operations.is_empty() {
return None;
}
let mut inserts: HashMap<&str, Vec<&BatchOperation>> = HashMap::new();
let mut other_ops = Vec::new();
for op in &self.operations {
match op {
BatchOperation::Insert { table, .. } => {
inserts.entry(table.as_str()).or_default().push(op);
}
_ => other_ops.push(op),
}
}
if !other_ops.is_empty() || inserts.len() > 1 {
return None;
}
if let Some((table, ops)) = inserts.into_iter().next() {
return self.combine_inserts(table, &ops, db_type);
}
None
}
fn combine_inserts(
&self,
table: &str,
ops: &[&BatchOperation],
db_type: DatabaseType,
) -> Option<(String, Vec<FilterValue>)> {
if ops.is_empty() {
return None;
}
let first_columns: Vec<&str> = match &ops[0] {
BatchOperation::Insert { data, .. } => data.keys().map(String::as_str).collect(),
_ => return None,
};
for op in ops.iter().skip(1) {
if let BatchOperation::Insert { data, .. } = op {
let cols: Vec<&str> = data.keys().map(String::as_str).collect();
if cols.len() != first_columns.len() {
return None;
}
}
}
let cols_per_row = first_columns.len();
let total_params = cols_per_row * ops.len();
let mut builder =
FastSqlBuilder::with_capacity(db_type, QueryCapacity::Custom(64 + total_params * 8));
builder.push_str("INSERT INTO ");
builder.push_str(table);
builder.push_str(" (");
for (i, col) in first_columns.iter().enumerate() {
if i > 0 {
builder.push_str(", ");
}
builder.push_str(col);
}
builder.push_str(") VALUES ");
let mut all_params = Vec::with_capacity(total_params);
for (row_idx, op) in ops.iter().enumerate() {
if row_idx > 0 {
builder.push_str(", ");
}
builder.push_char('(');
if let BatchOperation::Insert { data, .. } = op {
for (col_idx, col) in first_columns.iter().enumerate() {
if col_idx > 0 {
builder.push_str(", ");
}
builder.bind(data.get(*col).cloned().unwrap_or(FilterValue::Null));
if let Some(val) = data.get(*col) {
all_params.push(val.clone());
} else {
all_params.push(FilterValue::Null);
}
}
}
builder.push_char(')');
}
Some(builder.build())
}
}
impl Default for Batch {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum BatchOperation {
Insert {
table: String,
data: HashMap<String, FilterValue>,
},
Update {
table: String,
filter: HashMap<String, FilterValue>,
data: HashMap<String, FilterValue>,
},
Delete {
table: String,
filter: HashMap<String, FilterValue>,
},
Raw {
sql: String,
params: Vec<FilterValue>,
},
}
impl BatchOperation {
pub fn insert(table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
Self::Insert {
table: table.into(),
data,
}
}
pub fn update(
table: impl Into<String>,
filter: HashMap<String, FilterValue>,
data: HashMap<String, FilterValue>,
) -> Self {
Self::Update {
table: table.into(),
filter,
data,
}
}
pub fn delete(table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
Self::Delete {
table: table.into(),
filter,
}
}
pub fn raw(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
Self::Raw {
sql: sql.into(),
params,
}
}
}
#[derive(Debug, Default)]
pub struct BatchBuilder {
batch: Batch,
}
impl BatchBuilder {
pub fn new() -> Self {
Self {
batch: Batch::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
batch: Batch::with_capacity(capacity),
}
}
pub fn insert(mut self, table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
self.batch.add(BatchOperation::insert(table, data));
self
}
pub fn update(
mut self,
table: impl Into<String>,
filter: HashMap<String, FilterValue>,
data: HashMap<String, FilterValue>,
) -> Self {
self.batch.add(BatchOperation::update(table, filter, data));
self
}
pub fn delete(
mut self,
table: impl Into<String>,
filter: HashMap<String, FilterValue>,
) -> Self {
self.batch.add(BatchOperation::delete(table, filter));
self
}
pub fn raw(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
self.batch.add(BatchOperation::raw(sql, params));
self
}
pub fn build(self) -> Batch {
self.batch
}
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub results: Vec<OperationResult>,
pub total_affected: u64,
}
impl BatchResult {
pub fn new(results: Vec<OperationResult>) -> Self {
let total_affected = results.iter().map(|r| r.rows_affected).sum();
Self {
results,
total_affected,
}
}
pub fn len(&self) -> usize {
self.results.len()
}
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
pub fn all_succeeded(&self) -> bool {
self.results.iter().all(|r| r.success)
}
}
#[derive(Debug, Clone)]
pub struct OperationResult {
pub success: bool,
pub rows_affected: u64,
pub error: Option<String>,
}
impl OperationResult {
pub fn success(rows_affected: u64) -> Self {
Self {
success: true,
rows_affected,
error: None,
}
}
pub fn failure(error: impl Into<String>) -> Self {
Self {
success: false,
rows_affected: 0,
error: Some(error.into()),
}
}
}
#[derive(Debug, Clone)]
pub struct Pipeline {
queries: Vec<PipelineQuery>,
}
impl Pipeline {
pub fn new() -> Self {
Self {
queries: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
queries: Vec::with_capacity(capacity),
}
}
pub fn push(&mut self, sql: impl Into<String>, params: Vec<FilterValue>) {
self.queries.push(PipelineQuery {
sql: sql.into(),
params,
expect_rows: true,
});
}
pub fn push_execute(&mut self, sql: impl Into<String>, params: Vec<FilterValue>) {
self.queries.push(PipelineQuery {
sql: sql.into(),
params,
expect_rows: false,
});
}
pub fn queries(&self) -> &[PipelineQuery] {
&self.queries
}
pub fn len(&self) -> usize {
self.queries.len()
}
pub fn is_empty(&self) -> bool {
self.queries.is_empty()
}
}
impl Default for Pipeline {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PipelineQuery {
pub sql: String,
pub params: Vec<FilterValue>,
pub expect_rows: bool,
}
#[derive(Debug, Clone)]
pub struct PipelineBuilder {
pipeline: Pipeline,
}
impl PipelineBuilder {
pub fn new() -> Self {
Self {
pipeline: Pipeline::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
pipeline: Pipeline::with_capacity(capacity),
}
}
pub fn query(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
self.pipeline.push(sql, params);
self
}
pub fn execute(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
self.pipeline.push_execute(sql, params);
self
}
pub fn build(self) -> Pipeline {
self.pipeline
}
}
impl Default for PipelineBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct PipelineResult {
pub query_results: Vec<QueryResult>,
}
#[derive(Debug)]
pub enum QueryResult {
Rows {
count: usize,
},
Executed {
rows_affected: u64,
},
Error {
message: String,
},
}
impl PipelineResult {
pub fn new(query_results: Vec<QueryResult>) -> Self {
Self { query_results }
}
pub fn all_succeeded(&self) -> bool {
self.query_results
.iter()
.all(|r| !matches!(r, QueryResult::Error { .. }))
}
pub fn first_error(&self) -> Option<&str> {
self.query_results.iter().find_map(|r| {
if let QueryResult::Error { message } = r {
Some(message.as_str())
} else {
None
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_builder() {
let mut data1 = HashMap::new();
data1.insert("name".to_string(), FilterValue::String("Alice".into()));
let mut data2 = HashMap::new();
data2.insert("name".to_string(), FilterValue::String("Bob".into()));
let batch = BatchBuilder::new()
.insert("users", data1)
.insert("users", data2)
.build();
assert_eq!(batch.len(), 2);
}
#[test]
fn test_combine_inserts_postgres() {
let mut data1 = HashMap::new();
data1.insert("name".to_string(), FilterValue::String("Alice".into()));
data1.insert("age".to_string(), FilterValue::Int(30));
let mut data2 = HashMap::new();
data2.insert("name".to_string(), FilterValue::String("Bob".into()));
data2.insert("age".to_string(), FilterValue::Int(25));
let batch = BatchBuilder::new()
.insert("users", data1)
.insert("users", data2)
.build();
let result = batch.to_combined_sql(DatabaseType::PostgreSQL);
assert!(result.is_some());
let (sql, _) = result.unwrap();
assert!(sql.starts_with("INSERT INTO users"));
assert!(sql.contains("VALUES"));
}
#[test]
fn test_batch_result() {
let results = vec![
OperationResult::success(1),
OperationResult::success(1),
OperationResult::success(1),
];
let batch_result = BatchResult::new(results);
assert_eq!(batch_result.total_affected, 3);
assert!(batch_result.all_succeeded());
}
#[test]
fn test_batch_result_with_failure() {
let results = vec![
OperationResult::success(1),
OperationResult::failure("constraint violation"),
OperationResult::success(1),
];
let batch_result = BatchResult::new(results);
assert_eq!(batch_result.total_affected, 2);
assert!(!batch_result.all_succeeded());
}
}