use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray};
use crate::buffer::Buffer;
use crate::datatypes::{
ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use std::sync::Arc;
#[inline]
fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
array: &PrimitiveArray<I>,
buffer: Buffer,
) -> ArrayData {
unsafe {
ArrayData::new_unchecked(
O::DATA_TYPE,
array.len(),
None,
array
.data_ref()
.null_buffer()
.map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![buffer],
vec![],
)
}
}
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> O::Native,
{
let values = array.values().iter().map(|v| op(*v));
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
let data = into_primitive_array_data::<_, O>(array, buffer);
PrimitiveArray::<O>::from(data)
}
#[allow(clippy::redundant_closure)]
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let dict_values = array
.values()
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap();
let values = dict_values
.iter()
.map(|v| v.map(|value| op(value)))
.collect::<PrimitiveArray<T>>();
let keys = array.keys();
let mut data = ArrayData::builder(array.data_type().clone())
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
.add_child_data(values.data().clone());
match keys.data().null_buffer() {
Some(buffer) if keys.data().null_count() > 0 => {
data = data
.null_bit_buffer(Some(buffer.clone()))
.null_count(keys.data().null_count());
}
_ => data = data.null_count(0),
}
let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
match array.data_type() {
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
op,
),
DataType::Int16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int16Type>>()
.unwrap(),
op,
),
DataType::Int32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap(),
op,
),
DataType::Int64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int64Type>>()
.unwrap(),
op,
),
DataType::UInt8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt8Type>>()
.unwrap(),
op,
),
DataType::UInt16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt16Type>>()
.unwrap(),
op,
),
DataType::UInt32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt32Type>>()
.unwrap(),
op,
),
DataType::UInt64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt64Type>>()
.unwrap(),
op,
),
t => Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of key type {}.",
t
))),
},
_ => Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::{
as_primitive_array, Float64Array, PrimitiveBuilder, PrimitiveDictionaryBuilder,
};
use crate::datatypes::{Float64Type, Int32Type, Int8Type};
#[test]
fn test_unary_f64_slice() {
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
let input_slice = input.slice(1, 4);
let input_slice: &Float64Array = as_primitive_array(&input_slice);
let result = unary(input_slice, |n| n.round());
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
);
let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap();
assert_eq!(
result.as_any().downcast_ref::<Float64Array>().unwrap(),
&Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
);
}
#[test]
fn test_unary_dict_and_unary_dyn() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append_null();
builder.append(9).unwrap();
let dictionary_array = builder.finish();
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
builder.append_null();
builder.append(10).unwrap();
let expected = builder.finish();
let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
}
}