1use serde::{Deserialize, Serialize};
34
35use crate::error::{QueryError, QueryResult};
36use crate::sql::DatabaseType;
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Cte {
41 pub name: String,
43 pub columns: Vec<String>,
45 pub query: String,
47 pub recursive: bool,
49 pub materialized: Option<Materialized>,
51 pub search: Option<SearchClause>,
53 pub cycle: Option<CycleClause>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum Materialized {
60 Yes,
62 No,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
68pub struct SearchClause {
69 pub method: SearchMethod,
71 pub columns: Vec<String>,
73 pub set_column: String,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum SearchMethod {
80 BreadthFirst,
82 DepthFirst,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct CycleClause {
89 pub columns: Vec<String>,
91 pub set_column: String,
93 pub using_column: String,
95 pub mark_value: Option<String>,
97 pub default_value: Option<String>,
99}
100
101impl Cte {
102 pub fn new(name: impl Into<String>) -> Self {
104 Self {
105 name: name.into(),
106 columns: Vec::new(),
107 query: String::new(),
108 recursive: false,
109 materialized: None,
110 search: None,
111 cycle: None,
112 }
113 }
114
115 pub fn builder(name: impl Into<String>) -> CteBuilder {
117 CteBuilder::new(name)
118 }
119
120 pub fn columns<I, S>(mut self, columns: I) -> Self
122 where
123 I: IntoIterator<Item = S>,
124 S: Into<String>,
125 {
126 self.columns = columns.into_iter().map(Into::into).collect();
127 self
128 }
129
130 pub fn as_query(mut self, query: impl Into<String>) -> Self {
132 self.query = query.into();
133 self
134 }
135
136 pub fn recursive(mut self) -> Self {
138 self.recursive = true;
139 self
140 }
141
142 pub fn materialized(mut self, mat: Materialized) -> Self {
144 self.materialized = Some(mat);
145 self
146 }
147
148 pub fn to_sql(&self, db_type: DatabaseType) -> String {
150 let mut sql = self.name.clone();
151
152 if !self.columns.is_empty() {
154 sql.push_str(" (");
155 sql.push_str(&self.columns.join(", "));
156 sql.push(')');
157 }
158
159 sql.push_str(" AS ");
160
161 if db_type == DatabaseType::PostgreSQL {
163 if let Some(mat) = self.materialized {
164 match mat {
165 Materialized::Yes => sql.push_str("MATERIALIZED "),
166 Materialized::No => sql.push_str("NOT MATERIALIZED "),
167 }
168 }
169 }
170
171 sql.push('(');
172 sql.push_str(&self.query);
173 sql.push(')');
174
175 if db_type == DatabaseType::PostgreSQL {
177 if let Some(ref search) = self.search {
178 sql.push_str(" SEARCH ");
179 sql.push_str(match search.method {
180 SearchMethod::BreadthFirst => "BREADTH FIRST BY ",
181 SearchMethod::DepthFirst => "DEPTH FIRST BY ",
182 });
183 sql.push_str(&search.columns.join(", "));
184 sql.push_str(" SET ");
185 sql.push_str(&search.set_column);
186 }
187
188 if let Some(ref cycle) = self.cycle {
189 sql.push_str(" CYCLE ");
190 sql.push_str(&cycle.columns.join(", "));
191 sql.push_str(" SET ");
192 sql.push_str(&cycle.set_column);
193 if let (Some(mark), Some(default)) = (&cycle.mark_value, &cycle.default_value) {
194 sql.push_str(" TO ");
195 sql.push_str(mark);
196 sql.push_str(" DEFAULT ");
197 sql.push_str(default);
198 }
199 sql.push_str(" USING ");
200 sql.push_str(&cycle.using_column);
201 }
202 }
203
204 sql
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct CteBuilder {
211 name: String,
212 columns: Vec<String>,
213 query: Option<String>,
214 recursive: bool,
215 materialized: Option<Materialized>,
216 search: Option<SearchClause>,
217 cycle: Option<CycleClause>,
218}
219
220impl CteBuilder {
221 pub fn new(name: impl Into<String>) -> Self {
223 Self {
224 name: name.into(),
225 columns: Vec::new(),
226 query: None,
227 recursive: false,
228 materialized: None,
229 search: None,
230 cycle: None,
231 }
232 }
233
234 pub fn columns<I, S>(mut self, columns: I) -> Self
236 where
237 I: IntoIterator<Item = S>,
238 S: Into<String>,
239 {
240 self.columns = columns.into_iter().map(Into::into).collect();
241 self
242 }
243
244 pub fn as_query(mut self, query: impl Into<String>) -> Self {
246 self.query = Some(query.into());
247 self
248 }
249
250 pub fn recursive(mut self) -> Self {
252 self.recursive = true;
253 self
254 }
255
256 pub fn materialized(mut self) -> Self {
258 self.materialized = Some(Materialized::Yes);
259 self
260 }
261
262 pub fn not_materialized(mut self) -> Self {
264 self.materialized = Some(Materialized::No);
265 self
266 }
267
268 pub fn search_breadth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
270 where
271 I: IntoIterator<Item = S>,
272 S: Into<String>,
273 {
274 self.search = Some(SearchClause {
275 method: SearchMethod::BreadthFirst,
276 columns: columns.into_iter().map(Into::into).collect(),
277 set_column: set_column.into(),
278 });
279 self
280 }
281
282 pub fn search_depth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
284 where
285 I: IntoIterator<Item = S>,
286 S: Into<String>,
287 {
288 self.search = Some(SearchClause {
289 method: SearchMethod::DepthFirst,
290 columns: columns.into_iter().map(Into::into).collect(),
291 set_column: set_column.into(),
292 });
293 self
294 }
295
296 pub fn cycle<I, S>(
298 mut self,
299 columns: I,
300 set_column: impl Into<String>,
301 using_column: impl Into<String>,
302 ) -> Self
303 where
304 I: IntoIterator<Item = S>,
305 S: Into<String>,
306 {
307 self.cycle = Some(CycleClause {
308 columns: columns.into_iter().map(Into::into).collect(),
309 set_column: set_column.into(),
310 using_column: using_column.into(),
311 mark_value: None,
312 default_value: None,
313 });
314 self
315 }
316
317 pub fn build(self) -> QueryResult<Cte> {
319 let query = self.query.ok_or_else(|| {
320 QueryError::invalid_input("query", "CTE requires a query (use as_query())")
321 })?;
322
323 Ok(Cte {
324 name: self.name,
325 columns: self.columns,
326 query,
327 recursive: self.recursive,
328 materialized: self.materialized,
329 search: self.search,
330 cycle: self.cycle,
331 })
332 }
333}
334
335#[derive(Debug, Clone, Default, Serialize, Deserialize)]
337pub struct WithClause {
338 pub ctes: Vec<Cte>,
340 pub recursive: bool,
342 pub main_query: Option<String>,
344}
345
346impl WithClause {
347 pub fn new() -> Self {
349 Self::default()
350 }
351
352 pub fn cte(mut self, cte: Cte) -> Self {
354 if cte.recursive {
355 self.recursive = true;
356 }
357 self.ctes.push(cte);
358 self
359 }
360
361 pub fn ctes<I>(mut self, ctes: I) -> Self
363 where
364 I: IntoIterator<Item = Cte>,
365 {
366 for cte in ctes {
367 self = self.cte(cte);
368 }
369 self
370 }
371
372 pub fn main_query(mut self, query: impl Into<String>) -> Self {
374 self.main_query = Some(query.into());
375 self
376 }
377
378 pub fn select(self, columns: impl Into<String>) -> WithQueryBuilder {
380 WithQueryBuilder {
381 with_clause: self,
382 select: columns.into(),
383 from: None,
384 where_clause: None,
385 order_by: None,
386 limit: None,
387 }
388 }
389
390 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
392 if self.ctes.is_empty() {
393 return Err(QueryError::invalid_input("ctes", "WITH clause requires at least one CTE"));
394 }
395
396 let mut sql = String::with_capacity(256);
397
398 sql.push_str("WITH ");
399 if self.recursive {
400 sql.push_str("RECURSIVE ");
401 }
402
403 let cte_sqls: Vec<String> = self.ctes.iter().map(|c| c.to_sql(db_type)).collect();
404 sql.push_str(&cte_sqls.join(", "));
405
406 if let Some(ref main) = self.main_query {
407 sql.push(' ');
408 sql.push_str(main);
409 }
410
411 Ok(sql)
412 }
413}
414
415#[derive(Debug, Clone)]
417pub struct WithQueryBuilder {
418 with_clause: WithClause,
419 select: String,
420 from: Option<String>,
421 where_clause: Option<String>,
422 order_by: Option<String>,
423 limit: Option<u64>,
424}
425
426impl WithQueryBuilder {
427 pub fn from(mut self, table: impl Into<String>) -> Self {
429 self.from = Some(table.into());
430 self
431 }
432
433 pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
435 self.where_clause = Some(condition.into());
436 self
437 }
438
439 pub fn order_by(mut self, order: impl Into<String>) -> Self {
441 self.order_by = Some(order.into());
442 self
443 }
444
445 pub fn limit(mut self, limit: u64) -> Self {
447 self.limit = Some(limit);
448 self
449 }
450
451 pub fn build(mut self, db_type: DatabaseType) -> QueryResult<String> {
453 let mut main = format!("SELECT {}", self.select);
455
456 if let Some(from) = self.from {
457 main.push_str(" FROM ");
458 main.push_str(&from);
459 }
460
461 if let Some(where_clause) = self.where_clause {
462 main.push_str(" WHERE ");
463 main.push_str(&where_clause);
464 }
465
466 let has_order_by = self.order_by.is_some();
467 if let Some(order) = self.order_by {
468 main.push_str(" ORDER BY ");
469 main.push_str(&order);
470 }
471
472 if let Some(limit) = self.limit {
473 match db_type {
474 DatabaseType::MSSQL => {
475 if has_order_by {
477 main.push_str(&format!(" OFFSET 0 ROWS FETCH NEXT {} ROWS ONLY", limit));
478 } else {
479 main = main.replacen("SELECT ", &format!("SELECT TOP {} ", limit), 1);
481 }
482 }
483 _ => {
484 main.push_str(&format!(" LIMIT {}", limit));
485 }
486 }
487 }
488
489 self.with_clause.main_query = Some(main);
490 self.with_clause.to_sql(db_type)
491 }
492}
493
494pub mod patterns {
496 use super::*;
497
498 pub fn tree_traversal(
500 cte_name: &str,
501 table: &str,
502 id_col: &str,
503 parent_col: &str,
504 root_condition: &str,
505 ) -> Cte {
506 let base_query = format!(
507 "SELECT {id}, {parent}, 1 AS depth FROM {table} WHERE {root}",
508 id = id_col,
509 parent = parent_col,
510 table = table,
511 root = root_condition
512 );
513
514 let recursive_query = format!(
515 "SELECT t.{id}, t.{parent}, c.depth + 1 FROM {table} t \
516 INNER JOIN {cte} c ON t.{parent} = c.{id}",
517 id = id_col,
518 parent = parent_col,
519 table = table,
520 cte = cte_name
521 );
522
523 Cte::new(cte_name)
524 .columns([id_col, parent_col, "depth"])
525 .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
526 .recursive()
527 }
528
529 pub fn graph_path(
531 cte_name: &str,
532 edges_table: &str,
533 from_col: &str,
534 to_col: &str,
535 start_node: &str,
536 ) -> Cte {
537 let base_query = format!(
538 "SELECT {from_col}, {to_col}, ARRAY[{from_col}] AS path, 1 AS length \
539 FROM {table} WHERE {from_col} = {start}",
540 from_col = from_col,
541 to_col = to_col,
542 table = edges_table,
543 start = start_node
544 );
545
546 let recursive_query = format!(
547 "SELECT e.{from_col}, e.{to_col}, p.path || e.{to_col}, p.length + 1 \
548 FROM {table} e \
549 INNER JOIN {cte} p ON e.{from_col} = p.{to_col} \
550 WHERE NOT e.{to_col} = ANY(p.path)",
551 from_col = from_col,
552 to_col = to_col,
553 table = edges_table,
554 cte = cte_name
555 );
556
557 Cte::new(cte_name)
558 .columns([from_col, to_col, "path", "length"])
559 .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
560 .recursive()
561 }
562
563 pub fn paginated(
565 cte_name: &str,
566 query: &str,
567 order_by: &str,
568 ) -> Cte {
569 let paginated_query = format!(
570 "SELECT *, ROW_NUMBER() OVER (ORDER BY {}) AS row_num FROM ({})",
571 order_by, query
572 );
573
574 Cte::new(cte_name).as_query(paginated_query)
575 }
576
577 pub fn running_total(
579 cte_name: &str,
580 table: &str,
581 value_col: &str,
582 order_col: &str,
583 partition_col: Option<&str>,
584 ) -> Cte {
585 let partition = partition_col
586 .map(|p| format!("PARTITION BY {} ", p))
587 .unwrap_or_default();
588
589 let query = format!(
590 "SELECT *, SUM({value}) OVER ({partition}ORDER BY {order}) AS running_total \
591 FROM {table}",
592 value = value_col,
593 partition = partition,
594 order = order_col,
595 table = table
596 );
597
598 Cte::new(cte_name).as_query(query)
599 }
600}
601
602pub mod mongodb {
604 use serde::{Deserialize, Serialize};
605 use serde_json::Value as JsonValue;
606
607 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
609 pub struct Lookup {
610 pub from: String,
612 pub local_field: Option<String>,
614 pub foreign_field: Option<String>,
616 pub as_field: String,
618 pub pipeline: Option<Vec<JsonValue>>,
620 pub let_vars: Option<serde_json::Map<String, JsonValue>>,
622 }
623
624 impl Lookup {
625 pub fn simple(from: impl Into<String>, local: impl Into<String>, foreign: impl Into<String>, as_field: impl Into<String>) -> Self {
627 Self {
628 from: from.into(),
629 local_field: Some(local.into()),
630 foreign_field: Some(foreign.into()),
631 as_field: as_field.into(),
632 pipeline: None,
633 let_vars: None,
634 }
635 }
636
637 pub fn with_pipeline(from: impl Into<String>, as_field: impl Into<String>) -> LookupBuilder {
639 LookupBuilder {
640 from: from.into(),
641 as_field: as_field.into(),
642 pipeline: Vec::new(),
643 let_vars: serde_json::Map::new(),
644 }
645 }
646
647 pub fn to_bson(&self) -> JsonValue {
649 let mut lookup = serde_json::Map::new();
650 lookup.insert("from".to_string(), JsonValue::String(self.from.clone()));
651
652 if let (Some(local), Some(foreign)) = (&self.local_field, &self.foreign_field) {
653 lookup.insert("localField".to_string(), JsonValue::String(local.clone()));
654 lookup.insert("foreignField".to_string(), JsonValue::String(foreign.clone()));
655 }
656
657 lookup.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
658
659 if let Some(ref pipeline) = self.pipeline {
660 lookup.insert("pipeline".to_string(), JsonValue::Array(pipeline.clone()));
661 }
662
663 if let Some(ref vars) = self.let_vars {
664 if !vars.is_empty() {
665 lookup.insert("let".to_string(), JsonValue::Object(vars.clone()));
666 }
667 }
668
669 serde_json::json!({ "$lookup": lookup })
670 }
671 }
672
673 #[derive(Debug, Clone)]
675 pub struct LookupBuilder {
676 from: String,
677 as_field: String,
678 pipeline: Vec<JsonValue>,
679 let_vars: serde_json::Map<String, JsonValue>,
680 }
681
682 impl LookupBuilder {
683 pub fn let_var(mut self, name: impl Into<String>, expr: impl Into<String>) -> Self {
685 self.let_vars.insert(
686 name.into(),
687 JsonValue::String(format!("${}", expr.into())),
688 );
689 self
690 }
691
692 pub fn match_expr(mut self, expr: JsonValue) -> Self {
694 self.pipeline.push(serde_json::json!({ "$match": { "$expr": expr } }));
695 self
696 }
697
698 pub fn stage(mut self, stage: JsonValue) -> Self {
700 self.pipeline.push(stage);
701 self
702 }
703
704 pub fn project(mut self, fields: JsonValue) -> Self {
706 self.pipeline.push(serde_json::json!({ "$project": fields }));
707 self
708 }
709
710 pub fn limit(mut self, n: u64) -> Self {
712 self.pipeline.push(serde_json::json!({ "$limit": n }));
713 self
714 }
715
716 pub fn sort(mut self, fields: JsonValue) -> Self {
718 self.pipeline.push(serde_json::json!({ "$sort": fields }));
719 self
720 }
721
722 pub fn build(self) -> Lookup {
724 Lookup {
725 from: self.from,
726 local_field: None,
727 foreign_field: None,
728 as_field: self.as_field,
729 pipeline: if self.pipeline.is_empty() {
730 None
731 } else {
732 Some(self.pipeline)
733 },
734 let_vars: if self.let_vars.is_empty() {
735 None
736 } else {
737 Some(self.let_vars)
738 },
739 }
740 }
741 }
742
743 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
745 pub struct GraphLookup {
746 pub from: String,
748 pub start_with: String,
750 pub connect_from_field: String,
752 pub connect_to_field: String,
754 pub as_field: String,
756 pub max_depth: Option<u32>,
758 pub depth_field: Option<String>,
760 pub restrict_search_with_match: Option<JsonValue>,
762 }
763
764 impl GraphLookup {
765 pub fn new(
767 from: impl Into<String>,
768 start_with: impl Into<String>,
769 connect_from: impl Into<String>,
770 connect_to: impl Into<String>,
771 as_field: impl Into<String>,
772 ) -> Self {
773 Self {
774 from: from.into(),
775 start_with: start_with.into(),
776 connect_from_field: connect_from.into(),
777 connect_to_field: connect_to.into(),
778 as_field: as_field.into(),
779 max_depth: None,
780 depth_field: None,
781 restrict_search_with_match: None,
782 }
783 }
784
785 pub fn max_depth(mut self, depth: u32) -> Self {
787 self.max_depth = Some(depth);
788 self
789 }
790
791 pub fn depth_field(mut self, field: impl Into<String>) -> Self {
793 self.depth_field = Some(field.into());
794 self
795 }
796
797 pub fn restrict_search(mut self, filter: JsonValue) -> Self {
799 self.restrict_search_with_match = Some(filter);
800 self
801 }
802
803 pub fn to_bson(&self) -> JsonValue {
805 let mut graph = serde_json::Map::new();
806 graph.insert("from".to_string(), JsonValue::String(self.from.clone()));
807 graph.insert("startWith".to_string(), JsonValue::String(format!("${}", self.start_with)));
808 graph.insert("connectFromField".to_string(), JsonValue::String(self.connect_from_field.clone()));
809 graph.insert("connectToField".to_string(), JsonValue::String(self.connect_to_field.clone()));
810 graph.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
811
812 if let Some(max) = self.max_depth {
813 graph.insert("maxDepth".to_string(), JsonValue::Number(max.into()));
814 }
815
816 if let Some(ref field) = self.depth_field {
817 graph.insert("depthField".to_string(), JsonValue::String(field.clone()));
818 }
819
820 if let Some(ref filter) = self.restrict_search_with_match {
821 graph.insert("restrictSearchWithMatch".to_string(), filter.clone());
822 }
823
824 serde_json::json!({ "$graphLookup": graph })
825 }
826 }
827
828 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
830 pub struct UnionWith {
831 pub coll: String,
833 pub pipeline: Option<Vec<JsonValue>>,
835 }
836
837 impl UnionWith {
838 pub fn collection(coll: impl Into<String>) -> Self {
840 Self {
841 coll: coll.into(),
842 pipeline: None,
843 }
844 }
845
846 pub fn with_pipeline(coll: impl Into<String>, pipeline: Vec<JsonValue>) -> Self {
848 Self {
849 coll: coll.into(),
850 pipeline: Some(pipeline),
851 }
852 }
853
854 pub fn to_bson(&self) -> JsonValue {
856 if let Some(ref pipeline) = self.pipeline {
857 serde_json::json!({
858 "$unionWith": {
859 "coll": self.coll,
860 "pipeline": pipeline
861 }
862 })
863 } else {
864 serde_json::json!({ "$unionWith": self.coll })
865 }
866 }
867 }
868
869 pub fn lookup(from: &str, local: &str, foreign: &str, as_field: &str) -> Lookup {
871 Lookup::simple(from, local, foreign, as_field)
872 }
873
874 pub fn lookup_pipeline(from: &str, as_field: &str) -> LookupBuilder {
876 Lookup::with_pipeline(from, as_field)
877 }
878
879 pub fn graph_lookup(
881 from: &str,
882 start_with: &str,
883 connect_from: &str,
884 connect_to: &str,
885 as_field: &str,
886 ) -> GraphLookup {
887 GraphLookup::new(from, start_with, connect_from, connect_to, as_field)
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894
895 #[test]
896 fn test_simple_cte() {
897 let cte = Cte::new("active_users")
898 .as_query("SELECT * FROM users WHERE active = true");
899
900 let sql = cte.to_sql(DatabaseType::PostgreSQL);
901 assert!(sql.contains("active_users AS"));
902 assert!(sql.contains("SELECT * FROM users"));
903 }
904
905 #[test]
906 fn test_cte_with_columns() {
907 let cte = Cte::new("user_stats")
908 .columns(["id", "name", "total"])
909 .as_query("SELECT id, name, COUNT(*) FROM orders GROUP BY user_id");
910
911 let sql = cte.to_sql(DatabaseType::PostgreSQL);
912 assert!(sql.contains("user_stats (id, name, total) AS"));
913 }
914
915 #[test]
916 fn test_recursive_cte() {
917 let cte = Cte::new("subordinates")
918 .columns(["id", "name", "manager_id", "depth"])
919 .as_query(
920 "SELECT id, name, manager_id, 1 FROM employees WHERE manager_id IS NULL \
921 UNION ALL \
922 SELECT e.id, e.name, e.manager_id, s.depth + 1 \
923 FROM employees e JOIN subordinates s ON e.manager_id = s.id"
924 )
925 .recursive();
926
927 assert!(cte.recursive);
928 }
929
930 #[test]
931 fn test_materialized_cte() {
932 let cte = Cte::new("expensive_query")
933 .as_query("SELECT * FROM big_table WHERE complex_condition")
934 .materialized(Materialized::Yes);
935
936 let sql = cte.to_sql(DatabaseType::PostgreSQL);
937 assert!(sql.contains("MATERIALIZED"));
938 }
939
940 #[test]
941 fn test_with_clause() {
942 let cte1 = Cte::new("cte1").as_query("SELECT 1");
943 let cte2 = Cte::new("cte2").as_query("SELECT 2");
944
945 let with = WithClause::new()
946 .cte(cte1)
947 .cte(cte2)
948 .main_query("SELECT * FROM cte1, cte2");
949
950 let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
951 assert!(sql.starts_with("WITH "));
952 assert!(sql.contains("cte1 AS"));
953 assert!(sql.contains("cte2 AS"));
954 assert!(sql.contains("SELECT * FROM cte1, cte2"));
955 }
956
957 #[test]
958 fn test_recursive_with_clause() {
959 let cte = Cte::new("numbers")
960 .as_query("SELECT 1 AS n UNION ALL SELECT n + 1 FROM numbers WHERE n < 10")
961 .recursive();
962
963 let with = WithClause::new()
964 .cte(cte)
965 .main_query("SELECT * FROM numbers");
966
967 let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
968 assert!(sql.starts_with("WITH RECURSIVE"));
969 }
970
971 #[test]
972 fn test_with_query_builder() {
973 let cte = Cte::new("active")
974 .as_query("SELECT * FROM users WHERE active = true");
975
976 let sql = WithClause::new()
977 .cte(cte)
978 .select("*")
979 .from("active")
980 .where_clause("role = 'admin'")
981 .order_by("name")
982 .limit(10)
983 .build(DatabaseType::PostgreSQL)
984 .unwrap();
985
986 assert!(sql.contains("WITH active AS"));
987 assert!(sql.contains("SELECT *"));
988 assert!(sql.contains("FROM active"));
989 assert!(sql.contains("WHERE role = 'admin'"));
990 assert!(sql.contains("ORDER BY name"));
991 assert!(sql.contains("LIMIT 10"));
992 }
993
994 #[test]
995 fn test_mssql_limit() {
996 let cte = Cte::new("data").as_query("SELECT * FROM table1");
997
998 let sql = WithClause::new()
999 .cte(cte)
1000 .select("*")
1001 .from("data")
1002 .order_by("id")
1003 .limit(10)
1004 .build(DatabaseType::MSSQL)
1005 .unwrap();
1006
1007 assert!(sql.contains("OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY"));
1008 }
1009
1010 #[test]
1011 fn test_cte_builder() {
1012 let cte = CteBuilder::new("stats")
1013 .columns(["a", "b"])
1014 .as_query("SELECT 1, 2")
1015 .materialized()
1016 .build()
1017 .unwrap();
1018
1019 assert_eq!(cte.name, "stats");
1020 assert_eq!(cte.columns, vec!["a", "b"]);
1021 assert_eq!(cte.materialized, Some(Materialized::Yes));
1022 }
1023
1024 mod pattern_tests {
1025 use super::super::patterns::*;
1026
1027 #[test]
1028 fn test_tree_traversal_pattern() {
1029 let cte = tree_traversal(
1030 "org_tree",
1031 "employees",
1032 "id",
1033 "manager_id",
1034 "manager_id IS NULL"
1035 );
1036
1037 assert!(cte.recursive);
1038 assert!(cte.query.contains("UNION ALL"));
1039 assert!(cte.query.contains("depth + 1"));
1040 }
1041
1042 #[test]
1043 fn test_running_total_pattern() {
1044 let cte = running_total(
1045 "account_balance",
1046 "transactions",
1047 "amount",
1048 "transaction_date",
1049 Some("account_id")
1050 );
1051
1052 assert!(cte.query.contains("SUM(amount)"));
1053 assert!(cte.query.contains("PARTITION BY account_id"));
1054 assert!(cte.query.contains("running_total"));
1055 }
1056 }
1057
1058 mod mongodb_tests {
1059 use super::super::mongodb::*;
1060
1061 #[test]
1062 fn test_simple_lookup() {
1063 let lookup = Lookup::simple("orders", "user_id", "_id", "user_orders");
1064 let bson = lookup.to_bson();
1065
1066 assert_eq!(bson["$lookup"]["from"], "orders");
1067 assert_eq!(bson["$lookup"]["localField"], "user_id");
1068 assert_eq!(bson["$lookup"]["foreignField"], "_id");
1069 assert_eq!(bson["$lookup"]["as"], "user_orders");
1070 }
1071
1072 #[test]
1073 fn test_lookup_with_pipeline() {
1074 let lookup = Lookup::with_pipeline("inventory", "stock_items")
1075 .let_var("order_item", "item")
1076 .match_expr(serde_json::json!({
1077 "$eq": ["$sku", "$$order_item"]
1078 }))
1079 .project(serde_json::json!({ "inStock": 1 }))
1080 .build();
1081
1082 let bson = lookup.to_bson();
1083 assert!(bson["$lookup"]["pipeline"].is_array());
1084 assert!(bson["$lookup"]["let"].is_object());
1085 }
1086
1087 #[test]
1088 fn test_graph_lookup() {
1089 let lookup = GraphLookup::new(
1090 "employees",
1091 "reportsTo",
1092 "reportsTo",
1093 "name",
1094 "reportingHierarchy"
1095 )
1096 .max_depth(5)
1097 .depth_field("level");
1098
1099 let bson = lookup.to_bson();
1100 assert_eq!(bson["$graphLookup"]["from"], "employees");
1101 assert_eq!(bson["$graphLookup"]["maxDepth"], 5);
1102 assert_eq!(bson["$graphLookup"]["depthField"], "level");
1103 }
1104
1105 #[test]
1106 fn test_union_with() {
1107 let union = UnionWith::collection("archived_orders");
1108 let bson = union.to_bson();
1109
1110 assert_eq!(bson["$unionWith"], "archived_orders");
1111 }
1112
1113 #[test]
1114 fn test_union_with_pipeline() {
1115 let union = UnionWith::with_pipeline(
1116 "archive",
1117 vec![serde_json::json!({ "$match": { "year": 2023 } })]
1118 );
1119 let bson = union.to_bson();
1120
1121 assert!(bson["$unionWith"]["pipeline"].is_array());
1122 }
1123 }
1124}
1125