1use serde::{Deserialize, Serialize};
34
35use crate::error::{QueryError, QueryResult};
36use crate::sql::DatabaseType;
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Cte {
41 pub name: String,
43 pub columns: Vec<String>,
45 pub query: String,
47 pub recursive: bool,
49 pub materialized: Option<Materialized>,
51 pub search: Option<SearchClause>,
53 pub cycle: Option<CycleClause>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum Materialized {
60 Yes,
62 No,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
68pub struct SearchClause {
69 pub method: SearchMethod,
71 pub columns: Vec<String>,
73 pub set_column: String,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum SearchMethod {
80 BreadthFirst,
82 DepthFirst,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct CycleClause {
89 pub columns: Vec<String>,
91 pub set_column: String,
93 pub using_column: String,
95 pub mark_value: Option<String>,
97 pub default_value: Option<String>,
99}
100
101impl Cte {
102 pub fn new(name: impl Into<String>) -> Self {
104 Self {
105 name: name.into(),
106 columns: Vec::new(),
107 query: String::new(),
108 recursive: false,
109 materialized: None,
110 search: None,
111 cycle: None,
112 }
113 }
114
115 pub fn builder(name: impl Into<String>) -> CteBuilder {
117 CteBuilder::new(name)
118 }
119
120 pub fn columns<I, S>(mut self, columns: I) -> Self
122 where
123 I: IntoIterator<Item = S>,
124 S: Into<String>,
125 {
126 self.columns = columns.into_iter().map(Into::into).collect();
127 self
128 }
129
130 pub fn as_query(mut self, query: impl Into<String>) -> Self {
132 self.query = query.into();
133 self
134 }
135
136 pub fn recursive(mut self) -> Self {
138 self.recursive = true;
139 self
140 }
141
142 pub fn materialized(mut self, mat: Materialized) -> Self {
144 self.materialized = Some(mat);
145 self
146 }
147
148 pub fn to_sql(&self, db_type: DatabaseType) -> String {
150 let mut sql = self.name.clone();
151
152 if !self.columns.is_empty() {
154 sql.push_str(" (");
155 sql.push_str(&self.columns.join(", "));
156 sql.push(')');
157 }
158
159 sql.push_str(" AS ");
160
161 if db_type == DatabaseType::PostgreSQL {
163 if let Some(mat) = self.materialized {
164 match mat {
165 Materialized::Yes => sql.push_str("MATERIALIZED "),
166 Materialized::No => sql.push_str("NOT MATERIALIZED "),
167 }
168 }
169 }
170
171 sql.push('(');
172 sql.push_str(&self.query);
173 sql.push(')');
174
175 if db_type == DatabaseType::PostgreSQL {
177 if let Some(ref search) = self.search {
178 sql.push_str(" SEARCH ");
179 sql.push_str(match search.method {
180 SearchMethod::BreadthFirst => "BREADTH FIRST BY ",
181 SearchMethod::DepthFirst => "DEPTH FIRST BY ",
182 });
183 sql.push_str(&search.columns.join(", "));
184 sql.push_str(" SET ");
185 sql.push_str(&search.set_column);
186 }
187
188 if let Some(ref cycle) = self.cycle {
189 sql.push_str(" CYCLE ");
190 sql.push_str(&cycle.columns.join(", "));
191 sql.push_str(" SET ");
192 sql.push_str(&cycle.set_column);
193 if let (Some(mark), Some(default)) = (&cycle.mark_value, &cycle.default_value) {
194 sql.push_str(" TO ");
195 sql.push_str(mark);
196 sql.push_str(" DEFAULT ");
197 sql.push_str(default);
198 }
199 sql.push_str(" USING ");
200 sql.push_str(&cycle.using_column);
201 }
202 }
203
204 sql
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct CteBuilder {
211 name: String,
212 columns: Vec<String>,
213 query: Option<String>,
214 recursive: bool,
215 materialized: Option<Materialized>,
216 search: Option<SearchClause>,
217 cycle: Option<CycleClause>,
218}
219
220impl CteBuilder {
221 pub fn new(name: impl Into<String>) -> Self {
223 Self {
224 name: name.into(),
225 columns: Vec::new(),
226 query: None,
227 recursive: false,
228 materialized: None,
229 search: None,
230 cycle: None,
231 }
232 }
233
234 pub fn columns<I, S>(mut self, columns: I) -> Self
236 where
237 I: IntoIterator<Item = S>,
238 S: Into<String>,
239 {
240 self.columns = columns.into_iter().map(Into::into).collect();
241 self
242 }
243
244 pub fn as_query(mut self, query: impl Into<String>) -> Self {
246 self.query = Some(query.into());
247 self
248 }
249
250 pub fn recursive(mut self) -> Self {
252 self.recursive = true;
253 self
254 }
255
256 pub fn materialized(mut self) -> Self {
258 self.materialized = Some(Materialized::Yes);
259 self
260 }
261
262 pub fn not_materialized(mut self) -> Self {
264 self.materialized = Some(Materialized::No);
265 self
266 }
267
268 pub fn search_breadth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
270 where
271 I: IntoIterator<Item = S>,
272 S: Into<String>,
273 {
274 self.search = Some(SearchClause {
275 method: SearchMethod::BreadthFirst,
276 columns: columns.into_iter().map(Into::into).collect(),
277 set_column: set_column.into(),
278 });
279 self
280 }
281
282 pub fn search_depth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
284 where
285 I: IntoIterator<Item = S>,
286 S: Into<String>,
287 {
288 self.search = Some(SearchClause {
289 method: SearchMethod::DepthFirst,
290 columns: columns.into_iter().map(Into::into).collect(),
291 set_column: set_column.into(),
292 });
293 self
294 }
295
296 pub fn cycle<I, S>(
298 mut self,
299 columns: I,
300 set_column: impl Into<String>,
301 using_column: impl Into<String>,
302 ) -> Self
303 where
304 I: IntoIterator<Item = S>,
305 S: Into<String>,
306 {
307 self.cycle = Some(CycleClause {
308 columns: columns.into_iter().map(Into::into).collect(),
309 set_column: set_column.into(),
310 using_column: using_column.into(),
311 mark_value: None,
312 default_value: None,
313 });
314 self
315 }
316
317 pub fn build(self) -> QueryResult<Cte> {
319 let query = self.query.ok_or_else(|| {
320 QueryError::invalid_input("query", "CTE requires a query (use as_query())")
321 })?;
322
323 Ok(Cte {
324 name: self.name,
325 columns: self.columns,
326 query,
327 recursive: self.recursive,
328 materialized: self.materialized,
329 search: self.search,
330 cycle: self.cycle,
331 })
332 }
333}
334
335#[derive(Debug, Clone, Default, Serialize, Deserialize)]
337pub struct WithClause {
338 pub ctes: Vec<Cte>,
340 pub recursive: bool,
342 pub main_query: Option<String>,
344}
345
346impl WithClause {
347 pub fn new() -> Self {
349 Self::default()
350 }
351
352 pub fn cte(mut self, cte: Cte) -> Self {
354 if cte.recursive {
355 self.recursive = true;
356 }
357 self.ctes.push(cte);
358 self
359 }
360
361 pub fn ctes<I>(mut self, ctes: I) -> Self
363 where
364 I: IntoIterator<Item = Cte>,
365 {
366 for cte in ctes {
367 self = self.cte(cte);
368 }
369 self
370 }
371
372 pub fn main_query(mut self, query: impl Into<String>) -> Self {
374 self.main_query = Some(query.into());
375 self
376 }
377
378 pub fn select(self, columns: impl Into<String>) -> WithQueryBuilder {
380 WithQueryBuilder {
381 with_clause: self,
382 select: columns.into(),
383 from: None,
384 where_clause: None,
385 order_by: None,
386 limit: None,
387 }
388 }
389
390 pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
392 if self.ctes.is_empty() {
393 return Err(QueryError::invalid_input(
394 "ctes",
395 "WITH clause requires at least one CTE",
396 ));
397 }
398
399 let mut sql = String::with_capacity(256);
400
401 sql.push_str("WITH ");
402 if self.recursive {
403 sql.push_str("RECURSIVE ");
404 }
405
406 let cte_sqls: Vec<String> = self.ctes.iter().map(|c| c.to_sql(db_type)).collect();
407 sql.push_str(&cte_sqls.join(", "));
408
409 if let Some(ref main) = self.main_query {
410 sql.push(' ');
411 sql.push_str(main);
412 }
413
414 Ok(sql)
415 }
416}
417
418#[derive(Debug, Clone)]
420pub struct WithQueryBuilder {
421 with_clause: WithClause,
422 select: String,
423 from: Option<String>,
424 where_clause: Option<String>,
425 order_by: Option<String>,
426 limit: Option<u64>,
427}
428
429impl WithQueryBuilder {
430 pub fn from(mut self, table: impl Into<String>) -> Self {
432 self.from = Some(table.into());
433 self
434 }
435
436 pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
438 self.where_clause = Some(condition.into());
439 self
440 }
441
442 pub fn order_by(mut self, order: impl Into<String>) -> Self {
444 self.order_by = Some(order.into());
445 self
446 }
447
448 pub fn limit(mut self, limit: u64) -> Self {
450 self.limit = Some(limit);
451 self
452 }
453
454 pub fn build(mut self, db_type: DatabaseType) -> QueryResult<String> {
456 let mut main = format!("SELECT {}", self.select);
458
459 if let Some(from) = self.from {
460 main.push_str(" FROM ");
461 main.push_str(&from);
462 }
463
464 if let Some(where_clause) = self.where_clause {
465 main.push_str(" WHERE ");
466 main.push_str(&where_clause);
467 }
468
469 let has_order_by = self.order_by.is_some();
470 if let Some(order) = self.order_by {
471 main.push_str(" ORDER BY ");
472 main.push_str(&order);
473 }
474
475 if let Some(limit) = self.limit {
476 match db_type {
477 DatabaseType::MSSQL => {
478 if has_order_by {
480 main.push_str(&format!(" OFFSET 0 ROWS FETCH NEXT {} ROWS ONLY", limit));
481 } else {
482 main = main.replacen("SELECT ", &format!("SELECT TOP {} ", limit), 1);
484 }
485 }
486 _ => {
487 main.push_str(&format!(" LIMIT {}", limit));
488 }
489 }
490 }
491
492 self.with_clause.main_query = Some(main);
493 self.with_clause.to_sql(db_type)
494 }
495}
496
497pub mod patterns {
499 use super::*;
500
501 pub fn tree_traversal(
503 cte_name: &str,
504 table: &str,
505 id_col: &str,
506 parent_col: &str,
507 root_condition: &str,
508 ) -> Cte {
509 let base_query = format!(
510 "SELECT {id}, {parent}, 1 AS depth FROM {table} WHERE {root}",
511 id = id_col,
512 parent = parent_col,
513 table = table,
514 root = root_condition
515 );
516
517 let recursive_query = format!(
518 "SELECT t.{id}, t.{parent}, c.depth + 1 FROM {table} t \
519 INNER JOIN {cte} c ON t.{parent} = c.{id}",
520 id = id_col,
521 parent = parent_col,
522 table = table,
523 cte = cte_name
524 );
525
526 Cte::new(cte_name)
527 .columns([id_col, parent_col, "depth"])
528 .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
529 .recursive()
530 }
531
532 pub fn graph_path(
534 cte_name: &str,
535 edges_table: &str,
536 from_col: &str,
537 to_col: &str,
538 start_node: &str,
539 ) -> Cte {
540 let base_query = format!(
541 "SELECT {from_col}, {to_col}, ARRAY[{from_col}] AS path, 1 AS length \
542 FROM {table} WHERE {from_col} = {start}",
543 from_col = from_col,
544 to_col = to_col,
545 table = edges_table,
546 start = start_node
547 );
548
549 let recursive_query = format!(
550 "SELECT e.{from_col}, e.{to_col}, p.path || e.{to_col}, p.length + 1 \
551 FROM {table} e \
552 INNER JOIN {cte} p ON e.{from_col} = p.{to_col} \
553 WHERE NOT e.{to_col} = ANY(p.path)",
554 from_col = from_col,
555 to_col = to_col,
556 table = edges_table,
557 cte = cte_name
558 );
559
560 Cte::new(cte_name)
561 .columns([from_col, to_col, "path", "length"])
562 .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
563 .recursive()
564 }
565
566 pub fn paginated(cte_name: &str, query: &str, order_by: &str) -> Cte {
568 let paginated_query = format!(
569 "SELECT *, ROW_NUMBER() OVER (ORDER BY {}) AS row_num FROM ({})",
570 order_by, query
571 );
572
573 Cte::new(cte_name).as_query(paginated_query)
574 }
575
576 pub fn running_total(
578 cte_name: &str,
579 table: &str,
580 value_col: &str,
581 order_col: &str,
582 partition_col: Option<&str>,
583 ) -> Cte {
584 let partition = partition_col
585 .map(|p| format!("PARTITION BY {} ", p))
586 .unwrap_or_default();
587
588 let query = format!(
589 "SELECT *, SUM({value}) OVER ({partition}ORDER BY {order}) AS running_total \
590 FROM {table}",
591 value = value_col,
592 partition = partition,
593 order = order_col,
594 table = table
595 );
596
597 Cte::new(cte_name).as_query(query)
598 }
599}
600
601pub mod mongodb {
603 use serde::{Deserialize, Serialize};
604 use serde_json::Value as JsonValue;
605
606 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
608 pub struct Lookup {
609 pub from: String,
611 pub local_field: Option<String>,
613 pub foreign_field: Option<String>,
615 pub as_field: String,
617 pub pipeline: Option<Vec<JsonValue>>,
619 pub let_vars: Option<serde_json::Map<String, JsonValue>>,
621 }
622
623 impl Lookup {
624 pub fn simple(
626 from: impl Into<String>,
627 local: impl Into<String>,
628 foreign: impl Into<String>,
629 as_field: impl Into<String>,
630 ) -> Self {
631 Self {
632 from: from.into(),
633 local_field: Some(local.into()),
634 foreign_field: Some(foreign.into()),
635 as_field: as_field.into(),
636 pipeline: None,
637 let_vars: None,
638 }
639 }
640
641 pub fn with_pipeline(
643 from: impl Into<String>,
644 as_field: impl Into<String>,
645 ) -> LookupBuilder {
646 LookupBuilder {
647 from: from.into(),
648 as_field: as_field.into(),
649 pipeline: Vec::new(),
650 let_vars: serde_json::Map::new(),
651 }
652 }
653
654 pub fn to_bson(&self) -> JsonValue {
656 let mut lookup = serde_json::Map::new();
657 lookup.insert("from".to_string(), JsonValue::String(self.from.clone()));
658
659 if let (Some(local), Some(foreign)) = (&self.local_field, &self.foreign_field) {
660 lookup.insert("localField".to_string(), JsonValue::String(local.clone()));
661 lookup.insert(
662 "foreignField".to_string(),
663 JsonValue::String(foreign.clone()),
664 );
665 }
666
667 lookup.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
668
669 if let Some(ref pipeline) = self.pipeline {
670 lookup.insert("pipeline".to_string(), JsonValue::Array(pipeline.clone()));
671 }
672
673 if let Some(ref vars) = self.let_vars {
674 if !vars.is_empty() {
675 lookup.insert("let".to_string(), JsonValue::Object(vars.clone()));
676 }
677 }
678
679 serde_json::json!({ "$lookup": lookup })
680 }
681 }
682
683 #[derive(Debug, Clone)]
685 pub struct LookupBuilder {
686 from: String,
687 as_field: String,
688 pipeline: Vec<JsonValue>,
689 let_vars: serde_json::Map<String, JsonValue>,
690 }
691
692 impl LookupBuilder {
693 pub fn let_var(mut self, name: impl Into<String>, expr: impl Into<String>) -> Self {
695 self.let_vars
696 .insert(name.into(), JsonValue::String(format!("${}", expr.into())));
697 self
698 }
699
700 pub fn match_expr(mut self, expr: JsonValue) -> Self {
702 self.pipeline
703 .push(serde_json::json!({ "$match": { "$expr": expr } }));
704 self
705 }
706
707 pub fn stage(mut self, stage: JsonValue) -> Self {
709 self.pipeline.push(stage);
710 self
711 }
712
713 pub fn project(mut self, fields: JsonValue) -> Self {
715 self.pipeline
716 .push(serde_json::json!({ "$project": fields }));
717 self
718 }
719
720 pub fn limit(mut self, n: u64) -> Self {
722 self.pipeline.push(serde_json::json!({ "$limit": n }));
723 self
724 }
725
726 pub fn sort(mut self, fields: JsonValue) -> Self {
728 self.pipeline.push(serde_json::json!({ "$sort": fields }));
729 self
730 }
731
732 pub fn build(self) -> Lookup {
734 Lookup {
735 from: self.from,
736 local_field: None,
737 foreign_field: None,
738 as_field: self.as_field,
739 pipeline: if self.pipeline.is_empty() {
740 None
741 } else {
742 Some(self.pipeline)
743 },
744 let_vars: if self.let_vars.is_empty() {
745 None
746 } else {
747 Some(self.let_vars)
748 },
749 }
750 }
751 }
752
753 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
755 pub struct GraphLookup {
756 pub from: String,
758 pub start_with: String,
760 pub connect_from_field: String,
762 pub connect_to_field: String,
764 pub as_field: String,
766 pub max_depth: Option<u32>,
768 pub depth_field: Option<String>,
770 pub restrict_search_with_match: Option<JsonValue>,
772 }
773
774 impl GraphLookup {
775 pub fn new(
777 from: impl Into<String>,
778 start_with: impl Into<String>,
779 connect_from: impl Into<String>,
780 connect_to: impl Into<String>,
781 as_field: impl Into<String>,
782 ) -> Self {
783 Self {
784 from: from.into(),
785 start_with: start_with.into(),
786 connect_from_field: connect_from.into(),
787 connect_to_field: connect_to.into(),
788 as_field: as_field.into(),
789 max_depth: None,
790 depth_field: None,
791 restrict_search_with_match: None,
792 }
793 }
794
795 pub fn max_depth(mut self, depth: u32) -> Self {
797 self.max_depth = Some(depth);
798 self
799 }
800
801 pub fn depth_field(mut self, field: impl Into<String>) -> Self {
803 self.depth_field = Some(field.into());
804 self
805 }
806
807 pub fn restrict_search(mut self, filter: JsonValue) -> Self {
809 self.restrict_search_with_match = Some(filter);
810 self
811 }
812
813 pub fn to_bson(&self) -> JsonValue {
815 let mut graph = serde_json::Map::new();
816 graph.insert("from".to_string(), JsonValue::String(self.from.clone()));
817 graph.insert(
818 "startWith".to_string(),
819 JsonValue::String(format!("${}", self.start_with)),
820 );
821 graph.insert(
822 "connectFromField".to_string(),
823 JsonValue::String(self.connect_from_field.clone()),
824 );
825 graph.insert(
826 "connectToField".to_string(),
827 JsonValue::String(self.connect_to_field.clone()),
828 );
829 graph.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
830
831 if let Some(max) = self.max_depth {
832 graph.insert("maxDepth".to_string(), JsonValue::Number(max.into()));
833 }
834
835 if let Some(ref field) = self.depth_field {
836 graph.insert("depthField".to_string(), JsonValue::String(field.clone()));
837 }
838
839 if let Some(ref filter) = self.restrict_search_with_match {
840 graph.insert("restrictSearchWithMatch".to_string(), filter.clone());
841 }
842
843 serde_json::json!({ "$graphLookup": graph })
844 }
845 }
846
847 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
849 pub struct UnionWith {
850 pub coll: String,
852 pub pipeline: Option<Vec<JsonValue>>,
854 }
855
856 impl UnionWith {
857 pub fn collection(coll: impl Into<String>) -> Self {
859 Self {
860 coll: coll.into(),
861 pipeline: None,
862 }
863 }
864
865 pub fn with_pipeline(coll: impl Into<String>, pipeline: Vec<JsonValue>) -> Self {
867 Self {
868 coll: coll.into(),
869 pipeline: Some(pipeline),
870 }
871 }
872
873 pub fn to_bson(&self) -> JsonValue {
875 if let Some(ref pipeline) = self.pipeline {
876 serde_json::json!({
877 "$unionWith": {
878 "coll": self.coll,
879 "pipeline": pipeline
880 }
881 })
882 } else {
883 serde_json::json!({ "$unionWith": self.coll })
884 }
885 }
886 }
887
888 pub fn lookup(from: &str, local: &str, foreign: &str, as_field: &str) -> Lookup {
890 Lookup::simple(from, local, foreign, as_field)
891 }
892
893 pub fn lookup_pipeline(from: &str, as_field: &str) -> LookupBuilder {
895 Lookup::with_pipeline(from, as_field)
896 }
897
898 pub fn graph_lookup(
900 from: &str,
901 start_with: &str,
902 connect_from: &str,
903 connect_to: &str,
904 as_field: &str,
905 ) -> GraphLookup {
906 GraphLookup::new(from, start_with, connect_from, connect_to, as_field)
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::*;
913
914 #[test]
915 fn test_simple_cte() {
916 let cte = Cte::new("active_users").as_query("SELECT * FROM users WHERE active = true");
917
918 let sql = cte.to_sql(DatabaseType::PostgreSQL);
919 assert!(sql.contains("active_users AS"));
920 assert!(sql.contains("SELECT * FROM users"));
921 }
922
923 #[test]
924 fn test_cte_with_columns() {
925 let cte = Cte::new("user_stats")
926 .columns(["id", "name", "total"])
927 .as_query("SELECT id, name, COUNT(*) FROM orders GROUP BY user_id");
928
929 let sql = cte.to_sql(DatabaseType::PostgreSQL);
930 assert!(sql.contains("user_stats (id, name, total) AS"));
931 }
932
933 #[test]
934 fn test_recursive_cte() {
935 let cte = Cte::new("subordinates")
936 .columns(["id", "name", "manager_id", "depth"])
937 .as_query(
938 "SELECT id, name, manager_id, 1 FROM employees WHERE manager_id IS NULL \
939 UNION ALL \
940 SELECT e.id, e.name, e.manager_id, s.depth + 1 \
941 FROM employees e JOIN subordinates s ON e.manager_id = s.id",
942 )
943 .recursive();
944
945 assert!(cte.recursive);
946 }
947
948 #[test]
949 fn test_materialized_cte() {
950 let cte = Cte::new("expensive_query")
951 .as_query("SELECT * FROM big_table WHERE complex_condition")
952 .materialized(Materialized::Yes);
953
954 let sql = cte.to_sql(DatabaseType::PostgreSQL);
955 assert!(sql.contains("MATERIALIZED"));
956 }
957
958 #[test]
959 fn test_with_clause() {
960 let cte1 = Cte::new("cte1").as_query("SELECT 1");
961 let cte2 = Cte::new("cte2").as_query("SELECT 2");
962
963 let with = WithClause::new()
964 .cte(cte1)
965 .cte(cte2)
966 .main_query("SELECT * FROM cte1, cte2");
967
968 let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
969 assert!(sql.starts_with("WITH "));
970 assert!(sql.contains("cte1 AS"));
971 assert!(sql.contains("cte2 AS"));
972 assert!(sql.contains("SELECT * FROM cte1, cte2"));
973 }
974
975 #[test]
976 fn test_recursive_with_clause() {
977 let cte = Cte::new("numbers")
978 .as_query("SELECT 1 AS n UNION ALL SELECT n + 1 FROM numbers WHERE n < 10")
979 .recursive();
980
981 let with = WithClause::new()
982 .cte(cte)
983 .main_query("SELECT * FROM numbers");
984
985 let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
986 assert!(sql.starts_with("WITH RECURSIVE"));
987 }
988
989 #[test]
990 fn test_with_query_builder() {
991 let cte = Cte::new("active").as_query("SELECT * FROM users WHERE active = true");
992
993 let sql = WithClause::new()
994 .cte(cte)
995 .select("*")
996 .from("active")
997 .where_clause("role = 'admin'")
998 .order_by("name")
999 .limit(10)
1000 .build(DatabaseType::PostgreSQL)
1001 .unwrap();
1002
1003 assert!(sql.contains("WITH active AS"));
1004 assert!(sql.contains("SELECT *"));
1005 assert!(sql.contains("FROM active"));
1006 assert!(sql.contains("WHERE role = 'admin'"));
1007 assert!(sql.contains("ORDER BY name"));
1008 assert!(sql.contains("LIMIT 10"));
1009 }
1010
1011 #[test]
1012 fn test_mssql_limit() {
1013 let cte = Cte::new("data").as_query("SELECT * FROM table1");
1014
1015 let sql = WithClause::new()
1016 .cte(cte)
1017 .select("*")
1018 .from("data")
1019 .order_by("id")
1020 .limit(10)
1021 .build(DatabaseType::MSSQL)
1022 .unwrap();
1023
1024 assert!(sql.contains("OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY"));
1025 }
1026
1027 #[test]
1028 fn test_cte_builder() {
1029 let cte = CteBuilder::new("stats")
1030 .columns(["a", "b"])
1031 .as_query("SELECT 1, 2")
1032 .materialized()
1033 .build()
1034 .unwrap();
1035
1036 assert_eq!(cte.name, "stats");
1037 assert_eq!(cte.columns, vec!["a", "b"]);
1038 assert_eq!(cte.materialized, Some(Materialized::Yes));
1039 }
1040
1041 mod pattern_tests {
1042 use super::super::patterns::*;
1043
1044 #[test]
1045 fn test_tree_traversal_pattern() {
1046 let cte = tree_traversal(
1047 "org_tree",
1048 "employees",
1049 "id",
1050 "manager_id",
1051 "manager_id IS NULL",
1052 );
1053
1054 assert!(cte.recursive);
1055 assert!(cte.query.contains("UNION ALL"));
1056 assert!(cte.query.contains("depth + 1"));
1057 }
1058
1059 #[test]
1060 fn test_running_total_pattern() {
1061 let cte = running_total(
1062 "account_balance",
1063 "transactions",
1064 "amount",
1065 "transaction_date",
1066 Some("account_id"),
1067 );
1068
1069 assert!(cte.query.contains("SUM(amount)"));
1070 assert!(cte.query.contains("PARTITION BY account_id"));
1071 assert!(cte.query.contains("running_total"));
1072 }
1073 }
1074
1075 mod mongodb_tests {
1076 use super::super::mongodb::*;
1077
1078 #[test]
1079 fn test_simple_lookup() {
1080 let lookup = Lookup::simple("orders", "user_id", "_id", "user_orders");
1081 let bson = lookup.to_bson();
1082
1083 assert_eq!(bson["$lookup"]["from"], "orders");
1084 assert_eq!(bson["$lookup"]["localField"], "user_id");
1085 assert_eq!(bson["$lookup"]["foreignField"], "_id");
1086 assert_eq!(bson["$lookup"]["as"], "user_orders");
1087 }
1088
1089 #[test]
1090 fn test_lookup_with_pipeline() {
1091 let lookup = Lookup::with_pipeline("inventory", "stock_items")
1092 .let_var("order_item", "item")
1093 .match_expr(serde_json::json!({
1094 "$eq": ["$sku", "$$order_item"]
1095 }))
1096 .project(serde_json::json!({ "inStock": 1 }))
1097 .build();
1098
1099 let bson = lookup.to_bson();
1100 assert!(bson["$lookup"]["pipeline"].is_array());
1101 assert!(bson["$lookup"]["let"].is_object());
1102 }
1103
1104 #[test]
1105 fn test_graph_lookup() {
1106 let lookup = GraphLookup::new(
1107 "employees",
1108 "reportsTo",
1109 "reportsTo",
1110 "name",
1111 "reportingHierarchy",
1112 )
1113 .max_depth(5)
1114 .depth_field("level");
1115
1116 let bson = lookup.to_bson();
1117 assert_eq!(bson["$graphLookup"]["from"], "employees");
1118 assert_eq!(bson["$graphLookup"]["maxDepth"], 5);
1119 assert_eq!(bson["$graphLookup"]["depthField"], "level");
1120 }
1121
1122 #[test]
1123 fn test_union_with() {
1124 let union = UnionWith::collection("archived_orders");
1125 let bson = union.to_bson();
1126
1127 assert_eq!(bson["$unionWith"], "archived_orders");
1128 }
1129
1130 #[test]
1131 fn test_union_with_pipeline() {
1132 let union = UnionWith::with_pipeline(
1133 "archive",
1134 vec![serde_json::json!({ "$match": { "year": 2023 } })],
1135 );
1136 let bson = union.to_bson();
1137
1138 assert!(bson["$unionWith"]["pipeline"].is_array());
1139 }
1140 }
1141}