Skip to main content

datafusion_spark/function/math/
factorial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, Int64Array};
22use arrow::datatypes::DataType;
23use arrow::datatypes::DataType::{Int32, Int64};
24use datafusion_common::cast::as_int32_array;
25use datafusion_common::{
26    DataFusionError, Result, ScalarValue, exec_err, utils::take_function_args,
27};
28use datafusion_expr::Signature;
29use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
30
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#factorial>
32#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkFactorial {
34    signature: Signature,
35    aliases: Vec<String>,
36}
37
38impl Default for SparkFactorial {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkFactorial {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::exact(vec![Int32], Volatility::Immutable),
48            aliases: vec![],
49        }
50    }
51}
52
53impl ScalarUDFImpl for SparkFactorial {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "factorial"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67        Ok(Int64)
68    }
69
70    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71        spark_factorial(&args.args)
72    }
73
74    fn aliases(&self) -> &[String] {
75        &self.aliases
76    }
77}
78
79const FACTORIALS: [i64; 21] = [
80    1,
81    1,
82    2,
83    6,
84    24,
85    120,
86    720,
87    5040,
88    40320,
89    362880,
90    3628800,
91    39916800,
92    479001600,
93    6227020800,
94    87178291200,
95    1307674368000,
96    20922789888000,
97    355687428096000,
98    6402373705728000,
99    121645100408832000,
100    2432902008176640000,
101];
102
103pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
104    let [arg] = take_function_args("factorial", args)?;
105
106    match arg {
107        ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
108            let result = compute_factorial(*value);
109            Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
110        }
111        ColumnarValue::Scalar(other) => {
112            exec_err!("`factorial` got an unexpected scalar type: {}", other)
113        }
114        ColumnarValue::Array(array) => match array.data_type() {
115            Int32 => {
116                let array = as_int32_array(array)?;
117
118                let result: Int64Array = array.iter().map(compute_factorial).collect();
119
120                Ok(ColumnarValue::Array(Arc::new(result)))
121            }
122            other => {
123                exec_err!("`factorial` got an unexpected argument type: {}", other)
124            }
125        },
126    }
127}
128
129#[inline]
130fn compute_factorial(num: Option<i32>) -> Option<i64> {
131    num.filter(|&v| (0..=20).contains(&v))
132        .map(|v| FACTORIALS[v as usize])
133}
134
135#[cfg(test)]
136mod test {
137    use crate::function::math::factorial::spark_factorial;
138    use arrow::array::{Int32Array, Int64Array};
139    use datafusion_common::ScalarValue;
140    use datafusion_common::cast::as_int64_array;
141    use datafusion_expr::ColumnarValue;
142    use std::sync::Arc;
143
144    #[test]
145    fn test_spark_factorial_array() {
146        let input = Int32Array::from(vec![
147            Some(-1),
148            Some(0),
149            Some(1),
150            Some(2),
151            Some(4),
152            Some(20),
153            Some(21),
154            None,
155        ]);
156
157        let args = ColumnarValue::Array(Arc::new(input));
158        let result = spark_factorial(&[args]).unwrap();
159        let result = match result {
160            ColumnarValue::Array(array) => array,
161            _ => panic!("Expected array"),
162        };
163
164        let actual = as_int64_array(&result).unwrap();
165        let expected = Int64Array::from(vec![
166            None,
167            Some(1),
168            Some(1),
169            Some(2),
170            Some(24),
171            Some(2432902008176640000),
172            None,
173            None,
174        ]);
175
176        assert_eq!(actual, &expected);
177    }
178
179    #[test]
180    fn test_spark_factorial_scalar() {
181        let input = ScalarValue::Int32(Some(5));
182
183        let args = ColumnarValue::Scalar(input);
184        let result = spark_factorial(&[args]).unwrap();
185        let result = match result {
186            ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
187            _ => panic!("Expected scalar"),
188        };
189        let actual = result.unwrap();
190        let expected = 120_i64;
191
192        assert_eq!(actual, expected);
193    }
194}