Skip to main content

reifydb_column/
predicate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{buffer::ColumnBuffer, data::Column, mask::RowMask};
5use reifydb_type::{Result, value::Value};
6
7use crate::{
8	compute::{self, CompareOp},
9	error::ColumnError,
10	selection::Selection,
11	snapshot::{ColumnBlock, ColumnChunks},
12};
13
14#[derive(Clone, Debug, PartialEq, Eq, Hash)]
15pub struct ColRef(pub String);
16
17impl From<&str> for ColRef {
18	fn from(s: &str) -> Self {
19		Self(s.to_string())
20	}
21}
22
23impl From<String> for ColRef {
24	fn from(s: String) -> Self {
25		Self(s)
26	}
27}
28
29#[derive(Clone, Debug)]
30pub enum Predicate {
31	Eq(ColRef, Value),
32	Ne(ColRef, Value),
33	Lt(ColRef, Value),
34	LtEq(ColRef, Value),
35	Gt(ColRef, Value),
36	GtEq(ColRef, Value),
37	In(ColRef, Vec<Value>),
38	IsNone(ColRef),
39	IsNotNone(ColRef),
40	And(Vec<Predicate>),
41	Or(Vec<Predicate>),
42	Not(Box<Predicate>),
43}
44
45pub fn evaluate(block: &ColumnBlock, predicate: &Predicate) -> Result<Selection> {
46	let len = block.len();
47	let mask = evaluate_mask(block, predicate, len)?;
48	Ok(mask_to_selection(mask))
49}
50
51fn evaluate_mask(block: &ColumnBlock, predicate: &Predicate, len: usize) -> Result<RowMask> {
52	match predicate {
53		Predicate::Eq(col, v) => compare_mask(block, col, v, CompareOp::Eq),
54		Predicate::Ne(col, v) => compare_mask(block, col, v, CompareOp::Ne),
55		Predicate::Lt(col, v) => compare_mask(block, col, v, CompareOp::Lt),
56		Predicate::LtEq(col, v) => compare_mask(block, col, v, CompareOp::LtEq),
57		Predicate::Gt(col, v) => compare_mask(block, col, v, CompareOp::Gt),
58		Predicate::GtEq(col, v) => compare_mask(block, col, v, CompareOp::GtEq),
59		Predicate::In(col, values) => {
60			let mut acc = RowMask::none_set(len);
61			for v in values {
62				acc = acc.or(&compare_mask(block, col, v, CompareOp::Eq)?);
63			}
64			Ok(acc)
65		}
66		Predicate::IsNone(col) => Ok(is_none_mask(column(block, col)?)),
67		Predicate::IsNotNone(col) => Ok(is_none_mask(column(block, col)?).not()),
68		Predicate::And(clauses) => {
69			let mut acc = RowMask::all_set(len);
70			for c in clauses {
71				acc = acc.and(&evaluate_mask(block, c, len)?);
72			}
73			Ok(acc)
74		}
75		Predicate::Or(clauses) => {
76			let mut acc = RowMask::none_set(len);
77			for c in clauses {
78				acc = acc.or(&evaluate_mask(block, c, len)?);
79			}
80			Ok(acc)
81		}
82		Predicate::Not(inner) => Ok(evaluate_mask(block, inner, len)?.not()),
83	}
84}
85
86fn compare_mask(block: &ColumnBlock, col: &ColRef, rhs: &Value, op: CompareOp) -> Result<RowMask> {
87	let ch = column(block, col)?;
88	if ch.chunks.is_empty() {
89		return Ok(RowMask::none_set(0));
90	}
91	let mut parts = Vec::with_capacity(ch.chunks.len());
92	for chunk in &ch.chunks {
93		let result = compute::compare(chunk, rhs, op)?;
94		parts.push(bool_array_to_mask(&result)?);
95	}
96	Ok(RowMask::concat(&parts))
97}
98
99fn is_none_mask(ch: &ColumnChunks) -> RowMask {
100	let total = ch.len();
101	let mut mask = RowMask::none_set(total);
102	let mut row_offset = 0;
103	for chunk in &ch.chunks {
104		if let Some(nones) = chunk.nones() {
105			for i in 0..chunk.len() {
106				if nones.is_none(i) {
107					mask.set(row_offset + i, true);
108				}
109			}
110		}
111		row_offset += chunk.len();
112	}
113	mask
114}
115
116fn column<'a>(block: &'a ColumnBlock, col: &ColRef) -> Result<&'a ColumnChunks> {
117	block.column_by_name(&col.0).map(|(_, ch)| ch).ok_or_else(|| {
118		ColumnError::ColumnNotInSchema {
119			operation: "predicate::evaluate",
120			name: col.0.clone(),
121		}
122		.into()
123	})
124}
125
126fn bool_array_to_mask(array: &Column) -> Result<RowMask> {
127	let canon = array.to_canonical()?;
128	if !matches!(canon.buffer, ColumnBuffer::Bool(_)) {
129		return Err(ColumnError::PredicateCompareNotBool.into());
130	}
131	let len = canon.len();
132	let mut mask = RowMask::none_set(len);
133	let nones = canon.nones.as_ref();
134	for i in 0..len {
135		let is_true = matches!(canon.buffer.get_value(i), Value::Boolean(true));
136		if is_true && !nones.map(|n| n.is_none(i)).unwrap_or(false) {
137			mask.set(i, true);
138		}
139	}
140	Ok(mask)
141}
142
143fn mask_to_selection(mask: RowMask) -> Selection {
144	let kept = mask.popcount();
145	if kept == 0 {
146		Selection::None_
147	} else if kept == mask.len() {
148		Selection::All
149	} else {
150		Selection::Mask(mask)
151	}
152}
153
154#[cfg(test)]
155mod tests {
156	use std::sync::Arc;
157
158	use reifydb_core::value::column::{
159		buffer::ColumnBuffer,
160		data::{Column, canonical::Canonical},
161	};
162	use reifydb_type::value::r#type::Type;
163
164	use super::*;
165
166	fn mkblock(rows: [(i32, bool); 5]) -> ColumnBlock {
167		let ids = ColumnBuffer::int4(rows.map(|(v, _)| v).to_vec());
168		let flags = ColumnBuffer::bool(rows.map(|(_, v)| v).to_vec());
169		let id_col = ColumnChunks::single(
170			Type::Int4,
171			false,
172			Column::from_canonical(Canonical::from_column_buffer(&ids).unwrap()),
173		);
174		let flag_col = ColumnChunks::single(
175			Type::Boolean,
176			false,
177			Column::from_canonical(Canonical::from_column_buffer(&flags).unwrap()),
178		);
179		let schema = Arc::new(vec![
180			("id".to_string(), Type::Int4, false),
181			("flag".to_string(), Type::Boolean, false),
182		]);
183		ColumnBlock::new(schema, vec![id_col, flag_col])
184	}
185
186	#[test]
187	fn evaluate_eq_produces_mask() {
188		let t = mkblock([(1, true), (2, false), (3, true), (2, true), (5, false)]);
189		let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
190		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
191			panic!("expected Mask selection");
192		};
193		assert_eq!(m.popcount(), 2);
194		assert!(m.get(1));
195		assert!(m.get(3));
196	}
197
198	#[test]
199	fn evaluate_all_collapses_to_selection_all() {
200		let t = mkblock([(1, true), (2, true), (3, true), (4, true), (5, true)]);
201		let p = Predicate::GtEq(ColRef::from("id"), Value::Int4(0));
202		assert!(matches!(evaluate(&t, &p).unwrap(), Selection::All));
203	}
204
205	#[test]
206	fn evaluate_none_collapses_to_selection_none() {
207		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
208		let p = Predicate::Lt(ColRef::from("id"), Value::Int4(0));
209		assert!(matches!(evaluate(&t, &p).unwrap(), Selection::None_));
210	}
211
212	#[test]
213	fn evaluate_and_combines_with_intersection() {
214		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
215		let p = Predicate::And(vec![
216			Predicate::Gt(ColRef::from("id"), Value::Int4(1)),
217			Predicate::Eq(ColRef::from("flag"), Value::Boolean(true)),
218		]);
219		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
220			panic!("expected Mask selection");
221		};
222		assert_eq!(m.popcount(), 2);
223		assert!(m.get(2));
224		assert!(m.get(4));
225	}
226
227	#[test]
228	fn evaluate_in_matches_any_value() {
229		let t = mkblock([(1, true), (2, false), (3, true), (4, false), (5, true)]);
230		let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
231		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
232			panic!("expected Mask selection");
233		};
234		assert_eq!(m.popcount(), 2);
235		assert!(m.get(1));
236		assert!(m.get(4));
237	}
238
239	#[test]
240	fn evaluate_is_none_on_nullable_column() {
241		let mut nullable_ids = ColumnBuffer::int4_with_capacity(4);
242		nullable_ids.push::<i32>(10);
243		nullable_ids.push_none();
244		nullable_ids.push::<i32>(30);
245		nullable_ids.push_none();
246		let id_col = ColumnChunks::single(
247			Type::Int4,
248			true,
249			Column::from_canonical(Canonical::from_column_buffer(&nullable_ids).unwrap()),
250		);
251		let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
252		let t = ColumnBlock::new(schema, vec![id_col]);
253
254		let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
255			panic!("expected Mask selection");
256		};
257		assert_eq!(m.popcount(), 2);
258		assert!(m.get(1));
259		assert!(m.get(3));
260	}
261
262	fn int4_chunked(parts: &[&[i32]]) -> ColumnChunks {
263		let chunks = parts
264			.iter()
265			.map(|p| {
266				Column::from_canonical(
267					Canonical::from_column_buffer(&ColumnBuffer::int4(p.to_vec())).unwrap(),
268				)
269			})
270			.collect();
271		ColumnChunks::new(Type::Int4, false, chunks)
272	}
273
274	fn mkblock_chunked(id_parts: &[&[i32]]) -> ColumnBlock {
275		let id_col = int4_chunked(id_parts);
276		let schema = Arc::new(vec![("id".to_string(), Type::Int4, false)]);
277		ColumnBlock::new(schema, vec![id_col])
278	}
279
280	#[test]
281	fn evaluate_eq_over_multi_chunk_column() {
282		// id chunks: [1, 2, 3] | [2, 4, 2] | [5, 2]. Looking for id == 2.
283		let t = mkblock_chunked(&[&[1, 2, 3], &[2, 4, 2], &[5, 2]]);
284		let p = Predicate::Eq(ColRef::from("id"), Value::Int4(2));
285		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
286			panic!("expected Mask selection");
287		};
288		assert_eq!(m.len(), 8);
289		assert_eq!(m.popcount(), 4);
290		assert!(m.get(1));
291		assert!(m.get(3));
292		assert!(m.get(5));
293		assert!(m.get(7));
294	}
295
296	#[test]
297	fn evaluate_and_or_across_multi_chunk_columns() {
298		// Both columns are 2 chunks of length 3. AND/OR must align across chunk boundaries.
299		let id_col = int4_chunked(&[&[1, 2, 3], &[4, 5, 6]]);
300		let other_col = int4_chunked(&[&[10, 20, 10], &[20, 10, 20]]);
301		let schema =
302			Arc::new(vec![("id".to_string(), Type::Int4, false), ("other".to_string(), Type::Int4, false)]);
303		let t = ColumnBlock::new(schema, vec![id_col, other_col]);
304
305		let p = Predicate::And(vec![
306			Predicate::Gt(ColRef::from("id"), Value::Int4(2)),
307			Predicate::Eq(ColRef::from("other"), Value::Int4(20)),
308		]);
309		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
310			panic!("expected Mask selection");
311		};
312		// id > 2 → rows 2,3,4,5; other == 20 → rows 1,3,5. Intersection: rows 3, 5.
313		assert_eq!(m.len(), 6);
314		assert_eq!(m.popcount(), 2);
315		assert!(m.get(3));
316		assert!(m.get(5));
317	}
318
319	#[test]
320	fn evaluate_is_none_across_multi_chunk_nullable() {
321		// Two nullable chunks; nones at row 1 of each chunk → block rows 1 and 4.
322		let mut a = ColumnBuffer::int4_with_capacity(3);
323		a.push::<i32>(10);
324		a.push_none();
325		a.push::<i32>(30);
326		let mut b = ColumnBuffer::int4_with_capacity(3);
327		b.push::<i32>(40);
328		b.push_none();
329		b.push::<i32>(60);
330		let chunks = vec![
331			Column::from_canonical(Canonical::from_column_buffer(&a).unwrap()),
332			Column::from_canonical(Canonical::from_column_buffer(&b).unwrap()),
333		];
334		let id_col = ColumnChunks::new(Type::Int4, true, chunks);
335		let schema = Arc::new(vec![("id".to_string(), Type::Int4, true)]);
336		let t = ColumnBlock::new(schema, vec![id_col]);
337
338		let Selection::Mask(m) = evaluate(&t, &Predicate::IsNone(ColRef::from("id"))).unwrap() else {
339			panic!("expected Mask selection");
340		};
341		assert_eq!(m.len(), 6);
342		assert_eq!(m.popcount(), 2);
343		assert!(m.get(1));
344		assert!(m.get(4));
345	}
346
347	#[test]
348	fn evaluate_in_across_multi_chunk_column() {
349		let t = mkblock_chunked(&[&[1, 2], &[3, 4], &[5, 6]]);
350		let p = Predicate::In(ColRef::from("id"), vec![Value::Int4(2), Value::Int4(5)]);
351		let Selection::Mask(m) = evaluate(&t, &p).unwrap() else {
352			panic!("expected Mask selection");
353		};
354		assert_eq!(m.len(), 6);
355		assert_eq!(m.popcount(), 2);
356		assert!(m.get(1));
357		assert!(m.get(4));
358	}
359}