datafusion_functions_aggregate/
bit_and_or_xor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators
19
20use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::mem::{size_of, size_of_val};
25
26use ahash::RandomState;
27use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
28use arrow::datatypes::{
29    ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type,
30    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
31};
32
33use datafusion_common::cast::as_list_array;
34use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::INTEGERS;
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
40    Signature, Volatility,
41};
42
43use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
46use std::sync::LazyLock;
47
48/// This macro helps create group accumulators based on bitwise operations typically used internally
49/// and might not be necessary for users to call directly.
50macro_rules! group_accumulator_helper {
51    ($t:ty, $dt:expr, $opr:expr) => {
52        match $opr {
53            BitwiseOperationType::And => Ok(Box::new(
54                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
55                    .with_starting_value(!0),
56            )),
57            BitwiseOperationType::Or => Ok(Box::new(
58                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
59            )),
60            BitwiseOperationType::Xor => Ok(Box::new(
61                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
62            )),
63        }
64    };
65}
66
67/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool)
68macro_rules! accumulator_helper {
69    ($t:ty, $opr:expr, $is_distinct: expr) => {
70        match $opr {
71            BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
72            BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
73            BitwiseOperationType::Xor => {
74                if $is_distinct {
75                    Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
76                } else {
77                    Ok(Box::<BitXorAccumulator<$t>>::default())
78                }
79            }
80        }
81    };
82}
83
84/// AND, OR and XOR only supports a subset of numeric types
85///
86/// `args` is [AccumulatorArgs]
87/// `opr` is [BitwiseOperationType]
88/// `is_distinct` is boolean value indicating whether the operation is distinct or not.
89macro_rules! downcast_bitwise_accumulator {
90    ($args:ident, $opr:expr, $is_distinct: expr) => {
91        match $args.return_field.data_type() {
92            DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
93            DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
94            DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
95            DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
96            DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
97            DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
98            DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
99            DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
100            _ => {
101                not_impl_err!(
102                    "{} not supported for {}: {}",
103                    stringify!($opr),
104                    $args.name,
105                    $args.return_field.data_type()
106                )
107            }
108        }
109    };
110}
111
112/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner.
113///
114/// `EXPR_FN` identifier used to name the generated expression function.
115/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function.
116/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed.
117/// `DOCUMENTATION` documentation for the UDAF
118macro_rules! make_bitwise_udaf_expr_and_func {
119    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
120        make_udaf_expr!(
121            $EXPR_FN,
122            expr_x,
123            concat!(
124                "Returns the bitwise",
125                stringify!($OPR_TYPE),
126                "of a group of values"
127            ),
128            $AGGREGATE_UDF_FN
129        );
130        create_func!(
131            $EXPR_FN,
132            $AGGREGATE_UDF_FN,
133            BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
134        );
135    };
136}
137
138static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
139    Documentation::builder(
140        DOC_SECTION_GENERAL,
141        "Computes the bitwise AND of all non-null input values.",
142        "bit_and(expression)",
143    )
144    .with_standard_argument("expression", Some("Integer"))
145    .build()
146});
147
148fn get_bit_and_doc() -> &'static Documentation {
149    &BIT_AND_DOC
150}
151
152static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
153    Documentation::builder(
154        DOC_SECTION_GENERAL,
155        "Computes the bitwise OR of all non-null input values.",
156        "bit_or(expression)",
157    )
158    .with_standard_argument("expression", Some("Integer"))
159    .build()
160});
161
162fn get_bit_or_doc() -> &'static Documentation {
163    &BIT_OR_DOC
164}
165
166static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
167    Documentation::builder(
168        DOC_SECTION_GENERAL,
169        "Computes the bitwise exclusive OR of all non-null input values.",
170        "bit_xor(expression)",
171    )
172    .with_standard_argument("expression", Some("Integer"))
173    .build()
174});
175
176fn get_bit_xor_doc() -> &'static Documentation {
177    &BIT_XOR_DOC
178}
179
180make_bitwise_udaf_expr_and_func!(
181    bit_and,
182    bit_and_udaf,
183    BitwiseOperationType::And,
184    get_bit_and_doc()
185);
186make_bitwise_udaf_expr_and_func!(
187    bit_or,
188    bit_or_udaf,
189    BitwiseOperationType::Or,
190    get_bit_or_doc()
191);
192make_bitwise_udaf_expr_and_func!(
193    bit_xor,
194    bit_xor_udaf,
195    BitwiseOperationType::Xor,
196    get_bit_xor_doc()
197);
198
199/// The different types of bitwise operations that can be performed.
200#[derive(Debug, Clone, Eq, PartialEq, Hash)]
201enum BitwiseOperationType {
202    And,
203    Or,
204    Xor,
205}
206
207impl Display for BitwiseOperationType {
208    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{self:?}")
210    }
211}
212
213/// [BitwiseOperation] struct encapsulates information about a bitwise operation.
214#[derive(Debug)]
215struct BitwiseOperation {
216    signature: Signature,
217    /// `operation` indicates the type of bitwise operation to be performed.
218    operation: BitwiseOperationType,
219    func_name: &'static str,
220    documentation: &'static Documentation,
221}
222
223impl BitwiseOperation {
224    pub fn new(
225        operator: BitwiseOperationType,
226        func_name: &'static str,
227        documentation: &'static Documentation,
228    ) -> Self {
229        Self {
230            operation: operator,
231            signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
232            func_name,
233            documentation,
234        }
235    }
236}
237
238impl AggregateUDFImpl for BitwiseOperation {
239    fn as_any(&self) -> &dyn Any {
240        self
241    }
242
243    fn name(&self) -> &str {
244        self.func_name
245    }
246
247    fn signature(&self) -> &Signature {
248        &self.signature
249    }
250
251    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
252        let arg_type = &arg_types[0];
253        if !arg_type.is_integer() {
254            return exec_err!(
255                "[return_type] {} not supported for {}",
256                self.name(),
257                arg_type
258            );
259        }
260        Ok(arg_type.clone())
261    }
262
263    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264        downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
265    }
266
267    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
268        if self.operation == BitwiseOperationType::Xor && args.is_distinct {
269            Ok(vec![Field::new_list(
270                format_state_name(
271                    args.name,
272                    format!("{} distinct", self.name()).as_str(),
273                ),
274                // See COMMENTS.md to understand why nullable is set to true
275                Field::new_list_field(args.return_type().clone(), true),
276                false,
277            )
278            .into()])
279        } else {
280            Ok(vec![Field::new(
281                format_state_name(args.name, self.name()),
282                args.return_field.data_type().clone(),
283                true,
284            )
285            .into()])
286        }
287    }
288
289    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
290        true
291    }
292
293    fn create_groups_accumulator(
294        &self,
295        args: AccumulatorArgs,
296    ) -> Result<Box<dyn GroupsAccumulator>> {
297        let data_type = args.return_field.data_type();
298        let operation = &self.operation;
299        downcast_integer! {
300            data_type => (group_accumulator_helper, data_type, operation),
301            _ => not_impl_err!(
302                "GroupsAccumulator not supported for {} with {}",
303                self.name(),
304                data_type
305            ),
306        }
307    }
308
309    fn reverse_expr(&self) -> ReversedUDAF {
310        ReversedUDAF::Identical
311    }
312
313    fn documentation(&self) -> Option<&Documentation> {
314        Some(self.documentation)
315    }
316
317    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
318        let Some(other) = other.as_any().downcast_ref::<Self>() else {
319            return false;
320        };
321        let Self {
322            signature,
323            operation,
324            func_name,
325            documentation,
326        } = self;
327        signature == &other.signature
328            && operation == &other.operation
329            && func_name == &other.func_name
330            && documentation == &other.documentation
331    }
332
333    fn hash_value(&self) -> u64 {
334        let Self {
335            signature,
336            operation,
337            func_name,
338            documentation,
339        } = self;
340        let mut hasher = DefaultHasher::new();
341        std::any::type_name::<Self>().hash(&mut hasher);
342        signature.hash(&mut hasher);
343        operation.hash(&mut hasher);
344        func_name.hash(&mut hasher);
345        documentation.hash(&mut hasher);
346        hasher.finish()
347    }
348}
349
350struct BitAndAccumulator<T: ArrowNumericType> {
351    value: Option<T::Native>,
352}
353
354impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
355    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
356        write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
357    }
358}
359
360impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
361    fn default() -> Self {
362        Self { value: None }
363    }
364}
365
366impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
367where
368    T::Native: std::ops::BitAnd<Output = T::Native>,
369{
370    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
371        if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
372            let v = self.value.get_or_insert(x);
373            *v = *v & x;
374        }
375        Ok(())
376    }
377
378    fn evaluate(&mut self) -> Result<ScalarValue> {
379        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
380    }
381
382    fn size(&self) -> usize {
383        size_of_val(self)
384    }
385
386    fn state(&mut self) -> Result<Vec<ScalarValue>> {
387        Ok(vec![self.evaluate()?])
388    }
389
390    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
391        self.update_batch(states)
392    }
393}
394
395struct BitOrAccumulator<T: ArrowNumericType> {
396    value: Option<T::Native>,
397}
398
399impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401        write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
402    }
403}
404
405impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
406    fn default() -> Self {
407        Self { value: None }
408    }
409}
410
411impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
412where
413    T::Native: std::ops::BitOr<Output = T::Native>,
414{
415    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
416        if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
417            let v = self.value.get_or_insert(T::Native::usize_as(0));
418            *v = *v | x;
419        }
420        Ok(())
421    }
422
423    fn evaluate(&mut self) -> Result<ScalarValue> {
424        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
425    }
426
427    fn size(&self) -> usize {
428        size_of_val(self)
429    }
430
431    fn state(&mut self) -> Result<Vec<ScalarValue>> {
432        Ok(vec![self.evaluate()?])
433    }
434
435    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
436        self.update_batch(states)
437    }
438}
439
440struct BitXorAccumulator<T: ArrowNumericType> {
441    value: Option<T::Native>,
442}
443
444impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
445    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446        write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
447    }
448}
449
450impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
451    fn default() -> Self {
452        Self { value: None }
453    }
454}
455
456impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
457where
458    T::Native: std::ops::BitXor<Output = T::Native>,
459{
460    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
461        if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
462            let v = self.value.get_or_insert(T::Native::usize_as(0));
463            *v = *v ^ x;
464        }
465        Ok(())
466    }
467
468    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
469        // XOR is it's own inverse
470        self.update_batch(values)
471    }
472
473    fn supports_retract_batch(&self) -> bool {
474        true
475    }
476
477    fn evaluate(&mut self) -> Result<ScalarValue> {
478        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
479    }
480
481    fn size(&self) -> usize {
482        size_of_val(self)
483    }
484
485    fn state(&mut self) -> Result<Vec<ScalarValue>> {
486        Ok(vec![self.evaluate()?])
487    }
488
489    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
490        self.update_batch(states)
491    }
492}
493
494struct DistinctBitXorAccumulator<T: ArrowNumericType> {
495    values: HashSet<T::Native, RandomState>,
496}
497
498impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
499    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
500        write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
501    }
502}
503
504impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
505    fn default() -> Self {
506        Self {
507            values: HashSet::default(),
508        }
509    }
510}
511
512impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
513where
514    T::Native: std::ops::BitXor<Output = T::Native> + Hash + Eq,
515{
516    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
517        if values.is_empty() {
518            return Ok(());
519        }
520
521        let array = values[0].as_primitive::<T>();
522        match array.nulls().filter(|x| x.null_count() > 0) {
523            Some(n) => {
524                for idx in n.valid_indices() {
525                    self.values.insert(array.value(idx));
526                }
527            }
528            None => array.values().iter().for_each(|x| {
529                self.values.insert(*x);
530            }),
531        }
532        Ok(())
533    }
534
535    fn evaluate(&mut self) -> Result<ScalarValue> {
536        let mut acc = T::Native::usize_as(0);
537        for distinct_value in self.values.iter() {
538            acc = acc ^ *distinct_value;
539        }
540        let v = (!self.values.is_empty()).then_some(acc);
541        ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
542    }
543
544    fn size(&self) -> usize {
545        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
546    }
547
548    fn state(&mut self) -> Result<Vec<ScalarValue>> {
549        // 1. Stores aggregate state in `ScalarValue::List`
550        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
551        let state_out = {
552            let values = self
553                .values
554                .iter()
555                .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
556                .collect::<Result<Vec<_>>>()?;
557
558            let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
559            vec![ScalarValue::List(arr)]
560        };
561        Ok(state_out)
562    }
563
564    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
565        if let Some(state) = states.first() {
566            let list_arr = as_list_array(state)?;
567            for arr in list_arr.iter().flatten() {
568                self.update_batch(&[arr])?;
569            }
570        }
571        Ok(())
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use std::sync::Arc;
578
579    use arrow::array::{ArrayRef, UInt64Array};
580    use arrow::datatypes::UInt64Type;
581    use datafusion_common::ScalarValue;
582
583    use crate::bit_and_or_xor::BitXorAccumulator;
584    use datafusion_expr::Accumulator;
585
586    #[test]
587    fn test_bit_xor_accumulator() {
588        let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
589        let batches: Vec<_> = vec![vec![1, 2], vec![1]]
590            .into_iter()
591            .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
592            .collect();
593
594        let added = &[Arc::clone(&batches[0])];
595        let retracted = &[Arc::clone(&batches[1])];
596
597        // XOR of 1..3 is 3
598        accumulator.update_batch(added).unwrap();
599        assert_eq!(
600            accumulator.evaluate().unwrap(),
601            ScalarValue::UInt64(Some(3))
602        );
603
604        // Removing [1] ^ 3 = 2
605        accumulator.retract_batch(retracted).unwrap();
606        assert_eq!(
607            accumulator.evaluate().unwrap(),
608            ScalarValue::UInt64(Some(2))
609        );
610    }
611}