use crate::array::{
Array, ArrayAccessor, ArrayData, ArrayIter, ArrayRef, BufferBuilder, DictionaryArray,
PrimitiveArray,
};
use crate::buffer::Buffer;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType};
use crate::downcast_dictionary_array;
use crate::error::{ArrowError, Result};
use crate::util::bit_iterator::try_for_each_valid_idx;
use arrow_buffer::MutableBuffer;
use std::sync::Arc;
#[inline]
unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
len: usize,
buffer: Buffer,
null_count: usize,
null_buffer: Option<Buffer>,
) -> PrimitiveArray<O> {
PrimitiveArray::from(ArrayData::new_unchecked(
O::DATA_TYPE,
len,
Some(null_count),
null_buffer,
0,
vec![buffer],
vec![],
))
}
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> O::Native,
{
array.unary(op)
}
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native>,
{
array.try_unary(op)
}
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op).into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);
let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
if array.value_type() != T::DATA_TYPE {
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation on dictionary array of value type {}",
array.value_type()
)));
}
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);
let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native>,
{
downcast_dictionary_array! {
array => if array.values().data_type() == &T::DATA_TYPE {
try_unary_dict::<_, F, T>(array, op)
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of type {}",
array.data_type()
)))
},
t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)?))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}
pub fn binary<A, B, F, O>(
a: &PrimitiveArray<A>,
b: &PrimitiveArray<B>,
op: F,
) -> Result<PrimitiveArray<O>>
where
A: ArrowPrimitiveType,
B: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(A::Native, B::Native) -> O::Native,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
));
}
let len = a.len();
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
let null_count = null_buffer
.as_ref()
.map(|x| len - x.count_set_bits())
.unwrap_or_default();
let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r));
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) })
}
pub fn try_binary<A: ArrayAccessor, B: ArrayAccessor, F, O>(
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
"Cannot perform a binary operation on arrays of different length".to_string(),
));
}
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
let len = a.len();
if a.null_count() == 0 && b.null_count() == 0 {
try_binary_no_nulls(len, a, b, op)
} else {
let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap();
let null_count = null_buffer
.as_ref()
.map(|x| len - x.count_set_bits())
.unwrap_or_default();
let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(len);
let slice = buffer.as_slice_mut();
try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
unsafe {
*slice.get_unchecked_mut(idx) =
op(a.value_unchecked(idx), b.value_unchecked(idx))?
};
Ok::<_, ArrowError>(())
})?;
Ok(unsafe {
build_primitive_array(len, buffer.finish(), null_count, null_buffer)
})
}
}
#[inline(never)]
fn try_binary_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Result<O::Native>,
{
let mut buffer = MutableBuffer::new(len * O::get_byte_width());
for idx in 0..len {
unsafe {
buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
};
}
Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
}
#[inline(never)]
fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
let mut buffer = Vec::with_capacity(10);
for idx in 0..len {
unsafe {
buffer.push(op(a.value_unchecked(idx), b.value_unchecked(idx)));
};
}
Ok(buffer.iter().collect())
}
pub(crate) fn binary_opt<A: ArrayAccessor + Array, B: ArrayAccessor + Array, F, O>(
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
"Cannot perform binary operation on arrays of different length".to_string(),
));
}
if a.is_empty() {
return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
}
if a.null_count() == 0 && b.null_count() == 0 {
return try_binary_opt_no_nulls(a.len(), a, b, op);
}
let iter_a = ArrayIter::new(a);
let iter_b = ArrayIter::new(b);
let values = iter_a
.into_iter()
.zip(iter_b.into_iter())
.map(|(item_a, item_b)| {
if let (Some(a), Some(b)) = (item_a, item_b) {
op(a, b)
} else {
None
}
});
Ok(values.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::{as_primitive_array, Float64Array, PrimitiveDictionaryBuilder};
use crate::datatypes::{Float64Type, Int32Type, Int8Type};
#[test]
fn test_unary_f64_slice() {
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
let input_slice = input.slice(1, 4);
let input_slice: &Float64Array = as_primitive_array(&input_slice);
let result = unary(input_slice, |n| n.round());
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
);
let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap();
assert_eq!(
result.as_any().downcast_ref::<Float64Array>().unwrap(),
&Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
);
}
#[test]
fn test_unary_dict_and_unary_dyn() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(5).unwrap();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append_null();
builder.append(9).unwrap();
let dictionary_array = builder.finish();
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
builder.append_null();
builder.append(10).unwrap();
let expected = builder.finish();
let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
}
}