Skip to main content

datafusion_functions_nested/
array_normalize.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_normalize function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder,
23    OffsetBufferBuilder, OffsetSizeTrait,
24};
25use arrow::datatypes::{
26    DataType,
27    DataType::{FixedSizeList, LargeList, List, Null},
28    Field,
29};
30use datafusion_common::cast::{as_float64_array, as_generic_list_array};
31use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
32use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
33use datafusion_expr::{
34    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
35    Volatility,
36};
37use datafusion_macros::user_doc;
38use std::sync::Arc;
39
40make_udf_expr_and_func!(
41    ArrayNormalize,
42    array_normalize,
43    array,
44    "returns the L2-normalized vector for a numeric array.",
45    array_normalize_udf
46);
47
48#[user_doc(
49    doc_section(label = "Array Functions"),
50    description = "Returns the L2-normalized vector for the input numeric array, computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns NULL if the input is NULL, contains NULL elements, or has zero magnitude (all elements are zero). Returns an empty array for an empty input array.",
51    syntax_example = "array_normalize(array)",
52    sql_example = r#"```sql
53> select array_normalize([3.0, 4.0]);
54+-----------------------------+
55| array_normalize(List([3.0,4.0])) |
56+-----------------------------+
57| [0.6, 0.8]                  |
58+-----------------------------+
59```"#,
60    argument(
61        name = "array",
62        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63    )
64)]
65#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct ArrayNormalize {
67    signature: Signature,
68    aliases: Vec<String>,
69}
70
71impl Default for ArrayNormalize {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl ArrayNormalize {
78    pub fn new() -> Self {
79        Self {
80            signature: Signature::user_defined(Volatility::Immutable),
81            aliases: vec!["list_normalize".to_string()],
82        }
83    }
84}
85
86impl ScalarUDFImpl for ArrayNormalize {
87    fn name(&self) -> &str {
88        "array_normalize"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96        // After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64).
97        Ok(arg_types[0].clone())
98    }
99
100    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101        let [arg_type] = take_function_args(self.name(), arg_types)?;
102        let coercion = Some(&ListCoercion::FixedSizedListToList);
103
104        if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
105            return plan_err!("{} does not support type {arg_type}", self.name());
106        }
107
108        let coerced = if matches!(arg_type, Null) {
109            List(Arc::new(Field::new_list_field(DataType::Float64, true)))
110        } else {
111            coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion)
112        };
113
114        Ok(vec![coerced])
115    }
116
117    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118        make_scalar_function(array_normalize_inner)(&args.args)
119    }
120
121    fn aliases(&self) -> &[String] {
122        &self.aliases
123    }
124
125    fn documentation(&self) -> Option<&Documentation> {
126        self.doc()
127    }
128}
129
130fn array_normalize_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
131    let [array] = take_function_args("array_normalize", args)?;
132    match array.data_type() {
133        List(_) => general_array_normalize::<i32>(args),
134        LargeList(_) => general_array_normalize::<i64>(args),
135        arg_type => internal_err!(
136            "array_normalize received unexpected type after coercion: {arg_type}"
137        ),
138    }
139}
140
141fn general_array_normalize<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
142    let list_array = as_generic_list_array::<O>(&arrays[0])?;
143    let values = as_float64_array(list_array.values())?;
144    let offsets = list_array.value_offsets();
145
146    let mut new_values: Vec<f64> = Vec::with_capacity(values.len());
147    let mut new_offsets = OffsetBufferBuilder::<O>::new(list_array.len());
148    let mut nulls = NullBufferBuilder::new(list_array.len());
149
150    for row in 0..list_array.len() {
151        if list_array.is_null(row) {
152            nulls.append_null();
153            new_offsets.push_length(0);
154            continue;
155        }
156
157        let start = offsets[row].as_usize();
158        let end = offsets[row + 1].as_usize();
159        let len = end - start;
160
161        let slice = values.slice(start, len);
162        if slice.null_count() != 0 {
163            nulls.append_null();
164            new_offsets.push_length(0);
165            continue;
166        }
167
168        let vals = slice.values();
169
170        // Empty array: return empty array (no normalization needed, no division by zero risk)
171        if len == 0 {
172            nulls.append_non_null();
173            new_offsets.push_length(0);
174            continue;
175        }
176
177        // Compute squared magnitude.
178        let mut sq_sum = 0.0;
179        for i in 0..len {
180            sq_sum += vals[i] * vals[i];
181        }
182
183        // Zero magnitude: undefined normalization. Emit NULL row.
184        if sq_sum == 0.0 {
185            nulls.append_null();
186            new_offsets.push_length(0);
187            continue;
188        }
189
190        let mag = sq_sum.sqrt();
191        for i in 0..len {
192            new_values.push(vals[i] / mag);
193        }
194        nulls.append_non_null();
195        new_offsets.push_length(len);
196    }
197
198    let values_array = Arc::new(Float64Array::from(new_values));
199    let field = Arc::new(Field::new_list_field(DataType::Float64, true));
200
201    Ok(Arc::new(GenericListArray::<O>::try_new(
202        field,
203        new_offsets.finish(),
204        values_array,
205        nulls.finish(),
206    )?))
207}