Skip to main content

reifydb_engine/expression/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::Ordering;
5
6use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer};
7use reifydb_type::{
8	error::Diagnostic,
9	fragment::Fragment,
10	return_error,
11	value::{
12		container::{
13			blob::BlobContainer, bool::BoolContainer, identity_id::IdentityIdContainer,
14			number::NumberContainer, temporal::TemporalContainer, utf8::Utf8Container, uuid::UuidContainer,
15		},
16		decimal::Decimal,
17		int::Int,
18		is::{IsNumber, IsTemporal, IsUuid},
19		number::{compare::partial_cmp, promote::Promote},
20		r#type::Type,
21		uint::Uint,
22	},
23};
24
25use super::option::binary_op_unwrap_option;
26use crate::Result;
27
28/// Generates a complete match expression dispatching all numeric type pairs for comparison.
29/// Uses push-down accumulation to build the cross-product of type arms.
30macro_rules! dispatch_compare {
31	// Entry point
32	(
33		$left:expr, $right:expr;
34		$fragment:expr;
35		$($extra:tt)*
36	) => {
37		dispatch_compare!(@rows
38			($left, $right) ($fragment)
39			[(Float4, f32) (Float8, f64) (Int1, i8) (Int2, i16) (Int4, i32) (Int8, i64) (Int16, i128) (Uint1, u8) (Uint2, u16) (Uint4, u32) (Uint8, u64) (Uint16, u128)]
40			{$($extra)*}
41			{}
42		)
43	};
44
45	// Recursive: process one fixed-left type pair, generating all 15 right-side arms
46	(@rows
47		($left:expr, $right:expr) ($fragment:expr)
48		[($L:ident, $Lt:ty) $($rest:tt)*]
49		{$($extra:tt)*}
50		{$($acc:tt)*}
51	) => {
52		dispatch_compare!(@rows
53			($left, $right) ($fragment)
54			[$($rest)*]
55			{$($extra)*}
56			{
57				$($acc)*
58				(ColumnBuffer::$L(l), ColumnBuffer::Float4(r)) => { return Ok(compare_number::<Op, $Lt, f32>(l, r, $fragment)); },
59				(ColumnBuffer::$L(l), ColumnBuffer::Float8(r)) => { return Ok(compare_number::<Op, $Lt, f64>(l, r, $fragment)); },
60				(ColumnBuffer::$L(l), ColumnBuffer::Int1(r)) => { return Ok(compare_number::<Op, $Lt, i8>(l, r, $fragment)); },
61				(ColumnBuffer::$L(l), ColumnBuffer::Int2(r)) => { return Ok(compare_number::<Op, $Lt, i16>(l, r, $fragment)); },
62				(ColumnBuffer::$L(l), ColumnBuffer::Int4(r)) => { return Ok(compare_number::<Op, $Lt, i32>(l, r, $fragment)); },
63				(ColumnBuffer::$L(l), ColumnBuffer::Int8(r)) => { return Ok(compare_number::<Op, $Lt, i64>(l, r, $fragment)); },
64				(ColumnBuffer::$L(l), ColumnBuffer::Int16(r)) => { return Ok(compare_number::<Op, $Lt, i128>(l, r, $fragment)); },
65				(ColumnBuffer::$L(l), ColumnBuffer::Uint1(r)) => { return Ok(compare_number::<Op, $Lt, u8>(l, r, $fragment)); },
66				(ColumnBuffer::$L(l), ColumnBuffer::Uint2(r)) => { return Ok(compare_number::<Op, $Lt, u16>(l, r, $fragment)); },
67				(ColumnBuffer::$L(l), ColumnBuffer::Uint4(r)) => { return Ok(compare_number::<Op, $Lt, u32>(l, r, $fragment)); },
68				(ColumnBuffer::$L(l), ColumnBuffer::Uint8(r)) => { return Ok(compare_number::<Op, $Lt, u64>(l, r, $fragment)); },
69				(ColumnBuffer::$L(l), ColumnBuffer::Uint16(r)) => { return Ok(compare_number::<Op, $Lt, u128>(l, r, $fragment)); },
70				(ColumnBuffer::$L(l), ColumnBuffer::Int { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Int>(l, r, $fragment)); },
71				(ColumnBuffer::$L(l), ColumnBuffer::Uint { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Uint>(l, r, $fragment)); },
72				(ColumnBuffer::$L(l), ColumnBuffer::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, $Lt, Decimal>(l, r, $fragment)); },
73			}
74		)
75	};
76
77	// Base case: all fixed-left types processed, emit the match with arb-left arms
78	(@rows
79		($left:expr, $right:expr) ($fragment:expr)
80		[]
81		{$($extra:tt)*}
82		{$($acc:tt)*}
83	) => {
84		match ($left, $right) {
85			// Fixed × all (12 × 15 = 180 arms)
86			$($acc)*
87
88			// Int × all (15 arms)
89			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Float4(r)) => { return Ok(compare_number::<Op, Int, f32>(l, r, $fragment)); },
90			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Float8(r)) => { return Ok(compare_number::<Op, Int, f64>(l, r, $fragment)); },
91			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int1(r)) => { return Ok(compare_number::<Op, Int, i8>(l, r, $fragment)); },
92			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int2(r)) => { return Ok(compare_number::<Op, Int, i16>(l, r, $fragment)); },
93			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int4(r)) => { return Ok(compare_number::<Op, Int, i32>(l, r, $fragment)); },
94			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int8(r)) => { return Ok(compare_number::<Op, Int, i64>(l, r, $fragment)); },
95			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int16(r)) => { return Ok(compare_number::<Op, Int, i128>(l, r, $fragment)); },
96			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint1(r)) => { return Ok(compare_number::<Op, Int, u8>(l, r, $fragment)); },
97			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint2(r)) => { return Ok(compare_number::<Op, Int, u16>(l, r, $fragment)); },
98			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint4(r)) => { return Ok(compare_number::<Op, Int, u32>(l, r, $fragment)); },
99			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint8(r)) => { return Ok(compare_number::<Op, Int, u64>(l, r, $fragment)); },
100			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint16(r)) => { return Ok(compare_number::<Op, Int, u128>(l, r, $fragment)); },
101			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Int { container: r, .. }) => { return Ok(compare_number::<Op, Int, Int>(l, r, $fragment)); },
102			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Int, Uint>(l, r, $fragment)); },
103			(ColumnBuffer::Int { container: l, .. }, ColumnBuffer::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Int, Decimal>(l, r, $fragment)); },
104
105			// Uint × all (15 arms)
106			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Float4(r)) => { return Ok(compare_number::<Op, Uint, f32>(l, r, $fragment)); },
107			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Float8(r)) => { return Ok(compare_number::<Op, Uint, f64>(l, r, $fragment)); },
108			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int1(r)) => { return Ok(compare_number::<Op, Uint, i8>(l, r, $fragment)); },
109			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int2(r)) => { return Ok(compare_number::<Op, Uint, i16>(l, r, $fragment)); },
110			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int4(r)) => { return Ok(compare_number::<Op, Uint, i32>(l, r, $fragment)); },
111			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int8(r)) => { return Ok(compare_number::<Op, Uint, i64>(l, r, $fragment)); },
112			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int16(r)) => { return Ok(compare_number::<Op, Uint, i128>(l, r, $fragment)); },
113			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint1(r)) => { return Ok(compare_number::<Op, Uint, u8>(l, r, $fragment)); },
114			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint2(r)) => { return Ok(compare_number::<Op, Uint, u16>(l, r, $fragment)); },
115			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint4(r)) => { return Ok(compare_number::<Op, Uint, u32>(l, r, $fragment)); },
116			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint8(r)) => { return Ok(compare_number::<Op, Uint, u64>(l, r, $fragment)); },
117			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint16(r)) => { return Ok(compare_number::<Op, Uint, u128>(l, r, $fragment)); },
118			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Int { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Int>(l, r, $fragment)); },
119			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Uint>(l, r, $fragment)); },
120			(ColumnBuffer::Uint { container: l, .. }, ColumnBuffer::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Uint, Decimal>(l, r, $fragment)); },
121
122			// Decimal × all (15 arms)
123			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Float4(r)) => { return Ok(compare_number::<Op, Decimal, f32>(l, r, $fragment)); },
124			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Float8(r)) => { return Ok(compare_number::<Op, Decimal, f64>(l, r, $fragment)); },
125			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int1(r)) => { return Ok(compare_number::<Op, Decimal, i8>(l, r, $fragment)); },
126			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int2(r)) => { return Ok(compare_number::<Op, Decimal, i16>(l, r, $fragment)); },
127			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int4(r)) => { return Ok(compare_number::<Op, Decimal, i32>(l, r, $fragment)); },
128			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int8(r)) => { return Ok(compare_number::<Op, Decimal, i64>(l, r, $fragment)); },
129			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int16(r)) => { return Ok(compare_number::<Op, Decimal, i128>(l, r, $fragment)); },
130			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint1(r)) => { return Ok(compare_number::<Op, Decimal, u8>(l, r, $fragment)); },
131			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint2(r)) => { return Ok(compare_number::<Op, Decimal, u16>(l, r, $fragment)); },
132			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint4(r)) => { return Ok(compare_number::<Op, Decimal, u32>(l, r, $fragment)); },
133			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint8(r)) => { return Ok(compare_number::<Op, Decimal, u64>(l, r, $fragment)); },
134			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint16(r)) => { return Ok(compare_number::<Op, Decimal, u128>(l, r, $fragment)); },
135			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Int { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Int>(l, r, $fragment)); },
136			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Uint { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Uint>(l, r, $fragment)); },
137			(ColumnBuffer::Decimal { container: l, .. }, ColumnBuffer::Decimal { container: r, .. }) => { return Ok(compare_number::<Op, Decimal, Decimal>(l, r, $fragment)); },
138
139			// Additional arms
140			$($extra)*
141		}
142	};
143}
144
145// Trait for comparison operations - monomorphized for fast execution
146pub(crate) trait CompareOp {
147	fn compare_ordering(ordering: Option<Ordering>) -> bool;
148	fn compare_bool(_l: bool, _r: bool) -> Option<bool> {
149		None
150	}
151}
152
153pub(crate) struct Equal;
154pub(crate) struct NotEqual;
155pub(crate) struct GreaterThan;
156pub(crate) struct GreaterThanEqual;
157pub(crate) struct LessThan;
158pub(crate) struct LessThanEqual;
159
160impl CompareOp for Equal {
161	#[inline]
162	fn compare_ordering(o: Option<Ordering>) -> bool {
163		o == Some(Ordering::Equal)
164	}
165	#[inline]
166	fn compare_bool(l: bool, r: bool) -> Option<bool> {
167		Some(l == r)
168	}
169}
170
171impl CompareOp for NotEqual {
172	#[inline]
173	fn compare_ordering(o: Option<Ordering>) -> bool {
174		o != Some(Ordering::Equal)
175	}
176	#[inline]
177	fn compare_bool(l: bool, r: bool) -> Option<bool> {
178		Some(l != r)
179	}
180}
181
182impl CompareOp for GreaterThan {
183	#[inline]
184	fn compare_ordering(o: Option<Ordering>) -> bool {
185		o == Some(Ordering::Greater)
186	}
187}
188
189impl CompareOp for GreaterThanEqual {
190	#[inline]
191	fn compare_ordering(o: Option<Ordering>) -> bool {
192		matches!(o, Some(Ordering::Greater) | Some(Ordering::Equal))
193	}
194}
195
196impl CompareOp for LessThan {
197	#[inline]
198	fn compare_ordering(o: Option<Ordering>) -> bool {
199		o == Some(Ordering::Less)
200	}
201}
202
203impl CompareOp for LessThanEqual {
204	#[inline]
205	fn compare_ordering(o: Option<Ordering>) -> bool {
206		matches!(o, Some(Ordering::Less) | Some(Ordering::Equal))
207	}
208}
209
210#[inline]
211fn compare_number<Op: CompareOp, L, R>(
212	l: &NumberContainer<L>,
213	r: &NumberContainer<R>,
214	fragment: Fragment,
215) -> ColumnWithName
216where
217	L: Promote<R> + IsNumber,
218	R: IsNumber,
219	<L as Promote<R>>::Output: IsNumber,
220{
221	debug_assert_eq!(l.len(), r.len());
222
223	let data: Vec<bool> =
224		l.data().iter()
225			.zip(r.data().iter())
226			.map(|(l_val, r_val)| Op::compare_ordering(partial_cmp(l_val, r_val)))
227			.collect();
228
229	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
230}
231
232#[inline]
233fn compare_temporal<Op: CompareOp, T>(
234	l: &TemporalContainer<T>,
235	r: &TemporalContainer<T>,
236	fragment: Fragment,
237) -> ColumnWithName
238where
239	T: IsTemporal + Copy + PartialOrd,
240{
241	debug_assert_eq!(l.len(), r.len());
242
243	let data: Vec<bool> =
244		l.data().iter()
245			.zip(r.data().iter())
246			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
247			.collect();
248
249	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
250}
251
252#[inline]
253fn compare_uuid<Op: CompareOp, T>(l: &UuidContainer<T>, r: &UuidContainer<T>, fragment: Fragment) -> ColumnWithName
254where
255	T: IsUuid + PartialOrd,
256{
257	debug_assert_eq!(l.len(), r.len());
258
259	let data: Vec<bool> =
260		l.data().iter()
261			.zip(r.data().iter())
262			.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
263			.collect();
264
265	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
266}
267
268#[inline]
269fn compare_identity_id<Op: CompareOp>(
270	l: &IdentityIdContainer,
271	r: &IdentityIdContainer,
272	fragment: Fragment,
273) -> ColumnWithName {
274	debug_assert_eq!(l.len(), r.len());
275
276	let data: Vec<bool> =
277		l.iter().zip(r.iter()).map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(&r_val))).collect();
278
279	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
280}
281
282#[inline]
283fn compare_blob<Op: CompareOp>(l: &BlobContainer, r: &BlobContainer, fragment: Fragment) -> ColumnWithName {
284	debug_assert_eq!(l.len(), r.len());
285
286	let data: Vec<bool> = l
287		.iter_bytes()
288		.zip(r.iter_bytes())
289		.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
290		.collect();
291
292	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
293}
294
295#[inline]
296fn compare_utf8<Op: CompareOp>(l: &Utf8Container, r: &Utf8Container, fragment: Fragment) -> ColumnWithName {
297	debug_assert_eq!(l.len(), r.len());
298
299	let data: Vec<bool> = l
300		.iter_str()
301		.zip(r.iter_str())
302		.map(|(l_val, r_val)| Op::compare_ordering(l_val.partial_cmp(r_val)))
303		.collect();
304
305	ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data))
306}
307
308#[inline]
309fn compare_bool<Op: CompareOp>(l: &BoolContainer, r: &BoolContainer, fragment: Fragment) -> Option<ColumnWithName> {
310	debug_assert_eq!(l.len(), r.len());
311
312	let data: Vec<bool> =
313		l.data().iter()
314			.zip(r.data().iter())
315			.filter_map(|(l_val, r_val)| Op::compare_bool(l_val, r_val))
316			.collect();
317
318	if data.len() == l.len() {
319		Some(ColumnWithName::new(Fragment::internal(fragment.text()), ColumnBuffer::bool(data)))
320	} else {
321		None
322	}
323}
324
325pub(crate) fn compare_columns<Op: CompareOp>(
326	left: &ColumnWithName,
327	right: &ColumnWithName,
328	fragment: Fragment,
329	error_fn: impl FnOnce(Fragment, Type, Type) -> Diagnostic,
330) -> Result<ColumnWithName> {
331	binary_op_unwrap_option(left, right, fragment.clone(), |left, right| {
332		dispatch_compare!(
333			&left.data(), &right.data();
334			fragment;
335
336			(ColumnBuffer::Bool(l), ColumnBuffer::Bool(r)) => {
337				if let Some(col) = compare_bool::<Op>(l, r, fragment.clone()) {
338					return Ok(col);
339				}
340				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
341			}
342
343			(ColumnBuffer::Date(l), ColumnBuffer::Date(r)) => {
344				Ok(compare_temporal::<Op, _>(l, r, fragment))
345			},
346			(ColumnBuffer::DateTime(l), ColumnBuffer::DateTime(r)) => {
347				Ok(compare_temporal::<Op, _>(l, r, fragment))
348			},
349			(ColumnBuffer::Time(l), ColumnBuffer::Time(r)) => {
350				Ok(compare_temporal::<Op, _>(l, r, fragment))
351			},
352			(ColumnBuffer::Duration(l), ColumnBuffer::Duration(r)) => {
353				Ok(compare_temporal::<Op, _>(l, r, fragment))
354			},
355
356			(
357				ColumnBuffer::Utf8 {
358					container: l,
359					..
360				},
361				ColumnBuffer::Utf8 {
362					container: r,
363					..
364				},
365			) => {
366				Ok(compare_utf8::<Op>(l, r, fragment))
367			},
368
369			(ColumnBuffer::Uuid4(l), ColumnBuffer::Uuid4(r)) => {
370				Ok(compare_uuid::<Op, _>(l, r, fragment))
371			},
372			(ColumnBuffer::Uuid7(l), ColumnBuffer::Uuid7(r)) => {
373				Ok(compare_uuid::<Op, _>(l, r, fragment))
374			},
375			(ColumnBuffer::IdentityId(l), ColumnBuffer::IdentityId(r)) => {
376				Ok(compare_identity_id::<Op>(l, r, fragment))
377			},
378			(
379				ColumnBuffer::Blob {
380					container: l,
381					..
382				},
383				ColumnBuffer::Blob {
384					container: r,
385					..
386				},
387			) => {
388				Ok(compare_blob::<Op>(l, r, fragment))
389			},
390
391			_ => {
392				return_error!(error_fn(fragment, left.get_type(), right.get_type()))
393			},
394		)
395	})
396}