datafusion_spark/function/math/
factorial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, Int64Array};
22use arrow::datatypes::DataType;
23use arrow::datatypes::DataType::{Int32, Int64};
24use datafusion_common::cast::as_int32_array;
25use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::Signature;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
28
29/// <https://spark.apache.org/docs/latest/api/sql/index.html#factorial>
30#[derive(Debug)]
31pub struct SparkFactorial {
32    signature: Signature,
33    aliases: Vec<String>,
34}
35
36impl Default for SparkFactorial {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SparkFactorial {
43    pub fn new() -> Self {
44        Self {
45            signature: Signature::exact(vec![Int32], Volatility::Immutable),
46            aliases: vec![],
47        }
48    }
49}
50
51impl ScalarUDFImpl for SparkFactorial {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn name(&self) -> &str {
57        "factorial"
58    }
59
60    fn signature(&self) -> &Signature {
61        &self.signature
62    }
63
64    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65        Ok(Int64)
66    }
67
68    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69        spark_factorial(&args.args)
70    }
71
72    fn aliases(&self) -> &[String] {
73        &self.aliases
74    }
75}
76
77const FACTORIALS: [i64; 21] = [
78    1,
79    1,
80    2,
81    6,
82    24,
83    120,
84    720,
85    5040,
86    40320,
87    362880,
88    3628800,
89    39916800,
90    479001600,
91    6227020800,
92    87178291200,
93    1307674368000,
94    20922789888000,
95    355687428096000,
96    6402373705728000,
97    121645100408832000,
98    2432902008176640000,
99];
100
101pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
102    if args.len() != 1 {
103        return Err(DataFusionError::Internal(
104            "`factorial` expects exactly one argument".to_string(),
105        ));
106    }
107
108    match &args[0] {
109        ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
110            let result = compute_factorial(*value);
111            Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
112        }
113        ColumnarValue::Scalar(other) => {
114            exec_err!("`factorial` got an unexpected scalar type: {:?}", other)
115        }
116        ColumnarValue::Array(array) => match array.data_type() {
117            Int32 => {
118                let array = as_int32_array(array)?;
119
120                let result: Int64Array = array.iter().map(compute_factorial).collect();
121
122                Ok(ColumnarValue::Array(Arc::new(result)))
123            }
124            other => {
125                exec_err!("`factorial` got an unexpected argument type: {:?}", other)
126            }
127        },
128    }
129}
130
131#[inline]
132fn compute_factorial(num: Option<i32>) -> Option<i64> {
133    num.filter(|&v| (0..=20).contains(&v))
134        .map(|v| FACTORIALS[v as usize])
135}
136
137#[cfg(test)]
138mod test {
139    use crate::function::math::factorial::spark_factorial;
140    use arrow::array::{Int32Array, Int64Array};
141    use datafusion_common::cast::as_int64_array;
142    use datafusion_common::ScalarValue;
143    use datafusion_expr::ColumnarValue;
144    use std::sync::Arc;
145
146    #[test]
147    fn test_spark_factorial_array() {
148        let input = Int32Array::from(vec![
149            Some(-1),
150            Some(0),
151            Some(1),
152            Some(2),
153            Some(4),
154            Some(20),
155            Some(21),
156            None,
157        ]);
158
159        let args = ColumnarValue::Array(Arc::new(input));
160        let result = spark_factorial(&[args]).unwrap();
161        let result = match result {
162            ColumnarValue::Array(array) => array,
163            _ => panic!("Expected array"),
164        };
165
166        let actual = as_int64_array(&result).unwrap();
167        let expected = Int64Array::from(vec![
168            None,
169            Some(1),
170            Some(1),
171            Some(2),
172            Some(24),
173            Some(2432902008176640000),
174            None,
175            None,
176        ]);
177
178        assert_eq!(actual, &expected);
179    }
180
181    #[test]
182    fn test_spark_factorial_scalar() {
183        let input = ScalarValue::Int32(Some(5));
184
185        let args = ColumnarValue::Scalar(input);
186        let result = spark_factorial(&[args]).unwrap();
187        let result = match result {
188            ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
189            _ => panic!("Expected scalar"),
190        };
191        let actual = result.unwrap();
192        let expected = 120_i64;
193
194        assert_eq!(actual, expected);
195    }
196}