hermes_core/structures/vector/ivf/
cluster.rs1use serde::{Deserialize, Serialize};
7
8pub trait QuantizedCode: Clone + Send + Sync {
10 fn size_bytes(&self) -> usize;
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20 pub doc_ids: Vec<u32>,
22 pub ordinals: Vec<u16>,
25 pub codes: Vec<C>,
27 pub raw_vectors: Option<Vec<Vec<f32>>>,
29}
30
31impl<C: Clone> Default for ClusterData<C> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<C: Clone> ClusterData<C> {
38 pub fn new() -> Self {
39 Self {
40 doc_ids: Vec::new(),
41 ordinals: Vec::new(),
42 codes: Vec::new(),
43 raw_vectors: None,
44 }
45 }
46
47 pub fn with_capacity(capacity: usize) -> Self {
48 Self {
49 doc_ids: Vec::with_capacity(capacity),
50 ordinals: Vec::with_capacity(capacity),
51 codes: Vec::with_capacity(capacity),
52 raw_vectors: None,
53 }
54 }
55
56 pub fn len(&self) -> usize {
57 self.doc_ids.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
61 self.doc_ids.is_empty()
62 }
63
64 pub fn add(&mut self, doc_id: u32, ordinal: u16, code: C, raw_vector: Option<Vec<f32>>) {
66 self.doc_ids.push(doc_id);
67 self.ordinals.push(ordinal);
68 self.codes.push(code);
69
70 if let Some(raw) = raw_vector {
71 self.raw_vectors.get_or_insert_with(Vec::new).push(raw);
72 }
73 }
74
75 pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
77 for &doc_id in &other.doc_ids {
78 self.doc_ids.push(doc_id + doc_id_offset);
79 }
80 self.ordinals.extend(other.ordinals.iter().copied());
81 self.codes.extend(other.codes.iter().cloned());
82
83 if let Some(ref other_raw) = other.raw_vectors {
84 let raw = self.raw_vectors.get_or_insert_with(Vec::new);
85 raw.extend(other_raw.iter().cloned());
86 }
87 }
88
89 pub fn iter(&self) -> impl Iterator<Item = (u32, u16, &C)> {
91 self.doc_ids
92 .iter()
93 .copied()
94 .zip(self.ordinals.iter().copied())
95 .zip(self.codes.iter())
96 .map(|((doc_id, ordinal), code)| (doc_id, ordinal, code))
97 }
98
99 pub fn iter_with_raw(&self) -> impl Iterator<Item = (u32, u16, &C, Option<&Vec<f32>>)> {
101 let raw_iter = self.raw_vectors.as_ref();
102 self.doc_ids
103 .iter()
104 .copied()
105 .zip(self.ordinals.iter().copied())
106 .zip(self.codes.iter())
107 .enumerate()
108 .map(move |(i, ((doc_id, ordinal), code))| {
109 let raw = raw_iter.and_then(|r| r.get(i));
110 (doc_id, ordinal, code, raw)
111 })
112 }
113
114 pub fn clear(&mut self) {
116 self.doc_ids.clear();
117 self.ordinals.clear();
118 self.codes.clear();
119 if let Some(ref mut raw) = self.raw_vectors {
120 raw.clear();
121 }
122 }
123
124 pub fn reserve(&mut self, additional: usize) {
126 self.doc_ids.reserve(additional);
127 self.ordinals.reserve(additional);
128 self.codes.reserve(additional);
129 if let Some(ref mut raw) = self.raw_vectors {
130 raw.reserve(additional);
131 }
132 }
133}
134
135impl<C: Clone + QuantizedCode> ClusterData<C> {
136 pub fn size_bytes(&self) -> usize {
138 use std::mem::size_of;
139 let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
140 let ordinals_size = self.ordinals.len() * size_of::<u16>();
141 let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
142 let raw_size = self
143 .raw_vectors
144 .as_ref()
145 .map(|vecs| vecs.iter().map(|v| v.len() * size_of::<f32>()).sum())
146 .unwrap_or(0);
147
148 doc_ids_size + ordinals_size + codes_size + raw_size
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ClusterStorage<C: Clone> {
155 pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
157 pub total_vectors: usize,
159}
160
161impl<C: Clone> Default for ClusterStorage<C> {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167impl<C: Clone> ClusterStorage<C> {
168 pub fn new() -> Self {
169 Self {
170 clusters: std::collections::HashMap::new(),
171 total_vectors: 0,
172 }
173 }
174
175 pub fn with_capacity(num_clusters: usize) -> Self {
176 Self {
177 clusters: std::collections::HashMap::with_capacity(num_clusters),
178 total_vectors: 0,
179 }
180 }
181
182 pub fn add(
184 &mut self,
185 cluster_id: u32,
186 doc_id: u32,
187 ordinal: u16,
188 code: C,
189 raw_vector: Option<Vec<f32>>,
190 ) {
191 self.clusters
192 .entry(cluster_id)
193 .or_default()
194 .add(doc_id, ordinal, code, raw_vector);
195 self.total_vectors += 1;
196 }
197
198 pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
200 self.clusters.get(&cluster_id)
201 }
202
203 pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
205 self.clusters.get_mut(&cluster_id)
206 }
207
208 pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
210 self.clusters.entry(cluster_id).or_default()
211 }
212
213 pub fn num_clusters(&self) -> usize {
215 self.clusters.len()
216 }
217
218 pub fn len(&self) -> usize {
220 self.total_vectors
221 }
222
223 pub fn is_empty(&self) -> bool {
224 self.total_vectors == 0
225 }
226
227 pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
229 self.clusters.iter().map(|(&id, data)| (id, data))
230 }
231
232 pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
234 for (&cluster_id, other_data) in &other.clusters {
235 self.clusters
236 .entry(cluster_id)
237 .or_default()
238 .append(other_data, doc_id_offset);
239 }
240 self.total_vectors += other.total_vectors;
241 }
242
243 pub fn clear(&mut self) {
245 self.clusters.clear();
246 self.total_vectors = 0;
247 }
248}
249
250impl<C: Clone + QuantizedCode> ClusterStorage<C> {
251 pub fn size_bytes(&self) -> usize {
253 self.clusters.values().map(|c| c.size_bytes()).sum()
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[derive(Clone, Debug)]
263 struct TestCode(Vec<u8>);
264
265 impl QuantizedCode for TestCode {
266 fn size_bytes(&self) -> usize {
267 self.0.len()
268 }
269 }
270
271 #[test]
272 fn test_cluster_data_basic() {
273 let mut cluster: ClusterData<TestCode> = ClusterData::new();
274
275 cluster.add(0, 0, TestCode(vec![1, 2, 3]), None);
276 cluster.add(1, 0, TestCode(vec![4, 5, 6]), None);
277
278 assert_eq!(cluster.len(), 2);
279 assert!(!cluster.is_empty());
280 }
281
282 #[test]
283 fn test_cluster_data_with_raw() {
284 let mut cluster: ClusterData<TestCode> = ClusterData::new();
285
286 cluster.add(0, 0, TestCode(vec![1]), Some(vec![1.0, 2.0, 3.0]));
287 cluster.add(1, 0, TestCode(vec![2]), Some(vec![4.0, 5.0, 6.0]));
288
289 assert!(cluster.raw_vectors.is_some());
290 assert_eq!(cluster.raw_vectors.as_ref().unwrap().len(), 2);
291 }
292
293 #[test]
294 fn test_cluster_data_with_ordinals() {
295 let mut cluster: ClusterData<TestCode> = ClusterData::new();
296
297 cluster.add(0, 0, TestCode(vec![1]), None);
299 cluster.add(0, 1, TestCode(vec![2]), None);
300 cluster.add(0, 2, TestCode(vec![3]), None);
301
302 assert_eq!(cluster.len(), 3);
303 assert_eq!(cluster.ordinals, vec![0, 1, 2]);
304 }
305
306 #[test]
307 fn test_cluster_data_append() {
308 let mut cluster1: ClusterData<TestCode> = ClusterData::new();
309 cluster1.add(0, 0, TestCode(vec![1]), None);
310 cluster1.add(1, 0, TestCode(vec![2]), None);
311
312 let mut cluster2: ClusterData<TestCode> = ClusterData::new();
313 cluster2.add(0, 0, TestCode(vec![3]), None);
314 cluster2.add(1, 0, TestCode(vec![4]), None);
315
316 cluster1.append(&cluster2, 100);
317
318 assert_eq!(cluster1.len(), 4);
319 assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
320 }
321
322 #[test]
323 fn test_cluster_storage() {
324 let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
325
326 storage.add(0, 10, 0, TestCode(vec![1]), None);
327 storage.add(0, 11, 0, TestCode(vec![2]), None);
328 storage.add(1, 20, 0, TestCode(vec![3]), None);
329
330 assert_eq!(storage.num_clusters(), 2);
331 assert_eq!(storage.len(), 3);
332 assert_eq!(storage.get(0).unwrap().len(), 2);
333 assert_eq!(storage.get(1).unwrap().len(), 1);
334 }
335
336 #[test]
337 fn test_cluster_storage_merge() {
338 let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
339 storage1.add(0, 0, 0, TestCode(vec![1]), None);
340
341 let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
342 storage2.add(0, 0, 0, TestCode(vec![2]), None);
343 storage2.add(1, 0, 0, TestCode(vec![3]), None);
344
345 storage1.merge(&storage2, 100);
346
347 assert_eq!(storage1.len(), 3);
348 assert_eq!(storage1.get(0).unwrap().len(), 2);
349 assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
350 }
351}