1use std::collections::HashMap;
21use crate::filter::FilterValue;
22use crate::sql::{DatabaseType, FastSqlBuilder, QueryCapacity};
23
24#[derive(Debug, Clone)]
26pub struct Batch {
27 operations: Vec<BatchOperation>,
29}
30
31impl Batch {
32 pub fn new() -> Self {
34 Self {
35 operations: Vec::new(),
36 }
37 }
38
39 pub fn with_capacity(capacity: usize) -> Self {
41 Self {
42 operations: Vec::with_capacity(capacity),
43 }
44 }
45
46 pub fn add(&mut self, op: BatchOperation) {
48 self.operations.push(op);
49 }
50
51 pub fn operations(&self) -> &[BatchOperation] {
53 &self.operations
54 }
55
56 pub fn len(&self) -> usize {
58 self.operations.len()
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.operations.is_empty()
64 }
65
66 pub fn to_combined_sql(&self, db_type: DatabaseType) -> Option<(String, Vec<FilterValue>)> {
70 if self.operations.is_empty() {
71 return None;
72 }
73
74 let mut inserts: HashMap<&str, Vec<&BatchOperation>> = HashMap::new();
76 let mut other_ops = Vec::new();
77
78 for op in &self.operations {
79 match op {
80 BatchOperation::Insert { table, .. } => {
81 inserts.entry(table.as_str()).or_default().push(op);
82 }
83 _ => other_ops.push(op),
84 }
85 }
86
87 if !other_ops.is_empty() || inserts.len() > 1 {
89 return None;
90 }
91
92 if let Some((table, ops)) = inserts.into_iter().next() {
94 return self.combine_inserts(table, &ops, db_type);
95 }
96
97 None
98 }
99
100 fn combine_inserts(
102 &self,
103 table: &str,
104 ops: &[&BatchOperation],
105 db_type: DatabaseType,
106 ) -> Option<(String, Vec<FilterValue>)> {
107 if ops.is_empty() {
108 return None;
109 }
110
111 let first_columns: Vec<&str> = match &ops[0] {
113 BatchOperation::Insert { data, .. } => data.keys().map(String::as_str).collect(),
114 _ => return None,
115 };
116
117 for op in ops.iter().skip(1) {
119 if let BatchOperation::Insert { data, .. } = op {
120 let cols: Vec<&str> = data.keys().map(String::as_str).collect();
121 if cols.len() != first_columns.len() {
122 return None;
123 }
124 }
125 }
126
127 let cols_per_row = first_columns.len();
129 let total_params = cols_per_row * ops.len();
130
131 let mut builder = FastSqlBuilder::with_capacity(
132 db_type,
133 QueryCapacity::Custom(64 + total_params * 8),
134 );
135
136 builder.push_str("INSERT INTO ");
137 builder.push_str(table);
138 builder.push_str(" (");
139
140 for (i, col) in first_columns.iter().enumerate() {
141 if i > 0 {
142 builder.push_str(", ");
143 }
144 builder.push_str(col);
145 }
146
147 builder.push_str(") VALUES ");
148
149 let mut all_params = Vec::with_capacity(total_params);
150
151 for (row_idx, op) in ops.iter().enumerate() {
152 if row_idx > 0 {
153 builder.push_str(", ");
154 }
155 builder.push_char('(');
156
157 if let BatchOperation::Insert { data, .. } = op {
158 for (col_idx, col) in first_columns.iter().enumerate() {
159 if col_idx > 0 {
160 builder.push_str(", ");
161 }
162 builder.bind(data.get(*col).cloned().unwrap_or(FilterValue::Null));
163 if let Some(val) = data.get(*col) {
164 all_params.push(val.clone());
165 } else {
166 all_params.push(FilterValue::Null);
167 }
168 }
169 }
170
171 builder.push_char(')');
172 }
173
174 Some(builder.build())
175 }
176}
177
178impl Default for Batch {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[derive(Debug, Clone)]
186pub enum BatchOperation {
187 Insert {
189 table: String,
191 data: HashMap<String, FilterValue>,
193 },
194 Update {
196 table: String,
198 filter: HashMap<String, FilterValue>,
200 data: HashMap<String, FilterValue>,
202 },
203 Delete {
205 table: String,
207 filter: HashMap<String, FilterValue>,
209 },
210 Raw {
212 sql: String,
214 params: Vec<FilterValue>,
216 },
217}
218
219impl BatchOperation {
220 pub fn insert(table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
222 Self::Insert {
223 table: table.into(),
224 data,
225 }
226 }
227
228 pub fn update(
230 table: impl Into<String>,
231 filter: HashMap<String, FilterValue>,
232 data: HashMap<String, FilterValue>,
233 ) -> Self {
234 Self::Update {
235 table: table.into(),
236 filter,
237 data,
238 }
239 }
240
241 pub fn delete(table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
243 Self::Delete {
244 table: table.into(),
245 filter,
246 }
247 }
248
249 pub fn raw(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
251 Self::Raw {
252 sql: sql.into(),
253 params,
254 }
255 }
256}
257
258#[derive(Debug, Default)]
260pub struct BatchBuilder {
261 batch: Batch,
262}
263
264impl BatchBuilder {
265 pub fn new() -> Self {
267 Self {
268 batch: Batch::new(),
269 }
270 }
271
272 pub fn with_capacity(capacity: usize) -> Self {
274 Self {
275 batch: Batch::with_capacity(capacity),
276 }
277 }
278
279 pub fn insert(mut self, table: impl Into<String>, data: HashMap<String, FilterValue>) -> Self {
281 self.batch.add(BatchOperation::insert(table, data));
282 self
283 }
284
285 pub fn update(
287 mut self,
288 table: impl Into<String>,
289 filter: HashMap<String, FilterValue>,
290 data: HashMap<String, FilterValue>,
291 ) -> Self {
292 self.batch.add(BatchOperation::update(table, filter, data));
293 self
294 }
295
296 pub fn delete(mut self, table: impl Into<String>, filter: HashMap<String, FilterValue>) -> Self {
298 self.batch.add(BatchOperation::delete(table, filter));
299 self
300 }
301
302 pub fn raw(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
304 self.batch.add(BatchOperation::raw(sql, params));
305 self
306 }
307
308 pub fn build(self) -> Batch {
310 self.batch
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct BatchResult {
317 pub results: Vec<OperationResult>,
319 pub total_affected: u64,
321}
322
323impl BatchResult {
324 pub fn new(results: Vec<OperationResult>) -> Self {
326 let total_affected = results.iter().map(|r| r.rows_affected).sum();
327 Self {
328 results,
329 total_affected,
330 }
331 }
332
333 pub fn len(&self) -> usize {
335 self.results.len()
336 }
337
338 pub fn is_empty(&self) -> bool {
340 self.results.is_empty()
341 }
342
343 pub fn all_succeeded(&self) -> bool {
345 self.results.iter().all(|r| r.success)
346 }
347}
348
349#[derive(Debug, Clone)]
351pub struct OperationResult {
352 pub success: bool,
354 pub rows_affected: u64,
356 pub error: Option<String>,
358}
359
360impl OperationResult {
361 pub fn success(rows_affected: u64) -> Self {
363 Self {
364 success: true,
365 rows_affected,
366 error: None,
367 }
368 }
369
370 pub fn failure(error: impl Into<String>) -> Self {
372 Self {
373 success: false,
374 rows_affected: 0,
375 error: Some(error.into()),
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_batch_builder() {
386 let mut data1 = HashMap::new();
387 data1.insert("name".to_string(), FilterValue::String("Alice".into()));
388
389 let mut data2 = HashMap::new();
390 data2.insert("name".to_string(), FilterValue::String("Bob".into()));
391
392 let batch = BatchBuilder::new()
393 .insert("users", data1)
394 .insert("users", data2)
395 .build();
396
397 assert_eq!(batch.len(), 2);
398 }
399
400 #[test]
401 fn test_combine_inserts_postgres() {
402 let mut data1 = HashMap::new();
403 data1.insert("name".to_string(), FilterValue::String("Alice".into()));
404 data1.insert("age".to_string(), FilterValue::Int(30));
405
406 let mut data2 = HashMap::new();
407 data2.insert("name".to_string(), FilterValue::String("Bob".into()));
408 data2.insert("age".to_string(), FilterValue::Int(25));
409
410 let batch = BatchBuilder::new()
411 .insert("users", data1)
412 .insert("users", data2)
413 .build();
414
415 let result = batch.to_combined_sql(DatabaseType::PostgreSQL);
416 assert!(result.is_some());
417
418 let (sql, _) = result.unwrap();
419 assert!(sql.starts_with("INSERT INTO users"));
420 assert!(sql.contains("VALUES"));
421 }
422
423 #[test]
424 fn test_batch_result() {
425 let results = vec![
426 OperationResult::success(1),
427 OperationResult::success(1),
428 OperationResult::success(1),
429 ];
430
431 let batch_result = BatchResult::new(results);
432 assert_eq!(batch_result.total_affected, 3);
433 assert!(batch_result.all_succeeded());
434 }
435
436 #[test]
437 fn test_batch_result_with_failure() {
438 let results = vec![
439 OperationResult::success(1),
440 OperationResult::failure("constraint violation"),
441 OperationResult::success(1),
442 ];
443
444 let batch_result = BatchResult::new(results);
445 assert_eq!(batch_result.total_affected, 2);
446 assert!(!batch_result.all_succeeded());
447 }
448}
449