1use std::collections::VecDeque;
41use std::time::{Duration, Instant};
42
43use crate::filter::FilterValue;
44use crate::sql::DatabaseType;
45
46#[derive(Debug, Clone)]
48pub struct PipelineConfig {
49 pub max_batch_size: usize,
51 pub execution_timeout: Duration,
53 pub use_transaction: bool,
55 pub rollback_on_error: bool,
57 pub max_depth: usize,
59 pub collect_stats: bool,
61}
62
63impl Default for PipelineConfig {
64 fn default() -> Self {
65 Self {
66 max_batch_size: 1000,
67 execution_timeout: Duration::from_secs(60),
68 use_transaction: false,
69 rollback_on_error: true,
70 max_depth: 100,
71 collect_stats: true,
72 }
73 }
74}
75
76impl PipelineConfig {
77 #[must_use]
79 pub fn for_bulk_inserts() -> Self {
80 Self {
81 max_batch_size: 5000,
82 execution_timeout: Duration::from_secs(300),
83 use_transaction: true,
84 rollback_on_error: true,
85 max_depth: 500,
86 collect_stats: true,
87 }
88 }
89
90 #[must_use]
92 pub fn for_bulk_updates() -> Self {
93 Self {
94 max_batch_size: 1000,
95 execution_timeout: Duration::from_secs(180),
96 use_transaction: true,
97 rollback_on_error: true,
98 max_depth: 200,
99 collect_stats: true,
100 }
101 }
102
103 #[must_use]
105 pub fn for_mixed_operations() -> Self {
106 Self {
107 max_batch_size: 500,
108 execution_timeout: Duration::from_secs(120),
109 use_transaction: true,
110 rollback_on_error: true,
111 max_depth: 100,
112 collect_stats: true,
113 }
114 }
115
116 #[must_use]
118 pub fn with_max_batch_size(mut self, size: usize) -> Self {
119 self.max_batch_size = size.max(1);
120 self
121 }
122
123 #[must_use]
125 pub fn with_timeout(mut self, timeout: Duration) -> Self {
126 self.execution_timeout = timeout;
127 self
128 }
129
130 #[must_use]
132 pub fn with_transaction(mut self, use_tx: bool) -> Self {
133 self.use_transaction = use_tx;
134 self
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct PipelineError {
141 pub query_index: usize,
143 pub message: String,
145 pub is_timeout: bool,
147 pub sql: Option<String>,
149}
150
151impl std::fmt::Display for PipelineError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "Pipeline query {} failed: {}", self.query_index, self.message)
154 }
155}
156
157impl std::error::Error for PipelineError {}
158
159#[derive(Debug, Clone)]
161pub enum QueryResult {
162 Rows {
164 count: usize,
166 },
167 Executed {
169 rows_affected: u64,
171 },
172 Error {
174 message: String,
176 },
177}
178
179impl QueryResult {
180 pub fn is_success(&self) -> bool {
182 !matches!(self, Self::Error { .. })
183 }
184
185 pub fn rows_affected(&self) -> Option<u64> {
187 match self {
188 Self::Executed { rows_affected } => Some(*rows_affected),
189 _ => None,
190 }
191 }
192}
193
194#[derive(Debug)]
196pub struct PipelineResult {
197 pub results: Vec<QueryResult>,
199 pub total_affected: u64,
201 pub total_returned: u64,
203 pub stats: PipelineStats,
205}
206
207impl PipelineResult {
208 pub fn all_succeeded(&self) -> bool {
210 self.results.iter().all(|r| r.is_success())
211 }
212
213 pub fn first_error(&self) -> Option<&str> {
215 self.results.iter().find_map(|r| {
216 if let QueryResult::Error { message } = r {
217 Some(message.as_str())
218 } else {
219 None
220 }
221 })
222 }
223
224 pub fn success_count(&self) -> usize {
226 self.results.iter().filter(|r| r.is_success()).count()
227 }
228
229 pub fn error_count(&self) -> usize {
231 self.results.iter().filter(|r| !r.is_success()).count()
232 }
233}
234
235#[derive(Debug, Clone, Default)]
237pub struct PipelineStats {
238 pub total_queries: usize,
240 pub successful: usize,
242 pub failed: usize,
244 pub total_duration: Duration,
246 pub wait_time: Duration,
248 pub batches_used: usize,
250 pub avg_batch_size: f64,
252}
253
254#[derive(Debug, Clone)]
256pub struct PipelineQuery {
257 pub sql: String,
259 pub params: Vec<FilterValue>,
261 pub expects_rows: bool,
263 pub id: Option<String>,
265}
266
267impl PipelineQuery {
268 pub fn new(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
270 Self {
271 sql: sql.into(),
272 params,
273 expects_rows: true,
274 id: None,
275 }
276 }
277
278 pub fn execute(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
280 Self {
281 sql: sql.into(),
282 params,
283 expects_rows: false,
284 id: None,
285 }
286 }
287
288 #[must_use]
290 pub fn with_id(mut self, id: impl Into<String>) -> Self {
291 self.id = Some(id.into());
292 self
293 }
294}
295
296#[derive(Debug)]
298pub struct QueryPipeline {
299 config: PipelineConfig,
300 queries: VecDeque<PipelineQuery>,
301 db_type: DatabaseType,
302}
303
304impl QueryPipeline {
305 pub fn new(config: PipelineConfig) -> Self {
307 Self {
308 config,
309 queries: VecDeque::new(),
310 db_type: DatabaseType::PostgreSQL,
311 }
312 }
313
314 #[must_use]
316 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
317 self.db_type = db_type;
318 self
319 }
320
321 #[must_use]
323 pub fn add_query(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
324 self.queries.push_back(PipelineQuery::new(sql, params));
325 self
326 }
327
328 #[must_use]
330 pub fn add_execute(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
331 self.queries.push_back(PipelineQuery::execute(sql, params));
332 self
333 }
334
335 #[must_use]
337 pub fn add_insert(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
338 self.add_execute(sql, params)
339 }
340
341 #[must_use]
343 pub fn add_update(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
344 self.add_execute(sql, params)
345 }
346
347 #[must_use]
349 pub fn add_delete(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
350 self.add_execute(sql, params)
351 }
352
353 #[must_use]
355 pub fn add_select(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
356 self.add_query(sql, params)
357 }
358
359 pub fn push(&mut self, query: PipelineQuery) {
361 self.queries.push_back(query);
362 }
363
364 pub fn len(&self) -> usize {
366 self.queries.len()
367 }
368
369 pub fn is_empty(&self) -> bool {
371 self.queries.is_empty()
372 }
373
374 pub fn queries(&self) -> &VecDeque<PipelineQuery> {
376 &self.queries
377 }
378
379 pub fn to_batch_sql(&self) -> Option<(String, Vec<FilterValue>)> {
383 if self.queries.is_empty() {
384 return None;
385 }
386
387 match self.db_type {
389 DatabaseType::PostgreSQL | DatabaseType::MySQL => {}
390 _ => return None,
391 }
392
393 let mut combined = String::new();
394 let mut all_params = Vec::new();
395 let mut param_offset = 0;
396
397 for query in &self.queries {
398 if !combined.is_empty() {
399 combined.push_str(";\n");
400 }
401
402 if self.db_type == DatabaseType::PostgreSQL && !query.params.is_empty() {
404 let renumbered = renumber_params(&query.sql, param_offset);
405 combined.push_str(&renumbered);
406 param_offset += query.params.len();
407 } else {
408 combined.push_str(&query.sql);
409 }
410
411 all_params.extend(query.params.clone());
412 }
413
414 Some((combined, all_params))
415 }
416
417 pub fn into_batches(self) -> Vec<Vec<PipelineQuery>> {
419 let batch_size = self.config.max_batch_size;
420 let queries: Vec<_> = self.queries.into_iter().collect();
421
422 queries.chunks(batch_size).map(|c| c.to_vec()).collect()
423 }
424
425 pub fn to_transaction_sql(&self) -> Vec<(String, Vec<FilterValue>)> {
427 let mut statements = Vec::new();
428
429 statements.push((self.begin_transaction_sql().to_string(), Vec::new()));
431
432 for query in &self.queries {
434 statements.push((query.sql.clone(), query.params.clone()));
435 }
436
437 statements.push((self.commit_sql().to_string(), Vec::new()));
439
440 statements
441 }
442
443 fn begin_transaction_sql(&self) -> &'static str {
445 match self.db_type {
446 DatabaseType::PostgreSQL => "BEGIN",
447 DatabaseType::MySQL => "START TRANSACTION",
448 DatabaseType::SQLite => "BEGIN TRANSACTION",
449 DatabaseType::MSSQL => "BEGIN TRANSACTION",
450 }
451 }
452
453 fn commit_sql(&self) -> &'static str {
455 "COMMIT"
456 }
457
458 #[allow(dead_code)]
460 fn rollback_sql(&self) -> &'static str {
461 "ROLLBACK"
462 }
463}
464
465fn renumber_params(sql: &str, offset: usize) -> String {
467 let mut result = String::with_capacity(sql.len() + 10);
468 let mut chars = sql.chars().peekable();
469
470 while let Some(c) = chars.next() {
471 if c == '$' {
472 let mut num_str = String::new();
474 while let Some(&digit) = chars.peek() {
475 if digit.is_ascii_digit() {
476 num_str.push(digit);
477 chars.next();
478 } else {
479 break;
480 }
481 }
482
483 if let Ok(num) = num_str.parse::<usize>() {
484 result.push('$');
485 result.push_str(&(num + offset).to_string());
486 } else {
487 result.push('$');
488 result.push_str(&num_str);
489 }
490 } else {
491 result.push(c);
492 }
493 }
494
495 result
496}
497
498#[derive(Debug)]
500pub struct BulkInsertPipeline {
501 table: String,
502 columns: Vec<String>,
503 rows: Vec<Vec<FilterValue>>,
504 db_type: DatabaseType,
505 batch_size: usize,
506}
507
508impl BulkInsertPipeline {
509 pub fn new(table: impl Into<String>, columns: Vec<String>) -> Self {
511 Self {
512 table: table.into(),
513 columns,
514 rows: Vec::new(),
515 db_type: DatabaseType::PostgreSQL,
516 batch_size: 1000,
517 }
518 }
519
520 #[must_use]
522 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
523 self.db_type = db_type;
524 self
525 }
526
527 #[must_use]
529 pub fn with_batch_size(mut self, size: usize) -> Self {
530 self.batch_size = size.max(1);
531 self
532 }
533
534 pub fn add_row(&mut self, values: Vec<FilterValue>) {
536 assert_eq!(
537 values.len(),
538 self.columns.len(),
539 "Row has {} values, expected {}",
540 values.len(),
541 self.columns.len()
542 );
543 self.rows.push(values);
544 }
545
546 pub fn add_rows(&mut self, rows: impl IntoIterator<Item = Vec<FilterValue>>) {
548 for row in rows {
549 self.add_row(row);
550 }
551 }
552
553 pub fn len(&self) -> usize {
555 self.rows.len()
556 }
557
558 pub fn is_empty(&self) -> bool {
560 self.rows.is_empty()
561 }
562
563 pub fn to_insert_statements(&self) -> Vec<(String, Vec<FilterValue>)> {
565 if self.rows.is_empty() {
566 return Vec::new();
567 }
568
569 let mut statements = Vec::new();
570
571 for chunk in self.rows.chunks(self.batch_size) {
572 let (sql, params) = self.build_multi_insert(chunk);
573 statements.push((sql, params));
574 }
575
576 statements
577 }
578
579 fn build_multi_insert(&self, rows: &[Vec<FilterValue>]) -> (String, Vec<FilterValue>) {
580 let cols_str = self.columns.join(", ");
581 let mut sql = format!("INSERT INTO {} ({}) VALUES ", self.table, cols_str);
582 let mut params = Vec::with_capacity(rows.len() * self.columns.len());
583 let mut param_idx = 1;
584
585 for (row_idx, row) in rows.iter().enumerate() {
586 if row_idx > 0 {
587 sql.push_str(", ");
588 }
589 sql.push('(');
590
591 for (col_idx, value) in row.iter().enumerate() {
592 if col_idx > 0 {
593 sql.push_str(", ");
594 }
595
596 match self.db_type {
597 DatabaseType::PostgreSQL => {
598 sql.push_str(&format!("${}", param_idx));
599 }
600 DatabaseType::MySQL | DatabaseType::SQLite => {
601 sql.push('?');
602 }
603 DatabaseType::MSSQL => {
604 sql.push_str(&format!("@p{}", param_idx));
605 }
606 }
607
608 params.push(value.clone());
609 param_idx += 1;
610 }
611
612 sql.push(')');
613 }
614
615 (sql, params)
616 }
617
618 pub fn to_pipeline(self) -> QueryPipeline {
620 let statements = self.to_insert_statements();
621 let mut pipeline = QueryPipeline::new(PipelineConfig::for_bulk_inserts())
622 .for_database(self.db_type);
623
624 for (sql, params) in statements {
625 pipeline = pipeline.add_insert(sql, params);
626 }
627
628 pipeline
629 }
630}
631
632#[derive(Debug)]
634pub struct BulkUpdatePipeline {
635 table: String,
636 updates: Vec<BulkUpdate>,
637 db_type: DatabaseType,
638}
639
640#[derive(Debug, Clone)]
641struct BulkUpdate {
642 set: Vec<(String, FilterValue)>,
643 where_clause: Vec<(String, FilterValue)>,
644}
645
646impl BulkUpdatePipeline {
647 pub fn new(table: impl Into<String>) -> Self {
649 Self {
650 table: table.into(),
651 updates: Vec::new(),
652 db_type: DatabaseType::PostgreSQL,
653 }
654 }
655
656 #[must_use]
658 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
659 self.db_type = db_type;
660 self
661 }
662
663 pub fn add_update(
665 &mut self,
666 set: Vec<(String, FilterValue)>,
667 where_clause: Vec<(String, FilterValue)>,
668 ) {
669 self.updates.push(BulkUpdate { set, where_clause });
670 }
671
672 pub fn len(&self) -> usize {
674 self.updates.len()
675 }
676
677 pub fn is_empty(&self) -> bool {
679 self.updates.is_empty()
680 }
681
682 pub fn to_update_statements(&self) -> Vec<(String, Vec<FilterValue>)> {
684 self.updates
685 .iter()
686 .map(|update| self.build_update(update))
687 .collect()
688 }
689
690 fn build_update(&self, update: &BulkUpdate) -> (String, Vec<FilterValue>) {
691 let mut sql = format!("UPDATE {} SET ", self.table);
692 let mut params = Vec::new();
693 let mut param_idx = 1;
694
695 for (idx, (col, val)) in update.set.iter().enumerate() {
697 if idx > 0 {
698 sql.push_str(", ");
699 }
700
701 match self.db_type {
702 DatabaseType::PostgreSQL => {
703 sql.push_str(&format!("{} = ${}", col, param_idx));
704 }
705 DatabaseType::MySQL | DatabaseType::SQLite => {
706 sql.push_str(&format!("{} = ?", col));
707 }
708 DatabaseType::MSSQL => {
709 sql.push_str(&format!("{} = @p{}", col, param_idx));
710 }
711 }
712
713 params.push(val.clone());
714 param_idx += 1;
715 }
716
717 if !update.where_clause.is_empty() {
719 sql.push_str(" WHERE ");
720
721 for (idx, (col, val)) in update.where_clause.iter().enumerate() {
722 if idx > 0 {
723 sql.push_str(" AND ");
724 }
725
726 match self.db_type {
727 DatabaseType::PostgreSQL => {
728 sql.push_str(&format!("{} = ${}", col, param_idx));
729 }
730 DatabaseType::MySQL | DatabaseType::SQLite => {
731 sql.push_str(&format!("{} = ?", col));
732 }
733 DatabaseType::MSSQL => {
734 sql.push_str(&format!("{} = @p{}", col, param_idx));
735 }
736 }
737
738 params.push(val.clone());
739 param_idx += 1;
740 }
741 }
742
743 (sql, params)
744 }
745
746 pub fn to_pipeline(self) -> QueryPipeline {
748 let statements = self.to_update_statements();
749 let mut pipeline = QueryPipeline::new(PipelineConfig::for_bulk_updates())
750 .for_database(self.db_type);
751
752 for (sql, params) in statements {
753 pipeline = pipeline.add_update(sql, params);
754 }
755
756 pipeline
757 }
758}
759
760#[allow(async_fn_in_trait)]
764pub trait PipelineExecutor {
765 async fn execute_pipeline(&self, pipeline: &QueryPipeline) -> Result<PipelineResult, PipelineError>;
767}
768
769pub struct SimulatedExecutor {
771 latency: Duration,
772 error_rate: f64,
773}
774
775impl SimulatedExecutor {
776 pub fn new(latency: Duration, error_rate: f64) -> Self {
778 Self { latency, error_rate }
779 }
780
781 pub async fn execute(&self, pipeline: &QueryPipeline) -> PipelineResult {
783 let start = Instant::now();
784 let mut results = Vec::new();
785 let mut total_affected = 0u64;
786 let mut successful = 0;
787 let mut failed = 0;
788
789 for _query in pipeline.queries() {
791 tokio::time::sleep(self.latency / 10).await;
793
794 if rand_like_error(self.error_rate) {
796 results.push(QueryResult::Error {
797 message: "Simulated error".to_string(),
798 });
799 failed += 1;
800 } else {
801 let affected = 1;
802 total_affected += affected;
803 results.push(QueryResult::Executed {
804 rows_affected: affected,
805 });
806 successful += 1;
807 }
808 }
809
810 let total_duration = start.elapsed();
811 let batches_used = (pipeline.len() + 999) / 1000;
812
813 PipelineResult {
814 results,
815 total_affected,
816 total_returned: 0,
817 stats: PipelineStats {
818 total_queries: pipeline.len(),
819 successful,
820 failed,
821 total_duration,
822 wait_time: total_duration,
823 batches_used,
824 avg_batch_size: pipeline.len() as f64 / batches_used.max(1) as f64,
825 },
826 }
827 }
828}
829
830fn rand_like_error(rate: f64) -> bool {
832 use std::time::SystemTime;
833 let nanos = SystemTime::now()
834 .duration_since(SystemTime::UNIX_EPOCH)
835 .unwrap()
836 .subsec_nanos();
837 (nanos as f64 / u32::MAX as f64) < rate
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843
844 #[test]
845 fn test_pipeline_builder() {
846 let pipeline = QueryPipeline::new(PipelineConfig::default())
847 .add_insert("INSERT INTO users (name) VALUES ($1)", vec![FilterValue::String("Alice".into())])
848 .add_insert("INSERT INTO users (name) VALUES ($1)", vec![FilterValue::String("Bob".into())]);
849
850 assert_eq!(pipeline.len(), 2);
851 }
852
853 #[test]
854 fn test_bulk_insert_pipeline() {
855 let mut pipeline = BulkInsertPipeline::new("users", vec!["name".into(), "age".into()])
856 .with_batch_size(2);
857
858 pipeline.add_row(vec![FilterValue::String("Alice".into()), FilterValue::Int(30)]);
859 pipeline.add_row(vec![FilterValue::String("Bob".into()), FilterValue::Int(25)]);
860 pipeline.add_row(vec![FilterValue::String("Charlie".into()), FilterValue::Int(35)]);
861
862 let statements = pipeline.to_insert_statements();
863
864 assert_eq!(statements.len(), 2);
866
867 let (sql1, params1) = &statements[0];
869 assert!(sql1.contains("VALUES"));
870 assert_eq!(params1.len(), 4); let (sql2, params2) = &statements[1];
874 assert!(sql2.contains("VALUES"));
875 assert_eq!(params2.len(), 2); }
877
878 #[test]
879 fn test_bulk_update_pipeline() {
880 let mut pipeline = BulkUpdatePipeline::new("users");
881
882 pipeline.add_update(
883 vec![("name".into(), FilterValue::String("Updated".into()))],
884 vec![("id".into(), FilterValue::Int(1))],
885 );
886
887 let statements = pipeline.to_update_statements();
888 assert_eq!(statements.len(), 1);
889
890 let (sql, params) = &statements[0];
891 assert!(sql.contains("UPDATE users SET"));
892 assert!(sql.contains("WHERE"));
893 assert_eq!(params.len(), 2);
894 }
895
896 #[test]
897 fn test_renumber_params() {
898 let sql = "SELECT * FROM users WHERE id = $1 AND name = $2";
899 let renumbered = renumber_params(sql, 5);
900 assert_eq!(renumbered, "SELECT * FROM users WHERE id = $6 AND name = $7");
901 }
902
903 #[test]
904 fn test_batch_sql() {
905 let pipeline = QueryPipeline::new(PipelineConfig::default())
906 .for_database(DatabaseType::PostgreSQL)
907 .add_query("SELECT 1", vec![])
908 .add_query("SELECT 2", vec![]);
909
910 let batch = pipeline.to_batch_sql();
911 assert!(batch.is_some());
912
913 let (sql, _) = batch.unwrap();
914 assert!(sql.contains("SELECT 1"));
915 assert!(sql.contains("SELECT 2"));
916 }
917
918 #[test]
919 fn test_transaction_sql() {
920 let pipeline = QueryPipeline::new(PipelineConfig::default())
921 .for_database(DatabaseType::PostgreSQL)
922 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(1)]);
923
924 let statements = pipeline.to_transaction_sql();
925
926 assert_eq!(statements.len(), 3);
927 assert_eq!(statements[0].0, "BEGIN");
928 assert!(statements[1].0.contains("INSERT"));
929 assert_eq!(statements[2].0, "COMMIT");
930 }
931
932 #[tokio::test]
933 async fn test_simulated_executor() {
934 let executor = SimulatedExecutor::new(Duration::from_millis(1), 0.0);
935
936 let pipeline = QueryPipeline::new(PipelineConfig::default())
937 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(1)])
938 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(2)]);
939
940 let result = executor.execute(&pipeline).await;
941
942 assert!(result.all_succeeded());
943 assert_eq!(result.stats.total_queries, 2);
944 assert_eq!(result.total_affected, 2);
945 }
946}
947