hamming_heap/fixed_heap.rs
1/// This keeps the nearest `cap` items at all times.
2///
3/// This heap is not intended to be popped. Instead, this maintains the best `cap` items, and then when you are
4/// done adding items, you may fill a slice or iterate over the results. Theoretically, this could also allow
5/// popping elements in constant time, but that would incur a performance penalty for the highly specialized
6/// purpose this serves. This is specifically tailored for doing hamming space nearest neighbor searches.
7///
8/// To use this you will need to call `set_distances` before use. This should be passed the maximum number of
9/// distances. Please keep in mind that the maximum number of hamming distances between an `n` bit number
10/// is `n + 1`. An example would be:
11///
12/// ```
13/// assert_eq!((0u128 ^ !0).count_ones(), 128);
14/// ```
15///
16/// So make sure you use `n + 1` as your `distances` or else you may encounter a runtime panic.
17///
18/// ```
19/// use hamming_heap::FixedHammingHeap;
20/// let mut candidates = FixedHammingHeap::new_distances(129);
21/// candidates.set_capacity(3);
22/// candidates.push((0u128 ^ !0u128).count_ones(), ());
23/// ```
24#[derive(Clone, Debug)]
25pub struct FixedHammingHeap<T> {
26 cap: usize,
27 size: usize,
28 worst: u32,
29 distances: Vec<Vec<T>>,
30}
31
32impl<T> FixedHammingHeap<T> {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 /// Automatically initializes self with `distances` distances.
38 pub fn new_distances(distances: usize) -> Self {
39 let mut s = Self::new();
40 s.set_distances(distances);
41 s
42 }
43
44 /// This sets the capacity of the queue to `cap`, meaning that adding items to the queue will eject the worst ones
45 /// if they are better once `cap` is reached. If the capacity is lowered, this removes the worst elements to
46 /// keep `size == cap`.
47 pub fn set_capacity(&mut self, cap: usize) {
48 assert_ne!(cap, 0);
49 self.set_len(cap);
50 self.cap = cap;
51 // After the capacity is changed, if the size now equals the capacity we need to update the worst because it must
52 // actually be set to the worst item.
53 self.worst = self.distances.len() as u32 - 1;
54 if self.size == self.cap {
55 self.update_worst();
56 }
57 }
58
59 /// This removes elements until it reaches `len`. If `len` is higher than the current
60 /// number of elements, this does nothing. If the len is lowered, this will unconditionally allow insertions
61 /// until `cap` is reached.
62 pub fn set_len(&mut self, len: usize) {
63 if len == 0 {
64 let end = self.end();
65 for v in &mut self.distances[..=end] {
66 v.clear();
67 }
68 self.size = 0;
69 self.worst = self.distances.len() as u32 - 1;
70 } else if len < self.size {
71 // Remove the difference between them.
72 let end = self.end();
73 let mut remaining = self.size - len;
74 for vec in &mut self.distances[..=end] {
75 if vec.len() >= remaining {
76 // This has enough, remove them then break.
77 vec.drain(vec.len() - remaining..);
78 break;
79 } else {
80 // There werent enough, so remove everything and move on.
81 remaining -= vec.len();
82 vec.clear();
83 }
84 }
85 // When len is less than the cap, worst must be set to max.
86 self.worst = self.distances.len() as u32 - 1;
87 self.size = len;
88 }
89 }
90
91 /// Gets the `len` or `size` of the heap.
92 pub fn len(&self) -> usize {
93 self.size
94 }
95
96 /// Checks if the heap is empty.
97 pub fn is_empty(&self) -> bool {
98 self.size == 0
99 }
100
101 /// Clear the queue while maintaining the allocated memory.
102 pub fn clear(&mut self) {
103 assert_ne!(
104 self.distances.len(),
105 0,
106 "you must call set_distances() before calling clear()"
107 );
108 let end = self.end();
109 for v in self.distances[..=end].iter_mut() {
110 v.clear();
111 }
112 self.size = 0;
113 self.worst = self.distances.len() as u32 - 1;
114 }
115
116 /// Set number of distances. Also clears the heap.
117 ///
118 /// This does not preserve the allocated memory, so don't call this on each search.
119 ///
120 /// If you have a 128-bit number, keep in mind that it has `129` distances because
121 /// `128` is one of the possible distances.
122 pub fn set_distances(&mut self, distances: usize) {
123 self.distances.clear();
124 self.distances.resize_with(distances, || vec![]);
125 self.worst = self.distances.len() as u32 - 1;
126 self.size = 0;
127 }
128
129 /// Add a feature to the search.
130 ///
131 /// Returns true if it was added.
132 pub fn push(&mut self, distance: u32, item: T) -> bool {
133 if self.size != self.cap {
134 self.distances[distance as usize].push(item);
135 self.size += 1;
136 // Set the worst feature appropriately.
137 if self.size == self.cap {
138 self.update_worst();
139 }
140 true
141 } else {
142 unsafe { self.push_at_cap(distance, item) }
143 }
144 }
145
146 /// Fill a slice with the `top` elements and return the part of the slice written.
147 pub fn fill_slice<'a>(&self, s: &'a mut [T]) -> &'a mut [T]
148 where
149 T: Clone,
150 {
151 let total_fill = std::cmp::min(s.len(), self.size);
152 for (ix, f) in self.distances[..=self.end()]
153 .iter()
154 .flat_map(|v| v.iter())
155 .take(total_fill)
156 .enumerate()
157 {
158 s[ix] = f.clone();
159 }
160 &mut s[0..total_fill]
161 }
162
163 /// Gets the worst distance in the queue currently.
164 ///
165 /// This is initialized to max (which is the worst possible distance) until `cap` elements have been inserted.
166 pub fn worst(&self) -> u32 {
167 self.worst
168 }
169
170 /// Returns true if the cap has been reached.
171 pub fn at_cap(&self) -> bool {
172 self.size == self.cap
173 }
174
175 /// Iterate over the entire queue in best-to-worse order.
176 pub fn iter(&mut self) -> impl Iterator<Item = (u32, &T)> {
177 self.distances[..=self.end()]
178 .iter()
179 .enumerate()
180 .flat_map(|(distance, v)| v.iter().map(move |item| (distance as u32, item)))
181 }
182
183 /// Iterate over the entire queue in best-to-worse order.
184 pub fn iter_mut(&mut self) -> impl Iterator<Item = (u32, &mut T)> {
185 let end = self.end();
186 self.distances[..=end]
187 .iter_mut()
188 .enumerate()
189 .flat_map(|(distance, v)| v.iter_mut().map(move |item| (distance as u32, item)))
190 }
191
192 /// Add a feature to the search with the precondition we are already at the cap.
193 ///
194 /// Warning: This function cannot cause undefined behavior, but it can be used incorrectly.
195 /// This should only be called after `at_cap()` can been called and returns true.
196 /// This shouldn't be used unless you profile and actually find that the branch predictor is having
197 /// issues with the if statement in `push()`.
198 pub unsafe fn push_at_cap(&mut self, distance: u32, item: T) -> bool {
199 // We stop searching once we have enough features under the search distance,
200 // so if this is true it will always get added to the FeatureHeap.
201 if distance < self.worst {
202 self.distances[distance as usize].push(item);
203 self.remove_worst();
204 true
205 } else {
206 false
207 }
208 }
209
210 /// Gets the smallest known inclusive end of the datastructure.
211 fn end(&self) -> usize {
212 if self.at_cap() {
213 self.worst as usize
214 } else {
215 self.distances.len() - 1
216 }
217 }
218
219 /// Updates the worst when it has been set.
220 fn update_worst(&mut self) {
221 // If there is nothing left, it gets reset to max.
222 self.worst = self.distances[0..=self.worst as usize]
223 .iter()
224 .rev()
225 .position(|v| !v.is_empty())
226 .map(|n| self.worst - n as u32)
227 .unwrap_or(self.distances.len() as u32 - 1);
228 }
229
230 /// Remove the worst item and update the worst distance.
231 fn remove_worst(&mut self) {
232 self.distances[self.worst as usize].pop();
233 self.update_worst();
234 }
235}
236
237impl<T> Default for FixedHammingHeap<T> {
238 fn default() -> Self {
239 Self {
240 cap: 0,
241 size: 0,
242 worst: 0,
243 distances: vec![],
244 }
245 }
246}
247
248#[cfg(test)]
249#[test]
250fn test_fixed_heap() {
251 let mut candidates: FixedHammingHeap<u32> = FixedHammingHeap::new();
252 candidates.set_distances(11);
253 candidates.set_capacity(3);
254 assert!(candidates.push(5, 0));
255 assert!(candidates.push(4, 1));
256 assert!(candidates.push(3, 2));
257 assert!(!candidates.push(6, 3));
258 assert!(!candidates.push(7, 4));
259 assert!(candidates.push(2, 5));
260 assert!(candidates.push(3, 6));
261 assert!(!candidates.push(10, 7));
262 assert!(!candidates.push(6, 8));
263 assert!(!candidates.push(4, 9));
264 assert!(candidates.push(1, 10));
265 assert!(candidates.push(2, 11));
266 let mut arr = [0; 3];
267 candidates.fill_slice(&mut arr);
268 arr[1..3].sort_unstable();
269 assert_eq!(arr, [10, 5, 11]);
270}