Skip to main content

prax_query/
cte.rs

1//! Common Table Expressions (CTEs) support.
2//!
3//! This module provides types for building CTEs (WITH clauses) across
4//! different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature          | PostgreSQL | MySQL | SQLite | MSSQL | MongoDB        |
9//! |------------------|------------|-------|--------|-------|----------------|
10//! | Non-recursive    | ✅         | ✅    | ✅     | ✅    | ❌ ($lookup)   |
11//! | Recursive        | ✅         | ✅    | ✅     | ✅    | ❌             |
12//! | Materialized     | ✅         | ❌    | ❌     | ❌    | ❌             |
13//! | Pipeline stages  | ❌         | ❌    | ❌     | ❌    | ✅ $lookup     |
14//!
15//! # Example Usage
16//!
17//! ```rust,ignore
18//! use prax_query::cte::{Cte, CteBuilder, WithClause};
19//!
20//! // Simple CTE
21//! let cte = Cte::new("active_users")
22//!     .columns(["id", "name", "email"])
23//!     .as_query("SELECT * FROM users WHERE active = true");
24//!
25//! // Build full query with CTE
26//! let query = WithClause::new()
27//!     .cte(cte)
28//!     .select("*")
29//!     .from("active_users")
30//!     .build();
31//! ```
32
33use serde::{Deserialize, Serialize};
34
35use crate::error::{QueryError, QueryResult};
36use crate::sql::DatabaseType;
37
38/// A Common Table Expression (CTE) definition.
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Cte {
41    /// Name of the CTE (used in FROM clause).
42    pub name: String,
43    /// Optional column aliases.
44    pub columns: Vec<String>,
45    /// The query that defines the CTE.
46    pub query: String,
47    /// Whether this is a recursive CTE.
48    pub recursive: bool,
49    /// PostgreSQL: MATERIALIZED / NOT MATERIALIZED hint.
50    pub materialized: Option<Materialized>,
51    /// Search clause for recursive CTEs (PostgreSQL).
52    pub search: Option<SearchClause>,
53    /// Cycle detection for recursive CTEs (PostgreSQL).
54    pub cycle: Option<CycleClause>,
55}
56
57/// Materialization hint for CTEs (PostgreSQL only).
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum Materialized {
60    /// Force materialization.
61    Yes,
62    /// Prevent materialization (inline the CTE).
63    No,
64}
65
66/// Search clause for recursive CTEs.
67#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
68pub struct SearchClause {
69    /// Search method.
70    pub method: SearchMethod,
71    /// Columns to search by.
72    pub columns: Vec<String>,
73    /// Column to store the search sequence.
74    pub set_column: String,
75}
76
77/// Search method for recursive CTEs.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum SearchMethod {
80    /// Breadth-first search.
81    BreadthFirst,
82    /// Depth-first search.
83    DepthFirst,
84}
85
86/// Cycle detection for recursive CTEs.
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct CycleClause {
89    /// Columns to check for cycles.
90    pub columns: Vec<String>,
91    /// Column to mark cycle detection.
92    pub set_column: String,
93    /// Column to store the path.
94    pub using_column: String,
95    /// Value when cycle is detected.
96    pub mark_value: Option<String>,
97    /// Value when no cycle.
98    pub default_value: Option<String>,
99}
100
101impl Cte {
102    /// Create a new CTE with the given name.
103    pub fn new(name: impl Into<String>) -> Self {
104        Self {
105            name: name.into(),
106            columns: Vec::new(),
107            query: String::new(),
108            recursive: false,
109            materialized: None,
110            search: None,
111            cycle: None,
112        }
113    }
114
115    /// Create a new CTE builder.
116    pub fn builder(name: impl Into<String>) -> CteBuilder {
117        CteBuilder::new(name)
118    }
119
120    /// Set the column aliases.
121    pub fn columns<I, S>(mut self, columns: I) -> Self
122    where
123        I: IntoIterator<Item = S>,
124        S: Into<String>,
125    {
126        self.columns = columns.into_iter().map(Into::into).collect();
127        self
128    }
129
130    /// Set the query that defines this CTE.
131    pub fn as_query(mut self, query: impl Into<String>) -> Self {
132        self.query = query.into();
133        self
134    }
135
136    /// Mark this as a recursive CTE.
137    pub fn recursive(mut self) -> Self {
138        self.recursive = true;
139        self
140    }
141
142    /// Set materialization hint (PostgreSQL only).
143    pub fn materialized(mut self, mat: Materialized) -> Self {
144        self.materialized = Some(mat);
145        self
146    }
147
148    /// Generate the CTE definition SQL.
149    pub fn to_sql(&self, db_type: DatabaseType) -> String {
150        let mut sql = self.name.clone();
151
152        // Column aliases
153        if !self.columns.is_empty() {
154            sql.push_str(" (");
155            sql.push_str(&self.columns.join(", "));
156            sql.push(')');
157        }
158
159        sql.push_str(" AS ");
160
161        // Materialization hint (PostgreSQL only)
162        if db_type == DatabaseType::PostgreSQL {
163            if let Some(mat) = self.materialized {
164                match mat {
165                    Materialized::Yes => sql.push_str("MATERIALIZED "),
166                    Materialized::No => sql.push_str("NOT MATERIALIZED "),
167                }
168            }
169        }
170
171        sql.push('(');
172        sql.push_str(&self.query);
173        sql.push(')');
174
175        // Search clause (PostgreSQL only)
176        if db_type == DatabaseType::PostgreSQL {
177            if let Some(ref search) = self.search {
178                sql.push_str(" SEARCH ");
179                sql.push_str(match search.method {
180                    SearchMethod::BreadthFirst => "BREADTH FIRST BY ",
181                    SearchMethod::DepthFirst => "DEPTH FIRST BY ",
182                });
183                sql.push_str(&search.columns.join(", "));
184                sql.push_str(" SET ");
185                sql.push_str(&search.set_column);
186            }
187
188            if let Some(ref cycle) = self.cycle {
189                sql.push_str(" CYCLE ");
190                sql.push_str(&cycle.columns.join(", "));
191                sql.push_str(" SET ");
192                sql.push_str(&cycle.set_column);
193                if let (Some(mark), Some(default)) = (&cycle.mark_value, &cycle.default_value) {
194                    sql.push_str(" TO ");
195                    sql.push_str(mark);
196                    sql.push_str(" DEFAULT ");
197                    sql.push_str(default);
198                }
199                sql.push_str(" USING ");
200                sql.push_str(&cycle.using_column);
201            }
202        }
203
204        sql
205    }
206}
207
208/// Builder for CTEs.
209#[derive(Debug, Clone)]
210pub struct CteBuilder {
211    name: String,
212    columns: Vec<String>,
213    query: Option<String>,
214    recursive: bool,
215    materialized: Option<Materialized>,
216    search: Option<SearchClause>,
217    cycle: Option<CycleClause>,
218}
219
220impl CteBuilder {
221    /// Create a new CTE builder.
222    pub fn new(name: impl Into<String>) -> Self {
223        Self {
224            name: name.into(),
225            columns: Vec::new(),
226            query: None,
227            recursive: false,
228            materialized: None,
229            search: None,
230            cycle: None,
231        }
232    }
233
234    /// Set the column aliases.
235    pub fn columns<I, S>(mut self, columns: I) -> Self
236    where
237        I: IntoIterator<Item = S>,
238        S: Into<String>,
239    {
240        self.columns = columns.into_iter().map(Into::into).collect();
241        self
242    }
243
244    /// Set the query that defines this CTE.
245    pub fn as_query(mut self, query: impl Into<String>) -> Self {
246        self.query = Some(query.into());
247        self
248    }
249
250    /// Mark this as a recursive CTE.
251    pub fn recursive(mut self) -> Self {
252        self.recursive = true;
253        self
254    }
255
256    /// Set materialization hint (PostgreSQL only).
257    pub fn materialized(mut self) -> Self {
258        self.materialized = Some(Materialized::Yes);
259        self
260    }
261
262    /// Prevent materialization (PostgreSQL only).
263    pub fn not_materialized(mut self) -> Self {
264        self.materialized = Some(Materialized::No);
265        self
266    }
267
268    /// Add breadth-first search (PostgreSQL only).
269    pub fn search_breadth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
270    where
271        I: IntoIterator<Item = S>,
272        S: Into<String>,
273    {
274        self.search = Some(SearchClause {
275            method: SearchMethod::BreadthFirst,
276            columns: columns.into_iter().map(Into::into).collect(),
277            set_column: set_column.into(),
278        });
279        self
280    }
281
282    /// Add depth-first search (PostgreSQL only).
283    pub fn search_depth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
284    where
285        I: IntoIterator<Item = S>,
286        S: Into<String>,
287    {
288        self.search = Some(SearchClause {
289            method: SearchMethod::DepthFirst,
290            columns: columns.into_iter().map(Into::into).collect(),
291            set_column: set_column.into(),
292        });
293        self
294    }
295
296    /// Add cycle detection (PostgreSQL only).
297    pub fn cycle<I, S>(
298        mut self,
299        columns: I,
300        set_column: impl Into<String>,
301        using_column: impl Into<String>,
302    ) -> Self
303    where
304        I: IntoIterator<Item = S>,
305        S: Into<String>,
306    {
307        self.cycle = Some(CycleClause {
308            columns: columns.into_iter().map(Into::into).collect(),
309            set_column: set_column.into(),
310            using_column: using_column.into(),
311            mark_value: None,
312            default_value: None,
313        });
314        self
315    }
316
317    /// Build the CTE.
318    pub fn build(self) -> QueryResult<Cte> {
319        let query = self.query.ok_or_else(|| {
320            QueryError::invalid_input("query", "CTE requires a query (use as_query())")
321        })?;
322
323        Ok(Cte {
324            name: self.name,
325            columns: self.columns,
326            query,
327            recursive: self.recursive,
328            materialized: self.materialized,
329            search: self.search,
330            cycle: self.cycle,
331        })
332    }
333}
334
335/// A WITH clause containing one or more CTEs.
336#[derive(Debug, Clone, Default, Serialize, Deserialize)]
337pub struct WithClause {
338    /// The CTEs in this WITH clause.
339    pub ctes: Vec<Cte>,
340    /// Whether any CTE is recursive.
341    pub recursive: bool,
342    /// The main query that uses the CTEs.
343    pub main_query: Option<String>,
344}
345
346impl WithClause {
347    /// Create a new empty WITH clause.
348    pub fn new() -> Self {
349        Self::default()
350    }
351
352    /// Add a CTE to this WITH clause.
353    pub fn cte(mut self, cte: Cte) -> Self {
354        if cte.recursive {
355            self.recursive = true;
356        }
357        self.ctes.push(cte);
358        self
359    }
360
361    /// Add multiple CTEs.
362    pub fn ctes<I>(mut self, ctes: I) -> Self
363    where
364        I: IntoIterator<Item = Cte>,
365    {
366        for cte in ctes {
367            self = self.cte(cte);
368        }
369        self
370    }
371
372    /// Set the main query.
373    pub fn main_query(mut self, query: impl Into<String>) -> Self {
374        self.main_query = Some(query.into());
375        self
376    }
377
378    /// Convenience: SELECT from a CTE.
379    pub fn select(self, columns: impl Into<String>) -> WithQueryBuilder {
380        WithQueryBuilder {
381            with_clause: self,
382            select: columns.into(),
383            from: None,
384            where_clause: None,
385            order_by: None,
386            limit: None,
387        }
388    }
389
390    /// Generate the full SQL.
391    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
392        if self.ctes.is_empty() {
393            return Err(QueryError::invalid_input(
394                "ctes",
395                "WITH clause requires at least one CTE",
396            ));
397        }
398
399        let mut sql = String::with_capacity(256);
400
401        sql.push_str("WITH ");
402        if self.recursive {
403            sql.push_str("RECURSIVE ");
404        }
405
406        let cte_sqls: Vec<String> = self.ctes.iter().map(|c| c.to_sql(db_type)).collect();
407        sql.push_str(&cte_sqls.join(", "));
408
409        if let Some(ref main) = self.main_query {
410            sql.push(' ');
411            sql.push_str(main);
412        }
413
414        Ok(sql)
415    }
416}
417
418/// Builder for queries using WITH clause.
419#[derive(Debug, Clone)]
420pub struct WithQueryBuilder {
421    with_clause: WithClause,
422    select: String,
423    from: Option<String>,
424    where_clause: Option<String>,
425    order_by: Option<String>,
426    limit: Option<u64>,
427}
428
429impl WithQueryBuilder {
430    /// Set the FROM clause.
431    pub fn from(mut self, table: impl Into<String>) -> Self {
432        self.from = Some(table.into());
433        self
434    }
435
436    /// Set the WHERE clause.
437    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
438        self.where_clause = Some(condition.into());
439        self
440    }
441
442    /// Set ORDER BY.
443    pub fn order_by(mut self, order: impl Into<String>) -> Self {
444        self.order_by = Some(order.into());
445        self
446    }
447
448    /// Set LIMIT.
449    pub fn limit(mut self, limit: u64) -> Self {
450        self.limit = Some(limit);
451        self
452    }
453
454    /// Build the full SQL query.
455    pub fn build(mut self, db_type: DatabaseType) -> QueryResult<String> {
456        // Build main query
457        let mut main = format!("SELECT {}", self.select);
458
459        if let Some(from) = self.from {
460            main.push_str(" FROM ");
461            main.push_str(&from);
462        }
463
464        if let Some(where_clause) = self.where_clause {
465            main.push_str(" WHERE ");
466            main.push_str(&where_clause);
467        }
468
469        let has_order_by = self.order_by.is_some();
470        if let Some(order) = self.order_by {
471            main.push_str(" ORDER BY ");
472            main.push_str(&order);
473        }
474
475        if let Some(limit) = self.limit {
476            match db_type {
477                DatabaseType::MSSQL => {
478                    // MSSQL uses TOP or OFFSET FETCH
479                    if has_order_by {
480                        main.push_str(&format!(" OFFSET 0 ROWS FETCH NEXT {} ROWS ONLY", limit));
481                    } else {
482                        // Need to inject TOP after SELECT
483                        main = main.replacen("SELECT ", &format!("SELECT TOP {} ", limit), 1);
484                    }
485                }
486                _ => {
487                    main.push_str(&format!(" LIMIT {}", limit));
488                }
489            }
490        }
491
492        self.with_clause.main_query = Some(main);
493        self.with_clause.to_sql(db_type)
494    }
495}
496
497/// Helper functions for common CTE patterns.
498pub mod patterns {
499    use super::*;
500
501    /// Create a recursive CTE for tree traversal (parent-child hierarchy).
502    pub fn tree_traversal(
503        cte_name: &str,
504        table: &str,
505        id_col: &str,
506        parent_col: &str,
507        root_condition: &str,
508    ) -> Cte {
509        let base_query = format!(
510            "SELECT {id}, {parent}, 1 AS depth FROM {table} WHERE {root}",
511            id = id_col,
512            parent = parent_col,
513            table = table,
514            root = root_condition
515        );
516
517        let recursive_query = format!(
518            "SELECT t.{id}, t.{parent}, c.depth + 1 FROM {table} t \
519             INNER JOIN {cte} c ON t.{parent} = c.{id}",
520            id = id_col,
521            parent = parent_col,
522            table = table,
523            cte = cte_name
524        );
525
526        Cte::new(cte_name)
527            .columns([id_col, parent_col, "depth"])
528            .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
529            .recursive()
530    }
531
532    /// Create a recursive CTE for graph path finding.
533    pub fn graph_path(
534        cte_name: &str,
535        edges_table: &str,
536        from_col: &str,
537        to_col: &str,
538        start_node: &str,
539    ) -> Cte {
540        let base_query = format!(
541            "SELECT {from_col}, {to_col}, ARRAY[{from_col}] AS path, 1 AS length \
542             FROM {table} WHERE {from_col} = {start}",
543            from_col = from_col,
544            to_col = to_col,
545            table = edges_table,
546            start = start_node
547        );
548
549        let recursive_query = format!(
550            "SELECT e.{from_col}, e.{to_col}, p.path || e.{to_col}, p.length + 1 \
551             FROM {table} e \
552             INNER JOIN {cte} p ON e.{from_col} = p.{to_col} \
553             WHERE NOT e.{to_col} = ANY(p.path)",
554            from_col = from_col,
555            to_col = to_col,
556            table = edges_table,
557            cte = cte_name
558        );
559
560        Cte::new(cte_name)
561            .columns([from_col, to_col, "path", "length"])
562            .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
563            .recursive()
564    }
565
566    /// Create a CTE for pagination (row numbering).
567    pub fn paginated(cte_name: &str, query: &str, order_by: &str) -> Cte {
568        let paginated_query = format!(
569            "SELECT *, ROW_NUMBER() OVER (ORDER BY {}) AS row_num FROM ({})",
570            order_by, query
571        );
572
573        Cte::new(cte_name).as_query(paginated_query)
574    }
575
576    /// Create a CTE for running totals.
577    pub fn running_total(
578        cte_name: &str,
579        table: &str,
580        value_col: &str,
581        order_col: &str,
582        partition_col: Option<&str>,
583    ) -> Cte {
584        let partition = partition_col
585            .map(|p| format!("PARTITION BY {} ", p))
586            .unwrap_or_default();
587
588        let query = format!(
589            "SELECT *, SUM({value}) OVER ({partition}ORDER BY {order}) AS running_total \
590             FROM {table}",
591            value = value_col,
592            partition = partition,
593            order = order_col,
594            table = table
595        );
596
597        Cte::new(cte_name).as_query(query)
598    }
599}
600
601/// MongoDB $lookup pipeline support (CTE equivalent).
602pub mod mongodb {
603    use serde::{Deserialize, Serialize};
604    use serde_json::Value as JsonValue;
605
606    /// A $lookup stage for MongoDB aggregation pipelines.
607    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
608    pub struct Lookup {
609        /// The foreign collection.
610        pub from: String,
611        /// Local field to match.
612        pub local_field: Option<String>,
613        /// Foreign field to match.
614        pub foreign_field: Option<String>,
615        /// Output array field name.
616        pub as_field: String,
617        /// Pipeline to run on matched documents.
618        pub pipeline: Option<Vec<JsonValue>>,
619        /// Variables to pass to pipeline.
620        pub let_vars: Option<serde_json::Map<String, JsonValue>>,
621    }
622
623    impl Lookup {
624        /// Create a simple $lookup (equality match).
625        pub fn simple(
626            from: impl Into<String>,
627            local: impl Into<String>,
628            foreign: impl Into<String>,
629            as_field: impl Into<String>,
630        ) -> Self {
631            Self {
632                from: from.into(),
633                local_field: Some(local.into()),
634                foreign_field: Some(foreign.into()),
635                as_field: as_field.into(),
636                pipeline: None,
637                let_vars: None,
638            }
639        }
640
641        /// Create a $lookup with pipeline (subquery).
642        pub fn with_pipeline(
643            from: impl Into<String>,
644            as_field: impl Into<String>,
645        ) -> LookupBuilder {
646            LookupBuilder {
647                from: from.into(),
648                as_field: as_field.into(),
649                pipeline: Vec::new(),
650                let_vars: serde_json::Map::new(),
651            }
652        }
653
654        /// Convert to BSON document.
655        pub fn to_bson(&self) -> JsonValue {
656            let mut lookup = serde_json::Map::new();
657            lookup.insert("from".to_string(), JsonValue::String(self.from.clone()));
658
659            if let (Some(local), Some(foreign)) = (&self.local_field, &self.foreign_field) {
660                lookup.insert("localField".to_string(), JsonValue::String(local.clone()));
661                lookup.insert(
662                    "foreignField".to_string(),
663                    JsonValue::String(foreign.clone()),
664                );
665            }
666
667            lookup.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
668
669            if let Some(ref pipeline) = self.pipeline {
670                lookup.insert("pipeline".to_string(), JsonValue::Array(pipeline.clone()));
671            }
672
673            if let Some(ref vars) = self.let_vars {
674                if !vars.is_empty() {
675                    lookup.insert("let".to_string(), JsonValue::Object(vars.clone()));
676                }
677            }
678
679            serde_json::json!({ "$lookup": lookup })
680        }
681    }
682
683    /// Builder for $lookup with pipeline.
684    #[derive(Debug, Clone)]
685    pub struct LookupBuilder {
686        from: String,
687        as_field: String,
688        pipeline: Vec<JsonValue>,
689        let_vars: serde_json::Map<String, JsonValue>,
690    }
691
692    impl LookupBuilder {
693        /// Add a variable for the pipeline.
694        pub fn let_var(mut self, name: impl Into<String>, expr: impl Into<String>) -> Self {
695            self.let_vars
696                .insert(name.into(), JsonValue::String(format!("${}", expr.into())));
697            self
698        }
699
700        /// Add a $match stage to the pipeline.
701        pub fn match_expr(mut self, expr: JsonValue) -> Self {
702            self.pipeline
703                .push(serde_json::json!({ "$match": { "$expr": expr } }));
704            self
705        }
706
707        /// Add a raw stage to the pipeline.
708        pub fn stage(mut self, stage: JsonValue) -> Self {
709            self.pipeline.push(stage);
710            self
711        }
712
713        /// Add a $project stage.
714        pub fn project(mut self, fields: JsonValue) -> Self {
715            self.pipeline
716                .push(serde_json::json!({ "$project": fields }));
717            self
718        }
719
720        /// Add a $limit stage.
721        pub fn limit(mut self, n: u64) -> Self {
722            self.pipeline.push(serde_json::json!({ "$limit": n }));
723            self
724        }
725
726        /// Add a $sort stage.
727        pub fn sort(mut self, fields: JsonValue) -> Self {
728            self.pipeline.push(serde_json::json!({ "$sort": fields }));
729            self
730        }
731
732        /// Build the $lookup.
733        pub fn build(self) -> Lookup {
734            Lookup {
735                from: self.from,
736                local_field: None,
737                foreign_field: None,
738                as_field: self.as_field,
739                pipeline: if self.pipeline.is_empty() {
740                    None
741                } else {
742                    Some(self.pipeline)
743                },
744                let_vars: if self.let_vars.is_empty() {
745                    None
746                } else {
747                    Some(self.let_vars)
748                },
749            }
750        }
751    }
752
753    /// A $graphLookup stage for recursive lookups.
754    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
755    pub struct GraphLookup {
756        /// The collection to search.
757        pub from: String,
758        /// Starting value expression.
759        pub start_with: String,
760        /// Field to connect from.
761        pub connect_from_field: String,
762        /// Field to connect to.
763        pub connect_to_field: String,
764        /// Output array field.
765        pub as_field: String,
766        /// Maximum recursion depth.
767        pub max_depth: Option<u32>,
768        /// Name for depth field.
769        pub depth_field: Option<String>,
770        /// Filter to apply at each level.
771        pub restrict_search_with_match: Option<JsonValue>,
772    }
773
774    impl GraphLookup {
775        /// Create a new $graphLookup.
776        pub fn new(
777            from: impl Into<String>,
778            start_with: impl Into<String>,
779            connect_from: impl Into<String>,
780            connect_to: impl Into<String>,
781            as_field: impl Into<String>,
782        ) -> Self {
783            Self {
784                from: from.into(),
785                start_with: start_with.into(),
786                connect_from_field: connect_from.into(),
787                connect_to_field: connect_to.into(),
788                as_field: as_field.into(),
789                max_depth: None,
790                depth_field: None,
791                restrict_search_with_match: None,
792            }
793        }
794
795        /// Set maximum recursion depth.
796        pub fn max_depth(mut self, depth: u32) -> Self {
797            self.max_depth = Some(depth);
798            self
799        }
800
801        /// Add a depth field to results.
802        pub fn depth_field(mut self, field: impl Into<String>) -> Self {
803            self.depth_field = Some(field.into());
804            self
805        }
806
807        /// Add a filter for each recursion level.
808        pub fn restrict_search(mut self, filter: JsonValue) -> Self {
809            self.restrict_search_with_match = Some(filter);
810            self
811        }
812
813        /// Convert to BSON document.
814        pub fn to_bson(&self) -> JsonValue {
815            let mut graph = serde_json::Map::new();
816            graph.insert("from".to_string(), JsonValue::String(self.from.clone()));
817            graph.insert(
818                "startWith".to_string(),
819                JsonValue::String(format!("${}", self.start_with)),
820            );
821            graph.insert(
822                "connectFromField".to_string(),
823                JsonValue::String(self.connect_from_field.clone()),
824            );
825            graph.insert(
826                "connectToField".to_string(),
827                JsonValue::String(self.connect_to_field.clone()),
828            );
829            graph.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
830
831            if let Some(max) = self.max_depth {
832                graph.insert("maxDepth".to_string(), JsonValue::Number(max.into()));
833            }
834
835            if let Some(ref field) = self.depth_field {
836                graph.insert("depthField".to_string(), JsonValue::String(field.clone()));
837            }
838
839            if let Some(ref filter) = self.restrict_search_with_match {
840                graph.insert("restrictSearchWithMatch".to_string(), filter.clone());
841            }
842
843            serde_json::json!({ "$graphLookup": graph })
844        }
845    }
846
847    /// A $unionWith stage (similar to UNION ALL).
848    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
849    pub struct UnionWith {
850        /// Collection to union with.
851        pub coll: String,
852        /// Optional pipeline to apply before union.
853        pub pipeline: Option<Vec<JsonValue>>,
854    }
855
856    impl UnionWith {
857        /// Create a simple union with a collection.
858        pub fn collection(coll: impl Into<String>) -> Self {
859            Self {
860                coll: coll.into(),
861                pipeline: None,
862            }
863        }
864
865        /// Create a union with a pipeline.
866        pub fn with_pipeline(coll: impl Into<String>, pipeline: Vec<JsonValue>) -> Self {
867            Self {
868                coll: coll.into(),
869                pipeline: Some(pipeline),
870            }
871        }
872
873        /// Convert to BSON document.
874        pub fn to_bson(&self) -> JsonValue {
875            if let Some(ref pipeline) = self.pipeline {
876                serde_json::json!({
877                    "$unionWith": {
878                        "coll": self.coll,
879                        "pipeline": pipeline
880                    }
881                })
882            } else {
883                serde_json::json!({ "$unionWith": self.coll })
884            }
885        }
886    }
887
888    /// Helper to create a simple lookup.
889    pub fn lookup(from: &str, local: &str, foreign: &str, as_field: &str) -> Lookup {
890        Lookup::simple(from, local, foreign, as_field)
891    }
892
893    /// Helper to create a lookup with pipeline.
894    pub fn lookup_pipeline(from: &str, as_field: &str) -> LookupBuilder {
895        Lookup::with_pipeline(from, as_field)
896    }
897
898    /// Helper to create a graph lookup.
899    pub fn graph_lookup(
900        from: &str,
901        start_with: &str,
902        connect_from: &str,
903        connect_to: &str,
904        as_field: &str,
905    ) -> GraphLookup {
906        GraphLookup::new(from, start_with, connect_from, connect_to, as_field)
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913
914    #[test]
915    fn test_simple_cte() {
916        let cte = Cte::new("active_users").as_query("SELECT * FROM users WHERE active = true");
917
918        let sql = cte.to_sql(DatabaseType::PostgreSQL);
919        assert!(sql.contains("active_users AS"));
920        assert!(sql.contains("SELECT * FROM users"));
921    }
922
923    #[test]
924    fn test_cte_with_columns() {
925        let cte = Cte::new("user_stats")
926            .columns(["id", "name", "total"])
927            .as_query("SELECT id, name, COUNT(*) FROM orders GROUP BY user_id");
928
929        let sql = cte.to_sql(DatabaseType::PostgreSQL);
930        assert!(sql.contains("user_stats (id, name, total) AS"));
931    }
932
933    #[test]
934    fn test_recursive_cte() {
935        let cte = Cte::new("subordinates")
936            .columns(["id", "name", "manager_id", "depth"])
937            .as_query(
938                "SELECT id, name, manager_id, 1 FROM employees WHERE manager_id IS NULL \
939                 UNION ALL \
940                 SELECT e.id, e.name, e.manager_id, s.depth + 1 \
941                 FROM employees e JOIN subordinates s ON e.manager_id = s.id",
942            )
943            .recursive();
944
945        assert!(cte.recursive);
946    }
947
948    #[test]
949    fn test_materialized_cte() {
950        let cte = Cte::new("expensive_query")
951            .as_query("SELECT * FROM big_table WHERE complex_condition")
952            .materialized(Materialized::Yes);
953
954        let sql = cte.to_sql(DatabaseType::PostgreSQL);
955        assert!(sql.contains("MATERIALIZED"));
956    }
957
958    #[test]
959    fn test_with_clause() {
960        let cte1 = Cte::new("cte1").as_query("SELECT 1");
961        let cte2 = Cte::new("cte2").as_query("SELECT 2");
962
963        let with = WithClause::new()
964            .cte(cte1)
965            .cte(cte2)
966            .main_query("SELECT * FROM cte1, cte2");
967
968        let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
969        assert!(sql.starts_with("WITH "));
970        assert!(sql.contains("cte1 AS"));
971        assert!(sql.contains("cte2 AS"));
972        assert!(sql.contains("SELECT * FROM cte1, cte2"));
973    }
974
975    #[test]
976    fn test_recursive_with_clause() {
977        let cte = Cte::new("numbers")
978            .as_query("SELECT 1 AS n UNION ALL SELECT n + 1 FROM numbers WHERE n < 10")
979            .recursive();
980
981        let with = WithClause::new()
982            .cte(cte)
983            .main_query("SELECT * FROM numbers");
984
985        let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
986        assert!(sql.starts_with("WITH RECURSIVE"));
987    }
988
989    #[test]
990    fn test_with_query_builder() {
991        let cte = Cte::new("active").as_query("SELECT * FROM users WHERE active = true");
992
993        let sql = WithClause::new()
994            .cte(cte)
995            .select("*")
996            .from("active")
997            .where_clause("role = 'admin'")
998            .order_by("name")
999            .limit(10)
1000            .build(DatabaseType::PostgreSQL)
1001            .unwrap();
1002
1003        assert!(sql.contains("WITH active AS"));
1004        assert!(sql.contains("SELECT *"));
1005        assert!(sql.contains("FROM active"));
1006        assert!(sql.contains("WHERE role = 'admin'"));
1007        assert!(sql.contains("ORDER BY name"));
1008        assert!(sql.contains("LIMIT 10"));
1009    }
1010
1011    #[test]
1012    fn test_mssql_limit() {
1013        let cte = Cte::new("data").as_query("SELECT * FROM table1");
1014
1015        let sql = WithClause::new()
1016            .cte(cte)
1017            .select("*")
1018            .from("data")
1019            .order_by("id")
1020            .limit(10)
1021            .build(DatabaseType::MSSQL)
1022            .unwrap();
1023
1024        assert!(sql.contains("OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY"));
1025    }
1026
1027    #[test]
1028    fn test_cte_builder() {
1029        let cte = CteBuilder::new("stats")
1030            .columns(["a", "b"])
1031            .as_query("SELECT 1, 2")
1032            .materialized()
1033            .build()
1034            .unwrap();
1035
1036        assert_eq!(cte.name, "stats");
1037        assert_eq!(cte.columns, vec!["a", "b"]);
1038        assert_eq!(cte.materialized, Some(Materialized::Yes));
1039    }
1040
1041    mod pattern_tests {
1042        use super::super::patterns::*;
1043
1044        #[test]
1045        fn test_tree_traversal_pattern() {
1046            let cte = tree_traversal(
1047                "org_tree",
1048                "employees",
1049                "id",
1050                "manager_id",
1051                "manager_id IS NULL",
1052            );
1053
1054            assert!(cte.recursive);
1055            assert!(cte.query.contains("UNION ALL"));
1056            assert!(cte.query.contains("depth + 1"));
1057        }
1058
1059        #[test]
1060        fn test_running_total_pattern() {
1061            let cte = running_total(
1062                "account_balance",
1063                "transactions",
1064                "amount",
1065                "transaction_date",
1066                Some("account_id"),
1067            );
1068
1069            assert!(cte.query.contains("SUM(amount)"));
1070            assert!(cte.query.contains("PARTITION BY account_id"));
1071            assert!(cte.query.contains("running_total"));
1072        }
1073    }
1074
1075    mod mongodb_tests {
1076        use super::super::mongodb::*;
1077
1078        #[test]
1079        fn test_simple_lookup() {
1080            let lookup = Lookup::simple("orders", "user_id", "_id", "user_orders");
1081            let bson = lookup.to_bson();
1082
1083            assert_eq!(bson["$lookup"]["from"], "orders");
1084            assert_eq!(bson["$lookup"]["localField"], "user_id");
1085            assert_eq!(bson["$lookup"]["foreignField"], "_id");
1086            assert_eq!(bson["$lookup"]["as"], "user_orders");
1087        }
1088
1089        #[test]
1090        fn test_lookup_with_pipeline() {
1091            let lookup = Lookup::with_pipeline("inventory", "stock_items")
1092                .let_var("order_item", "item")
1093                .match_expr(serde_json::json!({
1094                    "$eq": ["$sku", "$$order_item"]
1095                }))
1096                .project(serde_json::json!({ "inStock": 1 }))
1097                .build();
1098
1099            let bson = lookup.to_bson();
1100            assert!(bson["$lookup"]["pipeline"].is_array());
1101            assert!(bson["$lookup"]["let"].is_object());
1102        }
1103
1104        #[test]
1105        fn test_graph_lookup() {
1106            let lookup = GraphLookup::new(
1107                "employees",
1108                "reportsTo",
1109                "reportsTo",
1110                "name",
1111                "reportingHierarchy",
1112            )
1113            .max_depth(5)
1114            .depth_field("level");
1115
1116            let bson = lookup.to_bson();
1117            assert_eq!(bson["$graphLookup"]["from"], "employees");
1118            assert_eq!(bson["$graphLookup"]["maxDepth"], 5);
1119            assert_eq!(bson["$graphLookup"]["depthField"], "level");
1120        }
1121
1122        #[test]
1123        fn test_union_with() {
1124            let union = UnionWith::collection("archived_orders");
1125            let bson = union.to_bson();
1126
1127            assert_eq!(bson["$unionWith"], "archived_orders");
1128        }
1129
1130        #[test]
1131        fn test_union_with_pipeline() {
1132            let union = UnionWith::with_pipeline(
1133                "archive",
1134                vec![serde_json::json!({ "$match": { "year": 2023 } })],
1135            );
1136            let bson = union.to_bson();
1137
1138            assert!(bson["$unionWith"]["pipeline"].is_array());
1139        }
1140    }
1141}