reifydb_function/math/scalar/
truncate.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::{container::number::NumberContainer, decimal::Decimal, r#type::Type};
7
8use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
9
10pub struct Truncate;
11
12impl Truncate {
13 pub fn new() -> Self {
14 Self
15 }
16}
17
18impl ScalarFunction for Truncate {
19 fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
20 if let Some(result) = propagate_options(self, &ctx) {
21 return result;
22 }
23 let columns = ctx.columns;
24 let row_count = ctx.row_count;
25
26 if columns.len() != 1 {
27 return Err(ScalarFunctionError::ArityMismatch {
28 function: ctx.fragment.clone(),
29 expected: 1,
30 actual: columns.len(),
31 });
32 }
33
34 let column = columns.get(0).unwrap();
35
36 match column.data() {
37 ColumnData::Float4(container) => {
38 let mut data = Vec::with_capacity(row_count);
39 let mut bitvec = Vec::with_capacity(row_count);
40 for i in 0..row_count {
41 if let Some(&value) = container.get(i) {
42 data.push(value.trunc());
43 bitvec.push(true);
44 } else {
45 data.push(0.0);
46 bitvec.push(false);
47 }
48 }
49 Ok(ColumnData::float4_with_bitvec(data, bitvec))
50 }
51 ColumnData::Float8(container) => {
52 let mut data = Vec::with_capacity(row_count);
53 let mut bitvec = Vec::with_capacity(row_count);
54 for i in 0..row_count {
55 if let Some(&value) = container.get(i) {
56 data.push(value.trunc());
57 bitvec.push(true);
58 } else {
59 data.push(0.0);
60 bitvec.push(false);
61 }
62 }
63 Ok(ColumnData::float8_with_bitvec(data, bitvec))
64 }
65 ColumnData::Decimal {
66 container,
67 precision,
68 scale,
69 } => {
70 let mut data = Vec::with_capacity(row_count);
71 for i in 0..row_count {
72 if let Some(value) = container.get(i) {
73 let f = value.0.to_f64().unwrap_or(0.0);
74 data.push(Decimal::from(f.trunc()));
75 } else {
76 data.push(Decimal::default());
77 }
78 }
79 Ok(ColumnData::Decimal {
80 container: NumberContainer::new(data),
81 precision: *precision,
82 scale: *scale,
83 })
84 }
85 other if other.get_type().is_number() => Ok(column.data().clone()),
86 other => Err(ScalarFunctionError::InvalidArgumentType {
87 function: ctx.fragment.clone(),
88 argument_index: 0,
89 expected: vec![
90 Type::Int1,
91 Type::Int2,
92 Type::Int4,
93 Type::Int8,
94 Type::Int16,
95 Type::Uint1,
96 Type::Uint2,
97 Type::Uint4,
98 Type::Uint8,
99 Type::Uint16,
100 Type::Float4,
101 Type::Float8,
102 Type::Int,
103 Type::Uint,
104 Type::Decimal,
105 ],
106 actual: other.get_type(),
107 }),
108 }
109 }
110
111 fn return_type(&self, input_types: &[Type]) -> Type {
112 input_types[0].clone()
113 }
114}