use std::fmt::Display;
use std::fmt::Formatter;
use itertools::Itertools;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_utils::aliases::hash_map::HashMap;
use crate::dtype::DType;
use crate::dtype::FieldName;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::StructFields;
use crate::expr::Expression;
use crate::expr::analysis::Annotation;
use crate::expr::analysis::AnnotationFn;
use crate::expr::analysis::Annotations;
use crate::expr::analysis::descendent_annotations;
use crate::expr::get_item;
use crate::expr::pack;
use crate::expr::root;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::NodeRewriter;
use crate::expr::traversal::Transformed;
use crate::expr::traversal::TraversalOrder;
pub fn partition<A: AnnotationFn>(
expr: Expression,
scope: &DType,
annotate_fn: A,
) -> VortexResult<PartitionedExpr<A::Annotation>>
where
A::Annotation: Display,
FieldName: From<A::Annotation>,
{
let annotations = descendent_annotations(&expr, annotate_fn);
let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
let root = expr.clone().rewrite(&mut splitter)?.value;
let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
let mut partition_dtypes = Vec::with_capacity(splitter.sub_expressions.len());
for (annotation, exprs) in splitter.sub_expressions.into_iter() {
let expr = pack(
exprs.into_iter().enumerate().map(|(idx, expr)| {
(
StructFieldExpressionSplitter::field_name(&annotation, idx),
expr,
)
}),
Nullability::NonNullable,
);
let expr = expr.optimize_recursive(scope)?;
let expr_dtype = expr.return_dtype(scope)?;
partitions.push(expr);
partition_annotations.push(annotation);
partition_dtypes.push(expr_dtype);
}
let partition_names = partition_annotations
.iter()
.map(|id| FieldName::from(id.clone()))
.collect::<FieldNames>();
let root_scope = DType::Struct(
StructFields::new(partition_names.clone(), partition_dtypes.clone()),
Nullability::NonNullable,
);
Ok(PartitionedExpr {
root: root.optimize_recursive(&root_scope)?,
partitions: partitions.into_boxed_slice(),
partition_names,
partition_dtypes: partition_dtypes.into_boxed_slice(),
partition_annotations: partition_annotations.into_boxed_slice(),
})
}
#[derive(Debug)]
pub struct PartitionedExpr<A> {
pub root: Expression,
pub partitions: Box<[Expression]>,
pub partition_names: FieldNames,
pub partition_dtypes: Box<[DType]>,
pub partition_annotations: Box<[A]>,
}
impl<A: Display> Display for PartitionedExpr<A> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"root: {} {{{}}}",
self.root,
self.partition_names
.iter()
.zip(self.partitions.iter())
.map(|(name, partition)| format!("{name}: {partition}"))
.join(", ")
)
}
}
impl<A: Annotation> PartitionedExpr<A>
where
FieldName: From<A>,
{
pub fn find_partition(&self, id: &A) -> Option<&Expression> {
let id = FieldName::from(id.clone());
self.partition_names
.iter()
.position(|field| field == id)
.map(|idx| &self.partitions[idx])
}
}
#[derive(Debug)]
struct StructFieldExpressionSplitter<'a, A: Annotation> {
annotations: &'a Annotations<'a, A>,
sub_expressions: HashMap<A, Vec<Expression>>,
}
impl<'a, A: Annotation + Display> StructFieldExpressionSplitter<'a, A> {
fn new(annotations: &'a Annotations<'a, A>) -> Self {
Self {
sub_expressions: HashMap::new(),
annotations,
}
}
fn field_name(annotation: &A, idx: usize) -> FieldName {
format!("{annotation}_{idx}").into()
}
}
impl<A: Annotation + Display> NodeRewriter for StructFieldExpressionSplitter<'_, A>
where
FieldName: From<A>,
{
type NodeTy = Expression;
fn visit_down(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
match self.annotations.get(&node) {
Some(annotations) if annotations.len() == 1 => {
let annotation = annotations
.iter()
.next()
.vortex_expect("expected one field");
let sub_exprs = self.sub_expressions.entry(annotation.clone()).or_default();
let idx = sub_exprs.len();
sub_exprs.push(node.clone());
let value = get_item(
StructFieldExpressionSplitter::field_name(annotation, idx),
get_item(FieldName::from(annotation.clone()), root()),
);
Ok(Transformed {
value,
changed: true,
order: TraversalOrder::Skip,
})
}
_ => Ok(Transformed::no(node)),
}
}
fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<Transformed<Self::NodeTy>> {
Ok(Transformed::no(node))
}
}
#[cfg(test)]
mod tests {
use rstest::fixture;
use rstest::rstest;
use super::*;
use crate::dtype::DType;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::PType::I32;
use crate::dtype::StructFields;
use crate::expr::analysis::make_free_field_annotator;
use crate::expr::and;
use crate::expr::col;
use crate::expr::get_item;
use crate::expr::lit;
use crate::expr::merge;
use crate::expr::pack;
use crate::expr::root;
use crate::expr::transform::replace::replace_root_fields;
#[fixture]
fn dtype() -> DType {
DType::Struct(
StructFields::from_iter([
(
"a",
DType::Struct(
StructFields::from_iter([("x", I32.into()), ("y", DType::from(I32))]),
NonNullable,
),
),
("b", I32.into()),
("c", I32.into()),
]),
NonNullable,
)
}
#[rstest]
fn test_expr_top_level_ref(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = root();
let partitioned =
partition(expr.clone(), &dtype, make_free_field_annotator(fields)).unwrap();
assert_eq!(partitioned.partitions.len(), 0);
assert_eq!(&partitioned.root, &root());
let expr = replace_root_fields(expr, fields);
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
assert_eq!(partitioned.partitions.len(), fields.names().len());
}
#[rstest]
fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = get_item("y", get_item("a", root()));
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
}
#[rstest]
fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = pack(
[
("x", get_item("x", get_item("a", root()))),
("y", get_item("y", get_item("a", root()))),
("c", get_item("c", root())),
],
NonNullable,
);
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
let split_a = partitioned.find_partition(&"a".into()).unwrap();
assert_eq!(
&split_a.optimize_recursive(&dtype).unwrap(),
&pack(
[
("a_0", get_item("x", get_item("a", root()))),
("a_1", get_item("y", get_item("a", root())))
],
NonNullable
)
);
}
#[rstest]
fn test_expr_top_level_ref_get_item_add(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = and(get_item("y", get_item("a", root())), lit(1));
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
assert_eq!(partitioned.partitions.len(), 1);
}
#[rstest]
fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
assert_eq!(partitioned.partitions.len(), 2);
}
#[rstest]
fn test_expr_merge(dtype: DType) {
let fields = dtype.as_struct_fields_opt().unwrap();
let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
let partitioned = partition(expr, &dtype, make_free_field_annotator(fields)).unwrap();
let expected = pack(
[
("x", get_item("x", get_item("a_0", col("a")))),
("y", get_item("y", get_item("a_0", col("a")))),
("b", get_item("b", get_item("b_0", col("b")))),
],
NonNullable,
);
assert_eq!(
&partitioned.root, &expected,
"{} {}",
partitioned.root, expected
);
assert_eq!(partitioned.partitions.len(), 2);
let part_a = partitioned.find_partition(&"a".into()).unwrap();
let expected_a = pack([("a_0", col("a"))], NonNullable);
assert_eq!(part_a, &expected_a, "{part_a} {expected_a}");
let part_b = partitioned.find_partition(&"b".into()).unwrap();
let expected_b = pack([("b_0", pack([("b", col("b"))], NonNullable))], NonNullable);
assert_eq!(part_b, &expected_b, "{part_b} {expected_b}");
}
}