datafusion_functions_aggregate/
bit_and_or_xor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators
19
20use std::any::Any;
21use std::collections::HashSet;
22use std::fmt::{Display, Formatter};
23use std::hash::Hash;
24use std::mem::{size_of, size_of_val};
25
26use ahash::RandomState;
27use arrow::array::{downcast_integer, Array, ArrayRef, AsArray};
28use arrow::datatypes::{
29    ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type,
30    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
31};
32
33use datafusion_common::cast::as_list_array;
34use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
35use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
36use datafusion_expr::type_coercion::aggregates::INTEGERS;
37use datafusion_expr::utils::format_state_name;
38use datafusion_expr::{
39    Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF,
40    Signature, Volatility,
41};
42
43use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
44use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
45use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign};
46use std::sync::LazyLock;
47
48/// This macro helps create group accumulators based on bitwise operations typically used internally
49/// and might not be necessary for users to call directly.
50macro_rules! group_accumulator_helper {
51    ($t:ty, $dt:expr, $opr:expr) => {
52        match $opr {
53            BitwiseOperationType::And => Ok(Box::new(
54                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y))
55                    .with_starting_value(!0),
56            )),
57            BitwiseOperationType::Or => Ok(Box::new(
58                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
59            )),
60            BitwiseOperationType::Xor => Ok(Box::new(
61                PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
62            )),
63        }
64    };
65}
66
67/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool)
68macro_rules! accumulator_helper {
69    ($t:ty, $opr:expr, $is_distinct: expr) => {
70        match $opr {
71            BitwiseOperationType::And => Ok(Box::<BitAndAccumulator<$t>>::default()),
72            BitwiseOperationType::Or => Ok(Box::<BitOrAccumulator<$t>>::default()),
73            BitwiseOperationType::Xor => {
74                if $is_distinct {
75                    Ok(Box::<DistinctBitXorAccumulator<$t>>::default())
76                } else {
77                    Ok(Box::<BitXorAccumulator<$t>>::default())
78                }
79            }
80        }
81    };
82}
83
84/// AND, OR and XOR only supports a subset of numeric types
85///
86/// `args` is [AccumulatorArgs]
87/// `opr` is [BitwiseOperationType]
88/// `is_distinct` is boolean value indicating whether the operation is distinct or not.
89macro_rules! downcast_bitwise_accumulator {
90    ($args:ident, $opr:expr, $is_distinct: expr) => {
91        match $args.return_field.data_type() {
92            DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct),
93            DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct),
94            DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct),
95            DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct),
96            DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct),
97            DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct),
98            DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct),
99            DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct),
100            _ => {
101                not_impl_err!(
102                    "{} not supported for {}: {}",
103                    stringify!($opr),
104                    $args.name,
105                    $args.return_field.data_type()
106                )
107            }
108        }
109    };
110}
111
112/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner.
113///
114/// `EXPR_FN` identifier used to name the generated expression function.
115/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function.
116/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed.
117/// `DOCUMENTATION` documentation for the UDAF
118macro_rules! make_bitwise_udaf_expr_and_func {
119    ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => {
120        make_udaf_expr!(
121            $EXPR_FN,
122            expr_x,
123            concat!(
124                "Returns the bitwise",
125                stringify!($OPR_TYPE),
126                "of a group of values"
127            ),
128            $AGGREGATE_UDF_FN
129        );
130        create_func!(
131            $EXPR_FN,
132            $AGGREGATE_UDF_FN,
133            BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION)
134        );
135    };
136}
137
138static BIT_AND_DOC: LazyLock<Documentation> = LazyLock::new(|| {
139    Documentation::builder(
140        DOC_SECTION_GENERAL,
141        "Computes the bitwise AND of all non-null input values.",
142        "bit_and(expression)",
143    )
144    .with_standard_argument("expression", Some("Integer"))
145    .build()
146});
147
148fn get_bit_and_doc() -> &'static Documentation {
149    &BIT_AND_DOC
150}
151
152static BIT_OR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
153    Documentation::builder(
154        DOC_SECTION_GENERAL,
155        "Computes the bitwise OR of all non-null input values.",
156        "bit_or(expression)",
157    )
158    .with_standard_argument("expression", Some("Integer"))
159    .build()
160});
161
162fn get_bit_or_doc() -> &'static Documentation {
163    &BIT_OR_DOC
164}
165
166static BIT_XOR_DOC: LazyLock<Documentation> = LazyLock::new(|| {
167    Documentation::builder(
168        DOC_SECTION_GENERAL,
169        "Computes the bitwise exclusive OR of all non-null input values.",
170        "bit_xor(expression)",
171    )
172    .with_standard_argument("expression", Some("Integer"))
173    .build()
174});
175
176fn get_bit_xor_doc() -> &'static Documentation {
177    &BIT_XOR_DOC
178}
179
180make_bitwise_udaf_expr_and_func!(
181    bit_and,
182    bit_and_udaf,
183    BitwiseOperationType::And,
184    get_bit_and_doc()
185);
186make_bitwise_udaf_expr_and_func!(
187    bit_or,
188    bit_or_udaf,
189    BitwiseOperationType::Or,
190    get_bit_or_doc()
191);
192make_bitwise_udaf_expr_and_func!(
193    bit_xor,
194    bit_xor_udaf,
195    BitwiseOperationType::Xor,
196    get_bit_xor_doc()
197);
198
199/// The different types of bitwise operations that can be performed.
200#[derive(Debug, Clone, Eq, PartialEq, Hash)]
201enum BitwiseOperationType {
202    And,
203    Or,
204    Xor,
205}
206
207impl Display for BitwiseOperationType {
208    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{self:?}")
210    }
211}
212
213/// [BitwiseOperation] struct encapsulates information about a bitwise operation.
214#[derive(Debug, PartialEq, Eq, Hash)]
215struct BitwiseOperation {
216    signature: Signature,
217    /// `operation` indicates the type of bitwise operation to be performed.
218    operation: BitwiseOperationType,
219    func_name: &'static str,
220    documentation: &'static Documentation,
221}
222
223impl BitwiseOperation {
224    pub fn new(
225        operator: BitwiseOperationType,
226        func_name: &'static str,
227        documentation: &'static Documentation,
228    ) -> Self {
229        Self {
230            operation: operator,
231            signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable),
232            func_name,
233            documentation,
234        }
235    }
236}
237
238impl AggregateUDFImpl for BitwiseOperation {
239    fn as_any(&self) -> &dyn Any {
240        self
241    }
242
243    fn name(&self) -> &str {
244        self.func_name
245    }
246
247    fn signature(&self) -> &Signature {
248        &self.signature
249    }
250
251    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
252        let arg_type = &arg_types[0];
253        if !arg_type.is_integer() {
254            return exec_err!(
255                "[return_type] {} not supported for {}",
256                self.name(),
257                arg_type
258            );
259        }
260        Ok(arg_type.clone())
261    }
262
263    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
264        downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct)
265    }
266
267    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
268        if self.operation == BitwiseOperationType::Xor && args.is_distinct {
269            Ok(vec![Field::new_list(
270                format_state_name(
271                    args.name,
272                    format!("{} distinct", self.name()).as_str(),
273                ),
274                // See COMMENTS.md to understand why nullable is set to true
275                Field::new_list_field(args.return_type().clone(), true),
276                false,
277            )
278            .into()])
279        } else {
280            Ok(vec![Field::new(
281                format_state_name(args.name, self.name()),
282                args.return_field.data_type().clone(),
283                true,
284            )
285            .into()])
286        }
287    }
288
289    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
290        true
291    }
292
293    fn create_groups_accumulator(
294        &self,
295        args: AccumulatorArgs,
296    ) -> Result<Box<dyn GroupsAccumulator>> {
297        let data_type = args.return_field.data_type();
298        let operation = &self.operation;
299        downcast_integer! {
300            data_type => (group_accumulator_helper, data_type, operation),
301            _ => not_impl_err!(
302                "GroupsAccumulator not supported for {} with {}",
303                self.name(),
304                data_type
305            ),
306        }
307    }
308
309    fn reverse_expr(&self) -> ReversedUDAF {
310        ReversedUDAF::Identical
311    }
312
313    fn documentation(&self) -> Option<&Documentation> {
314        Some(self.documentation)
315    }
316}
317
318struct BitAndAccumulator<T: ArrowNumericType> {
319    value: Option<T::Native>,
320}
321
322impl<T: ArrowNumericType> std::fmt::Debug for BitAndAccumulator<T> {
323    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
324        write!(f, "BitAndAccumulator({})", T::DATA_TYPE)
325    }
326}
327
328impl<T: ArrowNumericType> Default for BitAndAccumulator<T> {
329    fn default() -> Self {
330        Self { value: None }
331    }
332}
333
334impl<T: ArrowNumericType> Accumulator for BitAndAccumulator<T>
335where
336    T::Native: std::ops::BitAnd<Output = T::Native>,
337{
338    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
339        if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::<T>()) {
340            let v = self.value.get_or_insert(x);
341            *v = *v & x;
342        }
343        Ok(())
344    }
345
346    fn evaluate(&mut self) -> Result<ScalarValue> {
347        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
348    }
349
350    fn size(&self) -> usize {
351        size_of_val(self)
352    }
353
354    fn state(&mut self) -> Result<Vec<ScalarValue>> {
355        Ok(vec![self.evaluate()?])
356    }
357
358    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
359        self.update_batch(states)
360    }
361}
362
363struct BitOrAccumulator<T: ArrowNumericType> {
364    value: Option<T::Native>,
365}
366
367impl<T: ArrowNumericType> std::fmt::Debug for BitOrAccumulator<T> {
368    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369        write!(f, "BitOrAccumulator({})", T::DATA_TYPE)
370    }
371}
372
373impl<T: ArrowNumericType> Default for BitOrAccumulator<T> {
374    fn default() -> Self {
375        Self { value: None }
376    }
377}
378
379impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
380where
381    T::Native: std::ops::BitOr<Output = T::Native>,
382{
383    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
384        if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::<T>()) {
385            let v = self.value.get_or_insert(T::Native::usize_as(0));
386            *v = *v | x;
387        }
388        Ok(())
389    }
390
391    fn evaluate(&mut self) -> Result<ScalarValue> {
392        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
393    }
394
395    fn size(&self) -> usize {
396        size_of_val(self)
397    }
398
399    fn state(&mut self) -> Result<Vec<ScalarValue>> {
400        Ok(vec![self.evaluate()?])
401    }
402
403    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
404        self.update_batch(states)
405    }
406}
407
408struct BitXorAccumulator<T: ArrowNumericType> {
409    value: Option<T::Native>,
410}
411
412impl<T: ArrowNumericType> std::fmt::Debug for BitXorAccumulator<T> {
413    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
414        write!(f, "BitXorAccumulator({})", T::DATA_TYPE)
415    }
416}
417
418impl<T: ArrowNumericType> Default for BitXorAccumulator<T> {
419    fn default() -> Self {
420        Self { value: None }
421    }
422}
423
424impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
425where
426    T::Native: std::ops::BitXor<Output = T::Native>,
427{
428    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
429        if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::<T>()) {
430            let v = self.value.get_or_insert(T::Native::usize_as(0));
431            *v = *v ^ x;
432        }
433        Ok(())
434    }
435
436    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
437        // XOR is it's own inverse
438        self.update_batch(values)
439    }
440
441    fn supports_retract_batch(&self) -> bool {
442        true
443    }
444
445    fn evaluate(&mut self) -> Result<ScalarValue> {
446        ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
447    }
448
449    fn size(&self) -> usize {
450        size_of_val(self)
451    }
452
453    fn state(&mut self) -> Result<Vec<ScalarValue>> {
454        Ok(vec![self.evaluate()?])
455    }
456
457    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
458        self.update_batch(states)
459    }
460}
461
462struct DistinctBitXorAccumulator<T: ArrowNumericType> {
463    values: HashSet<T::Native, RandomState>,
464}
465
466impl<T: ArrowNumericType> std::fmt::Debug for DistinctBitXorAccumulator<T> {
467    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
468        write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE)
469    }
470}
471
472impl<T: ArrowNumericType> Default for DistinctBitXorAccumulator<T> {
473    fn default() -> Self {
474        Self {
475            values: HashSet::default(),
476        }
477    }
478}
479
480impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
481where
482    T::Native: std::ops::BitXor<Output = T::Native> + Hash + Eq,
483{
484    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
485        if values.is_empty() {
486            return Ok(());
487        }
488
489        let array = values[0].as_primitive::<T>();
490        match array.nulls().filter(|x| x.null_count() > 0) {
491            Some(n) => {
492                for idx in n.valid_indices() {
493                    self.values.insert(array.value(idx));
494                }
495            }
496            None => array.values().iter().for_each(|x| {
497                self.values.insert(*x);
498            }),
499        }
500        Ok(())
501    }
502
503    fn evaluate(&mut self) -> Result<ScalarValue> {
504        let mut acc = T::Native::usize_as(0);
505        for distinct_value in self.values.iter() {
506            acc = acc ^ *distinct_value;
507        }
508        let v = (!self.values.is_empty()).then_some(acc);
509        ScalarValue::new_primitive::<T>(v, &T::DATA_TYPE)
510    }
511
512    fn size(&self) -> usize {
513        size_of_val(self) + self.values.capacity() * size_of::<T::Native>()
514    }
515
516    fn state(&mut self) -> Result<Vec<ScalarValue>> {
517        // 1. Stores aggregate state in `ScalarValue::List`
518        // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
519        let state_out = {
520            let values = self
521                .values
522                .iter()
523                .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &T::DATA_TYPE))
524                .collect::<Result<Vec<_>>>()?;
525
526            let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE);
527            vec![ScalarValue::List(arr)]
528        };
529        Ok(state_out)
530    }
531
532    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
533        if let Some(state) = states.first() {
534            let list_arr = as_list_array(state)?;
535            for arr in list_arr.iter().flatten() {
536                self.update_batch(&[arr])?;
537            }
538        }
539        Ok(())
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use std::sync::Arc;
546
547    use arrow::array::{ArrayRef, UInt64Array};
548    use arrow::datatypes::UInt64Type;
549    use datafusion_common::ScalarValue;
550
551    use crate::bit_and_or_xor::BitXorAccumulator;
552    use datafusion_expr::Accumulator;
553
554    #[test]
555    fn test_bit_xor_accumulator() {
556        let mut accumulator = BitXorAccumulator::<UInt64Type> { value: None };
557        let batches: Vec<_> = vec![vec![1, 2], vec![1]]
558            .into_iter()
559            .map(|b| Arc::new(b.into_iter().collect::<UInt64Array>()) as ArrayRef)
560            .collect();
561
562        let added = &[Arc::clone(&batches[0])];
563        let retracted = &[Arc::clone(&batches[1])];
564
565        // XOR of 1..3 is 3
566        accumulator.update_batch(added).unwrap();
567        assert_eq!(
568            accumulator.evaluate().unwrap(),
569            ScalarValue::UInt64(Some(3))
570        );
571
572        // Removing [1] ^ 3 = 2
573        accumulator.retract_batch(retracted).unwrap();
574        assert_eq!(
575            accumulator.evaluate().unwrap(),
576            ScalarValue::UInt64(Some(2))
577        );
578    }
579}