1use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::mem::{size_of, size_of_val};
24
25use ahash::RandomState;
26use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
27use arrow::datatypes::{
28 ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type,
29 Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30};
31
32use datafusion_common::cast::as_list_array;
33use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
34use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35use datafusion_expr::type_coercion::aggregates::INTEGERS;
36use datafusion_expr::utils::format_state_name;
37use datafusion_expr::{
38 Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
39 Signature, Volatility,
40};
41
42use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
44use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
45use std::sync::LazyLock;
46
47macro_rules! group_accumulator_helper {
50 ($t:ty, $dt:expr, $opr:expr) => {
51 match $opr {
52 BitwiseOperationType::And => Ok(Box::new(
53 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
54 .with_starting_value(!0),
55 )),
56 BitwiseOperationType::Or => Ok(Box::new(
57 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
58 )),
59 BitwiseOperationType::Xor => Ok(Box::new(
60 PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
61 )),
62 }
63 };
64}
65
66macro_rules! accumulator_helper {
68 ($t:ty, $opr:expr, $is_distinct: expr) => {
69 match $opr {
70 BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
71 BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
72 BitwiseOperationType::Xor => {
73 if $is_distinct {
74 Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
75 } else {
76 Ok(Box::<BitXorAccumulator<$t>>::default())
77 }
78 }
79 }
80 };
81}
82
83macro_rules! downcast_bitwise_accumulator {
89 ($args:ident, $opr:expr, $is_distinct: expr) => {
90 match $args.return_type {
91 DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
92 DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
93 DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
94 DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
95 DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
96 DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
97 DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
98 DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
99 _ => {
100 not_impl_err!(
101 "{} not supported for {}: {}",
102 stringify!($opr),
103 $args.name,
104 $args.return_type
105 )
106 }
107 }
108 };
109}
110
111macro_rules! make_bitwise_udaf_expr_and_func {
118 ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
119 make_udaf_expr!(
120 $EXPR_FN,
121 expr_x,
122 concat!(
123 "Returns the bitwise",
124 stringify!($OPR_TYPE),
125 "of a group of values"
126 ),
127 $AGGREGATE_UDF_FN
128 );
129 create_func!(
130 $EXPR_FN,
131 $AGGREGATE_UDF_FN,
132 BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
133 );
134 };
135}
136
137static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
138 Documentation::builder(
139 DOC_SECTION_GENERAL,
140 "Computes the bitwise AND of all non-null input values.",
141 "bit_and(expression)",
142 )
143 .with_standard_argument("expression", Some("Integer"))
144 .build()
145});
146
147fn get_bit_and_doc() -> &'static Documentation {
148 &BIT_AND_DOC
149}
150
151static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
152 Documentation::builder(
153 DOC_SECTION_GENERAL,
154 "Computes the bitwise OR of all non-null input values.",
155 "bit_or(expression)",
156 )
157 .with_standard_argument("expression", Some("Integer"))
158 .build()
159});
160
161fn get_bit_or_doc() -> &'static Documentation {
162 &BIT_OR_DOC
163}
164
165static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
166 Documentation::builder(
167 DOC_SECTION_GENERAL,
168 "Computes the bitwise exclusive OR of all non-null input values.",
169 "bit_xor(expression)",
170 )
171 .with_standard_argument("expression", Some("Integer"))
172 .build()
173});
174
175fn get_bit_xor_doc() -> &'static Documentation {
176 &BIT_XOR_DOC
177}
178
179make_bitwise_udaf_expr_and_func!(
180 bit_and,
181 bit_and_udaf,
182 BitwiseOperationType::And,
183 get_bit_and_doc()
184);
185make_bitwise_udaf_expr_and_func!(
186 bit_or,
187 bit_or_udaf,
188 BitwiseOperationType::Or,
189 get_bit_or_doc()
190);
191make_bitwise_udaf_expr_and_func!(
192 bit_xor,
193 bit_xor_udaf,
194 BitwiseOperationType::Xor,
195 get_bit_xor_doc()
196);
197
198#[derive(Debug, Clone, Eq, PartialEq)]
200enum BitwiseOperationType {
201 And,
202 Or,
203 Xor,
204}
205
206impl Display for BitwiseOperationType {
207 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208 write!(f, "{:?}", self)
209 }
210}
211
212#[derive(Debug)]
214struct BitwiseOperation {
215 signature: Signature,
216 operation: BitwiseOperationType,
218 func_name: &'static str,
219 documentation: &'static Documentation,
220}
221
222impl BitwiseOperation {
223 pub fn new(
224 operator: BitwiseOperationType,
225 func_name: &'static str,
226 documentation: &'static Documentation,
227 ) -> Self {
228 Self {
229 operation: operator,
230 signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
231 func_name,
232 documentation,
233 }
234 }
235}
236
237impl AggregateUDFImpl for BitwiseOperation {
238 fn as_any(&self) -> &dyn Any {
239 self
240 }
241
242 fn name(&self) -> &str {
243 self.func_name
244 }
245
246 fn signature(&self) -> &Signature {
247 &self.signature
248 }
249
250 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
251 let arg_type = &arg_types[0];
252 if !arg_type.is_integer() {
253 return exec_err!(
254 "[return_type] {} not supported for {}",
255 self.name(),
256 arg_type
257 );
258 }
259 Ok(arg_type.clone())
260 }
261
262 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
263 downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
264 }
265
266 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
267 if self.operation == BitwiseOperationType::Xor && args.is_distinct {
268 Ok(vec![Field::new_list(
269 format_state_name(
270 args.name,
271 format!("{} distinct", self.name()).as_str(),
272 ),
273 Field::new_list_field(args.return_type.clone(), true),
275 false,
276 )])
277 } else {
278 Ok(vec![Field::new(
279 format_state_name(args.name, self.name()),
280 args.return_type.clone(),
281 true,
282 )])
283 }
284 }
285
286 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
287 true
288 }
289
290 fn create_groups_accumulator(
291 &self,
292 args: AccumulatorArgs,
293 ) -> Result<Box<dyn GroupsAccumulator>> {
294 let data_type = args.return_type;
295 let operation = &self.operation;
296 downcast_integer! {
297 data_type => (group_accumulator_helper, data_type, operation),
298 _ => not_impl_err!(
299 "GroupsAccumulator not supported for {} with {}",
300 self.name(),
301 data_type
302 ),
303 }
304 }
305
306 fn reverse_expr(&self) -> ReversedUDAF {
307 ReversedUDAF::Identical
308 }
309
310 fn documentation(&self) -> Option<&Documentation> {
311 Some(self.documentation)
312 }
313}
314
315struct BitAndAccumulator<T: ArrowNumericType> {
316 value: Option<T::Native>,
317}
318
319impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
320 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
321 write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
322 }
323}
324
325impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
326 fn default() -> Self {
327 Self { value: None }
328 }
329}
330
331impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
332where
333 T::Native: std::ops::BitAnd<Output = T::Native>,
334{
335 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
336 if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
337 let v = self.value.get_or_insert(x);
338 *v = *v & x;
339 }
340 Ok(())
341 }
342
343 fn evaluate(&mut self) -> Result<ScalarValue> {
344 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
345 }
346
347 fn size(&self) -> usize {
348 size_of_val(self)
349 }
350
351 fn state(&mut self) -> Result<Vec<ScalarValue>> {
352 Ok(vec![self.evaluate()?])
353 }
354
355 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
356 self.update_batch(states)
357 }
358}
359
360struct BitOrAccumulator<T: ArrowNumericType> {
361 value: Option<T::Native>,
362}
363
364impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
365 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
366 write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
367 }
368}
369
370impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
371 fn default() -> Self {
372 Self { value: None }
373 }
374}
375
376impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
377where
378 T::Native: std::ops::BitOr<Output = T::Native>,
379{
380 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
381 if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
382 let v = self.value.get_or_insert(T::Native::usize_as(0));
383 *v = *v | x;
384 }
385 Ok(())
386 }
387
388 fn evaluate(&mut self) -> Result<ScalarValue> {
389 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
390 }
391
392 fn size(&self) -> usize {
393 size_of_val(self)
394 }
395
396 fn state(&mut self) -> Result<Vec<ScalarValue>> {
397 Ok(vec![self.evaluate()?])
398 }
399
400 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
401 self.update_batch(states)
402 }
403}
404
405struct BitXorAccumulator<T: ArrowNumericType> {
406 value: Option<T::Native>,
407}
408
409impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
412 }
413}
414
415impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
416 fn default() -> Self {
417 Self { value: None }
418 }
419}
420
421impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
422where
423 T::Native: std::ops::BitXor<Output = T::Native>,
424{
425 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
426 if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
427 let v = self.value.get_or_insert(T::Native::usize_as(0));
428 *v = *v ^ x;
429 }
430 Ok(())
431 }
432
433 fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
434 self.update_batch(values)
436 }
437
438 fn supports_retract_batch(&self) -> bool {
439 true
440 }
441
442 fn evaluate(&mut self) -> Result<ScalarValue> {
443 ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
444 }
445
446 fn size(&self) -> usize {
447 size_of_val(self)
448 }
449
450 fn state(&mut self) -> Result<Vec<ScalarValue>> {
451 Ok(vec![self.evaluate()?])
452 }
453
454 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
455 self.update_batch(states)
456 }
457}
458
459struct DistinctBitXorAccumulator<T: ArrowNumericType> {
460 values: HashSet<T::Native, RandomState>,
461}
462
463impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465 write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
466 }
467}
468
469impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
470 fn default() -> Self {
471 Self {
472 values: HashSet::default(),
473 }
474 }
475}
476
477impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
478where
479 T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
480{
481 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
482 if values.is_empty() {
483 return Ok(());
484 }
485
486 let array = values[0].as_primitive::<T>();
487 match array.nulls().filter(|x| x.null_count() > 0) {
488 Some(n) => {
489 for idx in n.valid_indices() {
490 self.values.insert(array.value(idx));
491 }
492 }
493 None => array.values().iter().for_each(|x| {
494 self.values.insert(*x);
495 }),
496 }
497 Ok(())
498 }
499
500 fn evaluate(&mut self) -> Result<ScalarValue> {
501 let mut acc = T::Native::usize_as(0);
502 for distinct_value in self.values.iter() {
503 acc = acc ^ *distinct_value;
504 }
505 let v = (!self.values.is_empty()).then_some(acc);
506 ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
507 }
508
509 fn size(&self) -> usize {
510 size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
511 }
512
513 fn state(&mut self) -> Result<Vec<ScalarValue>> {
514 let state_out = {
517 let values = self
518 .values
519 .iter()
520 .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
521 .collect::<Result<Vec<_>>>()?;
522
523 let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
524 vec![ScalarValue::List(arr)]
525 };
526 Ok(state_out)
527 }
528
529 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
530 if let Some(state) = states.first() {
531 let list_arr = as_list_array(state)?;
532 for arr in list_arr.iter().flatten() {
533 self.update_batch(&[arr])?;
534 }
535 }
536 Ok(())
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use std::sync::Arc;
543
544 use arrow::array::{ArrayRef, UInt64Array};
545 use arrow::datatypes::UInt64Type;
546 use datafusion_common::ScalarValue;
547
548 use crate::bit_and_or_xor::BitXorAccumulator;
549 use datafusion_expr::Accumulator;
550
551 #[test]
552 fn test_bit_xor_accumulator() {
553 let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
554 let batches: Vec<_> = vec![vec![1, 2], vec![1]]
555 .into_iter()
556 .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
557 .collect();
558
559 let added = &[Arc::clone(&batches[0])];
560 let retracted = &[Arc::clone(&batches[1])];
561
562 accumulator.update_batch(added).unwrap();
564 assert_eq!(
565 accumulator.evaluate().unwrap(),
566 ScalarValue::UInt64(Some(3))
567 );
568
569 accumulator.retract_batch(retracted).unwrap();
571 assert_eq!(
572 accumulator.evaluate().unwrap(),
573 ScalarValue::UInt64(Some(2))
574 );
575 }
576}