bitstring_trees/
iter.rs

1//! Iterators over bit string prefixes
2#![allow(clippy::bool_comparison)]
3
4use bitstring::BitString;
5
6/// Generate the smallest (ordered) list of prefixes covering first..=last
7// could also derive `Copy`, but iterators probably shouldn't be `Copy`?
8#[derive(Clone, Debug)]
9pub struct IterInclusive<K> {
10	// cover all values between `first 0*` and `last 1*`
11
12	// if iterator done: shared_len > first.len(), otherwise:
13	// * first[..shared_len] == last[..shared_len]
14	// * either:
15	//   - shared_len == first.len() == last.len()
16	//   - first[shared_len] == 0, no trailing "0"s after that in first and
17	//     last[shared_len] == 1, no trailing "1"s after that in last
18	first: K,          // no trailing "0"s (from shared_len+1 on)
19	last: K,           // no trailing "1"s (from shared_len+1 on)
20	shared_len: usize, // if longer than first: iterator done
21}
22
23impl<K> IterInclusive<K>
24where
25	K: BitString + Clone,
26{
27	fn empty() -> Self {
28		Self {
29			first: K::null(),
30			last: K::null(),
31			shared_len: 1,
32		}
33	}
34
35	fn all() -> Self {
36		Self {
37			first: K::null(),
38			last: K::null(),
39			shared_len: 0,
40		}
41	}
42}
43
44impl<K> Default for IterInclusive<K>
45where
46	K: BitString + Clone,
47{
48	fn default() -> Self {
49		Self::empty()
50	}
51}
52
53impl<K> Iterator for IterInclusive<K>
54where
55	K: BitString + Clone,
56{
57	type Item = K;
58
59	fn next(&mut self) -> Option<Self::Item> {
60		let first_len = self.first.len();
61		if self.shared_len > first_len {
62			return None;
63		}
64
65		// without shared prefix (of length shared_len) we should have one
66		// of these scenarios (first, last):
67		// 1. (""(0*), ""(1*)) -> yield final ""
68		// 2. ("0.*01*"(0*), "1.*"(1*)) -> yield first, -> "increment" first to "0.*1" (don't care about last)
69		// 3. ("01*"(0*), "10*"(1*)) -> yield first, -> set first = last
70		// 4. ("01*"(0*), "10*|1.*"(1*)) -> yield first, -> set first to "10*|0" from last (flipped bit after "|")
71
72		if first_len == self.shared_len {
73			let last_len = self.last.len();
74			if last_len == self.shared_len {
75				// scenario 1: yield final shared prefix
76				// mark as done
77				self.shared_len = !0;
78				self.first.clip(0);
79				return Some(self.last.clone());
80			} else {
81				debug_assert!(
82					last_len == self.shared_len,
83					"first was shared prefix, but last was longer"
84				);
85				return None; // invalid state
86			}
87		}
88		// scenario 2-4
89		let result = self.first.clone();
90		// increment first; drop all trailing "1"s, then flip trailing "0" to "1"
91		for pos in (self.shared_len + 1..first_len).rev() {
92			if false == self.first.get(pos) {
93				// scenario 2
94				self.first.clip(pos + 1); // drop trailing "1"s
95				self.first.flip(pos); // flip trailing "0" to "1"
96				return Some(result);
97			}
98		}
99
100		// scenario 3-4
101		if true == self.first.get(self.shared_len) {
102			debug_assert!(
103				!self.first.get(self.shared_len),
104				"first should have a '0' after shared prefix"
105			);
106			return None; // invalid state
107		}
108		if false == self.last.get(self.shared_len) {
109			debug_assert!(
110				self.last.get(self.shared_len),
111				"last should have a '1' after shared prefix"
112			);
113			return None; // invalid state
114		}
115
116		// copy first "1" and then as many "0"s as possible; flip next 1 if present, otherwise cut
117		self.first = self.last.clone();
118		let check_from = self.shared_len + 1; // skip leading "1"
119		self.shared_len = self.last.len(); // in case we don't find another "1" - take all (scenario 3)
120		for pos in check_from..self.shared_len {
121			if self.first.get(pos) {
122				// scenario 4
123				self.first.clip(pos + 1);
124				self.first.flip(pos);
125				self.shared_len = pos;
126				break;
127			}
128		}
129
130		Some(result)
131	}
132}
133
134/// Generate the smallest (ordered) list of prefixes covering first..=last
135///
136/// Generate smallest ordered list of prefixes to cover all
137/// values `v` with `start 0*` <= `end 1*`.
138///
139/// E.g. for IP addresses this results in the smallest list of CIDR
140/// blocks exactly covering a range.
141pub fn iter_inclusive<K>(mut first: K, mut last: K) -> IterInclusive<K>
142where
143	K: BitString + Clone,
144{
145	// trailing "0"s in `first` and trailing "1"s in `last` are semantically
146	// not important; but establish certain invariants during iteration.
147	// Also see struct notes.
148
149	// clip trailing "0"s from first
150	let mut first_len = first.len();
151	while first_len > 0 && false == first.get(first_len - 1) {
152		first_len -= 1;
153	}
154	first.clip(first_len);
155
156	// clip trailing "1"s from last
157	let mut last_len = last.len();
158	while last_len > 0 && true == last.get(last_len - 1) {
159		last_len -= 1;
160	}
161	last.clip(last_len);
162
163	let mut shared_len = first.shared_prefix_len(&last);
164
165	if shared_len == first_len {
166		// first is a prefix of last; include further "0"s from last into shared prefix
167		while shared_len < last_len && false == last.get(shared_len) {
168			shared_len += 1;
169		}
170		// copy "0"s to first
171		first = last.clone();
172
173		if shared_len == last_len {
174			// first == last, yield once
175			first.clip(shared_len);
176		} else {
177			// last continues with a "1...", make sure first continues with a "0"
178			first.clip(shared_len + 1);
179			first.flip(shared_len);
180		}
181	} else if shared_len == last_len {
182		// last is a prefix of first; include further "1"s from first into shared prefix
183		while shared_len < first_len && true == first.get(shared_len) {
184			shared_len += 1;
185		}
186		// copy "1"s to last
187		last = first.clone();
188
189		if shared_len == first_len {
190			// last == first, yield once
191			last.clip(shared_len);
192		} else {
193			// first continues with a "0...", make sure last continues with a "1"
194			last.clip(shared_len + 1);
195			last.flip(shared_len);
196		}
197	} else if first.get(shared_len) > last.get(shared_len) {
198		// wrong order: yield nothing
199		return IterInclusive::empty();
200	}
201	IterInclusive {
202		first,
203		last,
204		shared_len,
205	}
206}
207
208/// Generate smallest set of prefixes covering values between other prefixes
209///
210/// See [`iter_between`].
211#[derive(Clone, Debug)]
212pub struct IterBetween<K> {
213	range: IterInclusive<K>,
214}
215
216impl<K> Default for IterBetween<K>
217where
218	K: BitString + Clone,
219{
220	fn default() -> Self {
221		Self {
222			range: IterInclusive::empty(),
223		}
224	}
225}
226
227impl<K> Iterator for IterBetween<K>
228where
229	K: BitString + Clone,
230{
231	type Item = K;
232
233	#[inline]
234	fn next(&mut self) -> Option<Self::Item> {
235		self.range.next()
236	}
237}
238
239fn increment<K>(key: &mut K) -> bool
240where
241	K: BitString + Clone,
242{
243	// clip trailing "1"s, flip (then) trailing "0" to "1"
244	for pos in (0..key.len()).rev() {
245		if false == key.get(pos) {
246			key.clip(pos + 1);
247			key.flip(pos);
248			return true;
249		}
250	}
251	// only found "1"s (possibly empty key)
252	false
253}
254
255fn decrement<K>(key: &mut K) -> bool
256where
257	K: BitString + Clone,
258{
259	// clip trailing "0"s, flip (then) trailing "1" to "0"
260	for pos in (0..key.len()).rev() {
261		if true == key.get(pos) {
262			key.clip(pos + 1);
263			key.flip(pos);
264			return true;
265		}
266	}
267	// only found "0"s (possibly empty key)
268	false
269}
270
271/// Generate smallest set of prefixes covering values between `start` and `end`
272///
273/// Pass `None` to cover all values before or after a prefix, or simply all values.
274pub fn iter_between<K>(mut after: Option<K>, mut before: Option<K>) -> IterBetween<K>
275where
276	K: BitString + Clone,
277{
278	if let Some(start) = after.as_mut() {
279		if !increment(start) {
280			return IterBetween {
281				range: IterInclusive::empty(),
282			};
283		}
284	}
285	if let Some(end) = before.as_mut() {
286		if !decrement(end) {
287			return IterBetween {
288				range: IterInclusive::empty(),
289			};
290		}
291	}
292
293	let range = match (after, before) {
294		(Some(first), Some(last)) => iter_inclusive(first, last),
295		(Some(first), None) => {
296			let mut last = first.clone();
297			// find first "0" and flip to "1" (and clip). no "0" -> only "1"s, yield first == last
298			for pos in 0..last.len() {
299				if false == last.get(pos) {
300					last.clip(pos + 1);
301					last.flip(pos);
302					break;
303				}
304			}
305			iter_inclusive(first, last)
306		},
307		(None, Some(last)) => {
308			let mut first = last.clone();
309			// find first "1" and flip to "0" (and clip). no "1" -> only "0"s, yield first == last
310			for pos in 0..first.len() {
311				if true == first.get(pos) {
312					first.clip(pos + 1);
313					first.flip(pos);
314					break;
315				}
316			}
317			iter_inclusive(first, last)
318		},
319		(None, None) => IterInclusive::all(),
320	};
321	IterBetween { range }
322}
323
324#[cfg(test)]
325mod tests {
326	use super::iter_inclusive;
327	use alloc::vec::Vec;
328	use bitstring::BitLengthString;
329	use core::net::{
330		Ipv4Addr,
331		Ipv6Addr,
332	};
333
334	type Ipv4Cidr = BitLengthString<Ipv4Addr>;
335	type Ipv6Cidr = BitLengthString<Ipv6Addr>;
336
337	fn c4(a: &str, net: usize) -> Ipv4Cidr {
338		Ipv4Cidr::new(a.parse().unwrap(), net)
339	}
340
341	fn c6(a: &str, net: usize) -> Ipv6Cidr {
342		Ipv6Cidr::new(a.parse().unwrap(), net)
343	}
344
345	#[test]
346	fn testv4_1() {
347		assert_eq!(
348			iter_inclusive(c4("192.168.0.6", 32), c4("192.168.0.6", 32)).collect::<Vec<_>>(),
349			alloc::vec![c4("192.168.0.6", 32),]
350		);
351	}
352
353	#[test]
354	fn testv6_1() {
355		assert_eq!(
356			iter_inclusive(c6("::f0:4", 128), c6("::f0:10", 128)).collect::<Vec<_>>(),
357			alloc::vec![c6("::f0:4", 126), c6("::f0:8", 125), c6("::f0:10", 128),]
358		);
359	}
360}