datafusion_functions_nested/
inner_product.rs1use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22use arrow::datatypes::{
23 DataType,
24 DataType::{FixedSizeList, LargeList, List, Null},
25 Field,
26};
27use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29use datafusion_common::{
30 Result, exec_err, internal_err, plan_err, utils::take_function_args,
31};
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 InnerProduct,
41 inner_product,
42 array1 array2,
43 "returns the inner product (dot product) of two numeric arrays.",
44 inner_product_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Returns the inner product (dot product) of two input arrays of equal length, computed as `sum(array1[i] * array2[i])`. Returns NULL if either array is NULL or contains NULL elements. Returns 0.0 for two empty arrays.",
50 syntax_example = "inner_product(array1, array2)",
51 sql_example = r#"```sql
52> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
53+-------------------------------------------------------+
54| inner_product(List([1.0,2.0,3.0]),List([4.0,5.0,6.0])) |
55+-------------------------------------------------------+
56| 32.0 |
57+-------------------------------------------------------+
58```"#,
59 argument(
60 name = "array1",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 ),
63 argument(
64 name = "array2",
65 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct InnerProduct {
70 signature: Signature,
71 aliases: Vec<String>,
72}
73
74impl Default for InnerProduct {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl InnerProduct {
81 pub fn new() -> Self {
82 Self {
83 signature: Signature::user_defined(Volatility::Immutable),
84 aliases: vec!["dot_product".to_string()],
85 }
86 }
87}
88
89impl ScalarUDFImpl for InnerProduct {
90 fn name(&self) -> &str {
91 "inner_product"
92 }
93
94 fn signature(&self) -> &Signature {
95 &self.signature
96 }
97
98 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99 Ok(DataType::Float64)
100 }
101
102 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103 let [_, _] = take_function_args(self.name(), arg_types)?;
104 let coercion = Some(&ListCoercion::FixedSizedListToList);
105
106 for arg_type in arg_types {
107 if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
108 return plan_err!("{} does not support type {arg_type}", self.name());
109 }
110 }
111
112 let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
116
117 let coerced = arg_types
118 .iter()
119 .map(|arg_type| {
120 if matches!(arg_type, Null) {
121 let field = Arc::new(Field::new_list_field(DataType::Float64, true));
122 return if any_large_list {
123 LargeList(field)
124 } else {
125 List(field)
126 };
127 }
128 let coerced = coerced_type_with_base_type_only(
129 arg_type,
130 &DataType::Float64,
131 coercion,
132 );
133 match coerced {
134 List(field) if any_large_list => LargeList(field),
135 other => other,
136 }
137 })
138 .collect();
139
140 Ok(coerced)
141 }
142
143 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
144 make_scalar_function(inner_product_inner)(&args.args)
145 }
146
147 fn aliases(&self) -> &[String] {
148 &self.aliases
149 }
150
151 fn documentation(&self) -> Option<&Documentation> {
152 self.doc()
153 }
154}
155
156fn inner_product_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
157 let [array1, array2] = take_function_args("inner_product", args)?;
158 match (array1.data_type(), array2.data_type()) {
159 (List(_), List(_)) => general_inner_product::<i32>(args),
160 (LargeList(_), LargeList(_)) => general_inner_product::<i64>(args),
161 (arg_type1, arg_type2) => internal_err!(
162 "inner_product received unexpected types after coercion: {arg_type1} and {arg_type2}"
163 ),
164 }
165}
166
167fn general_inner_product<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
168 let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
169 let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
170
171 let values1 = as_float64_array(list_array1.values())?;
172 let values2 = as_float64_array(list_array2.values())?;
173 let offsets1 = list_array1.value_offsets();
174 let offsets2 = list_array2.value_offsets();
175
176 let mut builder = Float64Array::builder(list_array1.len());
177 for row in 0..list_array1.len() {
178 if list_array1.is_null(row) || list_array2.is_null(row) {
179 builder.append_null();
180 continue;
181 }
182
183 let start1 = offsets1[row].as_usize();
184 let end1 = offsets1[row + 1].as_usize();
185 let start2 = offsets2[row].as_usize();
186 let end2 = offsets2[row + 1].as_usize();
187 let len1 = end1 - start1;
188 let len2 = end2 - start2;
189
190 if len1 != len2 {
191 return exec_err!(
192 "inner_product requires both list inputs to have the same length, got {len1} and {len2}"
193 );
194 }
195
196 let slice1 = values1.slice(start1, len1);
197 let slice2 = values2.slice(start2, len2);
198 if slice1.null_count() != 0 || slice2.null_count() != 0 {
199 builder.append_null();
200 continue;
201 }
202
203 let vals1 = slice1.values();
204 let vals2 = slice2.values();
205
206 let mut dot = 0.0;
207 for i in 0..len1 {
208 dot += vals1[i] * vals2[i];
209 }
210 builder.append_value(dot);
211 }
212
213 Ok(Arc::new(builder.finish()) as ArrayRef)
214}