reifydb_routine/function/math/
round.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
6use reifydb_type::value::{
7 container::number::NumberContainer,
8 decimal::Decimal,
9 r#type::{Type, input_types::InputTypes},
10};
11
12use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
13
14pub struct Round {
15 info: RoutineInfo,
16}
17
18impl Default for Round {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Round {
25 pub fn new() -> Self {
26 Self {
27 info: RoutineInfo::new("math::round"),
28 }
29 }
30}
31
32impl<'a> Routine<FunctionContext<'a>> for Round {
33 fn info(&self) -> &RoutineInfo {
34 &self.info
35 }
36
37 fn return_type(&self, input_types: &[Type]) -> Type {
38 input_types.first().cloned().unwrap_or(Type::Float8)
39 }
40
41 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
42 if args.is_empty() {
43 return Err(RoutineError::FunctionArityMismatch {
44 function: ctx.fragment.clone(),
45 expected: 1,
46 actual: 0,
47 });
48 }
49
50 let value_column = &args[0];
51 let precision_column = args.get(1);
52
53 let (val_data, val_bitvec) = value_column.unwrap_option();
54 let row_count = val_data.len();
55
56 let get_precision = |row_idx: usize| -> i32 {
57 if let Some(prec_col) = precision_column {
58 let (p_data, _) = prec_col.data().unwrap_option();
59 match p_data {
60 ColumnBuffer::Int4(prec_container) => {
61 prec_container.get(row_idx).copied().unwrap_or(0)
62 }
63 ColumnBuffer::Int1(prec_container) => {
64 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
65 }
66 ColumnBuffer::Int2(prec_container) => {
67 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
68 }
69 ColumnBuffer::Int8(prec_container) => {
70 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
71 }
72 ColumnBuffer::Int16(prec_container) => {
73 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
74 }
75 ColumnBuffer::Uint1(prec_container) => {
76 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
77 }
78 ColumnBuffer::Uint2(prec_container) => {
79 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
80 }
81 ColumnBuffer::Uint4(prec_container) => {
82 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
83 }
84 ColumnBuffer::Uint8(prec_container) => {
85 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
86 }
87 ColumnBuffer::Uint16(prec_container) => {
88 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
89 }
90 _ => 0,
91 }
92 } else {
93 0
94 }
95 };
96
97 let result_data = match val_data {
98 ColumnBuffer::Float4(container) => {
99 let mut result = Vec::with_capacity(row_count);
100 let mut bitvec = Vec::with_capacity(row_count);
101 for i in 0..row_count {
102 if let Some(&value) = container.get(i) {
103 let precision = get_precision(i);
104 let multiplier = 10_f32.powi(precision);
105 let rounded = (value * multiplier).round() / multiplier;
106 result.push(rounded);
107 bitvec.push(true);
108 } else {
109 result.push(0.0);
110 bitvec.push(false);
111 }
112 }
113 ColumnBuffer::float4_with_bitvec(result, bitvec)
114 }
115 ColumnBuffer::Float8(container) => {
116 let mut result = Vec::with_capacity(row_count);
117 let mut bitvec = Vec::with_capacity(row_count);
118 for i in 0..row_count {
119 if let Some(&value) = container.get(i) {
120 let precision = get_precision(i);
121 let multiplier = 10_f64.powi(precision);
122 let rounded = (value * multiplier).round() / multiplier;
123 result.push(rounded);
124 bitvec.push(true);
125 } else {
126 result.push(0.0);
127 bitvec.push(false);
128 }
129 }
130 ColumnBuffer::float8_with_bitvec(result, bitvec)
131 }
132 ColumnBuffer::Decimal {
133 container,
134 precision,
135 scale,
136 } => {
137 let mut result = Vec::with_capacity(row_count);
138 let mut bitvec = Vec::with_capacity(row_count);
139 for i in 0..row_count {
140 if let Some(value) = container.get(i) {
141 let prec = get_precision(i);
142 let f_val = value.0.to_f64().unwrap_or(0.0);
143 let multiplier = 10_f64.powi(prec);
144 let rounded = (f_val * multiplier).round() / multiplier;
145 result.push(Decimal::from(rounded));
146 bitvec.push(true);
147 } else {
148 result.push(Decimal::default());
149 bitvec.push(false);
150 }
151 }
152 ColumnBuffer::Decimal {
153 container: NumberContainer::new(result),
154 precision: *precision,
155 scale: *scale,
156 }
157 }
158 other if other.get_type().is_number() => val_data.clone(),
159 other => {
160 return Err(RoutineError::FunctionInvalidArgumentType {
161 function: ctx.fragment.clone(),
162 argument_index: 0,
163 expected: InputTypes::numeric().expected_at(0).to_vec(),
164 actual: other.get_type(),
165 });
166 }
167 };
168
169 let final_data = if let Some(bv) = val_bitvec {
170 ColumnBuffer::Option {
171 inner: Box::new(result_data),
172 bitvec: bv.clone(),
173 }
174 } else {
175 result_data
176 };
177
178 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
179 }
180}
181
182impl Function for Round {
183 fn kinds(&self) -> &[FunctionKind] {
184 &[FunctionKind::Scalar]
185 }
186}