datafusion_functions/math/
trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39 syntax_example = "trunc(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = r#"Optional. The number of decimal places to
44 truncate to. Defaults to 0 (truncate to a whole number). If
45 `decimal_places` is a positive integer, truncates digits to the
46 right of the decimal point. If `decimal_places` is a negative
47 integer, replaces digits to the left of the decimal point with `0`."#
48 ),
49 sql_example = r#"
50 ```sql
51 > SELECT trunc(42.738);
52 +----------------+
53 | trunc(42.738) |
54 +----------------+
55 | 42 |
56 +----------------+
57 ```"#
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub struct TruncFunc {
61 signature: Signature,
62}
63
64impl Default for TruncFunc {
65 fn default() -> Self {
66 TruncFunc::new()
67 }
68}
69
70impl TruncFunc {
71 pub fn new() -> Self {
72 use DataType::*;
73 Self {
74 signature: Signature::one_of(
80 vec![
81 Exact(vec![Float32, Int64]),
82 Exact(vec![Float64, Int64]),
83 Exact(vec![Float64]),
84 Exact(vec![Float32]),
85 ],
86 Volatility::Immutable,
87 ),
88 }
89 }
90}
91
92impl ScalarUDFImpl for TruncFunc {
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn name(&self) -> &str {
98 "trunc"
99 }
100
101 fn signature(&self) -> &Signature {
102 &self.signature
103 }
104
105 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106 match arg_types[0] {
107 Float32 => Ok(Float32),
108 _ => Ok(Float64),
109 }
110 }
111
112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113 make_scalar_function(trunc, vec![])(&args.args)
114 }
115
116 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
117 let value = &input[0];
119 let precision = input.get(1);
120
121 if precision
122 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
123 .unwrap_or(true)
124 {
125 Ok(value.sort_properties)
126 } else {
127 Ok(SortProperties::Unordered)
128 }
129 }
130
131 fn documentation(&self) -> Option<&Documentation> {
132 self.doc()
133 }
134}
135
136fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
138 if args.len() != 1 && args.len() != 2 {
139 return exec_err!(
140 "truncate function requires one or two arguments, got {}",
141 args.len()
142 );
143 }
144
145 let num = &args[0];
148 let precision = if args.len() == 1 {
149 ColumnarValue::Scalar(Int64(Some(0)))
150 } else {
151 ColumnarValue::Array(Arc::clone(&args[1]))
152 };
153
154 match num.data_type() {
155 Float64 => match precision {
156 ColumnarValue::Scalar(Int64(Some(0))) => {
157 Ok(Arc::new(
158 args[0]
159 .as_primitive::<Float64Type>()
160 .unary::<_, Float64Type>(|x: f64| {
161 if x == 0_f64 {
162 0_f64
163 } else {
164 x.trunc()
165 }
166 }),
167 ) as ArrayRef)
168 }
169 ColumnarValue::Array(precision) => {
170 let num_array = num.as_primitive::<Float64Type>();
171 let precision_array = precision.as_primitive::<Int64Type>();
172 let result: PrimitiveArray<Float64Type> =
173 arrow::compute::binary(num_array, precision_array, |x, y| {
174 compute_truncate64(x, y)
175 })?;
176
177 Ok(Arc::new(result) as ArrayRef)
178 }
179 _ => exec_err!("trunc function requires a scalar or array for precision"),
180 },
181 Float32 => match precision {
182 ColumnarValue::Scalar(Int64(Some(0))) => {
183 Ok(Arc::new(
184 args[0]
185 .as_primitive::<Float32Type>()
186 .unary::<_, Float32Type>(|x: f32| {
187 if x == 0_f32 {
188 0_f32
189 } else {
190 x.trunc()
191 }
192 }),
193 ) as ArrayRef)
194 }
195 ColumnarValue::Array(precision) => {
196 let num_array = num.as_primitive::<Float32Type>();
197 let precision_array = precision.as_primitive::<Int64Type>();
198 let result: PrimitiveArray<Float32Type> =
199 arrow::compute::binary(num_array, precision_array, |x, y| {
200 compute_truncate32(x, y)
201 })?;
202
203 Ok(Arc::new(result) as ArrayRef)
204 }
205 _ => exec_err!("trunc function requires a scalar or array for precision"),
206 },
207 other => exec_err!("Unsupported data type {other:?} for function trunc"),
208 }
209}
210
211fn compute_truncate32(x: f32, y: i64) -> f32 {
212 let factor = 10.0_f32.powi(y as i32);
213 (x * factor).round() / factor
214}
215
216fn compute_truncate64(x: f64, y: i64) -> f64 {
217 let factor = 10.0_f64.powi(y as i32);
218 (x * factor).round() / factor
219}
220
221#[cfg(test)]
222mod test {
223 use std::sync::Arc;
224
225 use crate::math::trunc::trunc;
226
227 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
228 use datafusion_common::cast::{as_float32_array, as_float64_array};
229
230 #[test]
231 fn test_truncate_32() {
232 let args: Vec<ArrayRef> = vec![
233 Arc::new(Float32Array::from(vec![
234 15.0,
235 1_234.267_8,
236 1_233.123_4,
237 3.312_979_2,
238 -21.123_4,
239 ])),
240 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
241 ];
242
243 let result = trunc(&args).expect("failed to initialize function truncate");
244 let floats =
245 as_float32_array(&result).expect("failed to initialize function truncate");
246
247 assert_eq!(floats.len(), 5);
248 assert_eq!(floats.value(0), 15.0);
249 assert_eq!(floats.value(1), 1_234.268);
250 assert_eq!(floats.value(2), 1_233.12);
251 assert_eq!(floats.value(3), 3.312_98);
252 assert_eq!(floats.value(4), -21.123_4);
253 }
254
255 #[test]
256 fn test_truncate_64() {
257 let args: Vec<ArrayRef> = vec![
258 Arc::new(Float64Array::from(vec![
259 5.0,
260 234.267_812_176,
261 123.123_456_789,
262 123.312_979_313_2,
263 -321.123_1,
264 ])),
265 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
266 ];
267
268 let result = trunc(&args).expect("failed to initialize function truncate");
269 let floats =
270 as_float64_array(&result).expect("failed to initialize function truncate");
271
272 assert_eq!(floats.len(), 5);
273 assert_eq!(floats.value(0), 5.0);
274 assert_eq!(floats.value(1), 234.268);
275 assert_eq!(floats.value(2), 123.12);
276 assert_eq!(floats.value(3), 123.312_98);
277 assert_eq!(floats.value(4), -321.123_1);
278 }
279
280 #[test]
281 fn test_truncate_64_one_arg() {
282 let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
283 5.0,
284 234.267_812,
285 123.123_45,
286 123.312_979_313_2,
287 -321.123,
288 ]))];
289
290 let result = trunc(&args).expect("failed to initialize function truncate");
291 let floats =
292 as_float64_array(&result).expect("failed to initialize function truncate");
293
294 assert_eq!(floats.len(), 5);
295 assert_eq!(floats.value(0), 5.0);
296 assert_eq!(floats.value(1), 234.0);
297 assert_eq!(floats.value(2), 123.0);
298 assert_eq!(floats.value(3), 123.0);
299 assert_eq!(floats.value(4), -321.0);
300 }
301}