Skip to main content

reifydb_function/math/scalar/
round.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::{container::number::NumberContainer, decimal::Decimal, int::Int, r#type::Type, uint::Uint};
7
8use crate::{
9	ScalarFunction, ScalarFunctionContext,
10	error::{ScalarFunctionError, ScalarFunctionResult},
11	propagate_options,
12};
13
14pub struct Round;
15
16impl Default for Round {
17	fn default() -> Self {
18		Self {}
19	}
20}
21
22impl Round {
23	pub fn new() -> Self {
24		Self::default()
25	}
26}
27
28impl ScalarFunction for Round {
29	fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
30		if let Some(result) = propagate_options(self, &ctx) {
31			return result;
32		}
33		let columns = ctx.columns;
34		let row_count = ctx.row_count;
35
36		// Validate at least 1 argument
37		if columns.is_empty() {
38			return Err(ScalarFunctionError::ArityMismatch {
39				function: ctx.fragment.clone(),
40				expected: 1,
41				actual: 0,
42			});
43		}
44
45		let value_column = columns.first().unwrap();
46
47		// Get precision column if provided (default to 0)
48		let precision_column = columns.get(1);
49
50		// Helper to get precision value at row index
51		let get_precision = |row_idx: usize| -> i32 {
52			if let Some(prec_col) = precision_column {
53				match prec_col.data() {
54					ColumnData::Int4(prec_container) => {
55						prec_container.get(row_idx).copied().unwrap_or(0)
56					}
57					ColumnData::Int1(prec_container) => {
58						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
59					}
60					ColumnData::Int2(prec_container) => {
61						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
62					}
63					ColumnData::Int8(prec_container) => {
64						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
65					}
66					ColumnData::Int16(prec_container) => {
67						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
68					}
69					ColumnData::Uint1(prec_container) => {
70						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
71					}
72					ColumnData::Uint2(prec_container) => {
73						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
74					}
75					ColumnData::Uint4(prec_container) => {
76						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
77					}
78					ColumnData::Uint8(prec_container) => {
79						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
80					}
81					ColumnData::Uint16(prec_container) => {
82						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
83					}
84					_ => 0,
85				}
86			} else {
87				0
88			}
89		};
90
91		match value_column.data() {
92			ColumnData::Float4(container) => {
93				let mut result = Vec::with_capacity(row_count);
94				let mut bitvec = Vec::with_capacity(row_count);
95
96				for row_idx in 0..row_count {
97					if let Some(&value) = container.get(row_idx) {
98						let precision = get_precision(row_idx);
99						let multiplier = 10_f32.powi(precision);
100						let rounded = (value * multiplier).round() / multiplier;
101						result.push(rounded);
102						bitvec.push(true);
103					} else {
104						result.push(0.0);
105						bitvec.push(false);
106					}
107				}
108
109				Ok(ColumnData::float4_with_bitvec(result, bitvec))
110			}
111			ColumnData::Float8(container) => {
112				let mut result = Vec::with_capacity(row_count);
113				let mut bitvec = Vec::with_capacity(row_count);
114
115				for row_idx in 0..row_count {
116					if let Some(&value) = container.get(row_idx) {
117						let precision = get_precision(row_idx);
118						let multiplier = 10_f64.powi(precision);
119						let rounded = (value * multiplier).round() / multiplier;
120						result.push(rounded);
121						bitvec.push(true);
122					} else {
123						result.push(0.0);
124						bitvec.push(false);
125					}
126				}
127
128				Ok(ColumnData::float8_with_bitvec(result, bitvec))
129			}
130			// Integer types: round is essentially identity (already whole numbers)
131			ColumnData::Int1(container) => {
132				let mut result = Vec::with_capacity(row_count);
133				let mut bitvec = Vec::with_capacity(row_count);
134
135				for row_idx in 0..row_count {
136					if let Some(&value) = container.get(row_idx) {
137						result.push(value);
138						bitvec.push(true);
139					} else {
140						result.push(0);
141						bitvec.push(false);
142					}
143				}
144
145				Ok(ColumnData::int1_with_bitvec(result, bitvec))
146			}
147			ColumnData::Int2(container) => {
148				let mut result = Vec::with_capacity(row_count);
149				let mut bitvec = Vec::with_capacity(row_count);
150
151				for row_idx in 0..row_count {
152					if let Some(&value) = container.get(row_idx) {
153						result.push(value);
154						bitvec.push(true);
155					} else {
156						result.push(0);
157						bitvec.push(false);
158					}
159				}
160
161				Ok(ColumnData::int2_with_bitvec(result, bitvec))
162			}
163			ColumnData::Int4(container) => {
164				let mut result = Vec::with_capacity(row_count);
165				let mut bitvec = Vec::with_capacity(row_count);
166
167				for row_idx in 0..row_count {
168					if let Some(&value) = container.get(row_idx) {
169						result.push(value);
170						bitvec.push(true);
171					} else {
172						result.push(0);
173						bitvec.push(false);
174					}
175				}
176
177				Ok(ColumnData::int4_with_bitvec(result, bitvec))
178			}
179			ColumnData::Int8(container) => {
180				let mut result = Vec::with_capacity(row_count);
181				let mut bitvec = Vec::with_capacity(row_count);
182
183				for row_idx in 0..row_count {
184					if let Some(&value) = container.get(row_idx) {
185						result.push(value);
186						bitvec.push(true);
187					} else {
188						result.push(0);
189						bitvec.push(false);
190					}
191				}
192
193				Ok(ColumnData::int8_with_bitvec(result, bitvec))
194			}
195			ColumnData::Int16(container) => {
196				let mut result = Vec::with_capacity(row_count);
197				let mut bitvec = Vec::with_capacity(row_count);
198
199				for row_idx in 0..row_count {
200					if let Some(&value) = container.get(row_idx) {
201						result.push(value);
202						bitvec.push(true);
203					} else {
204						result.push(0);
205						bitvec.push(false);
206					}
207				}
208
209				Ok(ColumnData::int16_with_bitvec(result, bitvec))
210			}
211			ColumnData::Uint1(container) => {
212				let mut result = Vec::with_capacity(row_count);
213				let mut bitvec = Vec::with_capacity(row_count);
214
215				for row_idx in 0..row_count {
216					if let Some(&value) = container.get(row_idx) {
217						result.push(value);
218						bitvec.push(true);
219					} else {
220						result.push(0);
221						bitvec.push(false);
222					}
223				}
224
225				Ok(ColumnData::uint1_with_bitvec(result, bitvec))
226			}
227			ColumnData::Uint2(container) => {
228				let mut result = Vec::with_capacity(row_count);
229				let mut bitvec = Vec::with_capacity(row_count);
230
231				for row_idx in 0..row_count {
232					if let Some(&value) = container.get(row_idx) {
233						result.push(value);
234						bitvec.push(true);
235					} else {
236						result.push(0);
237						bitvec.push(false);
238					}
239				}
240
241				Ok(ColumnData::uint2_with_bitvec(result, bitvec))
242			}
243			ColumnData::Uint4(container) => {
244				let mut result = Vec::with_capacity(row_count);
245				let mut bitvec = Vec::with_capacity(row_count);
246
247				for row_idx in 0..row_count {
248					if let Some(&value) = container.get(row_idx) {
249						result.push(value);
250						bitvec.push(true);
251					} else {
252						result.push(0);
253						bitvec.push(false);
254					}
255				}
256
257				Ok(ColumnData::uint4_with_bitvec(result, bitvec))
258			}
259			ColumnData::Uint8(container) => {
260				let mut result = Vec::with_capacity(row_count);
261				let mut bitvec = Vec::with_capacity(row_count);
262
263				for row_idx in 0..row_count {
264					if let Some(&value) = container.get(row_idx) {
265						result.push(value);
266						bitvec.push(true);
267					} else {
268						result.push(0);
269						bitvec.push(false);
270					}
271				}
272
273				Ok(ColumnData::uint8_with_bitvec(result, bitvec))
274			}
275			ColumnData::Uint16(container) => {
276				let mut result = Vec::with_capacity(row_count);
277				let mut bitvec = Vec::with_capacity(row_count);
278
279				for row_idx in 0..row_count {
280					if let Some(&value) = container.get(row_idx) {
281						result.push(value);
282						bitvec.push(true);
283					} else {
284						result.push(0);
285						bitvec.push(false);
286					}
287				}
288
289				Ok(ColumnData::uint16_with_bitvec(result, bitvec))
290			}
291			ColumnData::Int {
292				container,
293				max_bytes,
294			} => {
295				let mut result = Vec::with_capacity(row_count);
296				let mut bitvec = Vec::with_capacity(row_count);
297
298				for row_idx in 0..row_count {
299					if let Some(value) = container.get(row_idx) {
300						result.push(value.clone());
301						bitvec.push(true);
302					} else {
303						result.push(Int::default());
304						bitvec.push(false);
305					}
306				}
307
308				Ok(ColumnData::Int {
309					container: NumberContainer::new(result),
310					max_bytes: *max_bytes,
311				})
312			}
313			ColumnData::Uint {
314				container,
315				max_bytes,
316			} => {
317				let mut result = Vec::with_capacity(row_count);
318				let mut bitvec = Vec::with_capacity(row_count);
319
320				for row_idx in 0..row_count {
321					if let Some(value) = container.get(row_idx) {
322						result.push(value.clone());
323						bitvec.push(true);
324					} else {
325						result.push(Uint::default());
326						bitvec.push(false);
327					}
328				}
329
330				Ok(ColumnData::Uint {
331					container: NumberContainer::new(result),
332					max_bytes: *max_bytes,
333				})
334			}
335			ColumnData::Decimal {
336				container,
337				precision,
338				scale,
339			} => {
340				let mut result = Vec::with_capacity(row_count);
341				let mut bitvec = Vec::with_capacity(row_count);
342
343				for row_idx in 0..row_count {
344					if let Some(value) = container.get(row_idx) {
345						let prec = get_precision(row_idx);
346						let f_val = value.0.to_f64().unwrap_or(0.0);
347						let multiplier = 10_f64.powi(prec);
348						let rounded = (f_val * multiplier).round() / multiplier;
349						result.push(Decimal::from(rounded));
350						bitvec.push(true);
351					} else {
352						result.push(Decimal::default());
353						bitvec.push(false);
354					}
355				}
356
357				Ok(ColumnData::Decimal {
358					container: NumberContainer::new(result),
359					precision: *precision,
360					scale: *scale,
361				})
362			}
363			other => Err(ScalarFunctionError::InvalidArgumentType {
364				function: ctx.fragment.clone(),
365				argument_index: 0,
366				expected: vec![
367					Type::Int1,
368					Type::Int2,
369					Type::Int4,
370					Type::Int8,
371					Type::Int16,
372					Type::Uint1,
373					Type::Uint2,
374					Type::Uint4,
375					Type::Uint8,
376					Type::Uint16,
377					Type::Float4,
378					Type::Float8,
379					Type::Int,
380					Type::Uint,
381					Type::Decimal,
382				],
383				actual: other.get_type(),
384			}),
385		}
386	}
387
388	fn return_type(&self, input_types: &[Type]) -> Type {
389		input_types[0].clone()
390	}
391}