datafusion_spark/function/math/
expm1.rs1use crate::function::error_utils::{
19 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
20};
21use arrow::array::{ArrayRef, AsArray};
22use arrow::datatypes::{DataType, Float64Type};
23use datafusion_common::{Result, ScalarValue};
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use std::any::Any;
28use std::sync::Arc;
29
30#[derive(Debug)]
32pub struct SparkExpm1 {
33 signature: Signature,
34 aliases: Vec<String>,
35}
36
37impl Default for SparkExpm1 {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl SparkExpm1 {
44 pub fn new() -> Self {
45 Self {
46 signature: Signature::user_defined(Volatility::Immutable),
47 aliases: vec![],
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkExpm1 {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn name(&self) -> &str {
58 "expm1"
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66 Ok(DataType::Float64)
67 }
68
69 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
70 if args.args.len() != 1 {
71 return Err(invalid_arg_count_exec_err("expm1", (1, 1), args.args.len()));
72 }
73 match &args.args[0] {
74 ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok(
75 ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))),
76 ),
77 ColumnarValue::Array(array) => match array.data_type() {
78 DataType::Float64 => Ok(ColumnarValue::Array(Arc::new(
79 array
80 .as_primitive::<Float64Type>()
81 .unary::<_, Float64Type>(|x| x.exp_m1()),
82 )
83 as ArrayRef)),
84 other => Err(unsupported_data_type_exec_err(
85 "expm1",
86 format!("{}", DataType::Float64).as_str(),
87 other,
88 )),
89 },
90 other => Err(unsupported_data_type_exec_err(
91 "expm1",
92 format!("{}", DataType::Float64).as_str(),
93 &other.data_type(),
94 )),
95 }
96 }
97
98 fn aliases(&self) -> &[String] {
99 &self.aliases
100 }
101
102 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103 if arg_types.len() != 1 {
104 return Err(invalid_arg_count_exec_err("expm1", (1, 1), arg_types.len()));
105 }
106 if arg_types[0].is_numeric() {
107 Ok(vec![DataType::Float64])
108 } else {
109 Err(unsupported_data_type_exec_err(
110 "expm1",
111 "Numeric Type",
112 &arg_types[0],
113 ))
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use crate::function::math::expm1::SparkExpm1;
121 use crate::function::utils::test::test_scalar_function;
122 use arrow::array::{Array, Float64Array};
123 use arrow::datatypes::DataType::Float64;
124 use datafusion_common::{Result, ScalarValue};
125 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
126
127 macro_rules! test_expm1_float64_invoke {
128 ($INPUT:expr, $EXPECTED:expr) => {
129 test_scalar_function!(
130 SparkExpm1::new(),
131 vec![ColumnarValue::Scalar(ScalarValue::Float64($INPUT))],
132 $EXPECTED,
133 f64,
134 Float64,
135 Float64Array
136 );
137 };
138 }
139
140 #[test]
141 fn test_expm1_invoke() -> Result<()> {
142 test_expm1_float64_invoke!(Some(0f64), Ok(Some(0.0f64)));
143 Ok(())
144 }
145}