reifydb_engine/vm/volcano/
filter.rs1use std::{mem, sync::Arc};
5
6use reifydb_catalog::catalog::Catalog;
7use reifydb_core::{
8 interface::{catalog::dictionary::Dictionary, resolved::ResolvedShape},
9 value::{
10 batch::lazy::LazyBatch,
11 column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12 },
13};
14use reifydb_extension::transform::{Transform, context::TransformContext};
15use reifydb_rql::expression::{Expression, name::display_label};
16use reifydb_transaction::transaction::Transaction;
17use reifydb_type::{util::bitvec::BitVec, value::constraint::Constraint};
18use tracing::instrument;
19
20use super::{NoopNode, decode_dictionary_columns};
21use crate::{
22 Result,
23 expression::{
24 compile::{CompiledExpr, compile_expression},
25 context::{CompileContext, EvalContext},
26 },
27 vm::volcano::{
28 query::{QueryContext, QueryNode},
29 udf::{UdfEvalNode, strip_udf_columns},
30 },
31};
32
33pub(crate) struct FilterNode {
34 input: Box<dyn QueryNode>,
35 expressions: Vec<Expression>,
36 udf_names: Vec<String>,
37 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
38}
39
40impl FilterNode {
41 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
42 Self {
43 input,
44 expressions,
45 udf_names: Vec::new(),
46 context: None,
47 }
48 }
49}
50
51impl QueryNode for FilterNode {
52 #[instrument(level = "trace", skip_all, name = "volcano::filter::initialize")]
53 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
54 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
55 mem::replace(&mut self.input, Box::new(NoopNode)),
56 &self.expressions,
57 &ctx.symbols,
58 );
59 self.input = input;
60 self.expressions = expressions;
61 self.udf_names = udf_names;
62
63 let compile_ctx = CompileContext {
64 symbols: &ctx.symbols,
65 };
66 let compiled = self
67 .expressions
68 .iter()
69 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
70 .collect();
71 self.context = Some((Arc::new(ctx.clone()), compiled));
72 self.input.initialize(rx, ctx)?;
73 Ok(())
74 }
75
76 #[instrument(level = "trace", skip_all, name = "volcano::filter::next")]
77 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78 debug_assert!(self.context.is_some(), "FilterNode::next() called before initialize()");
79 let (stored_ctx, compiled) = self.context.as_ref().unwrap();
80
81 loop {
82 if let Some(mut lazy_batch) = self.input.next_lazy(rx, ctx)? {
84 let filter_result =
86 self.evaluate_filter_on_lazy(&lazy_batch, stored_ctx, compiled, rx)?;
87
88 if let Some(filter_mask) = filter_result {
89 lazy_batch.apply_filter(&filter_mask);
90 }
91
92 if lazy_batch.valid_row_count() == 0 {
93 continue; }
95
96 let dictionaries: Vec<Option<Dictionary>> =
98 lazy_batch.column_metas().iter().map(|m| m.dictionary.clone()).collect();
99
100 let mut columns = lazy_batch.into_columns();
102
103 decode_dictionary_columns(&mut columns, &dictionaries, rx)?;
105
106 strip_udf_columns(&mut columns, &self.udf_names);
107 return Ok(Some(columns));
108 }
109
110 if let Some(columns) = self.input.next(rx, ctx)? {
112 let transform_ctx = TransformContext {
113 routines: &ctx.services.routines,
114 runtime_context: &stored_ctx.services.runtime_context,
115 params: &stored_ctx.params,
116 };
117 let mut columns = self.apply(&transform_ctx, columns)?;
118 if columns.row_count() > 0 {
119 strip_udf_columns(&mut columns, &self.udf_names);
120 return Ok(Some(columns));
121 }
122 } else {
123 return Ok(None);
125 }
126 }
127 }
128
129 fn headers(&self) -> Option<ColumnHeaders> {
130 self.input.headers()
131 }
132}
133
134impl Transform for FilterNode {
135 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
136 let (stored_ctx, compiled) =
137 self.context.as_ref().expect("FilterNode::apply() called before initialize()");
138
139 let session = EvalContext::from_transform(ctx, stored_ctx);
140 let mut columns = input;
141 let mut row_count = columns.row_count();
142
143 for compiled_expr in compiled {
144 if row_count == 0 {
145 break;
146 }
147
148 let exec_ctx = session.with_eval(columns.clone(), row_count);
149
150 let result = compiled_expr.execute(&exec_ctx)?;
151
152 let filter_mask = match result.data() {
153 ColumnBuffer::Bool(container) => {
154 let mut mask = BitVec::repeat(row_count, false);
155 for i in 0..row_count {
156 if i < container.len() {
157 let valid = container.is_defined(i);
158 let filter_result = container.data().get(i);
159 mask.set(i, valid & filter_result);
160 }
161 }
162 mask
163 }
164 ColumnBuffer::Option {
165 inner,
166 bitvec,
167 } => match inner.as_ref() {
168 ColumnBuffer::Bool(container) => {
169 let mut mask = BitVec::repeat(row_count, false);
170 for i in 0..row_count {
171 let defined = i < bitvec.len() && bitvec.get(i);
172 let valid = defined && container.is_defined(i);
173 let value = valid && container.data().get(i);
174 mask.set(i, value);
175 }
176 mask
177 }
178 _ => panic!("filter expression must evaluate to a boolean column"),
179 },
180 _ => panic!("filter expression must evaluate to a boolean column"),
181 };
182
183 columns.filter(&filter_mask)?;
184 row_count = columns.row_count();
185 }
186
187 Ok(columns)
188 }
189}
190
191impl FilterNode {
192 fn evaluate_filter_on_lazy<'a>(
195 &self,
196 lazy_batch: &LazyBatch,
197 ctx: &QueryContext,
198 compiled: &[CompiledExpr],
199 rx: &mut Transaction<'a>,
200 ) -> Result<Option<BitVec>> {
201 let dictionaries: Vec<Option<Dictionary>> =
204 lazy_batch.column_metas().iter().map(|m| m.dictionary.clone()).collect();
205 let mut columns = lazy_batch.clone().into_columns();
206 decode_dictionary_columns(&mut columns, &dictionaries, rx)?;
207 let row_count = columns.row_count();
208
209 if row_count == 0 {
210 return Ok(Some(BitVec::empty()));
211 }
212
213 let session = EvalContext::from_query(ctx);
214 let mut mask = BitVec::repeat(row_count, true);
215
216 for compiled_expr in compiled {
217 let exec_ctx = session.with_eval(columns.clone(), row_count);
218
219 let result = compiled_expr.execute(&exec_ctx)?;
220
221 match result.data() {
222 ColumnBuffer::Bool(container) => {
223 for i in 0..row_count {
224 if mask.get(i) {
225 let valid = container.is_defined(i);
226 let filter_result = container.data().get(i);
227 mask.set(i, valid & filter_result);
228 }
229 }
230 }
231 ColumnBuffer::Option {
232 inner,
233 bitvec,
234 } => match inner.as_ref() {
235 ColumnBuffer::Bool(container) => {
236 for i in 0..row_count {
237 if mask.get(i) {
238 let defined = i < bitvec.len() && bitvec.get(i);
239 let valid = defined && container.is_defined(i);
240 let value = valid && container.data().get(i);
241 mask.set(i, value);
242 }
243 }
244 }
245 _ => panic!("filter expression must evaluate to a boolean column"),
246 },
247 _ => panic!("filter expression must evaluate to a boolean column"),
248 }
249 }
250
251 Ok(Some(mask))
252 }
253}
254
255pub(crate) fn resolve_is_variant_tags(
256 expr: &mut Expression,
257 source: &ResolvedShape,
258 catalog: &Catalog,
259 rx: &mut Transaction<'_>,
260) -> Result<()> {
261 match expr {
262 Expression::IsVariant(e) => {
263 let col_name = match e.expression.as_ref() {
264 Expression::Column(c) => c.0.name.text().to_string(),
265 other => display_label(other).text().to_string(),
266 };
267
268 let tag_col_name = format!("{}_tag", col_name);
269 let columns = source.columns();
270 if let Some(tag_col) = columns.iter().find(|c| c.name == tag_col_name)
271 && let Some(Constraint::SumType(id)) = tag_col.constraint.constraint()
272 {
273 let def = catalog.get_sumtype(rx, *id)?;
274 let variant_name = e.variant_name.text().to_lowercase();
275 if let Some(variant) =
276 def.variants.iter().find(|v| v.name.to_lowercase() == variant_name)
277 {
278 e.tag = Some(variant.tag);
279 }
280 }
281 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
282 }
283 Expression::And(e) => {
284 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
285 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
286 }
287 Expression::Or(e) => {
288 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
289 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
290 }
291 Expression::Equal(e) => {
292 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
293 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
294 }
295 Expression::NotEqual(e) => {
296 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
297 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
298 }
299 Expression::Prefix(e) => {
300 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
301 }
302 Expression::If(e) => {
303 resolve_is_variant_tags(&mut e.condition, source, catalog, rx)?;
304 resolve_is_variant_tags(&mut e.then_expr, source, catalog, rx)?;
305 for else_if in &mut e.else_ifs {
306 resolve_is_variant_tags(&mut else_if.condition, source, catalog, rx)?;
307 resolve_is_variant_tags(&mut else_if.then_expr, source, catalog, rx)?;
308 }
309 if let Some(else_expr) = &mut e.else_expr {
310 resolve_is_variant_tags(else_expr, source, catalog, rx)?;
311 }
312 }
313 Expression::Alias(e) => {
314 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
315 }
316 _ => {}
317 }
318 Ok(())
319}