reifydb_routine/function/math/
round.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
6use reifydb_type::value::{
7 container::number::NumberContainer,
8 decimal::Decimal,
9 r#type::{Type, input_types::InputTypes},
10};
11
12use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
13
14pub struct Round {
15 info: RoutineInfo,
16}
17
18impl Default for Round {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Round {
25 pub fn new() -> Self {
26 Self {
27 info: RoutineInfo::new("math::round"),
28 }
29 }
30}
31
32impl<'a> Routine<FunctionContext<'a>> for Round {
33 fn info(&self) -> &RoutineInfo {
34 &self.info
35 }
36
37 fn return_type(&self, input_types: &[Type]) -> Type {
38 input_types.first().cloned().unwrap_or(Type::Float8)
39 }
40
41 fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
42 if args.is_empty() {
43 return Err(RoutineError::FunctionArityMismatch {
44 function: ctx.fragment.clone(),
45 expected: 1,
46 actual: 0,
47 });
48 }
49
50 let value_column = &args[0];
51 let precision_column = args.get(1);
52
53 let (val_data, val_bitvec) = value_column.unwrap_option();
54 let row_count = val_data.len();
55
56 let get_precision = |row_idx: usize| -> i32 {
58 if let Some(prec_col) = precision_column {
59 let (p_data, _) = prec_col.data().unwrap_option();
60 match p_data {
61 ColumnBuffer::Int4(prec_container) => {
62 prec_container.get(row_idx).copied().unwrap_or(0)
63 }
64 ColumnBuffer::Int1(prec_container) => {
65 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
66 }
67 ColumnBuffer::Int2(prec_container) => {
68 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
69 }
70 ColumnBuffer::Int8(prec_container) => {
71 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
72 }
73 ColumnBuffer::Int16(prec_container) => {
74 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
75 }
76 ColumnBuffer::Uint1(prec_container) => {
77 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
78 }
79 ColumnBuffer::Uint2(prec_container) => {
80 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
81 }
82 ColumnBuffer::Uint4(prec_container) => {
83 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
84 }
85 ColumnBuffer::Uint8(prec_container) => {
86 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
87 }
88 ColumnBuffer::Uint16(prec_container) => {
89 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
90 }
91 _ => 0,
92 }
93 } else {
94 0
95 }
96 };
97
98 let result_data = match val_data {
99 ColumnBuffer::Float4(container) => {
100 let mut result = Vec::with_capacity(row_count);
101 let mut bitvec = Vec::with_capacity(row_count);
102 for i in 0..row_count {
103 if let Some(&value) = container.get(i) {
104 let precision = get_precision(i);
105 let multiplier = 10_f32.powi(precision);
106 let rounded = (value * multiplier).round() / multiplier;
107 result.push(rounded);
108 bitvec.push(true);
109 } else {
110 result.push(0.0);
111 bitvec.push(false);
112 }
113 }
114 ColumnBuffer::float4_with_bitvec(result, bitvec)
115 }
116 ColumnBuffer::Float8(container) => {
117 let mut result = Vec::with_capacity(row_count);
118 let mut bitvec = Vec::with_capacity(row_count);
119 for i in 0..row_count {
120 if let Some(&value) = container.get(i) {
121 let precision = get_precision(i);
122 let multiplier = 10_f64.powi(precision);
123 let rounded = (value * multiplier).round() / multiplier;
124 result.push(rounded);
125 bitvec.push(true);
126 } else {
127 result.push(0.0);
128 bitvec.push(false);
129 }
130 }
131 ColumnBuffer::float8_with_bitvec(result, bitvec)
132 }
133 ColumnBuffer::Decimal {
134 container,
135 precision,
136 scale,
137 } => {
138 let mut result = Vec::with_capacity(row_count);
139 let mut bitvec = Vec::with_capacity(row_count);
140 for i in 0..row_count {
141 if let Some(value) = container.get(i) {
142 let prec = get_precision(i);
143 let f_val = value.0.to_f64().unwrap_or(0.0);
144 let multiplier = 10_f64.powi(prec);
145 let rounded = (f_val * multiplier).round() / multiplier;
146 result.push(Decimal::from(rounded));
147 bitvec.push(true);
148 } else {
149 result.push(Decimal::default());
150 bitvec.push(false);
151 }
152 }
153 ColumnBuffer::Decimal {
154 container: NumberContainer::new(result),
155 precision: *precision,
156 scale: *scale,
157 }
158 }
159 other if other.get_type().is_number() => val_data.clone(),
160 other => {
161 return Err(RoutineError::FunctionInvalidArgumentType {
162 function: ctx.fragment.clone(),
163 argument_index: 0,
164 expected: InputTypes::numeric().expected_at(0).to_vec(),
165 actual: other.get_type(),
166 });
167 }
168 };
169
170 let final_data = if let Some(bv) = val_bitvec {
171 ColumnBuffer::Option {
172 inner: Box::new(result_data),
173 bitvec: bv.clone(),
174 }
175 } else {
176 result_data
177 };
178
179 Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
180 }
181}
182
183impl Function for Round {
184 fn kinds(&self) -> &[FunctionKind] {
185 &[FunctionKind::Scalar]
186 }
187}