datafusion_spark/function/math/
expm1.rs1use crate::function::error_utils::unsupported_data_type_exec_err;
19use arrow::array::{ArrayRef, AsArray};
20use arrow::datatypes::{DataType, Float64Type};
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::{
24 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
25};
26use std::any::Any;
27use std::sync::Arc;
28
29#[derive(Debug, PartialEq, Eq, Hash)]
31pub struct SparkExpm1 {
32 signature: Signature,
33}
34
35impl Default for SparkExpm1 {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SparkExpm1 {
42 pub fn new() -> Self {
43 Self {
44 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
45 }
46 }
47}
48
49impl ScalarUDFImpl for SparkExpm1 {
50 fn as_any(&self) -> &dyn Any {
51 self
52 }
53
54 fn name(&self) -> &str {
55 "expm1"
56 }
57
58 fn signature(&self) -> &Signature {
59 &self.signature
60 }
61
62 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
63 Ok(DataType::Float64)
64 }
65
66 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
67 let [arg] = take_function_args(self.name(), args.args)?;
68 match arg {
69 ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok(
70 ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))),
71 ),
72 ColumnarValue::Array(array) => match array.data_type() {
73 DataType::Float64 => Ok(ColumnarValue::Array(Arc::new(
74 array
75 .as_primitive::<Float64Type>()
76 .unary::<_, Float64Type>(|x| x.exp_m1()),
77 )
78 as ArrayRef)),
79 other => Err(unsupported_data_type_exec_err(
80 "expm1",
81 format!("{}", DataType::Float64).as_str(),
82 other,
83 )),
84 },
85 other => Err(unsupported_data_type_exec_err(
86 "expm1",
87 format!("{}", DataType::Float64).as_str(),
88 &other.data_type(),
89 )),
90 }
91 }
92}