lsph/map/mod.rs
1mod nn;
2mod table;
3
4use crate::{
5 error::*,
6 geometry::{distance::*, Point},
7 hasher::*,
8 map::{nn::*, table::*},
9 models::Model,
10};
11use core::{fmt::Debug, iter::Sum, mem};
12use num_traits::{
13 cast::{AsPrimitive, FromPrimitive},
14 float::Float,
15};
16use std::collections::BinaryHeap;
17
18/// Initial bucket size is set to 1
19const INITIAL_NBUCKETS: usize = 1;
20
21/// LearnedHashMap takes a model instead of an hasher for hashing indexes in the table.
22///
23/// Default Model for the LearndedHashMap is Linear regression.
24/// In order to build a ordered HashMap, we need to make sure that the model is **monotonic**.
25#[derive(Debug, Clone)]
26pub struct LearnedHashMap<M, F> {
27 hasher: LearnedHasher<M>,
28 table: Table<Point<F>>,
29 items: usize,
30}
31
32/// Default for the LearndedHashMap.
33impl<M, F> Default for LearnedHashMap<M, F>
34where
35 F: Float,
36 M: Model<F = F> + Default,
37{
38 fn default() -> Self {
39 Self {
40 hasher: LearnedHasher::<M>::new(),
41 table: Table::new(),
42 items: 0,
43 }
44 }
45}
46
47impl<M, F> LearnedHashMap<M, F>
48where
49 F: Float + Default + AsPrimitive<u64> + FromPrimitive + Debug + Sum,
50 M: Model<F = F> + Default + Clone,
51{
52 /// Returns a default LearnedHashMap with Model and Float type.
53 ///
54 /// # Examples
55 ///
56 /// ```
57 /// use lsph::{LearnedHashMap, LinearModel};
58 /// let map = LearnedHashMap::<LinearModel<f64>, f64>::new();
59 /// ```
60 #[inline]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 /// Returns a default LearnedHashMap with Model and Float type.
66 ///
67 /// # Arguments
68 /// * `hasher` - A LearnedHasher with model
69 ///
70 /// # Examples
71 ///
72 /// ```
73 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
74 /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_hasher(LearnedHasher::new());
75 /// ```
76 #[inline]
77 pub fn with_hasher(hasher: LearnedHasher<M>) -> Self {
78 Self {
79 hasher,
80 table: Table::new(),
81 items: 0,
82 }
83 }
84
85 /// Returns a default LearnedHashMap with Model and Float type.
86 ///
87 /// # Arguments
88 /// * `capacity` - A predefined capacity size for the LearnedHashMap
89 ///
90 /// # Examples
91 ///
92 /// ```
93 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
94 /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_capacity(10usize);
95 /// ```
96 #[inline]
97 pub fn with_capacity(capacity: usize) -> Self {
98 Self {
99 hasher: Default::default(),
100 table: Table::with_capacity(capacity),
101 items: 0,
102 }
103 }
104
105 /// Returns a default LearnedHashMap with Model and Float type
106 ///
107 /// # Arguments
108 /// * `data` - A Vec<[F; 2]> of 2d points for the map
109 ///
110 /// # Examples
111 ///
112 /// ```
113 /// use lsph::{LearnedHashMap, LinearModel};
114 /// let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
115 /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&data);
116 /// ```
117 #[inline]
118 pub fn with_data(data: &[[F; 2]]) -> Result<(Self, Vec<Point<F>>), Error> {
119 use crate::helper::convert_to_points;
120 let mut map = LearnedHashMap::with_capacity(data.len());
121 let mut ps = convert_to_points(data).unwrap();
122 match map.batch_insert(&mut ps) {
123 Ok(()) => Ok((map, ps)),
124 Err(err) => Err(err),
125 }
126 }
127
128 /// Returns Option<Point<F>> with given point data.
129 ///
130 /// # Arguments
131 /// * `p` - A array slice containing two points for querying
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
137 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
138 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
139 ///
140 /// assert_eq!(map.get(&[1., 1.]).is_some(), true);
141 /// ```
142 #[inline]
143 pub fn get(&mut self, p: &[F; 2]) -> Option<&Point<F>> {
144 let hash = make_hash_point(&mut self.hasher, p) as usize;
145 if hash > self.table.capacity() {
146 return None;
147 }
148 self.find_by_hash(hash, p)
149 }
150
151 /// Returns Option<Point<F>> by hash index, if it exists in the map.
152 ///
153 /// # Arguments
154 /// * `hash` - An usize hash value
155 /// * `p` - A array slice containing two points for querying
156 ///
157 /// # Examples
158 ///
159 /// ```
160 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
161 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
162 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
163 ///
164 /// assert_eq!(map.find_by_hash(0, &[1., 1.]).is_some(), true);
165 /// assert_eq!(map.find_by_hash(1, &[1., 1.]).is_none(), true);
166 /// ```
167 pub fn find_by_hash(&self, hash: usize, p: &[F; 2]) -> Option<&Point<F>> {
168 self.table[hash]
169 .iter()
170 .find(|&ep| ep.x == p[0] && ep.y == p[1])
171 }
172
173 /// Returns bool.
174 ///
175 /// # Arguments
176 /// * `p` - A array slice containing two points for querying
177 ///
178 /// # Examples
179 ///
180 /// ```
181 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
182 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
183 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
184 ///
185 /// assert_eq!(map.contains_points(&[1., 1.]), true);
186 /// assert_eq!(map.contains_points(&[0., 1.]), false);
187 /// ```
188 #[inline]
189 pub fn contains_points(&mut self, p: &[F; 2]) -> bool {
190 self.get(p).is_some()
191 }
192
193 /// Returns Option<Point<F>> if the map contains a point and successful remove it from the map.
194 ///
195 /// # Arguments
196 /// * `p` - A Point data
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
202 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
203 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
204 ///
205 /// let p = points[0];
206 /// assert_eq!(map.remove(&p).unwrap(), p);
207 /// ```
208 #[inline]
209 pub fn remove(&mut self, p: &Point<F>) -> Option<Point<F>> {
210 let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]);
211 self.items -= 1;
212 self.table.remove_entry(hash, *p)
213 }
214
215 /// Returns usize length.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
221 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
222 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
223 ///
224 /// assert_eq!(map.len(), 4);
225 /// ```
226 #[inline]
227 pub fn len(&self) -> usize {
228 self.table.len()
229 }
230
231 /// Returns usize number of items.
232 ///
233 /// # Examples
234 ///
235 /// ```
236 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
237 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
238 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
239 ///
240 /// assert_eq!(map.items(), 4);
241 /// ```
242 #[inline]
243 pub fn items(&self) -> usize {
244 self.items
245 }
246
247 /// Returns bool if the map is empty.
248 ///
249 /// # Examples
250 ///
251 /// ```
252 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
253 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
254 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
255 ///
256 /// assert_eq!(map.is_empty(), false);
257 /// ```
258 #[inline]
259 pub fn is_empty(&self) -> bool {
260 self.items == 0
261 }
262
263 /// Resize the map if needed, it will initialize the map to the INITIAL_NBUCKETS, otherwise it will double the capacity if table is not empty.
264 fn resize(&mut self) {
265 let target_size = match self.table.len() {
266 0 => INITIAL_NBUCKETS,
267 n => 2 * n,
268 };
269 self.resize_with_capacity(target_size);
270 }
271
272 /// Resize the map if needed, it will resize the map to desired capacity.
273 #[inline]
274 fn resize_with_capacity(&mut self, target_size: usize) {
275 let mut new_table = Table::with_capacity(target_size);
276 new_table.extend((0..target_size).map(|_| Bucket::new()));
277
278 for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) {
279 let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]) as usize;
280 new_table[hash].push(p);
281 }
282
283 self.table = new_table;
284 }
285
286 /// Rehash the map.
287 #[inline]
288 fn rehash(&mut self) -> Result<(), Error> {
289 let mut old_data = Vec::with_capacity(self.items());
290 for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) {
291 old_data.push(p);
292 }
293 self.batch_insert(&mut old_data)
294 }
295
296 /// Inner function for insert a single point into the map
297 #[inline]
298 fn insert_inner(&mut self, p: Point<F>) -> Option<Point<F>> {
299 // Resize if the table is empty or 3/4 size of the table is full
300 if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 {
301 self.resize();
302 }
303 let hash = make_hash_point::<M, F>(&mut self.hasher, &[p.x, p.y]);
304 self.insert_with_axis(p, hash)
305 }
306
307 /// Sequencial insert a point into the map.
308 ///
309 /// # Arguments
310 /// * `p` - A Point<F> with float number
311 ///
312 /// # Examples
313 ///
314 /// ```
315 /// use lsph::{LearnedHashMap, LinearModel, Point};
316 /// let a: Point<f64> = Point::new(0., 1.);
317 /// let b: Point<f64> = Point::new(1., 0.);
318
319 /// let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
320 /// map.insert(a);
321 /// map.insert(b);
322
323 /// assert_eq!(map.items(), 2);
324 /// assert_eq!(map.get(&[0., 1.]).unwrap(), &a);
325 /// assert_eq!(map.get(&[1., 0.]).unwrap(), &b);
326 /// ```
327 pub fn insert(&mut self, p: Point<F>) -> Option<Point<F>> {
328 // Resize if the table is empty or 3/4 size of the table is full
329 if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 {
330 self.resize();
331 }
332
333 let hash = make_hash_point::<M, F>(&mut self.hasher, &[p.x, p.y]);
334 // resize if hash index is larger or equal to the table capacity
335 if hash >= self.table.capacity() as u64 {
336 self.resize_with_capacity(hash as usize * 2);
337 self.insert_with_axis(p, hash);
338 match self.rehash() {
339 Ok(_) => None,
340 Err(err) => {
341 eprintln!("{:?}", err);
342 None
343 }
344 }
345 } else {
346 self.insert_with_axis(p, hash)
347 }
348 }
349
350 /// Insert a point into the map along the given axis.
351 ///
352 /// # Arguments
353 /// * `p_value` - A float number represent the key of a 2d point
354 #[inline]
355 fn insert_with_axis(&mut self, p: Point<F>, hash: u64) -> Option<Point<F>> {
356 let mut insert_index = 0;
357 let bucket_index = self.table.bucket(hash);
358 let bucket = &mut self.table[bucket_index];
359 if self.hasher.sort_by_x() {
360 // Get index from the hasher
361 for ep in bucket.iter_mut() {
362 if ep == &mut p.clone() {
363 return Some(mem::replace(ep, p));
364 }
365 if ep.y < p.y() {
366 insert_index += 1;
367 }
368 }
369 } else {
370 for ep in bucket.iter_mut() {
371 if ep == &mut p.clone() {
372 return Some(mem::replace(ep, p));
373 }
374 if ep.x < p.x() {
375 insert_index += 1;
376 }
377 }
378 }
379 bucket.insert(insert_index, p);
380 self.items += 1;
381 None
382 }
383
384 /// Fit the input data into the model of the hasher. Returns Error if error occurred during
385 /// model fitting.
386 ///
387 /// # Arguments
388 ///
389 /// * `xs` - A list of tuple of floating number
390 /// * `ys` - A list of tuple of floating number
391 #[inline]
392 pub fn model_fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> {
393 self.hasher.model.fit(xs, ys)
394 }
395
396 /// Fit the input data into the model of the hasher. Returns Error if error occurred during
397 /// model fitting.
398 ///
399 /// # Arguments
400 /// * `data` - A list of tuple of floating number
401 #[inline]
402 pub fn model_fit_tuple(&mut self, data: &[(F, F)]) -> Result<(), Error> {
403 self.hasher.model.fit_tuple(data)
404 }
405
406 /// Inner function for batch insert
407 #[inline]
408 fn batch_insert_inner(&mut self, ps: &[Point<F>]) {
409 // Allocate table capacity before insert
410 let n = ps.len();
411 self.resize_with_capacity(n);
412 for p in ps.iter() {
413 self.insert_inner(*p);
414 }
415 }
416
417 /// Batch insert a batch of 2d data into the map.
418 ///
419 /// # Arguments
420 /// * `ps` - A list of point number
421 ///
422 /// # Examples
423 ///
424 /// ```
425 /// use lsph::{LearnedHashMap, LinearModel};
426 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
427 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
428 ///
429 /// assert_eq!(map.get(&[1., 1.]).is_some(), true);
430 /// ```
431 #[inline]
432 pub fn batch_insert(&mut self, ps: &mut [Point<F>]) -> Result<(), Error> {
433 // Select suitable axis for training
434 use crate::geometry::Axis;
435 use crate::models::Trainer;
436
437 // Loading data into trainer
438 if let Ok(trainer) = Trainer::with_points(ps) {
439 trainer.train(&mut self.hasher.model).unwrap();
440 let axis = trainer.axis();
441 match axis {
442 Axis::X => self.hasher.set_sort_by_x(true),
443 _ => self.hasher.set_sort_by_x(false),
444 };
445
446 // Fit the data into model
447 self.model_fit(trainer.train_x(), trainer.train_y())
448 .unwrap();
449 // Batch insert into the map
450 self.batch_insert_inner(ps);
451 }
452 Ok(())
453 }
454
455 /// Range search finds all points for a given 2d range
456 /// Returns all the points within the given range
457 /// ```text
458 /// | top right
459 /// | .-----------*
460 /// | | . . |
461 /// | | . . . |
462 /// | | . |
463 /// bottom left *-----------.
464 /// |
465 /// | | |
466 /// |________v___________v________
467 /// left right
468 /// hash hash
469 /// ```
470 /// # Arguments
471 ///
472 /// * `bottom_left` - A tuple containing a pair of points that represent the bottom left of the
473 /// range.
474 ///
475 /// * `top_right` - A tuple containing a pair of points that represent the top right of the
476 /// range.
477 #[inline]
478 pub fn range_search(
479 &mut self,
480 bottom_left: &[F; 2],
481 top_right: &[F; 2],
482 ) -> Option<Vec<Point<F>>> {
483 let mut right_hash = make_hash_point(&mut self.hasher, top_right) as usize;
484 if right_hash > self.table.capacity() {
485 right_hash = self.table.capacity() as usize - 1;
486 }
487 let left_hash = make_hash_point(&mut self.hasher, bottom_left) as usize;
488 if left_hash > self.table.capacity() || left_hash > right_hash {
489 return None;
490 }
491 let mut result: Vec<Point<F>> = Vec::new();
492 for i in left_hash..=right_hash {
493 let bucket = &self.table[i];
494 for item in bucket.iter() {
495 if item.x >= bottom_left[0]
496 && item.x <= top_right[0]
497 && item.y >= bottom_left[1]
498 && item.y <= top_right[1]
499 {
500 result.push(*item);
501 }
502 }
503 }
504 if result.is_empty() {
505 return None;
506 }
507 Some(result)
508 }
509
510 /// Returns Option<Vec<Point<F>>> if points are found in the map with given range
511 ///
512 /// # Arguments
513 /// * `query_point` - A Point data for querying
514 /// * `radius` - A radius value
515 ///
516 /// # Examples
517 ///
518 /// ```
519 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
520 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
521 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
522 /// assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true);
523 /// ```
524 #[inline]
525 pub fn radius_range(&mut self, query_point: &[F; 2], radius: F) -> Option<Vec<Point<F>>> {
526 self.range_search(
527 &[query_point[0] - radius, query_point[1] - radius],
528 &[query_point[0] + radius, query_point[1] + radius],
529 )
530 }
531
532 /// Find the local minimum distance between query points and cadidates neighbors, then store
533 /// the cadidates neighbors in the min_heap.
534 ///
535 ///
536 /// # Arguments
537 /// * `heap` - mutable borrow of an BinaryHeap
538 /// * `local_hash` - A hash index of local bucket
539 /// * `query_point` - A Point data
540 /// * `min_d` - minimum distance
541 /// * `nearest_neighbor` - mutable borrow of an point data, which is the nearest neighbor at
542 /// search index bucket
543 #[inline]
544 fn local_min_heap(
545 &self,
546 heap: &mut BinaryHeap<NearestNeighborState<F>>,
547 local_hash: u64,
548 query_point: &[F; 2],
549 min_d: &mut F,
550 nearest_neighbor: &mut Point<F>,
551 ) {
552 let bucket = &self.table[local_hash as usize];
553 if !bucket.is_empty() {
554 for p in bucket.iter() {
555 let d = Euclidean::distance(query_point, &[p.x, p.y]);
556 heap.push(NearestNeighborState {
557 distance: d,
558 point: *p,
559 });
560 }
561 }
562 match heap.pop() {
563 Some(v) => {
564 let local_min_d = v.distance;
565 // Update the nearest neighbour and minimum distance
566 if &local_min_d < min_d {
567 *nearest_neighbor = v.point;
568 *min_d = local_min_d;
569 }
570 }
571 None => (),
572 }
573 }
574
575 /// Calculates the horizontal distance between query_point and bucket at index with given hash.
576 ///
577 /// # Arguments
578 /// * `hash` - A hash index of the bucket
579 /// * `query_point` - A Point data
580 #[inline]
581 fn horizontal_distance(&mut self, query_point: &[F; 2], hash: u64) -> F {
582 let x = unhash(&mut self.hasher, hash);
583 match self.hasher.sort_by_x() {
584 true => Euclidean::distance(&[query_point[0], F::zero()], &[x, F::zero()]),
585 false => Euclidean::distance(&[query_point[1], F::zero()], &[x, F::zero()]),
586 }
587 }
588
589 /// Nearest neighbor search for the closest point for given query point
590 /// Returns the closest point
591 ///```text
592 /// |
593 /// | .
594 /// | . |
595 /// | |. | * . <- nearest neighbor
596 /// | || | | .|
597 /// | expand <--------> expand
598 /// | left | right
599 /// | |
600 /// |_______________v_____________
601 /// query
602 /// point
603 ///```
604 /// # Arguments
605 ///
606 /// * `query_point` - A tuple containing a pair of points for querying
607 ///
608 /// # Examples
609 ///
610 /// ```
611 /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
612 /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
613 /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
614 /// assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true);
615 /// ```
616 #[inline]
617 pub fn nearest_neighbor(&mut self, query_point: &[F; 2]) -> Option<Point<F>> {
618 let mut hash = make_hash_point(&mut self.hasher, query_point);
619 let max_capacity = self.table.capacity() as u64;
620
621 // if hash out of max bound, still search right most bucket
622 if hash > max_capacity {
623 hash = max_capacity - 1;
624 }
625
626 let mut heap = BinaryHeap::new();
627 let mut min_d = F::max_value();
628 let mut nearest_neighbor = Point::default();
629
630 // Searching at current hash index
631 self.local_min_heap(
632 &mut heap,
633 hash,
634 query_point,
635 &mut min_d,
636 &mut nearest_neighbor,
637 );
638
639 // Measure left horizontal distance from current bucket to left hash bucket
640 // left hash must >= 0
641 let mut left_hash = hash.saturating_sub(1);
642 // Unhash the left_hash, then calculate the vertical distance between
643 // left hash point and query point
644 let mut left_hash_d = self.horizontal_distance(query_point, left_hash);
645
646 // Iterate over left
647 while left_hash_d < min_d {
648 self.local_min_heap(
649 &mut heap,
650 left_hash,
651 query_point,
652 &mut min_d,
653 &mut nearest_neighbor,
654 );
655
656 // break before update
657 if left_hash == 0 {
658 break;
659 }
660
661 // Update next right side bucket distance
662 left_hash = left_hash.saturating_sub(1);
663 left_hash_d = self.horizontal_distance(query_point, left_hash);
664 }
665
666 // Measure right vertical distance from current bucket to right hash bucket
667 let mut right_hash = hash + 1;
668 // Unhash the right_hash, then calculate the vertical distance between
669 // right hash point and query point
670 let mut right_hash_d = self.horizontal_distance(query_point, right_hash);
671
672 // Iterate over right
673 while right_hash_d < min_d {
674 self.local_min_heap(
675 &mut heap,
676 right_hash,
677 query_point,
678 &mut min_d,
679 &mut nearest_neighbor,
680 );
681
682 // Move to next right bucket
683 right_hash += 1;
684
685 // break after update
686 if right_hash == self.table.capacity() as u64 {
687 break;
688 }
689 // Update next right side bucket distance
690 right_hash_d = self.horizontal_distance(query_point, right_hash);
691 }
692
693 Some(nearest_neighbor)
694 }
695}
696
697pub struct Iter<'a, M, F>
698where
699 F: Float,
700 M: Model<F = F> + Default + Clone,
701{
702 map: &'a LearnedHashMap<M, F>,
703 bucket: usize,
704 at: usize,
705}
706
707impl<'a, M, F> Iterator for Iter<'a, M, F>
708where
709 F: Float,
710 M: Model<F = F> + Default + Clone,
711{
712 type Item = &'a Point<F>;
713 fn next(&mut self) -> Option<Self::Item> {
714 loop {
715 match self.map.table.get(self.bucket) {
716 Some(bucket) => {
717 match bucket.get(self.at) {
718 Some(p) => {
719 // move along self.at and self.bucket
720 self.at += 1;
721 break Some(p);
722 }
723 None => {
724 self.bucket += 1;
725 self.at = 0;
726 continue;
727 }
728 }
729 }
730 None => break None,
731 }
732 }
733 }
734}
735
736impl<'a, M, F> IntoIterator for &'a LearnedHashMap<M, F>
737where
738 F: Float,
739 M: Model<F = F> + Default + Clone,
740{
741 type Item = &'a Point<F>;
742 type IntoIter = Iter<'a, M, F>;
743 fn into_iter(self) -> Self::IntoIter {
744 Iter {
745 map: self,
746 bucket: 0,
747 at: 0,
748 }
749 }
750}
751
752pub struct IntoIter<M, F>
753where
754 F: Float,
755 M: Model<F = F> + Default + Clone,
756{
757 map: LearnedHashMap<M, F>,
758 bucket: usize,
759}
760
761impl<M, F> Iterator for IntoIter<M, F>
762where
763 F: Float,
764 M: Model<F = F> + Default + Clone,
765{
766 type Item = Point<F>;
767 fn next(&mut self) -> Option<Self::Item> {
768 loop {
769 match self.map.table.get_mut(self.bucket) {
770 Some(bucket) => match bucket.pop() {
771 Some(x) => break Some(x),
772 None => {
773 self.bucket += 1;
774 continue;
775 }
776 },
777 None => break None,
778 }
779 }
780 }
781}
782
783impl<M, F> IntoIterator for LearnedHashMap<M, F>
784where
785 F: Float,
786 M: Model<F = F> + Default + Clone,
787{
788 type Item = Point<F>;
789 type IntoIter = IntoIter<M, F>;
790 fn into_iter(self) -> Self::IntoIter {
791 IntoIter {
792 map: self,
793 bucket: 0,
794 }
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use super::*;
801 use crate::geometry::Point;
802 use crate::models::LinearModel;
803 use crate::test_utilities::*;
804
805 #[test]
806 fn insert() {
807 let a: Point<f64> = Point::new(0., 1.);
808 let b: Point<f64> = Point::new(1., 0.);
809
810 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
811 map.insert(a);
812 map.insert(b);
813
814 assert_eq!(map.items(), 2);
815 assert_eq!(map.get(&[0., 1.]).unwrap(), &a);
816 assert_eq!(map.get(&[1., 0.]).unwrap(), &b);
817 }
818
819 #[test]
820 fn insert_repeated() {
821 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
822 let a: Point<f64> = Point::new(0., 1.);
823 let b: Point<f64> = Point::new(1., 0.);
824 let res = map.insert(a);
825 assert_eq!(map.items(), 1);
826 assert_eq!(res, None);
827
828 let res = map.insert(b);
829 assert_eq!(map.items(), 2);
830 assert_eq!(res, None);
831 }
832
833 #[test]
834 fn with_data() {
835 let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
836 let (mut map, _points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&data).unwrap();
837 assert_eq!(map.get(&[1., 1.]).is_some(), true);
838 }
839
840 #[test]
841 fn fit_batch_insert() {
842 let mut data: Vec<Point<f64>> = vec![
843 Point::new(1., 1.),
844 Point::new(3., 1.),
845 Point::new(2., 1.),
846 Point::new(3., 2.),
847 Point::new(5., 1.),
848 ];
849 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
850 map.batch_insert(&mut data).unwrap();
851 dbg!(&map);
852
853 assert_delta!(1.02272, map.hasher.model.coefficient, 0.00001);
854 assert_delta!(-0.86363, map.hasher.model.intercept, 0.00001);
855 assert_eq!(Some(&Point::new(1., 1.)), map.get(&[1., 1.]));
856 assert_eq!(Some(&Point::new(3., 1.,)), map.get(&[3., 1.]));
857 assert_eq!(Some(&Point::new(5., 1.)), map.get(&[5., 1.]));
858
859 assert_eq!(None, map.get(&[5., 2.]));
860 assert_eq!(None, map.get(&[2., 2.]));
861 assert_eq!(None, map.get(&[50., 10.]));
862 assert_eq!(None, map.get(&[500., 100.]));
863 }
864
865 #[test]
866 fn insert_after_batch_insert() {
867 let mut data: Vec<Point<f64>> = vec![
868 Point::new(1., 1.),
869 Point::new(3., 1.),
870 Point::new(2., 1.),
871 Point::new(3., 2.),
872 Point::new(5., 1.),
873 ];
874 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
875 map.batch_insert(&mut data).unwrap();
876 dbg!(&map);
877
878 let a: Point<f64> = Point::new(10., 10.);
879 map.insert(a.clone());
880 assert_eq!(Some(&a), map.get(&[10., 10.]));
881
882 let b: Point<f64> = Point::new(100., 100.);
883 map.insert(b.clone());
884 assert_eq!(Some(&b), map.get(&[100., 100.]));
885 assert_eq!(None, map.get(&[100., 101.]));
886 }
887
888 #[test]
889 fn range_search() {
890 let mut data: Vec<Point<f64>> = vec![
891 Point::new(1., 1.),
892 Point::new(2., 2.),
893 Point::new(3., 3.),
894 Point::new(4., 4.),
895 Point::new(5., 5.),
896 ];
897 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
898 map.batch_insert(&mut data).unwrap();
899 // dbg!(&map);
900
901 let found: Vec<Point<f64>> =
902 vec![Point::new(1., 1.), Point::new(2., 2.), Point::new(3., 3.)];
903
904 assert_eq!(Some(found), map.range_search(&[1., 1.], &[3.5, 3.]));
905
906 let found: Vec<Point<f64>> = vec![Point::new(1., 1.)];
907
908 assert_eq!(Some(found), map.range_search(&[1., 1.], &[3., 1.]));
909 assert_eq!(None, map.range_search(&[4., 2.], &[5., 3.]));
910 }
911
912 #[test]
913 fn test_nearest_neighbor() {
914 let points = create_random_point_type_points(1000, SEED_1);
915 let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
916 map.batch_insert(&mut points.clone()).unwrap();
917
918 let sample_points = create_random_point_type_points(100, SEED_2);
919 let mut i = 0;
920 for sample_point in &sample_points {
921 let mut nearest = None;
922 let mut closest_dist = ::core::f64::INFINITY;
923 for point in &points {
924 let new_dist = Euclidean::distance_point(&point, &sample_point);
925 if new_dist < closest_dist {
926 closest_dist = new_dist;
927 nearest = Some(point);
928 }
929 }
930 let map_nearest = map
931 .nearest_neighbor(&[sample_point.x, sample_point.y])
932 .unwrap();
933 assert_eq!(nearest.unwrap(), &map_nearest);
934 i = i + 1;
935 }
936 }
937}