use std::collections::{HashMap, HashSet};
use super::super::ast::{CteDefinition, QueryExpr, QueryWithCte};
use super::super::unified::{ExecutionError, UnifiedRecord, UnifiedResult};
use crate::storage::schema::Value;
const MAX_RECURSION_DEPTH: usize = 1000;
const MAX_RECURSIVE_ROWS: usize = 100_000;
#[derive(Debug, Clone, Default)]
pub struct CteContext {
tables: HashMap<String, UnifiedResult>,
evaluating: HashSet<String>,
stats: CteStats,
}
impl CteContext {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, name: &str) -> Option<&UnifiedResult> {
self.tables.get(name)
}
pub fn store(&mut self, name: String, result: UnifiedResult) {
self.tables.insert(name, result);
}
pub fn is_evaluating(&self, name: &str) -> bool {
self.evaluating.contains(name)
}
pub fn start_evaluating(&mut self, name: &str) {
self.evaluating.insert(name.to_string());
}
pub fn done_evaluating(&mut self, name: &str) {
self.evaluating.remove(name);
}
pub fn stats(&self) -> &CteStats {
&self.stats
}
}
#[derive(Debug, Clone, Default)]
pub struct CteStats {
pub ctes_executed: usize,
pub recursive_iterations: usize,
pub rows_produced: usize,
pub exec_time_us: u64,
}
pub struct CteExecutor<F>
where
F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
{
execute_fn: F,
}
impl<F> CteExecutor<F>
where
F: Fn(&QueryExpr, &CteContext) -> Result<UnifiedResult, ExecutionError>,
{
pub fn new(execute_fn: F) -> Self {
Self { execute_fn }
}
pub fn execute(&self, query: &QueryWithCte) -> Result<UnifiedResult, ExecutionError> {
let start = std::time::Instant::now();
let mut ctx = CteContext::new();
if let Some(ref with_clause) = query.with_clause {
for cte in &with_clause.ctes {
self.materialize_cte(cte, &mut ctx)?;
}
}
let result = (self.execute_fn)(&query.query, &ctx)?;
ctx.stats.exec_time_us = start.elapsed().as_micros() as u64;
Ok(result)
}
fn materialize_cte(
&self,
cte: &CteDefinition,
ctx: &mut CteContext,
) -> Result<(), ExecutionError> {
if ctx.is_evaluating(&cte.name) {
return Err(ExecutionError::new(format!(
"Circular CTE reference: {}",
cte.name
)));
}
if ctx.get(&cte.name).is_some() {
return Ok(());
}
ctx.start_evaluating(&cte.name);
let result = if cte.recursive {
self.execute_recursive_cte(cte, ctx)?
} else {
let result = (self.execute_fn)(&cte.query, ctx)?;
self.project_columns(&result, &cte.columns)
};
ctx.stats.ctes_executed += 1;
ctx.stats.rows_produced += result.len();
ctx.store(cte.name.clone(), result);
ctx.done_evaluating(&cte.name);
Ok(())
}
fn execute_recursive_cte(
&self,
cte: &CteDefinition,
ctx: &mut CteContext,
) -> Result<UnifiedResult, ExecutionError> {
let mut all_results = UnifiedResult::with_columns(cte.columns.clone());
let mut working_table = UnifiedResult::with_columns(cte.columns.clone());
let mut seen_rows: HashSet<u64> = HashSet::new();
let mut iteration = 0;
let initial = (self.execute_fn)(&cte.query, ctx)?;
let initial = self.project_columns(&initial, &cte.columns);
for record in &initial.records {
let hash = self.hash_record(record);
if seen_rows.insert(hash) {
working_table.push(record.clone());
all_results.push(record.clone());
}
}
ctx.store(cte.name.clone(), working_table.clone());
while !working_table.is_empty() && iteration < MAX_RECURSION_DEPTH {
iteration += 1;
ctx.stats.recursive_iterations += 1;
if all_results.len() > MAX_RECURSIVE_ROWS {
return Err(ExecutionError::new(format!(
"Recursive CTE '{}' exceeded maximum rows ({})",
cte.name, MAX_RECURSIVE_ROWS
)));
}
let new_results = (self.execute_fn)(&cte.query, ctx)?;
let new_results = self.project_columns(&new_results, &cte.columns);
let mut new_working_table = UnifiedResult::with_columns(cte.columns.clone());
for record in &new_results.records {
let hash = self.hash_record(record);
if seen_rows.insert(hash) {
new_working_table.push(record.clone());
all_results.push(record.clone());
}
}
working_table = new_working_table;
ctx.store(cte.name.clone(), all_results.clone());
}
if iteration >= MAX_RECURSION_DEPTH && !working_table.is_empty() {
return Err(ExecutionError::new(format!(
"Recursive CTE '{}' exceeded maximum recursion depth ({})",
cte.name, MAX_RECURSION_DEPTH
)));
}
Ok(all_results)
}
fn project_columns(&self, result: &UnifiedResult, columns: &[String]) -> UnifiedResult {
if columns.is_empty() {
return result.clone();
}
let mut projected = UnifiedResult::with_columns(columns.to_vec());
for record in &result.records {
let mut new_record = UnifiedRecord::new();
for (i, col) in columns.iter().enumerate() {
let value = result
.columns
.get(i)
.and_then(|orig_col| record.get(orig_col))
.cloned()
.or_else(|| record.get(col).cloned())
.unwrap_or(Value::Null);
new_record.set(col, value);
}
projected.push(new_record);
}
projected
}
fn hash_record(&self, record: &UnifiedRecord) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let mut keys = record.column_names();
keys.sort();
for key in &keys {
(**key).hash(&mut hasher);
if let Some(value) = record.get(key) {
Self::hash_value(value, &mut hasher);
}
}
hasher.finish()
}
fn hash_value(value: &Value, hasher: &mut impl std::hash::Hasher) {
use std::hash::Hash;
match value {
Value::Null => 0u8.hash(hasher),
Value::Boolean(b) => {
1u8.hash(hasher);
b.hash(hasher);
}
Value::Integer(i) => {
2u8.hash(hasher);
i.hash(hasher);
}
Value::UnsignedInteger(u) => {
3u8.hash(hasher);
u.hash(hasher);
}
Value::Float(f) => {
4u8.hash(hasher);
f.to_bits().hash(hasher);
}
Value::Text(s) => {
5u8.hash(hasher);
s.hash(hasher);
}
Value::Blob(b) => {
6u8.hash(hasher);
b.hash(hasher);
}
Value::Timestamp(t) => {
7u8.hash(hasher);
t.hash(hasher);
}
Value::Duration(d) => {
8u8.hash(hasher);
d.hash(hasher);
}
Value::IpAddr(addr) => {
9u8.hash(hasher);
match addr {
std::net::IpAddr::V4(v4) => v4.octets().hash(hasher),
std::net::IpAddr::V6(v6) => v6.octets().hash(hasher),
}
}
Value::MacAddr(mac) => {
10u8.hash(hasher);
mac.hash(hasher);
}
Value::Vector(v) => {
11u8.hash(hasher);
v.len().hash(hasher);
for f in v {
f.to_bits().hash(hasher);
}
}
Value::Json(j) => {
12u8.hash(hasher);
j.hash(hasher);
}
Value::Uuid(u) => {
13u8.hash(hasher);
u.hash(hasher);
}
Value::NodeRef(n) => {
14u8.hash(hasher);
n.hash(hasher);
}
Value::EdgeRef(e) => {
15u8.hash(hasher);
e.hash(hasher);
}
Value::VectorRef(coll, id) => {
16u8.hash(hasher);
coll.hash(hasher);
id.hash(hasher);
}
Value::RowRef(table, id) => {
17u8.hash(hasher);
table.hash(hasher);
id.hash(hasher);
}
Value::Color(rgb) => {
18u8.hash(hasher);
rgb.hash(hasher);
}
Value::Email(s) => {
19u8.hash(hasher);
s.hash(hasher);
}
Value::Url(s) => {
20u8.hash(hasher);
s.hash(hasher);
}
Value::Phone(n) => {
21u8.hash(hasher);
n.hash(hasher);
}
Value::Semver(v) => {
22u8.hash(hasher);
v.hash(hasher);
}
Value::Cidr(ip, prefix) => {
23u8.hash(hasher);
ip.hash(hasher);
prefix.hash(hasher);
}
Value::Date(d) => {
24u8.hash(hasher);
d.hash(hasher);
}
Value::Time(t) => {
25u8.hash(hasher);
t.hash(hasher);
}
Value::Decimal(v) => {
26u8.hash(hasher);
v.hash(hasher);
}
Value::EnumValue(i) => {
27u8.hash(hasher);
i.hash(hasher);
}
Value::Array(elems) => {
28u8.hash(hasher);
elems.len().hash(hasher);
for elem in elems {
Self::hash_value(elem, hasher);
}
}
Value::TimestampMs(v) => {
29u8.hash(hasher);
v.hash(hasher);
}
Value::Ipv4(v) => {
30u8.hash(hasher);
v.hash(hasher);
}
Value::Ipv6(bytes) => {
31u8.hash(hasher);
bytes.hash(hasher);
}
Value::Subnet(ip, mask) => {
32u8.hash(hasher);
ip.hash(hasher);
mask.hash(hasher);
}
Value::Port(v) => {
33u8.hash(hasher);
v.hash(hasher);
}
Value::Latitude(v) => {
34u8.hash(hasher);
v.hash(hasher);
}
Value::Longitude(v) => {
35u8.hash(hasher);
v.hash(hasher);
}
Value::GeoPoint(lat, lon) => {
36u8.hash(hasher);
lat.hash(hasher);
lon.hash(hasher);
}
Value::Country2(c) => {
37u8.hash(hasher);
c.hash(hasher);
}
Value::Country3(c) => {
38u8.hash(hasher);
c.hash(hasher);
}
Value::Lang2(c) => {
39u8.hash(hasher);
c.hash(hasher);
}
Value::Lang5(c) => {
40u8.hash(hasher);
c.hash(hasher);
}
Value::Currency(c) => {
41u8.hash(hasher);
c.hash(hasher);
}
Value::AssetCode(code) => {
50u8.hash(hasher);
code.hash(hasher);
}
Value::Money {
asset_code,
minor_units,
scale,
} => {
51u8.hash(hasher);
asset_code.hash(hasher);
minor_units.hash(hasher);
scale.hash(hasher);
}
Value::ColorAlpha(rgba) => {
42u8.hash(hasher);
rgba.hash(hasher);
}
Value::BigInt(v) => {
43u8.hash(hasher);
v.hash(hasher);
}
Value::KeyRef(col, key) => {
44u8.hash(hasher);
col.hash(hasher);
key.hash(hasher);
}
Value::DocRef(col, id) => {
45u8.hash(hasher);
col.hash(hasher);
id.hash(hasher);
}
Value::TableRef(name) => {
46u8.hash(hasher);
name.hash(hasher);
}
Value::PageRef(page_id) => {
47u8.hash(hasher);
page_id.hash(hasher);
}
Value::Secret(bytes) => {
48u8.hash(hasher);
bytes.hash(hasher);
}
Value::Password(hash) => {
49u8.hash(hasher);
hash.hash(hasher);
}
}
}
}
pub fn split_union_parts(query: &QueryExpr) -> Option<(QueryExpr, QueryExpr)> {
let _ = query;
None
}
pub fn inline_ctes(query: QueryWithCte) -> Result<QueryExpr, ExecutionError> {
let Some(with_clause) = query.with_clause else {
return Ok(query.query);
};
if with_clause.has_recursive {
return Err(ExecutionError::new(
"WITH RECURSIVE is not yet supported by the executor; \
non-recursive WITH clauses run today, recursive support \
is tracked separately"
.to_string(),
));
}
let mut resolved: HashMap<String, QueryExpr> = HashMap::new();
for cte in &with_clause.ctes {
let mut body = (*cte.query).clone();
rewrite(&mut body, &resolved);
resolved.insert(cte.name.clone(), body);
}
let mut outer = query.query;
rewrite(&mut outer, &resolved);
Ok(outer)
}
fn rewrite(expr: &mut QueryExpr, ctes: &HashMap<String, QueryExpr>) {
use super::super::ast::TableSource;
match expr {
QueryExpr::Table(tq) => {
let lookup_name = match &tq.source {
Some(TableSource::Subquery(_)) => None,
Some(TableSource::Name(n)) => Some(n.clone()),
None => Some(tq.table.clone()),
};
if let Some(name) = lookup_name {
if let Some(body) = ctes.get(&name) {
let outer_has_constraints = tq.filter.is_some()
|| tq.where_expr.is_some()
|| tq.limit.is_some()
|| tq.offset.is_some()
|| !tq.columns.is_empty()
|| !tq.select_items.is_empty()
|| !tq.group_by.is_empty()
|| !tq.order_by.is_empty();
if outer_has_constraints {
tq.source = Some(TableSource::Subquery(Box::new(body.clone())));
tq.table = format!("__cte_{name}");
} else {
*expr = body.clone();
}
return;
}
}
if let Some(TableSource::Subquery(body)) = tq.source.as_mut() {
rewrite(body, ctes);
}
}
QueryExpr::Join(jq) => {
rewrite(&mut jq.left, ctes);
rewrite(&mut jq.right, ctes);
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::query::ast::CteQueryBuilder;
use crate::storage::query::WithClause;
fn mock_execute(
_query: &QueryExpr,
_ctx: &CteContext,
) -> Result<UnifiedResult, ExecutionError> {
Ok(UnifiedResult::empty())
}
#[test]
fn test_cte_context() {
let mut ctx = CteContext::new();
assert!(ctx.get("test").is_none());
assert!(!ctx.is_evaluating("test"));
let result = UnifiedResult::with_columns(vec!["col1".to_string()]);
ctx.store("test".to_string(), result);
assert!(ctx.get("test").is_some());
ctx.start_evaluating("other");
assert!(ctx.is_evaluating("other"));
ctx.done_evaluating("other");
assert!(!ctx.is_evaluating("other"));
}
#[test]
fn test_simple_cte_execution() {
let executor = CteExecutor::new(|_query, _ctx| {
let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
let mut record = UnifiedRecord::new();
record.set("id", Value::Integer(1));
result.push(record);
Ok(result)
});
let cte = CteDefinition {
name: "test_cte".to_string(),
columns: vec!["id".to_string()],
query: Box::new(QueryExpr::table("dummy").build()),
recursive: false,
};
let with_clause = WithClause::new().add(cte);
let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("test_cte").build());
let result = executor.execute(&query);
assert!(result.is_ok());
}
#[test]
fn test_cte_builder() {
let query = CteQueryBuilder::new()
.cte_with_columns(
"nums",
vec!["n".to_string()],
QueryExpr::table("numbers").build(),
)
.build(QueryExpr::table("nums").build());
assert!(query.with_clause.is_some());
let with_clause = query.with_clause.unwrap();
assert_eq!(with_clause.ctes.len(), 1);
assert_eq!(with_clause.ctes[0].name, "nums");
}
#[test]
fn test_recursive_cte_builder() {
let query = CteQueryBuilder::new()
.recursive_cte("paths", QueryExpr::table("connections").build())
.build(QueryExpr::table("paths").build());
assert!(query.with_clause.is_some());
let with_clause = query.with_clause.unwrap();
assert!(with_clause.has_recursive);
assert!(with_clause.ctes[0].recursive);
}
#[test]
fn test_circular_reference_detection() {
let mut ctx = CteContext::new();
ctx.start_evaluating("cte_a");
assert!(ctx.is_evaluating("cte_a"));
}
#[test]
fn test_cte_stats() {
let ctx = CteContext::new();
let stats = ctx.stats();
assert_eq!(stats.ctes_executed, 0);
assert_eq!(stats.recursive_iterations, 0);
assert_eq!(stats.rows_produced, 0);
}
#[test]
fn test_hash_record() {
let executor = CteExecutor::new(mock_execute);
let mut record1 = UnifiedRecord::new();
record1.set("id", Value::Integer(1));
record1.set("name", Value::text("test".to_string()));
let mut record2 = UnifiedRecord::new();
record2.set("id", Value::Integer(1));
record2.set("name", Value::text("test".to_string()));
let mut record3 = UnifiedRecord::new();
record3.set("id", Value::Integer(2));
record3.set("name", Value::text("test".to_string()));
assert_eq!(
executor.hash_record(&record1),
executor.hash_record(&record2)
);
assert_ne!(
executor.hash_record(&record1),
executor.hash_record(&record3)
);
}
#[test]
fn test_hash_various_value_types() {
let executor = CteExecutor::new(mock_execute);
let mut record = UnifiedRecord::new();
record.set("null_val", Value::Null);
record.set("bool_val", Value::Boolean(true));
record.set("int_val", Value::Integer(42));
record.set("float_val", Value::Float(2.5));
record.set("text_val", Value::text("hello".to_string()));
record.set("blob_val", Value::Blob(vec![1, 2, 3]));
record.set("timestamp_val", Value::Timestamp(1234567890));
record.set("duration_val", Value::Duration(5000));
let hash = executor.hash_record(&record);
assert!(hash > 0);
}
#[test]
fn test_project_columns() {
let executor = CteExecutor::new(mock_execute);
let mut original =
UnifiedResult::with_columns(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
let mut record = UnifiedRecord::new();
record.set("a", Value::Integer(1));
record.set("b", Value::Integer(2));
record.set("c", Value::Integer(3));
original.push(record);
let projected = executor.project_columns(&original, &["x".to_string(), "y".to_string()]);
assert_eq!(projected.columns, vec!["x", "y"]);
assert_eq!(projected.len(), 1);
}
#[test]
fn test_empty_columns_projection() {
let executor = CteExecutor::new(mock_execute);
let original = UnifiedResult::with_columns(vec!["a".to_string()]);
let projected = executor.project_columns(&original, &[]);
assert_eq!(projected.columns, original.columns);
}
#[test]
fn test_cte_with_multiple_definitions() {
let executor = CteExecutor::new(|query, ctx| {
match query {
QueryExpr::Table(t) if t.table == "base" => {
let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
let mut record = UnifiedRecord::new();
record.set("id", Value::Integer(1));
result.push(record);
Ok(result)
}
QueryExpr::Table(t) if t.table == "cte1" => {
if ctx.get("cte1").is_some() {
Ok(ctx.get("cte1").unwrap().clone())
} else {
Ok(UnifiedResult::empty())
}
}
_ => Ok(UnifiedResult::empty()),
}
});
let cte1 = CteDefinition {
name: "cte1".to_string(),
columns: vec!["id".to_string()],
query: Box::new(QueryExpr::table("base").build()),
recursive: false,
};
let cte2 = CteDefinition {
name: "cte2".to_string(),
columns: vec!["id".to_string()],
query: Box::new(QueryExpr::table("cte1").build()),
recursive: false,
};
let with_clause = WithClause::new().add(cte1).add(cte2);
let query = QueryWithCte::with_ctes(with_clause, QueryExpr::table("cte2").build());
let result = executor.execute(&query);
assert!(result.is_ok());
}
}