datafusion_functions/math/
trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39 syntax_example = "trunc(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = r#"Optional. The number of decimal places to
44 truncate to. Defaults to 0 (truncate to a whole number). If
45 `decimal_places` is a positive integer, truncates digits to the
46 right of the decimal point. If `decimal_places` is a negative
47 integer, replaces digits to the left of the decimal point with `0`."#
48 ),
49 sql_example = r#"
50 ```sql
51 > SELECT trunc(42.738);
52 +----------------+
53 | trunc(42.738) |
54 +----------------+
55 | 42 |
56 +----------------+
57 ```"#
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub struct TruncFunc {
61 signature: Signature,
62}
63
64impl Default for TruncFunc {
65 fn default() -> Self {
66 TruncFunc::new()
67 }
68}
69
70impl TruncFunc {
71 pub fn new() -> Self {
72 use DataType::*;
73 Self {
74 signature: Signature::one_of(
80 vec![
81 Exact(vec![Float32, Int64]),
82 Exact(vec![Float64, Int64]),
83 Exact(vec![Float64]),
84 Exact(vec![Float32]),
85 ],
86 Volatility::Immutable,
87 ),
88 }
89 }
90}
91
92impl ScalarUDFImpl for TruncFunc {
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn name(&self) -> &str {
98 "trunc"
99 }
100
101 fn signature(&self) -> &Signature {
102 &self.signature
103 }
104
105 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106 match arg_types[0] {
107 Float32 => Ok(Float32),
108 _ => Ok(Float64),
109 }
110 }
111
112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113 make_scalar_function(trunc, vec![])(&args.args)
114 }
115
116 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
117 let value = &input[0];
119 let precision = input.get(1);
120
121 if precision
122 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
123 .unwrap_or(true)
124 {
125 Ok(value.sort_properties)
126 } else {
127 Ok(SortProperties::Unordered)
128 }
129 }
130
131 fn documentation(&self) -> Option<&Documentation> {
132 self.doc()
133 }
134}
135
136fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
138 if args.len() != 1 && args.len() != 2 {
139 return exec_err!(
140 "truncate function requires one or two arguments, got {}",
141 args.len()
142 );
143 }
144
145 let num = &args[0];
148 let precision = if args.len() == 1 {
149 ColumnarValue::Scalar(Int64(Some(0)))
150 } else {
151 ColumnarValue::Array(Arc::clone(&args[1]))
152 };
153
154 match num.data_type() {
155 Float64 => match precision {
156 ColumnarValue::Scalar(Int64(Some(0))) => {
157 Ok(Arc::new(
158 args[0]
159 .as_primitive::<Float64Type>()
160 .unary::<_, Float64Type>(|x: f64| {
161 if x == 0_f64 { 0_f64 } else { x.trunc() }
162 }),
163 ) as ArrayRef)
164 }
165 ColumnarValue::Array(precision) => {
166 let num_array = num.as_primitive::<Float64Type>();
167 let precision_array = precision.as_primitive::<Int64Type>();
168 let result: PrimitiveArray<Float64Type> =
169 arrow::compute::binary(num_array, precision_array, |x, y| {
170 compute_truncate64(x, y)
171 })?;
172
173 Ok(Arc::new(result) as ArrayRef)
174 }
175 _ => exec_err!("trunc function requires a scalar or array for precision"),
176 },
177 Float32 => match precision {
178 ColumnarValue::Scalar(Int64(Some(0))) => {
179 Ok(Arc::new(
180 args[0]
181 .as_primitive::<Float32Type>()
182 .unary::<_, Float32Type>(|x: f32| {
183 if x == 0_f32 { 0_f32 } else { x.trunc() }
184 }),
185 ) as ArrayRef)
186 }
187 ColumnarValue::Array(precision) => {
188 let num_array = num.as_primitive::<Float32Type>();
189 let precision_array = precision.as_primitive::<Int64Type>();
190 let result: PrimitiveArray<Float32Type> =
191 arrow::compute::binary(num_array, precision_array, |x, y| {
192 compute_truncate32(x, y)
193 })?;
194
195 Ok(Arc::new(result) as ArrayRef)
196 }
197 _ => exec_err!("trunc function requires a scalar or array for precision"),
198 },
199 other => exec_err!("Unsupported data type {other:?} for function trunc"),
200 }
201}
202
203fn compute_truncate32(x: f32, y: i64) -> f32 {
204 let factor = 10.0_f32.powi(y as i32);
205 (x * factor).round() / factor
206}
207
208fn compute_truncate64(x: f64, y: i64) -> f64 {
209 let factor = 10.0_f64.powi(y as i32);
210 (x * factor).round() / factor
211}
212
213#[cfg(test)]
214mod test {
215 use std::sync::Arc;
216
217 use crate::math::trunc::trunc;
218
219 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
220 use datafusion_common::cast::{as_float32_array, as_float64_array};
221
222 #[test]
223 fn test_truncate_32() {
224 let args: Vec<ArrayRef> = vec![
225 Arc::new(Float32Array::from(vec![
226 15.0,
227 1_234.267_8,
228 1_233.123_4,
229 3.312_979_2,
230 -21.123_4,
231 ])),
232 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
233 ];
234
235 let result = trunc(&args).expect("failed to initialize function truncate");
236 let floats =
237 as_float32_array(&result).expect("failed to initialize function truncate");
238
239 assert_eq!(floats.len(), 5);
240 assert_eq!(floats.value(0), 15.0);
241 assert_eq!(floats.value(1), 1_234.268);
242 assert_eq!(floats.value(2), 1_233.12);
243 assert_eq!(floats.value(3), 3.312_98);
244 assert_eq!(floats.value(4), -21.123_4);
245 }
246
247 #[test]
248 fn test_truncate_64() {
249 let args: Vec<ArrayRef> = vec![
250 Arc::new(Float64Array::from(vec![
251 5.0,
252 234.267_812_176,
253 123.123_456_789,
254 123.312_979_313_2,
255 -321.123_1,
256 ])),
257 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
258 ];
259
260 let result = trunc(&args).expect("failed to initialize function truncate");
261 let floats =
262 as_float64_array(&result).expect("failed to initialize function truncate");
263
264 assert_eq!(floats.len(), 5);
265 assert_eq!(floats.value(0), 5.0);
266 assert_eq!(floats.value(1), 234.268);
267 assert_eq!(floats.value(2), 123.12);
268 assert_eq!(floats.value(3), 123.312_98);
269 assert_eq!(floats.value(4), -321.123_1);
270 }
271
272 #[test]
273 fn test_truncate_64_one_arg() {
274 let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
275 5.0,
276 234.267_812,
277 123.123_45,
278 123.312_979_313_2,
279 -321.123,
280 ]))];
281
282 let result = trunc(&args).expect("failed to initialize function truncate");
283 let floats =
284 as_float64_array(&result).expect("failed to initialize function truncate");
285
286 assert_eq!(floats.len(), 5);
287 assert_eq!(floats.value(0), 5.0);
288 assert_eq!(floats.value(1), 234.0);
289 assert_eq!(floats.value(2), 123.0);
290 assert_eq!(floats.value(3), 123.0);
291 assert_eq!(floats.value(4), -321.0);
292 }
293}