Skip to main content

datafusion_functions/math/
ceil.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray};
22use arrow::datatypes::{
23    DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
24    Float64Type,
25};
26use datafusion_common::{Result, ScalarValue, exec_err};
27use datafusion_expr::interval_arithmetic::Interval;
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignature, TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35use super::decimal::{apply_decimal_op, ceil_decimal_value};
36
37#[user_doc(
38    doc_section(label = "Math Functions"),
39    description = "Returns the nearest integer greater than or equal to a number.",
40    syntax_example = "ceil(numeric_expression)",
41    standard_argument(name = "numeric_expression", prefix = "Numeric"),
42    sql_example = r#"```sql
43> SELECT ceil(3.14);
44+------------+
45| ceil(3.14) |
46+------------+
47| 4.0        |
48+------------+
49```"#
50)]
51#[derive(Debug, PartialEq, Eq, Hash)]
52pub struct CeilFunc {
53    signature: Signature,
54}
55
56impl Default for CeilFunc {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl CeilFunc {
63    pub fn new() -> Self {
64        let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
65        Self {
66            signature: Signature::one_of(
67                vec![
68                    TypeSignature::Coercible(vec![decimal_sig]),
69                    TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
70                ],
71                Volatility::Immutable,
72            ),
73        }
74    }
75}
76
77impl ScalarUDFImpl for CeilFunc {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn name(&self) -> &str {
83        "ceil"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
91        match &arg_types[0] {
92            DataType::Null => Ok(DataType::Float64),
93            other => Ok(other.clone()),
94        }
95    }
96
97    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98        let arg = &args.args[0];
99
100        // Scalar fast path for float types - avoid array conversion overhead entirely
101        if let ColumnarValue::Scalar(scalar) = arg {
102            match scalar {
103                ScalarValue::Float64(v) => {
104                    return Ok(ColumnarValue::Scalar(ScalarValue::Float64(
105                        v.map(f64::ceil),
106                    )));
107                }
108                ScalarValue::Float32(v) => {
109                    return Ok(ColumnarValue::Scalar(ScalarValue::Float32(
110                        v.map(f32::ceil),
111                    )));
112                }
113                ScalarValue::Null => {
114                    return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
115                }
116                // For decimals: convert to array of size 1, process, then extract scalar
117                // This ensures we don't expand the array while reusing overflow validation
118                _ => {}
119            }
120        }
121
122        // Track if input was a scalar to convert back at the end
123        let is_scalar = matches!(arg, ColumnarValue::Scalar(_));
124
125        // Array path (also handles decimal scalars converted to size-1 arrays)
126        let value = arg.to_array(args.number_rows)?;
127
128        let result: ArrayRef = match value.data_type() {
129            DataType::Float64 => Arc::new(
130                value
131                    .as_primitive::<Float64Type>()
132                    .unary::<_, Float64Type>(f64::ceil),
133            ),
134            DataType::Float32 => Arc::new(
135                value
136                    .as_primitive::<Float32Type>()
137                    .unary::<_, Float32Type>(f32::ceil),
138            ),
139            DataType::Null => {
140                return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
141            }
142            DataType::Decimal32(precision, scale) => {
143                apply_decimal_op::<Decimal32Type, _>(
144                    &value,
145                    *precision,
146                    *scale,
147                    self.name(),
148                    ceil_decimal_value,
149                )?
150            }
151            DataType::Decimal64(precision, scale) => {
152                apply_decimal_op::<Decimal64Type, _>(
153                    &value,
154                    *precision,
155                    *scale,
156                    self.name(),
157                    ceil_decimal_value,
158                )?
159            }
160            DataType::Decimal128(precision, scale) => {
161                apply_decimal_op::<Decimal128Type, _>(
162                    &value,
163                    *precision,
164                    *scale,
165                    self.name(),
166                    ceil_decimal_value,
167                )?
168            }
169            DataType::Decimal256(precision, scale) => {
170                apply_decimal_op::<Decimal256Type, _>(
171                    &value,
172                    *precision,
173                    *scale,
174                    self.name(),
175                    ceil_decimal_value,
176                )?
177            }
178            other => {
179                return exec_err!(
180                    "Unsupported data type {other:?} for function {}",
181                    self.name()
182                );
183            }
184        };
185
186        // If input was a scalar, convert result back to scalar
187        if is_scalar {
188            ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar)
189        } else {
190            Ok(ColumnarValue::Array(result))
191        }
192    }
193
194    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
195        Ok(input[0].sort_properties)
196    }
197
198    fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
199        let data_type = inputs[0].data_type();
200        Interval::make_unbounded(&data_type)
201    }
202
203    fn documentation(&self) -> Option<&Documentation> {
204        self.doc()
205    }
206}