datafusion_spark/function/math/
expm1.rs1use crate::function::error_utils::unsupported_data_type_exec_err;
19use arrow::array::{ArrayRef, AsArray};
20use arrow::datatypes::{DataType, Float64Type};
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{Result, ScalarValue};
23use datafusion_expr::{
24 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
25};
26use std::sync::Arc;
27
28#[derive(Debug, PartialEq, Eq, Hash)]
30pub struct SparkExpm1 {
31 signature: Signature,
32}
33
34impl Default for SparkExpm1 {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl SparkExpm1 {
41 pub fn new() -> Self {
42 Self {
43 signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
44 }
45 }
46}
47
48impl ScalarUDFImpl for SparkExpm1 {
49 fn name(&self) -> &str {
50 "expm1"
51 }
52
53 fn signature(&self) -> &Signature {
54 &self.signature
55 }
56
57 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
58 Ok(DataType::Float64)
59 }
60
61 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
62 let [arg] = take_function_args(self.name(), args.args)?;
63 match arg {
64 ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok(
65 ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))),
66 ),
67 ColumnarValue::Array(array) => match array.data_type() {
68 DataType::Float64 => Ok(ColumnarValue::Array(Arc::new(
69 array
70 .as_primitive::<Float64Type>()
71 .unary::<_, Float64Type>(|x| x.exp_m1()),
72 )
73 as ArrayRef)),
74 other => Err(unsupported_data_type_exec_err(
75 "expm1",
76 format!("{}", DataType::Float64).as_str(),
77 other,
78 )),
79 },
80 other => Err(unsupported_data_type_exec_err(
81 "expm1",
82 format!("{}", DataType::Float64).as_str(),
83 &other.data_type(),
84 )),
85 }
86 }
87}