Skip to main content

datafusion_functions/math/
ceil.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{ArrayRef, AsArray};
21use arrow::datatypes::{
22    DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
23    Float64Type,
24};
25use datafusion_common::{Result, ScalarValue, exec_err};
26use datafusion_expr::interval_arithmetic::Interval;
27use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
28use datafusion_expr::{
29    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    TypeSignature, TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34use super::decimal::{apply_decimal_op, ceil_decimal_value};
35
36#[user_doc(
37    doc_section(label = "Math Functions"),
38    description = "Returns the nearest integer greater than or equal to a number.",
39    syntax_example = "ceil(numeric_expression)",
40    standard_argument(name = "numeric_expression", prefix = "Numeric"),
41    sql_example = r#"```sql
42> SELECT ceil(3.14);
43+------------+
44| ceil(3.14) |
45+------------+
46| 4.0        |
47+------------+
48```"#
49)]
50#[derive(Debug, PartialEq, Eq, Hash)]
51pub struct CeilFunc {
52    signature: Signature,
53}
54
55impl Default for CeilFunc {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl CeilFunc {
62    pub fn new() -> Self {
63        let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal);
64        Self {
65            signature: Signature::one_of(
66                vec![
67                    TypeSignature::Coercible(vec![decimal_sig]),
68                    TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
69                ],
70                Volatility::Immutable,
71            ),
72        }
73    }
74}
75
76impl ScalarUDFImpl for CeilFunc {
77    fn name(&self) -> &str {
78        "ceil"
79    }
80
81    fn signature(&self) -> &Signature {
82        &self.signature
83    }
84
85    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86        match &arg_types[0] {
87            DataType::Null => Ok(DataType::Float64),
88            other => Ok(other.clone()),
89        }
90    }
91
92    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93        let arg = &args.args[0];
94
95        // Scalar fast path for float types - avoid array conversion overhead entirely
96        if let ColumnarValue::Scalar(scalar) = arg {
97            match scalar {
98                ScalarValue::Float64(v) => {
99                    return Ok(ColumnarValue::Scalar(ScalarValue::Float64(
100                        v.map(f64::ceil),
101                    )));
102                }
103                ScalarValue::Float32(v) => {
104                    return Ok(ColumnarValue::Scalar(ScalarValue::Float32(
105                        v.map(f32::ceil),
106                    )));
107                }
108                ScalarValue::Null => {
109                    return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
110                }
111                // For decimals: convert to array of size 1, process, then extract scalar
112                // This ensures we don't expand the array while reusing overflow validation
113                _ => {}
114            }
115        }
116
117        // Track if input was a scalar to convert back at the end
118        let is_scalar = matches!(arg, ColumnarValue::Scalar(_));
119
120        // Array path (also handles decimal scalars converted to size-1 arrays)
121        let value = arg.to_array(args.number_rows)?;
122
123        let result: ArrayRef = match value.data_type() {
124            DataType::Float64 => Arc::new(
125                value
126                    .as_primitive::<Float64Type>()
127                    .unary::<_, Float64Type>(f64::ceil),
128            ),
129            DataType::Float32 => Arc::new(
130                value
131                    .as_primitive::<Float32Type>()
132                    .unary::<_, Float32Type>(f32::ceil),
133            ),
134            DataType::Null => {
135                return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)));
136            }
137            DataType::Decimal32(precision, scale) => {
138                apply_decimal_op::<Decimal32Type, _>(
139                    &value,
140                    *precision,
141                    *scale,
142                    self.name(),
143                    ceil_decimal_value,
144                )?
145            }
146            DataType::Decimal64(precision, scale) => {
147                apply_decimal_op::<Decimal64Type, _>(
148                    &value,
149                    *precision,
150                    *scale,
151                    self.name(),
152                    ceil_decimal_value,
153                )?
154            }
155            DataType::Decimal128(precision, scale) => {
156                apply_decimal_op::<Decimal128Type, _>(
157                    &value,
158                    *precision,
159                    *scale,
160                    self.name(),
161                    ceil_decimal_value,
162                )?
163            }
164            DataType::Decimal256(precision, scale) => {
165                apply_decimal_op::<Decimal256Type, _>(
166                    &value,
167                    *precision,
168                    *scale,
169                    self.name(),
170                    ceil_decimal_value,
171                )?
172            }
173            other => {
174                return exec_err!(
175                    "Unsupported data type {other:?} for function {}",
176                    self.name()
177                );
178            }
179        };
180
181        // If input was a scalar, convert result back to scalar
182        if is_scalar {
183            ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar)
184        } else {
185            Ok(ColumnarValue::Array(result))
186        }
187    }
188
189    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
190        Ok(input[0].sort_properties)
191    }
192
193    fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
194        let data_type = inputs[0].data_type();
195        Interval::make_unbounded(&data_type)
196    }
197
198    fn documentation(&self) -> Option<&Documentation> {
199        self.doc()
200    }
201}