1use reifydb_core::value::column::{buffer::ColumnBuffer, data::Column, mask::RowMask};
5use reifydb_type::{Result, value::Value};
6
7use crate::{
8 compute::{self, CompareOp},
9 error::ColumnError,
10 selection::Selection,
11 snapshot::{ColumnBlock, ColumnChunks},
12};
13
14#[derive(Clone, Debug, PartialEq, Eq, Hash)]
15pub struct ColRef(pub String);
16
17impl From<&str> for ColRef {
18 fn from(s: &str) -> Self {
19 Self(s.to_string())
20 }
21}
22
23impl From<String> for ColRef {
24 fn from(s: String) -> Self {
25 Self(s)
26 }
27}
28
29#[derive(Clone, Debug)]
30pub enum Predicate {
31 Eq(ColRef, Value),
32 Ne(ColRef, Value),
33 Lt(ColRef, Value),
34 LtEq(ColRef, Value),
35 Gt(ColRef, Value),
36 GtEq(ColRef, Value),
37 In(ColRef, Vec<Value>),
38 IsNone(ColRef),
39 IsNotNone(ColRef),
40 And(Vec<Predicate>),
41 Or(Vec<Predicate>),
42 Not(Box<Predicate>),
43}
44
45pub fn evaluate(block: &ColumnBlock, predicate: &Predicate) -> Result<Selection> {
46 let len = block.len();
47 let mask = evaluate_mask(block, predicate, len)?;
48 Ok(mask_to_selection(mask))
49}
50
51fn evaluate_mask(block: &ColumnBlock, predicate: &Predicate, len: usize) -> Result<RowMask> {
52 match predicate {
53 Predicate::Eq(col, v) => compare_mask(block, col, v, CompareOp::Eq),
54 Predicate::Ne(col, v) => compare_mask(block, col, v, CompareOp::Ne),
55 Predicate::Lt(col, v) => compare_mask(block, col, v, CompareOp::Lt),
56 Predicate::LtEq(col, v) => compare_mask(block, col, v, CompareOp::LtEq),
57 Predicate::Gt(col, v) => compare_mask(block, col, v, CompareOp::Gt),
58 Predicate::GtEq(col, v) => compare_mask(block, col, v, CompareOp::GtEq),
59 Predicate::In(col, values) => {
60 let mut acc = RowMask::none_set(len);
61 for v in values {
62 acc = acc.or(&compare_mask(block, col, v, CompareOp::Eq)?);
63 }
64 Ok(acc)
65 }
66 Predicate::IsNone(col) => Ok(is_none_mask(column(block, col)?)),
67 Predicate::IsNotNone(col) => Ok(is_none_mask(column(block, col)?).not()),
68 Predicate::And(clauses) => {
69 let mut acc = RowMask::all_set(len);
70 for c in clauses {
71 acc = acc.and(&evaluate_mask(block, c, len)?);
72 }
73 Ok(acc)
74 }
75 Predicate::Or(clauses) => {
76 let mut acc = RowMask::none_set(len);
77 for c in clauses {
78 acc = acc.or(&evaluate_mask(block, c, len)?);
79 }
80 Ok(acc)
81 }
82 Predicate::Not(inner) => Ok(evaluate_mask(block, inner, len)?.not()),
83 }
84}
85
86fn compare_mask(block: &ColumnBlock, col: &ColRef, rhs: &Value, op: CompareOp) -> Result<RowMask> {
87 let ch = column(block, col)?;
88 if ch.chunks.is_empty() {
89 return Ok(RowMask::none_set(0));
90 }
91 let mut parts = Vec::with_capacity(ch.chunks.len());
92 for chunk in &ch.chunks {
93 let result = compute::compare(chunk, rhs, op)?;
94 parts.push(bool_array_to_mask(&result)?);
95 }
96 Ok(RowMask::concat(&parts))
97}
98
99fn is_none_mask(ch: &ColumnChunks) -> RowMask {
100 let total = ch.len();
101 let mut mask = RowMask::none_set(total);
102 let mut row_offset = 0;
103 for chunk in &ch.chunks {
104 if let Some(nones) = chunk.nones() {
105 for i in 0..chunk.len() {
106 if nones.is_none(i) {
107 mask.set(row_offset + i, true);
108 }
109 }
110 }
111 row_offset += chunk.len();
112 }
113 mask
114}
115
116fn column<'a>(block: &'a ColumnBlock, col: &ColRef) -> Result<&'a ColumnChunks> {
117 block.column_by_name(&col.0).map(|(_, ch)| ch).ok_or_else(|| {
118 ColumnError::ColumnNotInSchema {
119 operation: "predicate::evaluate",
120 name: col.0.clone(),
121 }
122 .into()
123 })
124}
125
126fn bool_array_to_mask(array: &Column) -> Result<RowMask> {
127 let canon = array.to_canonical()?;
128 if !matches!(canon.buffer, ColumnBuffer::Bool(_)) {
129 return Err(ColumnError::PredicateCompareNotBool.into());
130 }
131 let len = canon.len();
132 let mut mask = RowMask::none_set(len);
133 let nones = canon.nones.as_ref();
134 for i in 0..len {
135 let is_true = matches!(canon.buffer.get_value(i), Value::Boolean(true));
136 if is_true && !nones.map(|n| n.is_none(i)).unwrap_or(false) {
137 mask.set(i, true);
138 }
139 }
140 Ok(mask)
141}
142
143fn mask_to_selection(mask: RowMask) -> Selection {
144 let kept = mask.popcount();
145 if kept == 0 {
146 Selection::None_
147 } else if kept == mask.len() {
148 Selection::All
149 } else {
150 Selection::Mask(mask)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use std::sync::Arc;
157
158 use reifydb_core::value::column::{
159 buffer::ColumnBuffer,
160 data::{Column, canonical::Canonical},
161 };
162 use reifydb_type::value::r#type::Type;
163
164 use super::*;
165
166 fn mkblock(rows: [(i32, bool); 5]) -> ColumnBlock {
167 let ids = ColumnBuffer::int4(rows.map(|(v, _)| v).to_vec());
168 let flags = ColumnBuffer::bool(rows.map(|(_, v)| v).to_vec());
169 let id_col = ColumnChunks::single(
170 Type::Int4,
171 false,
172 Column::from_canonical(Canonical::from_column_buffer(&ids).unwrap()),
173 );
174 let flag_col = ColumnChunks::single(
175 Type::Boolean,
176 false,
177 Column::from_canonical(Canonical::from_column_buffer(&flags).unwrap()),
178 );
179 let schema = Arc::new(vec![
180 ("id".to_string(), Type::Int4, false),
181 ("flag".to_string(), Type::Boolean, false),
182 ]);
183 ColumnBlock::new(schema, vec![id_col, flag_col])
184 }
185
186 #[test]
187 fn evaluate_eq_produces_mask() {
188 let t = mkblock([(1, true), (2, false), (3, true), (2, true), (5, false)]);
189 let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
190 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
191 panic!("expected Mask selection");
192 };
193 assert_eq!(m.popcount(), 2);
194 assert!(m.get(1));
195 assert!(m.get(3));
196 }
197
198 #[test]
199 fn evaluate_all_collapses_to_selection_all() {
200 let t = mkblock([(1, true), (2, true), (3, true), (4, true), (5, true)]);
201 let p = Predicate::GtEq(ColRef::from("id"), Value::Int4(0));
202 assert!(matches!(evaluate(&t, &p).unwrap(), Selection::All));
203 }
204
205 #[test]
206 fn evaluate_none_collapses_to_selection_none() {
207 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
208 let p = Predicate::Lt(ColRef::from("id"), Value::Int4(0));
209 assert!(matches!(evaluate(&t, &p).unwrap(), Selection::None_));
210 }
211
212 #[test]
213 fn evaluate_and_combines_with_intersection() {
214 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
215 let p = Predicate::And(vec![
216 Predicate::Gt(ColRef::from("id"), Value::Int4(1)),
217 Predicate::Eq(ColRef::from("flag"), Value::Boolean(true)),
218 ]);
219 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
220 panic!("expected Mask selection");
221 };
222 assert_eq!(m.popcount(), 2);
223 assert!(m.get(2));
224 assert!(m.get(4));
225 }
226
227 #[test]
228 fn evaluate_in_matches_any_value() {
229 let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
230 let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
231 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
232 panic!("expected Mask selection");
233 };
234 assert_eq!(m.popcount(), 2);
235 assert!(m.get(1));
236 assert!(m.get(4));
237 }
238
239 #[test]
240 fn evaluate_is_none_on_nullable_column() {
241 let mut nullable_ids = ColumnBuffer::int4_with_capacity(4);
242 nullable_ids.push::<i32>(10);
243 nullable_ids.push_none();
244 nullable_ids.push::<i32>(30);
245 nullable_ids.push_none();
246 let id_col = ColumnChunks::single(
247 Type::Int4,
248 true,
249 Column::from_canonical(Canonical::from_column_buffer(&nullable_ids).unwrap()),
250 );
251 let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
252 let t = ColumnBlock::new(schema, vec![id_col]);
253
254 let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
255 panic!("expected Mask selection");
256 };
257 assert_eq!(m.popcount(), 2);
258 assert!(m.get(1));
259 assert!(m.get(3));
260 }
261
262 fn int4_chunked(parts: &[&[i32]]) -> ColumnChunks {
263 let chunks = parts
264 .iter()
265 .map(|p| {
266 Column::from_canonical(
267 Canonical::from_column_buffer(&ColumnBuffer::int4(p.to_vec())).unwrap(),
268 )
269 })
270 .collect();
271 ColumnChunks::new(Type::Int4, false, chunks)
272 }
273
274 fn mkblock_chunked(id_parts: &[&[i32]]) -> ColumnBlock {
275 let id_col = int4_chunked(id_parts);
276 let schema = Arc::new(vec![("id".to_string(), Type::Int4, false)]);
277 ColumnBlock::new(schema, vec![id_col])
278 }
279
280 #[test]
281 fn evaluate_eq_over_multi_chunk_column() {
282 let t = mkblock_chunked(&[&[1, 2, 3], &[2, 4, 2], &[5, 2]]);
284 let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
285 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
286 panic!("expected Mask selection");
287 };
288 assert_eq!(m.len(), 8);
289 assert_eq!(m.popcount(), 4);
290 assert!(m.get(1));
291 assert!(m.get(3));
292 assert!(m.get(5));
293 assert!(m.get(7));
294 }
295
296 #[test]
297 fn evaluate_and_or_across_multi_chunk_columns() {
298 let id_col = int4_chunked(&[&[1, 2, 3], &[4, 5, 6]]);
300 let other_col = int4_chunked(&[&[10, 20, 10], &[20, 10, 20]]);
301 let schema =
302 Arc::new(vec![("id".to_string(), Type::Int4, false), ("other".to_string(), Type::Int4, false)]);
303 let t = ColumnBlock::new(schema, vec![id_col, other_col]);
304
305 let p = Predicate::And(vec![
306 Predicate::Gt(ColRef::from("id"), Value::Int4(2)),
307 Predicate::Eq(ColRef::from("other"), Value::Int4(20)),
308 ]);
309 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
310 panic!("expected Mask selection");
311 };
312 assert_eq!(m.len(), 6);
314 assert_eq!(m.popcount(), 2);
315 assert!(m.get(3));
316 assert!(m.get(5));
317 }
318
319 #[test]
320 fn evaluate_is_none_across_multi_chunk_nullable() {
321 let mut a = ColumnBuffer::int4_with_capacity(3);
323 a.push::<i32>(10);
324 a.push_none();
325 a.push::<i32>(30);
326 let mut b = ColumnBuffer::int4_with_capacity(3);
327 b.push::<i32>(40);
328 b.push_none();
329 b.push::<i32>(60);
330 let chunks = vec![
331 Column::from_canonical(Canonical::from_column_buffer(&a).unwrap()),
332 Column::from_canonical(Canonical::from_column_buffer(&b).unwrap()),
333 ];
334 let id_col = ColumnChunks::new(Type::Int4, true, chunks);
335 let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
336 let t = ColumnBlock::new(schema, vec![id_col]);
337
338 let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
339 panic!("expected Mask selection");
340 };
341 assert_eq!(m.len(), 6);
342 assert_eq!(m.popcount(), 2);
343 assert!(m.get(1));
344 assert!(m.get(4));
345 }
346
347 #[test]
348 fn evaluate_in_across_multi_chunk_column() {
349 let t = mkblock_chunked(&[&[1, 2], &[3, 4], &[5, 6]]);
350 let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
351 let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
352 panic!("expected Mask selection");
353 };
354 assert_eq!(m.len(), 6);
355 assert_eq!(m.popcount(), 2);
356 assert!(m.get(1));
357 assert!(m.get(4));
358 }
359}