1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use crate::Constellation;
use generic_array::{ArrayLength, GenericArray};
use std::sync::RwLock;
#[derive(Default)]
pub struct SimpleConstellation<N: ArrayLength<f32>> {
points: RwLock<Vec<GenericArray<f32, N>>>,
}
impl<N: ArrayLength<f32>> Constellation for SimpleConstellation<N> {
fn add_points(&self, points: Vec<Vec<f32>>) {
self.points
.write()
.expect("Error getting write lock")
.extend(
points
.into_iter()
.map(|p| GenericArray::<f32, N>::from_exact_iter(p).expect("Incorrect length")),
);
}
fn find(&self, point: Vec<f32>, within: f32) -> Box<dyn Iterator<Item = (f32, Vec<f32>)>> {
let arr = GenericArray::<f32, N>::from_exact_iter(point).expect("Incorrect length");
let things: Vec<(f32, Vec<f32>)> = self
.points
.read()
.expect("Error unwrapping points")
.iter()
.filter_map(|p| {
let distance = p
.iter()
.zip(&arr)
.map(|(a, b)| (a - b).powf(2.))
.sum::<f32>()
.sqrt();
if distance <= within {
return Some((distance, p.clone().into_iter().collect()));
}
None
})
.collect();
Box::new(things.into_iter())
}
fn count(&self) -> usize {
self.points.read().expect("Error getting read lock").len()
}
fn dimensions(&self) -> usize {
N::to_usize()
}
fn memory_size(&self) -> usize {
std::mem::size_of::<GenericArray<f32, N>>() * self.count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use typenum::{U4, U8};
#[test]
fn test_len() {
let constellation = SimpleConstellation::<U4>::default();
assert_eq!(constellation.count(), 0);
}
#[test]
fn test_mem_size() {
let constellation1 = SimpleConstellation::<U8>::default();
constellation1.add_points(vec![vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]);
assert_eq!(constellation1.memory_size(), 32);
}
#[test]
fn test_add_multiple() {
let constellation = SimpleConstellation::<U4>::default();
let points: Vec<_> = vec![vec![1.0, 1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0, 1.0]];
constellation.add_points(points);
assert_eq!(constellation.count(), 2);
}
#[test]
fn test_query() {
let constellation = SimpleConstellation::<U4>::default();
constellation.add_points(vec![vec![2.0, 2.0, 2.0, 2.0]]);
let iterator = constellation.find(vec![1.0, 1.0, 1.0, 1.0], 10.);
let items: Vec<(f32, Vec<f32>)> = iterator.collect();
assert_eq!(items, vec![(2.0, vec![2.0, 2.0, 2.0, 2.0])]);
}
#[test]
fn test_query_missing() {
let constellation = SimpleConstellation::<U4>::default();
constellation.add_points(vec![vec![2., 2., 2., 2.]]);
let iterator = constellation.find(vec![1., 1., 1., 1.], 0.99);
let items: Vec<(f32, Vec<f32>)> = iterator.collect();
assert_eq!(items, vec![]);
}
}