datafusion_spark/function/math/
factorial.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, Int64Array};
22use arrow::datatypes::DataType;
23use arrow::datatypes::DataType::{Int32, Int64};
24use datafusion_common::cast::as_int32_array;
25use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::Signature;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
28
29#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkFactorial {
32 signature: Signature,
33 aliases: Vec<String>,
34}
35
36impl Default for SparkFactorial {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SparkFactorial {
43 pub fn new() -> Self {
44 Self {
45 signature: Signature::exact(vec![Int32], Volatility::Immutable),
46 aliases: vec![],
47 }
48 }
49}
50
51impl ScalarUDFImpl for SparkFactorial {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn name(&self) -> &str {
57 "factorial"
58 }
59
60 fn signature(&self) -> &Signature {
61 &self.signature
62 }
63
64 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65 Ok(Int64)
66 }
67
68 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69 spark_factorial(&args.args)
70 }
71
72 fn aliases(&self) -> &[String] {
73 &self.aliases
74 }
75}
76
77const FACTORIALS: [i64; 21] = [
78 1,
79 1,
80 2,
81 6,
82 24,
83 120,
84 720,
85 5040,
86 40320,
87 362880,
88 3628800,
89 39916800,
90 479001600,
91 6227020800,
92 87178291200,
93 1307674368000,
94 20922789888000,
95 355687428096000,
96 6402373705728000,
97 121645100408832000,
98 2432902008176640000,
99];
100
101pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
102 if args.len() != 1 {
103 return internal_err!("`factorial` expects exactly one argument");
104 }
105
106 match &args[0] {
107 ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
108 let result = compute_factorial(*value);
109 Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
110 }
111 ColumnarValue::Scalar(other) => {
112 exec_err!("`factorial` got an unexpected scalar type: {}", other)
113 }
114 ColumnarValue::Array(array) => match array.data_type() {
115 Int32 => {
116 let array = as_int32_array(array)?;
117
118 let result: Int64Array = array.iter().map(compute_factorial).collect();
119
120 Ok(ColumnarValue::Array(Arc::new(result)))
121 }
122 other => {
123 exec_err!("`factorial` got an unexpected argument type: {}", other)
124 }
125 },
126 }
127}
128
129#[inline]
130fn compute_factorial(num: Option<i32>) -> Option<i64> {
131 num.filter(|&v| (0..=20).contains(&v))
132 .map(|v| FACTORIALS[v as usize])
133}
134
135#[cfg(test)]
136mod test {
137 use crate::function::math::factorial::spark_factorial;
138 use arrow::array::{Int32Array, Int64Array};
139 use datafusion_common::cast::as_int64_array;
140 use datafusion_common::ScalarValue;
141 use datafusion_expr::ColumnarValue;
142 use std::sync::Arc;
143
144 #[test]
145 fn test_spark_factorial_array() {
146 let input = Int32Array::from(vec![
147 Some(-1),
148 Some(0),
149 Some(1),
150 Some(2),
151 Some(4),
152 Some(20),
153 Some(21),
154 None,
155 ]);
156
157 let args = ColumnarValue::Array(Arc::new(input));
158 let result = spark_factorial(&[args]).unwrap();
159 let result = match result {
160 ColumnarValue::Array(array) => array,
161 _ => panic!("Expected array"),
162 };
163
164 let actual = as_int64_array(&result).unwrap();
165 let expected = Int64Array::from(vec![
166 None,
167 Some(1),
168 Some(1),
169 Some(2),
170 Some(24),
171 Some(2432902008176640000),
172 None,
173 None,
174 ]);
175
176 assert_eq!(actual, &expected);
177 }
178
179 #[test]
180 fn test_spark_factorial_scalar() {
181 let input = ScalarValue::Int32(Some(5));
182
183 let args = ColumnarValue::Scalar(input);
184 let result = spark_factorial(&[args]).unwrap();
185 let result = match result {
186 ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
187 _ => panic!("Expected scalar"),
188 };
189 let actual = result.unwrap();
190 let expected = 120_i64;
191
192 assert_eq!(actual, expected);
193 }
194}