datafusion_spark/function/math/
factorial.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, Int64Array};
22use arrow::datatypes::DataType;
23use arrow::datatypes::DataType::{Int32, Int64};
24use datafusion_common::cast::as_int32_array;
25use datafusion_common::{
26 DataFusionError, Result, ScalarValue, exec_err, utils::take_function_args,
27};
28use datafusion_expr::Signature;
29use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
30
31#[derive(Debug, PartialEq, Eq, Hash)]
33pub struct SparkFactorial {
34 signature: Signature,
35 aliases: Vec<String>,
36}
37
38impl Default for SparkFactorial {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl SparkFactorial {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::exact(vec![Int32], Volatility::Immutable),
48 aliases: vec![],
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkFactorial {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "factorial"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67 Ok(Int64)
68 }
69
70 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71 spark_factorial(&args.args)
72 }
73
74 fn aliases(&self) -> &[String] {
75 &self.aliases
76 }
77}
78
79const FACTORIALS: [i64; 21] = [
80 1,
81 1,
82 2,
83 6,
84 24,
85 120,
86 720,
87 5040,
88 40320,
89 362880,
90 3628800,
91 39916800,
92 479001600,
93 6227020800,
94 87178291200,
95 1307674368000,
96 20922789888000,
97 355687428096000,
98 6402373705728000,
99 121645100408832000,
100 2432902008176640000,
101];
102
103pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
104 let [arg] = take_function_args("factorial", args)?;
105
106 match arg {
107 ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
108 let result = compute_factorial(*value);
109 Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
110 }
111 ColumnarValue::Scalar(other) => {
112 exec_err!("`factorial` got an unexpected scalar type: {}", other)
113 }
114 ColumnarValue::Array(array) => match array.data_type() {
115 Int32 => {
116 let array = as_int32_array(array)?;
117
118 let result: Int64Array = array.iter().map(compute_factorial).collect();
119
120 Ok(ColumnarValue::Array(Arc::new(result)))
121 }
122 other => {
123 exec_err!("`factorial` got an unexpected argument type: {}", other)
124 }
125 },
126 }
127}
128
129#[inline]
130fn compute_factorial(num: Option<i32>) -> Option<i64> {
131 num.filter(|&v| (0..=20).contains(&v))
132 .map(|v| FACTORIALS[v as usize])
133}
134
135#[cfg(test)]
136mod test {
137 use crate::function::math::factorial::spark_factorial;
138 use arrow::array::{Int32Array, Int64Array};
139 use datafusion_common::ScalarValue;
140 use datafusion_common::cast::as_int64_array;
141 use datafusion_expr::ColumnarValue;
142 use std::sync::Arc;
143
144 #[test]
145 fn test_spark_factorial_array() {
146 let input = Int32Array::from(vec![
147 Some(-1),
148 Some(0),
149 Some(1),
150 Some(2),
151 Some(4),
152 Some(20),
153 Some(21),
154 None,
155 ]);
156
157 let args = ColumnarValue::Array(Arc::new(input));
158 let result = spark_factorial(&[args]).unwrap();
159 let result = match result {
160 ColumnarValue::Array(array) => array,
161 _ => panic!("Expected array"),
162 };
163
164 let actual = as_int64_array(&result).unwrap();
165 let expected = Int64Array::from(vec![
166 None,
167 Some(1),
168 Some(1),
169 Some(2),
170 Some(24),
171 Some(2432902008176640000),
172 None,
173 None,
174 ]);
175
176 assert_eq!(actual, &expected);
177 }
178
179 #[test]
180 fn test_spark_factorial_scalar() {
181 let input = ScalarValue::Int32(Some(5));
182
183 let args = ColumnarValue::Scalar(input);
184 let result = spark_factorial(&[args]).unwrap();
185 let result = match result {
186 ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
187 _ => panic!("Expected scalar"),
188 };
189 let actual = result.unwrap();
190 let expected = 120_i64;
191
192 assert_eq!(actual, expected);
193 }
194}