Skip to main content

reifydb_routine/function/math/
round.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
6use reifydb_type::value::{
7	container::number::NumberContainer,
8	decimal::Decimal,
9	r#type::{Type, input_types::InputTypes},
10};
11
12use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
13
14pub struct Round {
15	info: RoutineInfo,
16}
17
18impl Default for Round {
19	fn default() -> Self {
20		Self::new()
21	}
22}
23
24impl Round {
25	pub fn new() -> Self {
26		Self {
27			info: RoutineInfo::new("math::round"),
28		}
29	}
30}
31
32impl<'a> Routine<FunctionContext<'a>> for Round {
33	fn info(&self) -> &RoutineInfo {
34		&self.info
35	}
36
37	fn return_type(&self, input_types: &[Type]) -> Type {
38		input_types.first().cloned().unwrap_or(Type::Float8)
39	}
40
41	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
42		if args.is_empty() {
43			return Err(RoutineError::FunctionArityMismatch {
44				function: ctx.fragment.clone(),
45				expected: 1,
46				actual: 0,
47			});
48		}
49
50		let value_column = &args[0];
51		let precision_column = args.get(1);
52
53		let (val_data, val_bitvec) = value_column.unwrap_option();
54		let row_count = val_data.len();
55
56		// Helper to get precision value at row index
57		let get_precision = |row_idx: usize| -> i32 {
58			if let Some(prec_col) = precision_column {
59				let (p_data, _) = prec_col.data().unwrap_option();
60				match p_data {
61					ColumnBuffer::Int4(prec_container) => {
62						prec_container.get(row_idx).copied().unwrap_or(0)
63					}
64					ColumnBuffer::Int1(prec_container) => {
65						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
66					}
67					ColumnBuffer::Int2(prec_container) => {
68						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
69					}
70					ColumnBuffer::Int8(prec_container) => {
71						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
72					}
73					ColumnBuffer::Int16(prec_container) => {
74						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
75					}
76					ColumnBuffer::Uint1(prec_container) => {
77						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
78					}
79					ColumnBuffer::Uint2(prec_container) => {
80						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
81					}
82					ColumnBuffer::Uint4(prec_container) => {
83						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
84					}
85					ColumnBuffer::Uint8(prec_container) => {
86						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
87					}
88					ColumnBuffer::Uint16(prec_container) => {
89						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
90					}
91					_ => 0,
92				}
93			} else {
94				0
95			}
96		};
97
98		let result_data = match val_data {
99			ColumnBuffer::Float4(container) => {
100				let mut result = Vec::with_capacity(row_count);
101				let mut bitvec = Vec::with_capacity(row_count);
102				for i in 0..row_count {
103					if let Some(&value) = container.get(i) {
104						let precision = get_precision(i);
105						let multiplier = 10_f32.powi(precision);
106						let rounded = (value * multiplier).round() / multiplier;
107						result.push(rounded);
108						bitvec.push(true);
109					} else {
110						result.push(0.0);
111						bitvec.push(false);
112					}
113				}
114				ColumnBuffer::float4_with_bitvec(result, bitvec)
115			}
116			ColumnBuffer::Float8(container) => {
117				let mut result = Vec::with_capacity(row_count);
118				let mut bitvec = Vec::with_capacity(row_count);
119				for i in 0..row_count {
120					if let Some(&value) = container.get(i) {
121						let precision = get_precision(i);
122						let multiplier = 10_f64.powi(precision);
123						let rounded = (value * multiplier).round() / multiplier;
124						result.push(rounded);
125						bitvec.push(true);
126					} else {
127						result.push(0.0);
128						bitvec.push(false);
129					}
130				}
131				ColumnBuffer::float8_with_bitvec(result, bitvec)
132			}
133			ColumnBuffer::Decimal {
134				container,
135				precision,
136				scale,
137			} => {
138				let mut result = Vec::with_capacity(row_count);
139				let mut bitvec = Vec::with_capacity(row_count);
140				for i in 0..row_count {
141					if let Some(value) = container.get(i) {
142						let prec = get_precision(i);
143						let f_val = value.0.to_f64().unwrap_or(0.0);
144						let multiplier = 10_f64.powi(prec);
145						let rounded = (f_val * multiplier).round() / multiplier;
146						result.push(Decimal::from(rounded));
147						bitvec.push(true);
148					} else {
149						result.push(Decimal::default());
150						bitvec.push(false);
151					}
152				}
153				ColumnBuffer::Decimal {
154					container: NumberContainer::new(result),
155					precision: *precision,
156					scale: *scale,
157				}
158			}
159			other if other.get_type().is_number() => val_data.clone(),
160			other => {
161				return Err(RoutineError::FunctionInvalidArgumentType {
162					function: ctx.fragment.clone(),
163					argument_index: 0,
164					expected: InputTypes::numeric().expected_at(0).to_vec(),
165					actual: other.get_type(),
166				});
167			}
168		};
169
170		let final_data = if let Some(bv) = val_bitvec {
171			ColumnBuffer::Option {
172				inner: Box::new(result_data),
173				bitvec: bv.clone(),
174			}
175		} else {
176			result_data
177		};
178
179		Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
180	}
181}
182
183impl Function for Round {
184	fn kinds(&self) -> &[FunctionKind] {
185		&[FunctionKind::Scalar]
186	}
187}