Skip to main content

datafusion_functions/math/
factorial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, Int64Array};
19use std::any::Any;
20use std::sync::Arc;
21
22use arrow::datatypes::DataType::Int64;
23use arrow::datatypes::{DataType, Int64Type};
24
25use datafusion_common::{
26    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
27};
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "Math Functions"),
36    description = "Factorial. Returns 1 if value is less than 2.",
37    syntax_example = "factorial(numeric_expression)",
38    sql_example = r#"```sql
39> SELECT factorial(5);
40+---------------+
41| factorial(5)  |
42+---------------+
43| 120           |
44+---------------+
45```"#,
46    standard_argument(name = "numeric_expression", prefix = "Numeric")
47)]
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct FactorialFunc {
50    signature: Signature,
51}
52
53impl Default for FactorialFunc {
54    fn default() -> Self {
55        FactorialFunc::new()
56    }
57}
58
59impl FactorialFunc {
60    pub fn new() -> Self {
61        Self {
62            signature: Signature::uniform(1, vec![Int64], Volatility::Immutable),
63        }
64    }
65}
66
67impl ScalarUDFImpl for FactorialFunc {
68    fn as_any(&self) -> &dyn Any {
69        self
70    }
71
72    fn name(&self) -> &str {
73        "factorial"
74    }
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
81        Ok(Int64)
82    }
83
84    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
85        let [arg] = take_function_args(self.name(), args.args)?;
86
87        match arg {
88            ColumnarValue::Scalar(scalar) => {
89                if scalar.is_null() {
90                    return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None)));
91                }
92
93                match scalar {
94                    ScalarValue::Int64(Some(v)) => {
95                        let result = compute_factorial(v)?;
96                        Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
97                    }
98                    _ => {
99                        internal_err!(
100                            "Unexpected data type {:?} for function factorial",
101                            scalar.data_type()
102                        )
103                    }
104                }
105            }
106            ColumnarValue::Array(array) => match array.data_type() {
107                Int64 => {
108                    let result: Int64Array = array
109                        .as_primitive::<Int64Type>()
110                        .try_unary(compute_factorial)?;
111                    Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
112                }
113                other => {
114                    internal_err!("Unexpected data type {other:?} for function factorial")
115                }
116            },
117        }
118    }
119
120    fn documentation(&self) -> Option<&Documentation> {
121        self.doc()
122    }
123}
124
125const FACTORIALS: [i64; 21] = [
126    1,
127    1,
128    2,
129    6,
130    24,
131    120,
132    720,
133    5040,
134    40320,
135    362880,
136    3628800,
137    39916800,
138    479001600,
139    6227020800,
140    87178291200,
141    1307674368000,
142    20922789888000,
143    355687428096000,
144    6402373705728000,
145    121645100408832000,
146    2432902008176640000,
147]; // if return type changes, this constant needs to be updated accordingly
148
149fn compute_factorial(n: i64) -> Result<i64> {
150    if n < 0 {
151        Ok(1)
152    } else if n < FACTORIALS.len() as i64 {
153        Ok(FACTORIALS[n as usize])
154    } else {
155        exec_err!("Overflow happened on FACTORIAL({n})")
156    }
157}