1use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::hash::Hash;
24use std::mem::{size_of, size_of_val};
25
26use ahash::RandomState;
27use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
28use arrow::datatypes::{
29 ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type,
30 Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
31};
32
33use datafusion_common::cast::as_list_array;
34use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::INTEGERS;
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
40 Signature, Volatility,
41};
42
43use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
46use std::sync::LazyLock;
47
48macro_rules! group_accumulator_helper {
51 ($t:ty, $dt:expr, $opr:expr) => {
52 match $opr {
53 BitwiseOperationType::And => Ok(Box::new(
54 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
55 .with_starting_value(!0),
56 )),
57 BitwiseOperationType::Or => Ok(Box::new(
58 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
59 )),
60 BitwiseOperationType::Xor => Ok(Box::new(
61 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
62 )),
63 }
64 };
65}
66
67macro_rules! accumulator_helper {
69 ($t:ty, $opr:expr, $is_distinct: expr) => {
70 match $opr {
71 BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
72 BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
73 BitwiseOperationType::Xor => {
74 if $is_distinct {
75 Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
76 } else {
77 Ok(Box::<BitXorAccumulator<$t>>::default())
78 }
79 }
80 }
81 };
82}
83
84macro_rules! downcast_bitwise_accumulator {
90 ($args:ident, $opr:expr, $is_distinct: expr) => {
91 match $args.return_field.data_type() {
92 DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
93 DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
94 DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
95 DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
96 DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
97 DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
98 DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
99 DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
100 _ => {
101 not_impl_err!(
102 "{} not supported for {}: {}",
103 stringify!($opr),
104 $args.name,
105 $args.return_field.data_type()
106 )
107 }
108 }
109 };
110}
111
112macro_rules! make_bitwise_udaf_expr_and_func {
119 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
120 make_udaf_expr!(
121 $EXPR_FN,
122 expr_x,
123 concat!(
124 "Returns the bitwise",
125 stringify!($OPR_TYPE),
126 "of a group of values"
127 ),
128 $AGGREGATE_UDF_FN
129 );
130 create_func!(
131 $EXPR_FN,
132 $AGGREGATE_UDF_FN,
133 BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
134 );
135 };
136}
137
138static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
139 Documentation::builder(
140 DOC_SECTION_GENERAL,
141 "Computes the bitwise AND of all non-null input values.",
142 "bit_and(expression)",
143 )
144 .with_standard_argument("expression", Some("Integer"))
145 .build()
146});
147
148fn get_bit_and_doc() -> &'static Documentation {
149 &BIT_AND_DOC
150}
151
152static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
153 Documentation::builder(
154 DOC_SECTION_GENERAL,
155 "Computes the bitwise OR of all non-null input values.",
156 "bit_or(expression)",
157 )
158 .with_standard_argument("expression", Some("Integer"))
159 .build()
160});
161
162fn get_bit_or_doc() -> &'static Documentation {
163 &BIT_OR_DOC
164}
165
166static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
167 Documentation::builder(
168 DOC_SECTION_GENERAL,
169 "Computes the bitwise exclusive OR of all non-null input values.",
170 "bit_xor(expression)",
171 )
172 .with_standard_argument("expression", Some("Integer"))
173 .build()
174});
175
176fn get_bit_xor_doc() -> &'static Documentation {
177 &BIT_XOR_DOC
178}
179
180make_bitwise_udaf_expr_and_func!(
181 bit_and,
182 bit_and_udaf,
183 BitwiseOperationType::And,
184 get_bit_and_doc()
185);
186make_bitwise_udaf_expr_and_func!(
187 bit_or,
188 bit_or_udaf,
189 BitwiseOperationType::Or,
190 get_bit_or_doc()
191);
192make_bitwise_udaf_expr_and_func!(
193 bit_xor,
194 bit_xor_udaf,
195 BitwiseOperationType::Xor,
196 get_bit_xor_doc()
197);
198
199#[derive(Debug, Clone, Eq, PartialEq, Hash)]
201enum BitwiseOperationType {
202 And,
203 Or,
204 Xor,
205}
206
207impl Display for BitwiseOperationType {
208 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209 write!(f, "{self:?}")
210 }
211}
212
213#[derive(Debug, PartialEq, Eq, Hash)]
215struct BitwiseOperation {
216 signature: Signature,
217 operation: BitwiseOperationType,
219 func_name: &'static str,
220 documentation: &'static Documentation,
221}
222
223impl BitwiseOperation {
224 pub fn new(
225 operator: BitwiseOperationType,
226 func_name: &'static str,
227 documentation: &'static Documentation,
228 ) -> Self {
229 Self {
230 operation: operator,
231 signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
232 func_name,
233 documentation,
234 }
235 }
236}
237
238impl AggregateUDFImpl for BitwiseOperation {
239 fn as_any(&self) -> &dyn Any {
240 self
241 }
242
243 fn name(&self) -> &str {
244 self.func_name
245 }
246
247 fn signature(&self) -> &Signature {
248 &self.signature
249 }
250
251 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
252 let arg_type = &arg_types[0];
253 if !arg_type.is_integer() {
254 return exec_err!(
255 "[return_type] {} not supported for {}",
256 self.name(),
257 arg_type
258 );
259 }
260 Ok(arg_type.clone())
261 }
262
263 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264 downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
265 }
266
267 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
268 if self.operation == BitwiseOperationType::Xor && args.is_distinct {
269 Ok(vec![Field::new_list(
270 format_state_name(
271 args.name,
272 format!("{} distinct", self.name()).as_str(),
273 ),
274 Field::new_list_field(args.return_type().clone(), true),
276 false,
277 )
278 .into()])
279 } else {
280 Ok(vec![Field::new(
281 format_state_name(args.name, self.name()),
282 args.return_field.data_type().clone(),
283 true,
284 )
285 .into()])
286 }
287 }
288
289 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
290 true
291 }
292
293 fn create_groups_accumulator(
294 &self,
295 args: AccumulatorArgs,
296 ) -> Result<Box<dyn GroupsAccumulator>> {
297 let data_type = args.return_field.data_type();
298 let operation = &self.operation;
299 downcast_integer! {
300 data_type => (group_accumulator_helper, data_type, operation),
301 _ => not_impl_err!(
302 "GroupsAccumulator not supported for {} with {}",
303 self.name(),
304 data_type
305 ),
306 }
307 }
308
309 fn reverse_expr(&self) -> ReversedUDAF {
310 ReversedUDAF::Identical
311 }
312
313 fn documentation(&self) -> Option<&Documentation> {
314 Some(self.documentation)
315 }
316}
317
318struct BitAndAccumulator<T: ArrowNumericType> {
319 value: Option<T::Native>,
320}
321
322impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
323 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
324 write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
325 }
326}
327
328impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
329 fn default() -> Self {
330 Self { value: None }
331 }
332}
333
334impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
335where
336 T::Native: std::ops::BitAnd<Output = T::Native>,
337{
338 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
339 if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
340 let v = self.value.get_or_insert(x);
341 *v = *v & x;
342 }
343 Ok(())
344 }
345
346 fn evaluate(&mut self) -> Result<ScalarValue> {
347 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
348 }
349
350 fn size(&self) -> usize {
351 size_of_val(self)
352 }
353
354 fn state(&mut self) -> Result<Vec<ScalarValue>> {
355 Ok(vec![self.evaluate()?])
356 }
357
358 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
359 self.update_batch(states)
360 }
361}
362
363struct BitOrAccumulator<T: ArrowNumericType> {
364 value: Option<T::Native>,
365}
366
367impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
368 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369 write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
370 }
371}
372
373impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
374 fn default() -> Self {
375 Self { value: None }
376 }
377}
378
379impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
380where
381 T::Native: std::ops::BitOr<Output = T::Native>,
382{
383 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
384 if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
385 let v = self.value.get_or_insert(T::Native::usize_as(0));
386 *v = *v | x;
387 }
388 Ok(())
389 }
390
391 fn evaluate(&mut self) -> Result<ScalarValue> {
392 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
393 }
394
395 fn size(&self) -> usize {
396 size_of_val(self)
397 }
398
399 fn state(&mut self) -> Result<Vec<ScalarValue>> {
400 Ok(vec![self.evaluate()?])
401 }
402
403 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
404 self.update_batch(states)
405 }
406}
407
408struct BitXorAccumulator<T: ArrowNumericType> {
409 value: Option<T::Native>,
410}
411
412impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
413 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
414 write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
415 }
416}
417
418impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
419 fn default() -> Self {
420 Self { value: None }
421 }
422}
423
424impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
425where
426 T::Native: std::ops::BitXor<Output = T::Native>,
427{
428 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
429 if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
430 let v = self.value.get_or_insert(T::Native::usize_as(0));
431 *v = *v ^ x;
432 }
433 Ok(())
434 }
435
436 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
437 self.update_batch(values)
439 }
440
441 fn supports_retract_batch(&self) -> bool {
442 true
443 }
444
445 fn evaluate(&mut self) -> Result<ScalarValue> {
446 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
447 }
448
449 fn size(&self) -> usize {
450 size_of_val(self)
451 }
452
453 fn state(&mut self) -> Result<Vec<ScalarValue>> {
454 Ok(vec![self.evaluate()?])
455 }
456
457 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
458 self.update_batch(states)
459 }
460}
461
462struct DistinctBitXorAccumulator<T: ArrowNumericType> {
463 values: HashSet<T::Native, RandomState>,
464}
465
466impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
467 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
468 write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
469 }
470}
471
472impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
473 fn default() -> Self {
474 Self {
475 values: HashSet::default(),
476 }
477 }
478}
479
480impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
481where
482 T::Native: std::ops::BitXor<Output = T::Native> + Hash + Eq,
483{
484 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
485 if values.is_empty() {
486 return Ok(());
487 }
488
489 let array = values[0].as_primitive::<T>();
490 match array.nulls().filter(|x| x.null_count() > 0) {
491 Some(n) => {
492 for idx in n.valid_indices() {
493 self.values.insert(array.value(idx));
494 }
495 }
496 None => array.values().iter().for_each(|x| {
497 self.values.insert(*x);
498 }),
499 }
500 Ok(())
501 }
502
503 fn evaluate(&mut self) -> Result<ScalarValue> {
504 let mut acc = T::Native::usize_as(0);
505 for distinct_value in self.values.iter() {
506 acc = acc ^ *distinct_value;
507 }
508 let v = (!self.values.is_empty()).then_some(acc);
509 ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
510 }
511
512 fn size(&self) -> usize {
513 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
514 }
515
516 fn state(&mut self) -> Result<Vec<ScalarValue>> {
517 let state_out = {
520 let values = self
521 .values
522 .iter()
523 .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
524 .collect::<Result<Vec<_>>>()?;
525
526 let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
527 vec![ScalarValue::List(arr)]
528 };
529 Ok(state_out)
530 }
531
532 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
533 if let Some(state) = states.first() {
534 let list_arr = as_list_array(state)?;
535 for arr in list_arr.iter().flatten() {
536 self.update_batch(&[arr])?;
537 }
538 }
539 Ok(())
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use std::sync::Arc;
546
547 use arrow::array::{ArrayRef, UInt64Array};
548 use arrow::datatypes::UInt64Type;
549 use datafusion_common::ScalarValue;
550
551 use crate::bit_and_or_xor::BitXorAccumulator;
552 use datafusion_expr::Accumulator;
553
554 #[test]
555 fn test_bit_xor_accumulator() {
556 let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
557 let batches: Vec<_> = vec![vec![1, 2], vec![1]]
558 .into_iter()
559 .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
560 .collect();
561
562 let added = &[Arc::clone(&batches[0])];
563 let retracted = &[Arc::clone(&batches[1])];
564
565 accumulator.update_batch(added).unwrap();
567 assert_eq!(
568 accumulator.evaluate().unwrap(),
569 ScalarValue::UInt64(Some(3))
570 );
571
572 accumulator.retract_batch(retracted).unwrap();
574 assert_eq!(
575 accumulator.evaluate().unwrap(),
576 ScalarValue::UInt64(Some(2))
577 );
578 }
579}