Skip to main content

reddb_server/storage/query/executors/
cte.rs

1//! CTE (Common Table Expression) executor.
2//!
3//! Implements the `WITH … AS (…) SELECT …` SQL standard form. CTEs
4//! are materialised as temporary result sets that the main query can
5//! reference by name. Recursive CTEs (`WITH RECURSIVE`) use the
6//! classic iterative-fixpoint shape:
7//!
8//! 1. Execute the non-recursive (base) part once.
9//! 2. Repeat: execute the recursive part with the previous iteration's
10//!    rows visible as the CTE.
11//! 3. Stop when no new rows are produced (or guard limits trip).
12//!
13//! Today only non-recursive CTEs are wired through the runtime —
14//! `inline_ctes` rejects `WITH RECURSIVE` with a clear error. The
15//! iterative-fixpoint code in `CteExecutor` is reachable only via
16//! direct unit tests and is the basis for the future recursive wire-
17//! up (#41 follow-up).
18//!
19//! # Example
20//!
21//! ```ignore
22//! WITH active AS (
23//!     SELECT id, name FROM users WHERE status = 'active'
24//! )
25//! SELECT * FROM active
26//! ```
27
28use std::collections::{HashMap, HashSet};
29
30use super::super::ast::{CteDefinition, QueryExpr, QueryWithCte};
31use super::super::unified::{ExecutionError, UnifiedRecord, UnifiedResult};
32use crate::storage::schema::Value;
33
34/// Maximum recursion depth to prevent infinite loops
35const MAX_RECURSION_DEPTH: usize = 1000;
36
37/// Maximum total rows across all iterations
38const MAX_RECURSIVE_ROWS: usize = 100_000;
39
40/// CTE execution context holding materialized CTE results
41#[derive(Debug, Clone, Default)]
42pub struct CteContext {
43    /// Materialized CTE results by name
44    tables: HashMap<String, UnifiedResult>,
45    /// Track which CTEs are currently being evaluated (for cycle detection)
46    evaluating: HashSet<String>,
47    /// Statistics
48    stats: CteStats,
49}
50
51impl CteContext {
52    /// Create a new CTE context
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Get a materialized CTE result by name
58    pub fn get(&self, name: &str) -> Option<&UnifiedResult> {
59        self.tables.get(name)
60    }
61
62    /// Store a materialized CTE result
63    pub fn store(&mut self, name: String, result: UnifiedResult) {
64        self.tables.insert(name, result);
65    }
66
67    /// Check if a CTE is being evaluated (for recursion detection)
68    pub fn is_evaluating(&self, name: &str) -> bool {
69        self.evaluating.contains(name)
70    }
71
72    /// Mark a CTE as being evaluated
73    pub fn start_evaluating(&mut self, name: &str) {
74        self.evaluating.insert(name.to_string());
75    }
76
77    /// Mark a CTE as done evaluating
78    pub fn done_evaluating(&mut self, name: &str) {
79        self.evaluating.remove(name);
80    }
81
82    /// Get execution statistics
83    pub fn stats(&self) -> &CteStats {
84        &self.stats
85    }
86}
87
88/// Statistics about CTE execution
89#[derive(Debug, Clone, Default)]
90pub struct CteStats {
91    /// Number of CTEs executed
92    pub ctes_executed: usize,
93    /// Number of recursive iterations
94    pub recursive_iterations: usize,
95    /// Total rows produced by CTEs
96    pub rows_produced: usize,
97    /// Execution time in microseconds
98    pub exec_time_us: u64,
99}
100
101/// CTE Executor
102pub struct CteExecutor<F>
103where
104    F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
105{
106    /// Function to execute a query with CTE context
107    execute_fn: F,
108}
109
110impl<F> CteExecutor<F>
111where
112    F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
113{
114    /// Create a new CTE executor
115    pub fn new(execute_fn: F) -> Self {
116        Self { execute_fn }
117    }
118
119    /// Execute a query with CTEs
120    pub fn execute(&self, query: &QueryWithCte) -> Result<UnifiedResult, ExecutionError> {
121        let start = std::time::Instant::now();
122        let mut ctx = CteContext::new();
123
124        // Materialize all CTEs in order
125        if let Some(ref with_clause) = query.with_clause {
126            for cte in &with_clause.ctes {
127                self.materialize_cte(cte, &mut ctx)?;
128            }
129        }
130
131        // Execute the main query with CTE context
132        let result = (self.execute_fn)(&query.query, &ctx)?;
133
134        ctx.stats.exec_time_us = start.elapsed().as_micros() as u64;
135        Ok(result)
136    }
137
138    /// Materialize a single CTE
139    fn materialize_cte(
140        &self,
141        cte: &CteDefinition,
142        ctx: &mut CteContext,
143    ) -> Result<(), ExecutionError> {
144        if ctx.is_evaluating(&cte.name) {
145            return Err(ExecutionError::new(format!(
146                "Circular CTE reference: {}",
147                cte.name
148            )));
149        }
150
151        // Check if already materialized
152        if ctx.get(&cte.name).is_some() {
153            return Ok(());
154        }
155
156        ctx.start_evaluating(&cte.name);
157
158        let result = if cte.recursive {
159            self.execute_recursive_cte(cte, ctx)?
160        } else {
161            // Simple CTE: execute once
162            let result = (self.execute_fn)(&cte.query, ctx)?;
163            self.project_columns(&result, &cte.columns)
164        };
165
166        ctx.stats.ctes_executed += 1;
167        ctx.stats.rows_produced += result.len();
168        ctx.store(cte.name.clone(), result);
169        ctx.done_evaluating(&cte.name);
170
171        Ok(())
172    }
173
174    /// Execute a recursive CTE using iterative fixpoint
175    fn execute_recursive_cte(
176        &self,
177        cte: &CteDefinition,
178        ctx: &mut CteContext,
179    ) -> Result<UnifiedResult, ExecutionError> {
180        // For recursive CTEs, we need to handle UNION ALL structure
181        // The query should be: base_query UNION ALL recursive_query
182        //
183        // Algorithm:
184        // 1. Execute base query -> working_table
185        // 2. result_table = working_table
186        // 3. While working_table not empty:
187        //    a. Execute recursive query with CTE = working_table
188        //    b. new_rows = result - already_seen
189        //    c. working_table = new_rows
190        //    d. result_table += new_rows
191        // 4. Return result_table
192
193        // For simplicity in first implementation, we execute the full query
194        // iteratively, building up the working table
195
196        let mut all_results = UnifiedResult::with_columns(cte.columns.clone());
197        let mut working_table = UnifiedResult::with_columns(cte.columns.clone());
198        let mut seen_rows: HashSet<u64> = HashSet::new();
199        let mut iteration = 0;
200
201        // First iteration: execute the full query (base case)
202        let initial = (self.execute_fn)(&cte.query, ctx)?;
203        let initial = self.project_columns(&initial, &cte.columns);
204
205        for record in &initial.records {
206            let hash = self.hash_record(record);
207            if seen_rows.insert(hash) {
208                working_table.push(record.clone());
209                all_results.push(record.clone());
210            }
211        }
212
213        // Store initial results so recursive references can see them
214        ctx.store(cte.name.clone(), working_table.clone());
215
216        // Iterate until fixpoint
217        while !working_table.is_empty() && iteration < MAX_RECURSION_DEPTH {
218            iteration += 1;
219            ctx.stats.recursive_iterations += 1;
220
221            if all_results.len() > MAX_RECURSIVE_ROWS {
222                return Err(ExecutionError::new(format!(
223                    "Recursive CTE '{}' exceeded maximum rows ({})",
224                    cte.name, MAX_RECURSIVE_ROWS
225                )));
226            }
227
228            // Execute query with current CTE contents
229            let new_results = (self.execute_fn)(&cte.query, ctx)?;
230            let new_results = self.project_columns(&new_results, &cte.columns);
231
232            // Find genuinely new rows
233            let mut new_working_table = UnifiedResult::with_columns(cte.columns.clone());
234            for record in &new_results.records {
235                let hash = self.hash_record(record);
236                if seen_rows.insert(hash) {
237                    new_working_table.push(record.clone());
238                    all_results.push(record.clone());
239                }
240            }
241
242            working_table = new_working_table;
243
244            // Update CTE table for next iteration
245            ctx.store(cte.name.clone(), all_results.clone());
246        }
247
248        if iteration >= MAX_RECURSION_DEPTH && !working_table.is_empty() {
249            return Err(ExecutionError::new(format!(
250                "Recursive CTE '{}' exceeded maximum recursion depth ({})",
251                cte.name, MAX_RECURSION_DEPTH
252            )));
253        }
254
255        Ok(all_results)
256    }
257
258    /// Project result columns according to CTE column list
259    fn project_columns(&self, result: &UnifiedResult, columns: &[String]) -> UnifiedResult {
260        if columns.is_empty() {
261            return result.clone();
262        }
263
264        let mut projected = UnifiedResult::with_columns(columns.to_vec());
265
266        for record in &result.records {
267            let mut new_record = UnifiedRecord::new();
268
269            // Map result columns to CTE columns
270            for (i, col) in columns.iter().enumerate() {
271                // Try to find value by position first, then by name
272                let value = result
273                    .columns
274                    .get(i)
275                    .and_then(|orig_col| record.get(orig_col))
276                    .cloned()
277                    .or_else(|| record.get(col).cloned())
278                    .unwrap_or(Value::Null);
279
280                new_record.set(col, value);
281            }
282
283            projected.push(new_record);
284        }
285
286        projected
287    }
288
289    /// Hash a record for deduplication
290    fn hash_record(&self, record: &UnifiedRecord) -> u64 {
291        use std::collections::hash_map::DefaultHasher;
292        use std::hash::{Hash, Hasher};
293
294        let mut hasher = DefaultHasher::new();
295
296        // Hash all values in deterministic order
297        let mut keys = record.column_names();
298        keys.sort();
299
300        for key in &keys {
301            (**key).hash(&mut hasher);
302            if let Some(value) = record.get(key) {
303                Self::hash_value(value, &mut hasher);
304            }
305        }
306
307        hasher.finish()
308    }
309
310    /// Hash a Value for deduplication
311    fn hash_value(value: &Value, hasher: &mut impl std::hash::Hasher) {
312        use std::hash::Hash;
313
314        match value {
315            Value::Null => 0u8.hash(hasher),
316            Value::Boolean(b) => {
317                1u8.hash(hasher);
318                b.hash(hasher);
319            }
320            Value::Integer(i) => {
321                2u8.hash(hasher);
322                i.hash(hasher);
323            }
324            Value::UnsignedInteger(u) => {
325                3u8.hash(hasher);
326                u.hash(hasher);
327            }
328            Value::Float(f) => {
329                4u8.hash(hasher);
330                f.to_bits().hash(hasher);
331            }
332            Value::Text(s) => {
333                5u8.hash(hasher);
334                s.hash(hasher);
335            }
336            Value::Blob(b) => {
337                6u8.hash(hasher);
338                b.hash(hasher);
339            }
340            Value::Timestamp(t) => {
341                7u8.hash(hasher);
342                t.hash(hasher);
343            }
344            Value::Duration(d) => {
345                8u8.hash(hasher);
346                d.hash(hasher);
347            }
348            Value::IpAddr(addr) => {
349                9u8.hash(hasher);
350                match addr {
351                    std::net::IpAddr::V4(v4) => v4.octets().hash(hasher),
352                    std::net::IpAddr::V6(v6) => v6.octets().hash(hasher),
353                }
354            }
355            Value::MacAddr(mac) => {
356                10u8.hash(hasher);
357                mac.hash(hasher);
358            }
359            Value::Vector(v) => {
360                11u8.hash(hasher);
361                v.len().hash(hasher);
362                for f in v {
363                    f.to_bits().hash(hasher);
364                }
365            }
366            Value::Json(j) => {
367                12u8.hash(hasher);
368                j.hash(hasher);
369            }
370            Value::Uuid(u) => {
371                13u8.hash(hasher);
372                u.hash(hasher);
373            }
374            Value::NodeRef(n) => {
375                14u8.hash(hasher);
376                n.hash(hasher);
377            }
378            Value::EdgeRef(e) => {
379                15u8.hash(hasher);
380                e.hash(hasher);
381            }
382            Value::VectorRef(coll, id) => {
383                16u8.hash(hasher);
384                coll.hash(hasher);
385                id.hash(hasher);
386            }
387            Value::RowRef(table, id) => {
388                17u8.hash(hasher);
389                table.hash(hasher);
390                id.hash(hasher);
391            }
392            Value::Color(rgb) => {
393                18u8.hash(hasher);
394                rgb.hash(hasher);
395            }
396            Value::Email(s) => {
397                19u8.hash(hasher);
398                s.hash(hasher);
399            }
400            Value::Url(s) => {
401                20u8.hash(hasher);
402                s.hash(hasher);
403            }
404            Value::Phone(n) => {
405                21u8.hash(hasher);
406                n.hash(hasher);
407            }
408            Value::Semver(v) => {
409                22u8.hash(hasher);
410                v.hash(hasher);
411            }
412            Value::Cidr(ip, prefix) => {
413                23u8.hash(hasher);
414                ip.hash(hasher);
415                prefix.hash(hasher);
416            }
417            Value::Date(d) => {
418                24u8.hash(hasher);
419                d.hash(hasher);
420            }
421            Value::Time(t) => {
422                25u8.hash(hasher);
423                t.hash(hasher);
424            }
425            Value::Decimal(v) => {
426                26u8.hash(hasher);
427                v.hash(hasher);
428            }
429            Value::EnumValue(i) => {
430                27u8.hash(hasher);
431                i.hash(hasher);
432            }
433            Value::Array(elems) => {
434                28u8.hash(hasher);
435                elems.len().hash(hasher);
436                for elem in elems {
437                    Self::hash_value(elem, hasher);
438                }
439            }
440            Value::TimestampMs(v) => {
441                29u8.hash(hasher);
442                v.hash(hasher);
443            }
444            Value::Ipv4(v) => {
445                30u8.hash(hasher);
446                v.hash(hasher);
447            }
448            Value::Ipv6(bytes) => {
449                31u8.hash(hasher);
450                bytes.hash(hasher);
451            }
452            Value::Subnet(ip, mask) => {
453                32u8.hash(hasher);
454                ip.hash(hasher);
455                mask.hash(hasher);
456            }
457            Value::Port(v) => {
458                33u8.hash(hasher);
459                v.hash(hasher);
460            }
461            Value::Latitude(v) => {
462                34u8.hash(hasher);
463                v.hash(hasher);
464            }
465            Value::Longitude(v) => {
466                35u8.hash(hasher);
467                v.hash(hasher);
468            }
469            Value::GeoPoint(lat, lon) => {
470                36u8.hash(hasher);
471                lat.hash(hasher);
472                lon.hash(hasher);
473            }
474            Value::Country2(c) => {
475                37u8.hash(hasher);
476                c.hash(hasher);
477            }
478            Value::Country3(c) => {
479                38u8.hash(hasher);
480                c.hash(hasher);
481            }
482            Value::Lang2(c) => {
483                39u8.hash(hasher);
484                c.hash(hasher);
485            }
486            Value::Lang5(c) => {
487                40u8.hash(hasher);
488                c.hash(hasher);
489            }
490            Value::Currency(c) => {
491                41u8.hash(hasher);
492                c.hash(hasher);
493            }
494            Value::AssetCode(code) => {
495                50u8.hash(hasher);
496                code.hash(hasher);
497            }
498            Value::Money {
499                asset_code,
500                minor_units,
501                scale,
502            } => {
503                51u8.hash(hasher);
504                asset_code.hash(hasher);
505                minor_units.hash(hasher);
506                scale.hash(hasher);
507            }
508            Value::ColorAlpha(rgba) => {
509                42u8.hash(hasher);
510                rgba.hash(hasher);
511            }
512            Value::BigInt(v) => {
513                43u8.hash(hasher);
514                v.hash(hasher);
515            }
516            Value::KeyRef(col, key) => {
517                44u8.hash(hasher);
518                col.hash(hasher);
519                key.hash(hasher);
520            }
521            Value::DocRef(col, id) => {
522                45u8.hash(hasher);
523                col.hash(hasher);
524                id.hash(hasher);
525            }
526            Value::TableRef(name) => {
527                46u8.hash(hasher);
528                name.hash(hasher);
529            }
530            Value::PageRef(page_id) => {
531                47u8.hash(hasher);
532                page_id.hash(hasher);
533            }
534            Value::Secret(bytes) => {
535                48u8.hash(hasher);
536                bytes.hash(hasher);
537            }
538            Value::Password(hash) => {
539                49u8.hash(hasher);
540                hash.hash(hasher);
541            }
542        }
543    }
544}
545
546/// Helper to parse UNION structure for recursive CTEs
547pub fn split_union_parts(query: &QueryExpr) -> Option<(QueryExpr, QueryExpr)> {
548    // UNION support is not represented in the current AST; recursive queries execute
549    // the full body expression each iteration.
550    let _ = query;
551    None
552}
553
554// ─────────────────────────────────────────────────────────────────────
555// CTE inlining (#41) — non-recursive
556//
557// Rewrites a `QueryWithCte` into a plain `QueryExpr` by walking the
558// AST and substituting every `TableSource::Name(name)` (or legacy
559// `TableQuery.table` field) that matches a CTE name with
560// `TableSource::Subquery(cte.query)`. After this pass the runtime's
561// existing subquery-in-FROM machinery executes the result with no
562// CTE-specific dispatch needed.
563//
564// Recursive CTEs are rejected up-front — the iterative fixpoint
565// strategy is implemented in `CteExecutor` but is not wired into the
566// runtime yet (separate slice).
567// ─────────────────────────────────────────────────────────────────────
568
569/// Inline a `QueryWithCte`'s WITH clause into its inner query. Returns
570/// the rewritten `QueryExpr` ready for dispatch. Recursive CTEs are
571/// rejected with a clear error.
572pub fn inline_ctes(query: QueryWithCte) -> Result<QueryExpr, ExecutionError> {
573    let Some(with_clause) = query.with_clause else {
574        return Ok(query.query);
575    };
576    if with_clause.has_recursive {
577        return Err(ExecutionError::new(
578            "WITH RECURSIVE is not yet supported by the executor; \
579             non-recursive WITH clauses run today, recursive support \
580             is tracked separately"
581                .to_string(),
582        ));
583    }
584
585    // Inline each CTE into its successors first so chained CTEs
586    // (`WITH a AS (...), b AS (... a ...)`) end up with fully resolved
587    // bodies before they're substituted into the outer query.
588    let mut resolved: HashMap<String, QueryExpr> = HashMap::new();
589    for cte in &with_clause.ctes {
590        let mut body = (*cte.query).clone();
591        rewrite(&mut body, &resolved);
592        resolved.insert(cte.name.clone(), body);
593    }
594
595    let mut outer = query.query;
596    rewrite(&mut outer, &resolved);
597    Ok(outer)
598}
599
600/// Walk a `QueryExpr` and replace any table reference whose name
601/// matches a key in `ctes` with the inlined CTE body. Recurses
602/// through `Join` and nested `Subquery` sources so CTE refs inside
603/// JOINs and subqueries resolve too. Mirrors the view-rewrite
604/// convention: when the outer table reference carries filter / limit
605/// / offset constraints we wrap the body in a `Subquery` to preserve
606/// them; otherwise we replace the whole `Table` node verbatim with
607/// the CTE body so dispatchers that key off `QueryExpr::Table` (like
608/// the JOIN executor) see the right shape.
609fn rewrite(expr: &mut QueryExpr, ctes: &HashMap<String, QueryExpr>) {
610    use super::super::ast::TableSource;
611    match expr {
612        QueryExpr::Table(tq) => {
613            let lookup_name = match &tq.source {
614                Some(TableSource::Subquery(_)) => None,
615                Some(TableSource::Name(n)) => Some(n.clone()),
616                // Table-valued functions are not CTE references (issue #795);
617                // the inline-graph form likewise references no CTE (issue #799).
618                Some(TableSource::Function { .. } | TableSource::InlineGraphFunction { .. }) => {
619                    None
620                }
621                None => Some(tq.table.clone()),
622            };
623
624            if let Some(name) = lookup_name {
625                if let Some(body) = ctes.get(&name) {
626                    let outer_has_constraints = tq.filter.is_some()
627                        || tq.where_expr.is_some()
628                        || tq.limit.is_some()
629                        || tq.offset.is_some()
630                        || !tq.columns.is_empty()
631                        || !tq.select_items.is_empty()
632                        || !tq.group_by.is_empty()
633                        || !tq.order_by.is_empty();
634
635                    if outer_has_constraints {
636                        // Outer ref carries projections / filters /
637                        // limits — keep those by wrapping the body in
638                        // a subquery source. Sentinel name so legacy
639                        // `table` consumers can't resolve it against
640                        // the real schema.
641                        tq.source = Some(TableSource::Subquery(Box::new(body.clone())));
642                        tq.table = format!("__cte_{name}");
643                    } else {
644                        // Bare `FROM cte` (possibly with alias) —
645                        // replace verbatim so JOIN / dispatch paths
646                        // see the CTE body's natural shape.
647                        *expr = body.clone();
648                    }
649                    return;
650                }
651            }
652
653            if let Some(TableSource::Subquery(body)) = tq.source.as_mut() {
654                rewrite(body, ctes);
655            }
656        }
657        QueryExpr::Join(jq) => {
658            rewrite(&mut jq.left, ctes);
659            rewrite(&mut jq.right, ctes);
660        }
661        _ => {}
662    }
663}
664
665// ============================================================================
666// Tests
667// ============================================================================
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use crate::storage::query::ast::CteQueryBuilder;
673    use crate::storage::query::WithClause;
674
675    fn mock_execute(
676        _query: &QueryExpr,
677        _ctx: &CteContext,
678    ) -> Result<UnifiedResult, ExecutionError> {
679        // Simple mock that returns empty result
680        Ok(UnifiedResult::empty())
681    }
682
683    #[test]
684    fn test_cte_context() {
685        let mut ctx = CteContext::new();
686
687        // Test empty context
688        assert!(ctx.get("test").is_none());
689        assert!(!ctx.is_evaluating("test"));
690
691        // Test storing results
692        let result = UnifiedResult::with_columns(vec!["col1".to_string()]);
693        ctx.store("test".to_string(), result);
694        assert!(ctx.get("test").is_some());
695
696        // Test evaluation tracking
697        ctx.start_evaluating("other");
698        assert!(ctx.is_evaluating("other"));
699        ctx.done_evaluating("other");
700        assert!(!ctx.is_evaluating("other"));
701    }
702
703    #[test]
704    fn test_simple_cte_execution() {
705        let executor = CteExecutor::new(|_query, _ctx| {
706            let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
707            let mut record = UnifiedRecord::new();
708            record.set("id", Value::Integer(1));
709            result.push(record);
710            Ok(result)
711        });
712
713        // Create a simple CTE query
714        let cte = CteDefinition {
715            name: "test_cte".to_string(),
716            columns: vec!["id".to_string()],
717            query: Box::new(QueryExpr::table("dummy").build()),
718            recursive: false,
719        };
720
721        let with_clause = WithClause::new().add(cte);
722        let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("test_cte").build());
723
724        let result = executor.execute(&query);
725        assert!(result.is_ok());
726    }
727
728    #[test]
729    fn test_cte_builder() {
730        let query = CteQueryBuilder::new()
731            .cte_with_columns(
732                "nums",
733                vec!["n".to_string()],
734                QueryExpr::table("numbers").build(),
735            )
736            .build(QueryExpr::table("nums").build());
737
738        assert!(query.with_clause.is_some());
739        let with_clause = query.with_clause.unwrap();
740        assert_eq!(with_clause.ctes.len(), 1);
741        assert_eq!(with_clause.ctes[0].name, "nums");
742    }
743
744    #[test]
745    fn test_recursive_cte_builder() {
746        let query = CteQueryBuilder::new()
747            .recursive_cte("paths", QueryExpr::table("connections").build())
748            .build(QueryExpr::table("paths").build());
749
750        assert!(query.with_clause.is_some());
751        let with_clause = query.with_clause.unwrap();
752        assert!(with_clause.has_recursive);
753        assert!(with_clause.ctes[0].recursive);
754    }
755
756    #[test]
757    fn test_circular_reference_detection() {
758        let mut ctx = CteContext::new();
759        ctx.start_evaluating("cte_a");
760
761        // Simulate trying to evaluate cte_a while it's being evaluated
762        assert!(ctx.is_evaluating("cte_a"));
763    }
764
765    #[test]
766    fn test_cte_stats() {
767        let ctx = CteContext::new();
768        let stats = ctx.stats();
769
770        assert_eq!(stats.ctes_executed, 0);
771        assert_eq!(stats.recursive_iterations, 0);
772        assert_eq!(stats.rows_produced, 0);
773    }
774
775    #[test]
776    fn test_hash_record() {
777        let executor = CteExecutor::new(mock_execute);
778
779        let mut record1 = UnifiedRecord::new();
780        record1.set("id", Value::Integer(1));
781        record1.set("name", Value::text("test".to_string()));
782
783        let mut record2 = UnifiedRecord::new();
784        record2.set("id", Value::Integer(1));
785        record2.set("name", Value::text("test".to_string()));
786
787        let mut record3 = UnifiedRecord::new();
788        record3.set("id", Value::Integer(2));
789        record3.set("name", Value::text("test".to_string()));
790
791        // Same content should have same hash
792        assert_eq!(
793            executor.hash_record(&record1),
794            executor.hash_record(&record2)
795        );
796
797        // Different content should have different hash
798        assert_ne!(
799            executor.hash_record(&record1),
800            executor.hash_record(&record3)
801        );
802    }
803
804    #[test]
805    fn test_hash_various_value_types() {
806        let executor = CteExecutor::new(mock_execute);
807
808        // Test hashing different value types
809        let mut record = UnifiedRecord::new();
810        record.set("null_val", Value::Null);
811        record.set("bool_val", Value::Boolean(true));
812        record.set("int_val", Value::Integer(42));
813        record.set("float_val", Value::Float(2.5));
814        record.set("text_val", Value::text("hello".to_string()));
815        record.set("blob_val", Value::Blob(vec![1, 2, 3]));
816        record.set("timestamp_val", Value::Timestamp(1234567890));
817        record.set("duration_val", Value::Duration(5000));
818
819        // Should not panic
820        let hash = executor.hash_record(&record);
821        assert!(hash > 0);
822    }
823
824    #[test]
825    fn test_project_columns() {
826        let executor = CteExecutor::new(mock_execute);
827
828        let mut original =
829            UnifiedResult::with_columns(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
830
831        let mut record = UnifiedRecord::new();
832        record.set("a", Value::Integer(1));
833        record.set("b", Value::Integer(2));
834        record.set("c", Value::Integer(3));
835        original.push(record);
836
837        // Project to different column names
838        let projected = executor.project_columns(&original, &["x".to_string(), "y".to_string()]);
839
840        assert_eq!(projected.columns, vec!["x", "y"]);
841        assert_eq!(projected.len(), 1);
842    }
843
844    #[test]
845    fn test_empty_columns_projection() {
846        let executor = CteExecutor::new(mock_execute);
847
848        let original = UnifiedResult::with_columns(vec!["a".to_string()]);
849
850        // Empty columns should return original
851        let projected = executor.project_columns(&original, &[]);
852        assert_eq!(projected.columns, original.columns);
853    }
854
855    #[test]
856    fn test_cte_with_multiple_definitions() {
857        let executor = CteExecutor::new(|query, ctx| {
858            // Return different results based on which CTE is being queried
859            match query {
860                QueryExpr::Table(t) if t.table == "base" => {
861                    let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
862                    let mut record = UnifiedRecord::new();
863                    record.set("id", Value::Integer(1));
864                    result.push(record);
865                    Ok(result)
866                }
867                QueryExpr::Table(t) if t.table == "cte1" => {
868                    // Should be able to see cte1 in context
869                    if ctx.get("cte1").is_some() {
870                        Ok(ctx.get("cte1").unwrap().clone())
871                    } else {
872                        Ok(UnifiedResult::empty())
873                    }
874                }
875                _ => Ok(UnifiedResult::empty()),
876            }
877        });
878
879        let cte1 = CteDefinition {
880            name: "cte1".to_string(),
881            columns: vec!["id".to_string()],
882            query: Box::new(QueryExpr::table("base").build()),
883            recursive: false,
884        };
885
886        let cte2 = CteDefinition {
887            name: "cte2".to_string(),
888            columns: vec!["id".to_string()],
889            query: Box::new(QueryExpr::table("cte1").build()),
890            recursive: false,
891        };
892
893        let with_clause = WithClause::new().add(cte1).add(cte2);
894        let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("cte2").build());
895
896        let result = executor.execute(&query);
897        assert!(result.is_ok());
898    }
899}