datafusion_spark/function/aggregate/
collect.rs1use arrow::array::ArrayRef;
19use arrow::datatypes::{DataType, Field, FieldRef};
20use datafusion_common::utils::SingleRowListArrayBuilder;
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
23use datafusion_expr::utils::format_state_name;
24use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
25use datafusion_functions_aggregate::array_agg::{
26 ArrayAggAccumulator, DistinctArrayAggAccumulator,
27};
28use std::sync::Arc;
29
30#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkCollectList {
39 signature: Signature,
40}
41
42impl Default for SparkCollectList {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SparkCollectList {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::any(1, Volatility::Immutable),
52 }
53 }
54}
55
56impl AggregateUDFImpl for SparkCollectList {
57 fn name(&self) -> &str {
58 "collect_list"
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
66 Ok(DataType::List(Arc::new(Field::new_list_field(
67 arg_types[0].clone(),
68 true,
69 ))))
70 }
71
72 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
73 Ok(vec![
74 Field::new_list(
75 format_state_name(args.name, "collect_list"),
76 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
77 true,
78 )
79 .into(),
80 ])
81 }
82
83 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
84 let field = &acc_args.expr_fields[0];
85 let data_type = field.data_type().clone();
86 let ignore_nulls = true;
87 Ok(Box::new(NullToEmptyListAccumulator::new(
88 ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?,
89 data_type,
90 )))
91 }
92}
93
94#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct SparkCollectSet {
97 signature: Signature,
98}
99
100impl Default for SparkCollectSet {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl SparkCollectSet {
107 pub fn new() -> Self {
108 Self {
109 signature: Signature::any(1, Volatility::Immutable),
110 }
111 }
112}
113
114impl AggregateUDFImpl for SparkCollectSet {
115 fn name(&self) -> &str {
116 "collect_set"
117 }
118
119 fn signature(&self) -> &Signature {
120 &self.signature
121 }
122
123 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
124 Ok(DataType::List(Arc::new(Field::new_list_field(
125 arg_types[0].clone(),
126 true,
127 ))))
128 }
129
130 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
131 Ok(vec![
132 Field::new_list(
133 format_state_name(args.name, "collect_set"),
134 Field::new_list_field(args.input_fields[0].data_type().clone(), true),
135 true,
136 )
137 .into(),
138 ])
139 }
140
141 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
142 let field = &acc_args.expr_fields[0];
143 let data_type = field.data_type().clone();
144 let ignore_nulls = true;
145 Ok(Box::new(NullToEmptyListAccumulator::new(
146 DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?,
147 data_type,
148 )))
149 }
150}
151
152#[derive(Debug)]
155struct NullToEmptyListAccumulator<T: Accumulator> {
156 inner: T,
157 data_type: DataType,
158}
159
160impl<T: Accumulator> NullToEmptyListAccumulator<T> {
161 pub fn new(inner: T, data_type: DataType) -> Self {
162 Self { inner, data_type }
163 }
164}
165
166impl<T: Accumulator> Accumulator for NullToEmptyListAccumulator<T> {
167 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
168 self.inner.update_batch(values)
169 }
170
171 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
172 self.inner.merge_batch(states)
173 }
174
175 fn state(&mut self) -> Result<Vec<ScalarValue>> {
176 self.inner.state()
177 }
178
179 fn evaluate(&mut self) -> Result<ScalarValue> {
180 let result = self.inner.evaluate()?;
181 if result.is_null() {
182 let empty_array = arrow::array::new_empty_array(&self.data_type);
183 Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar())
184 } else {
185 Ok(result)
186 }
187 }
188
189 fn size(&self) -> usize {
190 self.inner.size() + self.data_type.size()
191 }
192}