prax_query/
batch.rs

1//! Batch query execution for combining multiple operations.
2//!
3//! This module provides utilities for executing multiple queries in a single
4//! database round-trip, improving performance for bulk operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use prax_query::batch::BatchBuilder;
10//!
11//! let batch = BatchBuilder::new()
12//!     .insert("users", &user1_data)
13//!     .insert("users", &user2_data)
14//!     .insert("users", &user3_data)
15//!     .build();
16//!
17//! let results = engine.execute_batch(batch).await?;
18//! ```
19
20use crate::filter::FilterValue;
21use crate::sql::{DatabaseType, FastSqlBuilder, QueryCapacity};
22use std::collections::HashMap;
23
24/// A batch of operations to execute together.
25#[derive(Debug, Clone)]
26pub struct Batch {
27    /// The operations in the batch.
28    operations: Vec<BatchOperation>,
29}
30
31impl Batch {
32    /// Create a new empty batch.
33    pub fn new() -> Self {
34        Self {
35            operations: Vec::new(),
36        }
37    }
38
39    /// Create a batch with pre-allocated capacity.
40    pub fn with_capacity(capacity: usize) -> Self {
41        Self {
42            operations: Vec::with_capacity(capacity),
43        }
44    }
45
46    /// Add an operation to the batch.
47    pub fn add(&mut self, op: BatchOperation) {
48        self.operations.push(op);
49    }
50
51    /// Get the operations in the batch.
52    pub fn operations(&self) -> &[BatchOperation] {
53        &self.operations
54    }
55
56    /// Get the number of operations.
57    pub fn len(&self) -> usize {
58        self.operations.len()
59    }
60
61    /// Check if the batch is empty.
62    pub fn is_empty(&self) -> bool {
63        self.operations.is_empty()
64    }
65
66    /// Convert the batch to a single SQL statement for databases that support it.
67    ///
68    /// This combines multiple INSERT statements into a single multi-row INSERT.
69    pub fn to_combined_sql(&self, db_type: DatabaseType) -> Option<(String, Vec<FilterValue>)> {
70        if self.operations.is_empty() {
71            return None;
72        }
73
74        // Group operations by type and table
75        let mut inserts: HashMap<&str, Vec<&BatchOperation>> = HashMap::new();
76        let mut other_ops = Vec::new();
77
78        for op in &self.operations {
79            match op {
80                BatchOperation::Insert { table, .. } => {
81                    inserts.entry(table.as_str()).or_default().push(op);
82                }
83                _ => other_ops.push(op),
84            }
85        }
86
87        // If we have non-insert operations or multiple tables, can't combine
88        if !other_ops.is_empty() || inserts.len() > 1 {
89            return None;
90        }
91
92        // Combine inserts for a single table
93        if let Some((table, ops)) = inserts.into_iter().next() {
94            return self.combine_inserts(table, &ops, db_type);
95        }
96
97        None
98    }
99
100    /// Combine multiple INSERT operations into a single multi-row INSERT.
101    fn combine_inserts(
102        &self,
103        table: &str,
104        ops: &[&BatchOperation],
105        db_type: DatabaseType,
106    ) -> Option<(String, Vec<FilterValue>)> {
107        if ops.is_empty() {
108            return None;
109        }
110
111        // Get columns from first insert
112        let first_columns: Vec<&str> = match &ops[0] {
113            BatchOperation::Insert { data, .. } => data.keys().map(String::as_str).collect(),
114            _ => return None,
115        };
116
117        // Verify all inserts have the same columns
118        for op in ops.iter().skip(1) {
119            if let BatchOperation::Insert { data, .. } = op {
120                let cols: Vec<&str> = data.keys().map(String::as_str).collect();
121                if cols.len() != first_columns.len() {
122                    return None;
123                }
124            }
125        }
126
127        // Build combined INSERT
128        let cols_per_row = first_columns.len();
129        let total_params = cols_per_row * ops.len();
130
131        let mut builder =
132            FastSqlBuilder::with_capacity(db_type, QueryCapacity::Custom(64 + total_params * 8));
133
134        builder.push_str("INSERT INTO ");
135        builder.push_str(table);
136        builder.push_str(" (");
137
138        for (i, col) in first_columns.iter().enumerate() {
139            if i > 0 {
140                builder.push_str(", ");
141            }
142            builder.push_str(col);
143        }
144
145        builder.push_str(") VALUES ");
146
147        let mut all_params = Vec::with_capacity(total_params);
148
149        for (row_idx, op) in ops.iter().enumerate() {
150            if row_idx > 0 {
151                builder.push_str(", ");
152            }
153            builder.push_char('(');
154
155            if let BatchOperation::Insert { data, .. } = op {
156                for (col_idx, col) in first_columns.iter().enumerate() {
157                    if col_idx > 0 {
158                        builder.push_str(", ");
159                    }
160                    builder.bind(data.get(*col).cloned().unwrap_or(FilterValue::Null));
161                    if let Some(val) = data.get(*col) {
162                        all_params.push(val.clone());
163                    } else {
164                        all_params.push(FilterValue::Null);
165                    }
166                }
167            }
168
169            builder.push_char(')');
170        }
171
172        Some(builder.build())
173    }
174}
175
176impl Default for Batch {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182/// A single operation in a batch.
183#[derive(Debug, Clone)]
184pub enum BatchOperation {
185    /// An INSERT operation.
186    Insert {
187        /// The table name.
188        table: String,
189        /// The data to insert.
190        data: HashMap<String, FilterValue>,
191    },
192    /// An UPDATE operation.
193    Update {
194        /// The table name.
195        table: String,
196        /// The filter for which rows to update.
197        filter: HashMap<String, FilterValue>,
198        /// The data to update.
199        data: HashMap<String, FilterValue>,
200    },
201    /// A DELETE operation.
202    Delete {
203        /// The table name.
204        table: String,
205        /// The filter for which rows to delete.
206        filter: HashMap<String, FilterValue>,
207    },
208    /// A raw SQL operation.
209    Raw {
210        /// The SQL query.
211        sql: String,
212        /// The parameters.
213        params: Vec<FilterValue>,
214    },
215}
216
217impl BatchOperation {
218    /// Create an INSERT operation.
219    pub fn insert(table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
220        Self::Insert {
221            table: table.into(),
222            data,
223        }
224    }
225
226    /// Create an UPDATE operation.
227    pub fn update(
228        table: impl Into<String>,
229        filter: HashMap<String, FilterValue>,
230        data: HashMap<String, FilterValue>,
231    ) -> Self {
232        Self::Update {
233            table: table.into(),
234            filter,
235            data,
236        }
237    }
238
239    /// Create a DELETE operation.
240    pub fn delete(table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
241        Self::Delete {
242            table: table.into(),
243            filter,
244        }
245    }
246
247    /// Create a raw SQL operation.
248    pub fn raw(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
249        Self::Raw {
250            sql: sql.into(),
251            params,
252        }
253    }
254}
255
256/// Builder for creating batches fluently.
257#[derive(Debug, Default)]
258pub struct BatchBuilder {
259    batch: Batch,
260}
261
262impl BatchBuilder {
263    /// Create a new batch builder.
264    pub fn new() -> Self {
265        Self {
266            batch: Batch::new(),
267        }
268    }
269
270    /// Create a builder with pre-allocated capacity.
271    pub fn with_capacity(capacity: usize) -> Self {
272        Self {
273            batch: Batch::with_capacity(capacity),
274        }
275    }
276
277    /// Add an INSERT operation.
278    pub fn insert(mut self, table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
279        self.batch.add(BatchOperation::insert(table, data));
280        self
281    }
282
283    /// Add an UPDATE operation.
284    pub fn update(
285        mut self,
286        table: impl Into<String>,
287        filter: HashMap<String, FilterValue>,
288        data: HashMap<String, FilterValue>,
289    ) -> Self {
290        self.batch.add(BatchOperation::update(table, filter, data));
291        self
292    }
293
294    /// Add a DELETE operation.
295    pub fn delete(
296        mut self,
297        table: impl Into<String>,
298        filter: HashMap<String, FilterValue>,
299    ) -> Self {
300        self.batch.add(BatchOperation::delete(table, filter));
301        self
302    }
303
304    /// Add a raw SQL operation.
305    pub fn raw(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
306        self.batch.add(BatchOperation::raw(sql, params));
307        self
308    }
309
310    /// Build the batch.
311    pub fn build(self) -> Batch {
312        self.batch
313    }
314}
315
316/// Result of a batch execution.
317#[derive(Debug, Clone)]
318pub struct BatchResult {
319    /// Results for each operation.
320    pub results: Vec<OperationResult>,
321    /// Total rows affected across all operations.
322    pub total_affected: u64,
323}
324
325impl BatchResult {
326    /// Create a new batch result.
327    pub fn new(results: Vec<OperationResult>) -> Self {
328        let total_affected = results.iter().map(|r| r.rows_affected).sum();
329        Self {
330            results,
331            total_affected,
332        }
333    }
334
335    /// Get the number of operations.
336    pub fn len(&self) -> usize {
337        self.results.len()
338    }
339
340    /// Check if empty.
341    pub fn is_empty(&self) -> bool {
342        self.results.is_empty()
343    }
344
345    /// Check if all operations succeeded.
346    pub fn all_succeeded(&self) -> bool {
347        self.results.iter().all(|r| r.success)
348    }
349}
350
351/// Result of a single operation in a batch.
352#[derive(Debug, Clone)]
353pub struct OperationResult {
354    /// Whether the operation succeeded.
355    pub success: bool,
356    /// Number of rows affected.
357    pub rows_affected: u64,
358    /// Error message if failed.
359    pub error: Option<String>,
360}
361
362impl OperationResult {
363    /// Create a successful result.
364    pub fn success(rows_affected: u64) -> Self {
365        Self {
366            success: true,
367            rows_affected,
368            error: None,
369        }
370    }
371
372    /// Create a failed result.
373    pub fn failure(error: impl Into<String>) -> Self {
374        Self {
375            success: false,
376            rows_affected: 0,
377            error: Some(error.into()),
378        }
379    }
380}
381
382// ============================================================================
383// Pipeline Execution
384// ============================================================================
385
386/// A query pipeline for executing multiple queries efficiently.
387///
388/// Pipelines combine multiple queries and execute them with minimal
389/// round-trips to the database. This is especially useful for:
390///
391/// - Fetching a parent record and its relations
392/// - Performing multiple inserts in sequence
393/// - Complex transactions with multiple operations
394///
395/// # Example
396///
397/// ```rust,ignore
398/// use prax_query::batch::Pipeline;
399///
400/// let pipeline = Pipeline::new()
401///     .query("SELECT * FROM users WHERE id = $1", vec![id.into()])
402///     .query("SELECT * FROM posts WHERE author_id = $1", vec![id.into()])
403///     .build();
404///
405/// let results = engine.execute_pipeline(pipeline).await?;
406/// ```
407#[derive(Debug, Clone)]
408pub struct Pipeline {
409    /// Queries in the pipeline.
410    queries: Vec<PipelineQuery>,
411}
412
413impl Pipeline {
414    /// Create a new empty pipeline.
415    pub fn new() -> Self {
416        Self {
417            queries: Vec::new(),
418        }
419    }
420
421    /// Create a pipeline with pre-allocated capacity.
422    pub fn with_capacity(capacity: usize) -> Self {
423        Self {
424            queries: Vec::with_capacity(capacity),
425        }
426    }
427
428    /// Add a query to the pipeline.
429    pub fn push(&mut self, sql: impl Into<String>, params: Vec<FilterValue>) {
430        self.queries.push(PipelineQuery {
431            sql: sql.into(),
432            params,
433            expect_rows: true,
434        });
435    }
436
437    /// Add an execute-only query (no result rows expected).
438    pub fn push_execute(&mut self, sql: impl Into<String>, params: Vec<FilterValue>) {
439        self.queries.push(PipelineQuery {
440            sql: sql.into(),
441            params,
442            expect_rows: false,
443        });
444    }
445
446    /// Get the queries.
447    pub fn queries(&self) -> &[PipelineQuery] {
448        &self.queries
449    }
450
451    /// Get the number of queries.
452    pub fn len(&self) -> usize {
453        self.queries.len()
454    }
455
456    /// Check if empty.
457    pub fn is_empty(&self) -> bool {
458        self.queries.is_empty()
459    }
460}
461
462impl Default for Pipeline {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468/// A single query in a pipeline.
469#[derive(Debug, Clone)]
470pub struct PipelineQuery {
471    /// The SQL query.
472    pub sql: String,
473    /// Query parameters.
474    pub params: Vec<FilterValue>,
475    /// Whether this query returns rows.
476    pub expect_rows: bool,
477}
478
479/// Builder for creating pipelines.
480#[derive(Debug, Clone)]
481pub struct PipelineBuilder {
482    pipeline: Pipeline,
483}
484
485impl PipelineBuilder {
486    /// Create a new pipeline builder.
487    pub fn new() -> Self {
488        Self {
489            pipeline: Pipeline::new(),
490        }
491    }
492
493    /// Create a builder with pre-allocated capacity.
494    pub fn with_capacity(capacity: usize) -> Self {
495        Self {
496            pipeline: Pipeline::with_capacity(capacity),
497        }
498    }
499
500    /// Add a SELECT query.
501    pub fn query(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
502        self.pipeline.push(sql, params);
503        self
504    }
505
506    /// Add an execute-only query (INSERT/UPDATE/DELETE).
507    pub fn execute(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
508        self.pipeline.push_execute(sql, params);
509        self
510    }
511
512    /// Build the pipeline.
513    pub fn build(self) -> Pipeline {
514        self.pipeline
515    }
516}
517
518impl Default for PipelineBuilder {
519    fn default() -> Self {
520        Self::new()
521    }
522}
523
524/// Result of pipeline execution.
525#[derive(Debug)]
526pub struct PipelineResult {
527    /// Results for each query in the pipeline.
528    pub query_results: Vec<QueryResult>,
529}
530
531/// Result of a single query in a pipeline.
532#[derive(Debug)]
533pub enum QueryResult {
534    /// Query returned rows.
535    Rows {
536        /// Number of rows returned.
537        count: usize,
538    },
539    /// Query was executed (no rows).
540    Executed {
541        /// Rows affected.
542        rows_affected: u64,
543    },
544    /// Query failed.
545    Error {
546        /// Error message.
547        message: String,
548    },
549}
550
551impl PipelineResult {
552    /// Create a new pipeline result.
553    pub fn new(query_results: Vec<QueryResult>) -> Self {
554        Self { query_results }
555    }
556
557    /// Check if all queries succeeded.
558    pub fn all_succeeded(&self) -> bool {
559        self.query_results
560            .iter()
561            .all(|r| !matches!(r, QueryResult::Error { .. }))
562    }
563
564    /// Get first error if any.
565    pub fn first_error(&self) -> Option<&str> {
566        self.query_results.iter().find_map(|r| {
567            if let QueryResult::Error { message } = r {
568                Some(message.as_str())
569            } else {
570                None
571            }
572        })
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_batch_builder() {
582        let mut data1 = HashMap::new();
583        data1.insert("name".to_string(), FilterValue::String("Alice".into()));
584
585        let mut data2 = HashMap::new();
586        data2.insert("name".to_string(), FilterValue::String("Bob".into()));
587
588        let batch = BatchBuilder::new()
589            .insert("users", data1)
590            .insert("users", data2)
591            .build();
592
593        assert_eq!(batch.len(), 2);
594    }
595
596    #[test]
597    fn test_combine_inserts_postgres() {
598        let mut data1 = HashMap::new();
599        data1.insert("name".to_string(), FilterValue::String("Alice".into()));
600        data1.insert("age".to_string(), FilterValue::Int(30));
601
602        let mut data2 = HashMap::new();
603        data2.insert("name".to_string(), FilterValue::String("Bob".into()));
604        data2.insert("age".to_string(), FilterValue::Int(25));
605
606        let batch = BatchBuilder::new()
607            .insert("users", data1)
608            .insert("users", data2)
609            .build();
610
611        let result = batch.to_combined_sql(DatabaseType::PostgreSQL);
612        assert!(result.is_some());
613
614        let (sql, _) = result.unwrap();
615        assert!(sql.starts_with("INSERT INTO users"));
616        assert!(sql.contains("VALUES"));
617    }
618
619    #[test]
620    fn test_batch_result() {
621        let results = vec![
622            OperationResult::success(1),
623            OperationResult::success(1),
624            OperationResult::success(1),
625        ];
626
627        let batch_result = BatchResult::new(results);
628        assert_eq!(batch_result.total_affected, 3);
629        assert!(batch_result.all_succeeded());
630    }
631
632    #[test]
633    fn test_batch_result_with_failure() {
634        let results = vec![
635            OperationResult::success(1),
636            OperationResult::failure("constraint violation"),
637            OperationResult::success(1),
638        ];
639
640        let batch_result = BatchResult::new(results);
641        assert_eq!(batch_result.total_affected, 2);
642        assert!(!batch_result.all_succeeded());
643    }
644}