datafusion_functions_aggregate_common/aggregate/groups_accumulator/
prim_op.rs1use std::mem::size_of;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
22use arrow::buffer::NullBuffer;
23use arrow::compute;
24use arrow::datatypes::ArrowPrimitiveType;
25use arrow::datatypes::DataType;
26use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
27use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
28
29use super::accumulate::NullState;
30
31#[derive(Debug)]
41pub struct PrimitiveGroupsAccumulator<T, F>
42where
43 T: ArrowPrimitiveType + Send,
44 F: Fn(&mut T::Native, T::Native) + Send + Sync,
45{
46 values: Vec<T::Native>,
48
49 data_type: DataType,
51
52 starting_value: T::Native,
54
55 null_state: NullState,
57
58 prim_fn: F,
60}
61
62impl<T, F> PrimitiveGroupsAccumulator<T, F>
63where
64 T: ArrowPrimitiveType + Send,
65 F: Fn(&mut T::Native, T::Native) + Send + Sync,
66{
67 pub fn new(data_type: &DataType, prim_fn: F) -> Self {
68 Self {
69 values: vec![],
70 data_type: data_type.clone(),
71 null_state: NullState::new(),
72 starting_value: T::default_value(),
73 prim_fn,
74 }
75 }
76
77 pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
79 self.starting_value = starting_value;
80 self
81 }
82}
83
84impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
85where
86 T: ArrowPrimitiveType + Send,
87 F: Fn(&mut T::Native, T::Native) + Send + Sync,
88{
89 fn update_batch(
90 &mut self,
91 values: &[ArrayRef],
92 group_indices: &[usize],
93 opt_filter: Option<&BooleanArray>,
94 total_num_groups: usize,
95 ) -> Result<()> {
96 assert_eq!(values.len(), 1, "single argument to update_batch");
97 let values = values[0].as_primitive::<T>();
98
99 self.values.resize(total_num_groups, self.starting_value);
101
102 self.null_state.accumulate(
104 group_indices,
105 values,
106 opt_filter,
107 total_num_groups,
108 |group_index, new_value| {
109 let value = &mut self.values[group_index];
110 (self.prim_fn)(value, new_value);
111 },
112 );
113
114 Ok(())
115 }
116
117 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
118 let values = emit_to.take_needed(&mut self.values);
119 let nulls = self.null_state.build(emit_to);
120 let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) .with_data_type(self.data_type.clone());
122 Ok(Arc::new(values))
123 }
124
125 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
126 self.evaluate(emit_to).map(|arr| vec![arr])
127 }
128
129 fn merge_batch(
130 &mut self,
131 values: &[ArrayRef],
132 group_indices: &[usize],
133 opt_filter: Option<&BooleanArray>,
134 total_num_groups: usize,
135 ) -> Result<()> {
136 self.update_batch(values, group_indices, opt_filter, total_num_groups)
138 }
139
140 fn convert_to_state(
146 &self,
147 values: &[ArrayRef],
148 opt_filter: Option<&BooleanArray>,
149 ) -> Result<Vec<ArrayRef>> {
150 let values = values[0].as_primitive::<T>().clone();
151
152 let initial_state =
154 PrimitiveArray::<T>::from_value(self.starting_value, values.len());
155
156 let values = match opt_filter {
158 None => values,
159 Some(filter) => {
160 let (filter_values, filter_nulls) = filter.clone().into_parts();
161 let filter_bool = match filter_nulls {
163 Some(filter_nulls) => filter_nulls.inner() & &filter_values,
164 None => filter_values,
165 };
166 let filter_nulls = NullBuffer::from(filter_bool);
167
168 let (dt, values_buf, original_nulls) = values.into_parts();
171 let nulls_buf =
172 NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
173 PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
174 }
175 };
176
177 let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
178 (self.prim_fn)(&mut x, y);
179 x
180 });
181 let state_values = state_values
182 .map_err(|_| {
183 internal_datafusion_err!(
184 "initial_values underlying buffer must not be shared"
185 )
186 })?
187 .map_err(DataFusionError::from)?
188 .with_data_type(self.data_type.clone());
189
190 Ok(vec![Arc::new(state_values)])
191 }
192
193 fn supports_convert_to_state(&self) -> bool {
194 true
195 }
196
197 fn size(&self) -> usize {
198 self.values.capacity() * size_of::<T::Native>() + self.null_state.size()
199 }
200}