use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
Float64Type,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use super::decimal::{apply_decimal_op, floor_decimal_value};
#[user_doc(
doc_section(label = "Math Functions"),
description = "Returns the nearest integer less than or equal to a number.",
syntax_example = "floor(numeric_expression)",
standard_argument(name = "numeric_expression", prefix = "Numeric"),
sql_example = r#"```sql
> SELECT floor(3.14);
+-------------+
| floor(3.14) |
+-------------+
| 3.0 |
+-------------+
```"#
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct FloorFunc {
signature: Signature,
}
impl Default for FloorFunc {
fn default() -> Self {
Self::new()
}
}
impl FloorFunc {
pub fn new() -> Self {
let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
Self {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![decimal_sig]),
TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for FloorFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"floor"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match &arg_types[0] {
DataType::Null => Ok(DataType::Float64),
other => Ok(other.clone()),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let value = &args[0];
let result: ArrayRef = match value.data_type() {
DataType::Float64 => Arc::new(
value
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(f64::floor),
),
DataType::Float32 => Arc::new(
value
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(f32::floor),
),
DataType::Null => {
return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
}
DataType::Decimal32(precision, scale) => {
apply_decimal_op::<Decimal32Type, _>(
value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal64(precision, scale) => {
apply_decimal_op::<Decimal64Type, _>(
value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal128(precision, scale) => {
apply_decimal_op::<Decimal128Type, _>(
value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
DataType::Decimal256(precision, scale) => {
apply_decimal_op::<Decimal256Type, _>(
value,
*precision,
*scale,
self.name(),
floor_decimal_value,
)?
}
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
);
}
};
Ok(ColumnarValue::Array(result))
}
fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
Ok(input[0].sort_properties)
}
fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
let data_type = inputs[0].data_type();
Interval::make_unbounded(&data_type)
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}