1use reifydb_core::value::column::{array::Column, buffer::ColumnBuffer, mask::RowMask};
5use reifydb_type::{Result, value::Value};
6
7use crate::{
8 compute::{self, CompareOp},
9 error::ColumnError,
10 selection::Selection,
11 snapshot::{ColumnBlock, ColumnChunks},
12};
13
14#[derive(Clone, Debug, PartialEq, Eq, Hash)]
15pub struct ColRef(pub String);
16
17impl From<&str> for ColRef {
18 fn from(s: &str) -> Self {
19 Self(s.to_string())
20 }
21}
22
23impl From<String> for ColRef {
24 fn from(s: String) -> Self {
25 Self(s)
26 }
27}
28
29#[derive(Clone, Debug)]
30pub enum Predicate {
31 Eq(ColRef, Value),
32 Ne(ColRef, Value),
33 Lt(ColRef, Value),
34 LtEq(ColRef, Value),
35 Gt(ColRef, Value),
36 GtEq(ColRef, Value),
37 In(ColRef, Vec<Value>),
38 IsNone(ColRef),
39 IsNotNone(ColRef),
40 And(Vec<Predicate>),
41 Or(Vec<Predicate>),
42 Not(Box<Predicate>),
43}
44
45pub fn evaluate(block: &ColumnBlock, predicate: &Predicate) -> Result<Selection> {
51 let len = block.len();
52 let mask = evaluate_mask(block, predicate, len)?;
53 Ok(mask_to_selection(mask))
54}
55
56fn evaluate_mask(block: &ColumnBlock, predicate: &Predicate, len: usize) -> Result<RowMask> {
57 match predicate {
58 Predicate::Eq(col, v) => compare_mask(block, col, v, CompareOp::Eq),
59 Predicate::Ne(col, v) => compare_mask(block, col, v, CompareOp::Ne),
60 Predicate::Lt(col, v) => compare_mask(block, col, v, CompareOp::Lt),
61 Predicate::LtEq(col, v) => compare_mask(block, col, v, CompareOp::LtEq),
62 Predicate::Gt(col, v) => compare_mask(block, col, v, CompareOp::Gt),
63 Predicate::GtEq(col, v) => compare_mask(block, col, v, CompareOp::GtEq),
64 Predicate::In(col, values) => {
65 let mut acc = RowMask::none_set(len);
66 for v in values {
67 acc = acc.or(&compare_mask(block, col, v, CompareOp::Eq)?);
68 }
69 Ok(acc)
70 }
71 Predicate::IsNone(col) => Ok(is_none_mask(column(block, col)?)),
72 Predicate::IsNotNone(col) => Ok(is_none_mask(column(block, col)?).not()),
73 Predicate::And(clauses) => {
74 let mut acc = RowMask::all_set(len);
75 for c in clauses {
76 acc = acc.and(&evaluate_mask(block, c, len)?);
77 }
78 Ok(acc)
79 }
80 Predicate::Or(clauses) => {
81 let mut acc = RowMask::none_set(len);
82 for c in clauses {
83 acc = acc.or(&evaluate_mask(block, c, len)?);
84 }
85 Ok(acc)
86 }
87 Predicate::Not(inner) => Ok(evaluate_mask(block, inner, len)?.not()),
88 }
89}
90
91fn compare_mask(block: &ColumnBlock, col: &ColRef, rhs: &Value, op: CompareOp) -> Result<RowMask> {
92 let ch = column(block, col)?;
93 if ch.chunks.is_empty() {
94 return Ok(RowMask::none_set(0));
95 }
96 let mut parts = Vec::with_capacity(ch.chunks.len());
97 for chunk in &ch.chunks {
98 let result = compute::compare(chunk, rhs, op)?;
101 parts.push(bool_array_to_mask(&result)?);
102 }
103 Ok(RowMask::concat(&parts))
104}
105
106fn is_none_mask(ch: &ColumnChunks) -> RowMask {
107 let total = ch.len();
108 let mut mask = RowMask::none_set(total);
109 let mut row_offset = 0;
110 for chunk in &ch.chunks {
111 if let Some(nones) = chunk.nones() {
112 for i in 0..chunk.len() {
113 if nones.is_none(i) {
114 mask.set(row_offset + i, true);
115 }
116 }
117 }
118 row_offset += chunk.len();
119 }
120 mask
121}
122
123fn column<'a>(block: &'a ColumnBlock, col: &ColRef) -> Result<&'a ColumnChunks> {
124 block.column_by_name(&col.0).map(|(_, ch)| ch).ok_or_else(|| {
125 ColumnError::ColumnNotInSchema {
126 operation: "predicate::evaluate",
127 name: col.0.clone(),
128 }
129 .into()
130 })
131}
132
133fn bool_array_to_mask(array: &Column) -> Result<RowMask> {
137 let canon = array.to_canonical()?;
138 if !matches!(canon.buffer, ColumnBuffer::Bool(_)) {
139 return Err(ColumnError::PredicateCompareNotBool.into());
140 }
141 let len = canon.len();
142 let mut mask = RowMask::none_set(len);
143 let nones = canon.nones.as_ref();
144 for i in 0..len {
145 let is_true = matches!(canon.buffer.get_value(i), Value::Boolean(true));
146 if is_true && !nones.map(|n| n.is_none(i)).unwrap_or(false) {
147 mask.set(i, true);
148 }
149 }
150 Ok(mask)
151}
152
153fn mask_to_selection(mask: RowMask) -> Selection {
154 let kept = mask.popcount();
155 if kept == 0 {
156 Selection::None_
157 } else if kept == mask.len() {
158 Selection::All
159 } else {
160 Selection::Mask(mask)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::Arc;
167
168 use reifydb_core::value::column::{
169 array::{Column, canonical::Canonical},
170 buffer::ColumnBuffer,
171 };
172 use reifydb_type::value::r#type::Type;
173
174 use super::*;
175
176 fn mkblock(rows: [(i32, bool); 5]) -> ColumnBlock {
177 let ids = ColumnBuffer::int4(rows.map(|(v, _)| v).to_vec());
178 let flags = ColumnBuffer::bool(rows.map(|(_, v)| v).to_vec());
179 let id_col = ColumnChunks::single(
180 Type::Int4,
181 false,
182 Column::from_canonical(Canonical::from_column_buffer(&ids).unwrap()),
183 );
184 let flag_col = ColumnChunks::single(
185 Type::Boolean,
186 false,
187 Column::from_canonical(Canonical::from_column_buffer(&flags).unwrap()),
188 );
189 let schema = Arc::new(vec![
190 ("id".to_string(), Type::Int4, false),
191 ("flag".to_string(), Type::Boolean, false),
192 ]);
193 ColumnBlock::new(schema, vec![id_col, flag_col])
194 }
195
196 #[test]
197 fn evaluate_eq_produces_mask() {
198 let t = mkblock([(1, true), (2, false), (3, true), (2, true), (5, false)]);
199 let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
200 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
201 panic!("expected Mask selection");
202 };
203 assert_eq!(m.popcount(), 2);
204 assert!(m.get(1));
205 assert!(m.get(3));
206 }
207
208 #[test]
209 fn evaluate_all_collapses_to_selection_all() {
210 let t = mkblock([(1, true), (2, true), (3, true), (4, true), (5, true)]);
211 let p = Predicate::GtEq(ColRef::from("id"), Value::Int4(0));
212 assert!(matches!(evaluate(&t, &p).unwrap(), Selection::All));
213 }
214
215 #[test]
216 fn evaluate_none_collapses_to_selection_none() {
217 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
218 let p = Predicate::Lt(ColRef::from("id"), Value::Int4(0));
219 assert!(matches!(evaluate(&t, &p).unwrap(), Selection::None_));
220 }
221
222 #[test]
223 fn evaluate_and_combines_with_intersection() {
224 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
225 let p = Predicate::And(vec![
226 Predicate::Gt(ColRef::from("id"), Value::Int4(1)),
227 Predicate::Eq(ColRef::from("flag"), Value::Boolean(true)),
228 ]);
229 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
230 panic!("expected Mask selection");
231 };
232 assert_eq!(m.popcount(), 2);
233 assert!(m.get(2));
234 assert!(m.get(4));
235 }
236
237 #[test]
238 fn evaluate_in_matches_any_value() {
239 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
240 let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
241 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
242 panic!("expected Mask selection");
243 };
244 assert_eq!(m.popcount(), 2);
245 assert!(m.get(1));
246 assert!(m.get(4));
247 }
248
249 #[test]
250 fn evaluate_is_none_on_nullable_column() {
251 let mut nullable_ids = ColumnBuffer::int4_with_capacity(4);
252 nullable_ids.push::<i32>(10);
253 nullable_ids.push_none();
254 nullable_ids.push::<i32>(30);
255 nullable_ids.push_none();
256 let id_col = ColumnChunks::single(
257 Type::Int4,
258 true,
259 Column::from_canonical(Canonical::from_column_buffer(&nullable_ids).unwrap()),
260 );
261 let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
262 let t = ColumnBlock::new(schema, vec![id_col]);
263
264 let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
265 panic!("expected Mask selection");
266 };
267 assert_eq!(m.popcount(), 2);
268 assert!(m.get(1));
269 assert!(m.get(3));
270 }
271
272 fn int4_chunked(parts: &[&[i32]]) -> ColumnChunks {
273 let chunks = parts
274 .iter()
275 .map(|p| {
276 Column::from_canonical(
277 Canonical::from_column_buffer(&ColumnBuffer::int4(p.to_vec())).unwrap(),
278 )
279 })
280 .collect();
281 ColumnChunks::new(Type::Int4, false, chunks)
282 }
283
284 fn mkblock_chunked(id_parts: &[&[i32]]) -> ColumnBlock {
285 let id_col = int4_chunked(id_parts);
286 let schema = Arc::new(vec![("id".to_string(), Type::Int4, false)]);
287 ColumnBlock::new(schema, vec![id_col])
288 }
289
290 #[test]
291 fn evaluate_eq_over_multi_chunk_column() {
292 let t = mkblock_chunked(&[&[1, 2, 3], &[2, 4, 2], &[5, 2]]);
294 let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
295 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
296 panic!("expected Mask selection");
297 };
298 assert_eq!(m.len(), 8);
299 assert_eq!(m.popcount(), 4);
300 assert!(m.get(1));
301 assert!(m.get(3));
302 assert!(m.get(5));
303 assert!(m.get(7));
304 }
305
306 #[test]
307 fn evaluate_and_or_across_multi_chunk_columns() {
308 let id_col = int4_chunked(&[&[1, 2, 3], &[4, 5, 6]]);
310 let other_col = int4_chunked(&[&[10, 20, 10], &[20, 10, 20]]);
311 let schema =
312 Arc::new(vec![("id".to_string(), Type::Int4, false), ("other".to_string(), Type::Int4, false)]);
313 let t = ColumnBlock::new(schema, vec![id_col, other_col]);
314
315 let p = Predicate::And(vec![
316 Predicate::Gt(ColRef::from("id"), Value::Int4(2)),
317 Predicate::Eq(ColRef::from("other"), Value::Int4(20)),
318 ]);
319 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
320 panic!("expected Mask selection");
321 };
322 assert_eq!(m.len(), 6);
324 assert_eq!(m.popcount(), 2);
325 assert!(m.get(3));
326 assert!(m.get(5));
327 }
328
329 #[test]
330 fn evaluate_is_none_across_multi_chunk_nullable() {
331 let mut a = ColumnBuffer::int4_with_capacity(3);
333 a.push::<i32>(10);
334 a.push_none();
335 a.push::<i32>(30);
336 let mut b = ColumnBuffer::int4_with_capacity(3);
337 b.push::<i32>(40);
338 b.push_none();
339 b.push::<i32>(60);
340 let chunks = vec![
341 Column::from_canonical(Canonical::from_column_buffer(&a).unwrap()),
342 Column::from_canonical(Canonical::from_column_buffer(&b).unwrap()),
343 ];
344 let id_col = ColumnChunks::new(Type::Int4, true, chunks);
345 let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
346 let t = ColumnBlock::new(schema, vec![id_col]);
347
348 let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
349 panic!("expected Mask selection");
350 };
351 assert_eq!(m.len(), 6);
352 assert_eq!(m.popcount(), 2);
353 assert!(m.get(1));
354 assert!(m.get(4));
355 }
356
357 #[test]
358 fn evaluate_in_across_multi_chunk_column() {
359 let t = mkblock_chunked(&[&[1, 2], &[3, 4], &[5, 6]]);
360 let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
361 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
362 panic!("expected Mask selection");
363 };
364 assert_eq!(m.len(), 6);
365 assert_eq!(m.popcount(), 2);
366 assert!(m.get(1));
367 assert!(m.get(4));
368 }
369}