1use std::marker::PhantomData;
11
12use crate::{prefix::mask_from_prefix_len, AsView, Prefix};
13
14use super::{TrieView, ViewIter};
15
16#[derive(Clone)]
24pub struct IntersectionView<'a, L, R> {
25 left: L,
26 right: R,
27 _phantom: PhantomData<&'a ()>,
28}
29
30impl<'a, L, R> IntersectionView<'a, L, R>
31where
32 L: TrieView<'a>,
33 R: TrieView<'a, P = L::P>,
34{
35 pub(crate) fn new(left: L, right: R) -> Option<Self> {
41 let (left, right) = align(left, right)?;
42 Some(Self {
43 left,
44 right,
45 _phantom: PhantomData,
46 })
47 }
48}
49
50fn align<'a, L, R>(left: L, right: R) -> Option<(L, R)>
54where
55 L: TrieView<'a>,
56 R: TrieView<'a, P = L::P>,
57{
58 let min_prefix_len = left.prefix_len().min(right.prefix_len());
60 let mask = mask_from_prefix_len(min_prefix_len as u8);
61 if left.key() & mask != right.key() & mask {
62 return None; }
64
65 if left.depth() < right.depth() {
67 let left = left.navigate_to(right.key(), right.prefix_len())?;
68 Some((left, right))
69 } else if right.depth() < left.depth() {
70 let right = right.navigate_to(left.key(), left.prefix_len())?;
71 Some((left, right))
72 } else if left.prefix_len() < right.prefix_len() {
73 let left = left.navigate_to(right.key(), right.prefix_len())?;
74 Some((left, right))
75 } else if right.prefix_len() < left.prefix_len() {
76 let right = right.navigate_to(left.key(), left.prefix_len())?;
77 Some((left, right))
78 } else {
79 Some((left, right))
80 }
81}
82
83impl<'a, L, R> TrieView<'a> for IntersectionView<'a, L, R>
84where
85 L: TrieView<'a>,
86 R: TrieView<'a, P = L::P>,
87{
88 type P = L::P;
89 type T = (L::T, R::T);
90
91 #[inline]
92 fn depth(&self) -> u32 {
93 self.left.depth()
94 }
95
96 #[inline]
97 fn key(&self) -> <L::P as Prefix>::R {
98 self.left.key()
99 }
100
101 #[inline]
102 fn prefix_len(&self) -> u32 {
103 self.left.prefix_len()
104 }
105
106 #[inline]
108 fn data_bitmap(&self) -> u32 {
109 self.left.data_bitmap() & self.right.data_bitmap()
110 }
111
112 #[inline]
114 fn child_bitmap(&self) -> u32 {
115 self.left.child_bitmap() & self.right.child_bitmap()
116 }
117
118 #[inline]
119 unsafe fn get_data(&mut self, data_bit: u32) -> (L::T, R::T) {
120 unsafe { (self.left.get_data(data_bit), self.right.get_data(data_bit)) }
122 }
123
124 #[inline]
125 unsafe fn get_child(&mut self, child_bit: u32) -> Self {
126 unsafe {
128 Self {
129 left: self.left.get_child(child_bit),
130 right: self.right.get_child(child_bit),
131 _phantom: PhantomData,
132 }
133 }
134 }
135
136 #[inline]
137 unsafe fn reposition(&mut self, key: <L::P as Prefix>::R, prefix_len: u32) {
138 unsafe {
140 self.left.reposition(key, prefix_len);
141 self.right.reposition(key, prefix_len);
142 }
143 }
144}
145
146impl<'a, L, R> IntoIterator for IntersectionView<'a, L, R>
147where
148 L: TrieView<'a>,
149 R: TrieView<'a, P = L::P>,
150{
151 type Item = (L::P, (L::T, R::T));
152 type IntoIter = ViewIter<'a, IntersectionView<'a, L, R>>;
153
154 fn into_iter(self) -> Self::IntoIter {
155 self.iter()
156 }
157}
158
159impl<'a, L, R> AsView<'a> for IntersectionView<'a, L, R>
160where
161 L: TrieView<'a>,
162 R: TrieView<'a, P = L::P>,
163{
164 type P = L::P;
165 type View = Self;
166
167 fn view(self) -> Self {
168 self
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use crate::{
175 Prefix,
176 {
177 trieview::{AsView, TrieView},
178 PrefixMap,
179 },
180 };
181
182 type P = (u32, u8);
183
184 fn p(repr: u32, len: u8) -> P {
185 P::from_repr_len(repr, len)
186 }
187
188 fn map_from(entries: &[(u32, u8, i32)]) -> PrefixMap<P, i32> {
189 let mut m = PrefixMap::new();
190 for &(repr, len, val) in entries {
191 m.insert(p(repr, len), val);
192 }
193 m
194 }
195
196 #[test]
197 fn intersection_basic() {
198 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0b000000, 8, 9)]);
199 let b = map_from(&[
200 (0x0a000000, 8, 10),
201 (0x0a010000, 16, 20),
202 (0x0c000000, 8, 99),
203 ]);
204 let got: Vec<_> = a
205 .view()
206 .intersection(b.view())
207 .unwrap()
208 .into_iter()
209 .map(|(p, (l, r))| (p, (*l, *r)))
210 .collect();
211 assert_eq!(
212 got,
213 vec![(p(0x0a000000, 8), (1, 10)), (p(0x0a010000, 16), (2, 20))]
214 );
215 }
216
217 #[test]
218 fn intersection_no_common_entries() {
219 let a = map_from(&[(0x0a000000, 8, 1)]);
223 let b = map_from(&[(0x0b000000, 8, 2)]);
224 let isect = a.view().intersection(b.view()).unwrap();
225 assert!(isect.into_iter().next().is_none());
226 }
227
228 #[test]
229 fn intersection_disjoint_subviews_is_none() {
230 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
232 let b = map_from(&[(0x0b000000, 8, 10), (0x0b010000, 16, 20)]);
233 let va = a.view_at(&p(0x0a000000, 8)).unwrap();
234 let vb = b.view_at(&p(0x0b000000, 8)).unwrap();
235 assert!(va.intersection(vb).is_none());
236 }
237
238 #[test]
239 fn intersection_into_iter_for_loop() {
240 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2)]);
241 let b = map_from(&[(0x0a000000, 8, 10), (0x0a010000, 16, 20)]);
242 let mut count = 0;
243 if let Some(isect) = a.view().intersection(b.view()) {
244 for (_prefix, (_l, _r)) in isect {
245 count += 1;
246 }
247 }
248 assert_eq!(count, 2);
249 }
250
251 #[test]
252 fn intersection_composed() {
253 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0b000000, 8, 3)]);
256 let b = map_from(&[
257 (0x0a000000, 8, 10),
258 (0x0a010000, 16, 20),
259 (0x0c000000, 8, 30),
260 ]);
261 let c = map_from(&[(0x0a000000, 8, 100), (0x0b000000, 8, 200)]);
262
263 let ab = a.view().intersection(b.view()).unwrap();
265 let got: Vec<_> = ab
266 .intersection(c.view())
267 .unwrap()
268 .into_iter()
269 .map(|(p, ((l, _m), r))| (p, (*l, *r)))
270 .collect();
271 assert_eq!(got, vec![(p(0x0a000000, 8), (1, 100))]);
272 }
273
274 #[test]
275 fn intersection_find_then_iter() {
276 let a = map_from(&[
278 (0x0a000000, 8, 1),
279 (0x0a010000, 16, 2),
280 (0x0a010100, 24, 3),
281 (0x0a020000, 16, 4),
282 (0x0b000000, 8, 5),
283 ]);
284 let b = map_from(&[
285 (0x0a000000, 8, 10),
286 (0x0a010000, 16, 20),
287 (0x0a010100, 24, 30),
288 (0x0a030000, 16, 40),
289 (0x0c000000, 8, 50),
290 ]);
291
292 let isect = a.view().intersection(b.view()).unwrap();
294
295 let sub: Vec<_> = isect
297 .find(&p(0x0a010000, 16))
298 .unwrap()
299 .into_iter()
300 .map(|(p, (l, r))| (p, (*l, *r)))
301 .collect();
302 assert_eq!(
303 sub,
304 vec![(p(0x0a010000, 16), (2, 20)), (p(0x0a010100, 24), (3, 30))]
305 );
306 }
307
308 #[test]
309 fn intersection_find_exact_and_value() {
310 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a010100, 24, 3)]);
311 let b = map_from(&[
312 (0x0a000000, 8, 10),
313 (0x0a010000, 16, 20),
314 (0x0a020000, 16, 40), ]);
316
317 let isect = a.view().intersection(b.view()).unwrap();
318
319 let v = isect.clone().find_exact(&p(0x0a010000, 16)).unwrap();
321 let (l, r) = v.value().unwrap();
322 assert_eq!((*l, *r), (2, 20));
323
324 assert!(isect.find_exact(&p(0x0a010100, 24)).is_none());
326 }
327
328 #[test]
329 fn intersection_mut_find_lpm_value_does_not_require_clone() {
330 let mut a = map_from(&[(0x0a000000, 8, 1), (0x0a010100, 24, 3)]);
331 let b = map_from(&[(0x0a000000, 8, 10), (0x0a010100, 24, 30)]);
332
333 let got = (&mut a)
334 .view()
335 .intersection(b.view())
336 .unwrap()
337 .find_lpm_value(&p(0x0a010180, 25))
338 .map(|(prefix, (left, right))| {
339 *left += *right;
340 (prefix, *left, *right)
341 });
342
343 assert_eq!(got, Some((p(0x0a010100, 24), 33, 30)));
344 assert_eq!(a.get(&p(0x0a010100, 24)), Some(&33));
345 }
346
347 #[test]
350 fn intersection_iter_from_inclusive() {
351 let a = map_from(&[
352 (0x0a000000, 8, 1),
353 (0x0a010000, 16, 2),
354 (0x0a020000, 16, 3),
355 (0x0a030000, 16, 4),
356 ]);
357 let b = map_from(&[
358 (0x0a000000, 8, 10),
359 (0x0a020000, 16, 30),
360 (0x0a030000, 16, 40),
361 ]);
362
363 let isect = a.view().intersection(b.view()).unwrap();
365 let from: Vec<_> = isect
366 .iter_from(&p(0x0a020000, 16), true)
367 .map(|(p, (l, r))| (p, (*l, *r)))
368 .collect();
369 assert_eq!(
370 from,
371 vec![(p(0x0a020000, 16), (3, 30)), (p(0x0a030000, 16), (4, 40))]
372 );
373 }
374
375 #[test]
376 fn intersection_iter_from_exclusive() {
377 let a = map_from(&[(0x0a000000, 8, 1), (0x0a010000, 16, 2), (0x0a020000, 16, 3)]);
378 let b = map_from(&[
379 (0x0a000000, 8, 10),
380 (0x0a010000, 16, 20),
381 (0x0a020000, 16, 30),
382 ]);
383
384 let isect = a.view().intersection(b.view()).unwrap();
385 let from: Vec<_> = isect
386 .iter_from(&p(0x0a000000, 8), false)
387 .map(|(p, (l, r))| (p, (*l, *r)))
388 .collect();
389 assert_eq!(
390 from,
391 vec![(p(0x0a010000, 16), (2, 20)), (p(0x0a020000, 16), (3, 30))]
392 );
393 }
394
395 #[test]
396 fn intersection_iter_from_subview() {
397 let a = map_from(&[
398 (0x0a000000, 8, 1), (0x0a020000, 16, 2),
400 (0x0a030000, 16, 3),
401 (0x0b000000, 8, 4), ]);
403 let b = map_from(&[
404 (0x0a000000, 8, 10), (0x0a020000, 16, 20),
406 (0x0a030000, 16, 30),
407 ]);
408
409 let isect = a
412 .view_at(&p(0x0a020000, 15))
413 .unwrap()
414 .intersection(b.view_at(&p(0x0a020000, 15)).unwrap())
415 .unwrap();
416
417 let all: Vec<_> = isect
418 .clone()
419 .iter()
420 .map(|(p, (l, r))| (p, (*l, *r)))
421 .collect();
422 assert_eq!(
423 all,
424 vec![(p(0x0a020000, 16), (2, 20)), (p(0x0a030000, 16), (3, 30))]
425 );
426
427 let from: Vec<_> = isect
429 .iter_from(&p(0x0a020000, 16), false)
430 .map(|(p, (l, r))| (p, (*l, *r)))
431 .collect();
432 assert_eq!(from, vec![(p(0x0a030000, 16), (3, 30))]);
433 }
434}