Skip to main content

reifydb_type/value/type/
promote.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::min;
5
6use Type::*;
7
8use crate::value::r#type::Type;
9
10impl Type {
11	/// Promote two Types to a common supertype, similar to Postgres
12	/// expression evaluation.
13	pub fn promote(left: Type, right: Type) -> Type {
14		if matches!(left, Any) || matches!(right, Any) {
15			return Any;
16		}
17
18		if matches!(left, Option(_)) || matches!(right, Option(_)) {
19			return left;
20		}
21
22		if left == Utf8 || right == Utf8 {
23			return Utf8;
24		}
25
26		if left == Boolean || right == Boolean {
27			return Boolean;
28		}
29
30		if left == Float8 || right == Float8 {
31			return Float8;
32		}
33
34		if left == Float4 || right == Float4 {
35			return Float8;
36		}
37
38		let signed_order = [Int1, Int2, Int4, Int8, Int16];
39		let unsigned_order = [Uint1, Uint2, Uint4, Uint8, Uint16];
40
41		let is_signed = |k: &Type| signed_order.contains(k);
42		let is_unsigned = |k: &Type| unsigned_order.contains(k);
43
44		let rank = |k: &Type| match k {
45			Int1 | Uint1 => 0,
46			Int2 | Uint2 => 1,
47			Int4 | Uint4 => 2,
48			Int8 | Uint8 => 3,
49			Int16 | Uint16 => 4,
50			_ => usize::MAX,
51		};
52
53		if is_signed(&left) && is_signed(&right) {
54			return signed_order[min(rank(&left).max(rank(&right)), 3) + 1].clone();
55		}
56
57		if is_unsigned(&left) && is_unsigned(&right) {
58			return unsigned_order[min(rank(&left).max(rank(&right)), 3) + 1].clone();
59		}
60
61		if (is_signed(&left) && is_unsigned(&right)) || (is_unsigned(&left) && is_signed(&right)) {
62			return match rank(&left).max(rank(&right)) + 1 {
63				0 => Int1,
64				1 => Int2,
65				2 => Int4,
66				3 => Int8,
67				4 => Int16,
68				_ => Int16,
69			};
70		}
71
72		left
73	}
74}
75
76#[cfg(test)]
77pub mod tests {
78	use Type::*;
79
80	use crate::value::r#type::Type;
81
82	#[test]
83	fn test_promote_bool() {
84		let cases = [
85			(Boolean, Boolean, Boolean),
86			(Boolean, Float4, Boolean),
87			(Boolean, Float8, Boolean),
88			(Boolean, Int1, Boolean),
89			(Boolean, Int2, Boolean),
90			(Boolean, Int4, Boolean),
91			(Boolean, Int8, Boolean),
92			(Boolean, Int16, Boolean),
93			(Boolean, Utf8, Utf8),
94			(Boolean, Uint1, Boolean),
95			(Boolean, Uint2, Boolean),
96			(Boolean, Uint4, Boolean),
97			(Boolean, Uint8, Boolean),
98			(Boolean, Uint16, Boolean),
99		];
100		for (left, right, expected) in cases {
101			assert_eq!(Type::promote(left, right), expected);
102		}
103	}
104
105	#[test]
106	fn test_promote_float4() {
107		let cases = [
108			(Float4, Boolean, Boolean),
109			(Float4, Float4, Float8),
110			(Float4, Float8, Float8),
111			(Float4, Int1, Float8),
112			(Float4, Int2, Float8),
113			(Float4, Int4, Float8),
114			(Float4, Int8, Float8),
115			(Float4, Int16, Float8),
116			(Float4, Utf8, Utf8),
117			(Float4, Uint1, Float8),
118			(Float4, Uint2, Float8),
119			(Float4, Uint4, Float8),
120			(Float4, Uint8, Float8),
121			(Float4, Uint16, Float8),
122		];
123		for (left, right, expected) in cases {
124			assert_eq!(Type::promote(left, right), expected);
125		}
126	}
127
128	#[test]
129	fn test_promote_float8() {
130		let cases = [
131			(Float8, Boolean, Boolean),
132			(Float8, Float4, Float8),
133			(Float8, Float8, Float8),
134			(Float8, Int1, Float8),
135			(Float8, Int2, Float8),
136			(Float8, Int4, Float8),
137			(Float8, Int8, Float8),
138			(Float8, Int16, Float8),
139			(Float8, Utf8, Utf8),
140			(Float8, Uint1, Float8),
141			(Float8, Uint2, Float8),
142			(Float8, Uint4, Float8),
143			(Float8, Uint8, Float8),
144			(Float8, Uint16, Float8),
145		];
146		for (left, right, expected) in cases {
147			assert_eq!(Type::promote(left, right), expected);
148		}
149	}
150
151	#[test]
152	fn test_promote_int1() {
153		let cases = [
154			(Int1, Boolean, Boolean),
155			(Int1, Float4, Float8),
156			(Int1, Float8, Float8),
157			(Int1, Int1, Int2),
158			(Int1, Int2, Int4),
159			(Int1, Int4, Int8),
160			(Int1, Int8, Int16),
161			(Int1, Int16, Int16),
162			(Int1, Utf8, Utf8),
163			(Int1, Uint1, Int2),
164			(Int1, Uint2, Int4),
165			(Int1, Uint4, Int8),
166			(Int1, Uint8, Int16),
167			(Int1, Uint16, Int16),
168		];
169		for (left, right, expected) in cases {
170			assert_eq!(Type::promote(left, right), expected);
171		}
172	}
173
174	#[test]
175	fn test_promote_int2() {
176		let cases = [
177			(Int2, Boolean, Boolean),
178			(Int2, Float4, Float8),
179			(Int2, Float8, Float8),
180			(Int2, Int1, Int4),
181			(Int2, Int2, Int4),
182			(Int2, Int4, Int8),
183			(Int2, Int8, Int16),
184			(Int2, Int16, Int16),
185			(Int2, Utf8, Utf8),
186			(Int2, Uint1, Int4),
187			(Int2, Uint2, Int4),
188			(Int2, Uint4, Int8),
189			(Int2, Uint8, Int16),
190			(Int2, Uint16, Int16),
191		];
192		for (left, right, expected) in cases {
193			assert_eq!(Type::promote(left, right), expected);
194		}
195	}
196
197	#[test]
198	fn test_promote_int4() {
199		let cases = [
200			(Int4, Boolean, Boolean),
201			(Int4, Float4, Float8),
202			(Int4, Float8, Float8),
203			(Int4, Int1, Int8),
204			(Int4, Int2, Int8),
205			(Int4, Int4, Int8),
206			(Int4, Int8, Int16),
207			(Int4, Int16, Int16),
208			(Int4, Utf8, Utf8),
209			(Int4, Uint1, Int8),
210			(Int4, Uint2, Int8),
211			(Int4, Uint4, Int8),
212			(Int4, Uint8, Int16),
213			(Int4, Uint16, Int16),
214		];
215		for (left, right, expected) in cases {
216			assert_eq!(Type::promote(left, right), expected);
217		}
218	}
219
220	#[test]
221	fn test_promote_int8() {
222		let cases = [
223			(Int8, Boolean, Boolean),
224			(Int8, Float4, Float8),
225			(Int8, Float8, Float8),
226			(Int8, Int1, Int16),
227			(Int8, Int2, Int16),
228			(Int8, Int4, Int16),
229			(Int8, Int8, Int16),
230			(Int8, Int16, Int16),
231			(Int8, Utf8, Utf8),
232			(Int8, Uint1, Int16),
233			(Int8, Uint2, Int16),
234			(Int8, Uint4, Int16),
235			(Int8, Uint8, Int16),
236			(Int8, Uint16, Int16),
237		];
238		for (left, right, expected) in cases {
239			assert_eq!(Type::promote(left, right), expected);
240		}
241	}
242
243	#[test]
244	fn test_promote_int16() {
245		let cases = [
246			(Int16, Boolean, Boolean),
247			(Int16, Float4, Float8),
248			(Int16, Float8, Float8),
249			(Int16, Int1, Int16),
250			(Int16, Int2, Int16),
251			(Int16, Int4, Int16),
252			(Int16, Int8, Int16),
253			(Int16, Int16, Int16),
254			(Int16, Utf8, Utf8),
255			(Int16, Uint1, Int16),
256			(Int16, Uint2, Int16),
257			(Int16, Uint4, Int16),
258			(Int16, Uint8, Int16),
259			(Int16, Uint16, Int16),
260		];
261		for (left, right, expected) in cases {
262			assert_eq!(Type::promote(left, right), expected);
263		}
264	}
265
266	#[test]
267	fn test_promote_string() {
268		let kinds = [
269			Boolean, Float4, Float8, Int1, Int2, Int4, Int8, Int16, Utf8, Uint1, Uint2, Uint4, Uint8,
270			Uint16,
271		];
272		for ty in kinds {
273			assert_eq!(Type::promote(Utf8, ty), Utf8);
274		}
275	}
276
277	#[test]
278	fn test_promote_uint1() {
279		let cases = [
280			(Uint1, Boolean, Boolean),
281			(Uint1, Float4, Float8),
282			(Uint1, Float8, Float8),
283			(Uint1, Int1, Int2),
284			(Uint1, Int2, Int4),
285			(Uint1, Int4, Int8),
286			(Uint1, Int8, Int16),
287			(Uint1, Int16, Int16),
288			(Uint1, Utf8, Utf8),
289			(Uint1, Uint1, Uint2),
290			(Uint1, Uint2, Uint4),
291			(Uint1, Uint4, Uint8),
292			(Uint1, Uint8, Uint16),
293			(Uint1, Uint16, Uint16),
294		];
295		for (left, right, expected) in cases {
296			assert_eq!(Type::promote(left, right), expected);
297		}
298	}
299
300	#[test]
301	fn test_promote_uint2() {
302		let cases = [
303			(Uint2, Boolean, Boolean),
304			(Uint2, Float4, Float8),
305			(Uint2, Float8, Float8),
306			(Uint2, Int1, Int4),
307			(Uint2, Int2, Int4),
308			(Uint2, Int4, Int8),
309			(Uint2, Int8, Int16),
310			(Uint2, Int16, Int16),
311			(Uint2, Utf8, Utf8),
312			(Uint2, Uint1, Uint4),
313			(Uint2, Uint2, Uint4),
314			(Uint2, Uint4, Uint8),
315			(Uint2, Uint8, Uint16),
316			(Uint2, Uint16, Uint16),
317		];
318		for (left, right, expected) in cases {
319			assert_eq!(Type::promote(left, right), expected);
320		}
321	}
322
323	#[test]
324	fn test_promote_uint4() {
325		let cases = [
326			(Uint4, Boolean, Boolean),
327			(Uint4, Float4, Float8),
328			(Uint4, Float8, Float8),
329			(Uint4, Int1, Int8),
330			(Uint4, Int2, Int8),
331			(Uint4, Int4, Int8),
332			(Uint4, Int8, Int16),
333			(Uint4, Int16, Int16),
334			(Uint4, Utf8, Utf8),
335			(Uint4, Uint1, Uint8),
336			(Uint4, Uint2, Uint8),
337			(Uint4, Uint4, Uint8),
338			(Uint4, Uint8, Uint16),
339			(Uint4, Uint16, Uint16),
340		];
341		for (left, right, expected) in cases {
342			assert_eq!(Type::promote(left, right), expected);
343		}
344	}
345
346	#[test]
347	fn test_promote_uint8() {
348		let cases = [
349			(Uint8, Boolean, Boolean),
350			(Uint8, Float4, Float8),
351			(Uint8, Float8, Float8),
352			(Uint8, Int1, Int16),
353			(Uint8, Int2, Int16),
354			(Uint8, Int4, Int16),
355			(Uint8, Int8, Int16),
356			(Uint8, Int16, Int16),
357			(Uint8, Utf8, Utf8),
358			(Uint8, Uint1, Uint16),
359			(Uint8, Uint2, Uint16),
360			(Uint8, Uint4, Uint16),
361			(Uint8, Uint8, Uint16),
362			(Uint8, Uint16, Uint16),
363		];
364		for (left, right, expected) in cases {
365			assert_eq!(Type::promote(left, right), expected);
366		}
367	}
368
369	#[test]
370	fn test_promote_any_is_absorbing() {
371		let kinds = [
372			Boolean,
373			Float4,
374			Float8,
375			Int1,
376			Int2,
377			Int4,
378			Int8,
379			Int16,
380			Uint1,
381			Uint2,
382			Uint4,
383			Uint8,
384			Uint16,
385			Utf8,
386			Date,
387			DateTime,
388			Time,
389			Duration,
390			Uuid4,
391			Uuid7,
392			Blob,
393			IdentityId,
394			DictionaryId,
395			Int,
396			Uint,
397			Decimal,
398			Any,
399		];
400		for ty in kinds {
401			assert_eq!(Type::promote(Any, ty.clone()), Any, "Any on left with {:?}", ty);
402			assert_eq!(Type::promote(ty.clone(), Any), Any, "Any on right with {:?}", ty);
403		}
404		assert_eq!(Type::promote(Any, Option(Box::new(Int4))), Any);
405		assert_eq!(Type::promote(Option(Box::new(Int4)), Any), Any);
406	}
407
408	#[test]
409	fn test_promote_uint16() {
410		let cases = [
411			(Uint16, Boolean, Boolean),
412			(Uint16, Float4, Float8),
413			(Uint16, Float8, Float8),
414			(Uint16, Int1, Int16),
415			(Uint16, Int2, Int16),
416			(Uint16, Int4, Int16),
417			(Uint16, Int8, Int16),
418			(Uint16, Int16, Int16),
419			(Uint16, Utf8, Utf8),
420			(Uint16, Uint1, Uint16),
421			(Uint16, Uint2, Uint16),
422			(Uint16, Uint4, Uint16),
423			(Uint16, Uint8, Uint16),
424			(Uint16, Uint16, Uint16),
425		];
426		for (left, right, expected) in cases {
427			assert_eq!(Type::promote(left, right), expected);
428		}
429	}
430}