datafusion_spark/function/aggregate/
collect.rs1use arrow::array::ArrayRef;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::utils::SingleRowListArrayBuilder;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
23use datafusion_expr::utils::format_state_name;
24use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
25use datafusion_functions_aggregate::array_agg::{
26 ArrayAggAccumulator, DistinctArrayAggAccumulator,
27};
28use std::{any::Any, sync::Arc};
29
30#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkCollectList {
39 signature: Signature,
40}
41
42impl Default for SparkCollectList {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SparkCollectList {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::any(1, Volatility::Immutable),
52 }
53 }
54}
55
56impl AggregateUDFImpl for SparkCollectList {
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn name(&self) -> &str {
62 "collect_list"
63 }
64
65 fn signature(&self) -> &Signature {
66 &self.signature
67 }
68
69 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
70 Ok(DataType::List(Arc::new(Field::new_list_field(
71 arg_types[0].clone(),
72 true,
73 ))))
74 }
75
76 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
77 Ok(vec![
78 Field::new_list(
79 format_state_name(args.name, "collect_list"),
80 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
81 true,
82 )
83 .into(),
84 ])
85 }
86
87 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
88 let field = &acc_args.expr_fields[0];
89 let data_type = field.data_type().clone();
90 let ignore_nulls = true;
91 Ok(Box::new(NullToEmptyListAccumulator::new(
92 ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?,
93 data_type,
94 )))
95 }
96}
97
98#[derive(Debug, PartialEq, Eq, Hash)]
100pub struct SparkCollectSet {
101 signature: Signature,
102}
103
104impl Default for SparkCollectSet {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl SparkCollectSet {
111 pub fn new() -> Self {
112 Self {
113 signature: Signature::any(1, Volatility::Immutable),
114 }
115 }
116}
117
118impl AggregateUDFImpl for SparkCollectSet {
119 fn as_any(&self) -> &dyn Any {
120 self
121 }
122
123 fn name(&self) -> &str {
124 "collect_set"
125 }
126
127 fn signature(&self) -> &Signature {
128 &self.signature
129 }
130
131 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
132 Ok(DataType::List(Arc::new(Field::new_list_field(
133 arg_types[0].clone(),
134 true,
135 ))))
136 }
137
138 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
139 Ok(vec![
140 Field::new_list(
141 format_state_name(args.name, "collect_set"),
142 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
143 true,
144 )
145 .into(),
146 ])
147 }
148
149 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
150 let field = &acc_args.expr_fields[0];
151 let data_type = field.data_type().clone();
152 let ignore_nulls = true;
153 Ok(Box::new(NullToEmptyListAccumulator::new(
154 DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?,
155 data_type,
156 )))
157 }
158}
159
160#[derive(Debug)]
163struct NullToEmptyListAccumulator<T: Accumulator> {
164 inner: T,
165 data_type: DataType,
166}
167
168impl<T: Accumulator> NullToEmptyListAccumulator<T> {
169 pub fn new(inner: T, data_type: DataType) -> Self {
170 Self { inner, data_type }
171 }
172}
173
174impl<T: Accumulator> Accumulator for NullToEmptyListAccumulator<T> {
175 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
176 self.inner.update_batch(values)
177 }
178
179 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
180 self.inner.merge_batch(states)
181 }
182
183 fn state(&mut self) -> Result<Vec<ScalarValue>> {
184 self.inner.state()
185 }
186
187 fn evaluate(&mut self) -> Result<ScalarValue> {
188 let result = self.inner.evaluate()?;
189 if result.is_null() {
190 let empty_array = arrow::array::new_empty_array(&self.data_type);
191 Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar())
192 } else {
193 Ok(result)
194 }
195 }
196
197 fn size(&self) -> usize {
198 self.inner.size() + self.data_type.size()
199 }
200}