datafusion_spark/function/math/
factorial.rs1use std::sync::Arc;
19
20use arrow::array::{Array, Int64Array};
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Int32, Int64};
23use datafusion_common::cast::as_int32_array;
24use datafusion_common::{
25 DataFusionError, Result, ScalarValue, exec_err, utils::take_function_args,
26};
27use datafusion_expr::Signature;
28use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
29
30#[derive(Debug, PartialEq, Eq, Hash)]
32pub struct SparkFactorial {
33 signature: Signature,
34 aliases: Vec<String>,
35}
36
37impl Default for SparkFactorial {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl SparkFactorial {
44 pub fn new() -> Self {
45 Self {
46 signature: Signature::exact(vec![Int32], Volatility::Immutable),
47 aliases: vec![],
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkFactorial {
53 fn name(&self) -> &str {
54 "factorial"
55 }
56
57 fn signature(&self) -> &Signature {
58 &self.signature
59 }
60
61 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62 Ok(Int64)
63 }
64
65 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
66 spark_factorial(&args.args)
67 }
68
69 fn aliases(&self) -> &[String] {
70 &self.aliases
71 }
72}
73
74const FACTORIALS: [i64; 21] = [
75 1,
76 1,
77 2,
78 6,
79 24,
80 120,
81 720,
82 5040,
83 40320,
84 362880,
85 3628800,
86 39916800,
87 479001600,
88 6227020800,
89 87178291200,
90 1307674368000,
91 20922789888000,
92 355687428096000,
93 6402373705728000,
94 121645100408832000,
95 2432902008176640000,
96];
97
98pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
99 let [arg] = take_function_args("factorial", args)?;
100
101 match arg {
102 ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
103 let result = compute_factorial(*value);
104 Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
105 }
106 ColumnarValue::Scalar(other) => {
107 exec_err!("`factorial` got an unexpected scalar type: {}", other)
108 }
109 ColumnarValue::Array(array) => match array.data_type() {
110 Int32 => {
111 let array = as_int32_array(array)?;
112
113 let result: Int64Array = array.iter().map(compute_factorial).collect();
114
115 Ok(ColumnarValue::Array(Arc::new(result)))
116 }
117 other => {
118 exec_err!("`factorial` got an unexpected argument type: {}", other)
119 }
120 },
121 }
122}
123
124#[inline]
125fn compute_factorial(num: Option<i32>) -> Option<i64> {
126 num.filter(|&v| (0..=20).contains(&v))
127 .map(|v| FACTORIALS[v as usize])
128}
129
130#[cfg(test)]
131mod test {
132 use crate::function::math::factorial::spark_factorial;
133 use arrow::array::{Int32Array, Int64Array};
134 use datafusion_common::ScalarValue;
135 use datafusion_common::cast::as_int64_array;
136 use datafusion_expr::ColumnarValue;
137 use std::sync::Arc;
138
139 #[test]
140 fn test_spark_factorial_array() {
141 let input = Int32Array::from(vec![
142 Some(-1),
143 Some(0),
144 Some(1),
145 Some(2),
146 Some(4),
147 Some(20),
148 Some(21),
149 None,
150 ]);
151
152 let args = ColumnarValue::Array(Arc::new(input));
153 let result = spark_factorial(&[args]).unwrap();
154 let result = match result {
155 ColumnarValue::Array(array) => array,
156 _ => panic!("Expected array"),
157 };
158
159 let actual = as_int64_array(&result).unwrap();
160 let expected = Int64Array::from(vec![
161 None,
162 Some(1),
163 Some(1),
164 Some(2),
165 Some(24),
166 Some(2432902008176640000),
167 None,
168 None,
169 ]);
170
171 assert_eq!(actual, &expected);
172 }
173
174 #[test]
175 fn test_spark_factorial_scalar() {
176 let input = ScalarValue::Int32(Some(5));
177
178 let args = ColumnarValue::Scalar(input);
179 let result = spark_factorial(&[args]).unwrap();
180 let result = match result {
181 ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
182 _ => panic!("Expected scalar"),
183 };
184 let actual = result.unwrap();
185 let expected = 120_i64;
186
187 assert_eq!(actual, expected);
188 }
189}