Skip to main content

reifydb_engine/expression/
compare.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::Ordering;
5
6use reifydb_core::value::column::{Column, data::ColumnData};
7use reifydb_type::{
8	error::Diagnostic,
9	fragment::Fragment,
10	return_error,
11	value::{
12		container::{
13			blob::BlobContainer, bool::BoolContainer, number::NumberContainer, temporal::TemporalContainer,
14			utf8::Utf8Container, uuid::UuidContainer,
15		},
16		decimal::Decimal,
17		int::Int,
18		is::{IsNumber, IsTemporal, IsUuid},
19		number::{compare::partial_cmp, promote::Promote},
20		r#type::Type,
21		uint::Uint,
22	},
23};
24
25use super::option::binary_op_unwrap_option;
26use crate::Result;
27
28/// Generates a complete match expression dispatching all numeric type pairs for comparison.
29/// Uses push-down accumulation to build the cross-product of type arms.
30macro_rules! dispatch_compare {
31	// Entry point
32	(
33		$left:expr, $right:expr;
34		$fragment:expr;
35		$($extra:tt)*
36	) => {
37		dispatch_compare!(@rows
38			($left, $right) ($fragment)
39			[(Float4, f32) (Float8, f64) (Int1, i8) (Int2, i16) (Int4, i32) (Int8, i64) (Int16, i128) (Uint1, u8) (Uint2, u16) (Uint4, u32) (Uint8, u64) (Uint16, u128)]
40			{$($extra)*}
41			{}
42		)
43	};
44
45	// Recursive: process one fixed-left type pair, generating all 15 right-side arms
46	(@rows
47		($left:expr, $right:expr) ($fragment:expr)
48		[($L:ident, $Lt:ty) $($rest:tt)*]
49		{$($extra:tt)*}
50		{$($acc:tt)*}
51	) => {
52		dispatch_compare!(@rows
53			($left, $right) ($fragment)
54			[$($rest)*]
55			{$($extra)*}
56			{
57				$($acc)*
58				(ColumnData::$L(l), ColumnData::Float4(r)) => { return Ok(compare_number::<Op, $Lt, f32>(l, r, $fragment)); },
59				(ColumnData::$L(l), ColumnData::Float8(r)) => { return Ok(compare_number::<Op, $Lt, f64>(l, r, $fragment)); },
60				(ColumnData::$L(l), ColumnData::Int1(r)) => { return Ok(compare_number::<Op, $Lt, i8>(l, r, $fragment)); },
61				(ColumnData::$L(l), ColumnData::Int2(r)) => { return Ok(compare_number::<Op, $Lt, i16>(l, r, $fragment)); },
62				(ColumnData::$L(l), ColumnData::Int4(r)) => { return Ok(compare_number::<Op, $Lt, i32>(l, r, $fragment)); },
63				(ColumnData::$L(l), ColumnData::Int8(r)) => { return Ok(compare_number::<Op, $Lt, i64>(l, r, $fragment)); },
64				(ColumnData::$L(l), ColumnData::Int16(r)) => { return Ok(compare_number::<Op, $Lt, i128>(l, r, $fragment)); },
65				(ColumnData::$L(l), ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, $Lt, u8>(l, r, $fragment)); },
66				(ColumnData::$L(l), ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, $Lt, u16>(l, r, $fragment)); },
67				(ColumnData::$L(l), ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, $Lt, u32>(l, r, $fragment)); },
68				(ColumnData::$L(l), ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, $Lt, u64>(l, r, $fragment)); },
69				(ColumnData::$L(l), ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, $Lt, u128>(l, r, $fragment)); },
70				(ColumnData::$L(l), ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Int>(l, r, $fragment)); },
71				(ColumnData::$L(l), ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Uint>(l, r, $fragment)); },
72				(ColumnData::$L(l), ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Decimal>(l, r, $fragment)); },
73			}
74		)
75	};
76
77	// Base case: all fixed-left types processed, emit the match with arb-left arms
78	(@rows
79		($left:expr, $right:expr) ($fragment:expr)
80		[]
81		{$($extra:tt)*}
82		{$($acc:tt)*}
83	) => {
84		match ($left, $right) {
85			// Fixed × all (12 × 15 = 180 arms)
86			$($acc)*
87
88			// Int × all (15 arms)
89			(ColumnData::Int { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Int, f32>(l, r, $fragment)); },
90			(ColumnData::Int { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Int, f64>(l, r, $fragment)); },
91			(ColumnData::Int { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Int, i8>(l, r, $fragment)); },
92			(ColumnData::Int { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Int, i16>(l, r, $fragment)); },
93			(ColumnData::Int { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Int, i32>(l, r, $fragment)); },
94			(ColumnData::Int { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Int, i64>(l, r, $fragment)); },
95			(ColumnData::Int { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Int, i128>(l, r, $fragment)); },
96			(ColumnData::Int { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Int, u8>(l, r, $fragment)); },
97			(ColumnData::Int { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Int, u16>(l, r, $fragment)); },
98			(ColumnData::Int { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Int, u32>(l, r, $fragment)); },
99			(ColumnData::Int { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Int, u64>(l, r, $fragment)); },
100			(ColumnData::Int { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Int, u128>(l, r, $fragment)); },
101			(ColumnData::Int { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Int, Int>(l, r, $fragment)); },
102			(ColumnData::Int { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Int, Uint>(l, r, $fragment)); },
103			(ColumnData::Int { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Int, Decimal>(l, r, $fragment)); },
104
105			// Uint × all (15 arms)
106			(ColumnData::Uint { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Uint, f32>(l, r, $fragment)); },
107			(ColumnData::Uint { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Uint, f64>(l, r, $fragment)); },
108			(ColumnData::Uint { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Uint, i8>(l, r, $fragment)); },
109			(ColumnData::Uint { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Uint, i16>(l, r, $fragment)); },
110			(ColumnData::Uint { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Uint, i32>(l, r, $fragment)); },
111			(ColumnData::Uint { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Uint, i64>(l, r, $fragment)); },
112			(ColumnData::Uint { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Uint, i128>(l, r, $fragment)); },
113			(ColumnData::Uint { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Uint, u8>(l, r, $fragment)); },
114			(ColumnData::Uint { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Uint, u16>(l, r, $fragment)); },
115			(ColumnData::Uint { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Uint, u32>(l, r, $fragment)); },
116			(ColumnData::Uint { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Uint, u64>(l, r, $fragment)); },
117			(ColumnData::Uint { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Uint, u128>(l, r, $fragment)); },
118			(ColumnData::Uint { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Int>(l, r, $fragment)); },
119			(ColumnData::Uint { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Uint>(l, r, $fragment)); },
120			(ColumnData::Uint { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Decimal>(l, r, $fragment)); },
121
122			// Decimal × all (15 arms)
123			(ColumnData::Decimal { container: l, .. }, ColumnData::Float4(r)) => { return Ok(compare_number::<Op, Decimal, f32>(l, r, $fragment)); },
124			(ColumnData::Decimal { container: l, .. }, ColumnData::Float8(r)) => { return Ok(compare_number::<Op, Decimal, f64>(l, r, $fragment)); },
125			(ColumnData::Decimal { container: l, .. }, ColumnData::Int1(r)) => { return Ok(compare_number::<Op, Decimal, i8>(l, r, $fragment)); },
126			(ColumnData::Decimal { container: l, .. }, ColumnData::Int2(r)) => { return Ok(compare_number::<Op, Decimal, i16>(l, r, $fragment)); },
127			(ColumnData::Decimal { container: l, .. }, ColumnData::Int4(r)) => { return Ok(compare_number::<Op, Decimal, i32>(l, r, $fragment)); },
128			(ColumnData::Decimal { container: l, .. }, ColumnData::Int8(r)) => { return Ok(compare_number::<Op, Decimal, i64>(l, r, $fragment)); },
129			(ColumnData::Decimal { container: l, .. }, ColumnData::Int16(r)) => { return Ok(compare_number::<Op, Decimal, i128>(l, r, $fragment)); },
130			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint1(r)) => { return Ok(compare_number::<Op, Decimal, u8>(l, r, $fragment)); },
131			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint2(r)) => { return Ok(compare_number::<Op, Decimal, u16>(l, r, $fragment)); },
132			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint4(r)) => { return Ok(compare_number::<Op, Decimal, u32>(l, r, $fragment)); },
133			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint8(r)) => { return Ok(compare_number::<Op, Decimal, u64>(l, r, $fragment)); },
134			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint16(r)) => { return Ok(compare_number::<Op, Decimal, u128>(l, r, $fragment)); },
135			(ColumnData::Decimal { container: l, .. }, ColumnData::Int { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Int>(l, r, $fragment)); },
136			(ColumnData::Decimal { container: l, .. }, ColumnData::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Uint>(l, r, $fragment)); },
137			(ColumnData::Decimal { container: l, .. }, ColumnData::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Decimal>(l, r, $fragment)); },
138
139			// Additional arms
140			$($extra)*
141		}
142	};
143}
144
145// Trait for comparison operations - monomorphized for fast execution
146pub(crate) trait CompareOp {
147	fn compare_ordering(ordering: Option<Ordering>) -> bool;
148	fn compare_bool(_l: bool, _r: bool) -> Option<bool> {
149		None
150	}
151}
152
153pub(crate) struct Equal;
154pub(crate) struct NotEqual;
155pub(crate) struct GreaterThan;
156pub(crate) struct GreaterThanEqual;
157pub(crate) struct LessThan;
158pub(crate) struct LessThanEqual;
159
160impl CompareOp for Equal {
161	#[inline]
162	fn compare_ordering(o: Option<Ordering>) -> bool {
163		o == Some(Ordering::Equal)
164	}
165	#[inline]
166	fn compare_bool(l: bool, r: bool) -> Option<bool> {
167		Some(l == r)
168	}
169}
170
171impl CompareOp for NotEqual {
172	#[inline]
173	fn compare_ordering(o: Option<Ordering>) -> bool {
174		o != Some(Ordering::Equal)
175	}
176	#[inline]
177	fn compare_bool(l: bool, r: bool) -> Option<bool> {
178		Some(l != r)
179	}
180}
181
182impl CompareOp for GreaterThan {
183	#[inline]
184	fn compare_ordering(o: Option<Ordering>) -> bool {
185		o == Some(Ordering::Greater)
186	}
187}
188
189impl CompareOp for GreaterThanEqual {
190	#[inline]
191	fn compare_ordering(o: Option<Ordering>) -> bool {
192		matches!(o, Some(Ordering::Greater) | Some(Ordering::Equal))
193	}
194}
195
196impl CompareOp for LessThan {
197	#[inline]
198	fn compare_ordering(o: Option<Ordering>) -> bool {
199		o == Some(Ordering::Less)
200	}
201}
202
203impl CompareOp for LessThanEqual {
204	#[inline]
205	fn compare_ordering(o: Option<Ordering>) -> bool {
206		matches!(o, Some(Ordering::Less) | Some(Ordering::Equal))
207	}
208}
209
210#[inline]
211fn compare_number<Op: CompareOp, L, R>(l: &NumberContainer<L>, r: &NumberContainer<R>, fragment: Fragment) -> Column
212where
213	L: Promote<R> + IsNumber,
214	R: IsNumber,
215	<L as Promote<R>>::Output: IsNumber,
216{
217	debug_assert_eq!(l.len(), r.len());
218
219	let data: Vec<bool> =
220		l.data().iter()
221			.zip(r.data().iter())
222			.map(|(l_val, r_val)| Op::compare_ordering(partial_cmp(l_val, r_val)))
223			.collect();
224
225	Column {
226		name: Fragment::internal(fragment.text()),
227		data: ColumnData::bool(data),
228	}
229}
230
231#[inline]
232fn compare_temporal<Op: CompareOp, T>(l: &TemporalContainer<T>, r: &TemporalContainer<T>, fragment: Fragment) -> Column
233where
234	T: IsTemporal + Copy + PartialOrd,
235{
236	debug_assert_eq!(l.len(), r.len());
237
238	let data: Vec<bool> =
239		l.data().iter()
240			.zip(r.data().iter())
241			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
242			.collect();
243
244	Column {
245		name: Fragment::internal(fragment.text()),
246		data: ColumnData::bool(data),
247	}
248}
249
250#[inline]
251fn compare_uuid<Op: CompareOp, T>(l: &UuidContainer<T>, r: &UuidContainer<T>, fragment: Fragment) -> Column
252where
253	T: IsUuid + PartialOrd,
254{
255	debug_assert_eq!(l.len(), r.len());
256
257	let data: Vec<bool> =
258		l.data().iter()
259			.zip(r.data().iter())
260			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
261			.collect();
262
263	Column {
264		name: Fragment::internal(fragment.text()),
265		data: ColumnData::bool(data),
266	}
267}
268
269#[inline]
270fn compare_blob<Op: CompareOp>(l: &BlobContainer, r: &BlobContainer, fragment: Fragment) -> Column {
271	debug_assert_eq!(l.len(), r.len());
272
273	let data: Vec<bool> =
274		l.data().iter()
275			.zip(r.data().iter())
276			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
277			.collect();
278
279	Column {
280		name: Fragment::internal(fragment.text()),
281		data: ColumnData::bool(data),
282	}
283}
284
285#[inline]
286fn compare_utf8<Op: CompareOp>(l: &Utf8Container, r: &Utf8Container, fragment: Fragment) -> Column {
287	debug_assert_eq!(l.len(), r.len());
288
289	let data: Vec<bool> =
290		l.data().iter()
291			.zip(r.data().iter())
292			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
293			.collect();
294
295	Column {
296		name: Fragment::internal(fragment.text()),
297		data: ColumnData::bool(data),
298	}
299}
300
301#[inline]
302fn compare_bool<Op: CompareOp>(l: &BoolContainer, r: &BoolContainer, fragment: Fragment) -> Option<Column> {
303	debug_assert_eq!(l.len(), r.len());
304
305	let data: Vec<bool> =
306		l.data().iter()
307			.zip(r.data().iter())
308			.filter_map(|(l_val, r_val)| Op::compare_bool(l_val, r_val))
309			.collect();
310
311	if data.len() == l.len() {
312		Some(Column {
313			name: Fragment::internal(fragment.text()),
314			data: ColumnData::bool(data),
315		})
316	} else {
317		None
318	}
319}
320
321pub(crate) fn compare_columns<Op: CompareOp>(
322	left: &Column,
323	right: &Column,
324	fragment: Fragment,
325	error_fn: impl FnOnce(Fragment, Type, Type) -> Diagnostic,
326) -> Result<Column> {
327	binary_op_unwrap_option(left, right, fragment.clone(), |left, right| {
328		dispatch_compare!(
329			&left.data(), &right.data();
330			fragment;
331
332			(ColumnData::Bool(l), ColumnData::Bool(r)) => {
333				if let Some(col) = compare_bool::<Op>(l, r, fragment.clone()) {
334					return Ok(col);
335				}
336				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
337			}
338
339			(ColumnData::Date(l), ColumnData::Date(r)) => {
340				return Ok(compare_temporal::<Op, _>(l, r, fragment));
341			},
342			(ColumnData::DateTime(l), ColumnData::DateTime(r)) => {
343				return Ok(compare_temporal::<Op, _>(l, r, fragment));
344			},
345			(ColumnData::Time(l), ColumnData::Time(r)) => {
346				return Ok(compare_temporal::<Op, _>(l, r, fragment));
347			},
348			(ColumnData::Duration(l), ColumnData::Duration(r)) => {
349				return Ok(compare_temporal::<Op, _>(l, r, fragment));
350			},
351
352			(
353				ColumnData::Utf8 {
354					container: l,
355					..
356				},
357				ColumnData::Utf8 {
358					container: r,
359					..
360				},
361			) => {
362				return Ok(compare_utf8::<Op>(l, r, fragment));
363			},
364
365			(ColumnData::Uuid4(l), ColumnData::Uuid4(r)) => {
366				return Ok(compare_uuid::<Op, _>(l, r, fragment));
367			},
368			(ColumnData::Uuid7(l), ColumnData::Uuid7(r)) => {
369				return Ok(compare_uuid::<Op, _>(l, r, fragment));
370			},
371			(
372				ColumnData::Blob {
373					container: l,
374					..
375				},
376				ColumnData::Blob {
377					container: r,
378					..
379				},
380			) => {
381				return Ok(compare_blob::<Op>(l, r, fragment));
382			},
383
384			_ => {
385				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
386			},
387		)
388	})
389}