Skip to main content

reifydb_function/math/scalar/
round.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
9
10pub struct Round;
11
12impl Default for Round {
13	fn default() -> Self {
14		Self {}
15	}
16}
17
18impl Round {
19	pub fn new() -> Self {
20		Self::default()
21	}
22}
23
24impl ScalarFunction for Round {
25	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
26		if let Some(result) = propagate_options(self, &ctx) {
27			return result;
28		}
29		let columns = ctx.columns;
30		let row_count = ctx.row_count;
31
32		// Validate at least 1 argument
33		if columns.is_empty() {
34			return Err(ScalarFunctionError::ArityMismatch {
35				function: ctx.fragment.clone(),
36				expected: 1,
37				actual: 0,
38			});
39		}
40
41		let value_column = columns.first().unwrap();
42
43		// Get precision column if provided (default to 0)
44		let precision_column = columns.get(1);
45
46		// Helper to get precision value at row index
47		let get_precision = |row_idx: usize| -> i32 {
48			if let Some(prec_col) = precision_column {
49				match prec_col.data() {
50					ColumnData::Int4(prec_container) => {
51						prec_container.get(row_idx).copied().unwrap_or(0)
52					}
53					ColumnData::Int1(prec_container) => {
54						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
55					}
56					ColumnData::Int2(prec_container) => {
57						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
58					}
59					ColumnData::Int8(prec_container) => {
60						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
61					}
62					ColumnData::Int16(prec_container) => {
63						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
64					}
65					ColumnData::Uint1(prec_container) => {
66						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
67					}
68					ColumnData::Uint2(prec_container) => {
69						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
70					}
71					ColumnData::Uint4(prec_container) => {
72						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
73					}
74					ColumnData::Uint8(prec_container) => {
75						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
76					}
77					ColumnData::Uint16(prec_container) => {
78						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
79					}
80					_ => 0,
81				}
82			} else {
83				0
84			}
85		};
86
87		match value_column.data() {
88			ColumnData::Float4(container) => {
89				let mut result = Vec::with_capacity(row_count);
90				let mut bitvec = Vec::with_capacity(row_count);
91
92				for row_idx in 0..row_count {
93					if let Some(&value) = container.get(row_idx) {
94						let precision = get_precision(row_idx);
95						let multiplier = 10_f32.powi(precision);
96						let rounded = (value * multiplier).round() / multiplier;
97						result.push(rounded);
98						bitvec.push(true);
99					} else {
100						result.push(0.0);
101						bitvec.push(false);
102					}
103				}
104
105				Ok(ColumnData::float4_with_bitvec(result, bitvec))
106			}
107			ColumnData::Float8(container) => {
108				let mut result = Vec::with_capacity(row_count);
109				let mut bitvec = Vec::with_capacity(row_count);
110
111				for row_idx in 0..row_count {
112					if let Some(&value) = container.get(row_idx) {
113						let precision = get_precision(row_idx);
114						let multiplier = 10_f64.powi(precision);
115						let rounded = (value * multiplier).round() / multiplier;
116						result.push(rounded);
117						bitvec.push(true);
118					} else {
119						result.push(0.0);
120						bitvec.push(false);
121					}
122				}
123
124				Ok(ColumnData::float8_with_bitvec(result, bitvec))
125			}
126			// Integer types: round is essentially identity (already whole numbers)
127			ColumnData::Int1(container) => {
128				let mut result = Vec::with_capacity(row_count);
129				let mut bitvec = Vec::with_capacity(row_count);
130
131				for row_idx in 0..row_count {
132					if let Some(&value) = container.get(row_idx) {
133						result.push(value);
134						bitvec.push(true);
135					} else {
136						result.push(0);
137						bitvec.push(false);
138					}
139				}
140
141				Ok(ColumnData::int1_with_bitvec(result, bitvec))
142			}
143			ColumnData::Int2(container) => {
144				let mut result = Vec::with_capacity(row_count);
145				let mut bitvec = Vec::with_capacity(row_count);
146
147				for row_idx in 0..row_count {
148					if let Some(&value) = container.get(row_idx) {
149						result.push(value);
150						bitvec.push(true);
151					} else {
152						result.push(0);
153						bitvec.push(false);
154					}
155				}
156
157				Ok(ColumnData::int2_with_bitvec(result, bitvec))
158			}
159			ColumnData::Int4(container) => {
160				let mut result = Vec::with_capacity(row_count);
161				let mut bitvec = Vec::with_capacity(row_count);
162
163				for row_idx in 0..row_count {
164					if let Some(&value) = container.get(row_idx) {
165						result.push(value);
166						bitvec.push(true);
167					} else {
168						result.push(0);
169						bitvec.push(false);
170					}
171				}
172
173				Ok(ColumnData::int4_with_bitvec(result, bitvec))
174			}
175			ColumnData::Int8(container) => {
176				let mut result = Vec::with_capacity(row_count);
177				let mut bitvec = Vec::with_capacity(row_count);
178
179				for row_idx in 0..row_count {
180					if let Some(&value) = container.get(row_idx) {
181						result.push(value);
182						bitvec.push(true);
183					} else {
184						result.push(0);
185						bitvec.push(false);
186					}
187				}
188
189				Ok(ColumnData::int8_with_bitvec(result, bitvec))
190			}
191			ColumnData::Int16(container) => {
192				let mut result = Vec::with_capacity(row_count);
193				let mut bitvec = Vec::with_capacity(row_count);
194
195				for row_idx in 0..row_count {
196					if let Some(&value) = container.get(row_idx) {
197						result.push(value);
198						bitvec.push(true);
199					} else {
200						result.push(0);
201						bitvec.push(false);
202					}
203				}
204
205				Ok(ColumnData::int16_with_bitvec(result, bitvec))
206			}
207			ColumnData::Uint1(container) => {
208				let mut result = Vec::with_capacity(row_count);
209				let mut bitvec = Vec::with_capacity(row_count);
210
211				for row_idx in 0..row_count {
212					if let Some(&value) = container.get(row_idx) {
213						result.push(value);
214						bitvec.push(true);
215					} else {
216						result.push(0);
217						bitvec.push(false);
218					}
219				}
220
221				Ok(ColumnData::uint1_with_bitvec(result, bitvec))
222			}
223			ColumnData::Uint2(container) => {
224				let mut result = Vec::with_capacity(row_count);
225				let mut bitvec = Vec::with_capacity(row_count);
226
227				for row_idx in 0..row_count {
228					if let Some(&value) = container.get(row_idx) {
229						result.push(value);
230						bitvec.push(true);
231					} else {
232						result.push(0);
233						bitvec.push(false);
234					}
235				}
236
237				Ok(ColumnData::uint2_with_bitvec(result, bitvec))
238			}
239			ColumnData::Uint4(container) => {
240				let mut result = Vec::with_capacity(row_count);
241				let mut bitvec = Vec::with_capacity(row_count);
242
243				for row_idx in 0..row_count {
244					if let Some(&value) = container.get(row_idx) {
245						result.push(value);
246						bitvec.push(true);
247					} else {
248						result.push(0);
249						bitvec.push(false);
250					}
251				}
252
253				Ok(ColumnData::uint4_with_bitvec(result, bitvec))
254			}
255			ColumnData::Uint8(container) => {
256				let mut result = Vec::with_capacity(row_count);
257				let mut bitvec = Vec::with_capacity(row_count);
258
259				for row_idx in 0..row_count {
260					if let Some(&value) = container.get(row_idx) {
261						result.push(value);
262						bitvec.push(true);
263					} else {
264						result.push(0);
265						bitvec.push(false);
266					}
267				}
268
269				Ok(ColumnData::uint8_with_bitvec(result, bitvec))
270			}
271			ColumnData::Uint16(container) => {
272				let mut result = Vec::with_capacity(row_count);
273				let mut bitvec = Vec::with_capacity(row_count);
274
275				for row_idx in 0..row_count {
276					if let Some(&value) = container.get(row_idx) {
277						result.push(value);
278						bitvec.push(true);
279					} else {
280						result.push(0);
281						bitvec.push(false);
282					}
283				}
284
285				Ok(ColumnData::uint16_with_bitvec(result, bitvec))
286			}
287			ColumnData::Int {
288				container,
289				max_bytes,
290			} => {
291				use reifydb_type::value::int::Int;
292				let mut result = Vec::with_capacity(row_count);
293				let mut bitvec = Vec::with_capacity(row_count);
294
295				for row_idx in 0..row_count {
296					if let Some(value) = container.get(row_idx) {
297						result.push(value.clone());
298						bitvec.push(true);
299					} else {
300						result.push(Int::default());
301						bitvec.push(false);
302					}
303				}
304
305				Ok(ColumnData::Int {
306					container: reifydb_type::value::container::number::NumberContainer::new(result),
307					max_bytes: *max_bytes,
308				})
309			}
310			ColumnData::Uint {
311				container,
312				max_bytes,
313			} => {
314				use reifydb_type::value::uint::Uint;
315				let mut result = Vec::with_capacity(row_count);
316				let mut bitvec = Vec::with_capacity(row_count);
317
318				for row_idx in 0..row_count {
319					if let Some(value) = container.get(row_idx) {
320						result.push(value.clone());
321						bitvec.push(true);
322					} else {
323						result.push(Uint::default());
324						bitvec.push(false);
325					}
326				}
327
328				Ok(ColumnData::Uint {
329					container: reifydb_type::value::container::number::NumberContainer::new(result),
330					max_bytes: *max_bytes,
331				})
332			}
333			ColumnData::Decimal {
334				container,
335				precision,
336				scale,
337			} => {
338				use reifydb_type::value::decimal::Decimal;
339				let mut result = Vec::with_capacity(row_count);
340				let mut bitvec = Vec::with_capacity(row_count);
341
342				for row_idx in 0..row_count {
343					if let Some(value) = container.get(row_idx) {
344						let prec = get_precision(row_idx);
345						let f_val = value.0.to_f64().unwrap_or(0.0);
346						let multiplier = 10_f64.powi(prec);
347						let rounded = (f_val * multiplier).round() / multiplier;
348						result.push(Decimal::from(rounded));
349						bitvec.push(true);
350					} else {
351						result.push(Decimal::default());
352						bitvec.push(false);
353					}
354				}
355
356				Ok(ColumnData::Decimal {
357					container: reifydb_type::value::container::number::NumberContainer::new(result),
358					precision: *precision,
359					scale: *scale,
360				})
361			}
362			other => Err(ScalarFunctionError::InvalidArgumentType {
363				function: ctx.fragment.clone(),
364				argument_index: 0,
365				expected: vec![
366					Type::Int1,
367					Type::Int2,
368					Type::Int4,
369					Type::Int8,
370					Type::Int16,
371					Type::Uint1,
372					Type::Uint2,
373					Type::Uint4,
374					Type::Uint8,
375					Type::Uint16,
376					Type::Float4,
377					Type::Float8,
378					Type::Int,
379					Type::Uint,
380					Type::Decimal,
381				],
382				actual: other.get_type(),
383			}),
384		}
385	}
386
387	fn return_type(&self, input_types: &[Type]) -> Type {
388		input_types[0].clone()
389	}
390}