use std::collections::HashSet;
use crate::union_zipper::{Lattice, LatticeMeet, ValueMergeStrategy};
use crate::zipper::{DictZipper, ValuedDictZipper};
#[derive(Clone, Debug)]
pub struct IntersectionZipper<Z: DictZipper, S = LatticeMeet> {
zippers: Vec<Option<Z>>,
path: Vec<Z::Unit>,
strategy: S,
}
impl<Z: DictZipper> IntersectionZipper<Z, LatticeMeet> {
pub fn new(zippers: Vec<Z>) -> Self {
Self {
zippers: zippers.into_iter().map(Some).collect(),
path: Vec::new(),
strategy: LatticeMeet,
}
}
}
impl<Z: DictZipper, S: Clone + Send + Sync> IntersectionZipper<Z, S> {
pub fn with_strategy(zippers: Vec<Z>, strategy: S) -> Self {
Self {
zippers: zippers.into_iter().map(Some).collect(),
path: Vec::new(),
strategy,
}
}
pub fn dictionary_count(&self) -> usize {
self.zippers.len()
}
pub fn active_dictionary_count(&self) -> usize {
self.zippers.iter().filter(|z| z.is_some()).count()
}
pub fn iter(&self) -> IntersectionIterator<Z, S> {
IntersectionIterator::new(self.clone())
}
}
impl<Z: DictZipper, S: Clone + Send + Sync> DictZipper for IntersectionZipper<Z, S> {
type Unit = Z::Unit;
fn is_final(&self) -> bool {
let active_count = self.zippers.iter().filter(|z| z.is_some()).count();
if active_count == 0 {
return false;
}
self.zippers
.iter()
.filter_map(|z| z.as_ref())
.all(|z| z.is_final())
&& active_count == self.zippers.len()
}
fn descend(&self, label: Self::Unit) -> Option<Self> {
let new_zippers: Vec<Option<Z>> = self
.zippers
.iter()
.map(|z| z.as_ref().and_then(|z| z.descend(label)))
.collect();
let all_can_descend = new_zippers.iter().all(|z| z.is_some());
if all_can_descend {
let mut new_path = self.path.clone();
new_path.push(label);
Some(Self {
zippers: new_zippers,
path: new_path,
strategy: self.strategy.clone(),
})
} else {
None
}
}
fn children(&self) -> impl Iterator<Item = (Self::Unit, Self)> {
let label_sets: Vec<HashSet<Z::Unit>> = self
.zippers
.iter()
.filter_map(|z| z.as_ref())
.map(|z| z.children().map(|(label, _)| label).collect())
.collect();
let common_labels: Vec<Z::Unit> = if label_sets.is_empty() {
Vec::new()
} else if label_sets.len() == 1 {
label_sets[0].iter().copied().collect()
} else {
let mut result = label_sets[0].clone();
for set in label_sets.iter().skip(1) {
result = result.intersection(set).copied().collect();
}
let mut labels: Vec<_> = result.into_iter().collect();
labels.sort_by(|a, b| format!("{:?}", a).cmp(&format!("{:?}", b)));
labels
};
let self_clone = self.clone();
common_labels
.into_iter()
.filter_map(move |label| self_clone.descend(label).map(|child| (label, child)))
}
fn path(&self) -> Vec<Self::Unit> {
self.path.clone()
}
}
impl<Z: ValuedDictZipper, S: ValueMergeStrategy<Z::Value> + Clone + Send + Sync> ValuedDictZipper
for IntersectionZipper<Z, S>
where
Z::Value: Lattice,
{
type Value = Z::Value;
fn value(&self) -> Option<Self::Value> {
if !self.is_final() {
return None;
}
let mut result: Option<Z::Value> = None;
for zipper in self.zippers.iter().filter_map(|z| z.as_ref()) {
if let Some(v) = zipper.value() {
result = Some(match result {
Some(existing) => self.strategy.merge(existing, v),
None => v,
});
}
}
result
}
}
pub struct IntersectionIterator<Z: DictZipper, S = LatticeMeet> {
stack: Vec<IntersectionZipper<Z, S>>,
seen: HashSet<Vec<Z::Unit>>,
}
impl<Z: DictZipper, S: Clone + Send + Sync> IntersectionIterator<Z, S> {
fn new(zipper: IntersectionZipper<Z, S>) -> Self {
let mut stack = Vec::with_capacity(16);
stack.push(zipper);
Self {
stack,
seen: HashSet::new(),
}
}
}
impl<Z: DictZipper, S: Clone + Send + Sync> Iterator for IntersectionIterator<Z, S> {
type Item = (Vec<Z::Unit>, IntersectionZipper<Z, S>);
fn next(&mut self) -> Option<Self::Item> {
while let Some(zipper) = self.stack.pop() {
for (_label, child) in zipper.children() {
self.stack.push(child);
}
if zipper.is_final() {
let path = zipper.path();
if self.seen.insert(path.clone()) {
return Some((path, zipper));
}
}
}
None
}
}
pub struct ValuedIntersectionIterator<Z: ValuedDictZipper, S> {
inner: IntersectionIterator<Z, S>,
}
impl<Z: ValuedDictZipper, S: ValueMergeStrategy<Z::Value> + Clone + Send + Sync>
ValuedIntersectionIterator<Z, S>
where
Z::Value: Lattice,
{
pub fn new(zipper: IntersectionZipper<Z, S>) -> Self {
Self {
inner: IntersectionIterator::new(zipper),
}
}
}
impl<Z: ValuedDictZipper, S: ValueMergeStrategy<Z::Value> + Clone + Send + Sync> Iterator
for ValuedIntersectionIterator<Z, S>
where
Z::Value: Lattice,
{
type Item = (Vec<Z::Unit>, Z::Value);
fn next(&mut self) -> Option<Self::Item> {
loop {
let (path, zipper) = self.inner.next()?;
if let Some(value) = zipper.value() {
return Some((path, value));
}
}
}
}
pub trait IntersectionZipperExt: DictZipper + Sized {
fn intersection_with(self, other: Self) -> IntersectionZipper<Self> {
IntersectionZipper::new(vec![self, other])
}
fn intersection_all(self, others: impl IntoIterator<Item = Self>) -> IntersectionZipper<Self> {
let mut zippers = vec![self];
zippers.extend(others);
IntersectionZipper::new(zippers)
}
}
impl<Z: DictZipper> IntersectionZipperExt for Z {}
pub trait ValuedIntersectionZipperExt: ValuedDictZipper + Sized {
fn intersection_with_strategy<S: ValueMergeStrategy<Self::Value> + Clone + Send + Sync>(
self,
other: Self,
strategy: S,
) -> IntersectionZipper<Self, S> {
IntersectionZipper::with_strategy(vec![self, other], strategy)
}
}
impl<Z: ValuedDictZipper> ValuedIntersectionZipperExt for Z {}
#[cfg(test)]
mod tests {
use super::*;
use crate::double_array_trie::DoubleArrayTrie;
use crate::double_array_trie_zipper::DoubleArrayTrieZipper;
use crate::union_zipper::LatticeJoin;
fn sorted_strings(mut v: Vec<String>) -> Vec<String> {
v.sort();
v
}
#[test]
fn test_intersection_basic() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "dog", "fish"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["cat", "fish", "bird"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = IntersectionZipper::new(vec![z1, z2]);
let results: Vec<String> = sorted_strings(
intersection
.iter()
.map(|(path, _)| String::from_utf8(path).unwrap())
.collect(),
);
assert_eq!(results, vec!["cat", "fish"]); }
#[test]
fn test_intersection_disjoint() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "dog"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["fish", "bird"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let count = intersection.iter().count();
assert_eq!(count, 0); }
#[test]
fn test_intersection_identical() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "dog"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["cat", "dog"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let results: Vec<String> = sorted_strings(
intersection
.iter()
.map(|(path, _)| String::from_utf8(path).unwrap())
.collect(),
);
assert_eq!(results, vec!["cat", "dog"]); }
#[test]
fn test_intersection_empty_dict() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "dog"].iter());
let dict2: DoubleArrayTrie = DoubleArrayTrie::new();
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let count = intersection.iter().count();
assert_eq!(count, 0); }
#[test]
fn test_intersection_three_dicts() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "dog", "fish", "bird"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["cat", "fish", "bird", "horse"].iter());
let dict3 = DoubleArrayTrie::from_terms(vec!["cat", "bird", "snake"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let z3 = DoubleArrayTrieZipper::new_from_dict(&dict3);
let intersection = z1.intersection_all(vec![z2, z3]);
let results: Vec<String> = sorted_strings(
intersection
.iter()
.map(|(path, _)| String::from_utf8(path).unwrap())
.collect(),
);
assert_eq!(results, vec!["bird", "cat"]); }
#[test]
fn test_intersection_descend() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "car"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["cat", "can"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let cat = intersection
.descend(b'c')
.and_then(|z| z.descend(b'a'))
.and_then(|z| z.descend(b't'));
assert!(cat.is_some());
let cat = cat.unwrap();
assert!(cat.is_final());
assert_eq!(cat.path(), b"cat".to_vec());
let car = intersection
.descend(b'c')
.and_then(|z| z.descend(b'a'))
.and_then(|z| z.descend(b'r'));
assert!(car.is_none());
}
#[test]
fn test_intersection_is_final() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat", "catch"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["catch"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let cat = intersection
.descend(b'c')
.and_then(|z| z.descend(b'a'))
.and_then(|z| z.descend(b't'));
assert!(cat.is_some());
let cat = cat.unwrap();
assert!(!cat.is_final());
let catch = cat.descend(b'c').and_then(|z| z.descend(b'h'));
assert!(catch.is_some());
assert!(catch.unwrap().is_final()); }
#[test]
fn test_intersection_children() {
let dict1 = DoubleArrayTrie::from_terms(vec!["ab", "ac", "ad"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["ab", "ac", "ae"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let a = intersection.descend(b'a').unwrap();
let mut children: Vec<u8> = a.children().map(|(label, _)| label).collect();
children.sort();
assert_eq!(children, vec![b'b', b'c']);
}
#[test]
fn test_valued_intersection_lattice_meet() {
let dict1 =
DoubleArrayTrie::from_terms_with_values(vec![("cat", 85u32), ("dog", 50)].into_iter());
let dict2 =
DoubleArrayTrie::from_terms_with_values(vec![("cat", 92u32), ("dog", 60)].into_iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = IntersectionZipper::new(vec![z1, z2]);
let cat = intersection
.descend(b'c')
.and_then(|z| z.descend(b'a'))
.and_then(|z| z.descend(b't'))
.unwrap();
assert_eq!(cat.value(), Some(85));
let dog = intersection
.descend(b'd')
.and_then(|z| z.descend(b'o'))
.and_then(|z| z.descend(b'g'))
.unwrap();
assert_eq!(dog.value(), Some(50));
}
#[test]
fn test_valued_intersection_lattice_join() {
let dict1 =
DoubleArrayTrie::from_terms_with_values(vec![("cat", 85u32), ("dog", 50)].into_iter());
let dict2 =
DoubleArrayTrie::from_terms_with_values(vec![("cat", 92u32), ("dog", 60)].into_iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = IntersectionZipper::with_strategy(vec![z1, z2], LatticeJoin);
let cat = intersection
.descend(b'c')
.and_then(|z| z.descend(b'a'))
.and_then(|z| z.descend(b't'))
.unwrap();
assert_eq!(cat.value(), Some(92));
}
#[test]
fn test_valued_intersection_hashset() {
let dict1 = DoubleArrayTrie::from_terms_with_values(
vec![("key", HashSet::from([1, 2, 3]))].into_iter(),
);
let dict2 = DoubleArrayTrie::from_terms_with_values(
vec![("key", HashSet::from([2, 3, 4]))].into_iter(),
);
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = IntersectionZipper::new(vec![z1, z2]);
let key = intersection
.descend(b'k')
.and_then(|z| z.descend(b'e'))
.and_then(|z| z.descend(b'y'))
.unwrap();
assert_eq!(key.value(), Some(HashSet::from([2, 3])));
}
#[test]
fn test_dictionary_count() {
let dict1 = DoubleArrayTrie::from_terms(vec!["cat"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["cat"].iter());
let dict3 = DoubleArrayTrie::from_terms(vec!["cat"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let z3 = DoubleArrayTrieZipper::new_from_dict(&dict3);
let intersection = z1.intersection_all(vec![z2, z3]);
assert_eq!(intersection.dictionary_count(), 3);
assert_eq!(intersection.active_dictionary_count(), 3);
}
#[test]
fn test_intersection_preserves_prefix_structure() {
let dict1 = DoubleArrayTrie::from_terms(vec!["apple", "application", "apply"].iter());
let dict2 = DoubleArrayTrie::from_terms(vec!["apple", "apply", "apt"].iter());
let z1 = DoubleArrayTrieZipper::new_from_dict(&dict1);
let z2 = DoubleArrayTrieZipper::new_from_dict(&dict2);
let intersection = z1.intersection_with(z2);
let results: Vec<String> = sorted_strings(
intersection
.iter()
.map(|(path, _)| String::from_utf8(path).unwrap())
.collect(),
);
assert_eq!(results, vec!["apple", "apply"]);
}
}