use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray};
use arrow::compute::{DecimalCast, rescale_decimal};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::preimage::PreimageResult;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use num_traits::{CheckedAdd, Float, One};
use super::decimal::{apply_decimal_op, floor_decimal_value};
#[user_doc(
doc_section(label = "Math Functions"),
description = "Returns the nearest integer less than or equal to a number.",
syntax_example = "floor(numeric_expression)",
standard_argument(name = "numeric_expression", prefix = "Numeric"),
sql_example = r#"```sql
> SELECT floor(3.14);
+-------------+
| floor(3.14) |
+-------------+
| 3.0 |
+-------------+
```"#
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct FloorFunc {
signature: Signature,
}
impl Default for FloorFunc {
fn default() -> Self {
Self::new()
}
}
impl FloorFunc {
pub fn new() -> Self {
let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
Self {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![decimal_sig]),
TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
],
Volatility::Immutable,
),
}
}
}
macro_rules! preimage_bounds {
(float: $variant:ident, $value:expr) => {
float_preimage_bounds($value).map(|(lo, hi)| {
(
ScalarValue::$variant(Some(lo)),
ScalarValue::$variant(Some(hi)),
)
})
};
(int: $variant:ident, $value:expr) => {
int_preimage_bounds($value).map(|(lo, hi)| {
(
ScalarValue::$variant(Some(lo)),
ScalarValue::$variant(Some(hi)),
)
})
};
(decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => {
decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map(
|(lo, hi)| {
(
ScalarValue::$variant(Some(lo), $precision, $scale),
ScalarValue::$variant(Some(hi), $precision, $scale),
)
},
)
};
}
impl ScalarUDFImpl for FloorFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"floor"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Null => Ok(DataType::Float64),
other => Ok(other.clone()),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let arg = &args.args[0];
if let ColumnarValue::Scalar(scalar) = arg {
match scalar {
ScalarValue::Float64(v) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(
v.map(f64::floor),
)));
}
ScalarValue::Float32(v) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float32(
v.map(f32::floor),
)));
}
ScalarValue::Null => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
_ => {}
}
}
let is_scalar = matches!(arg, ColumnarValue::Scalar(_));
let value = arg.to_array(args.number_rows)?;
let result: ArrayRef = match value.data_type() {
DataType::Float64 => Arc::new(
value
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(f64::floor),
),
DataType::Float32 => Arc::new(
value
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(f32::floor),
),
DataType::Null => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
DataType::Decimal32(precision, scale) => {
apply_decimal_op::<Decimal32Type, _>(
&value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal64(precision, scale) => {
apply_decimal_op::<Decimal64Type, _>(
&value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal128(precision, scale) => {
apply_decimal_op::<Decimal128Type, _>(
&value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal256(precision, scale) => {
apply_decimal_op::<Decimal256Type, _>(
&value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
);
}
};
if is_scalar {
ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar)
} else {
Ok(ColumnarValue::Array(result))
}
}
fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
Ok(input[0].sort_properties)
}
fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
let data_type = inputs[0].data_type();
Interval::make_unbounded(&data_type)
}
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
debug_assert!(args.len() == 1, "floor() takes exactly one argument");
let arg = args[0].clone();
let Expr::Literal(lit_value, _) = lit_expr else {
return Ok(PreimageResult::None);
};
let Some((lower, upper)) = (match lit_value {
ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n),
ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n),
ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n),
ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n),
ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n),
ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n),
ScalarValue::Decimal32(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale)
}
ScalarValue::Decimal64(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale)
}
ScalarValue::Decimal128(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale)
}
ScalarValue::Decimal256(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale)
}
_ => None,
}) else {
return Ok(PreimageResult::None);
};
Ok(PreimageResult::Range {
expr: arg,
interval: Box::new(Interval::try_new(lower, upper)?),
})
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
let one = F::one();
if !n.is_finite() {
return None;
}
if n.fract() != F::zero() {
return None;
}
if n + one <= n {
return None;
}
Some((n, n + one))
}
fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
let upper = n.checked_add(&I::one())?;
Some((n, upper))
}
fn decimal_preimage_bounds<D: DecimalType>(
value: D::Native,
precision: u8,
scale: i8,
) -> Option<(D::Native, D::Native)>
where
D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
{
let one_scaled: D::Native = rescale_decimal::<D, D>(
D::Native::ONE, 1, 0, precision, scale, )?;
if scale > 0 && value % one_scaled != D::Native::ZERO {
return None;
}
let upper = value.add_checked(one_scaled).ok()?;
Some((value, upper))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_buffer::i256;
use datafusion_expr::col;
fn assert_preimage_range(
input: ScalarValue,
expected_lower: ScalarValue,
expected_upper: ScalarValue,
) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();
let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
match result {
PreimageResult::Range { expr, interval } => {
assert_eq!(expr, col("x"));
assert_eq!(interval.lower().clone(), expected_lower);
assert_eq!(interval.upper().clone(), expected_upper);
}
PreimageResult::None => {
panic!("Expected Range, got None for input {input:?}")
}
}
}
fn assert_preimage_none(input: ScalarValue) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();
let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for input {input:?}"
);
}
#[test]
fn test_floor_preimage_valid_cases() {
assert_preimage_range(
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(101.0)),
);
assert_preimage_range(
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(51.0)),
);
assert_preimage_range(
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(43)),
);
assert_preimage_range(
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(101)),
);
assert_preimage_range(
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-4.0)),
);
assert_preimage_range(
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(1.0)),
);
}
#[test]
fn test_floor_preimage_non_integer_float() {
assert_preimage_none(ScalarValue::Float64(Some(1.3)));
assert_preimage_none(ScalarValue::Float64(Some(-2.5)));
assert_preimage_none(ScalarValue::Float32(Some(3.7)));
}
#[test]
fn test_floor_preimage_integer_overflow() {
assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
}
#[test]
fn test_floor_preimage_float_edge_cases() {
assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
assert_preimage_none(ScalarValue::Float64(Some(f64::MAX)));
assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); }
#[test]
fn test_floor_preimage_null_values() {
assert_preimage_none(ScalarValue::Float64(None));
assert_preimage_none(ScalarValue::Float32(None));
assert_preimage_none(ScalarValue::Int64(None));
}
#[test]
fn test_floor_preimage_decimal_valid_cases() {
assert_preimage_range(
ScalarValue::Decimal32(Some(10000), 9, 2),
ScalarValue::Decimal32(Some(10000), 9, 2), ScalarValue::Decimal32(Some(10100), 9, 2), );
assert_preimage_range(
ScalarValue::Decimal32(Some(5000), 9, 2),
ScalarValue::Decimal32(Some(5000), 9, 2), ScalarValue::Decimal32(Some(5100), 9, 2), );
assert_preimage_range(
ScalarValue::Decimal32(Some(-500), 9, 2),
ScalarValue::Decimal32(Some(-500), 9, 2), ScalarValue::Decimal32(Some(-400), 9, 2), );
assert_preimage_range(
ScalarValue::Decimal32(Some(0), 9, 2),
ScalarValue::Decimal32(Some(0), 9, 2), ScalarValue::Decimal32(Some(100), 9, 2), );
assert_preimage_range(
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(43), 9, 0),
);
assert_preimage_range(
ScalarValue::Decimal64(Some(10000), 18, 2),
ScalarValue::Decimal64(Some(10000), 18, 2), ScalarValue::Decimal64(Some(10100), 18, 2), );
assert_preimage_range(
ScalarValue::Decimal64(Some(-500), 18, 2),
ScalarValue::Decimal64(Some(-500), 18, 2), ScalarValue::Decimal64(Some(-400), 18, 2), );
assert_preimage_range(
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(100), 18, 2),
);
assert_preimage_range(
ScalarValue::Decimal128(Some(10000), 38, 2),
ScalarValue::Decimal128(Some(10000), 38, 2), ScalarValue::Decimal128(Some(10100), 38, 2), );
assert_preimage_range(
ScalarValue::Decimal128(Some(-500), 38, 2),
ScalarValue::Decimal128(Some(-500), 38, 2), ScalarValue::Decimal128(Some(-400), 38, 2), );
assert_preimage_range(
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(100), 38, 2),
);
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), );
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), );
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
);
}
#[test]
fn test_floor_preimage_decimal_non_integer() {
assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2));
assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2));
assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2));
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2));
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2));
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2));
}
#[test]
fn test_floor_preimage_decimal_overflow() {
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0));
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0));
}
#[test]
fn test_floor_preimage_decimal_edge_cases() {
let safe_max_aligned_32 = 999_999_900; assert_preimage_range(
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
);
let min_aligned_32 = -999_999_900; assert_preimage_range(
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
);
}
#[test]
fn test_floor_preimage_decimal_null() {
assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
}
}