Skip to main content

arrow_select/
zip.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`zip`]: Combine values from two arrays based on boolean mask
19
20use crate::filter::{SlicesIterator, prep_null_mask_filter};
21use arrow_array::cast::AsArray;
22use arrow_array::types::{
23    BinaryType, BinaryViewType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type,
24    StringViewType, Utf8Type,
25};
26use arrow_array::*;
27use arrow_buffer::{
28    BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder,
29    ScalarBuffer, ToByteSlice,
30};
31use arrow_data::transform::MutableArrayData;
32use arrow_data::{ArrayData, ByteView};
33use arrow_schema::{ArrowError, DataType};
34use std::fmt::{Debug, Formatter};
35use std::hash::Hash;
36use std::marker::PhantomData;
37use std::ops::Not;
38use std::sync::Arc;
39
40/// Zip two arrays by some boolean mask.
41///
42/// - Where `mask` is `true`, values of `truthy` are taken
43/// - Where `mask` is `false` or `NULL`, values of `falsy` are taken
44///
45/// # Example: `zip` two arrays
46/// ```
47/// # use std::sync::Arc;
48/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array};
49/// # use arrow_select::zip::zip;
50/// // mask: [true, true, false, NULL, true]
51/// let mask = BooleanArray::from(vec![
52///   Some(true), Some(true), Some(false), None, Some(true)
53/// ]);
54/// // truthy array: [1, NULL, 3, 4, 5]
55/// let truthy = Int32Array::from(vec![
56///   Some(1), None, Some(3), Some(4), Some(5)
57/// ]);
58/// // falsy array: [10, 20, 30, 40, 50]
59/// let falsy = Int32Array::from(vec![
60///   Some(10), Some(20), Some(30), Some(40), Some(50)
61/// ]);
62/// // zip with this mask select the first, second and last value from `truthy`
63/// // and the third and fourth value from `falsy`
64/// let result = zip(&mask, &truthy, &falsy).unwrap();
65/// // Expected: [1, NULL, 30, 40, 5]
66/// let expected: ArrayRef = Arc::new(Int32Array::from(vec![
67///   Some(1), None, Some(30), Some(40), Some(5)
68/// ]));
69/// assert_eq!(&result, &expected);
70/// ```
71///
72/// # Example: `zip` and array with a scalar
73///
74/// Use `zip` to replace certain values in an array with a scalar
75///
76/// ```
77/// # use std::sync::Arc;
78/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array};
79/// # use arrow_select::zip::zip;
80/// // mask: [true, true, false, NULL, true]
81/// let mask = BooleanArray::from(vec![
82///   Some(true), Some(true), Some(false), None, Some(true)
83/// ]);
84/// //  array: [1, NULL, 3, 4, 5]
85/// let arr = Int32Array::from(vec![
86///   Some(1), None, Some(3), Some(4), Some(5)
87/// ]);
88/// // scalar: 42
89/// let scalar = Int32Array::new_scalar(42);
90/// // zip the array with the  mask select the first, second and last value from `arr`
91/// // and fill the third and fourth value with the scalar 42
92/// let result = zip(&mask, &arr, &scalar).unwrap();
93/// // Expected: [1, NULL, 42, 42, 5]
94/// let expected: ArrayRef = Arc::new(Int32Array::from(vec![
95///   Some(1), None, Some(42), Some(42), Some(5)
96/// ]));
97/// assert_eq!(&result, &expected);
98/// ```
99pub fn zip(
100    mask: &BooleanArray,
101    truthy: &dyn Datum,
102    falsy: &dyn Datum,
103) -> Result<ArrayRef, ArrowError> {
104    let (truthy_array, truthy_is_scalar) = truthy.get();
105    let (falsy_array, falsy_is_scalar) = falsy.get();
106
107    if falsy_is_scalar && truthy_is_scalar {
108        let zipper = ScalarZipper::try_new(truthy, falsy)?;
109        return zipper.zip_impl.create_output(mask);
110    }
111
112    let truthy = truthy_array;
113    let falsy = falsy_array;
114
115    if truthy.data_type() != falsy.data_type() {
116        return Err(ArrowError::InvalidArgumentError(
117            "arguments need to have the same data type".into(),
118        ));
119    }
120
121    if truthy_is_scalar && truthy.len() != 1 {
122        return Err(ArrowError::InvalidArgumentError(
123            "scalar arrays must have 1 element".into(),
124        ));
125    }
126    if !truthy_is_scalar && truthy.len() != mask.len() {
127        return Err(ArrowError::InvalidArgumentError(
128            "all arrays should have the same length".into(),
129        ));
130    }
131    if falsy_is_scalar && falsy.len() != 1 {
132        return Err(ArrowError::InvalidArgumentError(
133            "scalar arrays must have 1 element".into(),
134        ));
135    }
136    if !falsy_is_scalar && falsy.len() != mask.len() {
137        return Err(ArrowError::InvalidArgumentError(
138            "all arrays should have the same length".into(),
139        ));
140    }
141
142    let falsy = falsy.to_data();
143    let truthy = truthy.to_data();
144
145    zip_impl(mask, &truthy, truthy_is_scalar, &falsy, falsy_is_scalar)
146}
147
148fn zip_impl(
149    mask: &BooleanArray,
150    truthy: &ArrayData,
151    truthy_is_scalar: bool,
152    falsy: &ArrayData,
153    falsy_is_scalar: bool,
154) -> Result<ArrayRef, ArrowError> {
155    let mut mutable = MutableArrayData::new(vec![truthy, falsy], false, truthy.len());
156
157    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
158    // fill with falsy values
159
160    // keep track of how much is filled
161    let mut filled = 0;
162
163    let mask_buffer = maybe_prep_null_mask_filter(mask);
164    SlicesIterator::from(&mask_buffer).for_each(|(start, end)| {
165        // the gap needs to be filled with falsy values
166        if start > filled {
167            if falsy_is_scalar {
168                for _ in filled..start {
169                    // Copy the first item from the 'falsy' array into the output buffer.
170                    mutable.extend(1, 0, 1);
171                }
172            } else {
173                mutable.extend(1, filled, start);
174            }
175        }
176        // fill with truthy values
177        if truthy_is_scalar {
178            for _ in start..end {
179                // Copy the first item from the 'truthy' array into the output buffer.
180                mutable.extend(0, 0, 1);
181            }
182        } else {
183            mutable.extend(0, start, end);
184        }
185        filled = end;
186    });
187    // the remaining part is falsy
188    if filled < mask.len() {
189        if falsy_is_scalar {
190            for _ in filled..mask.len() {
191                // Copy the first item from the 'falsy' array into the output buffer.
192                mutable.extend(1, 0, 1);
193            }
194        } else {
195            mutable.extend(1, filled, mask.len());
196        }
197    }
198
199    let data = mutable.freeze();
200    Ok(make_array(data))
201}
202
203/// Zipper for 2 scalars
204///
205/// Useful for using in `IF <expr> THEN <scalar> ELSE <scalar> END` expressions
206///
207/// # Example
208/// ```
209/// # use std::sync::Arc;
210/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array, Scalar, cast::AsArray, types::Int32Type};
211///
212/// # use arrow_select::zip::ScalarZipper;
213/// let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
214/// let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
215/// let zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap();
216///
217/// // Later when we have a boolean mask
218/// let mask = BooleanArray::from(vec![true, false, true, false, true]);
219/// let result = zipper.zip(&mask).unwrap();
220/// let actual = result.as_primitive::<Int32Type>();
221/// let expected = Int32Array::from(vec![Some(42), Some(123), Some(42), Some(123), Some(42)]);
222/// ```
223///
224#[derive(Debug, Clone)]
225pub struct ScalarZipper {
226    zip_impl: Arc<dyn ZipImpl>,
227}
228
229impl ScalarZipper {
230    /// Try to create a new ScalarZipper from two scalar Datum
231    ///
232    /// # Errors
233    /// returns error if:
234    /// - the two Datum have different data types
235    /// - either Datum is not a scalar (or has more than 1 element)
236    ///
237    pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result<Self, ArrowError> {
238        let (truthy, truthy_is_scalar) = truthy.get();
239        let (falsy, falsy_is_scalar) = falsy.get();
240
241        if truthy.data_type() != falsy.data_type() {
242            return Err(ArrowError::InvalidArgumentError(
243                "arguments need to have the same data type".into(),
244            ));
245        }
246
247        if !truthy_is_scalar {
248            return Err(ArrowError::InvalidArgumentError(
249                "only scalar arrays are supported".into(),
250            ));
251        }
252
253        if !falsy_is_scalar {
254            return Err(ArrowError::InvalidArgumentError(
255                "only scalar arrays are supported".into(),
256            ));
257        }
258
259        if truthy.len() != 1 {
260            return Err(ArrowError::InvalidArgumentError(
261                "scalar arrays must have 1 element".into(),
262            ));
263        }
264        if falsy.len() != 1 {
265            return Err(ArrowError::InvalidArgumentError(
266                "scalar arrays must have 1 element".into(),
267            ));
268        }
269
270        macro_rules! primitive_size_helper {
271            ($t:ty) => {
272                Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as Arc<dyn ZipImpl>
273            };
274        }
275
276        let zip_impl = downcast_primitive! {
277            truthy.data_type() => (primitive_size_helper),
278            DataType::Utf8 => {
279                Arc::new(BytesScalarImpl::<Utf8Type>::new(truthy, falsy)) as Arc<dyn ZipImpl>
280            },
281            DataType::LargeUtf8 => {
282                Arc::new(BytesScalarImpl::<LargeUtf8Type>::new(truthy, falsy)) as Arc<dyn ZipImpl>
283            },
284            DataType::Binary => {
285                Arc::new(BytesScalarImpl::<BinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
286            },
287            DataType::LargeBinary => {
288                Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
289            },
290            DataType::Utf8View => {
291                Arc::new(ByteViewScalarImpl::<StringViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
292            },
293            DataType::BinaryView => {
294                Arc::new(ByteViewScalarImpl::<BinaryViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
295            },
296            _ => {
297                Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
298            },
299        };
300
301        Ok(Self { zip_impl })
302    }
303
304    /// Creating output array based on input boolean array and the two scalar values the zipper was created with
305    /// See struct level documentation for examples.
306    pub fn zip(&self, mask: &BooleanArray) -> Result<ArrayRef, ArrowError> {
307        self.zip_impl.create_output(mask)
308    }
309}
310
311/// Impl for creating output array based on a mask
312trait ZipImpl: Debug + Send + Sync {
313    /// Creating output array based on input boolean array
314    fn create_output(&self, input: &BooleanArray) -> Result<ArrayRef, ArrowError>;
315}
316
317#[derive(Debug, PartialEq)]
318struct FallbackImpl {
319    truthy: ArrayData,
320    falsy: ArrayData,
321}
322
323impl FallbackImpl {
324    fn new(left: &dyn Array, right: &dyn Array) -> Self {
325        Self {
326            truthy: left.to_data(),
327            falsy: right.to_data(),
328        }
329    }
330}
331
332impl ZipImpl for FallbackImpl {
333    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
334        zip_impl(predicate, &self.truthy, true, &self.falsy, true)
335    }
336}
337
338struct PrimitiveScalarImpl<T: ArrowPrimitiveType> {
339    data_type: DataType,
340    truthy: Option<T::Native>,
341    falsy: Option<T::Native>,
342}
343
344impl<T: ArrowPrimitiveType> Debug for PrimitiveScalarImpl<T> {
345    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
346        f.debug_struct("PrimitiveScalarImpl")
347            .field("data_type", &self.data_type)
348            .field("truthy", &self.truthy)
349            .field("falsy", &self.falsy)
350            .finish()
351    }
352}
353
354impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
355    fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
356        Self {
357            data_type: truthy.data_type().clone(),
358            truthy: Self::get_value_from_scalar(truthy),
359            falsy: Self::get_value_from_scalar(falsy),
360        }
361    }
362
363    fn get_value_from_scalar(scalar: &dyn Array) -> Option<T::Native> {
364        if scalar.is_null(0) {
365            None
366        } else {
367            let value = scalar.as_primitive::<T>().value(0);
368
369            Some(value)
370        }
371    }
372
373    /// return an output array that has
374    /// `value` in all locations where predicate is true
375    /// `null` otherwise
376    fn get_scalar_and_null_buffer_for_single_non_nullable(
377        predicate: BooleanBuffer,
378        value: T::Native,
379    ) -> (Vec<T::Native>, Option<NullBuffer>) {
380        let result_len = predicate.len();
381        let nulls = NullBuffer::new(predicate);
382        let scalars = vec![value; result_len];
383
384        (scalars, Some(nulls))
385    }
386}
387
388impl<T: ArrowPrimitiveType> ZipImpl for PrimitiveScalarImpl<T> {
389    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
390        let result_len = predicate.len();
391        // Nulls are treated as false
392        let predicate = maybe_prep_null_mask_filter(predicate);
393
394        let (scalars, nulls): (Vec<T::Native>, Option<NullBuffer>) = match (self.truthy, self.falsy)
395        {
396            (Some(truthy_val), Some(falsy_val)) => {
397                let scalars: Vec<T::Native> = predicate
398                    .iter()
399                    .map(|b| if b { truthy_val } else { falsy_val })
400                    .collect();
401
402                (scalars, None)
403            }
404            (Some(truthy_val), None) => {
405                // If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null)
406                // If a value is false we need the FALSY and the null buffer will have 0 (meaning null)
407
408                Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
409            }
410            (None, Some(falsy_val)) => {
411                // Flipping the boolean buffer as we want the opposite of the TRUE case
412                //
413                // if the condition is true we want null so we need to NOT the value so we get 0 (meaning null)
414                // if the condition is false we want the FALSY value so we need to NOT the value so we get 1 (meaning not null)
415                let predicate = predicate.not();
416
417                Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
418            }
419            (None, None) => {
420                // All values are null
421                let nulls = NullBuffer::new_null(result_len);
422                let scalars = vec![T::default_value(); result_len];
423
424                (scalars, Some(nulls))
425            }
426        };
427
428        let scalars = ScalarBuffer::<T::Native>::from(scalars);
429        let output = PrimitiveArray::<T>::try_new(scalars, nulls)?;
430
431        // Keep decimal precisions, scales or timestamps timezones
432        let output = output.with_data_type(self.data_type.clone());
433
434        Ok(Arc::new(output))
435    }
436}
437
438#[derive(PartialEq, Hash)]
439struct BytesScalarImpl<T: ByteArrayType> {
440    truthy: Option<Vec<u8>>,
441    falsy: Option<Vec<u8>>,
442    phantom: PhantomData<T>,
443}
444
445impl<T: ByteArrayType> Debug for BytesScalarImpl<T> {
446    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
447        f.debug_struct("BytesScalarImpl")
448            .field("truthy", &self.truthy)
449            .field("falsy", &self.falsy)
450            .finish()
451    }
452}
453
454impl<T: ByteArrayType> BytesScalarImpl<T> {
455    fn new(truthy_value: &dyn Array, falsy_value: &dyn Array) -> Self {
456        Self {
457            truthy: Self::get_value_from_scalar(truthy_value),
458            falsy: Self::get_value_from_scalar(falsy_value),
459            phantom: PhantomData,
460        }
461    }
462
463    fn get_value_from_scalar(scalar: &dyn Array) -> Option<Vec<u8>> {
464        if scalar.is_null(0) {
465            None
466        } else {
467            let bytes: &[u8] = scalar.as_bytes::<T>().value(0).as_ref();
468
469            Some(bytes.to_vec())
470        }
471    }
472
473    /// return an output array that has
474    /// `value` in all locations where predicate is true
475    /// `null` otherwise
476    fn get_scalar_and_null_buffer_for_single_non_nullable(
477        predicate: BooleanBuffer,
478        value: &[u8],
479    ) -> (Buffer, OffsetBuffer<T::Offset>, Option<NullBuffer>) {
480        let value_length = value.len();
481
482        let number_of_true = predicate.count_set_bits();
483
484        // Fast path for all nulls
485        if number_of_true == 0 {
486            // All values are null
487            let nulls = NullBuffer::new_null(predicate.len());
488
489            return (
490                // Empty bytes
491                Buffer::from(&[]),
492                // All nulls so all lengths are 0
493                OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
494                Some(nulls),
495            );
496        }
497
498        let offsets = OffsetBuffer::<T::Offset>::from_lengths(
499            predicate.iter().map(|b| if b { value_length } else { 0 }),
500        );
501
502        let mut bytes = MutableBuffer::with_capacity(0);
503        bytes.repeat_slice_n_times(value, number_of_true);
504
505        let bytes = Buffer::from(bytes);
506
507        // If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null)
508        // If a value is false we need the FALSY and the null buffer will have 0 (meaning null)
509        let nulls = NullBuffer::new(predicate);
510
511        (bytes, offsets, Some(nulls))
512    }
513
514    /// Create a [`Buffer`] where `value` slice is repeated `number_of_values` times
515    /// and [`OffsetBuffer`] where there are `number_of_values` lengths, and all equals to `value` length
516    fn get_bytes_and_offset_for_all_same_value(
517        number_of_values: usize,
518        value: &[u8],
519    ) -> (Buffer, OffsetBuffer<T::Offset>) {
520        let value_length = value.len();
521
522        let offsets =
523            OffsetBuffer::<T::Offset>::from_repeated_length(value_length, number_of_values);
524
525        let mut bytes = MutableBuffer::with_capacity(0);
526        bytes.repeat_slice_n_times(value, number_of_values);
527        let bytes = Buffer::from(bytes);
528
529        (bytes, offsets)
530    }
531
532    fn create_output_on_non_nulls(
533        predicate: &BooleanBuffer,
534        truthy_val: &[u8],
535        falsy_val: &[u8],
536    ) -> (Buffer, OffsetBuffer<<T as ByteArrayType>::Offset>) {
537        let true_count = predicate.count_set_bits();
538
539        match true_count {
540            0 => {
541                // All values are falsy
542
543                let (bytes, offsets) =
544                    Self::get_bytes_and_offset_for_all_same_value(predicate.len(), falsy_val);
545
546                return (bytes, offsets);
547            }
548            n if n == predicate.len() => {
549                // All values are truthy
550                let (bytes, offsets) =
551                    Self::get_bytes_and_offset_for_all_same_value(predicate.len(), truthy_val);
552
553                return (bytes, offsets);
554            }
555
556            _ => {
557                // Fallback
558            }
559        }
560
561        let total_number_of_bytes =
562            true_count * truthy_val.len() + (predicate.len() - true_count) * falsy_val.len();
563        let mut mutable = MutableBuffer::with_capacity(total_number_of_bytes);
564        let mut offset_buffer_builder = OffsetBufferBuilder::<T::Offset>::new(predicate.len());
565
566        // keep track of how much is filled
567        let mut filled = 0;
568
569        let truthy_len = truthy_val.len();
570        let falsy_len = falsy_val.len();
571
572        SlicesIterator::from(predicate).for_each(|(start, end)| {
573            // the gap needs to be filled with falsy values
574            if start > filled {
575                let false_repeat_count = start - filled;
576                // Push false value `repeat_count` times
577                mutable.repeat_slice_n_times(falsy_val, false_repeat_count);
578
579                for _ in 0..false_repeat_count {
580                    offset_buffer_builder.push_length(falsy_len)
581                }
582            }
583
584            let true_repeat_count = end - start;
585            // fill with truthy values
586            mutable.repeat_slice_n_times(truthy_val, true_repeat_count);
587
588            for _ in 0..true_repeat_count {
589                offset_buffer_builder.push_length(truthy_len)
590            }
591            filled = end;
592        });
593        // the remaining part is falsy
594        if filled < predicate.len() {
595            let false_repeat_count = predicate.len() - filled;
596            // Copy the first item from the 'falsy' array into the output buffer.
597            mutable.repeat_slice_n_times(falsy_val, false_repeat_count);
598
599            for _ in 0..false_repeat_count {
600                offset_buffer_builder.push_length(falsy_len)
601            }
602        }
603
604        (mutable.into(), offset_buffer_builder.finish())
605    }
606}
607
608impl<T: ByteArrayType> ZipImpl for BytesScalarImpl<T> {
609    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
610        let result_len = predicate.len();
611        // Nulls are treated as false
612        let predicate = maybe_prep_null_mask_filter(predicate);
613
614        let (bytes, offsets, nulls): (Buffer, OffsetBuffer<T::Offset>, Option<NullBuffer>) =
615            match (self.truthy.as_deref(), self.falsy.as_deref()) {
616                (Some(truthy_val), Some(falsy_val)) => {
617                    let (bytes, offsets) =
618                        Self::create_output_on_non_nulls(&predicate, truthy_val, falsy_val);
619
620                    (bytes, offsets, None)
621                }
622                (Some(truthy_val), None) => {
623                    Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
624                }
625                (None, Some(falsy_val)) => {
626                    // Flipping the boolean buffer as we want the opposite of the TRUE case
627                    //
628                    // if the condition is true we want null so we need to NOT the value so we get 0 (meaning null)
629                    // if the condition is false we want the FALSE value so we need to NOT the value so we get 1 (meaning not null)
630                    let predicate = predicate.not();
631                    Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
632                }
633                (None, None) => {
634                    // All values are null
635                    let nulls = NullBuffer::new_null(result_len);
636
637                    (
638                        // Empty bytes
639                        Buffer::from(&[]),
640                        // All nulls so all lengths are 0
641                        OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
642                        Some(nulls),
643                    )
644                }
645            };
646
647        let output = unsafe {
648            // Safety: the values are based on valid inputs
649            // and `try_new` is expensive for strings as it validate that the input is valid utf8
650            GenericByteArray::<T>::new_unchecked(offsets, bytes, nulls)
651        };
652
653        Ok(Arc::new(output))
654    }
655}
656
657fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer {
658    // Nulls are treated as false
659    if predicate.null_count() == 0 {
660        predicate.values().clone()
661    } else {
662        let cleaned = prep_null_mask_filter(predicate);
663        let (boolean_buffer, _) = cleaned.into_parts();
664        boolean_buffer
665    }
666}
667
668struct ByteViewScalarImpl<T: ByteViewType> {
669    truthy_view: Option<u128>,
670    truthy_buffers: Vec<Buffer>,
671    falsy_view: Option<u128>,
672    falsy_buffers: Vec<Buffer>,
673    phantom: PhantomData<T>,
674}
675
676impl<T: ByteViewType> ByteViewScalarImpl<T> {
677    fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
678        let (truthy_view, truthy_buffers) = Self::get_value_from_scalar(truthy);
679        let (falsy_view, falsy_buffers) = Self::get_value_from_scalar(falsy);
680        Self {
681            truthy_view,
682            truthy_buffers,
683            falsy_view,
684            falsy_buffers,
685            phantom: PhantomData,
686        }
687    }
688
689    fn get_value_from_scalar(scalar: &dyn Array) -> (Option<u128>, Vec<Buffer>) {
690        if scalar.is_null(0) {
691            (None, vec![])
692        } else {
693            let (views, buffers, _) = scalar.as_byte_view::<T>().clone().into_parts();
694            (views.first().copied(), buffers)
695        }
696    }
697
698    fn get_views_for_single_non_nullable(
699        predicate: BooleanBuffer,
700        value: u128,
701        buffers: Vec<Buffer>,
702    ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
703        let number_of_true = predicate.count_set_bits();
704        let number_of_values = predicate.len();
705
706        // Fast path for all nulls
707        if number_of_true == 0 {
708            // All values are null
709            return (
710                vec![0; number_of_values].into(),
711                vec![],
712                Some(NullBuffer::new_null(number_of_values)),
713            );
714        }
715        let bytes = vec![value; number_of_values];
716
717        // If value is true and we want to handle the TRUTHY case, the null buffer will have 1 (meaning not null)
718        // If value is false and we want to handle the FALSY case, the null buffer will have 0 (meaning null)
719        let nulls = NullBuffer::new(predicate);
720        (bytes.into(), buffers, Some(nulls))
721    }
722
723    fn get_views_for_non_nullable(
724        predicate: BooleanBuffer,
725        result_len: usize,
726        truthy_view: u128,
727        truthy_buffers: Vec<Buffer>,
728        falsy_view: u128,
729        falsy_buffers: Vec<Buffer>,
730    ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
731        let true_count = predicate.count_set_bits();
732        match true_count {
733            0 => {
734                // all values are falsy
735                (vec![falsy_view; result_len].into(), falsy_buffers, None)
736            }
737            n if n == predicate.len() => {
738                // all values are truthy
739                (vec![truthy_view; result_len].into(), truthy_buffers, None)
740            }
741            _ => {
742                let true_count = predicate.count_set_bits();
743                let mut buffers: Vec<Buffer> = truthy_buffers.to_vec();
744
745                // If the falsy buffers are empty, we can use the falsy view as it is, because the value
746                // is completely inlined. Otherwise, we have non-inlined values in the buffer, and we need
747                // to recalculate the falsy view
748                let view_falsy = if falsy_buffers.is_empty() {
749                    falsy_view
750                } else {
751                    let byte_view_falsy = ByteView::from(falsy_view);
752                    let new_index_falsy_buffers =
753                        buffers.len() as u32 + byte_view_falsy.buffer_index;
754                    buffers.extend(falsy_buffers);
755                    let byte_view_falsy =
756                        byte_view_falsy.with_buffer_index(new_index_falsy_buffers);
757                    byte_view_falsy.as_u128()
758                };
759
760                let total_number_of_bytes = true_count * 16 + (predicate.len() - true_count) * 16;
761                let mut mutable = MutableBuffer::new(total_number_of_bytes);
762                let mut filled = 0;
763
764                SlicesIterator::from(&predicate).for_each(|(start, end)| {
765                    if start > filled {
766                        let false_repeat_count = start - filled;
767                        mutable
768                            .repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
769                    }
770                    let true_repeat_count = end - start;
771                    mutable.repeat_slice_n_times(truthy_view.to_byte_slice(), true_repeat_count);
772                    filled = end;
773                });
774
775                if filled < predicate.len() {
776                    let false_repeat_count = predicate.len() - filled;
777                    mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
778                }
779
780                let bytes = Buffer::from(mutable);
781                (bytes.into(), buffers, None)
782            }
783        }
784    }
785}
786
787impl<T: ByteViewType> Debug for ByteViewScalarImpl<T> {
788    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
789        f.debug_struct("ByteViewScalarImpl")
790            .field("truthy", &self.truthy_view)
791            .field("falsy", &self.falsy_view)
792            .finish()
793    }
794}
795
796impl<T: ByteViewType> ZipImpl for ByteViewScalarImpl<T> {
797    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
798        let result_len = predicate.len();
799        // Nulls are treated as false
800        let predicate = maybe_prep_null_mask_filter(predicate);
801
802        let (views, buffers, nulls) = match (self.truthy_view, self.falsy_view) {
803            (Some(truthy), Some(falsy)) => Self::get_views_for_non_nullable(
804                predicate,
805                result_len,
806                truthy,
807                self.truthy_buffers.clone(),
808                falsy,
809                self.falsy_buffers.clone(),
810            ),
811            (Some(truthy), None) => Self::get_views_for_single_non_nullable(
812                predicate,
813                truthy,
814                self.truthy_buffers.clone(),
815            ),
816            (None, Some(falsy)) => {
817                let predicate = predicate.not();
818                Self::get_views_for_single_non_nullable(
819                    predicate,
820                    falsy,
821                    self.falsy_buffers.clone(),
822                )
823            }
824            (None, None) => {
825                // All values are null
826                (
827                    vec![0; result_len].into(),
828                    vec![],
829                    Some(NullBuffer::new_null(result_len)),
830                )
831            }
832        };
833
834        let result = unsafe { GenericByteViewArray::<T>::new_unchecked(views, buffers, nulls) };
835        Ok(Arc::new(result))
836    }
837}
838
839#[cfg(test)]
840mod test {
841    use super::*;
842    use arrow_array::types::Int32Type;
843
844    #[test]
845    fn test_zip_kernel_one() {
846        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
847        let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
848        let mask = BooleanArray::from(vec![true, true, false, false, true]);
849        let out = zip(&mask, &a, &b).unwrap();
850        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
851        let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
852        assert_eq!(actual, &expected);
853    }
854
855    #[test]
856    fn test_zip_kernel_two() {
857        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
858        let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
859        let mask = BooleanArray::from(vec![false, false, true, true, false]);
860        let out = zip(&mask, &a, &b).unwrap();
861        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
862        let expected = Int32Array::from(vec![None, Some(3), Some(7), None, Some(3)]);
863        assert_eq!(actual, &expected);
864    }
865
866    #[test]
867    fn test_zip_kernel_scalar_falsy_1() {
868        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
869
870        let fallback = Scalar::new(Int32Array::from_value(42, 1));
871
872        let mask = BooleanArray::from(vec![true, true, false, false, true]);
873        let out = zip(&mask, &a, &fallback).unwrap();
874        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
875        let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
876        assert_eq!(actual, &expected);
877    }
878
879    #[test]
880    fn test_zip_kernel_scalar_falsy_2() {
881        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
882
883        let fallback = Scalar::new(Int32Array::from_value(42, 1));
884
885        let mask = BooleanArray::from(vec![false, false, true, true, false]);
886        let out = zip(&mask, &a, &fallback).unwrap();
887        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
888        let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
889        assert_eq!(actual, &expected);
890    }
891
892    #[test]
893    fn test_zip_kernel_scalar_truthy_1() {
894        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
895
896        let fallback = Scalar::new(Int32Array::from_value(42, 1));
897
898        let mask = BooleanArray::from(vec![true, true, false, false, true]);
899        let out = zip(&mask, &fallback, &a).unwrap();
900        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
901        let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
902        assert_eq!(actual, &expected);
903    }
904
905    #[test]
906    fn test_zip_kernel_scalar_truthy_2() {
907        let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
908
909        let fallback = Scalar::new(Int32Array::from_value(42, 1));
910
911        let mask = BooleanArray::from(vec![false, false, true, true, false]);
912        let out = zip(&mask, &fallback, &a).unwrap();
913        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
914        let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
915        assert_eq!(actual, &expected);
916    }
917
918    #[test]
919    fn test_zip_kernel_scalar_both_mask_ends_with_true() {
920        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
921        let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
922
923        let mask = BooleanArray::from(vec![true, true, false, false, true]);
924        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
925        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
926        let expected = Int32Array::from(vec![Some(42), Some(42), Some(123), Some(123), Some(42)]);
927        assert_eq!(actual, &expected);
928    }
929
930    #[test]
931    fn test_zip_kernel_scalar_both_mask_ends_with_false() {
932        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
933        let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
934
935        let mask = BooleanArray::from(vec![true, true, false, true, false, false]);
936        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
937        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
938        let expected = Int32Array::from(vec![
939            Some(42),
940            Some(42),
941            Some(123),
942            Some(42),
943            Some(123),
944            Some(123),
945        ]);
946        assert_eq!(actual, &expected);
947    }
948
949    #[test]
950    fn test_zip_kernel_primitive_scalar_none_1() {
951        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
952        let scalar_falsy = Scalar::new(Int32Array::new_null(1));
953
954        let mask = BooleanArray::from(vec![true, true, false, false, true]);
955        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
956        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
957        let expected = Int32Array::from(vec![Some(42), Some(42), None, None, Some(42)]);
958        assert_eq!(actual, &expected);
959    }
960
961    #[test]
962    fn test_zip_kernel_primitive_scalar_none_2() {
963        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
964        let scalar_falsy = Scalar::new(Int32Array::new_null(1));
965
966        let mask = BooleanArray::from(vec![false, false, true, true, false]);
967        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
968        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
969        let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]);
970        assert_eq!(actual, &expected);
971    }
972
973    #[test]
974    fn test_zip_kernel_primitive_scalar_both_null() {
975        let scalar_truthy = Scalar::new(Int32Array::new_null(1));
976        let scalar_falsy = Scalar::new(Int32Array::new_null(1));
977
978        let mask = BooleanArray::from(vec![false, false, true, true, false]);
979        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
980        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
981        let expected = Int32Array::from(vec![None, None, None, None, None]);
982        assert_eq!(actual, &expected);
983    }
984
985    #[test]
986    fn test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() {
987        let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]);
988        let falsy = Int32Array::from_iter_values(vec![7, 8, 9, 10, 11, 12]);
989
990        let mask = {
991            let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]);
992            let nulls = NullBuffer::from(vec![
993                true, true, true,
994                false, // null treated as false even though in the original mask it was true
995                true, true,
996            ]);
997            BooleanArray::new(booleans, Some(nulls))
998        };
999        let out = zip(&mask, &truthy, &falsy).unwrap();
1000        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
1001        let expected = Int32Array::from(vec![
1002            Some(1),
1003            Some(2),
1004            Some(9),
1005            Some(10), // true in mask but null
1006            Some(11),
1007            Some(12),
1008        ]);
1009        assert_eq!(actual, &expected);
1010    }
1011
1012    #[test]
1013    fn test_zip_kernel_primitive_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false()
1014     {
1015        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
1016        let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
1017
1018        let mask = {
1019            let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]);
1020            let nulls = NullBuffer::from(vec![
1021                true, true, true,
1022                false, // null treated as false even though in the original mask it was true
1023                true, true,
1024            ]);
1025            BooleanArray::new(booleans, Some(nulls))
1026        };
1027        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1028        let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
1029        let expected = Int32Array::from(vec![
1030            Some(42),
1031            Some(42),
1032            Some(123),
1033            Some(123), // true in mask but null
1034            Some(123),
1035            Some(123),
1036        ]);
1037        assert_eq!(actual, &expected);
1038    }
1039
1040    #[test]
1041    fn test_zip_string_array_with_nulls_is_mask_should_be_treated_as_false() {
1042        let truthy = StringArray::from_iter_values(vec!["1", "2", "3", "4", "5", "6"]);
1043        let falsy = StringArray::from_iter_values(vec!["7", "8", "9", "10", "11", "12"]);
1044
1045        let mask = {
1046            let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]);
1047            let nulls = NullBuffer::from(vec![
1048                true, true, true,
1049                false, // null treated as false even though in the original mask it was true
1050                true, true,
1051            ]);
1052            BooleanArray::new(booleans, Some(nulls))
1053        };
1054        let out = zip(&mask, &truthy, &falsy).unwrap();
1055        let actual = out.as_string::<i32>();
1056        let expected = StringArray::from_iter_values(vec![
1057            "1", "2", "9", "10", // true in mask but null
1058            "11", "12",
1059        ]);
1060        assert_eq!(actual, &expected);
1061    }
1062
1063    #[test]
1064    fn test_zip_kernel_large_string_scalar_with_boolean_array_mask_with_nulls_should_be_treated_as_false()
1065     {
1066        let scalar_truthy = Scalar::new(LargeStringArray::from_iter_values(["test"]));
1067        let scalar_falsy = Scalar::new(LargeStringArray::from_iter_values(["something else"]));
1068
1069        let mask = {
1070            let booleans = BooleanBuffer::from(vec![true, true, false, true, false, false]);
1071            let nulls = NullBuffer::from(vec![
1072                true, true, true,
1073                false, // null treated as false even though in the original mask it was true
1074                true, true,
1075            ]);
1076            BooleanArray::new(booleans, Some(nulls))
1077        };
1078        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1079        let actual = out.as_any().downcast_ref::<LargeStringArray>().unwrap();
1080        let expected = LargeStringArray::from_iter(vec![
1081            Some("test"),
1082            Some("test"),
1083            Some("something else"),
1084            Some("something else"), // true in mask but null
1085            Some("something else"),
1086            Some("something else"),
1087        ]);
1088        assert_eq!(actual, &expected);
1089    }
1090
1091    #[test]
1092    fn test_zip_kernel_bytes_scalar_none_1() {
1093        let scalar_truthy = Scalar::new(StringArray::from_iter_values(["hello"]));
1094        let scalar_falsy = Scalar::new(StringArray::new_null(1));
1095
1096        let mask = BooleanArray::from(vec![true, true, false, false, true]);
1097        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1098        let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
1099        let expected = StringArray::from_iter(vec![
1100            Some("hello"),
1101            Some("hello"),
1102            None,
1103            None,
1104            Some("hello"),
1105        ]);
1106        assert_eq!(actual, &expected);
1107    }
1108
1109    #[test]
1110    fn test_zip_kernel_bytes_scalar_none_2() {
1111        let scalar_truthy = Scalar::new(StringArray::new_null(1));
1112        let scalar_falsy = Scalar::new(StringArray::from_iter_values(["hello"]));
1113
1114        let mask = BooleanArray::from(vec![true, true, false, false, true]);
1115        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1116        let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
1117        let expected = StringArray::from_iter(vec![None, None, Some("hello"), Some("hello"), None]);
1118        assert_eq!(actual, &expected);
1119    }
1120
1121    #[test]
1122    fn test_zip_kernel_bytes_scalar_both() {
1123        let scalar_truthy = Scalar::new(StringArray::from_iter_values(["test"]));
1124        let scalar_falsy = Scalar::new(StringArray::from_iter_values(["something else"]));
1125
1126        // mask ends with false
1127        let mask = BooleanArray::from(vec![true, true, false, true, false, false]);
1128        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1129        let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
1130        let expected = StringArray::from_iter(vec![
1131            Some("test"),
1132            Some("test"),
1133            Some("something else"),
1134            Some("test"),
1135            Some("something else"),
1136            Some("something else"),
1137        ]);
1138        assert_eq!(actual, &expected);
1139    }
1140
1141    #[test]
1142    fn test_zip_scalar_bytes_only_taking_one_side() {
1143        let mask_len = 5;
1144        let all_true_mask = BooleanArray::from(vec![true; mask_len]);
1145        let all_false_mask = BooleanArray::from(vec![false; mask_len]);
1146
1147        let null_scalar = Scalar::new(StringArray::new_null(1));
1148        let non_null_scalar_1 = Scalar::new(StringArray::from_iter_values(["test"]));
1149        let non_null_scalar_2 = Scalar::new(StringArray::from_iter_values(["something else"]));
1150
1151        {
1152            // 1. Test where left is null and right is non-null
1153            //    and mask is all true
1154            let out = zip(&all_true_mask, &null_scalar, &non_null_scalar_1).unwrap();
1155            let actual = out.as_string::<i32>();
1156            let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
1157            assert_eq!(actual, &expected);
1158        }
1159
1160        {
1161            // 2. Test where left is null and right is non-null
1162            //    and mask is all false
1163            let out = zip(&all_false_mask, &null_scalar, &non_null_scalar_1).unwrap();
1164            let actual = out.as_string::<i32>();
1165            let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
1166            assert_eq!(actual, &expected);
1167        }
1168
1169        {
1170            // 3. Test where left is non-null and right is null
1171            //    and mask is all true
1172            let out = zip(&all_true_mask, &non_null_scalar_1, &null_scalar).unwrap();
1173            let actual = out.as_string::<i32>();
1174            let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
1175            assert_eq!(actual, &expected);
1176        }
1177
1178        {
1179            // 4. Test where left is non-null and right is null
1180            //    and mask is all false
1181            let out = zip(&all_false_mask, &non_null_scalar_1, &null_scalar).unwrap();
1182            let actual = out.as_string::<i32>();
1183            let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
1184            assert_eq!(actual, &expected);
1185        }
1186
1187        {
1188            // 5. Test where both left and right are not null
1189            //    and mask is all true
1190            let out = zip(&all_true_mask, &non_null_scalar_1, &non_null_scalar_2).unwrap();
1191            let actual = out.as_string::<i32>();
1192            let expected = StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
1193            assert_eq!(actual, &expected);
1194        }
1195
1196        {
1197            // 6. Test where both left and right are not null
1198            //    and mask is all false
1199            let out = zip(&all_false_mask, &non_null_scalar_1, &non_null_scalar_2).unwrap();
1200            let actual = out.as_string::<i32>();
1201            let expected =
1202                StringArray::from_iter(std::iter::repeat_n(Some("something else"), mask_len));
1203            assert_eq!(actual, &expected);
1204        }
1205
1206        {
1207            // 7. Test where both left and right are null
1208            //    and mask is random
1209            let mask = BooleanArray::from(vec![true, false, true, false, true]);
1210            let out = zip(&mask, &null_scalar, &null_scalar).unwrap();
1211            let actual = out.as_string::<i32>();
1212            let expected = StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
1213            assert_eq!(actual, &expected);
1214        }
1215    }
1216
1217    #[test]
1218    fn test_scalar_zipper() {
1219        let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
1220        let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
1221
1222        let mask = BooleanArray::from(vec![false, false, true, true, false]);
1223
1224        let scalar_zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap();
1225        let out = scalar_zipper.zip(&mask).unwrap();
1226        let actual = out.as_primitive::<Int32Type>();
1227        let expected = Int32Array::from(vec![Some(123), Some(123), Some(42), Some(42), Some(123)]);
1228        assert_eq!(actual, &expected);
1229
1230        // test with different mask length as well
1231        let mask = BooleanArray::from(vec![true, false, true]);
1232        let out = scalar_zipper.zip(&mask).unwrap();
1233        let actual = out.as_primitive::<Int32Type>();
1234        let expected = Int32Array::from(vec![Some(42), Some(123), Some(42)]);
1235        assert_eq!(actual, &expected);
1236    }
1237
1238    #[test]
1239    fn test_zip_kernel_scalar_strings() {
1240        let scalar_truthy = Scalar::new(StringArray::from(vec!["hello"]));
1241        let scalar_falsy = Scalar::new(StringArray::from(vec!["world"]));
1242
1243        let mask = BooleanArray::from(vec![true, false, true, false, true]);
1244        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1245        let actual = out.as_string::<i32>();
1246        let expected = StringArray::from(vec![
1247            Some("hello"),
1248            Some("world"),
1249            Some("hello"),
1250            Some("world"),
1251            Some("hello"),
1252        ]);
1253        assert_eq!(actual, &expected);
1254    }
1255
1256    #[test]
1257    fn test_zip_kernel_scalar_binary() {
1258        let truthy_bytes: &[u8] = b"\xFF\xFE\xFD";
1259        let falsy_bytes: &[u8] = b"world";
1260        let scalar_truthy = Scalar::new(BinaryArray::from_iter_values(
1261            // Non valid UTF8 bytes
1262            vec![truthy_bytes],
1263        ));
1264        let scalar_falsy = Scalar::new(BinaryArray::from_iter_values(vec![falsy_bytes]));
1265
1266        let mask = BooleanArray::from(vec![true, false, true, false, true]);
1267        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1268        let actual = out.as_binary::<i32>();
1269        let expected = BinaryArray::from(vec![
1270            Some(truthy_bytes),
1271            Some(falsy_bytes),
1272            Some(truthy_bytes),
1273            Some(falsy_bytes),
1274            Some(truthy_bytes),
1275        ]);
1276        assert_eq!(actual, &expected);
1277    }
1278
1279    #[test]
1280    fn test_zip_kernel_scalar_large_binary() {
1281        let truthy_bytes: &[u8] = b"hey";
1282        let falsy_bytes: &[u8] = b"world";
1283        let scalar_truthy = Scalar::new(LargeBinaryArray::from_iter_values(vec![truthy_bytes]));
1284        let scalar_falsy = Scalar::new(LargeBinaryArray::from_iter_values(vec![falsy_bytes]));
1285
1286        let mask = BooleanArray::from(vec![true, false, true, false, true]);
1287        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1288        let actual = out.as_binary::<i64>();
1289        let expected = LargeBinaryArray::from(vec![
1290            Some(truthy_bytes),
1291            Some(falsy_bytes),
1292            Some(truthy_bytes),
1293            Some(falsy_bytes),
1294            Some(truthy_bytes),
1295        ]);
1296        assert_eq!(actual, &expected);
1297    }
1298
1299    // Test to ensure that the precision and scale are kept when zipping Decimal128 data
1300    #[test]
1301    fn test_zip_decimal_with_custom_precision_and_scale() {
1302        let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432])
1303            .with_precision_and_scale(20, 2)
1304            .unwrap();
1305
1306        let arr: ArrayRef = Arc::new(arr);
1307
1308        let scalar_1 = Scalar::new(arr.slice(0, 1));
1309        let scalar_2 = Scalar::new(arr.slice(1, 1));
1310        let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1));
1311        let array_1: ArrayRef = arr.slice(0, 2);
1312        let array_2: ArrayRef = arr.slice(2, 2);
1313
1314        test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar, array_1, array_2);
1315    }
1316
1317    // Test to ensure that the timezone is kept when zipping TimestampArray data
1318    #[test]
1319    fn test_zip_timestamp_with_timezone() {
1320        let arr = TimestampSecondArray::from(vec![0, 1000, 2000, 4000])
1321            .with_timezone("+01:00".to_string());
1322
1323        let arr: ArrayRef = Arc::new(arr);
1324
1325        let scalar_1 = Scalar::new(arr.slice(0, 1));
1326        let scalar_2 = Scalar::new(arr.slice(1, 1));
1327        let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1));
1328        let array_1: ArrayRef = arr.slice(0, 2);
1329        let array_2: ArrayRef = arr.slice(2, 2);
1330
1331        test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar, array_1, array_2);
1332    }
1333
1334    fn test_zip_output_data_types_for_input(
1335        scalar_1: Scalar<ArrayRef>,
1336        scalar_2: Scalar<ArrayRef>,
1337        null_scalar: Scalar<ArrayRef>,
1338        array_1: ArrayRef,
1339        array_2: ArrayRef,
1340    ) {
1341        // non null Scalar vs non null Scalar
1342        test_zip_output_data_type(&scalar_1, &scalar_2, 10);
1343
1344        // null Scalar vs non-null Scalar (and vice versa)
1345        test_zip_output_data_type(&null_scalar, &scalar_1, 10);
1346        test_zip_output_data_type(&scalar_1, &null_scalar, 10);
1347
1348        // non-null Scalar and array (and vice versa)
1349        test_zip_output_data_type(&array_1.as_ref(), &scalar_1, array_1.len());
1350        test_zip_output_data_type(&scalar_1, &array_1.as_ref(), array_1.len());
1351
1352        // Array and null scalar (and vice versa)
1353        test_zip_output_data_type(&array_1.as_ref(), &null_scalar, array_1.len());
1354
1355        test_zip_output_data_type(&null_scalar, &array_1.as_ref(), array_1.len());
1356
1357        // Both arrays
1358        test_zip_output_data_type(&array_1.as_ref(), &array_2.as_ref(), array_1.len());
1359    }
1360
1361    fn test_zip_output_data_type(truthy: &dyn Datum, falsy: &dyn Datum, mask_length: usize) {
1362        let expected_data_type = truthy.get().0.data_type().clone();
1363        assert_eq!(&expected_data_type, falsy.get().0.data_type());
1364
1365        // Try different masks to test different paths
1366        let mask_all_true = BooleanArray::from(vec![true; mask_length]);
1367        let mask_all_false = BooleanArray::from(vec![false; mask_length]);
1368        let mask_some_true_and_false =
1369            BooleanArray::from((0..mask_length).map(|i| i % 2 == 0).collect::<Vec<bool>>());
1370
1371        for mask in [&mask_all_true, &mask_all_false, &mask_some_true_and_false] {
1372            let out = zip(mask, truthy, falsy).unwrap();
1373            assert_eq!(out.data_type(), &expected_data_type);
1374        }
1375    }
1376
1377    #[test]
1378    fn zip_scalar_fallback_impl() {
1379        let truthy_list_item_scalar = Some(vec![Some(1), None, Some(3)]);
1380        let truthy_list_array_scalar =
1381            Scalar::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1382                truthy_list_item_scalar.clone(),
1383            ]));
1384        let falsy_list_item_scalar = Some(vec![None, Some(2), Some(4)]);
1385        let falsy_list_array_scalar =
1386            Scalar::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1387                falsy_list_item_scalar.clone(),
1388            ]));
1389        let mask = BooleanArray::from(vec![true, false, true, false, false, true, false]);
1390        let out = zip(&mask, &truthy_list_array_scalar, &falsy_list_array_scalar).unwrap();
1391        let actual = out.as_list::<i32>();
1392
1393        let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1394            truthy_list_item_scalar.clone(),
1395            falsy_list_item_scalar.clone(),
1396            truthy_list_item_scalar.clone(),
1397            falsy_list_item_scalar.clone(),
1398            falsy_list_item_scalar.clone(),
1399            truthy_list_item_scalar.clone(),
1400            falsy_list_item_scalar.clone(),
1401        ]);
1402        assert_eq!(actual, &expected);
1403    }
1404
1405    #[test]
1406    fn test_zip_kernel_scalar_strings_array_view() {
1407        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1408        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1409
1410        let mask = BooleanArray::from(vec![true, false, true, false]);
1411        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1412        let actual = out.as_string_view();
1413        let expected = StringViewArray::from(vec![
1414            Some("hello"),
1415            Some("world"),
1416            Some("hello"),
1417            Some("world"),
1418        ]);
1419        assert_eq!(actual, &expected);
1420    }
1421
1422    #[test]
1423    fn test_zip_kernel_scalar_binary_array_view() {
1424        let scalar_truthy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"hello"]));
1425        let scalar_falsy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"world"]));
1426
1427        let mask = BooleanArray::from(vec![true, false]);
1428        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1429        let actual = out.as_byte_view();
1430        let expected = BinaryViewArray::from_iter_values(vec![b"hello", b"world"]);
1431        assert_eq!(actual, &expected);
1432    }
1433
1434    #[test]
1435    fn test_zip_kernel_scalar_strings_array_view_with_nulls() {
1436        let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"]));
1437        let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1438
1439        let mask = BooleanArray::from(vec![true, true, false, false, true]);
1440        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1441        let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1442        let expected = StringViewArray::from_iter(vec![
1443            Some("hello"),
1444            Some("hello"),
1445            None,
1446            None,
1447            Some("hello"),
1448        ]);
1449        assert_eq!(actual, &expected);
1450    }
1451
1452    #[test]
1453    fn test_zip_kernel_scalar_strings_array_view_all_true_null() {
1454        let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
1455        let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1456        let mask = BooleanArray::from(vec![true, true]);
1457        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1458        let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1459        let expected = StringViewArray::from_iter(vec![None::<String>, None]);
1460        assert_eq!(actual, &expected);
1461    }
1462
1463    #[test]
1464    fn test_zip_kernel_scalar_strings_array_view_all_false_null() {
1465        let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
1466        let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1467        let mask = BooleanArray::from(vec![false, false]);
1468        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1469        let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1470        let expected = StringViewArray::from_iter(vec![None::<String>, None]);
1471        assert_eq!(actual, &expected);
1472    }
1473
1474    #[test]
1475    fn test_zip_kernel_scalar_string_array_view_all_true() {
1476        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1477        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1478
1479        let mask = BooleanArray::from(vec![true, true]);
1480        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1481        let actual = out.as_string_view();
1482        let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]);
1483        assert_eq!(actual, &expected);
1484    }
1485
1486    #[test]
1487    fn test_zip_kernel_scalar_string_array_view_all_false() {
1488        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1489        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1490
1491        let mask = BooleanArray::from(vec![false, false]);
1492        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1493        let actual = out.as_string_view();
1494        let expected = StringViewArray::from(vec![Some("world"), Some("world")]);
1495        assert_eq!(actual, &expected);
1496    }
1497
1498    #[test]
1499    fn test_zip_kernel_scalar_strings_large_strings() {
1500        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1501        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1502
1503        let mask = BooleanArray::from(vec![true, false]);
1504        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1505        let actual = out.as_string_view();
1506        let expected = StringViewArray::from(vec![
1507            Some("longer than 12 bytes"),
1508            Some("another longer than 12 bytes"),
1509        ]);
1510        assert_eq!(actual, &expected);
1511    }
1512
1513    #[test]
1514    fn test_zip_kernel_scalar_strings_array_view_large_short_strings() {
1515        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1516        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1517
1518        let mask = BooleanArray::from(vec![true, false, true, false]);
1519        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1520        let actual = out.as_string_view();
1521        let expected = StringViewArray::from(vec![
1522            Some("hello"),
1523            Some("longer than 12 bytes"),
1524            Some("hello"),
1525            Some("longer than 12 bytes"),
1526        ]);
1527        assert_eq!(actual, &expected);
1528    }
1529    #[test]
1530    fn test_zip_kernel_scalar_strings_array_view_large_all_true() {
1531        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1532        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1533
1534        let mask = BooleanArray::from(vec![true, true]);
1535        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1536        let actual = out.as_string_view();
1537        let expected = StringViewArray::from(vec![
1538            Some("longer than 12 bytes"),
1539            Some("longer than 12 bytes"),
1540        ]);
1541        assert_eq!(actual, &expected);
1542    }
1543
1544    #[test]
1545    fn test_zip_kernel_scalar_strings_array_view_large_all_false() {
1546        let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1547        let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1548
1549        let mask = BooleanArray::from(vec![false, false]);
1550        let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1551        let actual = out.as_string_view();
1552        let expected = StringViewArray::from(vec![
1553            Some("another longer than 12 bytes"),
1554            Some("another longer than 12 bytes"),
1555        ]);
1556        assert_eq!(actual, &expected);
1557    }
1558}