Skip to main content

reifydb_type/value/type/
promote.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::min;
5
6use Type::*;
7
8use crate::value::r#type::Type;
9
10impl Type {
11	/// Promote two Types to a common supertype, similar to Postgres
12	/// expression evaluation.
13	pub fn promote(left: Type, right: Type) -> Type {
14		if matches!(left, Option(_)) || matches!(right, Option(_)) {
15			return left;
16		}
17
18		if left == Utf8 || right == Utf8 {
19			return Utf8;
20		}
21
22		if left == Boolean || right == Boolean {
23			return Boolean;
24		}
25
26		if left == Float8 || right == Float8 {
27			return Float8;
28		}
29
30		if left == Float4 || right == Float4 {
31			return Float8;
32		}
33
34		let signed_order = [Int1, Int2, Int4, Int8, Int16];
35		let unsigned_order = [Uint1, Uint2, Uint4, Uint8, Uint16];
36
37		let is_signed = |k: &Type| signed_order.contains(k);
38		let is_unsigned = |k: &Type| unsigned_order.contains(k);
39
40		let rank = |k: &Type| match k {
41			Int1 | Uint1 => 0,
42			Int2 | Uint2 => 1,
43			Int4 | Uint4 => 2,
44			Int8 | Uint8 => 3,
45			Int16 | Uint16 => 4,
46			_ => usize::MAX,
47		};
48
49		if is_signed(&left) && is_signed(&right) {
50			return signed_order[min(rank(&left).max(rank(&right)), 3) + 1].clone();
51		}
52
53		if is_unsigned(&left) && is_unsigned(&right) {
54			return unsigned_order[min(rank(&left).max(rank(&right)), 3) + 1].clone();
55		}
56
57		if (is_signed(&left) && is_unsigned(&right)) || (is_unsigned(&left) && is_signed(&right)) {
58			return match rank(&left).max(rank(&right)) + 1 {
59				0 => Int1,
60				1 => Int2,
61				2 => Int4,
62				3 => Int8,
63				4 => Int16,
64				_ => Int16,
65			};
66		}
67
68		left
69	}
70}
71
72#[cfg(test)]
73pub mod tests {
74	use Type::*;
75
76	use crate::value::r#type::Type;
77
78	#[test]
79	fn test_promote_bool() {
80		let cases = [
81			(Boolean, Boolean, Boolean),
82			(Boolean, Float4, Boolean),
83			(Boolean, Float8, Boolean),
84			(Boolean, Int1, Boolean),
85			(Boolean, Int2, Boolean),
86			(Boolean, Int4, Boolean),
87			(Boolean, Int8, Boolean),
88			(Boolean, Int16, Boolean),
89			(Boolean, Utf8, Utf8),
90			(Boolean, Uint1, Boolean),
91			(Boolean, Uint2, Boolean),
92			(Boolean, Uint4, Boolean),
93			(Boolean, Uint8, Boolean),
94			(Boolean, Uint16, Boolean),
95		];
96		for (left, right, expected) in cases {
97			assert_eq!(Type::promote(left, right), expected);
98		}
99	}
100
101	#[test]
102	fn test_promote_float4() {
103		let cases = [
104			(Float4, Boolean, Boolean),
105			(Float4, Float4, Float8),
106			(Float4, Float8, Float8),
107			(Float4, Int1, Float8),
108			(Float4, Int2, Float8),
109			(Float4, Int4, Float8),
110			(Float4, Int8, Float8),
111			(Float4, Int16, Float8),
112			(Float4, Utf8, Utf8),
113			(Float4, Uint1, Float8),
114			(Float4, Uint2, Float8),
115			(Float4, Uint4, Float8),
116			(Float4, Uint8, Float8),
117			(Float4, Uint16, Float8),
118		];
119		for (left, right, expected) in cases {
120			assert_eq!(Type::promote(left, right), expected);
121		}
122	}
123
124	#[test]
125	fn test_promote_float8() {
126		let cases = [
127			(Float8, Boolean, Boolean),
128			(Float8, Float4, Float8),
129			(Float8, Float8, Float8),
130			(Float8, Int1, Float8),
131			(Float8, Int2, Float8),
132			(Float8, Int4, Float8),
133			(Float8, Int8, Float8),
134			(Float8, Int16, Float8),
135			(Float8, Utf8, Utf8),
136			(Float8, Uint1, Float8),
137			(Float8, Uint2, Float8),
138			(Float8, Uint4, Float8),
139			(Float8, Uint8, Float8),
140			(Float8, Uint16, Float8),
141		];
142		for (left, right, expected) in cases {
143			assert_eq!(Type::promote(left, right), expected);
144		}
145	}
146
147	#[test]
148	fn test_promote_int1() {
149		let cases = [
150			(Int1, Boolean, Boolean),
151			(Int1, Float4, Float8),
152			(Int1, Float8, Float8),
153			(Int1, Int1, Int2),
154			(Int1, Int2, Int4),
155			(Int1, Int4, Int8),
156			(Int1, Int8, Int16),
157			(Int1, Int16, Int16),
158			(Int1, Utf8, Utf8),
159			(Int1, Uint1, Int2),
160			(Int1, Uint2, Int4),
161			(Int1, Uint4, Int8),
162			(Int1, Uint8, Int16),
163			(Int1, Uint16, Int16),
164		];
165		for (left, right, expected) in cases {
166			assert_eq!(Type::promote(left, right), expected);
167		}
168	}
169
170	#[test]
171	fn test_promote_int2() {
172		let cases = [
173			(Int2, Boolean, Boolean),
174			(Int2, Float4, Float8),
175			(Int2, Float8, Float8),
176			(Int2, Int1, Int4),
177			(Int2, Int2, Int4),
178			(Int2, Int4, Int8),
179			(Int2, Int8, Int16),
180			(Int2, Int16, Int16),
181			(Int2, Utf8, Utf8),
182			(Int2, Uint1, Int4),
183			(Int2, Uint2, Int4),
184			(Int2, Uint4, Int8),
185			(Int2, Uint8, Int16),
186			(Int2, Uint16, Int16),
187		];
188		for (left, right, expected) in cases {
189			assert_eq!(Type::promote(left, right), expected);
190		}
191	}
192
193	#[test]
194	fn test_promote_int4() {
195		let cases = [
196			(Int4, Boolean, Boolean),
197			(Int4, Float4, Float8),
198			(Int4, Float8, Float8),
199			(Int4, Int1, Int8),
200			(Int4, Int2, Int8),
201			(Int4, Int4, Int8),
202			(Int4, Int8, Int16),
203			(Int4, Int16, Int16),
204			(Int4, Utf8, Utf8),
205			(Int4, Uint1, Int8),
206			(Int4, Uint2, Int8),
207			(Int4, Uint4, Int8),
208			(Int4, Uint8, Int16),
209			(Int4, Uint16, Int16),
210		];
211		for (left, right, expected) in cases {
212			assert_eq!(Type::promote(left, right), expected);
213		}
214	}
215
216	#[test]
217	fn test_promote_int8() {
218		let cases = [
219			(Int8, Boolean, Boolean),
220			(Int8, Float4, Float8),
221			(Int8, Float8, Float8),
222			(Int8, Int1, Int16),
223			(Int8, Int2, Int16),
224			(Int8, Int4, Int16),
225			(Int8, Int8, Int16),
226			(Int8, Int16, Int16),
227			(Int8, Utf8, Utf8),
228			(Int8, Uint1, Int16),
229			(Int8, Uint2, Int16),
230			(Int8, Uint4, Int16),
231			(Int8, Uint8, Int16),
232			(Int8, Uint16, Int16),
233		];
234		for (left, right, expected) in cases {
235			assert_eq!(Type::promote(left, right), expected);
236		}
237	}
238
239	#[test]
240	fn test_promote_int16() {
241		let cases = [
242			(Int16, Boolean, Boolean),
243			(Int16, Float4, Float8),
244			(Int16, Float8, Float8),
245			(Int16, Int1, Int16),
246			(Int16, Int2, Int16),
247			(Int16, Int4, Int16),
248			(Int16, Int8, Int16),
249			(Int16, Int16, Int16),
250			(Int16, Utf8, Utf8),
251			(Int16, Uint1, Int16),
252			(Int16, Uint2, Int16),
253			(Int16, Uint4, Int16),
254			(Int16, Uint8, Int16),
255			(Int16, Uint16, Int16),
256		];
257		for (left, right, expected) in cases {
258			assert_eq!(Type::promote(left, right), expected);
259		}
260	}
261
262	#[test]
263	fn test_promote_string() {
264		let kinds = [
265			Boolean, Float4, Float8, Int1, Int2, Int4, Int8, Int16, Utf8, Uint1, Uint2, Uint4, Uint8,
266			Uint16,
267		];
268		for ty in kinds {
269			assert_eq!(Type::promote(Utf8, ty), Utf8);
270		}
271	}
272
273	#[test]
274	fn test_promote_uint1() {
275		let cases = [
276			(Uint1, Boolean, Boolean),
277			(Uint1, Float4, Float8),
278			(Uint1, Float8, Float8),
279			(Uint1, Int1, Int2),
280			(Uint1, Int2, Int4),
281			(Uint1, Int4, Int8),
282			(Uint1, Int8, Int16),
283			(Uint1, Int16, Int16),
284			(Uint1, Utf8, Utf8),
285			(Uint1, Uint1, Uint2),
286			(Uint1, Uint2, Uint4),
287			(Uint1, Uint4, Uint8),
288			(Uint1, Uint8, Uint16),
289			(Uint1, Uint16, Uint16),
290		];
291		for (left, right, expected) in cases {
292			assert_eq!(Type::promote(left, right), expected);
293		}
294	}
295
296	#[test]
297	fn test_promote_uint2() {
298		let cases = [
299			(Uint2, Boolean, Boolean),
300			(Uint2, Float4, Float8),
301			(Uint2, Float8, Float8),
302			(Uint2, Int1, Int4),
303			(Uint2, Int2, Int4),
304			(Uint2, Int4, Int8),
305			(Uint2, Int8, Int16),
306			(Uint2, Int16, Int16),
307			(Uint2, Utf8, Utf8),
308			(Uint2, Uint1, Uint4),
309			(Uint2, Uint2, Uint4),
310			(Uint2, Uint4, Uint8),
311			(Uint2, Uint8, Uint16),
312			(Uint2, Uint16, Uint16),
313		];
314		for (left, right, expected) in cases {
315			assert_eq!(Type::promote(left, right), expected);
316		}
317	}
318
319	#[test]
320	fn test_promote_uint4() {
321		let cases = [
322			(Uint4, Boolean, Boolean),
323			(Uint4, Float4, Float8),
324			(Uint4, Float8, Float8),
325			(Uint4, Int1, Int8),
326			(Uint4, Int2, Int8),
327			(Uint4, Int4, Int8),
328			(Uint4, Int8, Int16),
329			(Uint4, Int16, Int16),
330			(Uint4, Utf8, Utf8),
331			(Uint4, Uint1, Uint8),
332			(Uint4, Uint2, Uint8),
333			(Uint4, Uint4, Uint8),
334			(Uint4, Uint8, Uint16),
335			(Uint4, Uint16, Uint16),
336		];
337		for (left, right, expected) in cases {
338			assert_eq!(Type::promote(left, right), expected);
339		}
340	}
341
342	#[test]
343	fn test_promote_uint8() {
344		let cases = [
345			(Uint8, Boolean, Boolean),
346			(Uint8, Float4, Float8),
347			(Uint8, Float8, Float8),
348			(Uint8, Int1, Int16),
349			(Uint8, Int2, Int16),
350			(Uint8, Int4, Int16),
351			(Uint8, Int8, Int16),
352			(Uint8, Int16, Int16),
353			(Uint8, Utf8, Utf8),
354			(Uint8, Uint1, Uint16),
355			(Uint8, Uint2, Uint16),
356			(Uint8, Uint4, Uint16),
357			(Uint8, Uint8, Uint16),
358			(Uint8, Uint16, Uint16),
359		];
360		for (left, right, expected) in cases {
361			assert_eq!(Type::promote(left, right), expected);
362		}
363	}
364
365	#[test]
366	fn test_promote_uint16() {
367		let cases = [
368			(Uint16, Boolean, Boolean),
369			(Uint16, Float4, Float8),
370			(Uint16, Float8, Float8),
371			(Uint16, Int1, Int16),
372			(Uint16, Int2, Int16),
373			(Uint16, Int4, Int16),
374			(Uint16, Int8, Int16),
375			(Uint16, Int16, Int16),
376			(Uint16, Utf8, Utf8),
377			(Uint16, Uint1, Uint16),
378			(Uint16, Uint2, Uint16),
379			(Uint16, Uint4, Uint16),
380			(Uint16, Uint8, Uint16),
381			(Uint16, Uint16, Uint16),
382		];
383		for (left, right, expected) in cases {
384			assert_eq!(Type::promote(left, right), expected);
385		}
386	}
387}