Skip to main content

datafusion_functions_nested/
inner_product.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for inner_product function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22use arrow::datatypes::{
23    DataType,
24    DataType::{FixedSizeList, LargeList, List, Null},
25    Field,
26};
27use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29use datafusion_common::{
30    Result, exec_err, internal_err, plan_err, utils::take_function_args,
31};
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    InnerProduct,
41    inner_product,
42    array1 array2,
43    "returns the inner product (dot product) of two numeric arrays.",
44    inner_product_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Returns the inner product (dot product) of two input arrays of equal length, computed as `sum(array1[i] * array2[i])`. Returns NULL if either array is NULL or contains NULL elements. Returns 0.0 for two empty arrays.",
50    syntax_example = "inner_product(array1, array2)",
51    sql_example = r#"```sql
52> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
53+-------------------------------------------------------+
54| inner_product(List([1.0,2.0,3.0]),List([4.0,5.0,6.0])) |
55+-------------------------------------------------------+
56| 32.0                                                  |
57+-------------------------------------------------------+
58```"#,
59    argument(
60        name = "array1",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    ),
63    argument(
64        name = "array2",
65        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct InnerProduct {
70    signature: Signature,
71    aliases: Vec<String>,
72}
73
74impl Default for InnerProduct {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl InnerProduct {
81    pub fn new() -> Self {
82        Self {
83            signature: Signature::user_defined(Volatility::Immutable),
84            aliases: vec!["dot_product".to_string()],
85        }
86    }
87}
88
89impl ScalarUDFImpl for InnerProduct {
90    fn name(&self) -> &str {
91        "inner_product"
92    }
93
94    fn signature(&self) -> &Signature {
95        &self.signature
96    }
97
98    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99        Ok(DataType::Float64)
100    }
101
102    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103        let [_, _] = take_function_args(self.name(), arg_types)?;
104        let coercion = Some(&ListCoercion::FixedSizedListToList);
105
106        for arg_type in arg_types {
107            if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
108                return plan_err!("{} does not support type {arg_type}", self.name());
109            }
110        }
111
112        // If any input is `LargeList`, both sides must be widened to `LargeList`
113        // so the runtime dispatch in `inner_product_inner` sees a homogeneous
114        // pair. Follows the pattern in `ArrayConcat::coerce_types`.
115        let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
116
117        let coerced = arg_types
118            .iter()
119            .map(|arg_type| {
120                if matches!(arg_type, Null) {
121                    let field = Arc::new(Field::new_list_field(DataType::Float64, true));
122                    return if any_large_list {
123                        LargeList(field)
124                    } else {
125                        List(field)
126                    };
127                }
128                let coerced = coerced_type_with_base_type_only(
129                    arg_type,
130                    &DataType::Float64,
131                    coercion,
132                );
133                match coerced {
134                    List(field) if any_large_list => LargeList(field),
135                    other => other,
136                }
137            })
138            .collect();
139
140        Ok(coerced)
141    }
142
143    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
144        make_scalar_function(inner_product_inner)(&args.args)
145    }
146
147    fn aliases(&self) -> &[String] {
148        &self.aliases
149    }
150
151    fn documentation(&self) -> Option<&Documentation> {
152        self.doc()
153    }
154}
155
156fn inner_product_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
157    let [array1, array2] = take_function_args("inner_product", args)?;
158    match (array1.data_type(), array2.data_type()) {
159        (List(_), List(_)) => general_inner_product::<i32>(args),
160        (LargeList(_), LargeList(_)) => general_inner_product::<i64>(args),
161        (arg_type1, arg_type2) => internal_err!(
162            "inner_product received unexpected types after coercion: {arg_type1} and {arg_type2}"
163        ),
164    }
165}
166
167fn general_inner_product<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
168    let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
169    let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
170
171    let values1 = as_float64_array(list_array1.values())?;
172    let values2 = as_float64_array(list_array2.values())?;
173    let offsets1 = list_array1.value_offsets();
174    let offsets2 = list_array2.value_offsets();
175
176    let mut builder = Float64Array::builder(list_array1.len());
177    for row in 0..list_array1.len() {
178        if list_array1.is_null(row) || list_array2.is_null(row) {
179            builder.append_null();
180            continue;
181        }
182
183        let start1 = offsets1[row].as_usize();
184        let end1 = offsets1[row + 1].as_usize();
185        let start2 = offsets2[row].as_usize();
186        let end2 = offsets2[row + 1].as_usize();
187        let len1 = end1 - start1;
188        let len2 = end2 - start2;
189
190        if len1 != len2 {
191            return exec_err!(
192                "inner_product requires both list inputs to have the same length, got {len1} and {len2}"
193            );
194        }
195
196        let slice1 = values1.slice(start1, len1);
197        let slice2 = values2.slice(start2, len2);
198        if slice1.null_count() != 0 || slice2.null_count() != 0 {
199            builder.append_null();
200            continue;
201        }
202
203        let vals1 = slice1.values();
204        let vals2 = slice2.values();
205
206        let mut dot = 0.0;
207        for i in 0..len1 {
208            dot += vals1[i] * vals2[i];
209        }
210        builder.append_value(dot);
211    }
212
213    Ok(Arc::new(builder.finish()) as ArrayRef)
214}