datafusion_functions_nested/
array_normalize.rs1use crate::utils::make_scalar_function;
21use arrow::array::{
22 Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder,
23 OffsetBufferBuilder, OffsetSizeTrait,
24};
25use arrow::datatypes::{
26 DataType,
27 DataType::{FixedSizeList, LargeList, List, Null},
28 Field,
29};
30use datafusion_common::cast::{as_float64_array, as_generic_list_array};
31use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
32use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
33use datafusion_expr::{
34 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
35 Volatility,
36};
37use datafusion_macros::user_doc;
38use std::sync::Arc;
39
40make_udf_expr_and_func!(
41 ArrayNormalize,
42 array_normalize,
43 array,
44 "returns the L2-normalized vector for a numeric array.",
45 array_normalize_udf
46);
47
48#[user_doc(
49 doc_section(label = "Array Functions"),
50 description = "Returns the L2-normalized vector for the input numeric array, computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns NULL if the input is NULL, contains NULL elements, or has zero magnitude (all elements are zero). Returns an empty array for an empty input array.",
51 syntax_example = "array_normalize(array)",
52 sql_example = r#"```sql
53> select array_normalize([3.0, 4.0]);
54+-----------------------------+
55| array_normalize(List([3.0,4.0])) |
56+-----------------------------+
57| [0.6, 0.8] |
58+-----------------------------+
59```"#,
60 argument(
61 name = "array",
62 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63 )
64)]
65#[derive(Debug, PartialEq, Eq, Hash)]
66pub struct ArrayNormalize {
67 signature: Signature,
68 aliases: Vec<String>,
69}
70
71impl Default for ArrayNormalize {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl ArrayNormalize {
78 pub fn new() -> Self {
79 Self {
80 signature: Signature::user_defined(Volatility::Immutable),
81 aliases: vec!["list_normalize".to_string()],
82 }
83 }
84}
85
86impl ScalarUDFImpl for ArrayNormalize {
87 fn name(&self) -> &str {
88 "array_normalize"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96 Ok(arg_types[0].clone())
98 }
99
100 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101 let [arg_type] = take_function_args(self.name(), arg_types)?;
102 let coercion = Some(&ListCoercion::FixedSizedListToList);
103
104 if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
105 return plan_err!("{} does not support type {arg_type}", self.name());
106 }
107
108 let coerced = if matches!(arg_type, Null) {
109 List(Arc::new(Field::new_list_field(DataType::Float64, true)))
110 } else {
111 coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion)
112 };
113
114 Ok(vec![coerced])
115 }
116
117 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
118 make_scalar_function(array_normalize_inner)(&args.args)
119 }
120
121 fn aliases(&self) -> &[String] {
122 &self.aliases
123 }
124
125 fn documentation(&self) -> Option<&Documentation> {
126 self.doc()
127 }
128}
129
130fn array_normalize_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
131 let [array] = take_function_args("array_normalize", args)?;
132 match array.data_type() {
133 List(_) => general_array_normalize::<i32>(args),
134 LargeList(_) => general_array_normalize::<i64>(args),
135 arg_type => internal_err!(
136 "array_normalize received unexpected type after coercion: {arg_type}"
137 ),
138 }
139}
140
141fn general_array_normalize<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
142 let list_array = as_generic_list_array::<O>(&arrays[0])?;
143 let values = as_float64_array(list_array.values())?;
144 let offsets = list_array.value_offsets();
145
146 let mut new_values: Vec<f64> = Vec::with_capacity(values.len());
147 let mut new_offsets = OffsetBufferBuilder::<O>::new(list_array.len());
148 let mut nulls = NullBufferBuilder::new(list_array.len());
149
150 for row in 0..list_array.len() {
151 if list_array.is_null(row) {
152 nulls.append_null();
153 new_offsets.push_length(0);
154 continue;
155 }
156
157 let start = offsets[row].as_usize();
158 let end = offsets[row + 1].as_usize();
159 let len = end - start;
160
161 let slice = values.slice(start, len);
162 if slice.null_count() != 0 {
163 nulls.append_null();
164 new_offsets.push_length(0);
165 continue;
166 }
167
168 let vals = slice.values();
169
170 if len == 0 {
172 nulls.append_non_null();
173 new_offsets.push_length(0);
174 continue;
175 }
176
177 let mut sq_sum = 0.0;
179 for i in 0..len {
180 sq_sum += vals[i] * vals[i];
181 }
182
183 if sq_sum == 0.0 {
185 nulls.append_null();
186 new_offsets.push_length(0);
187 continue;
188 }
189
190 let mag = sq_sum.sqrt();
191 for i in 0..len {
192 new_values.push(vals[i] / mag);
193 }
194 nulls.append_non_null();
195 new_offsets.push_length(len);
196 }
197
198 let values_array = Arc::new(Float64Array::from(new_values));
199 let field = Arc::new(Field::new_list_field(DataType::Float64, true));
200
201 Ok(Arc::new(GenericListArray::<O>::try_new(
202 field,
203 new_offsets.finish(),
204 values_array,
205 nulls.finish(),
206 )?))
207}