use std::marker::PhantomData;
use crate::{
prefix::mask_from_prefix_len,
Prefix,
{
node::{
child_bit as node_child_bit, child_cover_mask_for_bit, data_cover_mask_for_bit,
data_lpm_mask,
},
AsView,
},
};
use super::{iter::ViewIter, TrieView};
#[derive(Clone)]
pub struct CoveringDifferenceView<'a, L, R> {
left: L,
right: Option<R>,
cov_data: u32,
cov_child: u32,
_phantom: PhantomData<&'a ()>,
}
impl<'a, L, R> CoveringDifferenceView<'a, L, R>
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
pub(crate) fn new(left: L, right: R) -> Self {
let (right, covered) = align_right(&left, right);
let (cov_data, cov_child) = r_coverage_masks(&left, right.as_ref(), covered);
Self {
left,
right,
cov_data,
cov_child,
_phantom: PhantomData,
}
}
}
fn align_right<'a, L, R>(left: &L, mut right: R) -> (Option<R>, bool)
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
let min_prefix_len = left.prefix_len().min(right.prefix_len());
let mask = mask_from_prefix_len(min_prefix_len as u8);
if left.key() & mask != right.key() & mask {
return (None, false); }
match right.depth().cmp(&left.depth()) {
std::cmp::Ordering::Greater => (Some(right), false),
std::cmp::Ordering::Equal => (Some(right), false),
std::cmp::Ordering::Less => {
loop {
let lpm = data_lpm_mask(right.depth(), left.key(), left.depth());
if right.data_bitmap() & lpm != 0 {
return (None, true);
}
if right.depth() == left.depth() {
return (Some(right), false);
}
let cb = node_child_bit(right.depth(), left.key());
if (right.child_bitmap() >> cb) & 1 == 0 {
return (None, false);
}
right = unsafe { right.get_child(cb) };
}
}
}
}
fn r_coverage_masks<'a, L, R>(left: &L, right: Option<&R>, covered: bool) -> (u32, u32)
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
if covered {
return (0x7FFF_FFFF, 0xFFFF_FFFF);
}
let Some(r) = right else { return (0, 0) };
if r.depth() != left.depth() {
return (0, 0);
}
let mut cov_data = 0u32;
let mut cov_child = 0u32;
let mut bits = r.data_bitmap();
while bits != 0 {
let r_b = bits.trailing_zeros();
bits &= bits - 1;
cov_data |= data_cover_mask_for_bit(r_b);
cov_child |= child_cover_mask_for_bit(r_b);
}
(cov_data, cov_child)
}
impl<'a, L, R> TrieView<'a> for CoveringDifferenceView<'a, L, R>
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
type P = L::P;
type T = L::T;
#[inline]
fn depth(&self) -> u32 {
self.left.depth()
}
#[inline]
fn key(&self) -> <L::P as Prefix>::R {
self.left.key()
}
#[inline]
fn prefix_len(&self) -> u32 {
self.left.prefix_len()
}
#[inline]
fn data_bitmap(&self) -> u32 {
self.left.data_bitmap() & !self.cov_data
}
#[inline]
fn child_bitmap(&self) -> u32 {
self.left.child_bitmap() & !self.cov_child
}
#[inline]
unsafe fn get_data(&mut self, data_bit: u32) -> L::T {
self.left.get_data(data_bit)
}
unsafe fn get_child(&mut self, child_bit: u32) -> Self {
let l_child = self.left.get_child(child_bit);
let r_child = match &mut self.right {
None => None,
Some(r) => {
if r.depth() == self.left.depth() {
if (r.child_bitmap() >> child_bit) & 1 == 1 {
Some(r.get_child(child_bit))
} else {
None
}
} else {
let toward_r = node_child_bit(self.left.depth(), r.key());
if child_bit == toward_r {
Some(self.right.take().unwrap())
} else {
None
}
}
}
};
let (cov_data, cov_child) = r_coverage_masks(&l_child, r_child.as_ref(), false);
CoveringDifferenceView {
left: l_child,
right: r_child,
cov_data,
cov_child,
_phantom: PhantomData,
}
}
unsafe fn reposition(&mut self, key: <L::P as Prefix>::R, prefix_len: u32) {
let left_depth = self.left.depth();
unsafe {
self.left.reposition(key, prefix_len);
if let Some(r) = self.right.as_mut() {
if r.depth() == left_depth {
r.reposition(key, prefix_len)
}
}
}
}
}
impl<'a, L, R> IntoIterator for CoveringDifferenceView<'a, L, R>
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
type Item = (L::P, L::T);
type IntoIter = ViewIter<'a, CoveringDifferenceView<'a, L, R>>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, L, R> AsView<'a> for CoveringDifferenceView<'a, L, R>
where
L: TrieView<'a>,
R: TrieView<'a, P = L::P>,
{
type P = L::P;
type View = Self;
fn view(self) -> Self {
self
}
}
#[cfg(test)]
mod tests {
use crate::{
Prefix,
{
trieview::{AsView, TrieView},
PrefixMap,
},
};
type P = (u32, u8);
fn p(repr: u32, len: u8) -> P {
P::from_repr_len(repr, len)
}
fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
let mut m = PrefixMap::new();
for &(repr, len, val) in entries {
m.insert(p(repr, len), val);
}
m
}
fn collect<'a>(iter: impl Iterator<Item = (P, &'a i32)>) -> Vec<(P, i32)> {
iter.map(|(p, v)| (p, *v)).collect()
}
#[test]
fn covering_diff_basic() {
let a = map_from(&[
(0x0a000000, 22, 1), (0x0a000000, 24, 2), (0x0a000200, 23, 3), ]);
let b = map_from(&[(0x0a000000, 23, 99)]); let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 22), 1), (p(0x0a000200, 23), 3),]);
}
#[test]
fn covering_diff_no_overlap() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b = map_from(&[(0x0b000000, 8, 99)]);
let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a010000, 16), 2),]);
}
#[test]
fn covering_diff_exact_match_excluded() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a020000, 16), 3),]);
}
#[test]
fn covering_diff_r_covers_everything() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b = map_from(&[(0x0a000000, 8, 99)]);
assert!(a
.view()
.covering_difference(b.view())
.into_iter()
.next()
.is_none());
}
#[test]
fn covering_diff_r_empty() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b: PrefixMap<P, i32> = PrefixMap::new();
let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a010000, 16), 2),]);
}
#[test]
fn covering_diff_l_empty() {
let a: PrefixMap<P, i32> = PrefixMap::new();
let b = map_from(&[(0x0a000000, 8, 99)]);
assert!(a
.view()
.covering_difference(b.view())
.into_iter()
.next()
.is_none());
}
#[test]
fn covering_diff_partial_coverage() {
let a = map_from(&[
(0x0a000000, 8, 1),
(0x0a010000, 16, 2),
(0x0a010100, 24, 3), (0x0a020000, 16, 4),
(0x0b000000, 8, 5),
]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(
got,
vec![
(p(0x0a000000, 8), 1),
(p(0x0a020000, 16), 4),
(p(0x0b000000, 8), 5),
]
);
}
#[test]
fn covering_diff_large_same_depth() {
let a = map_from(&[
(0x01000000, 8, 1),
(0x0a000000, 8, 10),
(0x0a000000, 16, 11),
(0x0a010000, 16, 12),
(0x0a010100, 24, 13),
(0x0a020000, 16, 14),
(0x0b000000, 8, 20),
(0x0b010000, 16, 21),
(0x64000000, 8, 100),
]);
let b = map_from(&[
(0x0a000000, 8, 99), (0x0b010000, 16, 99), ]);
let got = collect(a.view().covering_difference(b.view()).into_iter());
assert_eq!(
got,
vec![
(p(0x01000000, 8), 1),
(p(0x0b000000, 8), 20),
(p(0x64000000, 8), 100),
]
);
}
#[test]
fn covering_diff_find_then_iter() {
let a = map_from(&[
(0x0a000000, 8, 1),
(0x0a010000, 16, 2),
(0x0a010100, 24, 3), (0x0a020000, 16, 4),
]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let sub = a
.view()
.covering_difference(b.view())
.find(&p(0x0a000000, 8));
assert!(sub.is_some());
let got = collect(sub.unwrap().into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a020000, 16), 4),]);
}
#[test]
fn covering_diff_mut_find_lpm_value_does_not_require_clone() {
let mut a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
let b = map_from(&[(0x0a010100, 24, 30)]);
let got = (&mut a)
.view()
.covering_difference(b.view())
.find_lpm_value(&p(0x0a010180, 25))
.map(|(prefix, value)| {
*value += 10;
(prefix, *value)
});
assert_eq!(got, Some((p(0x0a010000, 16), 12)));
assert_eq!(a.get(&p(0x0a010000, 16)), Some(&12));
}
#[test]
fn covering_diff_right_shallower_covers() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b = map_from(&[(0x0a000000, 8, 99)]);
let a_sub = a.view_at(&p(0x0a000000, 8)).unwrap();
let b_root = b.view();
assert!(a_sub
.covering_difference(b_root)
.into_iter()
.next()
.is_none());
}
#[test]
fn covering_diff_right_shallower_no_cover() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b = map_from(&[(0x0b000000, 8, 99)]);
let a_sub = a.view_at(&p(0x0a000000, 8)).unwrap();
let b_root = b.view();
let got = collect(a_sub.covering_difference(b_root).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a010000, 16), 2),]);
}
#[test]
fn covering_diff_right_shallower_partial() {
let a = map_from(&[
(0x0a000000, 8, 1),
(0x0a010000, 16, 2), (0x0a020000, 16, 3),
]);
let b = map_from(&[(0x0a010000, 16, 99), (0x0b000000, 8, 99)]);
let a_sub = a.view_at(&p(0x0a000000, 8)).unwrap();
let b_root = b.view();
let got = collect(a_sub.covering_difference(b_root).into_iter());
assert_eq!(got, vec![(p(0x0a000000, 8), 1), (p(0x0a020000, 16), 3),]);
}
#[test]
fn covering_diff_left_shallower_right_deeper() {
let a = map_from(&[
(0x09000000, 8, 1),
(0x0a000000, 8, 2),
(0x0a010000, 16, 3), (0x0b000000, 8, 4),
]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let a_root = a.view();
let b_sub = b.view_at(&p(0x0a010000, 16)).unwrap();
let got = collect(a_root.covering_difference(b_sub).into_iter());
assert_eq!(
got,
vec![
(p(0x09000000, 8), 1),
(p(0x0a000000, 8), 2), (p(0x0b000000, 8), 4),
]
);
}
#[test]
fn covering_diff_composed_with_difference() {
let a = map_from(&[
(0x0a000000, 8, 1),
(0x0a010000, 16, 2),
(0x0a010100, 24, 3), (0x0b000000, 8, 4),
]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let c = map_from(&[(0x0a000000, 8, 99)]);
let got = collect(
a.view()
.covering_difference(b.view())
.difference(c.view())
.into_iter(),
);
assert_eq!(got, vec![(p(0x0b000000, 8), 4)]);
}
#[test]
fn covering_diff_composed_with_intersection() {
let a = map_from(&[
(0x0a000000, 8, 1),
(0x0a010000, 16, 2),
(0x0a010100, 24, 3), (0x0b000000, 8, 4),
]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let c = map_from(&[(0x0a000000, 8, 100), (0x0b000000, 8, 200)]);
let got: Vec<_> = a
.view()
.covering_difference(b.view())
.intersection(c.view())
.unwrap()
.into_iter()
.map(|(p, (l, r))| (p, *l, *r))
.collect();
assert_eq!(
got,
vec![(p(0x0a000000, 8), 1, 100), (p(0x0b000000, 8), 4, 200),]
);
}
#[test]
fn covering_diff_into_iter_for_loop() {
let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
let b = map_from(&[(0x0a010000, 16, 99)]);
let mut count = 0;
for _ in a.view().covering_difference(b.view()) {
count += 1;
}
assert_eq!(count, 1); }
#[test]
fn view_into_right_child() {
let a = map_from(&[(0x00000000, 0, 0), (0x00000000, 1, 1), (0x00000000, 2, 2)]);
let b = map_from(&[(0x00000000, 0, 0), (0x00000000, 2, 2)]);
let b_view = b.view_at(&p(0x00000000, 1)).unwrap();
let got = a
.view()
.covering_difference(b_view)
.iter()
.collect::<Vec<_>>();
let want = vec![(p(0x00000000, 0), &0), (p(0x00000000, 1), &1)];
assert_eq!(got, want);
}
#[test]
fn view_into_left_child() {
let a = map_from(&[(0x00000000, 0, 0), (0x00000000, 1, 1), (0x00000000, 2, 2)]);
let b = map_from(&[(0x00000000, 2, 2)]);
let a_view = a.view_at(&p(0x00000000, 1)).unwrap();
let got = a_view
.covering_difference(b.view())
.iter()
.collect::<Vec<_>>();
let want = vec![(p(0x00000000, 1), &1)];
assert_eq!(got, want);
}
#[test]
fn view_into_right_child_deep() {
let a = map_from(&[(0x00000000, 0, 0), (0x00000000, 5, 5), (0x00000000, 6, 6)]);
let b = map_from(&[(0x00000000, 0, 0), (0x00000000, 6, 6)]);
let b_view = b.view_at(&p(0x00000000, 5)).unwrap();
let got = a
.view()
.covering_difference(b_view)
.iter()
.collect::<Vec<_>>();
let want = vec![(p(0x00000000, 0), &0), (p(0x00000000, 5), &5)];
assert_eq!(got, want);
}
#[test]
fn view_into_left_child_deep() {
let a = map_from(&[(0x00000000, 0, 0), (0x00000000, 5, 5), (0x00000000, 6, 6)]);
let b = map_from(&[(0x00000000, 6, 6)]);
let a_view = a.view_at(&p(0x00000000, 5)).unwrap();
let got = a_view
.covering_difference(b.view())
.iter()
.collect::<Vec<_>>();
let want = vec![(p(0x00000000, 5), &5)];
assert_eq!(got, want);
}
}