Skip to main content

reddb_server/storage/query/executors/
cte.rs

1//! CTE (Common Table Expression) executor.
2//!
3//! Implements the `WITH … AS (…) SELECT …` SQL standard form. CTEs
4//! are materialised as temporary result sets that the main query can
5//! reference by name. Recursive CTEs (`WITH RECURSIVE`) use the
6//! classic iterative-fixpoint shape:
7//!
8//! 1. Execute the non-recursive (base) part once.
9//! 2. Repeat: execute the recursive part with the previous iteration's
10//!    rows visible as the CTE.
11//! 3. Stop when no new rows are produced (or guard limits trip).
12//!
13//! Today only non-recursive CTEs are wired through the runtime —
14//! `inline_ctes` rejects `WITH RECURSIVE` with a clear error. The
15//! iterative-fixpoint code in `CteExecutor` is reachable only via
16//! direct unit tests and is the basis for the future recursive wire-
17//! up (#41 follow-up).
18//!
19//! # Example
20//!
21//! ```ignore
22//! WITH active AS (
23//!     SELECT id, name FROM users WHERE status = 'active'
24//! )
25//! SELECT * FROM active
26//! ```
27
28use std::collections::{HashMap, HashSet};
29
30use super::super::ast::{
31    CteDefinition, Expr, Filter, Projection, QueryExpr, QueryWithCte, SelectItem, TableSource,
32};
33use super::super::unified::{ExecutionError, UnifiedRecord, UnifiedResult};
34use crate::storage::schema::Value;
35
36/// Maximum recursion depth to prevent infinite loops
37const MAX_RECURSION_DEPTH: usize = 1000;
38
39/// Maximum total rows across all iterations
40const MAX_RECURSIVE_ROWS: usize = 100_000;
41
42/// CTE execution context holding materialized CTE results
43#[derive(Debug, Clone, Default)]
44pub struct CteContext {
45    /// Materialized CTE results by name
46    tables: HashMap<String, UnifiedResult>,
47    /// Track which CTEs are currently being evaluated (for cycle detection)
48    evaluating: HashSet<String>,
49    /// Statistics
50    stats: CteStats,
51}
52
53impl CteContext {
54    /// Create a new CTE context
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Get a materialized CTE result by name
60    pub fn get(&self, name: &str) -> Option<&UnifiedResult> {
61        self.tables.get(name)
62    }
63
64    /// Store a materialized CTE result
65    pub fn store(&mut self, name: String, result: UnifiedResult) {
66        self.tables.insert(name, result);
67    }
68
69    /// Check if a CTE is being evaluated (for recursion detection)
70    pub fn is_evaluating(&self, name: &str) -> bool {
71        self.evaluating.contains(name)
72    }
73
74    /// Mark a CTE as being evaluated
75    pub fn start_evaluating(&mut self, name: &str) {
76        self.evaluating.insert(name.to_string());
77    }
78
79    /// Mark a CTE as done evaluating
80    pub fn done_evaluating(&mut self, name: &str) {
81        self.evaluating.remove(name);
82    }
83
84    /// Get execution statistics
85    pub fn stats(&self) -> &CteStats {
86        &self.stats
87    }
88}
89
90/// Statistics about CTE execution
91#[derive(Debug, Clone, Default)]
92pub struct CteStats {
93    /// Number of CTEs executed
94    pub ctes_executed: usize,
95    /// Number of recursive iterations
96    pub recursive_iterations: usize,
97    /// Total rows produced by CTEs
98    pub rows_produced: usize,
99    /// Execution time in microseconds
100    pub exec_time_us: u64,
101}
102
103/// CTE Executor
104pub struct CteExecutor<F>
105where
106    F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
107{
108    /// Function to execute a query with CTE context
109    execute_fn: F,
110}
111
112impl<F> CteExecutor<F>
113where
114    F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
115{
116    /// Create a new CTE executor
117    pub fn new(execute_fn: F) -> Self {
118        Self { execute_fn }
119    }
120
121    /// Execute a query with CTEs
122    pub fn execute(&self, query: &QueryWithCte) -> Result<UnifiedResult, ExecutionError> {
123        let start = std::time::Instant::now();
124        let mut ctx = CteContext::new();
125
126        // Materialize all CTEs in order
127        if let Some(ref with_clause) = query.with_clause {
128            for cte in &with_clause.ctes {
129                self.materialize_cte(cte, &mut ctx)?;
130            }
131        }
132
133        // Execute the main query with CTE context
134        let result = (self.execute_fn)(&query.query, &ctx)?;
135
136        ctx.stats.exec_time_us = start.elapsed().as_micros() as u64;
137        Ok(result)
138    }
139
140    /// Materialize a single CTE
141    fn materialize_cte(
142        &self,
143        cte: &CteDefinition,
144        ctx: &mut CteContext,
145    ) -> Result<(), ExecutionError> {
146        if ctx.is_evaluating(&cte.name) {
147            return Err(ExecutionError::new(format!(
148                "Circular CTE reference: {}",
149                cte.name
150            )));
151        }
152
153        // Check if already materialized
154        if ctx.get(&cte.name).is_some() {
155            return Ok(());
156        }
157
158        ctx.start_evaluating(&cte.name);
159
160        let result = if cte.recursive {
161            self.execute_recursive_cte(cte, ctx)?
162        } else {
163            // Simple CTE: execute once
164            let result = (self.execute_fn)(&cte.query, ctx)?;
165            self.project_columns(&result, &cte.columns)
166        };
167
168        ctx.stats.ctes_executed += 1;
169        ctx.stats.rows_produced += result.len();
170        ctx.store(cte.name.clone(), result);
171        ctx.done_evaluating(&cte.name);
172
173        Ok(())
174    }
175
176    /// Execute a recursive CTE using iterative fixpoint
177    fn execute_recursive_cte(
178        &self,
179        cte: &CteDefinition,
180        ctx: &mut CteContext,
181    ) -> Result<UnifiedResult, ExecutionError> {
182        // For recursive CTEs, we need to handle UNION ALL structure
183        // The query should be: base_query UNION ALL recursive_query
184        //
185        // Algorithm:
186        // 1. Execute base query -> working_table
187        // 2. result_table = working_table
188        // 3. While working_table not empty:
189        //    a. Execute recursive query with CTE = working_table
190        //    b. new_rows = result - already_seen
191        //    c. working_table = new_rows
192        //    d. result_table += new_rows
193        // 4. Return result_table
194
195        // For simplicity in first implementation, we execute the full query
196        // iteratively, building up the working table
197
198        let mut all_results = UnifiedResult::with_columns(cte.columns.clone());
199        let mut working_table = UnifiedResult::with_columns(cte.columns.clone());
200        let mut seen_rows: HashSet<u64> = HashSet::new();
201        let mut iteration = 0;
202
203        // First iteration: execute the full query (base case)
204        let initial = (self.execute_fn)(&cte.query, ctx)?;
205        let initial = self.project_columns(&initial, &cte.columns);
206
207        for record in &initial.records {
208            let hash = self.hash_record(record);
209            if seen_rows.insert(hash) {
210                working_table.push(record.clone());
211                all_results.push(record.clone());
212            }
213        }
214
215        // Store initial results so recursive references can see them
216        ctx.store(cte.name.clone(), working_table.clone());
217
218        // Iterate until fixpoint
219        while !working_table.is_empty() && iteration < MAX_RECURSION_DEPTH {
220            iteration += 1;
221            ctx.stats.recursive_iterations += 1;
222
223            if all_results.len() > MAX_RECURSIVE_ROWS {
224                return Err(ExecutionError::new(format!(
225                    "Recursive CTE '{}' exceeded maximum rows ({})",
226                    cte.name, MAX_RECURSIVE_ROWS
227                )));
228            }
229
230            // Execute query with current CTE contents
231            let new_results = (self.execute_fn)(&cte.query, ctx)?;
232            let new_results = self.project_columns(&new_results, &cte.columns);
233
234            // Find genuinely new rows
235            let mut new_working_table = UnifiedResult::with_columns(cte.columns.clone());
236            for record in &new_results.records {
237                let hash = self.hash_record(record);
238                if seen_rows.insert(hash) {
239                    new_working_table.push(record.clone());
240                    all_results.push(record.clone());
241                }
242            }
243
244            working_table = new_working_table;
245
246            // Update CTE table for next iteration
247            ctx.store(cte.name.clone(), all_results.clone());
248        }
249
250        if iteration >= MAX_RECURSION_DEPTH && !working_table.is_empty() {
251            return Err(ExecutionError::new(format!(
252                "Recursive CTE '{}' exceeded maximum recursion depth ({})",
253                cte.name, MAX_RECURSION_DEPTH
254            )));
255        }
256
257        Ok(all_results)
258    }
259
260    /// Project result columns according to CTE column list
261    fn project_columns(&self, result: &UnifiedResult, columns: &[String]) -> UnifiedResult {
262        if columns.is_empty() {
263            return result.clone();
264        }
265
266        let mut projected = UnifiedResult::with_columns(columns.to_vec());
267
268        for record in &result.records {
269            let mut new_record = UnifiedRecord::new();
270
271            // Map result columns to CTE columns
272            for (i, col) in columns.iter().enumerate() {
273                // Try to find value by position first, then by name
274                let value = result
275                    .columns
276                    .get(i)
277                    .and_then(|orig_col| record.get(orig_col))
278                    .cloned()
279                    .or_else(|| record.get(col).cloned())
280                    .unwrap_or(Value::Null);
281
282                new_record.set(col, value);
283            }
284
285            projected.push(new_record);
286        }
287
288        projected
289    }
290
291    /// Hash a record for deduplication
292    fn hash_record(&self, record: &UnifiedRecord) -> u64 {
293        use std::collections::hash_map::DefaultHasher;
294        use std::hash::{Hash, Hasher};
295
296        let mut hasher = DefaultHasher::new();
297
298        // Hash all values in deterministic order
299        let mut keys = record.column_names();
300        keys.sort();
301
302        for key in &keys {
303            (**key).hash(&mut hasher);
304            if let Some(value) = record.get(key) {
305                Self::hash_value(value, &mut hasher);
306            }
307        }
308
309        hasher.finish()
310    }
311
312    /// Hash a Value for deduplication
313    fn hash_value(value: &Value, hasher: &mut impl std::hash::Hasher) {
314        use std::hash::Hash;
315
316        match value {
317            Value::Null => 0u8.hash(hasher),
318            Value::Boolean(b) => {
319                1u8.hash(hasher);
320                b.hash(hasher);
321            }
322            Value::Integer(i) => {
323                2u8.hash(hasher);
324                i.hash(hasher);
325            }
326            Value::UnsignedInteger(u) => {
327                3u8.hash(hasher);
328                u.hash(hasher);
329            }
330            Value::Float(f) => {
331                4u8.hash(hasher);
332                f.to_bits().hash(hasher);
333            }
334            Value::Text(s) => {
335                5u8.hash(hasher);
336                s.hash(hasher);
337            }
338            Value::Blob(b) => {
339                6u8.hash(hasher);
340                b.hash(hasher);
341            }
342            Value::Timestamp(t) => {
343                7u8.hash(hasher);
344                t.hash(hasher);
345            }
346            Value::Duration(d) => {
347                8u8.hash(hasher);
348                d.hash(hasher);
349            }
350            Value::IpAddr(addr) => {
351                9u8.hash(hasher);
352                match addr {
353                    std::net::IpAddr::V4(v4) => v4.octets().hash(hasher),
354                    std::net::IpAddr::V6(v6) => v6.octets().hash(hasher),
355                }
356            }
357            Value::MacAddr(mac) => {
358                10u8.hash(hasher);
359                mac.hash(hasher);
360            }
361            Value::Vector(v) => {
362                11u8.hash(hasher);
363                v.len().hash(hasher);
364                for f in v {
365                    f.to_bits().hash(hasher);
366                }
367            }
368            Value::Json(j) => {
369                12u8.hash(hasher);
370                j.hash(hasher);
371            }
372            Value::Uuid(u) => {
373                13u8.hash(hasher);
374                u.hash(hasher);
375            }
376            Value::NodeRef(n) => {
377                14u8.hash(hasher);
378                n.hash(hasher);
379            }
380            Value::EdgeRef(e) => {
381                15u8.hash(hasher);
382                e.hash(hasher);
383            }
384            Value::VectorRef(coll, id) => {
385                16u8.hash(hasher);
386                coll.hash(hasher);
387                id.hash(hasher);
388            }
389            Value::RowRef(table, id) => {
390                17u8.hash(hasher);
391                table.hash(hasher);
392                id.hash(hasher);
393            }
394            Value::Color(rgb) => {
395                18u8.hash(hasher);
396                rgb.hash(hasher);
397            }
398            Value::Email(s) => {
399                19u8.hash(hasher);
400                s.hash(hasher);
401            }
402            Value::Url(s) => {
403                20u8.hash(hasher);
404                s.hash(hasher);
405            }
406            Value::Phone(n) => {
407                21u8.hash(hasher);
408                n.hash(hasher);
409            }
410            Value::Semver(v) => {
411                22u8.hash(hasher);
412                v.hash(hasher);
413            }
414            Value::Cidr(ip, prefix) => {
415                23u8.hash(hasher);
416                ip.hash(hasher);
417                prefix.hash(hasher);
418            }
419            Value::Date(d) => {
420                24u8.hash(hasher);
421                d.hash(hasher);
422            }
423            Value::Time(t) => {
424                25u8.hash(hasher);
425                t.hash(hasher);
426            }
427            Value::Decimal(v) => {
428                26u8.hash(hasher);
429                v.hash(hasher);
430            }
431            Value::EnumValue(i) => {
432                27u8.hash(hasher);
433                i.hash(hasher);
434            }
435            Value::Array(elems) => {
436                28u8.hash(hasher);
437                elems.len().hash(hasher);
438                for elem in elems {
439                    Self::hash_value(elem, hasher);
440                }
441            }
442            Value::TimestampMs(v) => {
443                29u8.hash(hasher);
444                v.hash(hasher);
445            }
446            Value::Ipv4(v) => {
447                30u8.hash(hasher);
448                v.hash(hasher);
449            }
450            Value::Ipv6(bytes) => {
451                31u8.hash(hasher);
452                bytes.hash(hasher);
453            }
454            Value::Subnet(ip, mask) => {
455                32u8.hash(hasher);
456                ip.hash(hasher);
457                mask.hash(hasher);
458            }
459            Value::Port(v) => {
460                33u8.hash(hasher);
461                v.hash(hasher);
462            }
463            Value::Latitude(v) => {
464                34u8.hash(hasher);
465                v.hash(hasher);
466            }
467            Value::Longitude(v) => {
468                35u8.hash(hasher);
469                v.hash(hasher);
470            }
471            Value::GeoPoint(lat, lon) => {
472                36u8.hash(hasher);
473                lat.hash(hasher);
474                lon.hash(hasher);
475            }
476            Value::Country2(c) => {
477                37u8.hash(hasher);
478                c.hash(hasher);
479            }
480            Value::Country3(c) => {
481                38u8.hash(hasher);
482                c.hash(hasher);
483            }
484            Value::Lang2(c) => {
485                39u8.hash(hasher);
486                c.hash(hasher);
487            }
488            Value::Lang5(c) => {
489                40u8.hash(hasher);
490                c.hash(hasher);
491            }
492            Value::Currency(c) => {
493                41u8.hash(hasher);
494                c.hash(hasher);
495            }
496            Value::AssetCode(code) => {
497                50u8.hash(hasher);
498                code.hash(hasher);
499            }
500            Value::Money {
501                asset_code,
502                minor_units,
503                scale,
504            } => {
505                51u8.hash(hasher);
506                asset_code.hash(hasher);
507                minor_units.hash(hasher);
508                scale.hash(hasher);
509            }
510            Value::ColorAlpha(rgba) => {
511                42u8.hash(hasher);
512                rgba.hash(hasher);
513            }
514            Value::BigInt(v) => {
515                43u8.hash(hasher);
516                v.hash(hasher);
517            }
518            Value::KeyRef(col, key) => {
519                44u8.hash(hasher);
520                col.hash(hasher);
521                key.hash(hasher);
522            }
523            Value::DocRef(col, id) => {
524                45u8.hash(hasher);
525                col.hash(hasher);
526                id.hash(hasher);
527            }
528            Value::TableRef(name) => {
529                46u8.hash(hasher);
530                name.hash(hasher);
531            }
532            Value::PageRef(page_id) => {
533                47u8.hash(hasher);
534                page_id.hash(hasher);
535            }
536            Value::Secret(bytes) => {
537                48u8.hash(hasher);
538                bytes.hash(hasher);
539            }
540            Value::Password(hash) => {
541                49u8.hash(hasher);
542                hash.hash(hasher);
543            }
544        }
545    }
546}
547
548/// Helper to parse UNION structure for recursive CTEs
549pub fn split_union_parts(query: &QueryExpr) -> Option<(QueryExpr, QueryExpr)> {
550    // UNION support is not represented in the current AST; recursive queries execute
551    // the full body expression each iteration.
552    let _ = query;
553    None
554}
555
556// ─────────────────────────────────────────────────────────────────────
557// CTE inlining (#41) — non-recursive
558//
559// Rewrites a `QueryWithCte` into a plain `QueryExpr` by walking the
560// AST and substituting every `TableSource::Name(name)` (or legacy
561// `TableQuery.table` field) that matches a CTE name with
562// `TableSource::Subquery(cte.query)`. After this pass the runtime's
563// existing subquery-in-FROM machinery executes the result with no
564// CTE-specific dispatch needed.
565//
566// Recursive CTEs are rejected up-front — the iterative fixpoint
567// strategy is implemented in `CteExecutor` but is not wired into the
568// runtime yet (separate slice).
569// ─────────────────────────────────────────────────────────────────────
570
571/// Inline a `QueryWithCte`'s WITH clause into its inner query. Returns
572/// the rewritten `QueryExpr` ready for dispatch. Recursive CTEs are
573/// rejected with a clear error.
574pub fn inline_ctes(query: QueryWithCte) -> Result<QueryExpr, ExecutionError> {
575    let Some(with_clause) = query.with_clause else {
576        return Ok(query.query);
577    };
578    if with_clause.has_recursive {
579        return Err(ExecutionError::new(
580            "WITH RECURSIVE is not yet supported by the executor; \
581             non-recursive WITH clauses run today, recursive support \
582             is tracked separately"
583                .to_string(),
584        ));
585    }
586
587    // Inline each CTE into its successors first so chained CTEs
588    // (`WITH a AS (...), b AS (... a ...)`) end up with fully resolved
589    // bodies before they're substituted into the outer query.
590    let mut resolved: HashMap<String, QueryExpr> = HashMap::new();
591    for cte in &with_clause.ctes {
592        let mut body = (*cte.query).clone();
593        rewrite(&mut body, &resolved);
594        resolved.insert(cte.name.clone(), body);
595    }
596
597    let mut outer = query.query;
598    rewrite(&mut outer, &resolved);
599    Ok(outer)
600}
601
602/// Walk a `QueryExpr` and replace any table reference whose name
603/// matches a key in `ctes` with the inlined CTE body. Recurses
604/// through `Join` and nested `Subquery` sources so CTE refs inside
605/// JOINs and subqueries resolve too. Mirrors the view-rewrite
606/// convention: when the outer table reference carries filter / limit
607/// / offset constraints we wrap the body in a `Subquery` to preserve
608/// them; otherwise we replace the whole `Table` node verbatim with
609/// the CTE body so dispatchers that key off `QueryExpr::Table` (like
610/// the JOIN executor) see the right shape.
611fn rewrite(expr: &mut QueryExpr, ctes: &HashMap<String, QueryExpr>) {
612    match expr {
613        QueryExpr::Table(tq) => {
614            let lookup_name = match &tq.source {
615                Some(TableSource::Subquery(_)) => None,
616                Some(TableSource::Name(n)) => Some(n.clone()),
617                // Table-valued functions are not CTE references (issue #795);
618                // the inline-graph form likewise references no CTE (issue #799).
619                Some(TableSource::Function { .. } | TableSource::InlineGraphFunction { .. }) => {
620                    None
621                }
622                None => Some(tq.table.clone()),
623            };
624
625            if let Some(name) = lookup_name {
626                if let Some(body) = ctes.get(&name) {
627                    let outer_has_constraints = tq.filter.is_some()
628                        || tq.where_expr.is_some()
629                        || tq.limit.is_some()
630                        || tq.offset.is_some()
631                        || !tq.columns.is_empty()
632                        || !tq.select_items.is_empty()
633                        || !tq.group_by.is_empty()
634                        || !tq.order_by.is_empty();
635
636                    if outer_has_constraints {
637                        // Outer ref carries projections / filters /
638                        // limits — keep those by wrapping the body in
639                        // a subquery source. Sentinel name so legacy
640                        // `table` consumers can't resolve it against
641                        // the real schema.
642                        tq.source = Some(TableSource::Subquery(Box::new(body.clone())));
643                        tq.table = format!("__cte_{name}");
644                    } else {
645                        // Bare `FROM cte` (possibly with alias) —
646                        // replace verbatim so JOIN / dispatch paths
647                        // see the CTE body's natural shape.
648                        *expr = body.clone();
649                    }
650                    return;
651                }
652            }
653
654            match tq.source.as_mut() {
655                Some(TableSource::Subquery(body)) => rewrite(body, ctes),
656                Some(TableSource::InlineGraphFunction { nodes, edges, .. }) => {
657                    rewrite(nodes, ctes);
658                    rewrite(edges, ctes);
659                }
660                _ => {}
661            }
662            rewrite_table_query_parts(tq, ctes);
663        }
664        QueryExpr::Join(jq) => {
665            rewrite(&mut jq.left, ctes);
666            rewrite(&mut jq.right, ctes);
667            if let Some(filter) = jq.filter.as_mut() {
668                rewrite_filter(filter, ctes);
669            }
670            for item in &mut jq.order_by {
671                if let Some(expr) = item.expr.as_mut() {
672                    rewrite_expr(expr, ctes);
673                }
674            }
675            for item in &mut jq.return_items {
676                rewrite_select_item(item, ctes);
677            }
678            for projection in &mut jq.return_ {
679                rewrite_projection(projection, ctes);
680            }
681        }
682        QueryExpr::Graph(gq) => {
683            if let Some(filter) = gq.filter.as_mut() {
684                rewrite_filter(filter, ctes);
685            }
686            for projection in &mut gq.return_ {
687                rewrite_projection(projection, ctes);
688            }
689        }
690        _ => {}
691    }
692}
693
694fn rewrite_table_query_parts(
695    tq: &mut super::super::ast::TableQuery,
696    ctes: &HashMap<String, QueryExpr>,
697) {
698    for item in &mut tq.select_items {
699        rewrite_select_item(item, ctes);
700    }
701    for projection in &mut tq.columns {
702        rewrite_projection(projection, ctes);
703    }
704    if let Some(expr) = tq.where_expr.as_mut() {
705        rewrite_expr(expr, ctes);
706    }
707    if let Some(filter) = tq.filter.as_mut() {
708        rewrite_filter(filter, ctes);
709    }
710    for expr in &mut tq.group_by_exprs {
711        rewrite_expr(expr, ctes);
712    }
713    if let Some(expr) = tq.having_expr.as_mut() {
714        rewrite_expr(expr, ctes);
715    }
716    if let Some(filter) = tq.having.as_mut() {
717        rewrite_filter(filter, ctes);
718    }
719    for item in &mut tq.order_by {
720        if let Some(expr) = item.expr.as_mut() {
721            rewrite_expr(expr, ctes);
722        }
723    }
724}
725
726fn rewrite_select_item(item: &mut SelectItem, ctes: &HashMap<String, QueryExpr>) {
727    if let SelectItem::Expr { expr, .. } = item {
728        rewrite_expr(expr, ctes);
729    }
730}
731
732fn rewrite_projection(projection: &mut Projection, ctes: &HashMap<String, QueryExpr>) {
733    match projection {
734        Projection::Function(_, args) => {
735            for arg in args {
736                rewrite_projection(arg, ctes);
737            }
738        }
739        Projection::Expression(filter, _) => rewrite_filter(filter, ctes),
740        Projection::Window { args, window, .. } => {
741            for arg in args {
742                rewrite_projection(arg, ctes);
743            }
744            for expr in &mut window.partition_by {
745                rewrite_expr(expr, ctes);
746            }
747            for item in &mut window.order_by {
748                rewrite_expr(&mut item.expr, ctes);
749            }
750        }
751        Projection::All
752        | Projection::Column(_)
753        | Projection::Alias(_, _)
754        | Projection::Field(_, _) => {}
755    }
756}
757
758fn rewrite_filter(filter: &mut Filter, ctes: &HashMap<String, QueryExpr>) {
759    match filter {
760        Filter::CompareExpr { lhs, rhs, .. } => {
761            rewrite_expr(lhs, ctes);
762            rewrite_expr(rhs, ctes);
763        }
764        Filter::And(left, right) | Filter::Or(left, right) => {
765            rewrite_filter(left, ctes);
766            rewrite_filter(right, ctes);
767        }
768        Filter::Not(inner) => rewrite_filter(inner, ctes),
769        Filter::Compare { .. }
770        | Filter::CompareFields { .. }
771        | Filter::IsNull(_)
772        | Filter::IsNotNull(_)
773        | Filter::In { .. }
774        | Filter::Between { .. }
775        | Filter::Like { .. }
776        | Filter::StartsWith { .. }
777        | Filter::EndsWith { .. }
778        | Filter::Contains { .. } => {}
779    }
780}
781
782fn rewrite_expr(expr: &mut Expr, ctes: &HashMap<String, QueryExpr>) {
783    match expr {
784        Expr::BinaryOp { lhs, rhs, .. } => {
785            rewrite_expr(lhs, ctes);
786            rewrite_expr(rhs, ctes);
787        }
788        Expr::UnaryOp { operand, .. } => rewrite_expr(operand, ctes),
789        Expr::Cast { inner, .. } => rewrite_expr(inner, ctes),
790        Expr::FunctionCall { args, .. } => {
791            for arg in args {
792                rewrite_expr(arg, ctes);
793            }
794        }
795        Expr::Case {
796            branches, else_, ..
797        } => {
798            for (condition, value) in branches {
799                rewrite_expr(condition, ctes);
800                rewrite_expr(value, ctes);
801            }
802            if let Some(value) = else_ {
803                rewrite_expr(value, ctes);
804            }
805        }
806        Expr::IsNull { operand, .. } => rewrite_expr(operand, ctes),
807        Expr::InList { target, values, .. } => {
808            rewrite_expr(target, ctes);
809            for value in values {
810                rewrite_expr(value, ctes);
811            }
812        }
813        Expr::Between {
814            target, low, high, ..
815        } => {
816            rewrite_expr(target, ctes);
817            rewrite_expr(low, ctes);
818            rewrite_expr(high, ctes);
819        }
820        Expr::Subquery { query, .. } => rewrite(&mut query.query, ctes),
821        Expr::WindowFunctionCall { args, window, .. } => {
822            for arg in args {
823                rewrite_expr(arg, ctes);
824            }
825            for expr in &mut window.partition_by {
826                rewrite_expr(expr, ctes);
827            }
828            for item in &mut window.order_by {
829                rewrite_expr(&mut item.expr, ctes);
830            }
831        }
832        Expr::Literal { .. } | Expr::Column { .. } | Expr::Parameter { .. } => {}
833    }
834}
835
836// ============================================================================
837// Tests
838// ============================================================================
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use crate::storage::query::ast::CteQueryBuilder;
844    use crate::storage::query::WithClause;
845
846    fn mock_execute(
847        _query: &QueryExpr,
848        _ctx: &CteContext,
849    ) -> Result<UnifiedResult, ExecutionError> {
850        // Simple mock that returns empty result
851        Ok(UnifiedResult::empty())
852    }
853
854    #[test]
855    fn test_cte_context() {
856        let mut ctx = CteContext::new();
857
858        // Test empty context
859        assert!(ctx.get("test").is_none());
860        assert!(!ctx.is_evaluating("test"));
861
862        // Test storing results
863        let result = UnifiedResult::with_columns(vec!["col1".to_string()]);
864        ctx.store("test".to_string(), result);
865        assert!(ctx.get("test").is_some());
866
867        // Test evaluation tracking
868        ctx.start_evaluating("other");
869        assert!(ctx.is_evaluating("other"));
870        ctx.done_evaluating("other");
871        assert!(!ctx.is_evaluating("other"));
872    }
873
874    #[test]
875    fn test_simple_cte_execution() {
876        let executor = CteExecutor::new(|_query, _ctx| {
877            let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
878            let mut record = UnifiedRecord::new();
879            record.set("id", Value::Integer(1));
880            result.push(record);
881            Ok(result)
882        });
883
884        // Create a simple CTE query
885        let cte = CteDefinition {
886            name: "test_cte".to_string(),
887            columns: vec!["id".to_string()],
888            query: Box::new(QueryExpr::table("dummy").build()),
889            recursive: false,
890        };
891
892        let with_clause = WithClause::new().add(cte);
893        let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("test_cte").build());
894
895        let result = executor.execute(&query);
896        assert!(result.is_ok());
897    }
898
899    #[test]
900    fn test_cte_builder() {
901        let query = CteQueryBuilder::new()
902            .cte_with_columns(
903                "nums",
904                vec!["n".to_string()],
905                QueryExpr::table("numbers").build(),
906            )
907            .build(QueryExpr::table("nums").build());
908
909        assert!(query.with_clause.is_some());
910        let with_clause = query.with_clause.unwrap();
911        assert_eq!(with_clause.ctes.len(), 1);
912        assert_eq!(with_clause.ctes[0].name, "nums");
913    }
914
915    #[test]
916    fn test_recursive_cte_builder() {
917        let query = CteQueryBuilder::new()
918            .recursive_cte("paths", QueryExpr::table("connections").build())
919            .build(QueryExpr::table("paths").build());
920
921        assert!(query.with_clause.is_some());
922        let with_clause = query.with_clause.unwrap();
923        assert!(with_clause.has_recursive);
924        assert!(with_clause.ctes[0].recursive);
925    }
926
927    #[test]
928    fn test_circular_reference_detection() {
929        let mut ctx = CteContext::new();
930        ctx.start_evaluating("cte_a");
931
932        // Simulate trying to evaluate cte_a while it's being evaluated
933        assert!(ctx.is_evaluating("cte_a"));
934    }
935
936    #[test]
937    fn test_cte_stats() {
938        let ctx = CteContext::new();
939        let stats = ctx.stats();
940
941        assert_eq!(stats.ctes_executed, 0);
942        assert_eq!(stats.recursive_iterations, 0);
943        assert_eq!(stats.rows_produced, 0);
944    }
945
946    #[test]
947    fn test_hash_record() {
948        let executor = CteExecutor::new(mock_execute);
949
950        let mut record1 = UnifiedRecord::new();
951        record1.set("id", Value::Integer(1));
952        record1.set("name", Value::text("test".to_string()));
953
954        let mut record2 = UnifiedRecord::new();
955        record2.set("id", Value::Integer(1));
956        record2.set("name", Value::text("test".to_string()));
957
958        let mut record3 = UnifiedRecord::new();
959        record3.set("id", Value::Integer(2));
960        record3.set("name", Value::text("test".to_string()));
961
962        // Same content should have same hash
963        assert_eq!(
964            executor.hash_record(&record1),
965            executor.hash_record(&record2)
966        );
967
968        // Different content should have different hash
969        assert_ne!(
970            executor.hash_record(&record1),
971            executor.hash_record(&record3)
972        );
973    }
974
975    #[test]
976    fn test_hash_various_value_types() {
977        let executor = CteExecutor::new(mock_execute);
978
979        // Test hashing different value types
980        let mut record = UnifiedRecord::new();
981        record.set("null_val", Value::Null);
982        record.set("bool_val", Value::Boolean(true));
983        record.set("int_val", Value::Integer(42));
984        record.set("float_val", Value::Float(2.5));
985        record.set("text_val", Value::text("hello".to_string()));
986        record.set("blob_val", Value::Blob(vec![1, 2, 3]));
987        record.set("timestamp_val", Value::Timestamp(1234567890));
988        record.set("duration_val", Value::Duration(5000));
989
990        // Should not panic
991        let hash = executor.hash_record(&record);
992        assert!(hash > 0);
993    }
994
995    #[test]
996    fn test_project_columns() {
997        let executor = CteExecutor::new(mock_execute);
998
999        let mut original =
1000            UnifiedResult::with_columns(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
1001
1002        let mut record = UnifiedRecord::new();
1003        record.set("a", Value::Integer(1));
1004        record.set("b", Value::Integer(2));
1005        record.set("c", Value::Integer(3));
1006        original.push(record);
1007
1008        // Project to different column names
1009        let projected = executor.project_columns(&original, &["x".to_string(), "y".to_string()]);
1010
1011        assert_eq!(projected.columns, vec!["x", "y"]);
1012        assert_eq!(projected.len(), 1);
1013    }
1014
1015    #[test]
1016    fn test_empty_columns_projection() {
1017        let executor = CteExecutor::new(mock_execute);
1018
1019        let original = UnifiedResult::with_columns(vec!["a".to_string()]);
1020
1021        // Empty columns should return original
1022        let projected = executor.project_columns(&original, &[]);
1023        assert_eq!(projected.columns, original.columns);
1024    }
1025
1026    #[test]
1027    fn test_cte_with_multiple_definitions() {
1028        let executor = CteExecutor::new(|query, ctx| {
1029            // Return different results based on which CTE is being queried
1030            match query {
1031                QueryExpr::Table(t) if t.table == "base" => {
1032                    let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
1033                    let mut record = UnifiedRecord::new();
1034                    record.set("id", Value::Integer(1));
1035                    result.push(record);
1036                    Ok(result)
1037                }
1038                QueryExpr::Table(t) if t.table == "cte1" => {
1039                    // Should be able to see cte1 in context
1040                    if ctx.get("cte1").is_some() {
1041                        Ok(ctx.get("cte1").unwrap().clone())
1042                    } else {
1043                        Ok(UnifiedResult::empty())
1044                    }
1045                }
1046                _ => Ok(UnifiedResult::empty()),
1047            }
1048        });
1049
1050        let cte1 = CteDefinition {
1051            name: "cte1".to_string(),
1052            columns: vec!["id".to_string()],
1053            query: Box::new(QueryExpr::table("base").build()),
1054            recursive: false,
1055        };
1056
1057        let cte2 = CteDefinition {
1058            name: "cte2".to_string(),
1059            columns: vec!["id".to_string()],
1060            query: Box::new(QueryExpr::table("cte1").build()),
1061            recursive: false,
1062        };
1063
1064        let with_clause = WithClause::new().add(cte1).add(cte2);
1065        let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("cte2").build());
1066
1067        let result = executor.execute(&query);
1068        assert!(result.is_ok());
1069    }
1070}