datafusion_functions/math/
trunc.rs1use std::sync::Arc;
19
20use crate::utils::make_scalar_function;
21
22use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
23use arrow::datatypes::DataType::{Float32, Float64};
24use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
25use datafusion_common::ScalarValue::Int64;
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::{
30 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31 Volatility,
32};
33use datafusion_macros::user_doc;
34
35#[user_doc(
36 doc_section(label = "Math Functions"),
37 description = "Truncates a number to a whole number or truncated to the specified decimal places.",
38 syntax_example = "trunc(numeric_expression[, decimal_places])",
39 standard_argument(name = "numeric_expression", prefix = "Numeric"),
40 argument(
41 name = "decimal_places",
42 description = r#"Optional. The number of decimal places to
43 truncate to. Defaults to 0 (truncate to a whole number). If
44 `decimal_places` is a positive integer, truncates digits to the
45 right of the decimal point. If `decimal_places` is a negative
46 integer, replaces digits to the left of the decimal point with `0`."#
47 ),
48 sql_example = r#"
49 ```sql
50 > SELECT trunc(42.738);
51 +----------------+
52 | trunc(42.738) |
53 +----------------+
54 | 42 |
55 +----------------+
56 ```"#
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct TruncFunc {
60 signature: Signature,
61}
62
63impl Default for TruncFunc {
64 fn default() -> Self {
65 TruncFunc::new()
66 }
67}
68
69impl TruncFunc {
70 pub fn new() -> Self {
71 use DataType::*;
72 Self {
73 signature: Signature::one_of(
79 vec![
80 Exact(vec![Float32, Int64]),
81 Exact(vec![Float64, Int64]),
82 Exact(vec![Float64]),
83 Exact(vec![Float32]),
84 ],
85 Volatility::Immutable,
86 ),
87 }
88 }
89}
90
91impl ScalarUDFImpl for TruncFunc {
92 fn name(&self) -> &str {
93 "trunc"
94 }
95
96 fn signature(&self) -> &Signature {
97 &self.signature
98 }
99
100 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101 match arg_types[0] {
102 Float32 => Ok(Float32),
103 _ => Ok(Float64),
104 }
105 }
106
107 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
108 let precision = match args.args.get(1) {
110 Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
111 Some(ColumnarValue::Scalar(Int64(None))) => None, Some(ColumnarValue::Array(_)) => {
113 return make_scalar_function(trunc, vec![])(&args.args);
115 }
116 None => Some(0), Some(cv) => {
118 return exec_err!(
119 "trunc function requires precision to be Int64, got {:?}",
120 cv.data_type()
121 );
122 }
123 };
124
125 match (&args.args[0], precision) {
127 (ColumnarValue::Scalar(sv), _) if sv.is_null() => {
129 ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
130 }
131 (_, None) => {
132 ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
133 }
134 (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok(
136 ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 {
137 v.trunc()
138 } else {
139 compute_truncate64(*v, p)
140 }))),
141 ),
142 (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok(
143 ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 {
144 v.trunc()
145 } else {
146 compute_truncate32(*v, p)
147 }))),
148 ),
149 _ => make_scalar_function(trunc, vec![])(&args.args),
151 }
152 }
153
154 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
155 let value = &input[0];
157 let precision = input.get(1);
158
159 if precision
160 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
161 .unwrap_or(true)
162 {
163 Ok(value.sort_properties)
164 } else {
165 Ok(SortProperties::Unordered)
166 }
167 }
168
169 fn documentation(&self) -> Option<&Documentation> {
170 self.doc()
171 }
172}
173
174fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
176 if args.len() != 1 && args.len() != 2 {
177 return exec_err!(
178 "truncate function requires one or two arguments, got {}",
179 args.len()
180 );
181 }
182
183 let num = &args[0];
186 let precision = if args.len() == 1 {
187 ColumnarValue::Scalar(Int64(Some(0)))
188 } else {
189 ColumnarValue::Array(Arc::clone(&args[1]))
190 };
191
192 match num.data_type() {
193 Float64 => match precision {
194 ColumnarValue::Scalar(Int64(Some(0))) => {
195 Ok(Arc::new(
196 args[0]
197 .as_primitive::<Float64Type>()
198 .unary::<_, Float64Type>(|x: f64| {
199 if x == 0_f64 { 0_f64 } else { x.trunc() }
200 }),
201 ) as ArrayRef)
202 }
203 ColumnarValue::Array(precision) => {
204 let num_array = num.as_primitive::<Float64Type>();
205 let precision_array = precision.as_primitive::<Int64Type>();
206 let result: PrimitiveArray<Float64Type> =
207 arrow::compute::binary(num_array, precision_array, |x, y| {
208 compute_truncate64(x, y)
209 })?;
210
211 Ok(Arc::new(result) as ArrayRef)
212 }
213 _ => exec_err!("trunc function requires a scalar or array for precision"),
214 },
215 Float32 => match precision {
216 ColumnarValue::Scalar(Int64(Some(0))) => {
217 Ok(Arc::new(
218 args[0]
219 .as_primitive::<Float32Type>()
220 .unary::<_, Float32Type>(|x: f32| {
221 if x == 0_f32 { 0_f32 } else { x.trunc() }
222 }),
223 ) as ArrayRef)
224 }
225 ColumnarValue::Array(precision) => {
226 let num_array = num.as_primitive::<Float32Type>();
227 let precision_array = precision.as_primitive::<Int64Type>();
228 let result: PrimitiveArray<Float32Type> =
229 arrow::compute::binary(num_array, precision_array, |x, y| {
230 compute_truncate32(x, y)
231 })?;
232
233 Ok(Arc::new(result) as ArrayRef)
234 }
235 _ => exec_err!("trunc function requires a scalar or array for precision"),
236 },
237 other => exec_err!("Unsupported data type {other:?} for function trunc"),
238 }
239}
240
241fn compute_truncate32(x: f32, y: i64) -> f32 {
242 let factor = 10.0_f32.powi(y as i32);
243 (x * factor).trunc() / factor
244}
245
246fn compute_truncate64(x: f64, y: i64) -> f64 {
247 let factor = 10.0_f64.powi(y as i32);
248 (x * factor).trunc() / factor
249}
250
251#[cfg(test)]
252mod test {
253 use std::sync::Arc;
254
255 use crate::math::trunc::trunc;
256
257 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
258 use datafusion_common::cast::{as_float32_array, as_float64_array};
259
260 #[test]
261 fn test_truncate_32() {
262 let args: Vec<ArrayRef> = vec![
263 Arc::new(Float32Array::from(vec![
264 15.0,
265 1_234.267_8,
266 1_233.123_4,
267 3.312_979_2,
268 -21.123_4,
269 ])),
270 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
271 ];
272
273 let result = trunc(&args).expect("failed to initialize function truncate");
274 let floats =
275 as_float32_array(&result).expect("failed to initialize function truncate");
276
277 assert_eq!(floats.len(), 5);
278 assert_eq!(floats.value(0), 15.0);
279 assert_eq!(floats.value(1), 1_234.267);
280 assert_eq!(floats.value(2), 1_233.12);
281 assert_eq!(floats.value(3), 3.312_97);
282 assert_eq!(floats.value(4), -21.123_4);
283 }
284
285 #[test]
286 fn test_truncate_64() {
287 let args: Vec<ArrayRef> = vec![
288 Arc::new(Float64Array::from(vec![
289 5.0,
290 234.267_812_176,
291 123.123_456_789,
292 123.312_979_313_2,
293 -321.123_1,
294 ])),
295 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
296 ];
297
298 let result = trunc(&args).expect("failed to initialize function truncate");
299 let floats =
300 as_float64_array(&result).expect("failed to initialize function truncate");
301
302 assert_eq!(floats.len(), 5);
303 assert_eq!(floats.value(0), 5.0);
304 assert_eq!(floats.value(1), 234.267);
305 assert_eq!(floats.value(2), 123.12);
306 assert_eq!(floats.value(3), 123.312_97);
307 assert_eq!(floats.value(4), -321.123_1);
308 }
309
310 #[test]
311 fn test_truncate_64_one_arg() {
312 let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
313 5.0,
314 234.267_812,
315 123.123_45,
316 123.312_979_313_2,
317 -321.123,
318 ]))];
319
320 let result = trunc(&args).expect("failed to initialize function truncate");
321 let floats =
322 as_float64_array(&result).expect("failed to initialize function truncate");
323
324 assert_eq!(floats.len(), 5);
325 assert_eq!(floats.value(0), 5.0);
326 assert_eq!(floats.value(1), 234.0);
327 assert_eq!(floats.value(2), 123.0);
328 assert_eq!(floats.value(3), 123.0);
329 assert_eq!(floats.value(4), -321.0);
330 }
331}