Skip to main content

reifydb_engine/vm/volcano/
filter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{mem, sync::Arc};
5
6use reifydb_catalog::catalog::Catalog;
7use reifydb_core::{
8	interface::{catalog::dictionary::Dictionary, resolved::ResolvedShape},
9	value::{
10		batch::lazy::LazyBatch,
11		column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12	},
13};
14use reifydb_extension::transform::{Transform, context::TransformContext};
15use reifydb_rql::expression::{Expression, name::display_label};
16use reifydb_transaction::transaction::Transaction;
17use reifydb_type::{util::bitvec::BitVec, value::constraint::Constraint};
18use tracing::instrument;
19
20use super::{NoopNode, decode_dictionary_columns};
21use crate::{
22	Result,
23	expression::{
24		compile::{CompiledExpr, compile_expression},
25		context::{CompileContext, EvalContext},
26	},
27	vm::volcano::{
28		query::{QueryContext, QueryNode},
29		udf::{UdfEvalNode, strip_udf_columns},
30	},
31};
32
33pub(crate) struct FilterNode {
34	input: Box<dyn QueryNode>,
35	expressions: Vec<Expression>,
36	udf_names: Vec<String>,
37	context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
38}
39
40impl FilterNode {
41	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
42		Self {
43			input,
44			expressions,
45			udf_names: Vec::new(),
46			context: None,
47		}
48	}
49}
50
51impl QueryNode for FilterNode {
52	#[instrument(level = "trace", skip_all, name = "volcano::filter::initialize")]
53	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54		let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55			mem::replace(&mut self.input, Box::new(NoopNode)),
56			&self.expressions,
57			&ctx.symbols,
58		);
59		self.input = input;
60		self.expressions = expressions;
61		self.udf_names = udf_names;
62
63		let compile_ctx = CompileContext {
64			symbols: &ctx.symbols,
65		};
66		let compiled = self
67			.expressions
68			.iter()
69			.map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70			.collect();
71		self.context = Some((Arc::new(ctx.clone()), compiled));
72		self.input.initialize(rx, ctx)?;
73		Ok(())
74	}
75
76	#[instrument(level = "trace", skip_all, name = "volcano::filter::next")]
77	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78		debug_assert!(self.context.is_some(), "FilterNode::next() called before initialize()");
79		let (stored_ctx, compiled) = self.context.as_ref().unwrap();
80
81		loop {
82			// Try lazy path first
83			if let Some(mut lazy_batch) = self.input.next_lazy(rx, ctx)? {
84				// Evaluate filter on lazy batch
85				let filter_result =
86					self.evaluate_filter_on_lazy(&lazy_batch, stored_ctx, compiled, rx)?;
87
88				if let Some(filter_mask) = filter_result {
89					lazy_batch.apply_filter(&filter_mask);
90				}
91
92				if lazy_batch.valid_row_count() == 0 {
93					continue; // Skip to next batch
94				}
95
96				// Save dictionary metadata before consuming the lazy batch
97				let dictionaries: Vec<Option<Dictionary>> =
98					lazy_batch.column_metas().iter().map(|m| m.dictionary.clone()).collect();
99
100				// Materialize surviving rows
101				let mut columns = lazy_batch.into_columns();
102
103				// Decode dictionary columns back to actual values
104				decode_dictionary_columns(&mut columns, &dictionaries, rx)?;
105
106				strip_udf_columns(&mut columns, &self.udf_names);
107				return Ok(Some(columns));
108			}
109
110			// Fall back to materialized path
111			if let Some(columns) = self.input.next(rx, ctx)? {
112				let transform_ctx = TransformContext {
113					routines: &ctx.services.routines,
114					runtime_context: &stored_ctx.services.runtime_context,
115					params: &stored_ctx.params,
116				};
117				let mut columns = self.apply(&transform_ctx, columns)?;
118				if columns.row_count() > 0 {
119					strip_udf_columns(&mut columns, &self.udf_names);
120					return Ok(Some(columns));
121				}
122			} else {
123				// No more batches
124				return Ok(None);
125			}
126		}
127	}
128
129	fn headers(&self) -> Option<ColumnHeaders> {
130		self.input.headers()
131	}
132}
133
134impl Transform for FilterNode {
135	fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
136		let (stored_ctx, compiled) =
137			self.context.as_ref().expect("FilterNode::apply() called before initialize()");
138
139		let session = EvalContext::from_transform(ctx, stored_ctx);
140		let mut columns = input;
141		let mut row_count = columns.row_count();
142
143		for compiled_expr in compiled {
144			if row_count == 0 {
145				break;
146			}
147
148			let exec_ctx = session.with_eval(columns.clone(), row_count);
149
150			let result = compiled_expr.execute(&exec_ctx)?;
151
152			let filter_mask = match result.data() {
153				ColumnBuffer::Bool(container) => {
154					let mut mask = BitVec::repeat(row_count, false);
155					for i in 0..row_count {
156						if i < container.len() {
157							let valid = container.is_defined(i);
158							let filter_result = container.data().get(i);
159							mask.set(i, valid & filter_result);
160						}
161					}
162					mask
163				}
164				ColumnBuffer::Option {
165					inner,
166					bitvec,
167				} => match inner.as_ref() {
168					ColumnBuffer::Bool(container) => {
169						let mut mask = BitVec::repeat(row_count, false);
170						for i in 0..row_count {
171							let defined = i < bitvec.len() && bitvec.get(i);
172							let valid = defined && container.is_defined(i);
173							let value = valid && container.data().get(i);
174							mask.set(i, value);
175						}
176						mask
177					}
178					_ => panic!("filter expression must evaluate to a boolean column"),
179				},
180				_ => panic!("filter expression must evaluate to a boolean column"),
181			};
182
183			columns.filter(&filter_mask)?;
184			row_count = columns.row_count();
185		}
186
187		Ok(columns)
188	}
189}
190
191impl FilterNode {
192	/// Evaluate filter expressions on a lazy batch using column-oriented evaluation.
193	/// Returns a filter mask indicating which rows pass all filter expressions.
194	fn evaluate_filter_on_lazy<'a>(
195		&self,
196		lazy_batch: &LazyBatch,
197		ctx: &QueryContext,
198		compiled: &[CompiledExpr],
199		rx: &mut Transaction<'a>,
200	) -> Result<Option<BitVec>> {
201		// Materialize to columns for column-oriented evaluation,
202		// then decode dictionary columns so filters can compare actual values.
203		let dictionaries: Vec<Option<Dictionary>> =
204			lazy_batch.column_metas().iter().map(|m| m.dictionary.clone()).collect();
205		let mut columns = lazy_batch.clone().into_columns();
206		decode_dictionary_columns(&mut columns, &dictionaries, rx)?;
207		let row_count = columns.row_count();
208
209		if row_count == 0 {
210			return Ok(Some(BitVec::empty()));
211		}
212
213		let session = EvalContext::from_query(ctx);
214		let mut mask = BitVec::repeat(row_count, true);
215
216		for compiled_expr in compiled {
217			let exec_ctx = session.with_eval(columns.clone(), row_count);
218
219			let result = compiled_expr.execute(&exec_ctx)?;
220
221			match result.data() {
222				ColumnBuffer::Bool(container) => {
223					for i in 0..row_count {
224						if mask.get(i) {
225							let valid = container.is_defined(i);
226							let filter_result = container.data().get(i);
227							mask.set(i, valid & filter_result);
228						}
229					}
230				}
231				ColumnBuffer::Option {
232					inner,
233					bitvec,
234				} => match inner.as_ref() {
235					ColumnBuffer::Bool(container) => {
236						for i in 0..row_count {
237							if mask.get(i) {
238								let defined = i < bitvec.len() && bitvec.get(i);
239								let valid = defined && container.is_defined(i);
240								let value = valid && container.data().get(i);
241								mask.set(i, value);
242							}
243						}
244					}
245					_ => panic!("filter expression must evaluate to a boolean column"),
246				},
247				_ => panic!("filter expression must evaluate to a boolean column"),
248			}
249		}
250
251		Ok(Some(mask))
252	}
253}
254
255pub(crate) fn resolve_is_variant_tags(
256	expr: &mut Expression,
257	source: &ResolvedShape,
258	catalog: &Catalog,
259	rx: &mut Transaction<'_>,
260) -> Result<()> {
261	match expr {
262		Expression::IsVariant(e) => {
263			let col_name = match e.expression.as_ref() {
264				Expression::Column(c) => c.0.name.text().to_string(),
265				other => display_label(other).text().to_string(),
266			};
267
268			let tag_col_name = format!("{}_tag", col_name);
269			let columns = source.columns();
270			if let Some(tag_col) = columns.iter().find(|c| c.name == tag_col_name)
271				&& let Some(Constraint::SumType(id)) = tag_col.constraint.constraint()
272			{
273				let def = catalog.get_sumtype(rx, *id)?;
274				let variant_name = e.variant_name.text().to_lowercase();
275				if let Some(variant) =
276					def.variants.iter().find(|v| v.name.to_lowercase() == variant_name)
277				{
278					e.tag = Some(variant.tag);
279				}
280			}
281			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
282		}
283		Expression::And(e) => {
284			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
285			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
286		}
287		Expression::Or(e) => {
288			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
289			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
290		}
291		Expression::Equal(e) => {
292			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
293			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
294		}
295		Expression::NotEqual(e) => {
296			resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
297			resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
298		}
299		Expression::Prefix(e) => {
300			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
301		}
302		Expression::If(e) => {
303			resolve_is_variant_tags(&mut e.condition, source, catalog, rx)?;
304			resolve_is_variant_tags(&mut e.then_expr, source, catalog, rx)?;
305			for else_if in &mut e.else_ifs {
306				resolve_is_variant_tags(&mut else_if.condition, source, catalog, rx)?;
307				resolve_is_variant_tags(&mut else_if.then_expr, source, catalog, rx)?;
308			}
309			if let Some(else_expr) = &mut e.else_expr {
310				resolve_is_variant_tags(else_expr, source, catalog, rx)?;
311			}
312		}
313		Expression::Alias(e) => {
314			resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
315		}
316		_ => {}
317	}
318	Ok(())
319}