Skip to main content

reifydb_engine/vm/volcano/
filter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{mem, sync::Arc};
5
6use reifydb_catalog::catalog::Catalog;
7use reifydb_core::{
8	interface::resolved::ResolvedShape,
9	value::column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{util::bitvec::BitVec, value::constraint::Constraint};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19	Result,
20	expression::{
21		compile::{CompiledExpr, compile_expression},
22		context::{CompileContext, EvalContext},
23	},
24	vm::volcano::{
25		query::{QueryContext, QueryNode},
26		udf::{UdfEvalNode, strip_udf_columns},
27	},
28};
29
30pub(crate) struct FilterNode {
31	input: Box<dyn QueryNode>,
32	expressions: Vec<Expression>,
33	udf_names: Vec<String>,
34	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
35}
36
37impl FilterNode {
38	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
39		Self {
40			input,
41			expressions,
42			udf_names: Vec::new(),
43			context: None,
44		}
45	}
46}
47
48impl QueryNode for FilterNode {
49	#[instrument(level = "trace", skip_all, name = "volcano::filter::initialize")]
50	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
51		let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
52			mem::replace(&mut self.input, Box::new(NoopNode)),
53			&self.expressions,
54			&ctx.symbols,
55		);
56		self.input = input;
57		self.expressions = expressions;
58		self.udf_names = udf_names;
59
60		let compile_ctx = CompileContext {
61			symbols: &ctx.symbols,
62		};
63		let compiled = self
64			.expressions
65			.iter()
66			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
67			.collect();
68		self.context = Some((Arc::new(ctx.clone()), compiled));
69		self.input.initialize(rx, ctx)?;
70		Ok(())
71	}
72
73	#[instrument(level = "trace", skip_all, name = "volcano::filter::next")]
74	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
75		debug_assert!(self.context.is_some(), "FilterNode::next() called before initialize()");
76		let (stored_ctx, _) = self.context.as_ref().unwrap();
77		let stored_ctx = stored_ctx.clone();
78
79		loop {
80			match self.input.next(rx, ctx)? {
81				Some(columns) => {
82					let transform_ctx = TransformContext {
83						routines: &ctx.services.routines,
84						runtime_context: &stored_ctx.services.runtime_context,
85						params: &stored_ctx.params,
86					};
87					let mut columns = self.apply(&transform_ctx, columns)?;
88					if columns.row_count() > 0 {
89						strip_udf_columns(&mut columns, &self.udf_names);
90						return Ok(Some(columns));
91					}
92				}
93				None => return Ok(None),
94			}
95		}
96	}
97
98	fn headers(&self) -> Option<ColumnHeaders> {
99		self.input.headers()
100	}
101}
102
103impl Transform for FilterNode {
104	fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
105		let (stored_ctx, compiled) =
106			self.context.as_ref().expect("FilterNode::apply() called before initialize()");
107
108		let session = EvalContext::from_transform(ctx, stored_ctx);
109		let mut columns = input;
110		let mut row_count = columns.row_count();
111
112		for compiled_expr in compiled {
113			if row_count == 0 {
114				break;
115			}
116
117			let exec_ctx = session.with_eval(columns.clone(), row_count);
118
119			let result = compiled_expr.execute(&exec_ctx)?;
120
121			let filter_mask = match result.data() {
122				ColumnBuffer::Bool(container) => {
123					let mut mask = BitVec::repeat(row_count, false);
124					for i in 0..row_count {
125						if i < container.len() {
126							let valid = container.is_defined(i);
127							let filter_result = container.data().get(i);
128							mask.set(i, valid & filter_result);
129						}
130					}
131					mask
132				}
133				ColumnBuffer::Option {
134					inner,
135					bitvec,
136				} => match inner.as_ref() {
137					ColumnBuffer::Bool(container) => {
138						let mut mask = BitVec::repeat(row_count, false);
139						for i in 0..row_count {
140							let defined = i < bitvec.len() && bitvec.get(i);
141							let valid = defined && container.is_defined(i);
142							let value = valid && container.data().get(i);
143							mask.set(i, value);
144						}
145						mask
146					}
147					_ => panic!("filter expression must evaluate to a boolean column"),
148				},
149				_ => panic!("filter expression must evaluate to a boolean column"),
150			};
151
152			columns.filter(&filter_mask)?;
153			row_count = columns.row_count();
154		}
155
156		Ok(columns)
157	}
158}
159
160pub(crate) fn resolve_is_variant_tags(
161	expr: &mut Expression,
162	source: &ResolvedShape,
163	catalog: &Catalog,
164	rx: &mut Transaction<'_>,
165) -> Result<()> {
166	match expr {
167		Expression::IsVariant(e) => {
168			let col_name = match e.expression.as_ref() {
169				Expression::Column(c) => c.0.name.text().to_string(),
170				other => display_label(other).text().to_string(),
171			};
172
173			let tag_col_name = format!("{}_tag", col_name);
174			let columns = source.columns();
175			if let Some(tag_col) = columns.iter().find(|c| c.name == tag_col_name)
176				&& let Some(Constraint::SumType(id)) = tag_col.constraint.constraint()
177			{
178				let def = catalog.get_sumtype(rx, *id)?;
179				let variant_name = e.variant_name.text().to_lowercase();
180				if let Some(variant) =
181					def.variants.iter().find(|v| v.name.to_lowercase() == variant_name)
182				{
183					e.tag = Some(variant.tag);
184				}
185			}
186			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
187		}
188		Expression::And(e) => {
189			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
190			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
191		}
192		Expression::Or(e) => {
193			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
194			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
195		}
196		Expression::Equal(e) => {
197			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
198			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
199		}
200		Expression::NotEqual(e) => {
201			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
202			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
203		}
204		Expression::Prefix(e) => {
205			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
206		}
207		Expression::If(e) => {
208			resolve_is_variant_tags(&mut e.condition, source, catalog, rx)?;
209			resolve_is_variant_tags(&mut e.then_expr, source, catalog, rx)?;
210			for else_if in &mut e.else_ifs {
211				resolve_is_variant_tags(&mut else_if.condition, source, catalog, rx)?;
212				resolve_is_variant_tags(&mut else_if.then_expr, source, catalog, rx)?;
213			}
214			if let Some(else_expr) = &mut e.else_expr {
215				resolve_is_variant_tags(else_expr, source, catalog, rx)?;
216			}
217		}
218		Expression::Alias(e) => {
219			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
220		}
221		_ => {}
222	}
223	Ok(())
224}