prax_query/
cte.rs

1//! Common Table Expressions (CTEs) support.
2//!
3//! This module provides types for building CTEs (WITH clauses) across
4//! different database backends.
5//!
6//! # Supported Features
7//!
8//! | Feature          | PostgreSQL | MySQL | SQLite | MSSQL | MongoDB        |
9//! |------------------|------------|-------|--------|-------|----------------|
10//! | Non-recursive    | ✅         | ✅    | ✅     | ✅    | ❌ ($lookup)   |
11//! | Recursive        | ✅         | ✅    | ✅     | ✅    | ❌             |
12//! | Materialized     | ✅         | ❌    | ❌     | ❌    | ❌             |
13//! | Pipeline stages  | ❌         | ❌    | ❌     | ❌    | ✅ $lookup     |
14//!
15//! # Example Usage
16//!
17//! ```rust,ignore
18//! use prax_query::cte::{Cte, CteBuilder, WithClause};
19//!
20//! // Simple CTE
21//! let cte = Cte::new("active_users")
22//!     .columns(["id", "name", "email"])
23//!     .as_query("SELECT * FROM users WHERE active = true");
24//!
25//! // Build full query with CTE
26//! let query = WithClause::new()
27//!     .cte(cte)
28//!     .select("*")
29//!     .from("active_users")
30//!     .build();
31//! ```
32
33use serde::{Deserialize, Serialize};
34
35use crate::error::{QueryError, QueryResult};
36use crate::sql::DatabaseType;
37
38/// A Common Table Expression (CTE) definition.
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Cte {
41    /// Name of the CTE (used in FROM clause).
42    pub name: String,
43    /// Optional column aliases.
44    pub columns: Vec<String>,
45    /// The query that defines the CTE.
46    pub query: String,
47    /// Whether this is a recursive CTE.
48    pub recursive: bool,
49    /// PostgreSQL: MATERIALIZED / NOT MATERIALIZED hint.
50    pub materialized: Option<Materialized>,
51    /// Search clause for recursive CTEs (PostgreSQL).
52    pub search: Option<SearchClause>,
53    /// Cycle detection for recursive CTEs (PostgreSQL).
54    pub cycle: Option<CycleClause>,
55}
56
57/// Materialization hint for CTEs (PostgreSQL only).
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum Materialized {
60    /// Force materialization.
61    Yes,
62    /// Prevent materialization (inline the CTE).
63    No,
64}
65
66/// Search clause for recursive CTEs.
67#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
68pub struct SearchClause {
69    /// Search method.
70    pub method: SearchMethod,
71    /// Columns to search by.
72    pub columns: Vec<String>,
73    /// Column to store the search sequence.
74    pub set_column: String,
75}
76
77/// Search method for recursive CTEs.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum SearchMethod {
80    /// Breadth-first search.
81    BreadthFirst,
82    /// Depth-first search.
83    DepthFirst,
84}
85
86/// Cycle detection for recursive CTEs.
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct CycleClause {
89    /// Columns to check for cycles.
90    pub columns: Vec<String>,
91    /// Column to mark cycle detection.
92    pub set_column: String,
93    /// Column to store the path.
94    pub using_column: String,
95    /// Value when cycle is detected.
96    pub mark_value: Option<String>,
97    /// Value when no cycle.
98    pub default_value: Option<String>,
99}
100
101impl Cte {
102    /// Create a new CTE with the given name.
103    pub fn new(name: impl Into<String>) -> Self {
104        Self {
105            name: name.into(),
106            columns: Vec::new(),
107            query: String::new(),
108            recursive: false,
109            materialized: None,
110            search: None,
111            cycle: None,
112        }
113    }
114
115    /// Create a new CTE builder.
116    pub fn builder(name: impl Into<String>) -> CteBuilder {
117        CteBuilder::new(name)
118    }
119
120    /// Set the column aliases.
121    pub fn columns<I, S>(mut self, columns: I) -> Self
122    where
123        I: IntoIterator<Item = S>,
124        S: Into<String>,
125    {
126        self.columns = columns.into_iter().map(Into::into).collect();
127        self
128    }
129
130    /// Set the query that defines this CTE.
131    pub fn as_query(mut self, query: impl Into<String>) -> Self {
132        self.query = query.into();
133        self
134    }
135
136    /// Mark this as a recursive CTE.
137    pub fn recursive(mut self) -> Self {
138        self.recursive = true;
139        self
140    }
141
142    /// Set materialization hint (PostgreSQL only).
143    pub fn materialized(mut self, mat: Materialized) -> Self {
144        self.materialized = Some(mat);
145        self
146    }
147
148    /// Generate the CTE definition SQL.
149    pub fn to_sql(&self, db_type: DatabaseType) -> String {
150        let mut sql = self.name.clone();
151
152        // Column aliases
153        if !self.columns.is_empty() {
154            sql.push_str(" (");
155            sql.push_str(&self.columns.join(", "));
156            sql.push(')');
157        }
158
159        sql.push_str(" AS ");
160
161        // Materialization hint (PostgreSQL only)
162        if db_type == DatabaseType::PostgreSQL {
163            if let Some(mat) = self.materialized {
164                match mat {
165                    Materialized::Yes => sql.push_str("MATERIALIZED "),
166                    Materialized::No => sql.push_str("NOT MATERIALIZED "),
167                }
168            }
169        }
170
171        sql.push('(');
172        sql.push_str(&self.query);
173        sql.push(')');
174
175        // Search clause (PostgreSQL only)
176        if db_type == DatabaseType::PostgreSQL {
177            if let Some(ref search) = self.search {
178                sql.push_str(" SEARCH ");
179                sql.push_str(match search.method {
180                    SearchMethod::BreadthFirst => "BREADTH FIRST BY ",
181                    SearchMethod::DepthFirst => "DEPTH FIRST BY ",
182                });
183                sql.push_str(&search.columns.join(", "));
184                sql.push_str(" SET ");
185                sql.push_str(&search.set_column);
186            }
187
188            if let Some(ref cycle) = self.cycle {
189                sql.push_str(" CYCLE ");
190                sql.push_str(&cycle.columns.join(", "));
191                sql.push_str(" SET ");
192                sql.push_str(&cycle.set_column);
193                if let (Some(mark), Some(default)) = (&cycle.mark_value, &cycle.default_value) {
194                    sql.push_str(" TO ");
195                    sql.push_str(mark);
196                    sql.push_str(" DEFAULT ");
197                    sql.push_str(default);
198                }
199                sql.push_str(" USING ");
200                sql.push_str(&cycle.using_column);
201            }
202        }
203
204        sql
205    }
206}
207
208/// Builder for CTEs.
209#[derive(Debug, Clone)]
210pub struct CteBuilder {
211    name: String,
212    columns: Vec<String>,
213    query: Option<String>,
214    recursive: bool,
215    materialized: Option<Materialized>,
216    search: Option<SearchClause>,
217    cycle: Option<CycleClause>,
218}
219
220impl CteBuilder {
221    /// Create a new CTE builder.
222    pub fn new(name: impl Into<String>) -> Self {
223        Self {
224            name: name.into(),
225            columns: Vec::new(),
226            query: None,
227            recursive: false,
228            materialized: None,
229            search: None,
230            cycle: None,
231        }
232    }
233
234    /// Set the column aliases.
235    pub fn columns<I, S>(mut self, columns: I) -> Self
236    where
237        I: IntoIterator<Item = S>,
238        S: Into<String>,
239    {
240        self.columns = columns.into_iter().map(Into::into).collect();
241        self
242    }
243
244    /// Set the query that defines this CTE.
245    pub fn as_query(mut self, query: impl Into<String>) -> Self {
246        self.query = Some(query.into());
247        self
248    }
249
250    /// Mark this as a recursive CTE.
251    pub fn recursive(mut self) -> Self {
252        self.recursive = true;
253        self
254    }
255
256    /// Set materialization hint (PostgreSQL only).
257    pub fn materialized(mut self) -> Self {
258        self.materialized = Some(Materialized::Yes);
259        self
260    }
261
262    /// Prevent materialization (PostgreSQL only).
263    pub fn not_materialized(mut self) -> Self {
264        self.materialized = Some(Materialized::No);
265        self
266    }
267
268    /// Add breadth-first search (PostgreSQL only).
269    pub fn search_breadth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
270    where
271        I: IntoIterator<Item = S>,
272        S: Into<String>,
273    {
274        self.search = Some(SearchClause {
275            method: SearchMethod::BreadthFirst,
276            columns: columns.into_iter().map(Into::into).collect(),
277            set_column: set_column.into(),
278        });
279        self
280    }
281
282    /// Add depth-first search (PostgreSQL only).
283    pub fn search_depth_first<I, S>(mut self, columns: I, set_column: impl Into<String>) -> Self
284    where
285        I: IntoIterator<Item = S>,
286        S: Into<String>,
287    {
288        self.search = Some(SearchClause {
289            method: SearchMethod::DepthFirst,
290            columns: columns.into_iter().map(Into::into).collect(),
291            set_column: set_column.into(),
292        });
293        self
294    }
295
296    /// Add cycle detection (PostgreSQL only).
297    pub fn cycle<I, S>(
298        mut self,
299        columns: I,
300        set_column: impl Into<String>,
301        using_column: impl Into<String>,
302    ) -> Self
303    where
304        I: IntoIterator<Item = S>,
305        S: Into<String>,
306    {
307        self.cycle = Some(CycleClause {
308            columns: columns.into_iter().map(Into::into).collect(),
309            set_column: set_column.into(),
310            using_column: using_column.into(),
311            mark_value: None,
312            default_value: None,
313        });
314        self
315    }
316
317    /// Build the CTE.
318    pub fn build(self) -> QueryResult<Cte> {
319        let query = self.query.ok_or_else(|| {
320            QueryError::invalid_input("query", "CTE requires a query (use as_query())")
321        })?;
322
323        Ok(Cte {
324            name: self.name,
325            columns: self.columns,
326            query,
327            recursive: self.recursive,
328            materialized: self.materialized,
329            search: self.search,
330            cycle: self.cycle,
331        })
332    }
333}
334
335/// A WITH clause containing one or more CTEs.
336#[derive(Debug, Clone, Default, Serialize, Deserialize)]
337pub struct WithClause {
338    /// The CTEs in this WITH clause.
339    pub ctes: Vec<Cte>,
340    /// Whether any CTE is recursive.
341    pub recursive: bool,
342    /// The main query that uses the CTEs.
343    pub main_query: Option<String>,
344}
345
346impl WithClause {
347    /// Create a new empty WITH clause.
348    pub fn new() -> Self {
349        Self::default()
350    }
351
352    /// Add a CTE to this WITH clause.
353    pub fn cte(mut self, cte: Cte) -> Self {
354        if cte.recursive {
355            self.recursive = true;
356        }
357        self.ctes.push(cte);
358        self
359    }
360
361    /// Add multiple CTEs.
362    pub fn ctes<I>(mut self, ctes: I) -> Self
363    where
364        I: IntoIterator<Item = Cte>,
365    {
366        for cte in ctes {
367            self = self.cte(cte);
368        }
369        self
370    }
371
372    /// Set the main query.
373    pub fn main_query(mut self, query: impl Into<String>) -> Self {
374        self.main_query = Some(query.into());
375        self
376    }
377
378    /// Convenience: SELECT from a CTE.
379    pub fn select(self, columns: impl Into<String>) -> WithQueryBuilder {
380        WithQueryBuilder {
381            with_clause: self,
382            select: columns.into(),
383            from: None,
384            where_clause: None,
385            order_by: None,
386            limit: None,
387        }
388    }
389
390    /// Generate the full SQL.
391    pub fn to_sql(&self, db_type: DatabaseType) -> QueryResult<String> {
392        if self.ctes.is_empty() {
393            return Err(QueryError::invalid_input("ctes", "WITH clause requires at least one CTE"));
394        }
395
396        let mut sql = String::with_capacity(256);
397
398        sql.push_str("WITH ");
399        if self.recursive {
400            sql.push_str("RECURSIVE ");
401        }
402
403        let cte_sqls: Vec<String> = self.ctes.iter().map(|c| c.to_sql(db_type)).collect();
404        sql.push_str(&cte_sqls.join(", "));
405
406        if let Some(ref main) = self.main_query {
407            sql.push(' ');
408            sql.push_str(main);
409        }
410
411        Ok(sql)
412    }
413}
414
415/// Builder for queries using WITH clause.
416#[derive(Debug, Clone)]
417pub struct WithQueryBuilder {
418    with_clause: WithClause,
419    select: String,
420    from: Option<String>,
421    where_clause: Option<String>,
422    order_by: Option<String>,
423    limit: Option<u64>,
424}
425
426impl WithQueryBuilder {
427    /// Set the FROM clause.
428    pub fn from(mut self, table: impl Into<String>) -> Self {
429        self.from = Some(table.into());
430        self
431    }
432
433    /// Set the WHERE clause.
434    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
435        self.where_clause = Some(condition.into());
436        self
437    }
438
439    /// Set ORDER BY.
440    pub fn order_by(mut self, order: impl Into<String>) -> Self {
441        self.order_by = Some(order.into());
442        self
443    }
444
445    /// Set LIMIT.
446    pub fn limit(mut self, limit: u64) -> Self {
447        self.limit = Some(limit);
448        self
449    }
450
451    /// Build the full SQL query.
452    pub fn build(mut self, db_type: DatabaseType) -> QueryResult<String> {
453        // Build main query
454        let mut main = format!("SELECT {}", self.select);
455
456        if let Some(from) = self.from {
457            main.push_str(" FROM ");
458            main.push_str(&from);
459        }
460
461        if let Some(where_clause) = self.where_clause {
462            main.push_str(" WHERE ");
463            main.push_str(&where_clause);
464        }
465
466        let has_order_by = self.order_by.is_some();
467        if let Some(order) = self.order_by {
468            main.push_str(" ORDER BY ");
469            main.push_str(&order);
470        }
471
472        if let Some(limit) = self.limit {
473            match db_type {
474                DatabaseType::MSSQL => {
475                    // MSSQL uses TOP or OFFSET FETCH
476                    if has_order_by {
477                        main.push_str(&format!(" OFFSET 0 ROWS FETCH NEXT {} ROWS ONLY", limit));
478                    } else {
479                        // Need to inject TOP after SELECT
480                        main = main.replacen("SELECT ", &format!("SELECT TOP {} ", limit), 1);
481                    }
482                }
483                _ => {
484                    main.push_str(&format!(" LIMIT {}", limit));
485                }
486            }
487        }
488
489        self.with_clause.main_query = Some(main);
490        self.with_clause.to_sql(db_type)
491    }
492}
493
494/// Helper functions for common CTE patterns.
495pub mod patterns {
496    use super::*;
497
498    /// Create a recursive CTE for tree traversal (parent-child hierarchy).
499    pub fn tree_traversal(
500        cte_name: &str,
501        table: &str,
502        id_col: &str,
503        parent_col: &str,
504        root_condition: &str,
505    ) -> Cte {
506        let base_query = format!(
507            "SELECT {id}, {parent}, 1 AS depth FROM {table} WHERE {root}",
508            id = id_col,
509            parent = parent_col,
510            table = table,
511            root = root_condition
512        );
513
514        let recursive_query = format!(
515            "SELECT t.{id}, t.{parent}, c.depth + 1 FROM {table} t \
516             INNER JOIN {cte} c ON t.{parent} = c.{id}",
517            id = id_col,
518            parent = parent_col,
519            table = table,
520            cte = cte_name
521        );
522
523        Cte::new(cte_name)
524            .columns([id_col, parent_col, "depth"])
525            .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
526            .recursive()
527    }
528
529    /// Create a recursive CTE for graph path finding.
530    pub fn graph_path(
531        cte_name: &str,
532        edges_table: &str,
533        from_col: &str,
534        to_col: &str,
535        start_node: &str,
536    ) -> Cte {
537        let base_query = format!(
538            "SELECT {from_col}, {to_col}, ARRAY[{from_col}] AS path, 1 AS length \
539             FROM {table} WHERE {from_col} = {start}",
540            from_col = from_col,
541            to_col = to_col,
542            table = edges_table,
543            start = start_node
544        );
545
546        let recursive_query = format!(
547            "SELECT e.{from_col}, e.{to_col}, p.path || e.{to_col}, p.length + 1 \
548             FROM {table} e \
549             INNER JOIN {cte} p ON e.{from_col} = p.{to_col} \
550             WHERE NOT e.{to_col} = ANY(p.path)",
551            from_col = from_col,
552            to_col = to_col,
553            table = edges_table,
554            cte = cte_name
555        );
556
557        Cte::new(cte_name)
558            .columns([from_col, to_col, "path", "length"])
559            .as_query(format!("{} UNION ALL {}", base_query, recursive_query))
560            .recursive()
561    }
562
563    /// Create a CTE for pagination (row numbering).
564    pub fn paginated(
565        cte_name: &str,
566        query: &str,
567        order_by: &str,
568    ) -> Cte {
569        let paginated_query = format!(
570            "SELECT *, ROW_NUMBER() OVER (ORDER BY {}) AS row_num FROM ({})",
571            order_by, query
572        );
573
574        Cte::new(cte_name).as_query(paginated_query)
575    }
576
577    /// Create a CTE for running totals.
578    pub fn running_total(
579        cte_name: &str,
580        table: &str,
581        value_col: &str,
582        order_col: &str,
583        partition_col: Option<&str>,
584    ) -> Cte {
585        let partition = partition_col
586            .map(|p| format!("PARTITION BY {} ", p))
587            .unwrap_or_default();
588
589        let query = format!(
590            "SELECT *, SUM({value}) OVER ({partition}ORDER BY {order}) AS running_total \
591             FROM {table}",
592            value = value_col,
593            partition = partition,
594            order = order_col,
595            table = table
596        );
597
598        Cte::new(cte_name).as_query(query)
599    }
600}
601
602/// MongoDB $lookup pipeline support (CTE equivalent).
603pub mod mongodb {
604    use serde::{Deserialize, Serialize};
605    use serde_json::Value as JsonValue;
606
607    /// A $lookup stage for MongoDB aggregation pipelines.
608    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
609    pub struct Lookup {
610        /// The foreign collection.
611        pub from: String,
612        /// Local field to match.
613        pub local_field: Option<String>,
614        /// Foreign field to match.
615        pub foreign_field: Option<String>,
616        /// Output array field name.
617        pub as_field: String,
618        /// Pipeline to run on matched documents.
619        pub pipeline: Option<Vec<JsonValue>>,
620        /// Variables to pass to pipeline.
621        pub let_vars: Option<serde_json::Map<String, JsonValue>>,
622    }
623
624    impl Lookup {
625        /// Create a simple $lookup (equality match).
626        pub fn simple(from: impl Into<String>, local: impl Into<String>, foreign: impl Into<String>, as_field: impl Into<String>) -> Self {
627            Self {
628                from: from.into(),
629                local_field: Some(local.into()),
630                foreign_field: Some(foreign.into()),
631                as_field: as_field.into(),
632                pipeline: None,
633                let_vars: None,
634            }
635        }
636
637        /// Create a $lookup with pipeline (subquery).
638        pub fn with_pipeline(from: impl Into<String>, as_field: impl Into<String>) -> LookupBuilder {
639            LookupBuilder {
640                from: from.into(),
641                as_field: as_field.into(),
642                pipeline: Vec::new(),
643                let_vars: serde_json::Map::new(),
644            }
645        }
646
647        /// Convert to BSON document.
648        pub fn to_bson(&self) -> JsonValue {
649            let mut lookup = serde_json::Map::new();
650            lookup.insert("from".to_string(), JsonValue::String(self.from.clone()));
651
652            if let (Some(local), Some(foreign)) = (&self.local_field, &self.foreign_field) {
653                lookup.insert("localField".to_string(), JsonValue::String(local.clone()));
654                lookup.insert("foreignField".to_string(), JsonValue::String(foreign.clone()));
655            }
656
657            lookup.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
658
659            if let Some(ref pipeline) = self.pipeline {
660                lookup.insert("pipeline".to_string(), JsonValue::Array(pipeline.clone()));
661            }
662
663            if let Some(ref vars) = self.let_vars {
664                if !vars.is_empty() {
665                    lookup.insert("let".to_string(), JsonValue::Object(vars.clone()));
666                }
667            }
668
669            serde_json::json!({ "$lookup": lookup })
670        }
671    }
672
673    /// Builder for $lookup with pipeline.
674    #[derive(Debug, Clone)]
675    pub struct LookupBuilder {
676        from: String,
677        as_field: String,
678        pipeline: Vec<JsonValue>,
679        let_vars: serde_json::Map<String, JsonValue>,
680    }
681
682    impl LookupBuilder {
683        /// Add a variable for the pipeline.
684        pub fn let_var(mut self, name: impl Into<String>, expr: impl Into<String>) -> Self {
685            self.let_vars.insert(
686                name.into(),
687                JsonValue::String(format!("${}", expr.into())),
688            );
689            self
690        }
691
692        /// Add a $match stage to the pipeline.
693        pub fn match_expr(mut self, expr: JsonValue) -> Self {
694            self.pipeline.push(serde_json::json!({ "$match": { "$expr": expr } }));
695            self
696        }
697
698        /// Add a raw stage to the pipeline.
699        pub fn stage(mut self, stage: JsonValue) -> Self {
700            self.pipeline.push(stage);
701            self
702        }
703
704        /// Add a $project stage.
705        pub fn project(mut self, fields: JsonValue) -> Self {
706            self.pipeline.push(serde_json::json!({ "$project": fields }));
707            self
708        }
709
710        /// Add a $limit stage.
711        pub fn limit(mut self, n: u64) -> Self {
712            self.pipeline.push(serde_json::json!({ "$limit": n }));
713            self
714        }
715
716        /// Add a $sort stage.
717        pub fn sort(mut self, fields: JsonValue) -> Self {
718            self.pipeline.push(serde_json::json!({ "$sort": fields }));
719            self
720        }
721
722        /// Build the $lookup.
723        pub fn build(self) -> Lookup {
724            Lookup {
725                from: self.from,
726                local_field: None,
727                foreign_field: None,
728                as_field: self.as_field,
729                pipeline: if self.pipeline.is_empty() {
730                    None
731                } else {
732                    Some(self.pipeline)
733                },
734                let_vars: if self.let_vars.is_empty() {
735                    None
736                } else {
737                    Some(self.let_vars)
738                },
739            }
740        }
741    }
742
743    /// A $graphLookup stage for recursive lookups.
744    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
745    pub struct GraphLookup {
746        /// The collection to search.
747        pub from: String,
748        /// Starting value expression.
749        pub start_with: String,
750        /// Field to connect from.
751        pub connect_from_field: String,
752        /// Field to connect to.
753        pub connect_to_field: String,
754        /// Output array field.
755        pub as_field: String,
756        /// Maximum recursion depth.
757        pub max_depth: Option<u32>,
758        /// Name for depth field.
759        pub depth_field: Option<String>,
760        /// Filter to apply at each level.
761        pub restrict_search_with_match: Option<JsonValue>,
762    }
763
764    impl GraphLookup {
765        /// Create a new $graphLookup.
766        pub fn new(
767            from: impl Into<String>,
768            start_with: impl Into<String>,
769            connect_from: impl Into<String>,
770            connect_to: impl Into<String>,
771            as_field: impl Into<String>,
772        ) -> Self {
773            Self {
774                from: from.into(),
775                start_with: start_with.into(),
776                connect_from_field: connect_from.into(),
777                connect_to_field: connect_to.into(),
778                as_field: as_field.into(),
779                max_depth: None,
780                depth_field: None,
781                restrict_search_with_match: None,
782            }
783        }
784
785        /// Set maximum recursion depth.
786        pub fn max_depth(mut self, depth: u32) -> Self {
787            self.max_depth = Some(depth);
788            self
789        }
790
791        /// Add a depth field to results.
792        pub fn depth_field(mut self, field: impl Into<String>) -> Self {
793            self.depth_field = Some(field.into());
794            self
795        }
796
797        /// Add a filter for each recursion level.
798        pub fn restrict_search(mut self, filter: JsonValue) -> Self {
799            self.restrict_search_with_match = Some(filter);
800            self
801        }
802
803        /// Convert to BSON document.
804        pub fn to_bson(&self) -> JsonValue {
805            let mut graph = serde_json::Map::new();
806            graph.insert("from".to_string(), JsonValue::String(self.from.clone()));
807            graph.insert("startWith".to_string(), JsonValue::String(format!("${}", self.start_with)));
808            graph.insert("connectFromField".to_string(), JsonValue::String(self.connect_from_field.clone()));
809            graph.insert("connectToField".to_string(), JsonValue::String(self.connect_to_field.clone()));
810            graph.insert("as".to_string(), JsonValue::String(self.as_field.clone()));
811
812            if let Some(max) = self.max_depth {
813                graph.insert("maxDepth".to_string(), JsonValue::Number(max.into()));
814            }
815
816            if let Some(ref field) = self.depth_field {
817                graph.insert("depthField".to_string(), JsonValue::String(field.clone()));
818            }
819
820            if let Some(ref filter) = self.restrict_search_with_match {
821                graph.insert("restrictSearchWithMatch".to_string(), filter.clone());
822            }
823
824            serde_json::json!({ "$graphLookup": graph })
825        }
826    }
827
828    /// A $unionWith stage (similar to UNION ALL).
829    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
830    pub struct UnionWith {
831        /// Collection to union with.
832        pub coll: String,
833        /// Optional pipeline to apply before union.
834        pub pipeline: Option<Vec<JsonValue>>,
835    }
836
837    impl UnionWith {
838        /// Create a simple union with a collection.
839        pub fn collection(coll: impl Into<String>) -> Self {
840            Self {
841                coll: coll.into(),
842                pipeline: None,
843            }
844        }
845
846        /// Create a union with a pipeline.
847        pub fn with_pipeline(coll: impl Into<String>, pipeline: Vec<JsonValue>) -> Self {
848            Self {
849                coll: coll.into(),
850                pipeline: Some(pipeline),
851            }
852        }
853
854        /// Convert to BSON document.
855        pub fn to_bson(&self) -> JsonValue {
856            if let Some(ref pipeline) = self.pipeline {
857                serde_json::json!({
858                    "$unionWith": {
859                        "coll": self.coll,
860                        "pipeline": pipeline
861                    }
862                })
863            } else {
864                serde_json::json!({ "$unionWith": self.coll })
865            }
866        }
867    }
868
869    /// Helper to create a simple lookup.
870    pub fn lookup(from: &str, local: &str, foreign: &str, as_field: &str) -> Lookup {
871        Lookup::simple(from, local, foreign, as_field)
872    }
873
874    /// Helper to create a lookup with pipeline.
875    pub fn lookup_pipeline(from: &str, as_field: &str) -> LookupBuilder {
876        Lookup::with_pipeline(from, as_field)
877    }
878
879    /// Helper to create a graph lookup.
880    pub fn graph_lookup(
881        from: &str,
882        start_with: &str,
883        connect_from: &str,
884        connect_to: &str,
885        as_field: &str,
886    ) -> GraphLookup {
887        GraphLookup::new(from, start_with, connect_from, connect_to, as_field)
888    }
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894
895    #[test]
896    fn test_simple_cte() {
897        let cte = Cte::new("active_users")
898            .as_query("SELECT * FROM users WHERE active = true");
899
900        let sql = cte.to_sql(DatabaseType::PostgreSQL);
901        assert!(sql.contains("active_users AS"));
902        assert!(sql.contains("SELECT * FROM users"));
903    }
904
905    #[test]
906    fn test_cte_with_columns() {
907        let cte = Cte::new("user_stats")
908            .columns(["id", "name", "total"])
909            .as_query("SELECT id, name, COUNT(*) FROM orders GROUP BY user_id");
910
911        let sql = cte.to_sql(DatabaseType::PostgreSQL);
912        assert!(sql.contains("user_stats (id, name, total) AS"));
913    }
914
915    #[test]
916    fn test_recursive_cte() {
917        let cte = Cte::new("subordinates")
918            .columns(["id", "name", "manager_id", "depth"])
919            .as_query(
920                "SELECT id, name, manager_id, 1 FROM employees WHERE manager_id IS NULL \
921                 UNION ALL \
922                 SELECT e.id, e.name, e.manager_id, s.depth + 1 \
923                 FROM employees e JOIN subordinates s ON e.manager_id = s.id"
924            )
925            .recursive();
926
927        assert!(cte.recursive);
928    }
929
930    #[test]
931    fn test_materialized_cte() {
932        let cte = Cte::new("expensive_query")
933            .as_query("SELECT * FROM big_table WHERE complex_condition")
934            .materialized(Materialized::Yes);
935
936        let sql = cte.to_sql(DatabaseType::PostgreSQL);
937        assert!(sql.contains("MATERIALIZED"));
938    }
939
940    #[test]
941    fn test_with_clause() {
942        let cte1 = Cte::new("cte1").as_query("SELECT 1");
943        let cte2 = Cte::new("cte2").as_query("SELECT 2");
944
945        let with = WithClause::new()
946            .cte(cte1)
947            .cte(cte2)
948            .main_query("SELECT * FROM cte1, cte2");
949
950        let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
951        assert!(sql.starts_with("WITH "));
952        assert!(sql.contains("cte1 AS"));
953        assert!(sql.contains("cte2 AS"));
954        assert!(sql.contains("SELECT * FROM cte1, cte2"));
955    }
956
957    #[test]
958    fn test_recursive_with_clause() {
959        let cte = Cte::new("numbers")
960            .as_query("SELECT 1 AS n UNION ALL SELECT n + 1 FROM numbers WHERE n < 10")
961            .recursive();
962
963        let with = WithClause::new()
964            .cte(cte)
965            .main_query("SELECT * FROM numbers");
966
967        let sql = with.to_sql(DatabaseType::PostgreSQL).unwrap();
968        assert!(sql.starts_with("WITH RECURSIVE"));
969    }
970
971    #[test]
972    fn test_with_query_builder() {
973        let cte = Cte::new("active")
974            .as_query("SELECT * FROM users WHERE active = true");
975
976        let sql = WithClause::new()
977            .cte(cte)
978            .select("*")
979            .from("active")
980            .where_clause("role = 'admin'")
981            .order_by("name")
982            .limit(10)
983            .build(DatabaseType::PostgreSQL)
984            .unwrap();
985
986        assert!(sql.contains("WITH active AS"));
987        assert!(sql.contains("SELECT *"));
988        assert!(sql.contains("FROM active"));
989        assert!(sql.contains("WHERE role = 'admin'"));
990        assert!(sql.contains("ORDER BY name"));
991        assert!(sql.contains("LIMIT 10"));
992    }
993
994    #[test]
995    fn test_mssql_limit() {
996        let cte = Cte::new("data").as_query("SELECT * FROM table1");
997
998        let sql = WithClause::new()
999            .cte(cte)
1000            .select("*")
1001            .from("data")
1002            .order_by("id")
1003            .limit(10)
1004            .build(DatabaseType::MSSQL)
1005            .unwrap();
1006
1007        assert!(sql.contains("OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY"));
1008    }
1009
1010    #[test]
1011    fn test_cte_builder() {
1012        let cte = CteBuilder::new("stats")
1013            .columns(["a", "b"])
1014            .as_query("SELECT 1, 2")
1015            .materialized()
1016            .build()
1017            .unwrap();
1018
1019        assert_eq!(cte.name, "stats");
1020        assert_eq!(cte.columns, vec!["a", "b"]);
1021        assert_eq!(cte.materialized, Some(Materialized::Yes));
1022    }
1023
1024    mod pattern_tests {
1025        use super::super::patterns::*;
1026
1027        #[test]
1028        fn test_tree_traversal_pattern() {
1029            let cte = tree_traversal(
1030                "org_tree",
1031                "employees",
1032                "id",
1033                "manager_id",
1034                "manager_id IS NULL"
1035            );
1036
1037            assert!(cte.recursive);
1038            assert!(cte.query.contains("UNION ALL"));
1039            assert!(cte.query.contains("depth + 1"));
1040        }
1041
1042        #[test]
1043        fn test_running_total_pattern() {
1044            let cte = running_total(
1045                "account_balance",
1046                "transactions",
1047                "amount",
1048                "transaction_date",
1049                Some("account_id")
1050            );
1051
1052            assert!(cte.query.contains("SUM(amount)"));
1053            assert!(cte.query.contains("PARTITION BY account_id"));
1054            assert!(cte.query.contains("running_total"));
1055        }
1056    }
1057
1058    mod mongodb_tests {
1059        use super::super::mongodb::*;
1060
1061        #[test]
1062        fn test_simple_lookup() {
1063            let lookup = Lookup::simple("orders", "user_id", "_id", "user_orders");
1064            let bson = lookup.to_bson();
1065
1066            assert_eq!(bson["$lookup"]["from"], "orders");
1067            assert_eq!(bson["$lookup"]["localField"], "user_id");
1068            assert_eq!(bson["$lookup"]["foreignField"], "_id");
1069            assert_eq!(bson["$lookup"]["as"], "user_orders");
1070        }
1071
1072        #[test]
1073        fn test_lookup_with_pipeline() {
1074            let lookup = Lookup::with_pipeline("inventory", "stock_items")
1075                .let_var("order_item", "item")
1076                .match_expr(serde_json::json!({
1077                    "$eq": ["$sku", "$$order_item"]
1078                }))
1079                .project(serde_json::json!({ "inStock": 1 }))
1080                .build();
1081
1082            let bson = lookup.to_bson();
1083            assert!(bson["$lookup"]["pipeline"].is_array());
1084            assert!(bson["$lookup"]["let"].is_object());
1085        }
1086
1087        #[test]
1088        fn test_graph_lookup() {
1089            let lookup = GraphLookup::new(
1090                "employees",
1091                "reportsTo",
1092                "reportsTo",
1093                "name",
1094                "reportingHierarchy"
1095            )
1096            .max_depth(5)
1097            .depth_field("level");
1098
1099            let bson = lookup.to_bson();
1100            assert_eq!(bson["$graphLookup"]["from"], "employees");
1101            assert_eq!(bson["$graphLookup"]["maxDepth"], 5);
1102            assert_eq!(bson["$graphLookup"]["depthField"], "level");
1103        }
1104
1105        #[test]
1106        fn test_union_with() {
1107            let union = UnionWith::collection("archived_orders");
1108            let bson = union.to_bson();
1109
1110            assert_eq!(bson["$unionWith"], "archived_orders");
1111        }
1112
1113        #[test]
1114        fn test_union_with_pipeline() {
1115            let union = UnionWith::with_pipeline(
1116                "archive",
1117                vec![serde_json::json!({ "$match": { "year": 2023 } })]
1118            );
1119            let bson = union.to_bson();
1120
1121            assert!(bson["$unionWith"]["pipeline"].is_array());
1122        }
1123    }
1124}
1125