1use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::mem::{size_of, size_of_val};
25
26use ahash::RandomState;
27use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
28use arrow::datatypes::{
29 ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type,
30 Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
31};
32
33use datafusion_common::cast::as_list_array;
34use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::INTEGERS;
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
40 Signature, Volatility,
41};
42
43use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
46use std::sync::LazyLock;
47
48macro_rules! group_accumulator_helper {
51 ($t:ty, $dt:expr, $opr:expr) => {
52 match $opr {
53 BitwiseOperationType::And => Ok(Box::new(
54 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
55 .with_starting_value(!0),
56 )),
57 BitwiseOperationType::Or => Ok(Box::new(
58 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
59 )),
60 BitwiseOperationType::Xor => Ok(Box::new(
61 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
62 )),
63 }
64 };
65}
66
67macro_rules! accumulator_helper {
69 ($t:ty, $opr:expr, $is_distinct: expr) => {
70 match $opr {
71 BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
72 BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
73 BitwiseOperationType::Xor => {
74 if $is_distinct {
75 Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
76 } else {
77 Ok(Box::<BitXorAccumulator<$t>>::default())
78 }
79 }
80 }
81 };
82}
83
84macro_rules! downcast_bitwise_accumulator {
90 ($args:ident, $opr:expr, $is_distinct: expr) => {
91 match $args.return_field.data_type() {
92 DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
93 DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
94 DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
95 DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
96 DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
97 DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
98 DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
99 DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
100 _ => {
101 not_impl_err!(
102 "{} not supported for {}: {}",
103 stringify!($opr),
104 $args.name,
105 $args.return_field.data_type()
106 )
107 }
108 }
109 };
110}
111
112macro_rules! make_bitwise_udaf_expr_and_func {
119 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
120 make_udaf_expr!(
121 $EXPR_FN,
122 expr_x,
123 concat!(
124 "Returns the bitwise",
125 stringify!($OPR_TYPE),
126 "of a group of values"
127 ),
128 $AGGREGATE_UDF_FN
129 );
130 create_func!(
131 $EXPR_FN,
132 $AGGREGATE_UDF_FN,
133 BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
134 );
135 };
136}
137
138static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
139 Documentation::builder(
140 DOC_SECTION_GENERAL,
141 "Computes the bitwise AND of all non-null input values.",
142 "bit_and(expression)",
143 )
144 .with_standard_argument("expression", Some("Integer"))
145 .build()
146});
147
148fn get_bit_and_doc() -> &'static Documentation {
149 &BIT_AND_DOC
150}
151
152static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
153 Documentation::builder(
154 DOC_SECTION_GENERAL,
155 "Computes the bitwise OR of all non-null input values.",
156 "bit_or(expression)",
157 )
158 .with_standard_argument("expression", Some("Integer"))
159 .build()
160});
161
162fn get_bit_or_doc() -> &'static Documentation {
163 &BIT_OR_DOC
164}
165
166static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
167 Documentation::builder(
168 DOC_SECTION_GENERAL,
169 "Computes the bitwise exclusive OR of all non-null input values.",
170 "bit_xor(expression)",
171 )
172 .with_standard_argument("expression", Some("Integer"))
173 .build()
174});
175
176fn get_bit_xor_doc() -> &'static Documentation {
177 &BIT_XOR_DOC
178}
179
180make_bitwise_udaf_expr_and_func!(
181 bit_and,
182 bit_and_udaf,
183 BitwiseOperationType::And,
184 get_bit_and_doc()
185);
186make_bitwise_udaf_expr_and_func!(
187 bit_or,
188 bit_or_udaf,
189 BitwiseOperationType::Or,
190 get_bit_or_doc()
191);
192make_bitwise_udaf_expr_and_func!(
193 bit_xor,
194 bit_xor_udaf,
195 BitwiseOperationType::Xor,
196 get_bit_xor_doc()
197);
198
199#[derive(Debug, Clone, Eq, PartialEq, Hash)]
201enum BitwiseOperationType {
202 And,
203 Or,
204 Xor,
205}
206
207impl Display for BitwiseOperationType {
208 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209 write!(f, "{self:?}")
210 }
211}
212
213#[derive(Debug)]
215struct BitwiseOperation {
216 signature: Signature,
217 operation: BitwiseOperationType,
219 func_name: &'static str,
220 documentation: &'static Documentation,
221}
222
223impl BitwiseOperation {
224 pub fn new(
225 operator: BitwiseOperationType,
226 func_name: &'static str,
227 documentation: &'static Documentation,
228 ) -> Self {
229 Self {
230 operation: operator,
231 signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
232 func_name,
233 documentation,
234 }
235 }
236}
237
238impl AggregateUDFImpl for BitwiseOperation {
239 fn as_any(&self) -> &dyn Any {
240 self
241 }
242
243 fn name(&self) -> &str {
244 self.func_name
245 }
246
247 fn signature(&self) -> &Signature {
248 &self.signature
249 }
250
251 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
252 let arg_type = &arg_types[0];
253 if !arg_type.is_integer() {
254 return exec_err!(
255 "[return_type] {} not supported for {}",
256 self.name(),
257 arg_type
258 );
259 }
260 Ok(arg_type.clone())
261 }
262
263 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264 downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
265 }
266
267 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
268 if self.operation == BitwiseOperationType::Xor && args.is_distinct {
269 Ok(vec![Field::new_list(
270 format_state_name(
271 args.name,
272 format!("{} distinct", self.name()).as_str(),
273 ),
274 Field::new_list_field(args.return_type().clone(), true),
276 false,
277 )
278 .into()])
279 } else {
280 Ok(vec![Field::new(
281 format_state_name(args.name, self.name()),
282 args.return_field.data_type().clone(),
283 true,
284 )
285 .into()])
286 }
287 }
288
289 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
290 true
291 }
292
293 fn create_groups_accumulator(
294 &self,
295 args: AccumulatorArgs,
296 ) -> Result<Box<dyn GroupsAccumulator>> {
297 let data_type = args.return_field.data_type();
298 let operation = &self.operation;
299 downcast_integer! {
300 data_type => (group_accumulator_helper, data_type, operation),
301 _ => not_impl_err!(
302 "GroupsAccumulator not supported for {} with {}",
303 self.name(),
304 data_type
305 ),
306 }
307 }
308
309 fn reverse_expr(&self) -> ReversedUDAF {
310 ReversedUDAF::Identical
311 }
312
313 fn documentation(&self) -> Option<&Documentation> {
314 Some(self.documentation)
315 }
316
317 fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
318 let Some(other) = other.as_any().downcast_ref::<Self>() else {
319 return false;
320 };
321 let Self {
322 signature,
323 operation,
324 func_name,
325 documentation,
326 } = self;
327 signature == &other.signature
328 && operation == &other.operation
329 && func_name == &other.func_name
330 && documentation == &other.documentation
331 }
332
333 fn hash_value(&self) -> u64 {
334 let Self {
335 signature,
336 operation,
337 func_name,
338 documentation,
339 } = self;
340 let mut hasher = DefaultHasher::new();
341 std::any::type_name::<Self>().hash(&mut hasher);
342 signature.hash(&mut hasher);
343 operation.hash(&mut hasher);
344 func_name.hash(&mut hasher);
345 documentation.hash(&mut hasher);
346 hasher.finish()
347 }
348}
349
350struct BitAndAccumulator<T: ArrowNumericType> {
351 value: Option<T::Native>,
352}
353
354impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
355 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
356 write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
357 }
358}
359
360impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
361 fn default() -> Self {
362 Self { value: None }
363 }
364}
365
366impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
367where
368 T::Native: std::ops::BitAnd<Output = T::Native>,
369{
370 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
371 if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
372 let v = self.value.get_or_insert(x);
373 *v = *v & x;
374 }
375 Ok(())
376 }
377
378 fn evaluate(&mut self) -> Result<ScalarValue> {
379 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
380 }
381
382 fn size(&self) -> usize {
383 size_of_val(self)
384 }
385
386 fn state(&mut self) -> Result<Vec<ScalarValue>> {
387 Ok(vec![self.evaluate()?])
388 }
389
390 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
391 self.update_batch(states)
392 }
393}
394
395struct BitOrAccumulator<T: ArrowNumericType> {
396 value: Option<T::Native>,
397}
398
399impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
402 }
403}
404
405impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
406 fn default() -> Self {
407 Self { value: None }
408 }
409}
410
411impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
412where
413 T::Native: std::ops::BitOr<Output = T::Native>,
414{
415 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
416 if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
417 let v = self.value.get_or_insert(T::Native::usize_as(0));
418 *v = *v | x;
419 }
420 Ok(())
421 }
422
423 fn evaluate(&mut self) -> Result<ScalarValue> {
424 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
425 }
426
427 fn size(&self) -> usize {
428 size_of_val(self)
429 }
430
431 fn state(&mut self) -> Result<Vec<ScalarValue>> {
432 Ok(vec![self.evaluate()?])
433 }
434
435 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
436 self.update_batch(states)
437 }
438}
439
440struct BitXorAccumulator<T: ArrowNumericType> {
441 value: Option<T::Native>,
442}
443
444impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
445 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446 write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
447 }
448}
449
450impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
451 fn default() -> Self {
452 Self { value: None }
453 }
454}
455
456impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
457where
458 T::Native: std::ops::BitXor<Output = T::Native>,
459{
460 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
461 if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
462 let v = self.value.get_or_insert(T::Native::usize_as(0));
463 *v = *v ^ x;
464 }
465 Ok(())
466 }
467
468 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469 self.update_batch(values)
471 }
472
473 fn supports_retract_batch(&self) -> bool {
474 true
475 }
476
477 fn evaluate(&mut self) -> Result<ScalarValue> {
478 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
479 }
480
481 fn size(&self) -> usize {
482 size_of_val(self)
483 }
484
485 fn state(&mut self) -> Result<Vec<ScalarValue>> {
486 Ok(vec![self.evaluate()?])
487 }
488
489 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
490 self.update_batch(states)
491 }
492}
493
494struct DistinctBitXorAccumulator<T: ArrowNumericType> {
495 values: HashSet<T::Native, RandomState>,
496}
497
498impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
499 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
500 write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
501 }
502}
503
504impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
505 fn default() -> Self {
506 Self {
507 values: HashSet::default(),
508 }
509 }
510}
511
512impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
513where
514 T::Native: std::ops::BitXor<Output = T::Native> + Hash + Eq,
515{
516 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517 if values.is_empty() {
518 return Ok(());
519 }
520
521 let array = values[0].as_primitive::<T>();
522 match array.nulls().filter(|x| x.null_count() > 0) {
523 Some(n) => {
524 for idx in n.valid_indices() {
525 self.values.insert(array.value(idx));
526 }
527 }
528 None => array.values().iter().for_each(|x| {
529 self.values.insert(*x);
530 }),
531 }
532 Ok(())
533 }
534
535 fn evaluate(&mut self) -> Result<ScalarValue> {
536 let mut acc = T::Native::usize_as(0);
537 for distinct_value in self.values.iter() {
538 acc = acc ^ *distinct_value;
539 }
540 let v = (!self.values.is_empty()).then_some(acc);
541 ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
542 }
543
544 fn size(&self) -> usize {
545 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
546 }
547
548 fn state(&mut self) -> Result<Vec<ScalarValue>> {
549 let state_out = {
552 let values = self
553 .values
554 .iter()
555 .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
556 .collect::<Result<Vec<_>>>()?;
557
558 let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
559 vec![ScalarValue::List(arr)]
560 };
561 Ok(state_out)
562 }
563
564 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
565 if let Some(state) = states.first() {
566 let list_arr = as_list_array(state)?;
567 for arr in list_arr.iter().flatten() {
568 self.update_batch(&[arr])?;
569 }
570 }
571 Ok(())
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use std::sync::Arc;
578
579 use arrow::array::{ArrayRef, UInt64Array};
580 use arrow::datatypes::UInt64Type;
581 use datafusion_common::ScalarValue;
582
583 use crate::bit_and_or_xor::BitXorAccumulator;
584 use datafusion_expr::Accumulator;
585
586 #[test]
587 fn test_bit_xor_accumulator() {
588 let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
589 let batches: Vec<_> = vec![vec![1, 2], vec![1]]
590 .into_iter()
591 .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
592 .collect();
593
594 let added = &[Arc::clone(&batches[0])];
595 let retracted = &[Arc::clone(&batches[1])];
596
597 accumulator.update_batch(added).unwrap();
599 assert_eq!(
600 accumulator.evaluate().unwrap(),
601 ScalarValue::UInt64(Some(3))
602 );
603
604 accumulator.retract_batch(retracted).unwrap();
606 assert_eq!(
607 accumulator.evaluate().unwrap(),
608 ScalarValue::UInt64(Some(2))
609 );
610 }
611}