use crate::filter::{SlicesIterator, prep_null_mask_filter};
use crate::zip::zip;
use arrow_array::{Array, ArrayRef, BooleanArray, Datum, make_array, new_empty_array};
use arrow_data::ArrayData;
use arrow_data::transform::MutableArrayData;
use arrow_schema::ArrowError;
pub trait MergeIndex: PartialEq + Eq + Copy {
fn index(&self) -> Option<usize>;
}
impl MergeIndex for usize {
fn index(&self) -> Option<usize> {
Some(*self)
}
}
impl MergeIndex for Option<usize> {
fn index(&self) -> Option<usize> {
*self
}
}
pub fn merge_n(values: &[&dyn Array], indices: &[impl MergeIndex]) -> Result<ArrayRef, ArrowError> {
if values.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"merge_n requires at least one value array".to_string(),
));
}
let data_type = values[0].data_type();
for array in values.iter().skip(1) {
if array.data_type() != data_type {
return Err(ArrowError::InvalidArgumentError(format!(
"It is not possible to merge arrays of different data types ({} and {})",
data_type,
array.data_type()
)));
}
}
if indices.is_empty() {
return Ok(new_empty_array(data_type));
}
#[cfg(debug_assertions)]
for ix in indices {
if let Some(index) = ix.index() {
assert!(
index < values.len(),
"Index out of bounds: {} >= {}",
index,
values.len()
);
}
}
let data: Vec<ArrayData> = values.iter().map(|a| a.to_data()).collect();
let data_refs = data.iter().collect();
let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
let mut take_offsets = vec![0; values.len() + 1];
let mut start_row_ix = 0;
loop {
let array_ix = indices[start_row_ix];
let mut end_row_ix = start_row_ix + 1;
while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
end_row_ix += 1;
}
let slice_length = end_row_ix - start_row_ix;
match array_ix.index() {
None => mutable.extend_nulls(slice_length),
Some(index) => {
let start_offset = take_offsets[index];
let end_offset = start_offset + slice_length;
mutable.extend(index, start_offset, end_offset);
take_offsets[index] = end_offset;
}
}
if end_row_ix == indices.len() {
break;
} else {
start_row_ix = end_row_ix;
}
}
Ok(make_array(mutable.freeze()))
}
pub fn merge(
mask: &BooleanArray,
truthy: &dyn Datum,
falsy: &dyn Datum,
) -> Result<ArrayRef, ArrowError> {
let (truthy_array, truthy_is_scalar) = truthy.get();
let (falsy_array, falsy_is_scalar) = falsy.get();
if truthy_is_scalar && falsy_is_scalar {
return zip(mask, truthy, falsy);
}
if truthy_array.data_type() != falsy_array.data_type() {
return Err(ArrowError::InvalidArgumentError(
"arguments need to have the same data type".into(),
));
}
if truthy_is_scalar && truthy_array.len() != 1 {
return Err(ArrowError::InvalidArgumentError(
"scalar arrays must have 1 element".into(),
));
}
if falsy_is_scalar && falsy_array.len() != 1 {
return Err(ArrowError::InvalidArgumentError(
"scalar arrays must have 1 element".into(),
));
}
let falsy = falsy_array.to_data();
let truthy = truthy_array.to_data();
let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, mask.len());
let mut filled = 0;
let mut falsy_offset = 0;
let mut truthy_offset = 0;
let mask_buffer = match mask.null_count() {
0 => mask.values().clone(),
_ => prep_null_mask_filter(mask).into_parts().0,
};
SlicesIterator::from(&mask_buffer).for_each(|(start, end)| {
if start > filled {
if falsy_is_scalar {
for _ in filled..start {
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = start - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
falsy_offset = falsy_end;
}
}
if truthy_is_scalar {
for _ in start..end {
mutable.extend(0, 0, 1);
}
} else {
let truthy_length = end - start;
let truthy_end = truthy_offset + truthy_length;
mutable.extend(0, truthy_offset, truthy_end);
truthy_offset = truthy_end;
}
filled = end;
});
if filled < mask.len() {
if falsy_is_scalar {
for _ in filled..mask.len() {
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = mask.len() - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
}
}
let data = mutable.freeze();
Ok(make_array(data))
}
#[cfg(test)]
mod tests {
use crate::merge::{MergeIndex, merge, merge_n};
use arrow_array::cast::AsArray;
use arrow_array::{Array, BooleanArray, Datum, Int32Array, Scalar, StringArray, UInt64Array};
use arrow_schema::ArrowError::InvalidArgumentError;
#[derive(PartialEq, Eq, Copy, Clone)]
struct CompactMergeIndex {
index: u8,
}
impl MergeIndex for CompactMergeIndex {
fn index(&self) -> Option<usize> {
if self.index == u8::MAX {
None
} else {
Some(self.index as usize)
}
}
}
#[test]
fn test_merge() {
let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
let a2 = StringArray::from(vec![Some("C"), Some("D")]);
let indices = BooleanArray::from(vec![true, false, true, false, true, true]);
let merged = merge(&indices, &a1, &a2).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), indices.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "C");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "B");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "D");
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "E");
assert!(!merged.is_valid(5));
}
#[test]
fn test_merge_null_is_false() {
let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
let a2 = StringArray::from(vec![Some("C"), Some("D")]);
let indices = BooleanArray::from(vec![
Some(true),
None,
Some(true),
None,
Some(true),
Some(true),
]);
let merged = merge(&indices, &a1, &a2).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), indices.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "C");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "B");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "D");
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "E");
assert!(!merged.is_valid(5));
}
#[test]
fn test_merge_false_tail() {
let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
let a2 = StringArray::from(vec![Some("C"), Some("D"), None, Some("F")]);
let indices = BooleanArray::from(vec![true, false, true, false, true, true, false, false]);
let merged = merge(&indices, &a1, &a2).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), indices.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "C");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "B");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "D");
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "E");
assert!(!merged.is_valid(5));
assert!(!merged.is_valid(6));
assert!(merged.is_valid(7));
assert_eq!(merged.value(7), "F");
}
#[test]
fn test_merge_scalars() {
let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
let falsy = Scalar::new(StringArray::from(vec![Some("B")]));
let mask = BooleanArray::from(vec![true, false, false, true]);
let merged = merge(&mask, &truthy, &falsy).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), mask.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "B");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "B");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "A");
}
#[test]
fn test_merge_scalar_and_array() {
let truthy = Scalar::new(StringArray::from(vec![Some("A")]));
let falsy = StringArray::from(vec![Some("B"), Some("C")]);
let mask = BooleanArray::from(vec![true, false, false, true]);
let merged = merge(&mask, &truthy, &falsy).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), mask.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "B");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "C");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "A");
}
#[test]
fn test_merge_array_and_scalar() {
let truthy = StringArray::from(vec![Some("B"), Some("C")]);
let falsy = Scalar::new(StringArray::from(vec![Some("A")]));
let mask = BooleanArray::from(vec![true, false, false, true, false, false]);
let merged = merge(&mask, &truthy, &falsy).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), mask.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "B");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "A");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "A");
assert!(merged.is_valid(3));
assert_eq!(merged.value(3), "C");
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "A");
assert!(merged.is_valid(5));
assert_eq!(merged.value(5), "A");
}
#[test]
fn test_merge_empty_mask() {
let a1 = StringArray::from(vec![Some("A")]);
let a2 = StringArray::from(vec![Some("B")]);
let mask: Vec<bool> = vec![];
let mask = BooleanArray::from(mask);
let result = merge(&mask, &a1, &a2).unwrap();
assert_eq!(result.len(), 0);
}
#[derive(Debug, Copy, Clone)]
pub struct UnsafeScalar<T: Array>(T);
impl<T: Array> Datum for UnsafeScalar<T> {
fn get(&self) -> (&dyn Array, bool) {
(&self.0, true)
}
}
#[test]
fn test_merge_invalid_truthy_scalar() {
let truthy = UnsafeScalar(StringArray::from(vec![Some("A"), Some("C")]));
let falsy = StringArray::from(vec![Some("B"), Some("D")]);
let mask = BooleanArray::from(vec![true, false, true, false]);
let merged = merge(&mask, &truthy, &falsy);
assert!(matches!(merged, Err(InvalidArgumentError { .. })));
}
#[test]
fn test_merge_invalid_falsy_scalar() {
let truthy = StringArray::from(vec![Some("A"), Some("C")]);
let falsy = UnsafeScalar(StringArray::from(vec![Some("B"), Some("D")]));
let mask = vec![true, false, true, false];
let mask = BooleanArray::from(mask);
let merged = merge(&mask, &truthy, &falsy);
assert!(matches!(merged, Err(InvalidArgumentError { .. })));
}
#[test]
fn test_merge_incompatible_arrays() {
let truthy = StringArray::from(vec![Some("A"), Some("B")]);
let falsy = Int32Array::from(vec![1, 2]);
let mask = BooleanArray::from(vec![true, false, true, false]);
let merged = merge(&mask, &truthy, &falsy);
assert!(matches!(merged, Err(InvalidArgumentError { .. })));
}
#[test]
fn test_merge_n() {
let a1 = StringArray::from(vec![Some("A")]);
let a2 = StringArray::from(vec![Some("B"), None, None]);
let a3 = StringArray::from(vec![Some("C"), Some("D")]);
let indices = vec![
CompactMergeIndex { index: u8::MAX },
CompactMergeIndex { index: 1 },
CompactMergeIndex { index: 0 },
CompactMergeIndex { index: u8::MAX },
CompactMergeIndex { index: 2 },
CompactMergeIndex { index: 2 },
CompactMergeIndex { index: 1 },
CompactMergeIndex { index: 1 },
];
let arrays = [a1, a2, a3];
let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
let merged = merge_n(&array_refs, &indices).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), indices.len());
assert!(!merged.is_valid(0));
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "B");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "A");
assert!(!merged.is_valid(3));
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "C");
assert!(merged.is_valid(5));
assert_eq!(merged.value(5), "D");
assert!(!merged.is_valid(6));
assert!(!merged.is_valid(7));
}
#[test]
#[should_panic]
fn test_merge_n_invalid_indices() {
let a1 = StringArray::from(vec![Some("A")]);
let indices = vec![CompactMergeIndex { index: 99 }];
let arrays = [a1];
let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
let _ = merge_n(&array_refs, &indices);
}
#[test]
fn test_merge_n_empty_indices() {
let a1 = StringArray::from(vec![Some("A")]);
let a2 = StringArray::from(vec![Some("B"), None, None]);
let a3 = StringArray::from(vec![Some("C"), Some("D")]);
let indices: Vec<CompactMergeIndex> = vec![];
let arrays = [a1, a2, a3];
let array_refs = arrays.iter().map(|a| a as &dyn Array).collect::<Vec<_>>();
let merged = merge_n(&array_refs, &indices).unwrap();
assert_eq!(merged.len(), indices.len());
}
#[test]
fn test_merge_n_empty_values() {
let indices: Vec<CompactMergeIndex> = vec![];
let arrays: Vec<&dyn Array> = vec![];
let merged = merge_n(&arrays, &indices);
assert!(matches!(merged, Err(InvalidArgumentError { .. })));
}
#[test]
fn test_merge_n_incompatible_arrays() {
let a1: Box<dyn Array> = Box::new(StringArray::from(vec![Some("A")]));
let a2: Box<dyn Array> = Box::new(Int32Array::from(vec![1, 2, 3]));
let a3: Box<dyn Array> = Box::new(UInt64Array::from(vec![42, 314]));
let indices: Vec<CompactMergeIndex> = vec![];
let arrays = [a1.as_ref(), a2.as_ref(), a3.as_ref()];
let merged = merge_n(&arrays, &indices);
assert!(matches!(merged, Err(InvalidArgumentError { .. })));
}
}