datafusion_functions/math/
nanvl.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{DataFusionError, Result, exec_err};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "Math Functions"),
36    description = r#"Returns the first argument if it's not _NaN_.
37Returns the second argument otherwise."#,
38    syntax_example = "nanvl(expression_x, expression_y)",
39    sql_example = r#"```sql
40> SELECT nanvl(0, 5);
41+------------+
42| nanvl(0,5) |
43+------------+
44| 0          |
45+------------+
46```"#,
47    argument(
48        name = "expression_x",
49        description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
50    ),
51    argument(
52        name = "expression_y",
53        description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct NanvlFunc {
58    signature: Signature,
59}
60
61impl Default for NanvlFunc {
62    fn default() -> Self {
63        NanvlFunc::new()
64    }
65}
66
67impl NanvlFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
73                Volatility::Immutable,
74            ),
75        }
76    }
77}
78
79impl ScalarUDFImpl for NanvlFunc {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "nanvl"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93        match &arg_types[0] {
94            Float32 => Ok(Float32),
95            _ => Ok(Float64),
96        }
97    }
98
99    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100        make_scalar_function(nanvl, vec![])(&args.args)
101    }
102
103    fn documentation(&self) -> Option<&Documentation> {
104        self.doc()
105    }
106}
107
108/// Nanvl SQL function
109fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
110    match args[0].data_type() {
111        Float64 => {
112            let compute_nanvl = |x: f64, y: f64| {
113                if x.is_nan() { y } else { x }
114            };
115
116            let x = args[0].as_primitive() as &Float64Array;
117            let y = args[1].as_primitive() as &Float64Array;
118            arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl)
119                .map(|res| Arc::new(res) as _)
120                .map_err(DataFusionError::from)
121        }
122        Float32 => {
123            let compute_nanvl = |x: f32, y: f32| {
124                if x.is_nan() { y } else { x }
125            };
126
127            let x = args[0].as_primitive() as &Float32Array;
128            let y = args[1].as_primitive() as &Float32Array;
129            arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl)
130                .map(|res| Arc::new(res) as _)
131                .map_err(DataFusionError::from)
132        }
133        other => exec_err!("Unsupported data type {other:?} for function nanvl"),
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use std::sync::Arc;
140
141    use crate::math::nanvl::nanvl;
142
143    use arrow::array::{ArrayRef, Float32Array, Float64Array};
144    use datafusion_common::cast::{as_float32_array, as_float64_array};
145
146    #[test]
147    fn test_nanvl_f64() {
148        let args: Vec<ArrayRef> = vec![
149            Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y
150            Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x
151        ];
152
153        let result = nanvl(&args).expect("failed to initialize function nanvl");
154        let floats =
155            as_float64_array(&result).expect("failed to initialize function nanvl");
156
157        assert_eq!(floats.len(), 4);
158        assert_eq!(floats.value(0), 1.0);
159        assert_eq!(floats.value(1), 6.0);
160        assert_eq!(floats.value(2), 3.0);
161        assert!(floats.value(3).is_nan());
162    }
163
164    #[test]
165    fn test_nanvl_f32() {
166        let args: Vec<ArrayRef> = vec![
167            Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y
168            Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x
169        ];
170
171        let result = nanvl(&args).expect("failed to initialize function nanvl");
172        let floats =
173            as_float32_array(&result).expect("failed to initialize function nanvl");
174
175        assert_eq!(floats.len(), 4);
176        assert_eq!(floats.value(0), 1.0);
177        assert_eq!(floats.value(1), 6.0);
178        assert_eq!(floats.value(2), 3.0);
179        assert!(floats.value(3).is_nan());
180    }
181}