reifydb_function/math/scalar/
round.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::{container::number::NumberContainer, decimal::Decimal, int::Int, r#type::Type, uint::Uint};
7
8use crate::{
9 ScalarFunction, ScalarFunctionContext,
10 error::{ScalarFunctionError, ScalarFunctionResult},
11 propagate_options,
12};
13
14pub struct Round;
15
16impl Default for Round {
17 fn default() -> Self {
18 Self {}
19 }
20}
21
22impl Round {
23 pub fn new() -> Self {
24 Self::default()
25 }
26}
27
28impl ScalarFunction for Round {
29 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
30 if let Some(result) = propagate_options(self, &ctx) {
31 return result;
32 }
33 let columns = ctx.columns;
34 let row_count = ctx.row_count;
35
36 if columns.is_empty() {
38 return Err(ScalarFunctionError::ArityMismatch {
39 function: ctx.fragment.clone(),
40 expected: 1,
41 actual: 0,
42 });
43 }
44
45 let value_column = columns.first().unwrap();
46
47 let precision_column = columns.get(1);
49
50 let get_precision = |row_idx: usize| -> i32 {
52 if let Some(prec_col) = precision_column {
53 match prec_col.data() {
54 ColumnData::Int4(prec_container) => {
55 prec_container.get(row_idx).copied().unwrap_or(0)
56 }
57 ColumnData::Int1(prec_container) => {
58 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
59 }
60 ColumnData::Int2(prec_container) => {
61 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
62 }
63 ColumnData::Int8(prec_container) => {
64 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
65 }
66 ColumnData::Int16(prec_container) => {
67 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
68 }
69 ColumnData::Uint1(prec_container) => {
70 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
71 }
72 ColumnData::Uint2(prec_container) => {
73 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
74 }
75 ColumnData::Uint4(prec_container) => {
76 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
77 }
78 ColumnData::Uint8(prec_container) => {
79 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
80 }
81 ColumnData::Uint16(prec_container) => {
82 prec_container.get(row_idx).map(|&v| v as i32).unwrap_or(0)
83 }
84 _ => 0,
85 }
86 } else {
87 0
88 }
89 };
90
91 match value_column.data() {
92 ColumnData::Float4(container) => {
93 let mut result = Vec::with_capacity(row_count);
94 let mut bitvec = Vec::with_capacity(row_count);
95
96 for row_idx in 0..row_count {
97 if let Some(&value) = container.get(row_idx) {
98 let precision = get_precision(row_idx);
99 let multiplier = 10_f32.powi(precision);
100 let rounded = (value * multiplier).round() / multiplier;
101 result.push(rounded);
102 bitvec.push(true);
103 } else {
104 result.push(0.0);
105 bitvec.push(false);
106 }
107 }
108
109 Ok(ColumnData::float4_with_bitvec(result, bitvec))
110 }
111 ColumnData::Float8(container) => {
112 let mut result = Vec::with_capacity(row_count);
113 let mut bitvec = Vec::with_capacity(row_count);
114
115 for row_idx in 0..row_count {
116 if let Some(&value) = container.get(row_idx) {
117 let precision = get_precision(row_idx);
118 let multiplier = 10_f64.powi(precision);
119 let rounded = (value * multiplier).round() / multiplier;
120 result.push(rounded);
121 bitvec.push(true);
122 } else {
123 result.push(0.0);
124 bitvec.push(false);
125 }
126 }
127
128 Ok(ColumnData::float8_with_bitvec(result, bitvec))
129 }
130 ColumnData::Int1(container) => {
132 let mut result = Vec::with_capacity(row_count);
133 let mut bitvec = Vec::with_capacity(row_count);
134
135 for row_idx in 0..row_count {
136 if let Some(&value) = container.get(row_idx) {
137 result.push(value);
138 bitvec.push(true);
139 } else {
140 result.push(0);
141 bitvec.push(false);
142 }
143 }
144
145 Ok(ColumnData::int1_with_bitvec(result, bitvec))
146 }
147 ColumnData::Int2(container) => {
148 let mut result = Vec::with_capacity(row_count);
149 let mut bitvec = Vec::with_capacity(row_count);
150
151 for row_idx in 0..row_count {
152 if let Some(&value) = container.get(row_idx) {
153 result.push(value);
154 bitvec.push(true);
155 } else {
156 result.push(0);
157 bitvec.push(false);
158 }
159 }
160
161 Ok(ColumnData::int2_with_bitvec(result, bitvec))
162 }
163 ColumnData::Int4(container) => {
164 let mut result = Vec::with_capacity(row_count);
165 let mut bitvec = Vec::with_capacity(row_count);
166
167 for row_idx in 0..row_count {
168 if let Some(&value) = container.get(row_idx) {
169 result.push(value);
170 bitvec.push(true);
171 } else {
172 result.push(0);
173 bitvec.push(false);
174 }
175 }
176
177 Ok(ColumnData::int4_with_bitvec(result, bitvec))
178 }
179 ColumnData::Int8(container) => {
180 let mut result = Vec::with_capacity(row_count);
181 let mut bitvec = Vec::with_capacity(row_count);
182
183 for row_idx in 0..row_count {
184 if let Some(&value) = container.get(row_idx) {
185 result.push(value);
186 bitvec.push(true);
187 } else {
188 result.push(0);
189 bitvec.push(false);
190 }
191 }
192
193 Ok(ColumnData::int8_with_bitvec(result, bitvec))
194 }
195 ColumnData::Int16(container) => {
196 let mut result = Vec::with_capacity(row_count);
197 let mut bitvec = Vec::with_capacity(row_count);
198
199 for row_idx in 0..row_count {
200 if let Some(&value) = container.get(row_idx) {
201 result.push(value);
202 bitvec.push(true);
203 } else {
204 result.push(0);
205 bitvec.push(false);
206 }
207 }
208
209 Ok(ColumnData::int16_with_bitvec(result, bitvec))
210 }
211 ColumnData::Uint1(container) => {
212 let mut result = Vec::with_capacity(row_count);
213 let mut bitvec = Vec::with_capacity(row_count);
214
215 for row_idx in 0..row_count {
216 if let Some(&value) = container.get(row_idx) {
217 result.push(value);
218 bitvec.push(true);
219 } else {
220 result.push(0);
221 bitvec.push(false);
222 }
223 }
224
225 Ok(ColumnData::uint1_with_bitvec(result, bitvec))
226 }
227 ColumnData::Uint2(container) => {
228 let mut result = Vec::with_capacity(row_count);
229 let mut bitvec = Vec::with_capacity(row_count);
230
231 for row_idx in 0..row_count {
232 if let Some(&value) = container.get(row_idx) {
233 result.push(value);
234 bitvec.push(true);
235 } else {
236 result.push(0);
237 bitvec.push(false);
238 }
239 }
240
241 Ok(ColumnData::uint2_with_bitvec(result, bitvec))
242 }
243 ColumnData::Uint4(container) => {
244 let mut result = Vec::with_capacity(row_count);
245 let mut bitvec = Vec::with_capacity(row_count);
246
247 for row_idx in 0..row_count {
248 if let Some(&value) = container.get(row_idx) {
249 result.push(value);
250 bitvec.push(true);
251 } else {
252 result.push(0);
253 bitvec.push(false);
254 }
255 }
256
257 Ok(ColumnData::uint4_with_bitvec(result, bitvec))
258 }
259 ColumnData::Uint8(container) => {
260 let mut result = Vec::with_capacity(row_count);
261 let mut bitvec = Vec::with_capacity(row_count);
262
263 for row_idx in 0..row_count {
264 if let Some(&value) = container.get(row_idx) {
265 result.push(value);
266 bitvec.push(true);
267 } else {
268 result.push(0);
269 bitvec.push(false);
270 }
271 }
272
273 Ok(ColumnData::uint8_with_bitvec(result, bitvec))
274 }
275 ColumnData::Uint16(container) => {
276 let mut result = Vec::with_capacity(row_count);
277 let mut bitvec = Vec::with_capacity(row_count);
278
279 for row_idx in 0..row_count {
280 if let Some(&value) = container.get(row_idx) {
281 result.push(value);
282 bitvec.push(true);
283 } else {
284 result.push(0);
285 bitvec.push(false);
286 }
287 }
288
289 Ok(ColumnData::uint16_with_bitvec(result, bitvec))
290 }
291 ColumnData::Int {
292 container,
293 max_bytes,
294 } => {
295 let mut result = Vec::with_capacity(row_count);
296 let mut bitvec = Vec::with_capacity(row_count);
297
298 for row_idx in 0..row_count {
299 if let Some(value) = container.get(row_idx) {
300 result.push(value.clone());
301 bitvec.push(true);
302 } else {
303 result.push(Int::default());
304 bitvec.push(false);
305 }
306 }
307
308 Ok(ColumnData::Int {
309 container: NumberContainer::new(result),
310 max_bytes: *max_bytes,
311 })
312 }
313 ColumnData::Uint {
314 container,
315 max_bytes,
316 } => {
317 let mut result = Vec::with_capacity(row_count);
318 let mut bitvec = Vec::with_capacity(row_count);
319
320 for row_idx in 0..row_count {
321 if let Some(value) = container.get(row_idx) {
322 result.push(value.clone());
323 bitvec.push(true);
324 } else {
325 result.push(Uint::default());
326 bitvec.push(false);
327 }
328 }
329
330 Ok(ColumnData::Uint {
331 container: NumberContainer::new(result),
332 max_bytes: *max_bytes,
333 })
334 }
335 ColumnData::Decimal {
336 container,
337 precision,
338 scale,
339 } => {
340 let mut result = Vec::with_capacity(row_count);
341 let mut bitvec = Vec::with_capacity(row_count);
342
343 for row_idx in 0..row_count {
344 if let Some(value) = container.get(row_idx) {
345 let prec = get_precision(row_idx);
346 let f_val = value.0.to_f64().unwrap_or(0.0);
347 let multiplier = 10_f64.powi(prec);
348 let rounded = (f_val * multiplier).round() / multiplier;
349 result.push(Decimal::from(rounded));
350 bitvec.push(true);
351 } else {
352 result.push(Decimal::default());
353 bitvec.push(false);
354 }
355 }
356
357 Ok(ColumnData::Decimal {
358 container: NumberContainer::new(result),
359 precision: *precision,
360 scale: *scale,
361 })
362 }
363 other => Err(ScalarFunctionError::InvalidArgumentType {
364 function: ctx.fragment.clone(),
365 argument_index: 0,
366 expected: vec![
367 Type::Int1,
368 Type::Int2,
369 Type::Int4,
370 Type::Int8,
371 Type::Int16,
372 Type::Uint1,
373 Type::Uint2,
374 Type::Uint4,
375 Type::Uint8,
376 Type::Uint16,
377 Type::Float4,
378 Type::Float8,
379 Type::Int,
380 Type::Uint,
381 Type::Decimal,
382 ],
383 actual: other.get_type(),
384 }),
385 }
386 }
387
388 fn return_type(&self, input_types: &[Type]) -> Type {
389 input_types[0].clone()
390 }
391}