1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray};
22use arrow::compute::{DecimalCast, rescale_decimal};
23use arrow::datatypes::{
24 ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
25 Decimal256Type, DecimalType, Float32Type, Float64Type,
26};
27use datafusion_common::{Result, ScalarValue, exec_err};
28use datafusion_expr::interval_arithmetic::Interval;
29use datafusion_expr::preimage::PreimageResult;
30use datafusion_expr::simplify::SimplifyContext;
31use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
32use datafusion_expr::{
33 Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
34 Signature, TypeSignature, TypeSignatureClass, Volatility,
35};
36use datafusion_macros::user_doc;
37use num_traits::{CheckedAdd, Float, One};
38
39use super::decimal::{apply_decimal_op, floor_decimal_value};
40
41#[user_doc(
42 doc_section(label = "Math Functions"),
43 description = "Returns the nearest integer less than or equal to a number.",
44 syntax_example = "floor(numeric_expression)",
45 standard_argument(name = "numeric_expression", prefix = "Numeric"),
46 sql_example = r#"```sql
47> SELECT floor(3.14);
48+-------------+
49| floor(3.14) |
50+-------------+
51| 3.0 |
52+-------------+
53```"#
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct FloorFunc {
57 signature: Signature,
58}
59
60impl Default for FloorFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl FloorFunc {
67 pub fn new() -> Self {
68 let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
69 Self {
70 signature: Signature::one_of(
71 vec![
72 TypeSignature::Coercible(vec![decimal_sig]),
73 TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
74 ],
75 Volatility::Immutable,
76 ),
77 }
78 }
79}
80
81macro_rules! preimage_bounds {
84 (float: $variant:ident, $value:expr) => {
86 float_preimage_bounds($value).map(|(lo, hi)| {
87 (
88 ScalarValue::$variant(Some(lo)),
89 ScalarValue::$variant(Some(hi)),
90 )
91 })
92 };
93
94 (int: $variant:ident, $value:expr) => {
96 int_preimage_bounds($value).map(|(lo, hi)| {
97 (
98 ScalarValue::$variant(Some(lo)),
99 ScalarValue::$variant(Some(hi)),
100 )
101 })
102 };
103
104 (decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => {
106 decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map(
107 |(lo, hi)| {
108 (
109 ScalarValue::$variant(Some(lo), $precision, $scale),
110 ScalarValue::$variant(Some(hi), $precision, $scale),
111 )
112 },
113 )
114 };
115}
116
117impl ScalarUDFImpl for FloorFunc {
118 fn as_any(&self) -> &dyn Any {
119 self
120 }
121
122 fn name(&self) -> &str {
123 "floor"
124 }
125
126 fn signature(&self) -> &Signature {
127 &self.signature
128 }
129
130 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
131 match &arg_types[0] {
132 DataType::Null => Ok(DataType::Float64),
133 other => Ok(other.clone()),
134 }
135 }
136
137 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
138 let arg = &args.args[0];
139
140 if let ColumnarValue::Scalar(scalar) = arg {
142 match scalar {
143 ScalarValue::Float64(v) => {
144 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(
145 v.map(f64::floor),
146 )));
147 }
148 ScalarValue::Float32(v) => {
149 return Ok(ColumnarValue::Scalar(ScalarValue::Float32(
150 v.map(f32::floor),
151 )));
152 }
153 ScalarValue::Null => {
154 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
155 }
156 _ => {}
159 }
160 }
161
162 let is_scalar = matches!(arg, ColumnarValue::Scalar(_));
164
165 let value = arg.to_array(args.number_rows)?;
167
168 let result: ArrayRef = match value.data_type() {
169 DataType::Float64 => Arc::new(
170 value
171 .as_primitive::<Float64Type>()
172 .unary::<_, Float64Type>(f64::floor),
173 ),
174 DataType::Float32 => Arc::new(
175 value
176 .as_primitive::<Float32Type>()
177 .unary::<_, Float32Type>(f32::floor),
178 ),
179 DataType::Null => {
180 return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
181 }
182 DataType::Decimal32(precision, scale) => {
183 apply_decimal_op::<Decimal32Type, _>(
184 &value,
185 *precision,
186 *scale,
187 self.name(),
188 floor_decimal_value,
189 )?
190 }
191 DataType::Decimal64(precision, scale) => {
192 apply_decimal_op::<Decimal64Type, _>(
193 &value,
194 *precision,
195 *scale,
196 self.name(),
197 floor_decimal_value,
198 )?
199 }
200 DataType::Decimal128(precision, scale) => {
201 apply_decimal_op::<Decimal128Type, _>(
202 &value,
203 *precision,
204 *scale,
205 self.name(),
206 floor_decimal_value,
207 )?
208 }
209 DataType::Decimal256(precision, scale) => {
210 apply_decimal_op::<Decimal256Type, _>(
211 &value,
212 *precision,
213 *scale,
214 self.name(),
215 floor_decimal_value,
216 )?
217 }
218 other => {
219 return exec_err!(
220 "Unsupported data type {other:?} for function {}",
221 self.name()
222 );
223 }
224 };
225
226 if is_scalar {
228 ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar)
229 } else {
230 Ok(ColumnarValue::Array(result))
231 }
232 }
233
234 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
235 Ok(input[0].sort_properties)
236 }
237
238 fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
239 let data_type = inputs[0].data_type();
240 Interval::make_unbounded(&data_type)
241 }
242
243 fn preimage(
251 &self,
252 args: &[Expr],
253 lit_expr: &Expr,
254 _info: &SimplifyContext,
255 ) -> Result<PreimageResult> {
256 debug_assert!(args.len() == 1, "floor() takes exactly one argument");
258
259 let arg = args[0].clone();
260
261 let Expr::Literal(lit_value, _) = lit_expr else {
263 return Ok(PreimageResult::None);
264 };
265
266 let Some((lower, upper)) = (match lit_value {
268 ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n),
270 ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n),
271
272 ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n),
276 ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n),
277 ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n),
278 ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n),
279
280 ScalarValue::Decimal32(Some(n), precision, scale) => {
285 preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale)
286 }
287 ScalarValue::Decimal64(Some(n), precision, scale) => {
288 preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale)
289 }
290 ScalarValue::Decimal128(Some(n), precision, scale) => {
291 preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale)
292 }
293 ScalarValue::Decimal256(Some(n), precision, scale) => {
294 preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale)
295 }
296
297 _ => None,
299 }) else {
300 return Ok(PreimageResult::None);
301 };
302
303 Ok(PreimageResult::Range {
304 expr: arg,
305 interval: Box::new(Interval::try_new(lower, upper)?),
306 })
307 }
308
309 fn documentation(&self) -> Option<&Documentation> {
310 self.doc()
311 }
312}
313
314fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
323 let one = F::one();
324 if !n.is_finite() {
326 return None;
327 }
328 if n.fract() != F::zero() {
330 return None;
331 }
332 if n + one <= n {
334 return None;
335 }
336 Some((n, n + one))
337}
338
339fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
343 let upper = n.checked_add(&I::one())?;
344 Some((n, upper))
345}
346
347fn decimal_preimage_bounds<D: DecimalType>(
353 value: D::Native,
354 precision: u8,
355 scale: i8,
356) -> Option<(D::Native, D::Native)>
357where
358 D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
359{
360 let one_scaled: D::Native = rescale_decimal::<D, D>(
363 D::Native::ONE, 1, 0, precision, scale, )?;
369
370 if scale > 0 && value % one_scaled != D::Native::ZERO {
373 return None;
374 }
375
376 let upper = value.add_checked(one_scaled).ok()?;
382
383 Some((value, upper))
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use arrow_buffer::i256;
390 use datafusion_expr::col;
391
392 fn assert_preimage_range(
394 input: ScalarValue,
395 expected_lower: ScalarValue,
396 expected_upper: ScalarValue,
397 ) {
398 let floor_func = FloorFunc::new();
399 let args = vec![col("x")];
400 let lit_expr = Expr::Literal(input.clone(), None);
401 let info = SimplifyContext::default();
402
403 let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
404
405 match result {
406 PreimageResult::Range { expr, interval } => {
407 assert_eq!(expr, col("x"));
408 assert_eq!(interval.lower().clone(), expected_lower);
409 assert_eq!(interval.upper().clone(), expected_upper);
410 }
411 PreimageResult::None => {
412 panic!("Expected Range, got None for input {input:?}")
413 }
414 }
415 }
416
417 fn assert_preimage_none(input: ScalarValue) {
419 let floor_func = FloorFunc::new();
420 let args = vec![col("x")];
421 let lit_expr = Expr::Literal(input.clone(), None);
422 let info = SimplifyContext::default();
423
424 let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
425 assert!(
426 matches!(result, PreimageResult::None),
427 "Expected None for input {input:?}"
428 );
429 }
430
431 #[test]
432 fn test_floor_preimage_valid_cases() {
433 assert_preimage_range(
435 ScalarValue::Float64(Some(100.0)),
436 ScalarValue::Float64(Some(100.0)),
437 ScalarValue::Float64(Some(101.0)),
438 );
439 assert_preimage_range(
441 ScalarValue::Float32(Some(50.0)),
442 ScalarValue::Float32(Some(50.0)),
443 ScalarValue::Float32(Some(51.0)),
444 );
445 assert_preimage_range(
447 ScalarValue::Int64(Some(42)),
448 ScalarValue::Int64(Some(42)),
449 ScalarValue::Int64(Some(43)),
450 );
451 assert_preimage_range(
453 ScalarValue::Int32(Some(100)),
454 ScalarValue::Int32(Some(100)),
455 ScalarValue::Int32(Some(101)),
456 );
457 assert_preimage_range(
459 ScalarValue::Float64(Some(-5.0)),
460 ScalarValue::Float64(Some(-5.0)),
461 ScalarValue::Float64(Some(-4.0)),
462 );
463 assert_preimage_range(
465 ScalarValue::Float64(Some(0.0)),
466 ScalarValue::Float64(Some(0.0)),
467 ScalarValue::Float64(Some(1.0)),
468 );
469 }
470
471 #[test]
472 fn test_floor_preimage_non_integer_float() {
473 assert_preimage_none(ScalarValue::Float64(Some(1.3)));
476 assert_preimage_none(ScalarValue::Float64(Some(-2.5)));
477 assert_preimage_none(ScalarValue::Float32(Some(3.7)));
478 }
479
480 #[test]
481 fn test_floor_preimage_integer_overflow() {
482 assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
484 assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
485 assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
486 assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
487 }
488
489 #[test]
490 fn test_floor_preimage_float_edge_cases() {
491 assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
493 assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
494 assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
495 assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
499 assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
500 assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
501 assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); }
503
504 #[test]
505 fn test_floor_preimage_null_values() {
506 assert_preimage_none(ScalarValue::Float64(None));
507 assert_preimage_none(ScalarValue::Float32(None));
508 assert_preimage_none(ScalarValue::Int64(None));
509 }
510
511 #[test]
514 fn test_floor_preimage_decimal_valid_cases() {
515 assert_preimage_range(
519 ScalarValue::Decimal32(Some(10000), 9, 2),
520 ScalarValue::Decimal32(Some(10000), 9, 2), ScalarValue::Decimal32(Some(10100), 9, 2), );
523
524 assert_preimage_range(
526 ScalarValue::Decimal32(Some(5000), 9, 2),
527 ScalarValue::Decimal32(Some(5000), 9, 2), ScalarValue::Decimal32(Some(5100), 9, 2), );
530
531 assert_preimage_range(
533 ScalarValue::Decimal32(Some(-500), 9, 2),
534 ScalarValue::Decimal32(Some(-500), 9, 2), ScalarValue::Decimal32(Some(-400), 9, 2), );
537
538 assert_preimage_range(
540 ScalarValue::Decimal32(Some(0), 9, 2),
541 ScalarValue::Decimal32(Some(0), 9, 2), ScalarValue::Decimal32(Some(100), 9, 2), );
544
545 assert_preimage_range(
547 ScalarValue::Decimal32(Some(42), 9, 0),
548 ScalarValue::Decimal32(Some(42), 9, 0),
549 ScalarValue::Decimal32(Some(43), 9, 0),
550 );
551
552 assert_preimage_range(
554 ScalarValue::Decimal64(Some(10000), 18, 2),
555 ScalarValue::Decimal64(Some(10000), 18, 2), ScalarValue::Decimal64(Some(10100), 18, 2), );
558
559 assert_preimage_range(
561 ScalarValue::Decimal64(Some(-500), 18, 2),
562 ScalarValue::Decimal64(Some(-500), 18, 2), ScalarValue::Decimal64(Some(-400), 18, 2), );
565
566 assert_preimage_range(
568 ScalarValue::Decimal64(Some(0), 18, 2),
569 ScalarValue::Decimal64(Some(0), 18, 2),
570 ScalarValue::Decimal64(Some(100), 18, 2),
571 );
572
573 assert_preimage_range(
575 ScalarValue::Decimal128(Some(10000), 38, 2),
576 ScalarValue::Decimal128(Some(10000), 38, 2), ScalarValue::Decimal128(Some(10100), 38, 2), );
579
580 assert_preimage_range(
582 ScalarValue::Decimal128(Some(-500), 38, 2),
583 ScalarValue::Decimal128(Some(-500), 38, 2), ScalarValue::Decimal128(Some(-400), 38, 2), );
586
587 assert_preimage_range(
589 ScalarValue::Decimal128(Some(0), 38, 2),
590 ScalarValue::Decimal128(Some(0), 38, 2),
591 ScalarValue::Decimal128(Some(100), 38, 2),
592 );
593
594 assert_preimage_range(
596 ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
597 ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), );
600
601 assert_preimage_range(
603 ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
604 ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), );
607
608 assert_preimage_range(
610 ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
611 ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
612 ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
613 );
614 }
615
616 #[test]
617 fn test_floor_preimage_decimal_non_integer() {
618 assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2));
642
643 assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2));
646 }
647
648 #[test]
649 fn test_floor_preimage_decimal_overflow() {
650 assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0));
654
655 assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0));
657 }
658
659 #[test]
660 fn test_floor_preimage_decimal_edge_cases() {
661 let safe_max_aligned_32 = 999_999_900; assert_preimage_range(
667 ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
668 ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
669 ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
670 );
671
672 let min_aligned_32 = -999_999_900; assert_preimage_range(
676 ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
677 ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
678 ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
679 );
680 }
681
682 #[test]
683 fn test_floor_preimage_decimal_null() {
684 assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
685 assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
686 assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
687 assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
688 }
689}