Skip to main content

reifydb_routine/function/math/
round.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData};
6use reifydb_type::value::{
7	container::number::NumberContainer,
8	decimal::Decimal,
9	r#type::{Type, input_types::InputTypes},
10};
11
12use crate::function::{Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
13
14pub struct Round {
15	info: FunctionInfo,
16}
17
18impl Default for Round {
19	fn default() -> Self {
20		Self::new()
21	}
22}
23
24impl Round {
25	pub fn new() -> Self {
26		Self {
27			info: FunctionInfo::new("math::round"),
28		}
29	}
30}
31
32impl Function for Round {
33	fn info(&self) -> &FunctionInfo {
34		&self.info
35	}
36
37	fn capabilities(&self) -> &[FunctionCapability] {
38		&[FunctionCapability::Scalar]
39	}
40
41	fn return_type(&self, input_types: &[Type]) -> Type {
42		input_types.first().cloned().unwrap_or(Type::Float8)
43	}
44
45	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
46		if args.is_empty() {
47			return Err(FunctionError::ArityMismatch {
48				function: ctx.fragment.clone(),
49				expected: 1,
50				actual: 0,
51			});
52		}
53
54		let value_column = &args[0];
55		let precision_column = args.get(1);
56
57		let (val_data, val_bitvec) = value_column.data().unwrap_option();
58		let row_count = val_data.len();
59
60		// Helper to get precision value at row index
61		let get_precision = |row_idx: usize| -> i32 {
62			if let Some(prec_col) = precision_column {
63				let (p_data, _) = prec_col.data().unwrap_option();
64				match p_data {
65					ColumnData::Int4(prec_container) => {
66						prec_container.get(row_idx).copied().unwrap_or(0)
67					}
68					ColumnData::Int1(prec_container) => {
69						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
70					}
71					ColumnData::Int2(prec_container) => {
72						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
73					}
74					ColumnData::Int8(prec_container) => {
75						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
76					}
77					ColumnData::Int16(prec_container) => {
78						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
79					}
80					ColumnData::Uint1(prec_container) => {
81						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
82					}
83					ColumnData::Uint2(prec_container) => {
84						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
85					}
86					ColumnData::Uint4(prec_container) => {
87						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
88					}
89					ColumnData::Uint8(prec_container) => {
90						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
91					}
92					ColumnData::Uint16(prec_container) => {
93						prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
94					}
95					_ => 0,
96				}
97			} else {
98				0
99			}
100		};
101
102		let result_data = match val_data {
103			ColumnData::Float4(container) => {
104				let mut result = Vec::with_capacity(row_count);
105				let mut bitvec = Vec::with_capacity(row_count);
106				for i in 0..row_count {
107					if let Some(&value) = container.get(i) {
108						let precision = get_precision(i);
109						let multiplier = 10_f32.powi(precision);
110						let rounded = (value * multiplier).round() / multiplier;
111						result.push(rounded);
112						bitvec.push(true);
113					} else {
114						result.push(0.0);
115						bitvec.push(false);
116					}
117				}
118				ColumnData::float4_with_bitvec(result, bitvec)
119			}
120			ColumnData::Float8(container) => {
121				let mut result = Vec::with_capacity(row_count);
122				let mut bitvec = Vec::with_capacity(row_count);
123				for i in 0..row_count {
124					if let Some(&value) = container.get(i) {
125						let precision = get_precision(i);
126						let multiplier = 10_f64.powi(precision);
127						let rounded = (value * multiplier).round() / multiplier;
128						result.push(rounded);
129						bitvec.push(true);
130					} else {
131						result.push(0.0);
132						bitvec.push(false);
133					}
134				}
135				ColumnData::float8_with_bitvec(result, bitvec)
136			}
137			ColumnData::Decimal {
138				container,
139				precision,
140				scale,
141			} => {
142				let mut result = Vec::with_capacity(row_count);
143				let mut bitvec = Vec::with_capacity(row_count);
144				for i in 0..row_count {
145					if let Some(value) = container.get(i) {
146						let prec = get_precision(i);
147						let f_val = value.0.to_f64().unwrap_or(0.0);
148						let multiplier = 10_f64.powi(prec);
149						let rounded = (f_val * multiplier).round() / multiplier;
150						result.push(Decimal::from(rounded));
151						bitvec.push(true);
152					} else {
153						result.push(Decimal::default());
154						bitvec.push(false);
155					}
156				}
157				ColumnData::Decimal {
158					container: NumberContainer::new(result),
159					precision: *precision,
160					scale: *scale,
161				}
162			}
163			other if other.get_type().is_number() => val_data.clone(),
164			other => {
165				return Err(FunctionError::InvalidArgumentType {
166					function: ctx.fragment.clone(),
167					argument_index: 0,
168					expected: InputTypes::numeric().expected_at(0).to_vec(),
169					actual: other.get_type(),
170				});
171			}
172		};
173
174		let final_data = if let Some(bv) = val_bitvec {
175			ColumnData::Option {
176				inner: Box::new(result_data),
177				bitvec: bv.clone(),
178			}
179		} else {
180			result_data
181		};
182
183		Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
184	}
185}