Skip to main content

reifydb_routine/function/math/
round.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
6use reifydb_type::value::{
7	container::number::NumberContainer,
8	decimal::Decimal,
9	r#type::{Type, input_types::InputTypes},
10};
11
12use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
13
14pub struct Round {
15	info: RoutineInfo,
16}
17
18impl Default for Round {
19	fn default() -> Self {
20		Self::new()
21	}
22}
23
24impl Round {
25	pub fn new() -> Self {
26		Self {
27			info: RoutineInfo::new("math::round"),
28		}
29	}
30}
31
32impl<'a> Routine<FunctionContext<'a>> for Round {
33	fn info(&self) -> &RoutineInfo {
34		&self.info
35	}
36
37	fn return_type(&self, input_types: &[Type]) -> Type {
38		input_types.first().cloned().unwrap_or(Type::Float8)
39	}
40
41	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
42		if args.is_empty() {
43			return Err(RoutineError::FunctionArityMismatch {
44				function: ctx.fragment.clone(),
45				expected: 1,
46				actual: 0,
47			});
48		}
49
50		let value_column = &args[0];
51		let precision_column = args.get(1);
52
53		let (val_data, val_bitvec) = value_column.unwrap_option();
54		let row_count = val_data.len();
55
56		let get_precision = |row_idx: usize| -> i32 {
57			if let Some(prec_col) = precision_column {
58				let (p_data, _) = prec_col.data().unwrap_option();
59				match p_data {
60					ColumnBuffer::Int4(prec_container) => {
61						prec_container.get(row_idx).copied().unwrap_or(0)
62					}
63					ColumnBuffer::Int1(prec_container) => {
64						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
65					}
66					ColumnBuffer::Int2(prec_container) => {
67						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
68					}
69					ColumnBuffer::Int8(prec_container) => {
70						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
71					}
72					ColumnBuffer::Int16(prec_container) => {
73						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
74					}
75					ColumnBuffer::Uint1(prec_container) => {
76						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
77					}
78					ColumnBuffer::Uint2(prec_container) => {
79						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
80					}
81					ColumnBuffer::Uint4(prec_container) => {
82						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
83					}
84					ColumnBuffer::Uint8(prec_container) => {
85						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
86					}
87					ColumnBuffer::Uint16(prec_container) => {
88						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
89					}
90					_ => 0,
91				}
92			} else {
93				0
94			}
95		};
96
97		let result_data = match val_data {
98			ColumnBuffer::Float4(container) => {
99				let mut result = Vec::with_capacity(row_count);
100				let mut bitvec = Vec::with_capacity(row_count);
101				for i in 0..row_count {
102					if let Some(&value) = container.get(i) {
103						let precision = get_precision(i);
104						let multiplier = 10_f32.powi(precision);
105						let rounded = (value * multiplier).round() / multiplier;
106						result.push(rounded);
107						bitvec.push(true);
108					} else {
109						result.push(0.0);
110						bitvec.push(false);
111					}
112				}
113				ColumnBuffer::float4_with_bitvec(result, bitvec)
114			}
115			ColumnBuffer::Float8(container) => {
116				let mut result = Vec::with_capacity(row_count);
117				let mut bitvec = Vec::with_capacity(row_count);
118				for i in 0..row_count {
119					if let Some(&value) = container.get(i) {
120						let precision = get_precision(i);
121						let multiplier = 10_f64.powi(precision);
122						let rounded = (value * multiplier).round() / multiplier;
123						result.push(rounded);
124						bitvec.push(true);
125					} else {
126						result.push(0.0);
127						bitvec.push(false);
128					}
129				}
130				ColumnBuffer::float8_with_bitvec(result, bitvec)
131			}
132			ColumnBuffer::Decimal {
133				container,
134				precision,
135				scale,
136			} => {
137				let mut result = Vec::with_capacity(row_count);
138				let mut bitvec = Vec::with_capacity(row_count);
139				for i in 0..row_count {
140					if let Some(value) = container.get(i) {
141						let prec = get_precision(i);
142						let f_val = value.0.to_f64().unwrap_or(0.0);
143						let multiplier = 10_f64.powi(prec);
144						let rounded = (f_val * multiplier).round() / multiplier;
145						result.push(Decimal::from(rounded));
146						bitvec.push(true);
147					} else {
148						result.push(Decimal::default());
149						bitvec.push(false);
150					}
151				}
152				ColumnBuffer::Decimal {
153					container: NumberContainer::new(result),
154					precision: *precision,
155					scale: *scale,
156				}
157			}
158			other if other.get_type().is_number() => val_data.clone(),
159			other => {
160				return Err(RoutineError::FunctionInvalidArgumentType {
161					function: ctx.fragment.clone(),
162					argument_index: 0,
163					expected: InputTypes::numeric().expected_at(0).to_vec(),
164					actual: other.get_type(),
165				});
166			}
167		};
168
169		let final_data = if let Some(bv) = val_bitvec {
170			ColumnBuffer::Option {
171				inner: Box::new(result_data),
172				bitvec: bv.clone(),
173			}
174		} else {
175			result_data
176		};
177
178		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
179	}
180}
181
182impl Function for Round {
183	fn kinds(&self) -> &[FunctionKind] {
184		&[FunctionKind::Scalar]
185	}
186}