Skip to main content

reifydb_column/
predicate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{array::Column, buffer::ColumnBuffer, mask::RowMask};
5use reifydb_type::{Result, value::Value};
6
7use crate::{
8	compute::{self, CompareOp},
9	error::ColumnError,
10	selection::Selection,
11	snapshot::{ColumnBlock, ColumnChunks},
12};
13
14#[derive(Clone, Debug, PartialEq, Eq, Hash)]
15pub struct ColRef(pub String);
16
17impl From<&str> for ColRef {
18	fn from(s: &str) -> Self {
19		Self(s.to_string())
20	}
21}
22
23impl From<String> for ColRef {
24	fn from(s: String) -> Self {
25		Self(s)
26	}
27}
28
29#[derive(Clone, Debug)]
30pub enum Predicate {
31	Eq(ColRef, Value),
32	Ne(ColRef, Value),
33	Lt(ColRef, Value),
34	LtEq(ColRef, Value),
35	Gt(ColRef, Value),
36	GtEq(ColRef, Value),
37	In(ColRef, Vec<Value>),
38	IsNone(ColRef),
39	IsNotNone(ColRef),
40	And(Vec<Predicate>),
41	Or(Vec<Predicate>),
42	Not(Box<Predicate>),
43}
44
45// Evaluate a `Predicate` over a `ColumnBlock`, producing a `Selection` that
46// callers can feed to `compute::filter`. Per-column iteration walks each chunk
47// independently and dispatches through `compute::compare`, so encoding-specific
48// kernels stay live - canonicalization happens only on the bool result we have
49// to inspect bit-by-bit, not on the input columns.
50pub fn evaluate(block: &ColumnBlock, predicate: &Predicate) -> Result<Selection> {
51	let len = block.len();
52	let mask = evaluate_mask(block, predicate, len)?;
53	Ok(mask_to_selection(mask))
54}
55
56fn evaluate_mask(block: &ColumnBlock, predicate: &Predicate, len: usize) -> Result<RowMask> {
57	match predicate {
58		Predicate::Eq(col, v) => compare_mask(block, col, v, CompareOp::Eq),
59		Predicate::Ne(col, v) => compare_mask(block, col, v, CompareOp::Ne),
60		Predicate::Lt(col, v) => compare_mask(block, col, v, CompareOp::Lt),
61		Predicate::LtEq(col, v) => compare_mask(block, col, v, CompareOp::LtEq),
62		Predicate::Gt(col, v) => compare_mask(block, col, v, CompareOp::Gt),
63		Predicate::GtEq(col, v) => compare_mask(block, col, v, CompareOp::GtEq),
64		Predicate::In(col, values) => {
65			let mut acc = RowMask::none_set(len);
66			for v in values {
67				acc = acc.or(&compare_mask(block, col, v, CompareOp::Eq)?);
68			}
69			Ok(acc)
70		}
71		Predicate::IsNone(col) => Ok(is_none_mask(column(block, col)?)),
72		Predicate::IsNotNone(col) => Ok(is_none_mask(column(block, col)?).not()),
73		Predicate::And(clauses) => {
74			let mut acc = RowMask::all_set(len);
75			for c in clauses {
76				acc = acc.and(&evaluate_mask(block, c, len)?);
77			}
78			Ok(acc)
79		}
80		Predicate::Or(clauses) => {
81			let mut acc = RowMask::none_set(len);
82			for c in clauses {
83				acc = acc.or(&evaluate_mask(block, c, len)?);
84			}
85			Ok(acc)
86		}
87		Predicate::Not(inner) => Ok(evaluate_mask(block, inner, len)?.not()),
88	}
89}
90
91fn compare_mask(block: &ColumnBlock, col: &ColRef, rhs: &Value, op: CompareOp) -> Result<RowMask> {
92	let ch = column(block, col)?;
93	if ch.chunks.is_empty() {
94		return Ok(RowMask::none_set(0));
95	}
96	let mut parts = Vec::with_capacity(ch.chunks.len());
97	for chunk in &ch.chunks {
98		// `compute::compare` routes through encoding-specific specialization, so
99		// compressed encodings can run the comparison without canonicalizing.
100		let result = compute::compare(chunk, rhs, op)?;
101		parts.push(bool_array_to_mask(&result)?);
102	}
103	Ok(RowMask::concat(&parts))
104}
105
106fn is_none_mask(ch: &ColumnChunks) -> RowMask {
107	let total = ch.len();
108	let mut mask = RowMask::none_set(total);
109	let mut row_offset = 0;
110	for chunk in &ch.chunks {
111		if let Some(nones) = chunk.nones() {
112			for i in 0..chunk.len() {
113				if nones.is_none(i) {
114					mask.set(row_offset + i, true);
115				}
116			}
117		}
118		row_offset += chunk.len();
119	}
120	mask
121}
122
123fn column<'a>(block: &'a ColumnBlock, col: &ColRef) -> Result<&'a ColumnChunks> {
124	block.column_by_name(&col.0).map(|(_, ch)| ch).ok_or_else(|| {
125		ColumnError::ColumnNotInSchema {
126			operation: "predicate::evaluate",
127			name: col.0.clone(),
128		}
129		.into()
130	})
131}
132
133// Convert a bool canonical `Column` to a `RowMask`. None-valued rows count as
134// "not selected" - three-valued-logic collapses to a two-valued mask at the
135// `Selection` boundary.
136fn bool_array_to_mask(array: &Column) -> Result<RowMask> {
137	let canon = array.to_canonical()?;
138	if !matches!(canon.buffer, ColumnBuffer::Bool(_)) {
139		return Err(ColumnError::PredicateCompareNotBool.into());
140	}
141	let len = canon.len();
142	let mut mask = RowMask::none_set(len);
143	let nones = canon.nones.as_ref();
144	for i in 0..len {
145		let is_true = matches!(canon.buffer.get_value(i), Value::Boolean(true));
146		if is_true && !nones.map(|n| n.is_none(i)).unwrap_or(false) {
147			mask.set(i, true);
148		}
149	}
150	Ok(mask)
151}
152
153fn mask_to_selection(mask: RowMask) -> Selection {
154	let kept = mask.popcount();
155	if kept == 0 {
156		Selection::None_
157	} else if kept == mask.len() {
158		Selection::All
159	} else {
160		Selection::Mask(mask)
161	}
162}
163
164#[cfg(test)]
165mod tests {
166	use std::sync::Arc;
167
168	use reifydb_core::value::column::{
169		array::{Column, canonical::Canonical},
170		buffer::ColumnBuffer,
171	};
172	use reifydb_type::value::r#type::Type;
173
174	use super::*;
175
176	fn mkblock(rows: [(i32, bool); 5]) -> ColumnBlock {
177		let ids = ColumnBuffer::int4(rows.map(|(v, _)| v).to_vec());
178		let flags = ColumnBuffer::bool(rows.map(|(_, v)| v).to_vec());
179		let id_col = ColumnChunks::single(
180			Type::Int4,
181			false,
182			Column::from_canonical(Canonical::from_column_buffer(&ids).unwrap()),
183		);
184		let flag_col = ColumnChunks::single(
185			Type::Boolean,
186			false,
187			Column::from_canonical(Canonical::from_column_buffer(&flags).unwrap()),
188		);
189		let schema = Arc::new(vec![
190			("id".to_string(), Type::Int4, false),
191			("flag".to_string(), Type::Boolean, false),
192		]);
193		ColumnBlock::new(schema, vec![id_col, flag_col])
194	}
195
196	#[test]
197	fn evaluate_eq_produces_mask() {
198		let t = mkblock([(1, true), (2, false), (3, true), (2, true), (5, false)]);
199		let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
200		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
201			panic!("expected Mask selection");
202		};
203		assert_eq!(m.popcount(), 2);
204		assert!(m.get(1));
205		assert!(m.get(3));
206	}
207
208	#[test]
209	fn evaluate_all_collapses_to_selection_all() {
210		let t = mkblock([(1, true), (2, true), (3, true), (4, true), (5, true)]);
211		let p = Predicate::GtEq(ColRef::from("id"), Value::Int4(0));
212		assert!(matches!(evaluate(&t, &p).unwrap(), Selection::All));
213	}
214
215	#[test]
216	fn evaluate_none_collapses_to_selection_none() {
217		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
218		let p = Predicate::Lt(ColRef::from("id"), Value::Int4(0));
219		assert!(matches!(evaluate(&t, &p).unwrap(), Selection::None_));
220	}
221
222	#[test]
223	fn evaluate_and_combines_with_intersection() {
224		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
225		let p = Predicate::And(vec![
226			Predicate::Gt(ColRef::from("id"), Value::Int4(1)),
227			Predicate::Eq(ColRef::from("flag"), Value::Boolean(true)),
228		]);
229		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
230			panic!("expected Mask selection");
231		};
232		assert_eq!(m.popcount(), 2);
233		assert!(m.get(2));
234		assert!(m.get(4));
235	}
236
237	#[test]
238	fn evaluate_in_matches_any_value() {
239		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
240		let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
241		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
242			panic!("expected Mask selection");
243		};
244		assert_eq!(m.popcount(), 2);
245		assert!(m.get(1));
246		assert!(m.get(4));
247	}
248
249	#[test]
250	fn evaluate_is_none_on_nullable_column() {
251		let mut nullable_ids = ColumnBuffer::int4_with_capacity(4);
252		nullable_ids.push::<i32>(10);
253		nullable_ids.push_none();
254		nullable_ids.push::<i32>(30);
255		nullable_ids.push_none();
256		let id_col = ColumnChunks::single(
257			Type::Int4,
258			true,
259			Column::from_canonical(Canonical::from_column_buffer(&nullable_ids).unwrap()),
260		);
261		let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
262		let t = ColumnBlock::new(schema, vec![id_col]);
263
264		let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
265			panic!("expected Mask selection");
266		};
267		assert_eq!(m.popcount(), 2);
268		assert!(m.get(1));
269		assert!(m.get(3));
270	}
271
272	fn int4_chunked(parts: &[&[i32]]) -> ColumnChunks {
273		let chunks = parts
274			.iter()
275			.map(|p| {
276				Column::from_canonical(
277					Canonical::from_column_buffer(&ColumnBuffer::int4(p.to_vec())).unwrap(),
278				)
279			})
280			.collect();
281		ColumnChunks::new(Type::Int4, false, chunks)
282	}
283
284	fn mkblock_chunked(id_parts: &[&[i32]]) -> ColumnBlock {
285		let id_col = int4_chunked(id_parts);
286		let schema = Arc::new(vec![("id".to_string(), Type::Int4, false)]);
287		ColumnBlock::new(schema, vec![id_col])
288	}
289
290	#[test]
291	fn evaluate_eq_over_multi_chunk_column() {
292		// id chunks: [1, 2, 3] | [2, 4, 2] | [5, 2]. Looking for id == 2.
293		let t = mkblock_chunked(&[&[1, 2, 3], &[2, 4, 2], &[5, 2]]);
294		let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
295		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
296			panic!("expected Mask selection");
297		};
298		assert_eq!(m.len(), 8);
299		assert_eq!(m.popcount(), 4);
300		assert!(m.get(1));
301		assert!(m.get(3));
302		assert!(m.get(5));
303		assert!(m.get(7));
304	}
305
306	#[test]
307	fn evaluate_and_or_across_multi_chunk_columns() {
308		// Both columns are 2 chunks of length 3. AND/OR must align across chunk boundaries.
309		let id_col = int4_chunked(&[&[1, 2, 3], &[4, 5, 6]]);
310		let other_col = int4_chunked(&[&[10, 20, 10], &[20, 10, 20]]);
311		let schema =
312			Arc::new(vec![("id".to_string(), Type::Int4, false), ("other".to_string(), Type::Int4, false)]);
313		let t = ColumnBlock::new(schema, vec![id_col, other_col]);
314
315		let p = Predicate::And(vec![
316			Predicate::Gt(ColRef::from("id"), Value::Int4(2)),
317			Predicate::Eq(ColRef::from("other"), Value::Int4(20)),
318		]);
319		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
320			panic!("expected Mask selection");
321		};
322		// id > 2 → rows 2,3,4,5; other == 20 → rows 1,3,5. Intersection: rows 3, 5.
323		assert_eq!(m.len(), 6);
324		assert_eq!(m.popcount(), 2);
325		assert!(m.get(3));
326		assert!(m.get(5));
327	}
328
329	#[test]
330	fn evaluate_is_none_across_multi_chunk_nullable() {
331		// Two nullable chunks; nones at row 1 of each chunk → block rows 1 and 4.
332		let mut a = ColumnBuffer::int4_with_capacity(3);
333		a.push::<i32>(10);
334		a.push_none();
335		a.push::<i32>(30);
336		let mut b = ColumnBuffer::int4_with_capacity(3);
337		b.push::<i32>(40);
338		b.push_none();
339		b.push::<i32>(60);
340		let chunks = vec![
341			Column::from_canonical(Canonical::from_column_buffer(&a).unwrap()),
342			Column::from_canonical(Canonical::from_column_buffer(&b).unwrap()),
343		];
344		let id_col = ColumnChunks::new(Type::Int4, true, chunks);
345		let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
346		let t = ColumnBlock::new(schema, vec![id_col]);
347
348		let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
349			panic!("expected Mask selection");
350		};
351		assert_eq!(m.len(), 6);
352		assert_eq!(m.popcount(), 2);
353		assert!(m.get(1));
354		assert!(m.get(4));
355	}
356
357	#[test]
358	fn evaluate_in_across_multi_chunk_column() {
359		let t = mkblock_chunked(&[&[1, 2], &[3, 4], &[5, 6]]);
360		let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
361		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
362			panic!("expected Mask selection");
363		};
364		assert_eq!(m.len(), 6);
365		assert_eq!(m.popcount(), 2);
366		assert!(m.get(1));
367		assert!(m.get(4));
368	}
369}