datafusion_functions/math/
nanvl.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{exec_err, DataFusionError, Result};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "Math Functions"),
36 description = r#"Returns the first argument if it's not _NaN_.
37Returns the second argument otherwise."#,
38 syntax_example = "nanvl(expression_x, expression_y)",
39 sql_example = r#"```sql
40> SELECT nanvl(0, 5);
41+------------+
42| nanvl(0,5) |
43+------------+
44| 0 |
45+------------+
46```"#,
47 argument(
48 name = "expression_x",
49 description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
50 ),
51 argument(
52 name = "expression_y",
53 description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct NanvlFunc {
58 signature: Signature,
59}
60
61impl Default for NanvlFunc {
62 fn default() -> Self {
63 NanvlFunc::new()
64 }
65}
66
67impl NanvlFunc {
68 pub fn new() -> Self {
69 use DataType::*;
70 Self {
71 signature: Signature::one_of(
72 vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
73 Volatility::Immutable,
74 ),
75 }
76 }
77}
78
79impl ScalarUDFImpl for NanvlFunc {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "nanvl"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
93 match &arg_types[0] {
94 Float32 => Ok(Float32),
95 _ => Ok(Float64),
96 }
97 }
98
99 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100 make_scalar_function(nanvl, vec![])(&args.args)
101 }
102
103 fn documentation(&self) -> Option<&Documentation> {
104 self.doc()
105 }
106}
107
108fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
110 match args[0].data_type() {
111 Float64 => {
112 let compute_nanvl = |x: f64, y: f64| {
113 if x.is_nan() {
114 y
115 } else {
116 x
117 }
118 };
119
120 let x = args[0].as_primitive() as &Float64Array;
121 let y = args[1].as_primitive() as &Float64Array;
122 arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl)
123 .map(|res| Arc::new(res) as _)
124 .map_err(DataFusionError::from)
125 }
126 Float32 => {
127 let compute_nanvl = |x: f32, y: f32| {
128 if x.is_nan() {
129 y
130 } else {
131 x
132 }
133 };
134
135 let x = args[0].as_primitive() as &Float32Array;
136 let y = args[1].as_primitive() as &Float32Array;
137 arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl)
138 .map(|res| Arc::new(res) as _)
139 .map_err(DataFusionError::from)
140 }
141 other => exec_err!("Unsupported data type {other:?} for function nanvl"),
142 }
143}
144
145#[cfg(test)]
146mod test {
147 use std::sync::Arc;
148
149 use crate::math::nanvl::nanvl;
150
151 use arrow::array::{ArrayRef, Float32Array, Float64Array};
152 use datafusion_common::cast::{as_float32_array, as_float64_array};
153
154 #[test]
155 fn test_nanvl_f64() {
156 let args: Vec<ArrayRef> = vec![
157 Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), ];
160
161 let result = nanvl(&args).expect("failed to initialize function nanvl");
162 let floats =
163 as_float64_array(&result).expect("failed to initialize function nanvl");
164
165 assert_eq!(floats.len(), 4);
166 assert_eq!(floats.value(0), 1.0);
167 assert_eq!(floats.value(1), 6.0);
168 assert_eq!(floats.value(2), 3.0);
169 assert!(floats.value(3).is_nan());
170 }
171
172 #[test]
173 fn test_nanvl_f32() {
174 let args: Vec<ArrayRef> = vec![
175 Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), ];
178
179 let result = nanvl(&args).expect("failed to initialize function nanvl");
180 let floats =
181 as_float32_array(&result).expect("failed to initialize function nanvl");
182
183 assert_eq!(floats.len(), 4);
184 assert_eq!(floats.value(0), 1.0);
185 assert_eq!(floats.value(1), 6.0);
186 assert_eq!(floats.value(2), 3.0);
187 assert!(floats.value(3).is_nan());
188 }
189}