reddb_server/storage/query/executors/
cte.rs1use std::collections::{HashMap, HashSet};
29
30use super::super::ast::{CteDefinition, QueryExpr, QueryWithCte};
31use super::super::unified::{ExecutionError, UnifiedRecord, UnifiedResult};
32use crate::storage::schema::Value;
33
34const MAX_RECURSION_DEPTH: usize = 1000;
36
37const MAX_RECURSIVE_ROWS: usize = 100_000;
39
40#[derive(Debug, Clone, Default)]
42pub struct CteContext {
43 tables: HashMap<String, UnifiedResult>,
45 evaluating: HashSet<String>,
47 stats: CteStats,
49}
50
51impl CteContext {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn get(&self, name: &str) -> Option<&UnifiedResult> {
59 self.tables.get(name)
60 }
61
62 pub fn store(&mut self, name: String, result: UnifiedResult) {
64 self.tables.insert(name, result);
65 }
66
67 pub fn is_evaluating(&self, name: &str) -> bool {
69 self.evaluating.contains(name)
70 }
71
72 pub fn start_evaluating(&mut self, name: &str) {
74 self.evaluating.insert(name.to_string());
75 }
76
77 pub fn done_evaluating(&mut self, name: &str) {
79 self.evaluating.remove(name);
80 }
81
82 pub fn stats(&self) -> &CteStats {
84 &self.stats
85 }
86}
87
88#[derive(Debug, Clone, Default)]
90pub struct CteStats {
91 pub ctes_executed: usize,
93 pub recursive_iterations: usize,
95 pub rows_produced: usize,
97 pub exec_time_us: u64,
99}
100
101pub struct CteExecutor<F>
103where
104 F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
105{
106 execute_fn: F,
108}
109
110impl<F> CteExecutor<F>
111where
112 F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
113{
114 pub fn new(execute_fn: F) -> Self {
116 Self { execute_fn }
117 }
118
119 pub fn execute(&self, query: &QueryWithCte) -> Result<UnifiedResult, ExecutionError> {
121 let start = std::time::Instant::now();
122 let mut ctx = CteContext::new();
123
124 if let Some(ref with_clause) = query.with_clause {
126 for cte in &with_clause.ctes {
127 self.materialize_cte(cte, &mut ctx)?;
128 }
129 }
130
131 let result = (self.execute_fn)(&query.query, &ctx)?;
133
134 ctx.stats.exec_time_us = start.elapsed().as_micros() as u64;
135 Ok(result)
136 }
137
138 fn materialize_cte(
140 &self,
141 cte: &CteDefinition,
142 ctx: &mut CteContext,
143 ) -> Result<(), ExecutionError> {
144 if ctx.is_evaluating(&cte.name) {
145 return Err(ExecutionError::new(format!(
146 "Circular CTE reference: {}",
147 cte.name
148 )));
149 }
150
151 if ctx.get(&cte.name).is_some() {
153 return Ok(());
154 }
155
156 ctx.start_evaluating(&cte.name);
157
158 let result = if cte.recursive {
159 self.execute_recursive_cte(cte, ctx)?
160 } else {
161 let result = (self.execute_fn)(&cte.query, ctx)?;
163 self.project_columns(&result, &cte.columns)
164 };
165
166 ctx.stats.ctes_executed += 1;
167 ctx.stats.rows_produced += result.len();
168 ctx.store(cte.name.clone(), result);
169 ctx.done_evaluating(&cte.name);
170
171 Ok(())
172 }
173
174 fn execute_recursive_cte(
176 &self,
177 cte: &CteDefinition,
178 ctx: &mut CteContext,
179 ) -> Result<UnifiedResult, ExecutionError> {
180 let mut all_results = UnifiedResult::with_columns(cte.columns.clone());
197 let mut working_table = UnifiedResult::with_columns(cte.columns.clone());
198 let mut seen_rows: HashSet<u64> = HashSet::new();
199 let mut iteration = 0;
200
201 let initial = (self.execute_fn)(&cte.query, ctx)?;
203 let initial = self.project_columns(&initial, &cte.columns);
204
205 for record in &initial.records {
206 let hash = self.hash_record(record);
207 if seen_rows.insert(hash) {
208 working_table.push(record.clone());
209 all_results.push(record.clone());
210 }
211 }
212
213 ctx.store(cte.name.clone(), working_table.clone());
215
216 while !working_table.is_empty() && iteration < MAX_RECURSION_DEPTH {
218 iteration += 1;
219 ctx.stats.recursive_iterations += 1;
220
221 if all_results.len() > MAX_RECURSIVE_ROWS {
222 return Err(ExecutionError::new(format!(
223 "Recursive CTE '{}' exceeded maximum rows ({})",
224 cte.name, MAX_RECURSIVE_ROWS
225 )));
226 }
227
228 let new_results = (self.execute_fn)(&cte.query, ctx)?;
230 let new_results = self.project_columns(&new_results, &cte.columns);
231
232 let mut new_working_table = UnifiedResult::with_columns(cte.columns.clone());
234 for record in &new_results.records {
235 let hash = self.hash_record(record);
236 if seen_rows.insert(hash) {
237 new_working_table.push(record.clone());
238 all_results.push(record.clone());
239 }
240 }
241
242 working_table = new_working_table;
243
244 ctx.store(cte.name.clone(), all_results.clone());
246 }
247
248 if iteration >= MAX_RECURSION_DEPTH && !working_table.is_empty() {
249 return Err(ExecutionError::new(format!(
250 "Recursive CTE '{}' exceeded maximum recursion depth ({})",
251 cte.name, MAX_RECURSION_DEPTH
252 )));
253 }
254
255 Ok(all_results)
256 }
257
258 fn project_columns(&self, result: &UnifiedResult, columns: &[String]) -> UnifiedResult {
260 if columns.is_empty() {
261 return result.clone();
262 }
263
264 let mut projected = UnifiedResult::with_columns(columns.to_vec());
265
266 for record in &result.records {
267 let mut new_record = UnifiedRecord::new();
268
269 for (i, col) in columns.iter().enumerate() {
271 let value = result
273 .columns
274 .get(i)
275 .and_then(|orig_col| record.get(orig_col))
276 .cloned()
277 .or_else(|| record.get(col).cloned())
278 .unwrap_or(Value::Null);
279
280 new_record.set(col, value);
281 }
282
283 projected.push(new_record);
284 }
285
286 projected
287 }
288
289 fn hash_record(&self, record: &UnifiedRecord) -> u64 {
291 use std::collections::hash_map::DefaultHasher;
292 use std::hash::{Hash, Hasher};
293
294 let mut hasher = DefaultHasher::new();
295
296 let mut keys = record.column_names();
298 keys.sort();
299
300 for key in &keys {
301 (**key).hash(&mut hasher);
302 if let Some(value) = record.get(key) {
303 Self::hash_value(value, &mut hasher);
304 }
305 }
306
307 hasher.finish()
308 }
309
310 fn hash_value(value: &Value, hasher: &mut impl std::hash::Hasher) {
312 use std::hash::Hash;
313
314 match value {
315 Value::Null => 0u8.hash(hasher),
316 Value::Boolean(b) => {
317 1u8.hash(hasher);
318 b.hash(hasher);
319 }
320 Value::Integer(i) => {
321 2u8.hash(hasher);
322 i.hash(hasher);
323 }
324 Value::UnsignedInteger(u) => {
325 3u8.hash(hasher);
326 u.hash(hasher);
327 }
328 Value::Float(f) => {
329 4u8.hash(hasher);
330 f.to_bits().hash(hasher);
331 }
332 Value::Text(s) => {
333 5u8.hash(hasher);
334 s.hash(hasher);
335 }
336 Value::Blob(b) => {
337 6u8.hash(hasher);
338 b.hash(hasher);
339 }
340 Value::Timestamp(t) => {
341 7u8.hash(hasher);
342 t.hash(hasher);
343 }
344 Value::Duration(d) => {
345 8u8.hash(hasher);
346 d.hash(hasher);
347 }
348 Value::IpAddr(addr) => {
349 9u8.hash(hasher);
350 match addr {
351 std::net::IpAddr::V4(v4) => v4.octets().hash(hasher),
352 std::net::IpAddr::V6(v6) => v6.octets().hash(hasher),
353 }
354 }
355 Value::MacAddr(mac) => {
356 10u8.hash(hasher);
357 mac.hash(hasher);
358 }
359 Value::Vector(v) => {
360 11u8.hash(hasher);
361 v.len().hash(hasher);
362 for f in v {
363 f.to_bits().hash(hasher);
364 }
365 }
366 Value::Json(j) => {
367 12u8.hash(hasher);
368 j.hash(hasher);
369 }
370 Value::Uuid(u) => {
371 13u8.hash(hasher);
372 u.hash(hasher);
373 }
374 Value::NodeRef(n) => {
375 14u8.hash(hasher);
376 n.hash(hasher);
377 }
378 Value::EdgeRef(e) => {
379 15u8.hash(hasher);
380 e.hash(hasher);
381 }
382 Value::VectorRef(coll, id) => {
383 16u8.hash(hasher);
384 coll.hash(hasher);
385 id.hash(hasher);
386 }
387 Value::RowRef(table, id) => {
388 17u8.hash(hasher);
389 table.hash(hasher);
390 id.hash(hasher);
391 }
392 Value::Color(rgb) => {
393 18u8.hash(hasher);
394 rgb.hash(hasher);
395 }
396 Value::Email(s) => {
397 19u8.hash(hasher);
398 s.hash(hasher);
399 }
400 Value::Url(s) => {
401 20u8.hash(hasher);
402 s.hash(hasher);
403 }
404 Value::Phone(n) => {
405 21u8.hash(hasher);
406 n.hash(hasher);
407 }
408 Value::Semver(v) => {
409 22u8.hash(hasher);
410 v.hash(hasher);
411 }
412 Value::Cidr(ip, prefix) => {
413 23u8.hash(hasher);
414 ip.hash(hasher);
415 prefix.hash(hasher);
416 }
417 Value::Date(d) => {
418 24u8.hash(hasher);
419 d.hash(hasher);
420 }
421 Value::Time(t) => {
422 25u8.hash(hasher);
423 t.hash(hasher);
424 }
425 Value::Decimal(v) => {
426 26u8.hash(hasher);
427 v.hash(hasher);
428 }
429 Value::EnumValue(i) => {
430 27u8.hash(hasher);
431 i.hash(hasher);
432 }
433 Value::Array(elems) => {
434 28u8.hash(hasher);
435 elems.len().hash(hasher);
436 for elem in elems {
437 Self::hash_value(elem, hasher);
438 }
439 }
440 Value::TimestampMs(v) => {
441 29u8.hash(hasher);
442 v.hash(hasher);
443 }
444 Value::Ipv4(v) => {
445 30u8.hash(hasher);
446 v.hash(hasher);
447 }
448 Value::Ipv6(bytes) => {
449 31u8.hash(hasher);
450 bytes.hash(hasher);
451 }
452 Value::Subnet(ip, mask) => {
453 32u8.hash(hasher);
454 ip.hash(hasher);
455 mask.hash(hasher);
456 }
457 Value::Port(v) => {
458 33u8.hash(hasher);
459 v.hash(hasher);
460 }
461 Value::Latitude(v) => {
462 34u8.hash(hasher);
463 v.hash(hasher);
464 }
465 Value::Longitude(v) => {
466 35u8.hash(hasher);
467 v.hash(hasher);
468 }
469 Value::GeoPoint(lat, lon) => {
470 36u8.hash(hasher);
471 lat.hash(hasher);
472 lon.hash(hasher);
473 }
474 Value::Country2(c) => {
475 37u8.hash(hasher);
476 c.hash(hasher);
477 }
478 Value::Country3(c) => {
479 38u8.hash(hasher);
480 c.hash(hasher);
481 }
482 Value::Lang2(c) => {
483 39u8.hash(hasher);
484 c.hash(hasher);
485 }
486 Value::Lang5(c) => {
487 40u8.hash(hasher);
488 c.hash(hasher);
489 }
490 Value::Currency(c) => {
491 41u8.hash(hasher);
492 c.hash(hasher);
493 }
494 Value::AssetCode(code) => {
495 50u8.hash(hasher);
496 code.hash(hasher);
497 }
498 Value::Money {
499 asset_code,
500 minor_units,
501 scale,
502 } => {
503 51u8.hash(hasher);
504 asset_code.hash(hasher);
505 minor_units.hash(hasher);
506 scale.hash(hasher);
507 }
508 Value::ColorAlpha(rgba) => {
509 42u8.hash(hasher);
510 rgba.hash(hasher);
511 }
512 Value::BigInt(v) => {
513 43u8.hash(hasher);
514 v.hash(hasher);
515 }
516 Value::KeyRef(col, key) => {
517 44u8.hash(hasher);
518 col.hash(hasher);
519 key.hash(hasher);
520 }
521 Value::DocRef(col, id) => {
522 45u8.hash(hasher);
523 col.hash(hasher);
524 id.hash(hasher);
525 }
526 Value::TableRef(name) => {
527 46u8.hash(hasher);
528 name.hash(hasher);
529 }
530 Value::PageRef(page_id) => {
531 47u8.hash(hasher);
532 page_id.hash(hasher);
533 }
534 Value::Secret(bytes) => {
535 48u8.hash(hasher);
536 bytes.hash(hasher);
537 }
538 Value::Password(hash) => {
539 49u8.hash(hasher);
540 hash.hash(hasher);
541 }
542 }
543 }
544}
545
546pub fn split_union_parts(query: &QueryExpr) -> Option<(QueryExpr, QueryExpr)> {
548 let _ = query;
551 None
552}
553
554pub fn inline_ctes(query: QueryWithCte) -> Result<QueryExpr, ExecutionError> {
573 let Some(with_clause) = query.with_clause else {
574 return Ok(query.query);
575 };
576 if with_clause.has_recursive {
577 return Err(ExecutionError::new(
578 "WITH RECURSIVE is not yet supported by the executor; \
579 non-recursive WITH clauses run today, recursive support \
580 is tracked separately"
581 .to_string(),
582 ));
583 }
584
585 let mut resolved: HashMap<String, QueryExpr> = HashMap::new();
589 for cte in &with_clause.ctes {
590 let mut body = (*cte.query).clone();
591 rewrite(&mut body, &resolved);
592 resolved.insert(cte.name.clone(), body);
593 }
594
595 let mut outer = query.query;
596 rewrite(&mut outer, &resolved);
597 Ok(outer)
598}
599
600fn rewrite(expr: &mut QueryExpr, ctes: &HashMap<String, QueryExpr>) {
610 use super::super::ast::TableSource;
611 match expr {
612 QueryExpr::Table(tq) => {
613 let lookup_name = match &tq.source {
614 Some(TableSource::Subquery(_)) => None,
615 Some(TableSource::Name(n)) => Some(n.clone()),
616 Some(TableSource::Function { .. } | TableSource::InlineGraphFunction { .. }) => {
619 None
620 }
621 None => Some(tq.table.clone()),
622 };
623
624 if let Some(name) = lookup_name {
625 if let Some(body) = ctes.get(&name) {
626 let outer_has_constraints = tq.filter.is_some()
627 || tq.where_expr.is_some()
628 || tq.limit.is_some()
629 || tq.offset.is_some()
630 || !tq.columns.is_empty()
631 || !tq.select_items.is_empty()
632 || !tq.group_by.is_empty()
633 || !tq.order_by.is_empty();
634
635 if outer_has_constraints {
636 tq.source = Some(TableSource::Subquery(Box::new(body.clone())));
642 tq.table = format!("__cte_{name}");
643 } else {
644 *expr = body.clone();
648 }
649 return;
650 }
651 }
652
653 if let Some(TableSource::Subquery(body)) = tq.source.as_mut() {
654 rewrite(body, ctes);
655 }
656 }
657 QueryExpr::Join(jq) => {
658 rewrite(&mut jq.left, ctes);
659 rewrite(&mut jq.right, ctes);
660 }
661 _ => {}
662 }
663}
664
665#[cfg(test)]
670mod tests {
671 use super::*;
672 use crate::storage::query::ast::CteQueryBuilder;
673 use crate::storage::query::WithClause;
674
675 fn mock_execute(
676 _query: &QueryExpr,
677 _ctx: &CteContext,
678 ) -> Result<UnifiedResult, ExecutionError> {
679 Ok(UnifiedResult::empty())
681 }
682
683 #[test]
684 fn test_cte_context() {
685 let mut ctx = CteContext::new();
686
687 assert!(ctx.get("test").is_none());
689 assert!(!ctx.is_evaluating("test"));
690
691 let result = UnifiedResult::with_columns(vec!["col1".to_string()]);
693 ctx.store("test".to_string(), result);
694 assert!(ctx.get("test").is_some());
695
696 ctx.start_evaluating("other");
698 assert!(ctx.is_evaluating("other"));
699 ctx.done_evaluating("other");
700 assert!(!ctx.is_evaluating("other"));
701 }
702
703 #[test]
704 fn test_simple_cte_execution() {
705 let executor = CteExecutor::new(|_query, _ctx| {
706 let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
707 let mut record = UnifiedRecord::new();
708 record.set("id", Value::Integer(1));
709 result.push(record);
710 Ok(result)
711 });
712
713 let cte = CteDefinition {
715 name: "test_cte".to_string(),
716 columns: vec!["id".to_string()],
717 query: Box::new(QueryExpr::table("dummy").build()),
718 recursive: false,
719 };
720
721 let with_clause = WithClause::new().add(cte);
722 let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("test_cte").build());
723
724 let result = executor.execute(&query);
725 assert!(result.is_ok());
726 }
727
728 #[test]
729 fn test_cte_builder() {
730 let query = CteQueryBuilder::new()
731 .cte_with_columns(
732 "nums",
733 vec!["n".to_string()],
734 QueryExpr::table("numbers").build(),
735 )
736 .build(QueryExpr::table("nums").build());
737
738 assert!(query.with_clause.is_some());
739 let with_clause = query.with_clause.unwrap();
740 assert_eq!(with_clause.ctes.len(), 1);
741 assert_eq!(with_clause.ctes[0].name, "nums");
742 }
743
744 #[test]
745 fn test_recursive_cte_builder() {
746 let query = CteQueryBuilder::new()
747 .recursive_cte("paths", QueryExpr::table("connections").build())
748 .build(QueryExpr::table("paths").build());
749
750 assert!(query.with_clause.is_some());
751 let with_clause = query.with_clause.unwrap();
752 assert!(with_clause.has_recursive);
753 assert!(with_clause.ctes[0].recursive);
754 }
755
756 #[test]
757 fn test_circular_reference_detection() {
758 let mut ctx = CteContext::new();
759 ctx.start_evaluating("cte_a");
760
761 assert!(ctx.is_evaluating("cte_a"));
763 }
764
765 #[test]
766 fn test_cte_stats() {
767 let ctx = CteContext::new();
768 let stats = ctx.stats();
769
770 assert_eq!(stats.ctes_executed, 0);
771 assert_eq!(stats.recursive_iterations, 0);
772 assert_eq!(stats.rows_produced, 0);
773 }
774
775 #[test]
776 fn test_hash_record() {
777 let executor = CteExecutor::new(mock_execute);
778
779 let mut record1 = UnifiedRecord::new();
780 record1.set("id", Value::Integer(1));
781 record1.set("name", Value::text("test".to_string()));
782
783 let mut record2 = UnifiedRecord::new();
784 record2.set("id", Value::Integer(1));
785 record2.set("name", Value::text("test".to_string()));
786
787 let mut record3 = UnifiedRecord::new();
788 record3.set("id", Value::Integer(2));
789 record3.set("name", Value::text("test".to_string()));
790
791 assert_eq!(
793 executor.hash_record(&record1),
794 executor.hash_record(&record2)
795 );
796
797 assert_ne!(
799 executor.hash_record(&record1),
800 executor.hash_record(&record3)
801 );
802 }
803
804 #[test]
805 fn test_hash_various_value_types() {
806 let executor = CteExecutor::new(mock_execute);
807
808 let mut record = UnifiedRecord::new();
810 record.set("null_val", Value::Null);
811 record.set("bool_val", Value::Boolean(true));
812 record.set("int_val", Value::Integer(42));
813 record.set("float_val", Value::Float(2.5));
814 record.set("text_val", Value::text("hello".to_string()));
815 record.set("blob_val", Value::Blob(vec![1, 2, 3]));
816 record.set("timestamp_val", Value::Timestamp(1234567890));
817 record.set("duration_val", Value::Duration(5000));
818
819 let hash = executor.hash_record(&record);
821 assert!(hash > 0);
822 }
823
824 #[test]
825 fn test_project_columns() {
826 let executor = CteExecutor::new(mock_execute);
827
828 let mut original =
829 UnifiedResult::with_columns(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
830
831 let mut record = UnifiedRecord::new();
832 record.set("a", Value::Integer(1));
833 record.set("b", Value::Integer(2));
834 record.set("c", Value::Integer(3));
835 original.push(record);
836
837 let projected = executor.project_columns(&original, &["x".to_string(), "y".to_string()]);
839
840 assert_eq!(projected.columns, vec!["x", "y"]);
841 assert_eq!(projected.len(), 1);
842 }
843
844 #[test]
845 fn test_empty_columns_projection() {
846 let executor = CteExecutor::new(mock_execute);
847
848 let original = UnifiedResult::with_columns(vec!["a".to_string()]);
849
850 let projected = executor.project_columns(&original, &[]);
852 assert_eq!(projected.columns, original.columns);
853 }
854
855 #[test]
856 fn test_cte_with_multiple_definitions() {
857 let executor = CteExecutor::new(|query, ctx| {
858 match query {
860 QueryExpr::Table(t) if t.table == "base" => {
861 let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
862 let mut record = UnifiedRecord::new();
863 record.set("id", Value::Integer(1));
864 result.push(record);
865 Ok(result)
866 }
867 QueryExpr::Table(t) if t.table == "cte1" => {
868 if ctx.get("cte1").is_some() {
870 Ok(ctx.get("cte1").unwrap().clone())
871 } else {
872 Ok(UnifiedResult::empty())
873 }
874 }
875 _ => Ok(UnifiedResult::empty()),
876 }
877 });
878
879 let cte1 = CteDefinition {
880 name: "cte1".to_string(),
881 columns: vec!["id".to_string()],
882 query: Box::new(QueryExpr::table("base").build()),
883 recursive: false,
884 };
885
886 let cte2 = CteDefinition {
887 name: "cte2".to_string(),
888 columns: vec!["id".to_string()],
889 query: Box::new(QueryExpr::table("cte1").build()),
890 recursive: false,
891 };
892
893 let with_clause = WithClause::new().add(cte1).add(cte2);
894 let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("cte2").build());
895
896 let result = executor.execute(&query);
897 assert!(result.is_ok());
898 }
899}