Skip to main content

reifydb_engine/expression/
compare.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::Ordering;
5
6use reifydb_core::value::column::{Column, data::ColumnData};
7use reifydb_type::{
8	error::diagnostic::Diagnostic,
9	fragment::Fragment,
10	return_error,
11	value::{
12		container::{
13			blob::BlobContainer, bool::BoolContainer, number::NumberContainer, temporal::TemporalContainer,
14			utf8::Utf8Container, uuid::UuidContainer,
15		},
16		decimal::Decimal,
17		int::Int,
18		is::{IsNumber, IsTemporal, IsUuid},
19		number::{compare::partial_cmp, promote::Promote},
20		r#type::Type,
21		uint::Uint,
22	},
23};
24
25/// Generates a complete match expression dispatching all numeric type pairs for comparison.
26/// Uses push-down accumulation to build the cross-product of type arms.
27macro_rules! dispatch_compare {
28	// Entry point
29	(
30		$left:expr, $right:expr;
31		$fragment:expr;
32		$($extra:tt)*
33	) => {
34		dispatch_compare!(@rows
35			($left, $right) ($fragment)
36			[(Float4, f32) (Float8, f64) (Int1, i8) (Int2, i16) (Int4, i32) (Int8, i64) (Int16, i128) (Uint1, u8) (Uint2, u16) (Uint4, u32) (Uint8, u64) (Uint16, u128)]
37			{$($extra)*}
38			{}
39		)
40	};
41
42	// Recursive: process one fixed-left type pair, generating all 15 right-side arms
43	(@rows
44		($left:expr, $right:expr) ($fragment:expr)
45		[($L:ident, $Lt:ty) $($rest:tt)*]
46		{$($extra:tt)*}
47		{$($acc:tt)*}
48	) => {
49		dispatch_compare!(@rows
50			($left, $right) ($fragment)
51			[$($rest)*]
52			{$($extra)*}
53			{
54				$($acc)*
55				(ColumnData::$L(l), ColumnData::Float4(r)) => { return Ok(compare_number::<Op, $Lt, f32>(l, r, $fragment)); },
56				(ColumnData::$L(l), ColumnData::Float8(r)) => { return Ok(compare_number::<Op, $Lt, f64>(l, r, $fragment)); },
57				(ColumnData::$L(l), ColumnData::Int1(r)) => { return Ok(compare_number::<Op, $Lt, i8>(l, r, $fragment)); },
58				(ColumnData::$L(l), ColumnData::Int2(r)) => { return Ok(compare_number::<Op, $Lt, i16>(l, r, $fragment)); },
59				(ColumnData::$L(l), ColumnData::Int4(r)) => { return Ok(compare_number::<Op, $Lt, i32>(l, r, $fragment)); },
60				(ColumnData::$L(l), ColumnData::Int8(r)) => { return Ok(compare_number::<Op, $Lt, i64>(l, r, $fragment)); },
61				(ColumnData::$L(l), ColumnData::Int16(r)) => { return Ok(compare_number::<Op, $Lt, i128>(l, r, $fragment)); },
62				(ColumnData::$L(l), ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, $Lt, u8>(l, r, $fragment)); },
63				(ColumnData::$L(l), ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, $Lt, u16>(l, r, $fragment)); },
64				(ColumnData::$L(l), ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, $Lt, u32>(l, r, $fragment)); },
65				(ColumnData::$L(l), ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, $Lt, u64>(l, r, $fragment)); },
66				(ColumnData::$L(l), ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, $Lt, u128>(l, r, $fragment)); },
67				(ColumnData::$L(l), ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Int>(l, r, $fragment)); },
68				(ColumnData::$L(l), ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Uint>(l, r, $fragment)); },
69				(ColumnData::$L(l), ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Decimal>(l, r, $fragment)); },
70			}
71		)
72	};
73
74	// Base case: all fixed-left types processed, emit the match with arb-left arms
75	(@rows
76		($left:expr, $right:expr) ($fragment:expr)
77		[]
78		{$($extra:tt)*}
79		{$($acc:tt)*}
80	) => {
81		match ($left, $right) {
82			// Fixed × all (12 × 15 = 180 arms)
83			$($acc)*
84
85			// Int × all (15 arms)
86			(ColumnData::Int { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Int, f32>(l, r, $fragment)); },
87			(ColumnData::Int { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Int, f64>(l, r, $fragment)); },
88			(ColumnData::Int { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Int, i8>(l, r, $fragment)); },
89			(ColumnData::Int { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Int, i16>(l, r, $fragment)); },
90			(ColumnData::Int { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Int, i32>(l, r, $fragment)); },
91			(ColumnData::Int { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Int, i64>(l, r, $fragment)); },
92			(ColumnData::Int { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Int, i128>(l, r, $fragment)); },
93			(ColumnData::Int { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Int, u8>(l, r, $fragment)); },
94			(ColumnData::Int { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Int, u16>(l, r, $fragment)); },
95			(ColumnData::Int { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Int, u32>(l, r, $fragment)); },
96			(ColumnData::Int { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Int, u64>(l, r, $fragment)); },
97			(ColumnData::Int { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Int, u128>(l, r, $fragment)); },
98			(ColumnData::Int { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Int, Int>(l, r, $fragment)); },
99			(ColumnData::Int { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Int, Uint>(l, r, $fragment)); },
100			(ColumnData::Int { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Int, Decimal>(l, r, $fragment)); },
101
102			// Uint × all (15 arms)
103			(ColumnData::Uint { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Uint, f32>(l, r, $fragment)); },
104			(ColumnData::Uint { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Uint, f64>(l, r, $fragment)); },
105			(ColumnData::Uint { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Uint, i8>(l, r, $fragment)); },
106			(ColumnData::Uint { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Uint, i16>(l, r, $fragment)); },
107			(ColumnData::Uint { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Uint, i32>(l, r, $fragment)); },
108			(ColumnData::Uint { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Uint, i64>(l, r, $fragment)); },
109			(ColumnData::Uint { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Uint, i128>(l, r, $fragment)); },
110			(ColumnData::Uint { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Uint, u8>(l, r, $fragment)); },
111			(ColumnData::Uint { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Uint, u16>(l, r, $fragment)); },
112			(ColumnData::Uint { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Uint, u32>(l, r, $fragment)); },
113			(ColumnData::Uint { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Uint, u64>(l, r, $fragment)); },
114			(ColumnData::Uint { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Uint, u128>(l, r, $fragment)); },
115			(ColumnData::Uint { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Int>(l, r, $fragment)); },
116			(ColumnData::Uint { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Uint>(l, r, $fragment)); },
117			(ColumnData::Uint { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Decimal>(l, r, $fragment)); },
118
119			// Decimal × all (15 arms)
120			(ColumnData::Decimal { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Decimal, f32>(l, r, $fragment)); },
121			(ColumnData::Decimal { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Decimal, f64>(l, r, $fragment)); },
122			(ColumnData::Decimal { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Decimal, i8>(l, r, $fragment)); },
123			(ColumnData::Decimal { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Decimal, i16>(l, r, $fragment)); },
124			(ColumnData::Decimal { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Decimal, i32>(l, r, $fragment)); },
125			(ColumnData::Decimal { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Decimal, i64>(l, r, $fragment)); },
126			(ColumnData::Decimal { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Decimal, i128>(l, r, $fragment)); },
127			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Decimal, u8>(l, r, $fragment)); },
128			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Decimal, u16>(l, r, $fragment)); },
129			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Decimal, u32>(l, r, $fragment)); },
130			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Decimal, u64>(l, r, $fragment)); },
131			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Decimal, u128>(l, r, $fragment)); },
132			(ColumnData::Decimal { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Int>(l, r, $fragment)); },
133			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Uint>(l, r, $fragment)); },
134			(ColumnData::Decimal { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Decimal>(l, r, $fragment)); },
135
136			// Additional arms
137			$($extra)*
138		}
139	};
140}
141
142// Trait for comparison operations - monomorphized for fast execution
143pub(crate) trait CompareOp {
144	fn compare_ordering(ordering: Option<Ordering>) -> bool;
145	fn compare_bool(_l: bool, _r: bool) -> Option<bool> {
146		None
147	}
148}
149
150pub(crate) struct Equal;
151pub(crate) struct NotEqual;
152pub(crate) struct GreaterThan;
153pub(crate) struct GreaterThanEqual;
154pub(crate) struct LessThan;
155pub(crate) struct LessThanEqual;
156
157impl CompareOp for Equal {
158	#[inline]
159	fn compare_ordering(o: Option<Ordering>) -> bool {
160		o == Some(Ordering::Equal)
161	}
162	#[inline]
163	fn compare_bool(l: bool, r: bool) -> Option<bool> {
164		Some(l == r)
165	}
166}
167
168impl CompareOp for NotEqual {
169	#[inline]
170	fn compare_ordering(o: Option<Ordering>) -> bool {
171		o != Some(Ordering::Equal)
172	}
173	#[inline]
174	fn compare_bool(l: bool, r: bool) -> Option<bool> {
175		Some(l != r)
176	}
177}
178
179impl CompareOp for GreaterThan {
180	#[inline]
181	fn compare_ordering(o: Option<Ordering>) -> bool {
182		o == Some(Ordering::Greater)
183	}
184}
185
186impl CompareOp for GreaterThanEqual {
187	#[inline]
188	fn compare_ordering(o: Option<Ordering>) -> bool {
189		matches!(o, Some(Ordering::Greater) | Some(Ordering::Equal))
190	}
191}
192
193impl CompareOp for LessThan {
194	#[inline]
195	fn compare_ordering(o: Option<Ordering>) -> bool {
196		o == Some(Ordering::Less)
197	}
198}
199
200impl CompareOp for LessThanEqual {
201	#[inline]
202	fn compare_ordering(o: Option<Ordering>) -> bool {
203		matches!(o, Some(Ordering::Less) | Some(Ordering::Equal))
204	}
205}
206
207#[inline]
208fn compare_number<Op: CompareOp, L, R>(l: &NumberContainer<L>, r: &NumberContainer<R>, fragment: Fragment) -> Column
209where
210	L: Promote<R> + IsNumber,
211	R: IsNumber,
212	<L as Promote<R>>::Output: IsNumber,
213{
214	debug_assert_eq!(l.len(), r.len());
215
216	let data: Vec<bool> =
217		l.data().iter()
218			.zip(r.data().iter())
219			.map(|(l_val, r_val)| Op::compare_ordering(partial_cmp(l_val, r_val)))
220			.collect();
221
222	Column {
223		name: Fragment::internal(fragment.text()),
224		data: ColumnData::bool(data),
225	}
226}
227
228#[inline]
229fn compare_temporal<Op: CompareOp, T>(l: &TemporalContainer<T>, r: &TemporalContainer<T>, fragment: Fragment) -> Column
230where
231	T: IsTemporal + Copy + PartialOrd,
232{
233	debug_assert_eq!(l.len(), r.len());
234
235	let data: Vec<bool> =
236		l.data().iter()
237			.zip(r.data().iter())
238			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
239			.collect();
240
241	Column {
242		name: Fragment::internal(fragment.text()),
243		data: ColumnData::bool(data),
244	}
245}
246
247#[inline]
248fn compare_uuid<Op: CompareOp, T>(l: &UuidContainer<T>, r: &UuidContainer<T>, fragment: Fragment) -> Column
249where
250	T: IsUuid + PartialOrd,
251{
252	debug_assert_eq!(l.len(), r.len());
253
254	let data: Vec<bool> =
255		l.data().iter()
256			.zip(r.data().iter())
257			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
258			.collect();
259
260	Column {
261		name: Fragment::internal(fragment.text()),
262		data: ColumnData::bool(data),
263	}
264}
265
266#[inline]
267fn compare_blob<Op: CompareOp>(l: &BlobContainer, r: &BlobContainer, fragment: Fragment) -> Column {
268	debug_assert_eq!(l.len(), r.len());
269
270	let data: Vec<bool> =
271		l.data().iter()
272			.zip(r.data().iter())
273			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
274			.collect();
275
276	Column {
277		name: Fragment::internal(fragment.text()),
278		data: ColumnData::bool(data),
279	}
280}
281
282#[inline]
283fn compare_utf8<Op: CompareOp>(l: &Utf8Container, r: &Utf8Container, fragment: Fragment) -> Column {
284	debug_assert_eq!(l.len(), r.len());
285
286	let data: Vec<bool> =
287		l.data().iter()
288			.zip(r.data().iter())
289			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
290			.collect();
291
292	Column {
293		name: Fragment::internal(fragment.text()),
294		data: ColumnData::bool(data),
295	}
296}
297
298#[inline]
299fn compare_bool<Op: CompareOp>(l: &BoolContainer, r: &BoolContainer, fragment: Fragment) -> Option<Column> {
300	debug_assert_eq!(l.len(), r.len());
301
302	let data: Vec<bool> =
303		l.data().iter()
304			.zip(r.data().iter())
305			.filter_map(|(l_val, r_val)| Op::compare_bool(l_val, r_val))
306			.collect();
307
308	if data.len() == l.len() {
309		Some(Column {
310			name: Fragment::internal(fragment.text()),
311			data: ColumnData::bool(data),
312		})
313	} else {
314		None
315	}
316}
317
318pub(crate) fn compare_columns<Op: CompareOp>(
319	left: &Column,
320	right: &Column,
321	fragment: Fragment,
322	error_fn: impl FnOnce(Fragment, Type, Type) -> Diagnostic,
323) -> crate::Result<Column> {
324	super::option::binary_op_unwrap_option(left, right, fragment.clone(), |left, right| {
325		dispatch_compare!(
326			&left.data(), &right.data();
327			fragment;
328
329			(ColumnData::Bool(l), ColumnData::Bool(r)) => {
330				if let Some(col) = compare_bool::<Op>(l, r, fragment.clone()) {
331					return Ok(col);
332				}
333				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
334			}
335
336			(ColumnData::Date(l), ColumnData::Date(r)) => {
337				return Ok(compare_temporal::<Op, _>(l, r, fragment));
338			},
339			(ColumnData::DateTime(l), ColumnData::DateTime(r)) => {
340				return Ok(compare_temporal::<Op, _>(l, r, fragment));
341			},
342			(ColumnData::Time(l), ColumnData::Time(r)) => {
343				return Ok(compare_temporal::<Op, _>(l, r, fragment));
344			},
345			(ColumnData::Duration(l), ColumnData::Duration(r)) => {
346				return Ok(compare_temporal::<Op, _>(l, r, fragment));
347			},
348
349			(
350				ColumnData::Utf8 {
351					container: l,
352					..
353				},
354				ColumnData::Utf8 {
355					container: r,
356					..
357				},
358			) => {
359				return Ok(compare_utf8::<Op>(l, r, fragment));
360			},
361
362			(ColumnData::Uuid4(l), ColumnData::Uuid4(r)) => {
363				return Ok(compare_uuid::<Op, _>(l, r, fragment));
364			},
365			(ColumnData::Uuid7(l), ColumnData::Uuid7(r)) => {
366				return Ok(compare_uuid::<Op, _>(l, r, fragment));
367			},
368			(
369				ColumnData::Blob {
370					container: l,
371					..
372				},
373				ColumnData::Blob {
374					container: r,
375					..
376				},
377			) => {
378				return Ok(compare_blob::<Op>(l, r, fragment));
379			},
380
381			_ => {
382				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
383			},
384		)
385	})
386}