reifydb_engine/vm/volcano/
filter.rs1use std::{mem, sync::Arc};
5
6use reifydb_catalog::catalog::Catalog;
7use reifydb_core::{
8 interface::resolved::ResolvedShape,
9 value::column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
10};
11use reifydb_extension::transform::{Transform, context::TransformContext};
12use reifydb_rql::expression::{Expression, name::display_label};
13use reifydb_transaction::transaction::Transaction;
14use reifydb_type::{util::bitvec::BitVec, value::constraint::Constraint};
15use tracing::instrument;
16
17use super::NoopNode;
18use crate::{
19 Result,
20 expression::{
21 compile::{CompiledExpr, compile_expression},
22 context::{CompileContext, EvalContext},
23 },
24 vm::volcano::{
25 query::{QueryContext, QueryNode},
26 udf::{UdfEvalNode, strip_udf_columns},
27 },
28};
29
30pub(crate) struct FilterNode {
31 input: Box<dyn QueryNode>,
32 expressions: Vec<Expression>,
33 udf_names: Vec<String>,
34 context: Option<(Arc<QueryContext>, Vec<CompiledExpr>)>,
35}
36
37impl FilterNode {
38 pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>) -> Self {
39 Self {
40 input,
41 expressions,
42 udf_names: Vec::new(),
43 context: None,
44 }
45 }
46}
47
48impl QueryNode for FilterNode {
49 #[instrument(level = "trace", skip_all, name = "volcano::filter::initialize")]
50 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
51 let (input, expressions, udf_names) = UdfEvalNode::wrap_if_needed(
52 mem::replace(&mut self.input, Box::new(NoopNode)),
53 &self.expressions,
54 &ctx.symbols,
55 );
56 self.input = input;
57 self.expressions = expressions;
58 self.udf_names = udf_names;
59
60 let compile_ctx = CompileContext {
61 symbols: &ctx.symbols,
62 };
63 let compiled = self
64 .expressions
65 .iter()
66 .map(|e| compile_expression(&compile_ctx, e).expect("compile"))
67 .collect();
68 self.context = Some((Arc::new(ctx.clone()), compiled));
69 self.input.initialize(rx, ctx)?;
70 Ok(())
71 }
72
73 #[instrument(level = "trace", skip_all, name = "volcano::filter::next")]
74 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
75 debug_assert!(self.context.is_some(), "FilterNode::next() called before initialize()");
76 let (stored_ctx, _) = self.context.as_ref().unwrap();
77 let stored_ctx = stored_ctx.clone();
78
79 loop {
80 match self.input.next(rx, ctx)? {
81 Some(columns) => {
82 let transform_ctx = TransformContext {
83 routines: &ctx.services.routines,
84 runtime_context: &stored_ctx.services.runtime_context,
85 params: &stored_ctx.params,
86 };
87 let mut columns = self.apply(&transform_ctx, columns)?;
88 if columns.row_count() > 0 {
89 strip_udf_columns(&mut columns, &self.udf_names);
90 return Ok(Some(columns));
91 }
92 }
93 None => return Ok(None),
94 }
95 }
96 }
97
98 fn headers(&self) -> Option<ColumnHeaders> {
99 self.input.headers()
100 }
101}
102
103impl Transform for FilterNode {
104 fn apply(&self, ctx: &TransformContext, input: Columns) -> Result<Columns> {
105 let (stored_ctx, compiled) =
106 self.context.as_ref().expect("FilterNode::apply() called before initialize()");
107
108 let session = EvalContext::from_transform(ctx, stored_ctx);
109 let mut columns = input;
110 let mut row_count = columns.row_count();
111
112 for compiled_expr in compiled {
113 if row_count == 0 {
114 break;
115 }
116
117 let exec_ctx = session.with_eval(columns.clone(), row_count);
118
119 let result = compiled_expr.execute(&exec_ctx)?;
120
121 let filter_mask = match result.data() {
122 ColumnBuffer::Bool(container) => {
123 let mut mask = BitVec::repeat(row_count, false);
124 for i in 0..row_count {
125 if i < container.len() {
126 let valid = container.is_defined(i);
127 let filter_result = container.data().get(i);
128 mask.set(i, valid & filter_result);
129 }
130 }
131 mask
132 }
133 ColumnBuffer::Option {
134 inner,
135 bitvec,
136 } => match inner.as_ref() {
137 ColumnBuffer::Bool(container) => {
138 let mut mask = BitVec::repeat(row_count, false);
139 for i in 0..row_count {
140 let defined = i < bitvec.len() && bitvec.get(i);
141 let valid = defined && container.is_defined(i);
142 let value = valid && container.data().get(i);
143 mask.set(i, value);
144 }
145 mask
146 }
147 _ => panic!("filter expression must evaluate to a boolean column"),
148 },
149 _ => panic!("filter expression must evaluate to a boolean column"),
150 };
151
152 columns.filter(&filter_mask)?;
153 row_count = columns.row_count();
154 }
155
156 Ok(columns)
157 }
158}
159
160pub(crate) fn resolve_is_variant_tags(
161 expr: &mut Expression,
162 source: &ResolvedShape,
163 catalog: &Catalog,
164 rx: &mut Transaction<'_>,
165) -> Result<()> {
166 match expr {
167 Expression::IsVariant(e) => {
168 let col_name = match e.expression.as_ref() {
169 Expression::Column(c) => c.0.name.text().to_string(),
170 other => display_label(other).text().to_string(),
171 };
172
173 let tag_col_name = format!("{}_tag", col_name);
174 let columns = source.columns();
175 if let Some(tag_col) = columns.iter().find(|c| c.name == tag_col_name)
176 && let Some(Constraint::SumType(id)) = tag_col.constraint.constraint()
177 {
178 let def = catalog.get_sumtype(rx, *id)?;
179 let variant_name = e.variant_name.text().to_lowercase();
180 if let Some(variant) =
181 def.variants.iter().find(|v| v.name.to_lowercase() == variant_name)
182 {
183 e.tag = Some(variant.tag);
184 }
185 }
186 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
187 }
188 Expression::And(e) => {
189 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
190 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
191 }
192 Expression::Or(e) => {
193 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
194 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
195 }
196 Expression::Equal(e) => {
197 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
198 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
199 }
200 Expression::NotEqual(e) => {
201 resolve_is_variant_tags(&mut e.left, source, catalog, rx)?;
202 resolve_is_variant_tags(&mut e.right, source, catalog, rx)?;
203 }
204 Expression::Prefix(e) => {
205 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
206 }
207 Expression::If(e) => {
208 resolve_is_variant_tags(&mut e.condition, source, catalog, rx)?;
209 resolve_is_variant_tags(&mut e.then_expr, source, catalog, rx)?;
210 for else_if in &mut e.else_ifs {
211 resolve_is_variant_tags(&mut else_if.condition, source, catalog, rx)?;
212 resolve_is_variant_tags(&mut else_if.then_expr, source, catalog, rx)?;
213 }
214 if let Some(else_expr) = &mut e.else_expr {
215 resolve_is_variant_tags(else_expr, source, catalog, rx)?;
216 }
217 }
218 Expression::Alias(e) => {
219 resolve_is_variant_tags(&mut e.expression, source, catalog, rx)?;
220 }
221 _ => {}
222 }
223 Ok(())
224}