use crate::hash_utils::combine_hashes;
use crate::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
};
use crate::Result;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
use std::sync::Arc;
pub trait HashNode {
fn hash_node<H: Hasher>(&self, state: &mut H);
}
impl<T: HashNode + ?Sized> HashNode for Arc<T> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
(**self).hash_node(state);
}
}
pub trait Normalizeable {
fn can_normalize(&self) -> bool;
}
pub trait NormalizeEq: Eq + Normalizeable {
fn normalize_eq(&self, other: &Self) -> bool;
}
#[derive(Debug, Eq)]
struct Identifier<'n, N: NormalizeEq> {
hash: u64,
node: &'n N,
}
impl<N: NormalizeEq> Clone for Identifier<'_, N> {
fn clone(&self) -> Self {
*self
}
}
impl<N: NormalizeEq> Copy for Identifier<'_, N> {}
impl<N: NormalizeEq> Hash for Identifier<'_, N> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}
impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash && self.node.normalize_eq(other.node)
}
}
impl<'n, N> Identifier<'n, N>
where
N: HashNode + NormalizeEq,
{
fn new(node: &'n N, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
node.hash_node(&mut hasher);
let hash = hasher.finish();
Self { hash, node }
}
fn combine(mut self, other: Option<Self>) -> Self {
other.map_or(self, |other_id| {
self.hash = combine_hashes(self.hash, other_id.hash);
self
})
}
}
type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
#[derive(PartialEq, Eq)]
enum NodeEvaluation {
SurelyOnce,
ConditionallyAtLeastOnce,
Common,
}
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
type ChildrenList<N> = (Vec<N>, Vec<N>);
pub trait CSEController {
type Node;
fn conditional_children(node: &Self::Node) -> Option<ChildrenList<&Self::Node>>;
fn is_valid(node: &Self::Node) -> bool;
fn is_ignored(&self, node: &Self::Node) -> bool;
fn generate_alias(&self) -> String;
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
fn rewrite_f_down(&mut self, _node: &Self::Node) {}
fn rewrite_f_up(&mut self, _node: &Self::Node) {}
}
#[derive(Debug)]
pub enum FoundCommonNodes<N> {
No { original_nodes_list: Vec<Vec<N>> },
Yes {
common_nodes: Vec<(N, String)>,
new_nodes_list: Vec<Vec<N>>,
original_nodes_list: Vec<Vec<N>>,
},
}
struct CSEVisitor<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
node_stats: &'a mut NodeStats<'n, N>,
id_array: &'a mut IdArray<'n, N>,
visit_stack: Vec<VisitRecord<'n, N>>,
down_index: usize,
up_index: usize,
random_state: &'a RandomState,
found_common: bool,
conditional: bool,
controller: &'a C,
}
enum VisitRecord<'n, N>
where
N: NormalizeEq,
{
EnterMark(usize),
NodeItem(Identifier<'n, N>, bool),
}
impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
fn pop_enter_mark(
&mut self,
can_normalize: bool,
) -> (usize, Option<Identifier<'n, N>>, bool) {
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
let mut is_valid = true;
while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index) => {
if can_normalize {
node_ids.sort_by_key(|i| i.hash);
}
let node_id = node_ids
.into_iter()
.fold(None, |accum, item| Some(item.combine(accum)));
return (down_index, node_id, is_valid);
}
VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
node_ids.push(sub_node_id);
is_valid &= sub_node_is_valid;
}
}
}
unreachable!("EnterMark should paired with NodeItem");
}
}
impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.id_array.push((0, None));
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
self.down_index += 1;
Ok(if self.conditional {
TreeNodeRecursion::Continue
} else {
match C::conditional_children(node) {
Some((normal, conditional)) => {
normal
.into_iter()
.try_for_each(|n| n.visit(self).map(|_| ()))?;
self.conditional = true;
conditional
.into_iter()
.try_for_each(|n| n.visit(self).map(|_| ()))?;
self.conditional = false;
TreeNodeRecursion::Jump
}
_ => TreeNodeRecursion::Continue,
}
})
}
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_index, sub_node_id, sub_node_is_valid) =
self.pop_enter_mark(node.can_normalize());
let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
let is_valid = C::is_valid(node) && sub_node_is_valid;
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.controller.is_ignored(node) {
self.id_array[down_index].1 = Some(node_id);
self.node_stats
.entry(node_id)
.and_modify(|evaluation| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
{
*evaluation = NodeEvaluation::Common;
self.found_common = true;
}
})
.or_insert_with(|| {
if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
}
});
}
self.visit_stack
.push(VisitRecord::NodeItem(node_id, is_valid));
self.up_index += 1;
Ok(TreeNodeRecursion::Continue)
}
}
struct CSERewriter<'a, 'n, N, C>
where
N: NormalizeEq,
C: CSEController<Node = N>,
{
node_stats: &'a NodeStats<'n, N>,
id_array: &'a IdArray<'n, N>,
common_nodes: &'a mut CommonNodes<'n, N>,
down_index: usize,
controller: &'a mut C,
}
impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
where
N: TreeNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
self.controller.rewrite_f_down(&node);
let (up_index, node_id) = self.id_array[self.down_index];
self.down_index += 1;
if let Some(node_id) = node_id {
let evaluation = self.node_stats.get(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
while self.down_index < self.id_array.len()
&& self.id_array[self.down_index].0 < up_index
{
self.down_index += 1;
}
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
{
self.controller.rewrite(&node, alias)
} else {
let node_alias = self.controller.generate_alias();
let rewritten = self.controller.rewrite(&node, &node_alias);
self.common_nodes.insert(node_id, (node, node_alias));
rewritten
};
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
}
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
self.controller.rewrite_f_up(&node);
Ok(Transformed::no(node))
}
}
pub struct CSE<N, C: CSEController<Node = N>> {
random_state: RandomState,
phantom_data: PhantomData<N>,
controller: C,
}
impl<N, C> CSE<N, C>
where
N: TreeNode + HashNode + Clone + NormalizeEq,
C: CSEController<Node = N>,
{
pub fn new(controller: C) -> Self {
Self {
random_state: RandomState::new(),
phantom_data: PhantomData,
controller,
}
}
fn node_to_id_array<'n>(
&self,
node: &'n N,
node_stats: &mut NodeStats<'n, N>,
id_array: &mut IdArray<'n, N>,
) -> Result<bool> {
let mut visitor = CSEVisitor {
node_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
random_state: &self.random_state,
found_common: false,
conditional: false,
controller: &self.controller,
};
node.visit(&mut visitor)?;
Ok(visitor.found_common)
}
fn to_arrays<'n>(
&self,
nodes: &'n [N],
node_stats: &mut NodeStats<'n, N>,
) -> Result<(bool, Vec<IdArray<'n, N>>)> {
let mut found_common = false;
nodes
.iter()
.map(|n| {
let mut id_array = vec![];
self.node_to_id_array(n, node_stats, &mut id_array)
.map(|fc| {
found_common |= fc;
id_array
})
})
.collect::<Result<Vec<_>>>()
.map(|id_arrays| (found_common, id_arrays))
}
fn replace_common_node<'n>(
&mut self,
node: N,
id_array: &IdArray<'n, N>,
node_stats: &NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<N> {
if id_array.is_empty() {
Ok(Transformed::no(node))
} else {
node.rewrite(&mut CSERewriter {
node_stats,
id_array,
common_nodes,
down_index: 0,
controller: &mut self.controller,
})
}
.data()
}
fn rewrite_nodes_list<'n>(
&mut self,
nodes_list: Vec<Vec<N>>,
arrays_list: &[Vec<IdArray<'n, N>>],
node_stats: &NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<Vec<Vec<N>>> {
nodes_list
.into_iter()
.zip(arrays_list.iter())
.map(|(nodes, arrays)| {
nodes
.into_iter()
.zip(arrays.iter())
.map(|(node, id_array)| {
self.replace_common_node(node, id_array, node_stats, common_nodes)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
}
pub fn extract_common_nodes(
&mut self,
nodes_list: Vec<Vec<N>>,
) -> Result<FoundCommonNodes<N>> {
let mut found_common = false;
let mut node_stats = NodeStats::new();
let id_arrays_list = nodes_list
.iter()
.map(|nodes| {
self.to_arrays(nodes, &mut node_stats)
.map(|(fc, id_arrays)| {
found_common |= fc;
id_arrays
})
})
.collect::<Result<Vec<_>>>()?;
if found_common {
let mut common_nodes = CommonNodes::new();
let new_nodes_list = self.rewrite_nodes_list(
nodes_list.clone(),
&id_arrays_list,
&node_stats,
&mut common_nodes,
)?;
assert!(!common_nodes.is_empty());
Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
new_nodes_list,
original_nodes_list: nodes_list,
})
} else {
Ok(FoundCommonNodes::No {
original_nodes_list: nodes_list,
})
}
}
}
#[cfg(test)]
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
Normalizeable, CSE,
};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
const CSE_PREFIX: &str = "__common_node";
#[derive(Clone, Copy)]
pub enum TestTreeNodeMask {
Normal,
NormalAndAggregates,
}
pub struct TestTreeNodeCSEController<'a> {
alias_generator: &'a AliasGenerator,
mask: TestTreeNodeMask,
}
impl<'a> TestTreeNodeCSEController<'a> {
fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
Self {
alias_generator,
mask,
}
}
}
impl CSEController for TestTreeNodeCSEController<'_> {
type Node = TestTreeNode<String>;
fn conditional_children(
_: &Self::Node,
) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
None
}
fn is_valid(_node: &Self::Node) -> bool {
true
}
fn is_ignored(&self, node: &Self::Node) -> bool {
let is_leaf = node.is_leaf();
let is_aggr = node.data == "avg" || node.data == "sum";
match self.mask {
TestTreeNodeMask::Normal => is_leaf || is_aggr,
TestTreeNodeMask::NormalAndAggregates => is_leaf,
}
}
fn generate_alias(&self) -> String {
self.alias_generator.next(CSE_PREFIX)
}
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}
}
impl HashNode for TestTreeNode<String> {
fn hash_node<H: Hasher>(&self, state: &mut H) {
self.data.hash(state);
}
}
impl Normalizeable for TestTreeNode<String> {
fn can_normalize(&self) -> bool {
false
}
}
impl NormalizeEq for TestTreeNode<String> {
fn normalize_eq(&self, other: &Self) -> bool {
self == other
}
}
#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));
let a_plus_1 = TestTreeNode::new(
vec![
TestTreeNode::new_leaf("a".to_string()),
TestTreeNode::new_leaf("1".to_string()),
],
"+".to_string(),
);
let avg_c = TestTreeNode::new(
vec![TestTreeNode::new_leaf("c".to_string())],
"avg".to_string(),
);
let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string());
let sum_a_plus_1_minus_avg_c =
TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
let root = TestTreeNode::new(
vec![
sum_a_plus_1_minus_avg_c,
TestTreeNode::new_leaf("2".to_string()),
],
"*".to_string(),
);
let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
panic!("Cannot extract subtree references")
};
let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else {
panic!("Cannot extract subtree references")
};
let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
panic!("Cannot extract subtree references")
};
let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
fn collect_hashes(
id_array: &mut IdArray<'_, TestTreeNode<String>>,
) -> HashSet<u64> {
id_array
.iter_mut()
.flat_map(|(_, id_option)| {
id_option.as_mut().map(|node_id| {
let hash = node_id.hash;
node_id.hash = 0;
hash
})
})
.collect::<HashSet<_>>()
}
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 3);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
node: &root,
}),
),
(
6,
Some(Identifier {
hash: 0,
node: sum_a_plus_1_minus_avg_c,
}),
),
(3, None),
(
2,
Some(Identifier {
hash: 0,
node: a_plus_1,
}),
),
(0, None),
(1, None),
(5, None),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
let eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));
let mut id_array = vec![];
eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
let hashes = collect_hashes(&mut id_array);
assert_eq!(hashes.len(), 5);
let expected = vec![
(
8,
Some(Identifier {
hash: 0,
node: &root,
}),
),
(
6,
Some(Identifier {
hash: 0,
node: sum_a_plus_1_minus_avg_c,
}),
),
(
3,
Some(Identifier {
hash: 0,
node: sum_a_plus_1,
}),
),
(
2,
Some(Identifier {
hash: 0,
node: a_plus_1,
}),
),
(0, None),
(1, None),
(
5,
Some(Identifier {
hash: 0,
node: avg_c,
}),
),
(4, None),
(7, None),
];
assert_eq!(expected, id_array);
Ok(())
}
}