reddb_server/storage/query/executors/
cte.rs1use std::collections::{HashMap, HashSet};
29
30use super::super::ast::{CteDefinition, QueryExpr, QueryWithCte};
31use super::super::unified::{ExecutionError, UnifiedRecord, UnifiedResult};
32use crate::storage::schema::Value;
33
34const MAX_RECURSION_DEPTH: usize = 1000;
36
37const MAX_RECURSIVE_ROWS: usize = 100_000;
39
40#[derive(Debug, Clone, Default)]
42pub struct CteContext {
43 tables: HashMap<String, UnifiedResult>,
45 evaluating: HashSet<String>,
47 stats: CteStats,
49}
50
51impl CteContext {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn get(&self, name: &str) -> Option<&UnifiedResult> {
59 self.tables.get(name)
60 }
61
62 pub fn store(&mut self, name: String, result: UnifiedResult) {
64 self.tables.insert(name, result);
65 }
66
67 pub fn is_evaluating(&self, name: &str) -> bool {
69 self.evaluating.contains(name)
70 }
71
72 pub fn start_evaluating(&mut self, name: &str) {
74 self.evaluating.insert(name.to_string());
75 }
76
77 pub fn done_evaluating(&mut self, name: &str) {
79 self.evaluating.remove(name);
80 }
81
82 pub fn stats(&self) -> &CteStats {
84 &self.stats
85 }
86}
87
88#[derive(Debug, Clone, Default)]
90pub struct CteStats {
91 pub ctes_executed: usize,
93 pub recursive_iterations: usize,
95 pub rows_produced: usize,
97 pub exec_time_us: u64,
99}
100
101pub struct CteExecutor<F>
103where
104 F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
105{
106 execute_fn: F,
108}
109
110impl<F> CteExecutor<F>
111where
112 F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
113{
114 pub fn new(execute_fn: F) -> Self {
116 Self { execute_fn }
117 }
118
119 pub fn execute(&self, query: &QueryWithCte) -> Result<UnifiedResult, ExecutionError> {
121 let start = std::time::Instant::now();
122 let mut ctx = CteContext::new();
123
124 if let Some(ref with_clause) = query.with_clause {
126 for cte in &with_clause.ctes {
127 self.materialize_cte(cte, &mut ctx)?;
128 }
129 }
130
131 let result = (self.execute_fn)(&query.query, &ctx)?;
133
134 ctx.stats.exec_time_us = start.elapsed().as_micros() as u64;
135 Ok(result)
136 }
137
138 fn materialize_cte(
140 &self,
141 cte: &CteDefinition,
142 ctx: &mut CteContext,
143 ) -> Result<(), ExecutionError> {
144 if ctx.is_evaluating(&cte.name) {
145 return Err(ExecutionError::new(format!(
146 "Circular CTE reference: {}",
147 cte.name
148 )));
149 }
150
151 if ctx.get(&cte.name).is_some() {
153 return Ok(());
154 }
155
156 ctx.start_evaluating(&cte.name);
157
158 let result = if cte.recursive {
159 self.execute_recursive_cte(cte, ctx)?
160 } else {
161 let result = (self.execute_fn)(&cte.query, ctx)?;
163 self.project_columns(&result, &cte.columns)
164 };
165
166 ctx.stats.ctes_executed += 1;
167 ctx.stats.rows_produced += result.len();
168 ctx.store(cte.name.clone(), result);
169 ctx.done_evaluating(&cte.name);
170
171 Ok(())
172 }
173
174 fn execute_recursive_cte(
176 &self,
177 cte: &CteDefinition,
178 ctx: &mut CteContext,
179 ) -> Result<UnifiedResult, ExecutionError> {
180 let mut all_results = UnifiedResult::with_columns(cte.columns.clone());
197 let mut working_table = UnifiedResult::with_columns(cte.columns.clone());
198 let mut seen_rows: HashSet<u64> = HashSet::new();
199 let mut iteration = 0;
200
201 let initial = (self.execute_fn)(&cte.query, ctx)?;
203 let initial = self.project_columns(&initial, &cte.columns);
204
205 for record in &initial.records {
206 let hash = self.hash_record(record);
207 if seen_rows.insert(hash) {
208 working_table.push(record.clone());
209 all_results.push(record.clone());
210 }
211 }
212
213 ctx.store(cte.name.clone(), working_table.clone());
215
216 while !working_table.is_empty() && iteration < MAX_RECURSION_DEPTH {
218 iteration += 1;
219 ctx.stats.recursive_iterations += 1;
220
221 if all_results.len() > MAX_RECURSIVE_ROWS {
222 return Err(ExecutionError::new(format!(
223 "Recursive CTE '{}' exceeded maximum rows ({})",
224 cte.name, MAX_RECURSIVE_ROWS
225 )));
226 }
227
228 let new_results = (self.execute_fn)(&cte.query, ctx)?;
230 let new_results = self.project_columns(&new_results, &cte.columns);
231
232 let mut new_working_table = UnifiedResult::with_columns(cte.columns.clone());
234 for record in &new_results.records {
235 let hash = self.hash_record(record);
236 if seen_rows.insert(hash) {
237 new_working_table.push(record.clone());
238 all_results.push(record.clone());
239 }
240 }
241
242 working_table = new_working_table;
243
244 ctx.store(cte.name.clone(), all_results.clone());
246 }
247
248 if iteration >= MAX_RECURSION_DEPTH && !working_table.is_empty() {
249 return Err(ExecutionError::new(format!(
250 "Recursive CTE '{}' exceeded maximum recursion depth ({})",
251 cte.name, MAX_RECURSION_DEPTH
252 )));
253 }
254
255 Ok(all_results)
256 }
257
258 fn project_columns(&self, result: &UnifiedResult, columns: &[String]) -> UnifiedResult {
260 if columns.is_empty() {
261 return result.clone();
262 }
263
264 let mut projected = UnifiedResult::with_columns(columns.to_vec());
265
266 for record in &result.records {
267 let mut new_record = UnifiedRecord::new();
268
269 for (i, col) in columns.iter().enumerate() {
271 let value = result
273 .columns
274 .get(i)
275 .and_then(|orig_col| record.get(orig_col))
276 .cloned()
277 .or_else(|| record.get(col).cloned())
278 .unwrap_or(Value::Null);
279
280 new_record.set(col, value);
281 }
282
283 projected.push(new_record);
284 }
285
286 projected
287 }
288
289 fn hash_record(&self, record: &UnifiedRecord) -> u64 {
291 use std::collections::hash_map::DefaultHasher;
292 use std::hash::{Hash, Hasher};
293
294 let mut hasher = DefaultHasher::new();
295
296 let mut keys = record.column_names();
298 keys.sort();
299
300 for key in &keys {
301 (**key).hash(&mut hasher);
302 if let Some(value) = record.get(key) {
303 Self::hash_value(value, &mut hasher);
304 }
305 }
306
307 hasher.finish()
308 }
309
310 fn hash_value(value: &Value, hasher: &mut impl std::hash::Hasher) {
312 use std::hash::Hash;
313
314 match value {
315 Value::Null => 0u8.hash(hasher),
316 Value::Boolean(b) => {
317 1u8.hash(hasher);
318 b.hash(hasher);
319 }
320 Value::Integer(i) => {
321 2u8.hash(hasher);
322 i.hash(hasher);
323 }
324 Value::UnsignedInteger(u) => {
325 3u8.hash(hasher);
326 u.hash(hasher);
327 }
328 Value::Float(f) => {
329 4u8.hash(hasher);
330 f.to_bits().hash(hasher);
331 }
332 Value::Text(s) => {
333 5u8.hash(hasher);
334 s.hash(hasher);
335 }
336 Value::Blob(b) => {
337 6u8.hash(hasher);
338 b.hash(hasher);
339 }
340 Value::Timestamp(t) => {
341 7u8.hash(hasher);
342 t.hash(hasher);
343 }
344 Value::Duration(d) => {
345 8u8.hash(hasher);
346 d.hash(hasher);
347 }
348 Value::IpAddr(addr) => {
349 9u8.hash(hasher);
350 match addr {
351 std::net::IpAddr::V4(v4) => v4.octets().hash(hasher),
352 std::net::IpAddr::V6(v6) => v6.octets().hash(hasher),
353 }
354 }
355 Value::MacAddr(mac) => {
356 10u8.hash(hasher);
357 mac.hash(hasher);
358 }
359 Value::Vector(v) => {
360 11u8.hash(hasher);
361 v.len().hash(hasher);
362 for f in v {
363 f.to_bits().hash(hasher);
364 }
365 }
366 Value::Json(j) => {
367 12u8.hash(hasher);
368 j.hash(hasher);
369 }
370 Value::Uuid(u) => {
371 13u8.hash(hasher);
372 u.hash(hasher);
373 }
374 Value::NodeRef(n) => {
375 14u8.hash(hasher);
376 n.hash(hasher);
377 }
378 Value::EdgeRef(e) => {
379 15u8.hash(hasher);
380 e.hash(hasher);
381 }
382 Value::VectorRef(coll, id) => {
383 16u8.hash(hasher);
384 coll.hash(hasher);
385 id.hash(hasher);
386 }
387 Value::RowRef(table, id) => {
388 17u8.hash(hasher);
389 table.hash(hasher);
390 id.hash(hasher);
391 }
392 Value::Color(rgb) => {
393 18u8.hash(hasher);
394 rgb.hash(hasher);
395 }
396 Value::Email(s) => {
397 19u8.hash(hasher);
398 s.hash(hasher);
399 }
400 Value::Url(s) => {
401 20u8.hash(hasher);
402 s.hash(hasher);
403 }
404 Value::Phone(n) => {
405 21u8.hash(hasher);
406 n.hash(hasher);
407 }
408 Value::Semver(v) => {
409 22u8.hash(hasher);
410 v.hash(hasher);
411 }
412 Value::Cidr(ip, prefix) => {
413 23u8.hash(hasher);
414 ip.hash(hasher);
415 prefix.hash(hasher);
416 }
417 Value::Date(d) => {
418 24u8.hash(hasher);
419 d.hash(hasher);
420 }
421 Value::Time(t) => {
422 25u8.hash(hasher);
423 t.hash(hasher);
424 }
425 Value::Decimal(v) => {
426 26u8.hash(hasher);
427 v.hash(hasher);
428 }
429 Value::EnumValue(i) => {
430 27u8.hash(hasher);
431 i.hash(hasher);
432 }
433 Value::Array(elems) => {
434 28u8.hash(hasher);
435 elems.len().hash(hasher);
436 for elem in elems {
437 Self::hash_value(elem, hasher);
438 }
439 }
440 Value::TimestampMs(v) => {
441 29u8.hash(hasher);
442 v.hash(hasher);
443 }
444 Value::Ipv4(v) => {
445 30u8.hash(hasher);
446 v.hash(hasher);
447 }
448 Value::Ipv6(bytes) => {
449 31u8.hash(hasher);
450 bytes.hash(hasher);
451 }
452 Value::Subnet(ip, mask) => {
453 32u8.hash(hasher);
454 ip.hash(hasher);
455 mask.hash(hasher);
456 }
457 Value::Port(v) => {
458 33u8.hash(hasher);
459 v.hash(hasher);
460 }
461 Value::Latitude(v) => {
462 34u8.hash(hasher);
463 v.hash(hasher);
464 }
465 Value::Longitude(v) => {
466 35u8.hash(hasher);
467 v.hash(hasher);
468 }
469 Value::GeoPoint(lat, lon) => {
470 36u8.hash(hasher);
471 lat.hash(hasher);
472 lon.hash(hasher);
473 }
474 Value::Country2(c) => {
475 37u8.hash(hasher);
476 c.hash(hasher);
477 }
478 Value::Country3(c) => {
479 38u8.hash(hasher);
480 c.hash(hasher);
481 }
482 Value::Lang2(c) => {
483 39u8.hash(hasher);
484 c.hash(hasher);
485 }
486 Value::Lang5(c) => {
487 40u8.hash(hasher);
488 c.hash(hasher);
489 }
490 Value::Currency(c) => {
491 41u8.hash(hasher);
492 c.hash(hasher);
493 }
494 Value::AssetCode(code) => {
495 50u8.hash(hasher);
496 code.hash(hasher);
497 }
498 Value::Money {
499 asset_code,
500 minor_units,
501 scale,
502 } => {
503 51u8.hash(hasher);
504 asset_code.hash(hasher);
505 minor_units.hash(hasher);
506 scale.hash(hasher);
507 }
508 Value::ColorAlpha(rgba) => {
509 42u8.hash(hasher);
510 rgba.hash(hasher);
511 }
512 Value::BigInt(v) => {
513 43u8.hash(hasher);
514 v.hash(hasher);
515 }
516 Value::KeyRef(col, key) => {
517 44u8.hash(hasher);
518 col.hash(hasher);
519 key.hash(hasher);
520 }
521 Value::DocRef(col, id) => {
522 45u8.hash(hasher);
523 col.hash(hasher);
524 id.hash(hasher);
525 }
526 Value::TableRef(name) => {
527 46u8.hash(hasher);
528 name.hash(hasher);
529 }
530 Value::PageRef(page_id) => {
531 47u8.hash(hasher);
532 page_id.hash(hasher);
533 }
534 Value::Secret(bytes) => {
535 48u8.hash(hasher);
536 bytes.hash(hasher);
537 }
538 Value::Password(hash) => {
539 49u8.hash(hasher);
540 hash.hash(hasher);
541 }
542 }
543 }
544}
545
546pub fn split_union_parts(query: &QueryExpr) -> Option<(QueryExpr, QueryExpr)> {
548 let _ = query;
551 None
552}
553
554pub fn inline_ctes(query: QueryWithCte) -> Result<QueryExpr, ExecutionError> {
573 let Some(with_clause) = query.with_clause else {
574 return Ok(query.query);
575 };
576 if with_clause.has_recursive {
577 return Err(ExecutionError::new(
578 "WITH RECURSIVE is not yet supported by the executor; \
579 non-recursive WITH clauses run today, recursive support \
580 is tracked separately"
581 .to_string(),
582 ));
583 }
584
585 let mut resolved: HashMap<String, QueryExpr> = HashMap::new();
589 for cte in &with_clause.ctes {
590 let mut body = (*cte.query).clone();
591 rewrite(&mut body, &resolved);
592 resolved.insert(cte.name.clone(), body);
593 }
594
595 let mut outer = query.query;
596 rewrite(&mut outer, &resolved);
597 Ok(outer)
598}
599
600fn rewrite(expr: &mut QueryExpr, ctes: &HashMap<String, QueryExpr>) {
610 use super::super::ast::TableSource;
611 match expr {
612 QueryExpr::Table(tq) => {
613 let lookup_name = match &tq.source {
614 Some(TableSource::Subquery(_)) => None,
615 Some(TableSource::Name(n)) => Some(n.clone()),
616 None => Some(tq.table.clone()),
617 };
618
619 if let Some(name) = lookup_name {
620 if let Some(body) = ctes.get(&name) {
621 let outer_has_constraints = tq.filter.is_some()
622 || tq.where_expr.is_some()
623 || tq.limit.is_some()
624 || tq.offset.is_some()
625 || !tq.columns.is_empty()
626 || !tq.select_items.is_empty()
627 || !tq.group_by.is_empty()
628 || !tq.order_by.is_empty();
629
630 if outer_has_constraints {
631 tq.source = Some(TableSource::Subquery(Box::new(body.clone())));
637 tq.table = format!("__cte_{name}");
638 } else {
639 *expr = body.clone();
643 }
644 return;
645 }
646 }
647
648 if let Some(TableSource::Subquery(body)) = tq.source.as_mut() {
649 rewrite(body, ctes);
650 }
651 }
652 QueryExpr::Join(jq) => {
653 rewrite(&mut jq.left, ctes);
654 rewrite(&mut jq.right, ctes);
655 }
656 _ => {}
657 }
658}
659
660#[cfg(test)]
665mod tests {
666 use super::*;
667 use crate::storage::query::ast::CteQueryBuilder;
668 use crate::storage::query::WithClause;
669
670 fn mock_execute(
671 _query: &QueryExpr,
672 _ctx: &CteContext,
673 ) -> Result<UnifiedResult, ExecutionError> {
674 Ok(UnifiedResult::empty())
676 }
677
678 #[test]
679 fn test_cte_context() {
680 let mut ctx = CteContext::new();
681
682 assert!(ctx.get("test").is_none());
684 assert!(!ctx.is_evaluating("test"));
685
686 let result = UnifiedResult::with_columns(vec!["col1".to_string()]);
688 ctx.store("test".to_string(), result);
689 assert!(ctx.get("test").is_some());
690
691 ctx.start_evaluating("other");
693 assert!(ctx.is_evaluating("other"));
694 ctx.done_evaluating("other");
695 assert!(!ctx.is_evaluating("other"));
696 }
697
698 #[test]
699 fn test_simple_cte_execution() {
700 let executor = CteExecutor::new(|_query, _ctx| {
701 let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
702 let mut record = UnifiedRecord::new();
703 record.set("id", Value::Integer(1));
704 result.push(record);
705 Ok(result)
706 });
707
708 let cte = CteDefinition {
710 name: "test_cte".to_string(),
711 columns: vec!["id".to_string()],
712 query: Box::new(QueryExpr::table("dummy").build()),
713 recursive: false,
714 };
715
716 let with_clause = WithClause::new().add(cte);
717 let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("test_cte").build());
718
719 let result = executor.execute(&query);
720 assert!(result.is_ok());
721 }
722
723 #[test]
724 fn test_cte_builder() {
725 let query = CteQueryBuilder::new()
726 .cte_with_columns(
727 "nums",
728 vec!["n".to_string()],
729 QueryExpr::table("numbers").build(),
730 )
731 .build(QueryExpr::table("nums").build());
732
733 assert!(query.with_clause.is_some());
734 let with_clause = query.with_clause.unwrap();
735 assert_eq!(with_clause.ctes.len(), 1);
736 assert_eq!(with_clause.ctes[0].name, "nums");
737 }
738
739 #[test]
740 fn test_recursive_cte_builder() {
741 let query = CteQueryBuilder::new()
742 .recursive_cte("paths", QueryExpr::table("connections").build())
743 .build(QueryExpr::table("paths").build());
744
745 assert!(query.with_clause.is_some());
746 let with_clause = query.with_clause.unwrap();
747 assert!(with_clause.has_recursive);
748 assert!(with_clause.ctes[0].recursive);
749 }
750
751 #[test]
752 fn test_circular_reference_detection() {
753 let mut ctx = CteContext::new();
754 ctx.start_evaluating("cte_a");
755
756 assert!(ctx.is_evaluating("cte_a"));
758 }
759
760 #[test]
761 fn test_cte_stats() {
762 let ctx = CteContext::new();
763 let stats = ctx.stats();
764
765 assert_eq!(stats.ctes_executed, 0);
766 assert_eq!(stats.recursive_iterations, 0);
767 assert_eq!(stats.rows_produced, 0);
768 }
769
770 #[test]
771 fn test_hash_record() {
772 let executor = CteExecutor::new(mock_execute);
773
774 let mut record1 = UnifiedRecord::new();
775 record1.set("id", Value::Integer(1));
776 record1.set("name", Value::text("test".to_string()));
777
778 let mut record2 = UnifiedRecord::new();
779 record2.set("id", Value::Integer(1));
780 record2.set("name", Value::text("test".to_string()));
781
782 let mut record3 = UnifiedRecord::new();
783 record3.set("id", Value::Integer(2));
784 record3.set("name", Value::text("test".to_string()));
785
786 assert_eq!(
788 executor.hash_record(&record1),
789 executor.hash_record(&record2)
790 );
791
792 assert_ne!(
794 executor.hash_record(&record1),
795 executor.hash_record(&record3)
796 );
797 }
798
799 #[test]
800 fn test_hash_various_value_types() {
801 let executor = CteExecutor::new(mock_execute);
802
803 let mut record = UnifiedRecord::new();
805 record.set("null_val", Value::Null);
806 record.set("bool_val", Value::Boolean(true));
807 record.set("int_val", Value::Integer(42));
808 record.set("float_val", Value::Float(2.5));
809 record.set("text_val", Value::text("hello".to_string()));
810 record.set("blob_val", Value::Blob(vec![1, 2, 3]));
811 record.set("timestamp_val", Value::Timestamp(1234567890));
812 record.set("duration_val", Value::Duration(5000));
813
814 let hash = executor.hash_record(&record);
816 assert!(hash > 0);
817 }
818
819 #[test]
820 fn test_project_columns() {
821 let executor = CteExecutor::new(mock_execute);
822
823 let mut original =
824 UnifiedResult::with_columns(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
825
826 let mut record = UnifiedRecord::new();
827 record.set("a", Value::Integer(1));
828 record.set("b", Value::Integer(2));
829 record.set("c", Value::Integer(3));
830 original.push(record);
831
832 let projected = executor.project_columns(&original, &["x".to_string(), "y".to_string()]);
834
835 assert_eq!(projected.columns, vec!["x", "y"]);
836 assert_eq!(projected.len(), 1);
837 }
838
839 #[test]
840 fn test_empty_columns_projection() {
841 let executor = CteExecutor::new(mock_execute);
842
843 let original = UnifiedResult::with_columns(vec!["a".to_string()]);
844
845 let projected = executor.project_columns(&original, &[]);
847 assert_eq!(projected.columns, original.columns);
848 }
849
850 #[test]
851 fn test_cte_with_multiple_definitions() {
852 let executor = CteExecutor::new(|query, ctx| {
853 match query {
855 QueryExpr::Table(t) if t.table == "base" => {
856 let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
857 let mut record = UnifiedRecord::new();
858 record.set("id", Value::Integer(1));
859 result.push(record);
860 Ok(result)
861 }
862 QueryExpr::Table(t) if t.table == "cte1" => {
863 if ctx.get("cte1").is_some() {
865 Ok(ctx.get("cte1").unwrap().clone())
866 } else {
867 Ok(UnifiedResult::empty())
868 }
869 }
870 _ => Ok(UnifiedResult::empty()),
871 }
872 });
873
874 let cte1 = CteDefinition {
875 name: "cte1".to_string(),
876 columns: vec!["id".to_string()],
877 query: Box::new(QueryExpr::table("base").build()),
878 recursive: false,
879 };
880
881 let cte2 = CteDefinition {
882 name: "cte2".to_string(),
883 columns: vec!["id".to_string()],
884 query: Box::new(QueryExpr::table("cte1").build()),
885 recursive: false,
886 };
887
888 let with_clause = WithClause::new().add(cte1).add(cte2);
889 let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("cte2").build());
890
891 let result = executor.execute(&query);
892 assert!(result.is_ok());
893 }
894}