datafusion_functions/math/
nanvl.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{exec_err, DataFusionError, Result};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "Math Functions"),
36    description = r#"Returns the first argument if it's not _NaN_.
37Returns the second argument otherwise."#,
38    syntax_example = "nanvl(expression_x, expression_y)",
39    sql_example = r#"```sql
40> SELECT nanvl(0, 5);
41+------------+
42| nanvl(0,5) |
43+------------+
44| 0          |
45+------------+
46```"#,
47    argument(
48        name = "expression_x",
49        description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
50    ),
51    argument(
52        name = "expression_y",
53        description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct NanvlFunc {
58    signature: Signature,
59}
60
61impl Default for NanvlFunc {
62    fn default() -> Self {
63        NanvlFunc::new()
64    }
65}
66
67impl NanvlFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
73                Volatility::Immutable,
74            ),
75        }
76    }
77}
78
79impl ScalarUDFImpl for NanvlFunc {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "nanvl"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93        match &arg_types[0] {
94            Float32 => Ok(Float32),
95            _ => Ok(Float64),
96        }
97    }
98
99    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100        make_scalar_function(nanvl, vec![])(&args.args)
101    }
102
103    fn documentation(&self) -> Option<&Documentation> {
104        self.doc()
105    }
106}
107
108/// Nanvl SQL function
109fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
110    match args[0].data_type() {
111        Float64 => {
112            let compute_nanvl = |x: f64, y: f64| {
113                if x.is_nan() {
114                    y
115                } else {
116                    x
117                }
118            };
119
120            let x = args[0].as_primitive() as &Float64Array;
121            let y = args[1].as_primitive() as &Float64Array;
122            arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl)
123                .map(|res| Arc::new(res) as _)
124                .map_err(DataFusionError::from)
125        }
126        Float32 => {
127            let compute_nanvl = |x: f32, y: f32| {
128                if x.is_nan() {
129                    y
130                } else {
131                    x
132                }
133            };
134
135            let x = args[0].as_primitive() as &Float32Array;
136            let y = args[1].as_primitive() as &Float32Array;
137            arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl)
138                .map(|res| Arc::new(res) as _)
139                .map_err(DataFusionError::from)
140        }
141        other => exec_err!("Unsupported data type {other:?} for function nanvl"),
142    }
143}
144
145#[cfg(test)]
146mod test {
147    use std::sync::Arc;
148
149    use crate::math::nanvl::nanvl;
150
151    use arrow::array::{ArrayRef, Float32Array, Float64Array};
152    use datafusion_common::cast::{as_float32_array, as_float64_array};
153
154    #[test]
155    fn test_nanvl_f64() {
156        let args: Vec<ArrayRef> = vec![
157            Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y
158            Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x
159        ];
160
161        let result = nanvl(&args).expect("failed to initialize function nanvl");
162        let floats =
163            as_float64_array(&result).expect("failed to initialize function nanvl");
164
165        assert_eq!(floats.len(), 4);
166        assert_eq!(floats.value(0), 1.0);
167        assert_eq!(floats.value(1), 6.0);
168        assert_eq!(floats.value(2), 3.0);
169        assert!(floats.value(3).is_nan());
170    }
171
172    #[test]
173    fn test_nanvl_f32() {
174        let args: Vec<ArrayRef> = vec![
175            Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y
176            Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x
177        ];
178
179        let result = nanvl(&args).expect("failed to initialize function nanvl");
180        let floats =
181            as_float32_array(&result).expect("failed to initialize function nanvl");
182
183        assert_eq!(floats.len(), 4);
184        assert_eq!(floats.value(0), 1.0);
185        assert_eq!(floats.value(1), 6.0);
186        assert_eq!(floats.value(2), 3.0);
187        assert!(floats.value(3).is_nan());
188    }
189}