1use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray};
21use arrow::compute::{DecimalCast, rescale_decimal};
22use arrow::datatypes::{
23 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
24 Decimal256Type, DecimalType, Float32Type, Float64Type,
25};
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::interval_arithmetic::Interval;
28use datafusion_expr::preimage::PreimageResult;
29use datafusion_expr::simplify::SimplifyContext;
30use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
31use datafusion_expr::{
32 Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
33 Signature, TypeSignature, TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36use num_traits::{CheckedAdd, Float, One};
37
38use super::decimal::{apply_decimal_op, floor_decimal_value};
39
40#[user_doc(
41 doc_section(label = "Math Functions"),
42 description = "Returns the nearest integer less than or equal to a number.",
43 syntax_example = "floor(numeric_expression)",
44 standard_argument(name = "numeric_expression", prefix = "Numeric"),
45 sql_example = r#"```sql
46> SELECT floor(3.14);
47+-------------+
48| floor(3.14) |
49+-------------+
50| 3.0 |
51+-------------+
52```"#
53)]
54#[derive(Debug, PartialEq, Eq, Hash)]
55pub struct FloorFunc {
56 signature: Signature,
57}
58
59impl Default for FloorFunc {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl FloorFunc {
66 pub fn new() -> Self {
67 let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
68 Self {
69 signature: Signature::one_of(
70 vec![
71 TypeSignature::Coercible(vec![decimal_sig]),
72 TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
73 ],
74 Volatility::Immutable,
75 ),
76 }
77 }
78}
79
80macro_rules! preimage_bounds {
83 (float: $variant:ident, $value:expr) => {
85 float_preimage_bounds($value).map(|(lo, hi)| {
86 (
87 ScalarValue::$variant(Some(lo)),
88 ScalarValue::$variant(Some(hi)),
89 )
90 })
91 };
92
93 (int: $variant:ident, $value:expr) => {
95 int_preimage_bounds($value).map(|(lo, hi)| {
96 (
97 ScalarValue::$variant(Some(lo)),
98 ScalarValue::$variant(Some(hi)),
99 )
100 })
101 };
102
103 (decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => {
105 decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map(
106 |(lo, hi)| {
107 (
108 ScalarValue::$variant(Some(lo), $precision, $scale),
109 ScalarValue::$variant(Some(hi), $precision, $scale),
110 )
111 },
112 )
113 };
114}
115
116impl ScalarUDFImpl for FloorFunc {
117 fn name(&self) -> &str {
118 "floor"
119 }
120
121 fn signature(&self) -> &Signature {
122 &self.signature
123 }
124
125 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
126 match &arg_types[0] {
127 DataType::Null => Ok(DataType::Float64),
128 other => Ok(other.clone()),
129 }
130 }
131
132 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
133 let arg = &args.args[0];
134
135 if let ColumnarValue::Scalar(scalar) = arg {
137 match scalar {
138 ScalarValue::Float64(v) => {
139 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(
140 v.map(f64::floor),
141 )));
142 }
143 ScalarValue::Float32(v) => {
144 return Ok(ColumnarValue::Scalar(ScalarValue::Float32(
145 v.map(f32::floor),
146 )));
147 }
148 ScalarValue::Null => {
149 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
150 }
151 _ => {}
154 }
155 }
156
157 let is_scalar = matches!(arg, ColumnarValue::Scalar(_));
159
160 let value = arg.to_array(args.number_rows)?;
162
163 let result: ArrayRef = match value.data_type() {
164 DataType::Float64 => Arc::new(
165 value
166 .as_primitive::<Float64Type>()
167 .unary::<_, Float64Type>(f64::floor),
168 ),
169 DataType::Float32 => Arc::new(
170 value
171 .as_primitive::<Float32Type>()
172 .unary::<_, Float32Type>(f32::floor),
173 ),
174 DataType::Null => {
175 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
176 }
177 DataType::Decimal32(precision, scale) => {
178 apply_decimal_op::<Decimal32Type, _>(
179 &value,
180 *precision,
181 *scale,
182 self.name(),
183 floor_decimal_value,
184 )?
185 }
186 DataType::Decimal64(precision, scale) => {
187 apply_decimal_op::<Decimal64Type, _>(
188 &value,
189 *precision,
190 *scale,
191 self.name(),
192 floor_decimal_value,
193 )?
194 }
195 DataType::Decimal128(precision, scale) => {
196 apply_decimal_op::<Decimal128Type, _>(
197 &value,
198 *precision,
199 *scale,
200 self.name(),
201 floor_decimal_value,
202 )?
203 }
204 DataType::Decimal256(precision, scale) => {
205 apply_decimal_op::<Decimal256Type, _>(
206 &value,
207 *precision,
208 *scale,
209 self.name(),
210 floor_decimal_value,
211 )?
212 }
213 other => {
214 return exec_err!(
215 "Unsupported data type {other:?} for function {}",
216 self.name()
217 );
218 }
219 };
220
221 if is_scalar {
223 ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar)
224 } else {
225 Ok(ColumnarValue::Array(result))
226 }
227 }
228
229 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
230 Ok(input[0].sort_properties)
231 }
232
233 fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
234 let data_type = inputs[0].data_type();
235 Interval::make_unbounded(&data_type)
236 }
237
238 fn preimage(
246 &self,
247 args: &[Expr],
248 lit_expr: &Expr,
249 _info: &SimplifyContext,
250 ) -> Result<PreimageResult> {
251 debug_assert!(args.len() == 1, "floor() takes exactly one argument");
253
254 let arg = args[0].clone();
255
256 let Expr::Literal(lit_value, _) = lit_expr else {
258 return Ok(PreimageResult::None);
259 };
260
261 let Some((lower, upper)) = (match lit_value {
263 ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n),
265 ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n),
266
267 ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n),
271 ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n),
272 ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n),
273 ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n),
274
275 ScalarValue::Decimal32(Some(n), precision, scale) => {
280 preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale)
281 }
282 ScalarValue::Decimal64(Some(n), precision, scale) => {
283 preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale)
284 }
285 ScalarValue::Decimal128(Some(n), precision, scale) => {
286 preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale)
287 }
288 ScalarValue::Decimal256(Some(n), precision, scale) => {
289 preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale)
290 }
291
292 _ => None,
294 }) else {
295 return Ok(PreimageResult::None);
296 };
297
298 Ok(PreimageResult::Range {
299 expr: arg,
300 interval: Box::new(Interval::try_new(lower, upper)?),
301 })
302 }
303
304 fn documentation(&self) -> Option<&Documentation> {
305 self.doc()
306 }
307}
308
309fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
318 let one = F::one();
319 if !n.is_finite() {
321 return None;
322 }
323 if n.fract() != F::zero() {
325 return None;
326 }
327 if n + one <= n {
329 return None;
330 }
331 Some((n, n + one))
332}
333
334fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
338 let upper = n.checked_add(&I::one())?;
339 Some((n, upper))
340}
341
342fn decimal_preimage_bounds<D: DecimalType>(
348 value: D::Native,
349 precision: u8,
350 scale: i8,
351) -> Option<(D::Native, D::Native)>
352where
353 D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
354{
355 let one_scaled: D::Native = rescale_decimal::<D, D>(
358 D::Native::ONE, 1, 0, precision, scale, )?;
364
365 if scale > 0 && value % one_scaled != D::Native::ZERO {
368 return None;
369 }
370
371 let upper = value.add_checked(one_scaled).ok()?;
377
378 Some((value, upper))
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use arrow_buffer::i256;
385 use datafusion_expr::col;
386
387 fn assert_preimage_range(
389 input: ScalarValue,
390 expected_lower: ScalarValue,
391 expected_upper: ScalarValue,
392 ) {
393 let floor_func = FloorFunc::new();
394 let args = vec![col("x")];
395 let lit_expr = Expr::Literal(input.clone(), None);
396 let info = SimplifyContext::default();
397
398 let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
399
400 match result {
401 PreimageResult::Range { expr, interval } => {
402 assert_eq!(expr, col("x"));
403 assert_eq!(interval.lower().clone(), expected_lower);
404 assert_eq!(interval.upper().clone(), expected_upper);
405 }
406 PreimageResult::None => {
407 panic!("Expected Range, got None for input {input:?}")
408 }
409 }
410 }
411
412 fn assert_preimage_none(input: ScalarValue) {
414 let floor_func = FloorFunc::new();
415 let args = vec![col("x")];
416 let lit_expr = Expr::Literal(input.clone(), None);
417 let info = SimplifyContext::default();
418
419 let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
420 assert!(
421 matches!(result, PreimageResult::None),
422 "Expected None for input {input:?}"
423 );
424 }
425
426 #[test]
427 fn test_floor_preimage_valid_cases() {
428 assert_preimage_range(
430 ScalarValue::Float64(Some(100.0)),
431 ScalarValue::Float64(Some(100.0)),
432 ScalarValue::Float64(Some(101.0)),
433 );
434 assert_preimage_range(
436 ScalarValue::Float32(Some(50.0)),
437 ScalarValue::Float32(Some(50.0)),
438 ScalarValue::Float32(Some(51.0)),
439 );
440 assert_preimage_range(
442 ScalarValue::Int64(Some(42)),
443 ScalarValue::Int64(Some(42)),
444 ScalarValue::Int64(Some(43)),
445 );
446 assert_preimage_range(
448 ScalarValue::Int32(Some(100)),
449 ScalarValue::Int32(Some(100)),
450 ScalarValue::Int32(Some(101)),
451 );
452 assert_preimage_range(
454 ScalarValue::Float64(Some(-5.0)),
455 ScalarValue::Float64(Some(-5.0)),
456 ScalarValue::Float64(Some(-4.0)),
457 );
458 assert_preimage_range(
460 ScalarValue::Float64(Some(0.0)),
461 ScalarValue::Float64(Some(0.0)),
462 ScalarValue::Float64(Some(1.0)),
463 );
464 }
465
466 #[test]
467 fn test_floor_preimage_non_integer_float() {
468 assert_preimage_none(ScalarValue::Float64(Some(1.3)));
471 assert_preimage_none(ScalarValue::Float64(Some(-2.5)));
472 assert_preimage_none(ScalarValue::Float32(Some(3.7)));
473 }
474
475 #[test]
476 fn test_floor_preimage_integer_overflow() {
477 assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
479 assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
480 assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
481 assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
482 }
483
484 #[test]
485 fn test_floor_preimage_float_edge_cases() {
486 assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
488 assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
489 assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
490 assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
494 assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
495 assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
496 assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); }
498
499 #[test]
500 fn test_floor_preimage_null_values() {
501 assert_preimage_none(ScalarValue::Float64(None));
502 assert_preimage_none(ScalarValue::Float32(None));
503 assert_preimage_none(ScalarValue::Int64(None));
504 }
505
506 #[test]
509 fn test_floor_preimage_decimal_valid_cases() {
510 assert_preimage_range(
514 ScalarValue::Decimal32(Some(10000), 9, 2),
515 ScalarValue::Decimal32(Some(10000), 9, 2), ScalarValue::Decimal32(Some(10100), 9, 2), );
518
519 assert_preimage_range(
521 ScalarValue::Decimal32(Some(5000), 9, 2),
522 ScalarValue::Decimal32(Some(5000), 9, 2), ScalarValue::Decimal32(Some(5100), 9, 2), );
525
526 assert_preimage_range(
528 ScalarValue::Decimal32(Some(-500), 9, 2),
529 ScalarValue::Decimal32(Some(-500), 9, 2), ScalarValue::Decimal32(Some(-400), 9, 2), );
532
533 assert_preimage_range(
535 ScalarValue::Decimal32(Some(0), 9, 2),
536 ScalarValue::Decimal32(Some(0), 9, 2), ScalarValue::Decimal32(Some(100), 9, 2), );
539
540 assert_preimage_range(
542 ScalarValue::Decimal32(Some(42), 9, 0),
543 ScalarValue::Decimal32(Some(42), 9, 0),
544 ScalarValue::Decimal32(Some(43), 9, 0),
545 );
546
547 assert_preimage_range(
549 ScalarValue::Decimal64(Some(10000), 18, 2),
550 ScalarValue::Decimal64(Some(10000), 18, 2), ScalarValue::Decimal64(Some(10100), 18, 2), );
553
554 assert_preimage_range(
556 ScalarValue::Decimal64(Some(-500), 18, 2),
557 ScalarValue::Decimal64(Some(-500), 18, 2), ScalarValue::Decimal64(Some(-400), 18, 2), );
560
561 assert_preimage_range(
563 ScalarValue::Decimal64(Some(0), 18, 2),
564 ScalarValue::Decimal64(Some(0), 18, 2),
565 ScalarValue::Decimal64(Some(100), 18, 2),
566 );
567
568 assert_preimage_range(
570 ScalarValue::Decimal128(Some(10000), 38, 2),
571 ScalarValue::Decimal128(Some(10000), 38, 2), ScalarValue::Decimal128(Some(10100), 38, 2), );
574
575 assert_preimage_range(
577 ScalarValue::Decimal128(Some(-500), 38, 2),
578 ScalarValue::Decimal128(Some(-500), 38, 2), ScalarValue::Decimal128(Some(-400), 38, 2), );
581
582 assert_preimage_range(
584 ScalarValue::Decimal128(Some(0), 38, 2),
585 ScalarValue::Decimal128(Some(0), 38, 2),
586 ScalarValue::Decimal128(Some(100), 38, 2),
587 );
588
589 assert_preimage_range(
591 ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
592 ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), );
595
596 assert_preimage_range(
598 ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
599 ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), );
602
603 assert_preimage_range(
605 ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
606 ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
607 ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
608 );
609 }
610
611 #[test]
612 fn test_floor_preimage_decimal_non_integer() {
613 assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2));
637
638 assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2));
641 }
642
643 #[test]
644 fn test_floor_preimage_decimal_overflow() {
645 assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0));
649
650 assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0));
652 }
653
654 #[test]
655 fn test_floor_preimage_decimal_edge_cases() {
656 let safe_max_aligned_32 = 999_999_900; assert_preimage_range(
662 ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
663 ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
664 ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
665 );
666
667 let min_aligned_32 = -999_999_900; assert_preimage_range(
671 ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
672 ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
673 ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
674 );
675 }
676
677 #[test]
678 fn test_floor_preimage_decimal_null() {
679 assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
680 assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
681 assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
682 assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
683 }
684}