datafusion_functions_aggregate/
bit_and_or_xor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators
19
20use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::mem::{size_of, size_of_val};
24
25use ahash::RandomState;
26use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
27use arrow::datatypes::{
28    ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type,
29    Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30};
31
32use datafusion_common::cast::as_list_array;
33use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
34use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35use datafusion_expr::type_coercion::aggregates::INTEGERS;
36use datafusion_expr::utils::format_state_name;
37use datafusion_expr::{
38    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
39    Signature, Volatility,
40};
41
42use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
43use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
44use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
45use std::sync::LazyLock;
46
47/// This macro helps create group accumulators based on bitwise operations typically used internally
48/// and might not be necessary for users to call directly.
49macro_rules! group_accumulator_helper {
50    ($t:ty, $dt:expr, $opr:expr) => {
51        match $opr {
52            BitwiseOperationType::And => Ok(Box::new(
53                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
54                    .with_starting_value(!0),
55            )),
56            BitwiseOperationType::Or => Ok(Box::new(
57                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
58            )),
59            BitwiseOperationType::Xor => Ok(Box::new(
60                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
61            )),
62        }
63    };
64}
65
66/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool)
67macro_rules! accumulator_helper {
68    ($t:ty, $opr:expr, $is_distinct: expr) => {
69        match $opr {
70            BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
71            BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
72            BitwiseOperationType::Xor => {
73                if $is_distinct {
74                    Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
75                } else {
76                    Ok(Box::<BitXorAccumulator<$t>>::default())
77                }
78            }
79        }
80    };
81}
82
83/// AND, OR and XOR only supports a subset of numeric types
84///
85/// `args` is [AccumulatorArgs]
86/// `opr` is [BitwiseOperationType]
87/// `is_distinct` is boolean value indicating whether the operation is distinct or not.
88macro_rules! downcast_bitwise_accumulator {
89    ($args:ident, $opr:expr, $is_distinct: expr) => {
90        match $args.return_type {
91            DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
92            DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
93            DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
94            DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
95            DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
96            DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
97            DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
98            DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
99            _ => {
100                not_impl_err!(
101                    "{} not supported for {}: {}",
102                    stringify!($opr),
103                    $args.name,
104                    $args.return_type
105                )
106            }
107        }
108    };
109}
110
111/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner.
112///
113/// `EXPR_FN` identifier used to name the generated expression function.
114/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function.
115/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed.
116/// `DOCUMENTATION` documentation for the UDAF
117macro_rules! make_bitwise_udaf_expr_and_func {
118    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
119        make_udaf_expr!(
120            $EXPR_FN,
121            expr_x,
122            concat!(
123                "Returns the bitwise",
124                stringify!($OPR_TYPE),
125                "of a group of values"
126            ),
127            $AGGREGATE_UDF_FN
128        );
129        create_func!(
130            $EXPR_FN,
131            $AGGREGATE_UDF_FN,
132            BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
133        );
134    };
135}
136
137static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
138    Documentation::builder(
139        DOC_SECTION_GENERAL,
140        "Computes the bitwise AND of all non-null input values.",
141        "bit_and(expression)",
142    )
143    .with_standard_argument("expression", Some("Integer"))
144    .build()
145});
146
147fn get_bit_and_doc() -> &'static Documentation {
148    &BIT_AND_DOC
149}
150
151static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
152    Documentation::builder(
153        DOC_SECTION_GENERAL,
154        "Computes the bitwise OR of all non-null input values.",
155        "bit_or(expression)",
156    )
157    .with_standard_argument("expression", Some("Integer"))
158    .build()
159});
160
161fn get_bit_or_doc() -> &'static Documentation {
162    &BIT_OR_DOC
163}
164
165static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
166    Documentation::builder(
167        DOC_SECTION_GENERAL,
168        "Computes the bitwise exclusive OR of all non-null input values.",
169        "bit_xor(expression)",
170    )
171    .with_standard_argument("expression", Some("Integer"))
172    .build()
173});
174
175fn get_bit_xor_doc() -> &'static Documentation {
176    &BIT_XOR_DOC
177}
178
179make_bitwise_udaf_expr_and_func!(
180    bit_and,
181    bit_and_udaf,
182    BitwiseOperationType::And,
183    get_bit_and_doc()
184);
185make_bitwise_udaf_expr_and_func!(
186    bit_or,
187    bit_or_udaf,
188    BitwiseOperationType::Or,
189    get_bit_or_doc()
190);
191make_bitwise_udaf_expr_and_func!(
192    bit_xor,
193    bit_xor_udaf,
194    BitwiseOperationType::Xor,
195    get_bit_xor_doc()
196);
197
198/// The different types of bitwise operations that can be performed.
199#[derive(Debug, Clone, Eq, PartialEq)]
200enum BitwiseOperationType {
201    And,
202    Or,
203    Xor,
204}
205
206impl Display for BitwiseOperationType {
207    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208        write!(f, "{:?}", self)
209    }
210}
211
212/// [BitwiseOperation] struct encapsulates information about a bitwise operation.
213#[derive(Debug)]
214struct BitwiseOperation {
215    signature: Signature,
216    /// `operation` indicates the type of bitwise operation to be performed.
217    operation: BitwiseOperationType,
218    func_name: &'static str,
219    documentation: &'static Documentation,
220}
221
222impl BitwiseOperation {
223    pub fn new(
224        operator: BitwiseOperationType,
225        func_name: &'static str,
226        documentation: &'static Documentation,
227    ) -> Self {
228        Self {
229            operation: operator,
230            signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
231            func_name,
232            documentation,
233        }
234    }
235}
236
237impl AggregateUDFImpl for BitwiseOperation {
238    fn as_any(&self) -> &dyn Any {
239        self
240    }
241
242    fn name(&self) -> &str {
243        self.func_name
244    }
245
246    fn signature(&self) -> &Signature {
247        &self.signature
248    }
249
250    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
251        let arg_type = &arg_types[0];
252        if !arg_type.is_integer() {
253            return exec_err!(
254                "[return_type] {} not supported for {}",
255                self.name(),
256                arg_type
257            );
258        }
259        Ok(arg_type.clone())
260    }
261
262    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
263        downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
264    }
265
266    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
267        if self.operation == BitwiseOperationType::Xor && args.is_distinct {
268            Ok(vec![Field::new_list(
269                format_state_name(
270                    args.name,
271                    format!("{} distinct", self.name()).as_str(),
272                ),
273                // See COMMENTS.md to understand why nullable is set to true
274                Field::new_list_field(args.return_type.clone(), true),
275                false,
276            )])
277        } else {
278            Ok(vec![Field::new(
279                format_state_name(args.name, self.name()),
280                args.return_type.clone(),
281                true,
282            )])
283        }
284    }
285
286    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
287        true
288    }
289
290    fn create_groups_accumulator(
291        &self,
292        args: AccumulatorArgs,
293    ) -> Result<Box<dyn GroupsAccumulator>> {
294        let data_type = args.return_type;
295        let operation = &self.operation;
296        downcast_integer! {
297            data_type => (group_accumulator_helper, data_type, operation),
298            _ => not_impl_err!(
299                "GroupsAccumulator not supported for {} with {}",
300                self.name(),
301                data_type
302            ),
303        }
304    }
305
306    fn reverse_expr(&self) -> ReversedUDAF {
307        ReversedUDAF::Identical
308    }
309
310    fn documentation(&self) -> Option<&Documentation> {
311        Some(self.documentation)
312    }
313}
314
315struct BitAndAccumulator<T: ArrowNumericType> {
316    value: Option<T::Native>,
317}
318
319impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
320    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
321        write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
322    }
323}
324
325impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
326    fn default() -> Self {
327        Self { value: None }
328    }
329}
330
331impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
332where
333    T::Native: std::ops::BitAnd<Output = T::Native>,
334{
335    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
336        if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
337            let v = self.value.get_or_insert(x);
338            *v = *v & x;
339        }
340        Ok(())
341    }
342
343    fn evaluate(&mut self) -> Result<ScalarValue> {
344        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
345    }
346
347    fn size(&self) -> usize {
348        size_of_val(self)
349    }
350
351    fn state(&mut self) -> Result<Vec<ScalarValue>> {
352        Ok(vec![self.evaluate()?])
353    }
354
355    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
356        self.update_batch(states)
357    }
358}
359
360struct BitOrAccumulator<T: ArrowNumericType> {
361    value: Option<T::Native>,
362}
363
364impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
365    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
366        write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
367    }
368}
369
370impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
371    fn default() -> Self {
372        Self { value: None }
373    }
374}
375
376impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
377where
378    T::Native: std::ops::BitOr<Output = T::Native>,
379{
380    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
381        if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
382            let v = self.value.get_or_insert(T::Native::usize_as(0));
383            *v = *v | x;
384        }
385        Ok(())
386    }
387
388    fn evaluate(&mut self) -> Result<ScalarValue> {
389        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
390    }
391
392    fn size(&self) -> usize {
393        size_of_val(self)
394    }
395
396    fn state(&mut self) -> Result<Vec<ScalarValue>> {
397        Ok(vec![self.evaluate()?])
398    }
399
400    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
401        self.update_batch(states)
402    }
403}
404
405struct BitXorAccumulator<T: ArrowNumericType> {
406    value: Option<T::Native>,
407}
408
409impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411        write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
412    }
413}
414
415impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
416    fn default() -> Self {
417        Self { value: None }
418    }
419}
420
421impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
422where
423    T::Native: std::ops::BitXor<Output = T::Native>,
424{
425    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
426        if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
427            let v = self.value.get_or_insert(T::Native::usize_as(0));
428            *v = *v ^ x;
429        }
430        Ok(())
431    }
432
433    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
434        // XOR is it's own inverse
435        self.update_batch(values)
436    }
437
438    fn supports_retract_batch(&self) -> bool {
439        true
440    }
441
442    fn evaluate(&mut self) -> Result<ScalarValue> {
443        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
444    }
445
446    fn size(&self) -> usize {
447        size_of_val(self)
448    }
449
450    fn state(&mut self) -> Result<Vec<ScalarValue>> {
451        Ok(vec![self.evaluate()?])
452    }
453
454    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
455        self.update_batch(states)
456    }
457}
458
459struct DistinctBitXorAccumulator<T: ArrowNumericType> {
460    values: HashSet<T::Native, RandomState>,
461}
462
463impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
464    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465        write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
466    }
467}
468
469impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
470    fn default() -> Self {
471        Self {
472            values: HashSet::default(),
473        }
474    }
475}
476
477impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
478where
479    T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
480{
481    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
482        if values.is_empty() {
483            return Ok(());
484        }
485
486        let array = values[0].as_primitive::<T>();
487        match array.nulls().filter(|x| x.null_count() > 0) {
488            Some(n) => {
489                for idx in n.valid_indices() {
490                    self.values.insert(array.value(idx));
491                }
492            }
493            None => array.values().iter().for_each(|x| {
494                self.values.insert(*x);
495            }),
496        }
497        Ok(())
498    }
499
500    fn evaluate(&mut self) -> Result<ScalarValue> {
501        let mut acc = T::Native::usize_as(0);
502        for distinct_value in self.values.iter() {
503            acc = acc ^ *distinct_value;
504        }
505        let v = (!self.values.is_empty()).then_some(acc);
506        ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
507    }
508
509    fn size(&self) -> usize {
510        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
511    }
512
513    fn state(&mut self) -> Result<Vec<ScalarValue>> {
514        // 1. Stores aggregate state in `ScalarValue::List`
515        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
516        let state_out = {
517            let values = self
518                .values
519                .iter()
520                .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
521                .collect::<Result<Vec<_>>>()?;
522
523            let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
524            vec![ScalarValue::List(arr)]
525        };
526        Ok(state_out)
527    }
528
529    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
530        if let Some(state) = states.first() {
531            let list_arr = as_list_array(state)?;
532            for arr in list_arr.iter().flatten() {
533                self.update_batch(&[arr])?;
534            }
535        }
536        Ok(())
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use std::sync::Arc;
543
544    use arrow::array::{ArrayRef, UInt64Array};
545    use arrow::datatypes::UInt64Type;
546    use datafusion_common::ScalarValue;
547
548    use crate::bit_and_or_xor::BitXorAccumulator;
549    use datafusion_expr::Accumulator;
550
551    #[test]
552    fn test_bit_xor_accumulator() {
553        let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
554        let batches: Vec<_> = vec![vec![1, 2], vec![1]]
555            .into_iter()
556            .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
557            .collect();
558
559        let added = &[Arc::clone(&batches[0])];
560        let retracted = &[Arc::clone(&batches[1])];
561
562        // XOR of 1..3 is 3
563        accumulator.update_batch(added).unwrap();
564        assert_eq!(
565            accumulator.evaluate().unwrap(),
566            ScalarValue::UInt64(Some(3))
567        );
568
569        // Removing [1] ^ 3 = 2
570        accumulator.retract_batch(retracted).unwrap();
571        assert_eq!(
572            accumulator.evaluate().unwrap(),
573            ScalarValue::UInt64(Some(2))
574        );
575    }
576}