Skip to main content

datafusion_functions/math/
factorial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, Int64Array};
19use std::sync::Arc;
20
21use arrow::datatypes::DataType::Int64;
22use arrow::datatypes::{DataType, Int64Type};
23
24use datafusion_common::{
25    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
26};
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29    Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34    doc_section(label = "Math Functions"),
35    description = "Factorial of a non-negative integer. Errors if the argument is negative or the result overflows.",
36    syntax_example = "factorial(numeric_expression)",
37    sql_example = r#"```sql
38> SELECT factorial(5);
39+---------------+
40| factorial(5)  |
41+---------------+
42| 120           |
43+---------------+
44```"#,
45    standard_argument(name = "numeric_expression", prefix = "Numeric")
46)]
47#[derive(Debug, PartialEq, Eq, Hash)]
48pub struct FactorialFunc {
49    signature: Signature,
50}
51
52impl Default for FactorialFunc {
53    fn default() -> Self {
54        FactorialFunc::new()
55    }
56}
57
58impl FactorialFunc {
59    pub fn new() -> Self {
60        Self {
61            signature: Signature::uniform(1, vec![Int64], Volatility::Immutable),
62        }
63    }
64}
65
66impl ScalarUDFImpl for FactorialFunc {
67    fn name(&self) -> &str {
68        "factorial"
69    }
70
71    fn signature(&self) -> &Signature {
72        &self.signature
73    }
74
75    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76        Ok(Int64)
77    }
78
79    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80        let [arg] = take_function_args(self.name(), args.args)?;
81
82        match arg {
83            ColumnarValue::Scalar(scalar) => {
84                if scalar.is_null() {
85                    return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None)));
86                }
87
88                match scalar {
89                    ScalarValue::Int64(Some(v)) => {
90                        let result = compute_factorial(v)?;
91                        Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
92                    }
93                    _ => {
94                        internal_err!(
95                            "Unexpected data type {:?} for function factorial",
96                            scalar.data_type()
97                        )
98                    }
99                }
100            }
101            ColumnarValue::Array(array) => match array.data_type() {
102                Int64 => {
103                    let result: Int64Array = array
104                        .as_primitive::<Int64Type>()
105                        .try_unary(compute_factorial)?;
106                    Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
107                }
108                other => {
109                    internal_err!("Unexpected data type {other:?} for function factorial")
110                }
111            },
112        }
113    }
114
115    fn documentation(&self) -> Option<&Documentation> {
116        self.doc()
117    }
118}
119
120const FACTORIALS: [i64; 21] = [
121    1,
122    1,
123    2,
124    6,
125    24,
126    120,
127    720,
128    5040,
129    40320,
130    362880,
131    3628800,
132    39916800,
133    479001600,
134    6227020800,
135    87178291200,
136    1307674368000,
137    20922789888000,
138    355687428096000,
139    6402373705728000,
140    121645100408832000,
141    2432902008176640000,
142]; // if return type changes, this constant needs to be updated accordingly
143
144fn compute_factorial(n: i64) -> Result<i64> {
145    if n < 0 {
146        exec_err!("factorial of a negative number is undefined")
147    } else if n < FACTORIALS.len() as i64 {
148        Ok(FACTORIALS[n as usize])
149    } else {
150        exec_err!("Overflow happened on FACTORIAL({n})")
151    }
152}