use arrow::{
array::{Array, ArrayRef, Date32Array, Date64Array, NullArray},
compute::{CastOptions, kernels, max, min},
datatypes::{DataType, Field},
util::pretty::pretty_format_columns,
};
use datafusion_common::internal_datafusion_err;
use datafusion_common::{
Result, ScalarValue,
format::DEFAULT_CAST_OPTIONS,
internal_err,
scalar::{date_to_timestamp_multiplier, ensure_timestamp_in_bounds},
};
use std::fmt;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub enum ColumnarValue {
Array(ArrayRef),
Scalar(ScalarValue),
}
impl From<ArrayRef> for ColumnarValue {
fn from(value: ArrayRef) -> Self {
ColumnarValue::Array(value)
}
}
impl From<ScalarValue> for ColumnarValue {
fn from(value: ScalarValue) -> Self {
ColumnarValue::Scalar(value)
}
}
impl ColumnarValue {
pub fn data_type(&self) -> DataType {
match self {
ColumnarValue::Array(array_value) => array_value.data_type().clone(),
ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(),
}
}
pub fn into_array(self, num_rows: usize) -> Result<ArrayRef> {
Ok(match self {
ColumnarValue::Array(array) => array,
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?,
})
}
pub fn into_array_of_size(self, num_rows: usize) -> Result<ArrayRef> {
match self {
ColumnarValue::Array(array) => {
if array.len() == num_rows {
Ok(array)
} else {
internal_err!(
"Array length {} does not match expected length {}",
array.len(),
num_rows
)
}
}
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows),
}
}
pub fn to_array(&self, num_rows: usize) -> Result<ArrayRef> {
Ok(match self {
ColumnarValue::Array(array) => Arc::clone(array),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?,
})
}
pub fn to_array_of_size(&self, num_rows: usize) -> Result<ArrayRef> {
match self {
ColumnarValue::Array(array) => {
if array.len() == num_rows {
Ok(Arc::clone(array))
} else {
internal_err!(
"Array length {} does not match expected length {}",
array.len(),
num_rows
)
}
}
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows),
}
}
pub fn create_null_array(num_rows: usize) -> Self {
ColumnarValue::Array(Arc::new(NullArray::new(num_rows)))
}
pub fn values_to_arrays(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> {
if args.is_empty() {
return Ok(vec![]);
}
let mut array_len = None;
for arg in args {
array_len = match (arg, array_len) {
(ColumnarValue::Array(a), None) => Some(a.len()),
(ColumnarValue::Array(a), Some(array_len)) => {
if array_len == a.len() {
Some(array_len)
} else {
return internal_err!(
"Arguments has mixed length. Expected length: {array_len}, found length: {}",
a.len()
);
}
}
(ColumnarValue::Scalar(_), array_len) => array_len,
}
}
let inferred_length = array_len.unwrap_or(1);
let args = args
.iter()
.map(|arg| arg.to_array(inferred_length))
.collect::<Result<Vec<_>>>()?;
Ok(args)
}
pub fn cast_to(
&self,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match self {
ColumnarValue::Array(array) => {
let casted = cast_array_by_name(array, cast_type, &cast_options)?;
Ok(ColumnarValue::Array(casted))
}
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
scalar.cast_to_with_options(cast_type, &cast_options)?,
)),
}
}
}
fn cast_array_by_name(
array: &ArrayRef,
cast_type: &DataType,
cast_options: &CastOptions<'static>,
) -> Result<ArrayRef> {
if array.data_type() == cast_type {
return Ok(Arc::clone(array));
}
match cast_type {
DataType::Struct(_) => {
let target_field = Field::new("_", cast_type.clone(), true);
datafusion_common::nested_struct::cast_column(
array,
&target_field,
cast_options,
)
}
_ => {
ensure_date_array_timestamp_bounds(array, cast_type)?;
Ok(kernels::cast::cast_with_options(
array,
cast_type,
cast_options,
)?)
}
}
}
fn ensure_date_array_timestamp_bounds(
array: &ArrayRef,
cast_type: &DataType,
) -> Result<()> {
let source_type = array.data_type().clone();
let Some(multiplier) = date_to_timestamp_multiplier(&source_type, cast_type) else {
return Ok(());
};
if multiplier <= 1 {
return Ok(());
}
let (min_val, max_val): (Option<i64>, Option<i64>) = match &source_type {
DataType::Date32 => {
let arr = array
.as_any()
.downcast_ref::<Date32Array>()
.ok_or_else(|| {
internal_datafusion_err!(
"Expected Date32Array but found {}",
array.data_type()
)
})?;
(min(arr).map(|v| v as i64), max(arr).map(|v| v as i64))
}
DataType::Date64 => {
let arr = array
.as_any()
.downcast_ref::<Date64Array>()
.ok_or_else(|| {
internal_datafusion_err!(
"Expected Date64Array but found {}",
array.data_type()
)
})?;
(min(arr), max(arr))
}
_ => return Ok(()), };
if let Some(min) = min_val {
ensure_timestamp_in_bounds(min, multiplier, &source_type, cast_type)?;
}
if let Some(max) = max_val {
ensure_timestamp_in_bounds(max, multiplier, &source_type, cast_type)?;
}
Ok(())
}
impl fmt::Display for ColumnarValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let formatted = match self {
ColumnarValue::Array(array) => {
pretty_format_columns("ColumnarValue(ArrayRef)", &[Arc::clone(array)])
}
ColumnarValue::Scalar(_) => {
if let Ok(array) = self.to_array(1) {
pretty_format_columns("ColumnarValue(ScalarValue)", &[array])
} else {
return write!(f, "Error formatting columnar value");
}
}
};
if let Ok(formatted) = formatted {
write!(f, "{formatted}")
} else {
write!(f, "Error formatting columnar value")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Date64Array, Int32Array, StructArray},
datatypes::{Field, Fields, TimeUnit},
};
#[test]
fn into_array_of_size() {
let arr = make_array(1, 3);
let arr_columnar_value = ColumnarValue::Array(Arc::clone(&arr));
assert_eq!(&arr_columnar_value.into_array_of_size(3).unwrap(), &arr);
let scalar_columnar_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(42)));
let expected_array = make_array(42, 100);
assert_eq!(
&scalar_columnar_value.into_array_of_size(100).unwrap(),
&expected_array
);
let arr = make_array(1, 3);
let arr_columnar_value = ColumnarValue::Array(Arc::clone(&arr));
let result = arr_columnar_value.into_array_of_size(5);
let err = result.unwrap_err();
assert!(
err.to_string().starts_with(
"Internal error: Array length 3 does not match expected length 5"
),
"Found: {err}"
);
}
#[test]
fn values_to_arrays() {
let cases = vec![
TestCase {
input: vec![],
expected: vec![],
},
TestCase {
input: vec![ColumnarValue::Array(make_array(1, 3))],
expected: vec![make_array(1, 3)],
},
TestCase {
input: vec![
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Array(make_array(2, 3)),
],
expected: vec![make_array(1, 3), make_array(2, 3)],
},
TestCase {
input: vec![
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
],
expected: vec![
make_array(1, 3),
make_array(100, 3), ],
},
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
],
expected: vec![
make_array(100, 3), make_array(1, 3),
],
},
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(200))),
],
expected: vec![
make_array(100, 3), make_array(1, 3),
make_array(200, 3), ],
},
];
for case in cases {
case.run();
}
}
#[test]
#[should_panic(
expected = "Arguments has mixed length. Expected length: 3, found length: 4"
)]
fn values_to_arrays_mixed_length() {
ColumnarValue::values_to_arrays(&[
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Array(make_array(2, 4)),
])
.unwrap();
}
#[test]
#[should_panic(
expected = "Arguments has mixed length. Expected length: 3, found length: 7"
)]
fn values_to_arrays_mixed_length_and_scalar() {
ColumnarValue::values_to_arrays(&[
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(2, 7)),
])
.unwrap();
}
struct TestCase {
input: Vec<ColumnarValue>,
expected: Vec<ArrayRef>,
}
impl TestCase {
fn run(self) {
let Self { input, expected } = self;
assert_eq!(
ColumnarValue::values_to_arrays(&input).unwrap(),
expected,
"\ninput: {input:?}\nexpected: {expected:?}"
);
}
}
fn make_array(val: i32, len: usize) -> ArrayRef {
Arc::new(Int32Array::from(vec![val; len]))
}
#[test]
fn test_display_scalar() {
let column = ColumnarValue::from(ScalarValue::from("foo"));
assert_eq!(
column.to_string(),
concat!(
"+----------------------------+\n",
"| ColumnarValue(ScalarValue) |\n",
"+----------------------------+\n",
"| foo |\n",
"+----------------------------+"
)
);
}
#[test]
fn test_display_array() {
let array: ArrayRef = Arc::new(Int32Array::from_iter_values(vec![1, 2, 3]));
let column = ColumnarValue::from(array);
assert_eq!(
column.to_string(),
concat!(
"+-------------------------+\n",
"| ColumnarValue(ArrayRef) |\n",
"+-------------------------+\n",
"| 1 |\n",
"| 2 |\n",
"| 3 |\n",
"+-------------------------+"
)
);
}
#[test]
fn cast_struct_by_field_name() {
let source_fields = Fields::from(vec![
Field::new("b", DataType::Int32, true),
Field::new("a", DataType::Int32, true),
]);
let target_fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);
let struct_array = StructArray::new(
source_fields,
vec![
Arc::new(Int32Array::from(vec![Some(3)])),
Arc::new(Int32Array::from(vec![Some(4)])),
],
None,
);
let value = ColumnarValue::Array(Arc::new(struct_array));
let casted = value
.cast_to(&DataType::Struct(target_fields.clone()), None)
.expect("struct cast should succeed");
let ColumnarValue::Array(arr) = casted else {
panic!("expected array after cast");
};
let struct_array = arr
.as_any()
.downcast_ref::<StructArray>()
.expect("expected StructArray");
let field_a = struct_array
.column_by_name("a")
.expect("expected field a in cast result");
let field_b = struct_array
.column_by_name("b")
.expect("expected field b in cast result");
assert_eq!(
field_a
.as_any()
.downcast_ref::<Int32Array>()
.expect("expected Int32 array")
.value(0),
4
);
assert_eq!(
field_b
.as_any()
.downcast_ref::<Int32Array>()
.expect("expected Int32 array")
.value(0),
3
);
}
#[test]
fn cast_struct_missing_field_inserts_nulls() {
let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]);
let target_fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);
let struct_array = StructArray::new(
source_fields,
vec![Arc::new(Int32Array::from(vec![Some(5)]))],
None,
);
let value = ColumnarValue::Array(Arc::new(struct_array));
let casted = value
.cast_to(&DataType::Struct(target_fields.clone()), None)
.expect("struct cast should succeed");
let ColumnarValue::Array(arr) = casted else {
panic!("expected array after cast");
};
let struct_array = arr
.as_any()
.downcast_ref::<StructArray>()
.expect("expected StructArray");
let field_b = struct_array
.column_by_name("b")
.expect("expected missing field to be added");
assert!(field_b.is_null(0));
}
#[test]
fn cast_date64_array_to_timestamp_overflow() {
let overflow_value = i64::MAX / 1_000_000 + 1;
let array: ArrayRef = Arc::new(Date64Array::from(vec![Some(overflow_value)]));
let value = ColumnarValue::Array(array);
let result =
value.cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None), None);
let err = result.expect_err("expected overflow to be detected");
assert!(
err.to_string()
.contains("converted value exceeds the representable i64 range"),
"unexpected error: {err}"
);
}
}