Skip to main content

reifydb_function/math/scalar/
truncate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::{container::number::NumberContainer, decimal::Decimal, r#type::Type};
7
8use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
9
10pub struct Truncate;
11
12impl Truncate {
13	pub fn new() -> Self {
14		Self
15	}
16}
17
18impl ScalarFunction for Truncate {
19	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
20		if let Some(result) = propagate_options(self, &ctx) {
21			return result;
22		}
23		let columns = ctx.columns;
24		let row_count = ctx.row_count;
25
26		if columns.len() != 1 {
27			return Err(ScalarFunctionError::ArityMismatch {
28				function: ctx.fragment.clone(),
29				expected: 1,
30				actual: columns.len(),
31			});
32		}
33
34		let column = columns.get(0).unwrap();
35
36		match column.data() {
37			ColumnData::Float4(container) => {
38				let mut data = Vec::with_capacity(row_count);
39				let mut bitvec = Vec::with_capacity(row_count);
40				for i in 0..row_count {
41					if let Some(&value) = container.get(i) {
42						data.push(value.trunc());
43						bitvec.push(true);
44					} else {
45						data.push(0.0);
46						bitvec.push(false);
47					}
48				}
49				Ok(ColumnData::float4_with_bitvec(data, bitvec))
50			}
51			ColumnData::Float8(container) => {
52				let mut data = Vec::with_capacity(row_count);
53				let mut bitvec = Vec::with_capacity(row_count);
54				for i in 0..row_count {
55					if let Some(&value) = container.get(i) {
56						data.push(value.trunc());
57						bitvec.push(true);
58					} else {
59						data.push(0.0);
60						bitvec.push(false);
61					}
62				}
63				Ok(ColumnData::float8_with_bitvec(data, bitvec))
64			}
65			ColumnData::Decimal {
66				container,
67				precision,
68				scale,
69			} => {
70				let mut data = Vec::with_capacity(row_count);
71				for i in 0..row_count {
72					if let Some(value) = container.get(i) {
73						let f = value.0.to_f64().unwrap_or(0.0);
74						data.push(Decimal::from(f.trunc()));
75					} else {
76						data.push(Decimal::default());
77					}
78				}
79				Ok(ColumnData::Decimal {
80					container: NumberContainer::new(data),
81					precision: *precision,
82					scale: *scale,
83				})
84			}
85			other if other.get_type().is_number() => Ok(column.data().clone()),
86			other => Err(ScalarFunctionError::InvalidArgumentType {
87				function: ctx.fragment.clone(),
88				argument_index: 0,
89				expected: vec![
90					Type::Int1,
91					Type::Int2,
92					Type::Int4,
93					Type::Int8,
94					Type::Int16,
95					Type::Uint1,
96					Type::Uint2,
97					Type::Uint4,
98					Type::Uint8,
99					Type::Uint16,
100					Type::Float4,
101					Type::Float8,
102					Type::Int,
103					Type::Uint,
104					Type::Decimal,
105				],
106				actual: other.get_type(),
107			}),
108		}
109	}
110
111	fn return_type(&self, input_types: &[Type]) -> Type {
112		input_types[0].clone()
113	}
114}