hermes_core/structures/vector/ivf/
cluster.rs1use serde::{Deserialize, Serialize};
7
8pub trait QuantizedCode: Clone + Send + Sync {
10 fn size_bytes(&self) -> usize;
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20 pub doc_ids: Vec<u32>,
22 pub codes: Vec<C>,
24 pub raw_vectors: Option<Vec<Vec<f32>>>,
26}
27
28impl<C: Clone> Default for ClusterData<C> {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl<C: Clone> ClusterData<C> {
35 pub fn new() -> Self {
36 Self {
37 doc_ids: Vec::new(),
38 codes: Vec::new(),
39 raw_vectors: None,
40 }
41 }
42
43 pub fn with_capacity(capacity: usize) -> Self {
44 Self {
45 doc_ids: Vec::with_capacity(capacity),
46 codes: Vec::with_capacity(capacity),
47 raw_vectors: None,
48 }
49 }
50
51 pub fn len(&self) -> usize {
52 self.doc_ids.len()
53 }
54
55 pub fn is_empty(&self) -> bool {
56 self.doc_ids.is_empty()
57 }
58
59 pub fn add(&mut self, doc_id: u32, code: C, raw_vector: Option<Vec<f32>>) {
61 self.doc_ids.push(doc_id);
62 self.codes.push(code);
63
64 if let Some(raw) = raw_vector {
65 self.raw_vectors.get_or_insert_with(Vec::new).push(raw);
66 }
67 }
68
69 pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
71 for &doc_id in &other.doc_ids {
72 self.doc_ids.push(doc_id + doc_id_offset);
73 }
74 self.codes.extend(other.codes.iter().cloned());
75
76 if let Some(ref other_raw) = other.raw_vectors {
77 let raw = self.raw_vectors.get_or_insert_with(Vec::new);
78 raw.extend(other_raw.iter().cloned());
79 }
80 }
81
82 pub fn iter(&self) -> impl Iterator<Item = (u32, &C)> {
84 self.doc_ids.iter().copied().zip(self.codes.iter())
85 }
86
87 pub fn iter_with_raw(&self) -> impl Iterator<Item = (u32, &C, Option<&Vec<f32>>)> {
89 let raw_iter = self.raw_vectors.as_ref();
90 self.doc_ids
91 .iter()
92 .copied()
93 .zip(self.codes.iter())
94 .enumerate()
95 .map(move |(i, (doc_id, code))| {
96 let raw = raw_iter.and_then(|r| r.get(i));
97 (doc_id, code, raw)
98 })
99 }
100
101 pub fn clear(&mut self) {
103 self.doc_ids.clear();
104 self.codes.clear();
105 if let Some(ref mut raw) = self.raw_vectors {
106 raw.clear();
107 }
108 }
109
110 pub fn reserve(&mut self, additional: usize) {
112 self.doc_ids.reserve(additional);
113 self.codes.reserve(additional);
114 if let Some(ref mut raw) = self.raw_vectors {
115 raw.reserve(additional);
116 }
117 }
118}
119
120impl<C: Clone + QuantizedCode> ClusterData<C> {
121 pub fn size_bytes(&self) -> usize {
123 let doc_ids_size = self.doc_ids.len() * 4;
124 let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
125 let raw_size = self
126 .raw_vectors
127 .as_ref()
128 .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
129 .unwrap_or(0);
130
131 doc_ids_size + codes_size + raw_size
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ClusterStorage<C: Clone> {
138 pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
140 pub total_vectors: usize,
142}
143
144impl<C: Clone> Default for ClusterStorage<C> {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl<C: Clone> ClusterStorage<C> {
151 pub fn new() -> Self {
152 Self {
153 clusters: std::collections::HashMap::new(),
154 total_vectors: 0,
155 }
156 }
157
158 pub fn with_capacity(num_clusters: usize) -> Self {
159 Self {
160 clusters: std::collections::HashMap::with_capacity(num_clusters),
161 total_vectors: 0,
162 }
163 }
164
165 pub fn add(&mut self, cluster_id: u32, doc_id: u32, code: C, raw_vector: Option<Vec<f32>>) {
167 self.clusters
168 .entry(cluster_id)
169 .or_default()
170 .add(doc_id, code, raw_vector);
171 self.total_vectors += 1;
172 }
173
174 pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
176 self.clusters.get(&cluster_id)
177 }
178
179 pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
181 self.clusters.get_mut(&cluster_id)
182 }
183
184 pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
186 self.clusters.entry(cluster_id).or_default()
187 }
188
189 pub fn num_clusters(&self) -> usize {
191 self.clusters.len()
192 }
193
194 pub fn len(&self) -> usize {
196 self.total_vectors
197 }
198
199 pub fn is_empty(&self) -> bool {
200 self.total_vectors == 0
201 }
202
203 pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
205 self.clusters.iter().map(|(&id, data)| (id, data))
206 }
207
208 pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
210 for (&cluster_id, other_data) in &other.clusters {
211 self.clusters
212 .entry(cluster_id)
213 .or_default()
214 .append(other_data, doc_id_offset);
215 }
216 self.total_vectors += other.total_vectors;
217 }
218
219 pub fn clear(&mut self) {
221 self.clusters.clear();
222 self.total_vectors = 0;
223 }
224}
225
226impl<C: Clone + QuantizedCode> ClusterStorage<C> {
227 pub fn size_bytes(&self) -> usize {
229 self.clusters.values().map(|c| c.size_bytes()).sum()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[derive(Clone, Debug)]
239 struct TestCode(Vec<u8>);
240
241 impl QuantizedCode for TestCode {
242 fn size_bytes(&self) -> usize {
243 self.0.len()
244 }
245 }
246
247 #[test]
248 fn test_cluster_data_basic() {
249 let mut cluster: ClusterData<TestCode> = ClusterData::new();
250
251 cluster.add(0, TestCode(vec![1, 2, 3]), None);
252 cluster.add(1, TestCode(vec![4, 5, 6]), None);
253
254 assert_eq!(cluster.len(), 2);
255 assert!(!cluster.is_empty());
256 }
257
258 #[test]
259 fn test_cluster_data_with_raw() {
260 let mut cluster: ClusterData<TestCode> = ClusterData::new();
261
262 cluster.add(0, TestCode(vec![1]), Some(vec![1.0, 2.0, 3.0]));
263 cluster.add(1, TestCode(vec![2]), Some(vec![4.0, 5.0, 6.0]));
264
265 assert!(cluster.raw_vectors.is_some());
266 assert_eq!(cluster.raw_vectors.as_ref().unwrap().len(), 2);
267 }
268
269 #[test]
270 fn test_cluster_data_append() {
271 let mut cluster1: ClusterData<TestCode> = ClusterData::new();
272 cluster1.add(0, TestCode(vec![1]), None);
273 cluster1.add(1, TestCode(vec![2]), None);
274
275 let mut cluster2: ClusterData<TestCode> = ClusterData::new();
276 cluster2.add(0, TestCode(vec![3]), None);
277 cluster2.add(1, TestCode(vec![4]), None);
278
279 cluster1.append(&cluster2, 100);
280
281 assert_eq!(cluster1.len(), 4);
282 assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
283 }
284
285 #[test]
286 fn test_cluster_storage() {
287 let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
288
289 storage.add(0, 10, TestCode(vec![1]), None);
290 storage.add(0, 11, TestCode(vec![2]), None);
291 storage.add(1, 20, TestCode(vec![3]), None);
292
293 assert_eq!(storage.num_clusters(), 2);
294 assert_eq!(storage.len(), 3);
295 assert_eq!(storage.get(0).unwrap().len(), 2);
296 assert_eq!(storage.get(1).unwrap().len(), 1);
297 }
298
299 #[test]
300 fn test_cluster_storage_merge() {
301 let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
302 storage1.add(0, 0, TestCode(vec![1]), None);
303
304 let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
305 storage2.add(0, 0, TestCode(vec![2]), None);
306 storage2.add(1, 0, TestCode(vec![3]), None);
307
308 storage1.merge(&storage2, 100);
309
310 assert_eq!(storage1.len(), 3);
311 assert_eq!(storage1.get(0).unwrap().len(), 2);
312 assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
313 }
314}