use itertools::Itertools as _;
use vortex_array::arrays::ChunkedArray;
use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray};
use vortex_dtype::DType;
use vortex_error::{VortexResult, vortex_bail};
use crate::display::{DisplayAs, DisplayFormat};
use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
vtable!(Concat);
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConcatExpr {
values: Vec<ExprRef>,
}
pub struct ConcatExprEncoding;
impl VTable for ConcatVTable {
type Expr = ConcatExpr;
type Encoding = ConcatExprEncoding;
type Metadata = EmptyMetadata;
fn id(_encoding: &Self::Encoding) -> ExprId {
ExprId::new_ref("concat")
}
fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
ExprEncodingRef::new_ref(ConcatExprEncoding.as_ref())
}
fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
Some(EmptyMetadata)
}
fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
expr.values.iter().collect()
}
fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
Ok(ConcatExpr { values: children })
}
fn build(
_encoding: &Self::Encoding,
_metadata: &<Self::Metadata as DeserializeMetadata>::Output,
children: Vec<ExprRef>,
) -> VortexResult<Self::Expr> {
Ok(ConcatExpr { values: children })
}
fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
if expr.values.is_empty() {
vortex_bail!("Concat expression must have at least one child");
}
let value_arrays = expr
.values
.iter()
.map(|value_expr| value_expr.unchecked_evaluate(scope))
.process_results(|it| it.collect::<Vec<_>>())?;
let dtype = value_arrays[0].dtype().clone();
for array in &value_arrays[1..] {
if array.dtype() != &dtype {
vortex_bail!(
"All arrays in concat must have the same dtype, expected {:?} but got {:?}",
dtype,
array.dtype()
);
}
}
Ok(ChunkedArray::try_new(value_arrays, dtype)?.into_array())
}
fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
if expr.values.is_empty() {
vortex_bail!("Concat expression must have at least one child");
}
let dtype = expr.values[0].return_dtype(scope)?;
for value_expr in &expr.values[1..] {
let child_dtype = value_expr.return_dtype(scope)?;
if child_dtype != dtype {
vortex_bail!(
"All expressions in concat must return the same dtype, expected {:?} but got {:?}",
dtype,
child_dtype
);
}
}
Ok(dtype)
}
}
impl ConcatExpr {
pub fn new(values: Vec<ExprRef>) -> Self {
ConcatExpr { values }
}
pub fn new_expr(values: Vec<ExprRef>) -> ExprRef {
Self::new(values).into_expr()
}
pub fn values(&self) -> &[ExprRef] {
&self.values
}
}
pub fn concat(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
let values = elements.into_iter().map(|value| value.into()).collect_vec();
ConcatExpr::new(values).into_expr()
}
impl DisplayAs for ConcatExpr {
fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match df {
DisplayFormat::Compact => {
write!(f, "concat({})", self.values.iter().format(", "))
}
DisplayFormat::Tree => {
write!(f, "Concat")
}
}
}
}
impl AnalysisExpr for ConcatExpr {}
#[cfg(test)]
mod tests {
use vortex_array::arrays::ChunkedVTable;
use vortex_array::{IntoArray, ToCanonical};
use vortex_buffer::buffer;
use vortex_dtype::{DType, Nullability, PType};
use crate::{ConcatExpr, Scope, col, concat, lit, root};
fn test_array() -> vortex_array::ArrayRef {
vortex_array::arrays::StructArray::from_fields(&[
("a", buffer![1, 2, 3].into_array()),
("b", buffer![4, 5, 6].into_array()),
])
.unwrap()
.into_array()
}
#[test]
pub fn test_concat_literals() {
let expr = ConcatExpr::new(vec![
lit(vortex_scalar::Scalar::from(1i32)),
lit(vortex_scalar::Scalar::from(2i32)),
lit(vortex_scalar::Scalar::from(3i32)),
]);
let scope_array = buffer![0i32].into_array();
let actual_array = expr.evaluate(&Scope::new(scope_array)).unwrap();
let chunked = actual_array.as_::<ChunkedVTable>();
assert_eq!(chunked.nchunks(), 3);
assert_eq!(chunked.len(), 3);
let canonical = chunked.to_canonical().into_array();
let primitive = canonical.to_primitive();
assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3]);
}
#[test]
pub fn test_concat_columns() {
let expr = ConcatExpr::new(vec![col("a"), col("b"), col("a")]);
let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
let chunked = actual_array.as_::<ChunkedVTable>();
assert_eq!(chunked.nchunks(), 3);
assert_eq!(chunked.len(), 9);
let canonical = chunked.to_canonical().into_array();
let primitive = canonical.to_primitive();
assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 4, 5, 6, 1, 2, 3]);
}
#[test]
pub fn test_concat_mixed() {
let expr = ConcatExpr::new(vec![
col("a"),
lit(vortex_scalar::Scalar::from(99i32)),
col("b"),
]);
let actual_array = expr.evaluate(&Scope::new(test_array())).unwrap();
let chunked = actual_array.as_::<ChunkedVTable>();
assert_eq!(chunked.nchunks(), 3);
assert_eq!(chunked.len(), 9);
let canonical = chunked.to_canonical().into_array();
let primitive = canonical.to_primitive();
assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 99, 99, 99, 4, 5, 6]);
}
#[test]
pub fn test_concat_dtype_mismatch() {
let expr = ConcatExpr::new(vec![
lit(vortex_scalar::Scalar::from(1i32)),
lit(vortex_scalar::Scalar::from(2i64)),
]);
let result = expr.evaluate(&Scope::new(test_array()));
assert!(result.is_err());
}
#[test]
pub fn test_return_dtype() {
let expr = ConcatExpr::new(vec![
lit(vortex_scalar::Scalar::from(1i32)),
lit(vortex_scalar::Scalar::from(2i32)),
]);
let dtype = expr
.return_dtype(&DType::Primitive(PType::I32, Nullability::NonNullable))
.unwrap();
assert_eq!(
dtype,
DType::Primitive(PType::I32, Nullability::NonNullable)
);
}
#[test]
pub fn test_display() {
let expr = concat([col("a"), col("b"), col("c")]);
assert_eq!(expr.to_string(), "concat($.a, $.b, $.c)");
}
#[test]
pub fn test_concat_with_root() {
let expr = concat([root(), root()]);
let test_array = buffer![1, 2, 3].into_array();
let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
let chunked = actual_array.as_::<ChunkedVTable>();
assert_eq!(chunked.nchunks(), 2);
assert_eq!(chunked.len(), 6);
let canonical = chunked.to_canonical().into_array();
let primitive = canonical.to_primitive();
assert_eq!(primitive.as_slice::<i32>(), &[1, 2, 3, 1, 2, 3]);
}
}