prax_query/
batch.rs

1//! Batch query execution for combining multiple operations.
2//!
3//! This module provides utilities for executing multiple queries in a single
4//! database round-trip, improving performance for bulk operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use prax_query::batch::BatchBuilder;
10//!
11//! let batch = BatchBuilder::new()
12//!     .insert("users", &user1_data)
13//!     .insert("users", &user2_data)
14//!     .insert("users", &user3_data)
15//!     .build();
16//!
17//! let results = engine.execute_batch(batch).await?;
18//! ```
19
20use std::collections::HashMap;
21use crate::filter::FilterValue;
22use crate::sql::{DatabaseType, FastSqlBuilder, QueryCapacity};
23
24/// A batch of operations to execute together.
25#[derive(Debug, Clone)]
26pub struct Batch {
27    /// The operations in the batch.
28    operations: Vec<BatchOperation>,
29}
30
31impl Batch {
32    /// Create a new empty batch.
33    pub fn new() -> Self {
34        Self {
35            operations: Vec::new(),
36        }
37    }
38
39    /// Create a batch with pre-allocated capacity.
40    pub fn with_capacity(capacity: usize) -> Self {
41        Self {
42            operations: Vec::with_capacity(capacity),
43        }
44    }
45
46    /// Add an operation to the batch.
47    pub fn add(&mut self, op: BatchOperation) {
48        self.operations.push(op);
49    }
50
51    /// Get the operations in the batch.
52    pub fn operations(&self) -> &[BatchOperation] {
53        &self.operations
54    }
55
56    /// Get the number of operations.
57    pub fn len(&self) -> usize {
58        self.operations.len()
59    }
60
61    /// Check if the batch is empty.
62    pub fn is_empty(&self) -> bool {
63        self.operations.is_empty()
64    }
65
66    /// Convert the batch to a single SQL statement for databases that support it.
67    ///
68    /// This combines multiple INSERT statements into a single multi-row INSERT.
69    pub fn to_combined_sql(&self, db_type: DatabaseType) -> Option<(String, Vec<FilterValue>)> {
70        if self.operations.is_empty() {
71            return None;
72        }
73
74        // Group operations by type and table
75        let mut inserts: HashMap<&str, Vec<&BatchOperation>> = HashMap::new();
76        let mut other_ops = Vec::new();
77
78        for op in &self.operations {
79            match op {
80                BatchOperation::Insert { table, .. } => {
81                    inserts.entry(table.as_str()).or_default().push(op);
82                }
83                _ => other_ops.push(op),
84            }
85        }
86
87        // If we have non-insert operations or multiple tables, can't combine
88        if !other_ops.is_empty() || inserts.len() > 1 {
89            return None;
90        }
91
92        // Combine inserts for a single table
93        if let Some((table, ops)) = inserts.into_iter().next() {
94            return self.combine_inserts(table, &ops, db_type);
95        }
96
97        None
98    }
99
100    /// Combine multiple INSERT operations into a single multi-row INSERT.
101    fn combine_inserts(
102        &self,
103        table: &str,
104        ops: &[&BatchOperation],
105        db_type: DatabaseType,
106    ) -> Option<(String, Vec<FilterValue>)> {
107        if ops.is_empty() {
108            return None;
109        }
110
111        // Get columns from first insert
112        let first_columns: Vec<&str> = match &ops[0] {
113            BatchOperation::Insert { data, .. } => data.keys().map(String::as_str).collect(),
114            _ => return None,
115        };
116
117        // Verify all inserts have the same columns
118        for op in ops.iter().skip(1) {
119            if let BatchOperation::Insert { data, .. } = op {
120                let cols: Vec<&str> = data.keys().map(String::as_str).collect();
121                if cols.len() != first_columns.len() {
122                    return None;
123                }
124            }
125        }
126
127        // Build combined INSERT
128        let cols_per_row = first_columns.len();
129        let total_params = cols_per_row * ops.len();
130
131        let mut builder = FastSqlBuilder::with_capacity(
132            db_type,
133            QueryCapacity::Custom(64 + total_params * 8),
134        );
135
136        builder.push_str("INSERT INTO ");
137        builder.push_str(table);
138        builder.push_str(" (");
139
140        for (i, col) in first_columns.iter().enumerate() {
141            if i > 0 {
142                builder.push_str(", ");
143            }
144            builder.push_str(col);
145        }
146
147        builder.push_str(") VALUES ");
148
149        let mut all_params = Vec::with_capacity(total_params);
150
151        for (row_idx, op) in ops.iter().enumerate() {
152            if row_idx > 0 {
153                builder.push_str(", ");
154            }
155            builder.push_char('(');
156
157            if let BatchOperation::Insert { data, .. } = op {
158                for (col_idx, col) in first_columns.iter().enumerate() {
159                    if col_idx > 0 {
160                        builder.push_str(", ");
161                    }
162                    builder.bind(data.get(*col).cloned().unwrap_or(FilterValue::Null));
163                    if let Some(val) = data.get(*col) {
164                        all_params.push(val.clone());
165                    } else {
166                        all_params.push(FilterValue::Null);
167                    }
168                }
169            }
170
171            builder.push_char(')');
172        }
173
174        Some(builder.build())
175    }
176}
177
178impl Default for Batch {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// A single operation in a batch.
185#[derive(Debug, Clone)]
186pub enum BatchOperation {
187    /// An INSERT operation.
188    Insert {
189        /// The table name.
190        table: String,
191        /// The data to insert.
192        data: HashMap<String, FilterValue>,
193    },
194    /// An UPDATE operation.
195    Update {
196        /// The table name.
197        table: String,
198        /// The filter for which rows to update.
199        filter: HashMap<String, FilterValue>,
200        /// The data to update.
201        data: HashMap<String, FilterValue>,
202    },
203    /// A DELETE operation.
204    Delete {
205        /// The table name.
206        table: String,
207        /// The filter for which rows to delete.
208        filter: HashMap<String, FilterValue>,
209    },
210    /// A raw SQL operation.
211    Raw {
212        /// The SQL query.
213        sql: String,
214        /// The parameters.
215        params: Vec<FilterValue>,
216    },
217}
218
219impl BatchOperation {
220    /// Create an INSERT operation.
221    pub fn insert(table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
222        Self::Insert {
223            table: table.into(),
224            data,
225        }
226    }
227
228    /// Create an UPDATE operation.
229    pub fn update(
230        table: impl Into<String>,
231        filter: HashMap<String, FilterValue>,
232        data: HashMap<String, FilterValue>,
233    ) -> Self {
234        Self::Update {
235            table: table.into(),
236            filter,
237            data,
238        }
239    }
240
241    /// Create a DELETE operation.
242    pub fn delete(table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
243        Self::Delete {
244            table: table.into(),
245            filter,
246        }
247    }
248
249    /// Create a raw SQL operation.
250    pub fn raw(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
251        Self::Raw {
252            sql: sql.into(),
253            params,
254        }
255    }
256}
257
258/// Builder for creating batches fluently.
259#[derive(Debug, Default)]
260pub struct BatchBuilder {
261    batch: Batch,
262}
263
264impl BatchBuilder {
265    /// Create a new batch builder.
266    pub fn new() -> Self {
267        Self {
268            batch: Batch::new(),
269        }
270    }
271
272    /// Create a builder with pre-allocated capacity.
273    pub fn with_capacity(capacity: usize) -> Self {
274        Self {
275            batch: Batch::with_capacity(capacity),
276        }
277    }
278
279    /// Add an INSERT operation.
280    pub fn insert(mut self, table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
281        self.batch.add(BatchOperation::insert(table, data));
282        self
283    }
284
285    /// Add an UPDATE operation.
286    pub fn update(
287        mut self,
288        table: impl Into<String>,
289        filter: HashMap<String, FilterValue>,
290        data: HashMap<String, FilterValue>,
291    ) -> Self {
292        self.batch.add(BatchOperation::update(table, filter, data));
293        self
294    }
295
296    /// Add a DELETE operation.
297    pub fn delete(mut self, table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
298        self.batch.add(BatchOperation::delete(table, filter));
299        self
300    }
301
302    /// Add a raw SQL operation.
303    pub fn raw(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
304        self.batch.add(BatchOperation::raw(sql, params));
305        self
306    }
307
308    /// Build the batch.
309    pub fn build(self) -> Batch {
310        self.batch
311    }
312}
313
314/// Result of a batch execution.
315#[derive(Debug, Clone)]
316pub struct BatchResult {
317    /// Results for each operation.
318    pub results: Vec<OperationResult>,
319    /// Total rows affected across all operations.
320    pub total_affected: u64,
321}
322
323impl BatchResult {
324    /// Create a new batch result.
325    pub fn new(results: Vec<OperationResult>) -> Self {
326        let total_affected = results.iter().map(|r| r.rows_affected).sum();
327        Self {
328            results,
329            total_affected,
330        }
331    }
332
333    /// Get the number of operations.
334    pub fn len(&self) -> usize {
335        self.results.len()
336    }
337
338    /// Check if empty.
339    pub fn is_empty(&self) -> bool {
340        self.results.is_empty()
341    }
342
343    /// Check if all operations succeeded.
344    pub fn all_succeeded(&self) -> bool {
345        self.results.iter().all(|r| r.success)
346    }
347}
348
349/// Result of a single operation in a batch.
350#[derive(Debug, Clone)]
351pub struct OperationResult {
352    /// Whether the operation succeeded.
353    pub success: bool,
354    /// Number of rows affected.
355    pub rows_affected: u64,
356    /// Error message if failed.
357    pub error: Option<String>,
358}
359
360impl OperationResult {
361    /// Create a successful result.
362    pub fn success(rows_affected: u64) -> Self {
363        Self {
364            success: true,
365            rows_affected,
366            error: None,
367        }
368    }
369
370    /// Create a failed result.
371    pub fn failure(error: impl Into<String>) -> Self {
372        Self {
373            success: false,
374            rows_affected: 0,
375            error: Some(error.into()),
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_batch_builder() {
386        let mut data1 = HashMap::new();
387        data1.insert("name".to_string(), FilterValue::String("Alice".into()));
388
389        let mut data2 = HashMap::new();
390        data2.insert("name".to_string(), FilterValue::String("Bob".into()));
391
392        let batch = BatchBuilder::new()
393            .insert("users", data1)
394            .insert("users", data2)
395            .build();
396
397        assert_eq!(batch.len(), 2);
398    }
399
400    #[test]
401    fn test_combine_inserts_postgres() {
402        let mut data1 = HashMap::new();
403        data1.insert("name".to_string(), FilterValue::String("Alice".into()));
404        data1.insert("age".to_string(), FilterValue::Int(30));
405
406        let mut data2 = HashMap::new();
407        data2.insert("name".to_string(), FilterValue::String("Bob".into()));
408        data2.insert("age".to_string(), FilterValue::Int(25));
409
410        let batch = BatchBuilder::new()
411            .insert("users", data1)
412            .insert("users", data2)
413            .build();
414
415        let result = batch.to_combined_sql(DatabaseType::PostgreSQL);
416        assert!(result.is_some());
417
418        let (sql, _) = result.unwrap();
419        assert!(sql.starts_with("INSERT INTO users"));
420        assert!(sql.contains("VALUES"));
421    }
422
423    #[test]
424    fn test_batch_result() {
425        let results = vec![
426            OperationResult::success(1),
427            OperationResult::success(1),
428            OperationResult::success(1),
429        ];
430
431        let batch_result = BatchResult::new(results);
432        assert_eq!(batch_result.total_affected, 3);
433        assert!(batch_result.all_succeeded());
434    }
435
436    #[test]
437    fn test_batch_result_with_failure() {
438        let results = vec![
439            OperationResult::success(1),
440            OperationResult::failure("constraint violation"),
441            OperationResult::success(1),
442        ];
443
444        let batch_result = BatchResult::new(results);
445        assert_eq!(batch_result.total_affected, 2);
446        assert!(!batch_result.all_succeeded());
447    }
448}
449