use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_utils::aliases::hash_map::HashMap;
use crate::expr::Expression;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::NodeVisitor;
use crate::expr::traversal::TraversalOrder;
pub fn label_tree<L: Clone>(
expr: &Expression,
self_label: impl Fn(&Expression) -> L,
mut merge_child: impl FnMut(L, &L) -> L,
) -> HashMap<&Expression, L> {
let mut visitor = LabelingVisitor {
labels: Default::default(),
self_label,
merge_child: &mut merge_child,
};
expr.accept(&mut visitor)
.vortex_expect("LabelingVisitor is infallible");
visitor.labels
}
struct LabelingVisitor<'a, 'b, L, F, G>
where
F: Fn(&Expression) -> L,
G: FnMut(L, &L) -> L,
{
labels: HashMap<&'a Expression, L>,
self_label: F,
merge_child: &'b mut G,
}
impl<'a, 'b, L: Clone, F, G> NodeVisitor<'a> for LabelingVisitor<'a, 'b, L, F, G>
where
F: Fn(&Expression) -> L,
G: FnMut(L, &L) -> L,
{
type NodeTy = Expression;
fn visit_down(&mut self, _node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> {
Ok(TraversalOrder::Continue)
}
fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
let self_label = (self.self_label)(node);
let final_label = node.children().iter().fold(self_label, |acc, child| {
let child_label = self
.labels
.get(child)
.vortex_expect("child must have label");
(self.merge_child)(acc, child_label)
});
self.labels.insert(node, final_label);
Ok(TraversalOrder::Continue)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::col;
use crate::expr::eq;
use crate::expr::lit;
#[test]
fn test_tree_depth() {
let expr = eq(col("col1"), lit(5));
let depths = label_tree(
&expr,
|_node| 1, |self_depth, child_depth| self_depth.max(*child_depth + 1),
);
assert_eq!(depths.get(&expr), Some(&3));
}
#[test]
fn test_node_count() {
let expr = eq(col("col1"), lit(5));
let counts = label_tree(
&expr,
|_node| 1, |self_count, child_count| self_count + *child_count,
);
assert_eq!(counts.get(&expr), Some(&4));
}
}