1use std::collections::VecDeque;
41use std::time::{Duration, Instant};
42
43use crate::filter::FilterValue;
44use crate::sql::DatabaseType;
45
46#[derive(Debug, Clone)]
48pub struct PipelineConfig {
49 pub max_batch_size: usize,
51 pub execution_timeout: Duration,
53 pub use_transaction: bool,
55 pub rollback_on_error: bool,
57 pub max_depth: usize,
59 pub collect_stats: bool,
61}
62
63impl Default for PipelineConfig {
64 fn default() -> Self {
65 Self {
66 max_batch_size: 1000,
67 execution_timeout: Duration::from_secs(60),
68 use_transaction: false,
69 rollback_on_error: true,
70 max_depth: 100,
71 collect_stats: true,
72 }
73 }
74}
75
76impl PipelineConfig {
77 #[must_use]
79 pub fn for_bulk_inserts() -> Self {
80 Self {
81 max_batch_size: 5000,
82 execution_timeout: Duration::from_secs(300),
83 use_transaction: true,
84 rollback_on_error: true,
85 max_depth: 500,
86 collect_stats: true,
87 }
88 }
89
90 #[must_use]
92 pub fn for_bulk_updates() -> Self {
93 Self {
94 max_batch_size: 1000,
95 execution_timeout: Duration::from_secs(180),
96 use_transaction: true,
97 rollback_on_error: true,
98 max_depth: 200,
99 collect_stats: true,
100 }
101 }
102
103 #[must_use]
105 pub fn for_mixed_operations() -> Self {
106 Self {
107 max_batch_size: 500,
108 execution_timeout: Duration::from_secs(120),
109 use_transaction: true,
110 rollback_on_error: true,
111 max_depth: 100,
112 collect_stats: true,
113 }
114 }
115
116 #[must_use]
118 pub fn with_max_batch_size(mut self, size: usize) -> Self {
119 self.max_batch_size = size.max(1);
120 self
121 }
122
123 #[must_use]
125 pub fn with_timeout(mut self, timeout: Duration) -> Self {
126 self.execution_timeout = timeout;
127 self
128 }
129
130 #[must_use]
132 pub fn with_transaction(mut self, use_tx: bool) -> Self {
133 self.use_transaction = use_tx;
134 self
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct PipelineError {
141 pub query_index: usize,
143 pub message: String,
145 pub is_timeout: bool,
147 pub sql: Option<String>,
149}
150
151impl std::fmt::Display for PipelineError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(
154 f,
155 "Pipeline query {} failed: {}",
156 self.query_index, self.message
157 )
158 }
159}
160
161impl std::error::Error for PipelineError {}
162
163#[derive(Debug, Clone)]
165pub enum QueryResult {
166 Rows {
168 count: usize,
170 },
171 Executed {
173 rows_affected: u64,
175 },
176 Error {
178 message: String,
180 },
181}
182
183impl QueryResult {
184 pub fn is_success(&self) -> bool {
186 !matches!(self, Self::Error { .. })
187 }
188
189 pub fn rows_affected(&self) -> Option<u64> {
191 match self {
192 Self::Executed { rows_affected } => Some(*rows_affected),
193 _ => None,
194 }
195 }
196}
197
198#[derive(Debug)]
200pub struct PipelineResult {
201 pub results: Vec<QueryResult>,
203 pub total_affected: u64,
205 pub total_returned: u64,
207 pub stats: PipelineStats,
209}
210
211impl PipelineResult {
212 pub fn all_succeeded(&self) -> bool {
214 self.results.iter().all(|r| r.is_success())
215 }
216
217 pub fn first_error(&self) -> Option<&str> {
219 self.results.iter().find_map(|r| {
220 if let QueryResult::Error { message } = r {
221 Some(message.as_str())
222 } else {
223 None
224 }
225 })
226 }
227
228 pub fn success_count(&self) -> usize {
230 self.results.iter().filter(|r| r.is_success()).count()
231 }
232
233 pub fn error_count(&self) -> usize {
235 self.results.iter().filter(|r| !r.is_success()).count()
236 }
237}
238
239#[derive(Debug, Clone, Default)]
241pub struct PipelineStats {
242 pub total_queries: usize,
244 pub successful: usize,
246 pub failed: usize,
248 pub total_duration: Duration,
250 pub wait_time: Duration,
252 pub batches_used: usize,
254 pub avg_batch_size: f64,
256}
257
258#[derive(Debug, Clone)]
260pub struct PipelineQuery {
261 pub sql: String,
263 pub params: Vec<FilterValue>,
265 pub expects_rows: bool,
267 pub id: Option<String>,
269}
270
271impl PipelineQuery {
272 pub fn new(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
274 Self {
275 sql: sql.into(),
276 params,
277 expects_rows: true,
278 id: None,
279 }
280 }
281
282 pub fn execute(sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
284 Self {
285 sql: sql.into(),
286 params,
287 expects_rows: false,
288 id: None,
289 }
290 }
291
292 #[must_use]
294 pub fn with_id(mut self, id: impl Into<String>) -> Self {
295 self.id = Some(id.into());
296 self
297 }
298}
299
300#[derive(Debug)]
302pub struct QueryPipeline {
303 config: PipelineConfig,
304 queries: VecDeque<PipelineQuery>,
305 db_type: DatabaseType,
306}
307
308impl QueryPipeline {
309 pub fn new(config: PipelineConfig) -> Self {
311 Self {
312 config,
313 queries: VecDeque::new(),
314 db_type: DatabaseType::PostgreSQL,
315 }
316 }
317
318 #[must_use]
320 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
321 self.db_type = db_type;
322 self
323 }
324
325 #[must_use]
327 pub fn add_query(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
328 self.queries.push_back(PipelineQuery::new(sql, params));
329 self
330 }
331
332 #[must_use]
334 pub fn add_execute(mut self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
335 self.queries.push_back(PipelineQuery::execute(sql, params));
336 self
337 }
338
339 #[must_use]
341 pub fn add_insert(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
342 self.add_execute(sql, params)
343 }
344
345 #[must_use]
347 pub fn add_update(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
348 self.add_execute(sql, params)
349 }
350
351 #[must_use]
353 pub fn add_delete(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
354 self.add_execute(sql, params)
355 }
356
357 #[must_use]
359 pub fn add_select(self, sql: impl Into<String>, params: Vec<FilterValue>) -> Self {
360 self.add_query(sql, params)
361 }
362
363 pub fn push(&mut self, query: PipelineQuery) {
365 self.queries.push_back(query);
366 }
367
368 pub fn len(&self) -> usize {
370 self.queries.len()
371 }
372
373 pub fn is_empty(&self) -> bool {
375 self.queries.is_empty()
376 }
377
378 pub fn queries(&self) -> &VecDeque<PipelineQuery> {
380 &self.queries
381 }
382
383 pub fn to_batch_sql(&self) -> Option<(String, Vec<FilterValue>)> {
387 if self.queries.is_empty() {
388 return None;
389 }
390
391 match self.db_type {
393 DatabaseType::PostgreSQL | DatabaseType::MySQL => {}
394 _ => return None,
395 }
396
397 let mut combined = String::new();
398 let mut all_params = Vec::new();
399 let mut param_offset = 0;
400
401 for query in &self.queries {
402 if !combined.is_empty() {
403 combined.push_str(";\n");
404 }
405
406 if self.db_type == DatabaseType::PostgreSQL && !query.params.is_empty() {
408 let renumbered = renumber_params(&query.sql, param_offset);
409 combined.push_str(&renumbered);
410 param_offset += query.params.len();
411 } else {
412 combined.push_str(&query.sql);
413 }
414
415 all_params.extend(query.params.clone());
416 }
417
418 Some((combined, all_params))
419 }
420
421 pub fn into_batches(self) -> Vec<Vec<PipelineQuery>> {
423 let batch_size = self.config.max_batch_size;
424 let queries: Vec<_> = self.queries.into_iter().collect();
425
426 queries.chunks(batch_size).map(|c| c.to_vec()).collect()
427 }
428
429 pub fn to_transaction_sql(&self) -> Vec<(String, Vec<FilterValue>)> {
431 let mut statements = Vec::new();
432
433 statements.push((self.begin_transaction_sql().to_string(), Vec::new()));
435
436 for query in &self.queries {
438 statements.push((query.sql.clone(), query.params.clone()));
439 }
440
441 statements.push((self.commit_sql().to_string(), Vec::new()));
443
444 statements
445 }
446
447 fn begin_transaction_sql(&self) -> &'static str {
449 match self.db_type {
450 DatabaseType::PostgreSQL => "BEGIN",
451 DatabaseType::MySQL => "START TRANSACTION",
452 DatabaseType::SQLite => "BEGIN TRANSACTION",
453 DatabaseType::MSSQL => "BEGIN TRANSACTION",
454 }
455 }
456
457 fn commit_sql(&self) -> &'static str {
459 "COMMIT"
460 }
461
462 #[allow(dead_code)]
464 fn rollback_sql(&self) -> &'static str {
465 "ROLLBACK"
466 }
467}
468
469fn renumber_params(sql: &str, offset: usize) -> String {
471 let mut result = String::with_capacity(sql.len() + 10);
472 let mut chars = sql.chars().peekable();
473
474 while let Some(c) = chars.next() {
475 if c == '$' {
476 let mut num_str = String::new();
478 while let Some(&digit) = chars.peek() {
479 if digit.is_ascii_digit() {
480 num_str.push(digit);
481 chars.next();
482 } else {
483 break;
484 }
485 }
486
487 if let Ok(num) = num_str.parse::<usize>() {
488 result.push('$');
489 result.push_str(&(num + offset).to_string());
490 } else {
491 result.push('$');
492 result.push_str(&num_str);
493 }
494 } else {
495 result.push(c);
496 }
497 }
498
499 result
500}
501
502#[derive(Debug)]
504pub struct BulkInsertPipeline {
505 table: String,
506 columns: Vec<String>,
507 rows: Vec<Vec<FilterValue>>,
508 db_type: DatabaseType,
509 batch_size: usize,
510}
511
512impl BulkInsertPipeline {
513 pub fn new(table: impl Into<String>, columns: Vec<String>) -> Self {
515 Self {
516 table: table.into(),
517 columns,
518 rows: Vec::new(),
519 db_type: DatabaseType::PostgreSQL,
520 batch_size: 1000,
521 }
522 }
523
524 #[must_use]
526 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
527 self.db_type = db_type;
528 self
529 }
530
531 #[must_use]
533 pub fn with_batch_size(mut self, size: usize) -> Self {
534 self.batch_size = size.max(1);
535 self
536 }
537
538 pub fn add_row(&mut self, values: Vec<FilterValue>) {
540 assert_eq!(
541 values.len(),
542 self.columns.len(),
543 "Row has {} values, expected {}",
544 values.len(),
545 self.columns.len()
546 );
547 self.rows.push(values);
548 }
549
550 pub fn add_rows(&mut self, rows: impl IntoIterator<Item = Vec<FilterValue>>) {
552 for row in rows {
553 self.add_row(row);
554 }
555 }
556
557 pub fn len(&self) -> usize {
559 self.rows.len()
560 }
561
562 pub fn is_empty(&self) -> bool {
564 self.rows.is_empty()
565 }
566
567 pub fn to_insert_statements(&self) -> Vec<(String, Vec<FilterValue>)> {
569 if self.rows.is_empty() {
570 return Vec::new();
571 }
572
573 let mut statements = Vec::new();
574
575 for chunk in self.rows.chunks(self.batch_size) {
576 let (sql, params) = self.build_multi_insert(chunk);
577 statements.push((sql, params));
578 }
579
580 statements
581 }
582
583 fn build_multi_insert(&self, rows: &[Vec<FilterValue>]) -> (String, Vec<FilterValue>) {
584 let cols_str = self.columns.join(", ");
585 let mut sql = format!("INSERT INTO {} ({}) VALUES ", self.table, cols_str);
586 let mut params = Vec::with_capacity(rows.len() * self.columns.len());
587 let mut param_idx = 1;
588
589 for (row_idx, row) in rows.iter().enumerate() {
590 if row_idx > 0 {
591 sql.push_str(", ");
592 }
593 sql.push('(');
594
595 for (col_idx, value) in row.iter().enumerate() {
596 if col_idx > 0 {
597 sql.push_str(", ");
598 }
599
600 match self.db_type {
601 DatabaseType::PostgreSQL => {
602 sql.push_str(&format!("${}", param_idx));
603 }
604 DatabaseType::MySQL | DatabaseType::SQLite => {
605 sql.push('?');
606 }
607 DatabaseType::MSSQL => {
608 sql.push_str(&format!("@p{}", param_idx));
609 }
610 }
611
612 params.push(value.clone());
613 param_idx += 1;
614 }
615
616 sql.push(')');
617 }
618
619 (sql, params)
620 }
621
622 pub fn to_pipeline(self) -> QueryPipeline {
624 let statements = self.to_insert_statements();
625 let mut pipeline =
626 QueryPipeline::new(PipelineConfig::for_bulk_inserts()).for_database(self.db_type);
627
628 for (sql, params) in statements {
629 pipeline = pipeline.add_insert(sql, params);
630 }
631
632 pipeline
633 }
634}
635
636#[derive(Debug)]
638pub struct BulkUpdatePipeline {
639 table: String,
640 updates: Vec<BulkUpdate>,
641 db_type: DatabaseType,
642}
643
644#[derive(Debug, Clone)]
645struct BulkUpdate {
646 set: Vec<(String, FilterValue)>,
647 where_clause: Vec<(String, FilterValue)>,
648}
649
650impl BulkUpdatePipeline {
651 pub fn new(table: impl Into<String>) -> Self {
653 Self {
654 table: table.into(),
655 updates: Vec::new(),
656 db_type: DatabaseType::PostgreSQL,
657 }
658 }
659
660 #[must_use]
662 pub fn for_database(mut self, db_type: DatabaseType) -> Self {
663 self.db_type = db_type;
664 self
665 }
666
667 pub fn add_update(
669 &mut self,
670 set: Vec<(String, FilterValue)>,
671 where_clause: Vec<(String, FilterValue)>,
672 ) {
673 self.updates.push(BulkUpdate { set, where_clause });
674 }
675
676 pub fn len(&self) -> usize {
678 self.updates.len()
679 }
680
681 pub fn is_empty(&self) -> bool {
683 self.updates.is_empty()
684 }
685
686 pub fn to_update_statements(&self) -> Vec<(String, Vec<FilterValue>)> {
688 self.updates
689 .iter()
690 .map(|update| self.build_update(update))
691 .collect()
692 }
693
694 fn build_update(&self, update: &BulkUpdate) -> (String, Vec<FilterValue>) {
695 let mut sql = format!("UPDATE {} SET ", self.table);
696 let mut params = Vec::new();
697 let mut param_idx = 1;
698
699 for (idx, (col, val)) in update.set.iter().enumerate() {
701 if idx > 0 {
702 sql.push_str(", ");
703 }
704
705 match self.db_type {
706 DatabaseType::PostgreSQL => {
707 sql.push_str(&format!("{} = ${}", col, param_idx));
708 }
709 DatabaseType::MySQL | DatabaseType::SQLite => {
710 sql.push_str(&format!("{} = ?", col));
711 }
712 DatabaseType::MSSQL => {
713 sql.push_str(&format!("{} = @p{}", col, param_idx));
714 }
715 }
716
717 params.push(val.clone());
718 param_idx += 1;
719 }
720
721 if !update.where_clause.is_empty() {
723 sql.push_str(" WHERE ");
724
725 for (idx, (col, val)) in update.where_clause.iter().enumerate() {
726 if idx > 0 {
727 sql.push_str(" AND ");
728 }
729
730 match self.db_type {
731 DatabaseType::PostgreSQL => {
732 sql.push_str(&format!("{} = ${}", col, param_idx));
733 }
734 DatabaseType::MySQL | DatabaseType::SQLite => {
735 sql.push_str(&format!("{} = ?", col));
736 }
737 DatabaseType::MSSQL => {
738 sql.push_str(&format!("{} = @p{}", col, param_idx));
739 }
740 }
741
742 params.push(val.clone());
743 param_idx += 1;
744 }
745 }
746
747 (sql, params)
748 }
749
750 pub fn to_pipeline(self) -> QueryPipeline {
752 let statements = self.to_update_statements();
753 let mut pipeline =
754 QueryPipeline::new(PipelineConfig::for_bulk_updates()).for_database(self.db_type);
755
756 for (sql, params) in statements {
757 pipeline = pipeline.add_update(sql, params);
758 }
759
760 pipeline
761 }
762}
763
764#[allow(async_fn_in_trait)]
768pub trait PipelineExecutor {
769 async fn execute_pipeline(
771 &self,
772 pipeline: &QueryPipeline,
773 ) -> Result<PipelineResult, PipelineError>;
774}
775
776pub struct SimulatedExecutor {
778 latency: Duration,
779 error_rate: f64,
780}
781
782impl SimulatedExecutor {
783 pub fn new(latency: Duration, error_rate: f64) -> Self {
785 Self {
786 latency,
787 error_rate,
788 }
789 }
790
791 pub async fn execute(&self, pipeline: &QueryPipeline) -> PipelineResult {
793 let start = Instant::now();
794 let mut results = Vec::new();
795 let mut total_affected = 0u64;
796 let mut successful = 0;
797 let mut failed = 0;
798
799 for _query in pipeline.queries() {
801 tokio::time::sleep(self.latency / 10).await;
803
804 if rand_like_error(self.error_rate) {
806 results.push(QueryResult::Error {
807 message: "Simulated error".to_string(),
808 });
809 failed += 1;
810 } else {
811 let affected = 1;
812 total_affected += affected;
813 results.push(QueryResult::Executed {
814 rows_affected: affected,
815 });
816 successful += 1;
817 }
818 }
819
820 let total_duration = start.elapsed();
821 let batches_used = (pipeline.len() + 999) / 1000;
822
823 PipelineResult {
824 results,
825 total_affected,
826 total_returned: 0,
827 stats: PipelineStats {
828 total_queries: pipeline.len(),
829 successful,
830 failed,
831 total_duration,
832 wait_time: total_duration,
833 batches_used,
834 avg_batch_size: pipeline.len() as f64 / batches_used.max(1) as f64,
835 },
836 }
837 }
838}
839
840fn rand_like_error(rate: f64) -> bool {
842 use std::time::SystemTime;
843 let nanos = SystemTime::now()
844 .duration_since(SystemTime::UNIX_EPOCH)
845 .unwrap()
846 .subsec_nanos();
847 (nanos as f64 / u32::MAX as f64) < rate
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
855 fn test_pipeline_builder() {
856 let pipeline = QueryPipeline::new(PipelineConfig::default())
857 .add_insert(
858 "INSERT INTO users (name) VALUES ($1)",
859 vec![FilterValue::String("Alice".into())],
860 )
861 .add_insert(
862 "INSERT INTO users (name) VALUES ($1)",
863 vec![FilterValue::String("Bob".into())],
864 );
865
866 assert_eq!(pipeline.len(), 2);
867 }
868
869 #[test]
870 fn test_bulk_insert_pipeline() {
871 let mut pipeline =
872 BulkInsertPipeline::new("users", vec!["name".into(), "age".into()]).with_batch_size(2);
873
874 pipeline.add_row(vec![
875 FilterValue::String("Alice".into()),
876 FilterValue::Int(30),
877 ]);
878 pipeline.add_row(vec![
879 FilterValue::String("Bob".into()),
880 FilterValue::Int(25),
881 ]);
882 pipeline.add_row(vec![
883 FilterValue::String("Charlie".into()),
884 FilterValue::Int(35),
885 ]);
886
887 let statements = pipeline.to_insert_statements();
888
889 assert_eq!(statements.len(), 2);
891
892 let (sql1, params1) = &statements[0];
894 assert!(sql1.contains("VALUES"));
895 assert_eq!(params1.len(), 4); let (sql2, params2) = &statements[1];
899 assert!(sql2.contains("VALUES"));
900 assert_eq!(params2.len(), 2); }
902
903 #[test]
904 fn test_bulk_update_pipeline() {
905 let mut pipeline = BulkUpdatePipeline::new("users");
906
907 pipeline.add_update(
908 vec![("name".into(), FilterValue::String("Updated".into()))],
909 vec![("id".into(), FilterValue::Int(1))],
910 );
911
912 let statements = pipeline.to_update_statements();
913 assert_eq!(statements.len(), 1);
914
915 let (sql, params) = &statements[0];
916 assert!(sql.contains("UPDATE users SET"));
917 assert!(sql.contains("WHERE"));
918 assert_eq!(params.len(), 2);
919 }
920
921 #[test]
922 fn test_renumber_params() {
923 let sql = "SELECT * FROM users WHERE id = $1 AND name = $2";
924 let renumbered = renumber_params(sql, 5);
925 assert_eq!(
926 renumbered,
927 "SELECT * FROM users WHERE id = $6 AND name = $7"
928 );
929 }
930
931 #[test]
932 fn test_batch_sql() {
933 let pipeline = QueryPipeline::new(PipelineConfig::default())
934 .for_database(DatabaseType::PostgreSQL)
935 .add_query("SELECT 1", vec![])
936 .add_query("SELECT 2", vec![]);
937
938 let batch = pipeline.to_batch_sql();
939 assert!(batch.is_some());
940
941 let (sql, _) = batch.unwrap();
942 assert!(sql.contains("SELECT 1"));
943 assert!(sql.contains("SELECT 2"));
944 }
945
946 #[test]
947 fn test_transaction_sql() {
948 let pipeline = QueryPipeline::new(PipelineConfig::default())
949 .for_database(DatabaseType::PostgreSQL)
950 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(1)]);
951
952 let statements = pipeline.to_transaction_sql();
953
954 assert_eq!(statements.len(), 3);
955 assert_eq!(statements[0].0, "BEGIN");
956 assert!(statements[1].0.contains("INSERT"));
957 assert_eq!(statements[2].0, "COMMIT");
958 }
959
960 #[tokio::test]
961 async fn test_simulated_executor() {
962 let executor = SimulatedExecutor::new(Duration::from_millis(1), 0.0);
963
964 let pipeline = QueryPipeline::new(PipelineConfig::default())
965 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(1)])
966 .add_insert("INSERT INTO users VALUES ($1)", vec![FilterValue::Int(2)]);
967
968 let result = executor.execute(&pipeline).await;
969
970 assert!(result.all_succeeded());
971 assert_eq!(result.stats.total_queries, 2);
972 assert_eq!(result.total_affected, 2);
973 }
974}