reifydb_routine/function/math/
round.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData};
6use reifydb_type::value::{
7 container::number::NumberContainer,
8 decimal::Decimal,
9 r#type::{Type, input_types::InputTypes},
10};
11
12use crate::function::{Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
13
14pub struct Round {
15 info: FunctionInfo,
16}
17
18impl Default for Round {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Round {
25 pub fn new() -> Self {
26 Self {
27 info: FunctionInfo::new("math::round"),
28 }
29 }
30}
31
32impl Function for Round {
33 fn info(&self) -> &FunctionInfo {
34 &self.info
35 }
36
37 fn capabilities(&self) -> &[FunctionCapability] {
38 &[FunctionCapability::Scalar]
39 }
40
41 fn return_type(&self, input_types: &[Type]) -> Type {
42 input_types.first().cloned().unwrap_or(Type::Float8)
43 }
44
45 fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
46 if args.is_empty() {
47 return Err(FunctionError::ArityMismatch {
48 function: ctx.fragment.clone(),
49 expected: 1,
50 actual: 0,
51 });
52 }
53
54 let value_column = &args[0];
55 let precision_column = args.get(1);
56
57 let (val_data, val_bitvec) = value_column.data().unwrap_option();
58 let row_count = val_data.len();
59
60 let get_precision = |row_idx: usize| -> i32 {
62 if let Some(prec_col) = precision_column {
63 let (p_data, _) = prec_col.data().unwrap_option();
64 match p_data {
65 ColumnData::Int4(prec_container) => {
66 prec_container.get(row_idx).copied().unwrap_or(0)
67 }
68 ColumnData::Int1(prec_container) => {
69 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
70 }
71 ColumnData::Int2(prec_container) => {
72 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
73 }
74 ColumnData::Int8(prec_container) => {
75 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
76 }
77 ColumnData::Int16(prec_container) => {
78 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
79 }
80 ColumnData::Uint1(prec_container) => {
81 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
82 }
83 ColumnData::Uint2(prec_container) => {
84 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
85 }
86 ColumnData::Uint4(prec_container) => {
87 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
88 }
89 ColumnData::Uint8(prec_container) => {
90 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
91 }
92 ColumnData::Uint16(prec_container) => {
93 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
94 }
95 _ => 0,
96 }
97 } else {
98 0
99 }
100 };
101
102 let result_data = match val_data {
103 ColumnData::Float4(container) => {
104 let mut result = Vec::with_capacity(row_count);
105 let mut bitvec = Vec::with_capacity(row_count);
106 for i in 0..row_count {
107 if let Some(&value) = container.get(i) {
108 let precision = get_precision(i);
109 let multiplier = 10_f32.powi(precision);
110 let rounded = (value * multiplier).round() / multiplier;
111 result.push(rounded);
112 bitvec.push(true);
113 } else {
114 result.push(0.0);
115 bitvec.push(false);
116 }
117 }
118 ColumnData::float4_with_bitvec(result, bitvec)
119 }
120 ColumnData::Float8(container) => {
121 let mut result = Vec::with_capacity(row_count);
122 let mut bitvec = Vec::with_capacity(row_count);
123 for i in 0..row_count {
124 if let Some(&value) = container.get(i) {
125 let precision = get_precision(i);
126 let multiplier = 10_f64.powi(precision);
127 let rounded = (value * multiplier).round() / multiplier;
128 result.push(rounded);
129 bitvec.push(true);
130 } else {
131 result.push(0.0);
132 bitvec.push(false);
133 }
134 }
135 ColumnData::float8_with_bitvec(result, bitvec)
136 }
137 ColumnData::Decimal {
138 container,
139 precision,
140 scale,
141 } => {
142 let mut result = Vec::with_capacity(row_count);
143 let mut bitvec = Vec::with_capacity(row_count);
144 for i in 0..row_count {
145 if let Some(value) = container.get(i) {
146 let prec = get_precision(i);
147 let f_val = value.0.to_f64().unwrap_or(0.0);
148 let multiplier = 10_f64.powi(prec);
149 let rounded = (f_val * multiplier).round() / multiplier;
150 result.push(Decimal::from(rounded));
151 bitvec.push(true);
152 } else {
153 result.push(Decimal::default());
154 bitvec.push(false);
155 }
156 }
157 ColumnData::Decimal {
158 container: NumberContainer::new(result),
159 precision: *precision,
160 scale: *scale,
161 }
162 }
163 other if other.get_type().is_number() => val_data.clone(),
164 other => {
165 return Err(FunctionError::InvalidArgumentType {
166 function: ctx.fragment.clone(),
167 argument_index: 0,
168 expected: InputTypes::numeric().expected_at(0).to_vec(),
169 actual: other.get_type(),
170 });
171 }
172 };
173
174 let final_data = if let Some(bv) = val_bitvec {
175 ColumnData::Option {
176 inner: Box::new(result_data),
177 bitvec: bv.clone(),
178 }
179 } else {
180 result_data
181 };
182
183 Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
184 }
185}