citadel_vector/vendored/prism/
point.rs1pub struct PointStore {
6 pub vectors: Vec<f32>,
8 pub dim: usize,
10 pub len: usize,
12 pub attrs: Vec<Vec<u32>>,
14}
15
16impl PointStore {
17 pub fn new(dim: usize, k: usize) -> Self {
18 Self {
19 vectors: Vec::new(),
20 dim,
21 len: 0,
22 attrs: vec![Vec::new(); k],
23 }
24 }
25
26 pub fn from_parts(vectors: Vec<f32>, dim: usize, attrs: Vec<Vec<u32>>) -> Self {
28 let len = vectors.len() / dim;
29 debug_assert_eq!(vectors.len(), len * dim);
30 for a in &attrs {
31 debug_assert_eq!(a.len(), len);
32 }
33 Self {
34 vectors,
35 dim,
36 len,
37 attrs,
38 }
39 }
40
41 pub fn k(&self) -> usize {
43 self.attrs.len()
44 }
45
46 #[inline]
48 pub fn vector(&self, id: u32) -> &[f32] {
49 let start = id as usize * self.dim;
50 &self.vectors[start..start + self.dim]
51 }
52
53 #[inline]
55 pub fn attr(&self, id: u32, j: usize) -> u32 {
56 self.attrs[j][id as usize]
57 }
58
59 pub fn push(&mut self, vector: &[f32], attr_values: &[u32]) -> u32 {
61 debug_assert_eq!(vector.len(), self.dim);
62 debug_assert_eq!(attr_values.len(), self.attrs.len());
63 let id = self.len as u32;
64 self.vectors.extend_from_slice(vector);
65 for (j, &val) in attr_values.iter().enumerate() {
66 self.attrs[j].push(val);
67 }
68 self.len += 1;
69 id
70 }
71
72 pub fn cardinality(&self, j: usize) -> usize {
74 let mut seen = std::collections::HashSet::new();
75 for &v in &self.attrs[j] {
76 seen.insert(v);
77 }
78 seen.len()
79 }
80}
81
82impl Drop for PointStore {
83 fn drop(&mut self) {
84 use zeroize::Zeroize;
87 self.vectors.zeroize();
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 fn test_point_store_basic() {
97 let mut store = PointStore::new(3, 2);
98 let id0 = store.push(&[1.0, 2.0, 3.0], &[0, 1]);
99 let id1 = store.push(&[4.0, 5.0, 6.0], &[1, 0]);
100 assert_eq!(id0, 0);
101 assert_eq!(id1, 1);
102 assert_eq!(store.len, 2);
103 assert_eq!(store.vector(0), &[1.0, 2.0, 3.0]);
104 assert_eq!(store.vector(1), &[4.0, 5.0, 6.0]);
105 assert_eq!(store.attr(0, 0), 0);
106 assert_eq!(store.attr(0, 1), 1);
107 assert_eq!(store.attr(1, 0), 1);
108 assert_eq!(store.attr(1, 1), 0);
109 }
110
111 #[test]
112 fn test_from_parts() {
113 let vectors = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
114 let attrs = vec![vec![0, 1], vec![1, 0]];
115 let store = PointStore::from_parts(vectors, 3, attrs);
116 assert_eq!(store.len, 2);
117 assert_eq!(store.k(), 2);
118 assert_eq!(store.cardinality(0), 2);
119 }
120}