datafusion_functions/math/
trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{Result, ScalarValue, exec_err};
28use datafusion_expr::TypeSignature::Exact;
29use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39 syntax_example = "trunc(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = r#"Optional. The number of decimal places to
44 truncate to. Defaults to 0 (truncate to a whole number). If
45 `decimal_places` is a positive integer, truncates digits to the
46 right of the decimal point. If `decimal_places` is a negative
47 integer, replaces digits to the left of the decimal point with `0`."#
48 ),
49 sql_example = r#"
50 ```sql
51 > SELECT trunc(42.738);
52 +----------------+
53 | trunc(42.738) |
54 +----------------+
55 | 42 |
56 +----------------+
57 ```"#
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub struct TruncFunc {
61 signature: Signature,
62}
63
64impl Default for TruncFunc {
65 fn default() -> Self {
66 TruncFunc::new()
67 }
68}
69
70impl TruncFunc {
71 pub fn new() -> Self {
72 use DataType::*;
73 Self {
74 signature: Signature::one_of(
80 vec![
81 Exact(vec![Float32, Int64]),
82 Exact(vec![Float64, Int64]),
83 Exact(vec![Float64]),
84 Exact(vec![Float32]),
85 ],
86 Volatility::Immutable,
87 ),
88 }
89 }
90}
91
92impl ScalarUDFImpl for TruncFunc {
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn name(&self) -> &str {
98 "trunc"
99 }
100
101 fn signature(&self) -> &Signature {
102 &self.signature
103 }
104
105 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106 match arg_types[0] {
107 Float32 => Ok(Float32),
108 _ => Ok(Float64),
109 }
110 }
111
112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113 let precision = match args.args.get(1) {
115 Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
116 Some(ColumnarValue::Scalar(Int64(None))) => None, Some(ColumnarValue::Array(_)) => {
118 return make_scalar_function(trunc, vec![])(&args.args);
120 }
121 None => Some(0), Some(cv) => {
123 return exec_err!(
124 "trunc function requires precision to be Int64, got {:?}",
125 cv.data_type()
126 );
127 }
128 };
129
130 match (&args.args[0], precision) {
132 (ColumnarValue::Scalar(sv), _) if sv.is_null() => {
134 ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
135 }
136 (_, None) => {
137 ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
138 }
139 (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok(
141 ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 {
142 v.trunc()
143 } else {
144 compute_truncate64(*v, p)
145 }))),
146 ),
147 (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok(
148 ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 {
149 v.trunc()
150 } else {
151 compute_truncate32(*v, p)
152 }))),
153 ),
154 _ => make_scalar_function(trunc, vec![])(&args.args),
156 }
157 }
158
159 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
160 let value = &input[0];
162 let precision = input.get(1);
163
164 if precision
165 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
166 .unwrap_or(true)
167 {
168 Ok(value.sort_properties)
169 } else {
170 Ok(SortProperties::Unordered)
171 }
172 }
173
174 fn documentation(&self) -> Option<&Documentation> {
175 self.doc()
176 }
177}
178
179fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
181 if args.len() != 1 && args.len() != 2 {
182 return exec_err!(
183 "truncate function requires one or two arguments, got {}",
184 args.len()
185 );
186 }
187
188 let num = &args[0];
191 let precision = if args.len() == 1 {
192 ColumnarValue::Scalar(Int64(Some(0)))
193 } else {
194 ColumnarValue::Array(Arc::clone(&args[1]))
195 };
196
197 match num.data_type() {
198 Float64 => match precision {
199 ColumnarValue::Scalar(Int64(Some(0))) => {
200 Ok(Arc::new(
201 args[0]
202 .as_primitive::<Float64Type>()
203 .unary::<_, Float64Type>(|x: f64| {
204 if x == 0_f64 { 0_f64 } else { x.trunc() }
205 }),
206 ) as ArrayRef)
207 }
208 ColumnarValue::Array(precision) => {
209 let num_array = num.as_primitive::<Float64Type>();
210 let precision_array = precision.as_primitive::<Int64Type>();
211 let result: PrimitiveArray<Float64Type> =
212 arrow::compute::binary(num_array, precision_array, |x, y| {
213 compute_truncate64(x, y)
214 })?;
215
216 Ok(Arc::new(result) as ArrayRef)
217 }
218 _ => exec_err!("trunc function requires a scalar or array for precision"),
219 },
220 Float32 => match precision {
221 ColumnarValue::Scalar(Int64(Some(0))) => {
222 Ok(Arc::new(
223 args[0]
224 .as_primitive::<Float32Type>()
225 .unary::<_, Float32Type>(|x: f32| {
226 if x == 0_f32 { 0_f32 } else { x.trunc() }
227 }),
228 ) as ArrayRef)
229 }
230 ColumnarValue::Array(precision) => {
231 let num_array = num.as_primitive::<Float32Type>();
232 let precision_array = precision.as_primitive::<Int64Type>();
233 let result: PrimitiveArray<Float32Type> =
234 arrow::compute::binary(num_array, precision_array, |x, y| {
235 compute_truncate32(x, y)
236 })?;
237
238 Ok(Arc::new(result) as ArrayRef)
239 }
240 _ => exec_err!("trunc function requires a scalar or array for precision"),
241 },
242 other => exec_err!("Unsupported data type {other:?} for function trunc"),
243 }
244}
245
246fn compute_truncate32(x: f32, y: i64) -> f32 {
247 let factor = 10.0_f32.powi(y as i32);
248 (x * factor).trunc() / factor
249}
250
251fn compute_truncate64(x: f64, y: i64) -> f64 {
252 let factor = 10.0_f64.powi(y as i32);
253 (x * factor).trunc() / factor
254}
255
256#[cfg(test)]
257mod test {
258 use std::sync::Arc;
259
260 use crate::math::trunc::trunc;
261
262 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
263 use datafusion_common::cast::{as_float32_array, as_float64_array};
264
265 #[test]
266 fn test_truncate_32() {
267 let args: Vec<ArrayRef> = vec![
268 Arc::new(Float32Array::from(vec![
269 15.0,
270 1_234.267_8,
271 1_233.123_4,
272 3.312_979_2,
273 -21.123_4,
274 ])),
275 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
276 ];
277
278 let result = trunc(&args).expect("failed to initialize function truncate");
279 let floats =
280 as_float32_array(&result).expect("failed to initialize function truncate");
281
282 assert_eq!(floats.len(), 5);
283 assert_eq!(floats.value(0), 15.0);
284 assert_eq!(floats.value(1), 1_234.267);
285 assert_eq!(floats.value(2), 1_233.12);
286 assert_eq!(floats.value(3), 3.312_97);
287 assert_eq!(floats.value(4), -21.123_4);
288 }
289
290 #[test]
291 fn test_truncate_64() {
292 let args: Vec<ArrayRef> = vec![
293 Arc::new(Float64Array::from(vec![
294 5.0,
295 234.267_812_176,
296 123.123_456_789,
297 123.312_979_313_2,
298 -321.123_1,
299 ])),
300 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
301 ];
302
303 let result = trunc(&args).expect("failed to initialize function truncate");
304 let floats =
305 as_float64_array(&result).expect("failed to initialize function truncate");
306
307 assert_eq!(floats.len(), 5);
308 assert_eq!(floats.value(0), 5.0);
309 assert_eq!(floats.value(1), 234.267);
310 assert_eq!(floats.value(2), 123.12);
311 assert_eq!(floats.value(3), 123.312_97);
312 assert_eq!(floats.value(4), -321.123_1);
313 }
314
315 #[test]
316 fn test_truncate_64_one_arg() {
317 let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
318 5.0,
319 234.267_812,
320 123.123_45,
321 123.312_979_313_2,
322 -321.123,
323 ]))];
324
325 let result = trunc(&args).expect("failed to initialize function truncate");
326 let floats =
327 as_float64_array(&result).expect("failed to initialize function truncate");
328
329 assert_eq!(floats.len(), 5);
330 assert_eq!(floats.value(0), 5.0);
331 assert_eq!(floats.value(1), 234.0);
332 assert_eq!(floats.value(2), 123.0);
333 assert_eq!(floats.value(3), 123.0);
334 assert_eq!(floats.value(4), -321.0);
335 }
336}