use std::fmt::Debug;
use std::iter::ExactSizeIterator;
use crate::*;
#[non_exhaustive]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct EClass<L, D> {
pub id: Id,
pub nodes: Vec<L>,
pub data: D,
pub(crate) parents: Vec<Id>,
}
impl<L, D> EClass<L, D> {
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn iter(&self) -> impl ExactSizeIterator<Item = &L> {
self.nodes.iter()
}
pub fn parents(&self) -> impl ExactSizeIterator<Item = Id> + '_ {
self.parents.iter().copied()
}
}
impl<L: Language, D> EClass<L, D> {
pub fn leaves(&self) -> impl Iterator<Item = &L> {
self.nodes.iter().filter(|&n| n.is_leaf())
}
pub fn assert_unique_leaves(&self)
where
L: Language,
{
let mut leaves = self.leaves();
if let Some(first) = leaves.next() {
assert!(
leaves.all(|l| l == first),
"Different leaves in eclass {}: {:?}",
self.id,
self.leaves().collect::<crate::util::HashSet<_>>()
);
}
}
pub fn for_each_matching_node<Err>(
&self,
node: &L,
mut f: impl FnMut(&L) -> Result<(), Err>,
) -> Result<(), Err>
where
L: Language,
{
if self.nodes.len() < 50 {
self.nodes
.iter()
.filter(|n| node.matches(n))
.try_for_each(f)
} else {
debug_assert!(node.all(|id| id == Id::from(0)));
debug_assert!(self.nodes.windows(2).all(|w| w[0] < w[1]));
let mut start = self.nodes.binary_search(node).unwrap_or_else(|i| i);
let discrim = node.discriminant();
while start > 0 {
if self.nodes[start - 1].discriminant() == discrim {
start -= 1;
} else {
break;
}
}
let mut matching = self.nodes[start..]
.iter()
.take_while(|&n| n.discriminant() == discrim)
.filter(|n| node.matches(n));
debug_assert_eq!(
matching.clone().count(),
self.nodes.iter().filter(|n| node.matches(n)).count(),
"matching node {:?}\nstart={}\n{:?} != {:?}\nnodes: {:?}",
node,
start,
matching.clone().collect::<HashSet<_>>(),
self.nodes
.iter()
.filter(|n| node.matches(n))
.collect::<HashSet<_>>(),
self.nodes
);
matching.try_for_each(&mut f)
}
}
}