Skip to main content

datafusion_spark/function/collection/
size.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, AsArray, Int32Array};
19use arrow::compute::kernels::length::length as arrow_length;
20use arrow::datatypes::{DataType, Field, FieldRef};
21use datafusion_common::{Result, plan_err};
22use datafusion_expr::{
23    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
24    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
25};
26use datafusion_functions::utils::make_scalar_function;
27use std::sync::Arc;
28
29/// Spark-compatible `size` function.
30///
31/// Returns the number of elements in an array or the number of key-value pairs in a map.
32/// Returns -1 for null input (Spark behavior).
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkSize {
35    signature: Signature,
36}
37
38impl Default for SparkSize {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl SparkSize {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::one_of(
48                vec![
49                    // Array Type
50                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
51                        arguments: vec![ArrayFunctionArgument::Array],
52                        array_coercion: None,
53                    }),
54                    // Map Type
55                    TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
56                ],
57                Volatility::Immutable,
58            ),
59        }
60    }
61}
62
63impl ScalarUDFImpl for SparkSize {
64    fn name(&self) -> &str {
65        "size"
66    }
67
68    fn signature(&self) -> &Signature {
69        &self.signature
70    }
71
72    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
73        Ok(DataType::Int32)
74    }
75
76    fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
77        // nullable=false for legacy behavior (NULL -> -1); set to input nullability for null-on-null
78        Ok(Arc::new(Field::new(self.name(), DataType::Int32, false)))
79    }
80
81    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82        make_scalar_function(spark_size_inner, vec![])(&args.args)
83    }
84}
85
86fn spark_size_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
87    let array = &args[0];
88
89    match array.data_type() {
90        DataType::List(_) => {
91            if array.null_count() == 0 {
92                Ok(arrow_length(array)?)
93            } else {
94                let list_array = array.as_list::<i32>();
95                let lengths: Vec<i32> = list_array
96                    .offsets()
97                    .lengths()
98                    .enumerate()
99                    .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
100                    .collect();
101                Ok(Arc::new(Int32Array::from(lengths)))
102            }
103        }
104        DataType::FixedSizeList(_, size) => {
105            if array.null_count() == 0 {
106                Ok(arrow_length(array)?)
107            } else {
108                let length: Vec<i32> = (0..array.len())
109                    .map(|i| if array.is_null(i) { -1 } else { *size })
110                    .collect();
111                Ok(Arc::new(Int32Array::from(length)))
112            }
113        }
114        DataType::LargeList(_) => {
115            // Arrow length kernel returns Int64 for LargeList
116            let list_array = array.as_list::<i64>();
117            if array.null_count() == 0 {
118                let lengths: Vec<i32> = list_array
119                    .offsets()
120                    .lengths()
121                    .map(|len| len as i32)
122                    .collect();
123                Ok(Arc::new(Int32Array::from(lengths)))
124            } else {
125                let lengths: Vec<i32> = list_array
126                    .offsets()
127                    .lengths()
128                    .enumerate()
129                    .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
130                    .collect();
131                Ok(Arc::new(Int32Array::from(lengths)))
132            }
133        }
134        DataType::Map(_, _) => {
135            let map_array = array.as_map();
136            let length: Vec<i32> = if array.null_count() == 0 {
137                map_array
138                    .offsets()
139                    .lengths()
140                    .map(|len| len as i32)
141                    .collect()
142            } else {
143                map_array
144                    .offsets()
145                    .lengths()
146                    .enumerate()
147                    .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
148                    .collect()
149            };
150            Ok(Arc::new(Int32Array::from(length)))
151        }
152        DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))),
153        dt => {
154            plan_err!("size function does not support type: {}", dt)
155        }
156    }
157}