datafusion_spark/function/math/
factorial.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, Int64Array};
22use arrow::datatypes::DataType;
23use arrow::datatypes::DataType::{Int32, Int64};
24use datafusion_common::cast::as_int32_array;
25use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::Signature;
27use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
28
29#[derive(Debug)]
31pub struct SparkFactorial {
32 signature: Signature,
33 aliases: Vec<String>,
34}
35
36impl Default for SparkFactorial {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SparkFactorial {
43 pub fn new() -> Self {
44 Self {
45 signature: Signature::exact(vec![Int32], Volatility::Immutable),
46 aliases: vec![],
47 }
48 }
49}
50
51impl ScalarUDFImpl for SparkFactorial {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn name(&self) -> &str {
57 "factorial"
58 }
59
60 fn signature(&self) -> &Signature {
61 &self.signature
62 }
63
64 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65 Ok(Int64)
66 }
67
68 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
69 spark_factorial(&args.args)
70 }
71
72 fn aliases(&self) -> &[String] {
73 &self.aliases
74 }
75}
76
77const FACTORIALS: [i64; 21] = [
78 1,
79 1,
80 2,
81 6,
82 24,
83 120,
84 720,
85 5040,
86 40320,
87 362880,
88 3628800,
89 39916800,
90 479001600,
91 6227020800,
92 87178291200,
93 1307674368000,
94 20922789888000,
95 355687428096000,
96 6402373705728000,
97 121645100408832000,
98 2432902008176640000,
99];
100
101pub fn spark_factorial(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
102 if args.len() != 1 {
103 return Err(DataFusionError::Internal(
104 "`factorial` expects exactly one argument".to_string(),
105 ));
106 }
107
108 match &args[0] {
109 ColumnarValue::Scalar(ScalarValue::Int32(value)) => {
110 let result = compute_factorial(*value);
111 Ok(ColumnarValue::Scalar(ScalarValue::Int64(result)))
112 }
113 ColumnarValue::Scalar(other) => {
114 exec_err!("`factorial` got an unexpected scalar type: {:?}", other)
115 }
116 ColumnarValue::Array(array) => match array.data_type() {
117 Int32 => {
118 let array = as_int32_array(array)?;
119
120 let result: Int64Array = array.iter().map(compute_factorial).collect();
121
122 Ok(ColumnarValue::Array(Arc::new(result)))
123 }
124 other => {
125 exec_err!("`factorial` got an unexpected argument type: {:?}", other)
126 }
127 },
128 }
129}
130
131#[inline]
132fn compute_factorial(num: Option<i32>) -> Option<i64> {
133 num.filter(|&v| (0..=20).contains(&v))
134 .map(|v| FACTORIALS[v as usize])
135}
136
137#[cfg(test)]
138mod test {
139 use crate::function::math::factorial::spark_factorial;
140 use arrow::array::{Int32Array, Int64Array};
141 use datafusion_common::cast::as_int64_array;
142 use datafusion_common::ScalarValue;
143 use datafusion_expr::ColumnarValue;
144 use std::sync::Arc;
145
146 #[test]
147 fn test_spark_factorial_array() {
148 let input = Int32Array::from(vec![
149 Some(-1),
150 Some(0),
151 Some(1),
152 Some(2),
153 Some(4),
154 Some(20),
155 Some(21),
156 None,
157 ]);
158
159 let args = ColumnarValue::Array(Arc::new(input));
160 let result = spark_factorial(&[args]).unwrap();
161 let result = match result {
162 ColumnarValue::Array(array) => array,
163 _ => panic!("Expected array"),
164 };
165
166 let actual = as_int64_array(&result).unwrap();
167 let expected = Int64Array::from(vec![
168 None,
169 Some(1),
170 Some(1),
171 Some(2),
172 Some(24),
173 Some(2432902008176640000),
174 None,
175 None,
176 ]);
177
178 assert_eq!(actual, &expected);
179 }
180
181 #[test]
182 fn test_spark_factorial_scalar() {
183 let input = ScalarValue::Int32(Some(5));
184
185 let args = ColumnarValue::Scalar(input);
186 let result = spark_factorial(&[args]).unwrap();
187 let result = match result {
188 ColumnarValue::Scalar(ScalarValue::Int64(val)) => val,
189 _ => panic!("Expected scalar"),
190 };
191 let actual = result.unwrap();
192 let expected = 120_i64;
193
194 assert_eq!(actual, expected);
195 }
196}