hermes_core/structures/vector/ivf/
cluster.rs1use serde::{Deserialize, Serialize};
7
8pub trait QuantizedCode: Clone + Send + Sync {
10 fn size_bytes(&self) -> usize;
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20 pub doc_ids: Vec<u32>,
22 pub ordinals: Vec<u16>,
25 pub codes: Vec<C>,
27}
28
29impl<C: Clone> Default for ClusterData<C> {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl<C: Clone> ClusterData<C> {
36 pub fn new() -> Self {
37 Self {
38 doc_ids: Vec::new(),
39 ordinals: Vec::new(),
40 codes: Vec::new(),
41 }
42 }
43
44 pub fn with_capacity(capacity: usize) -> Self {
45 Self {
46 doc_ids: Vec::with_capacity(capacity),
47 ordinals: Vec::with_capacity(capacity),
48 codes: Vec::with_capacity(capacity),
49 }
50 }
51
52 pub fn len(&self) -> usize {
53 self.doc_ids.len()
54 }
55
56 pub fn is_empty(&self) -> bool {
57 self.doc_ids.is_empty()
58 }
59
60 pub fn add(&mut self, doc_id: u32, ordinal: u16, code: C) {
62 self.doc_ids.push(doc_id);
63 self.ordinals.push(ordinal);
64 self.codes.push(code);
65 }
66
67 pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
69 for &doc_id in &other.doc_ids {
70 self.doc_ids.push(doc_id + doc_id_offset);
71 }
72 self.ordinals.extend(other.ordinals.iter().copied());
73 self.codes.extend(other.codes.iter().cloned());
74 }
75
76 pub fn iter(&self) -> impl Iterator<Item = (u32, u16, &C)> {
78 self.doc_ids
79 .iter()
80 .copied()
81 .zip(self.ordinals.iter().copied())
82 .zip(self.codes.iter())
83 .map(|((doc_id, ordinal), code)| (doc_id, ordinal, code))
84 }
85
86 pub fn clear(&mut self) {
88 self.doc_ids.clear();
89 self.ordinals.clear();
90 self.codes.clear();
91 }
92
93 pub fn reserve(&mut self, additional: usize) {
95 self.doc_ids.reserve(additional);
96 self.ordinals.reserve(additional);
97 self.codes.reserve(additional);
98 }
99}
100
101impl<C: Clone + QuantizedCode> ClusterData<C> {
102 pub fn size_bytes(&self) -> usize {
104 use std::mem::size_of;
105 let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
106 let ordinals_size = self.ordinals.len() * size_of::<u16>();
107 let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
108
109 doc_ids_size + ordinals_size + codes_size
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ClusterStorage<C: Clone> {
116 pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
118 pub total_vectors: usize,
120}
121
122impl<C: Clone> Default for ClusterStorage<C> {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl<C: Clone> ClusterStorage<C> {
129 pub fn new() -> Self {
130 Self {
131 clusters: std::collections::HashMap::new(),
132 total_vectors: 0,
133 }
134 }
135
136 pub fn with_capacity(num_clusters: usize) -> Self {
137 Self {
138 clusters: std::collections::HashMap::with_capacity(num_clusters),
139 total_vectors: 0,
140 }
141 }
142
143 pub fn add(&mut self, cluster_id: u32, doc_id: u32, ordinal: u16, code: C) {
145 self.clusters
146 .entry(cluster_id)
147 .or_default()
148 .add(doc_id, ordinal, code);
149 self.total_vectors += 1;
150 }
151
152 pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
154 self.clusters.get(&cluster_id)
155 }
156
157 pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
159 self.clusters.get_mut(&cluster_id)
160 }
161
162 pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
164 self.clusters.entry(cluster_id).or_default()
165 }
166
167 pub fn num_clusters(&self) -> usize {
169 self.clusters.len()
170 }
171
172 pub fn len(&self) -> usize {
174 self.total_vectors
175 }
176
177 pub fn is_empty(&self) -> bool {
178 self.total_vectors == 0
179 }
180
181 pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
183 self.clusters.iter().map(|(&id, data)| (id, data))
184 }
185
186 pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
188 for (&cluster_id, other_data) in &other.clusters {
189 self.clusters
190 .entry(cluster_id)
191 .or_default()
192 .append(other_data, doc_id_offset);
193 }
194 self.total_vectors += other.total_vectors;
195 }
196
197 pub fn clear(&mut self) {
199 self.clusters.clear();
200 self.total_vectors = 0;
201 }
202}
203
204impl<C: Clone + QuantizedCode> ClusterStorage<C> {
205 pub fn size_bytes(&self) -> usize {
207 self.clusters.values().map(|c| c.size_bytes()).sum()
208 }
209
210 pub fn estimated_memory_bytes(&self) -> usize {
212 self.size_bytes()
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[derive(Clone, Debug)]
222 struct TestCode(Vec<u8>);
223
224 impl QuantizedCode for TestCode {
225 fn size_bytes(&self) -> usize {
226 self.0.len()
227 }
228 }
229
230 #[test]
231 fn test_cluster_data_basic() {
232 let mut cluster: ClusterData<TestCode> = ClusterData::new();
233
234 cluster.add(0, 0, TestCode(vec![1, 2, 3]));
235 cluster.add(1, 0, TestCode(vec![4, 5, 6]));
236
237 assert_eq!(cluster.len(), 2);
238 assert!(!cluster.is_empty());
239 }
240
241 #[test]
242 fn test_cluster_data_with_ordinals() {
243 let mut cluster: ClusterData<TestCode> = ClusterData::new();
244
245 cluster.add(0, 0, TestCode(vec![1]));
247 cluster.add(0, 1, TestCode(vec![2]));
248 cluster.add(0, 2, TestCode(vec![3]));
249
250 assert_eq!(cluster.len(), 3);
251 assert_eq!(cluster.ordinals, vec![0, 1, 2]);
252 }
253
254 #[test]
255 fn test_cluster_data_append() {
256 let mut cluster1: ClusterData<TestCode> = ClusterData::new();
257 cluster1.add(0, 0, TestCode(vec![1]));
258 cluster1.add(1, 0, TestCode(vec![2]));
259
260 let mut cluster2: ClusterData<TestCode> = ClusterData::new();
261 cluster2.add(0, 0, TestCode(vec![3]));
262 cluster2.add(1, 0, TestCode(vec![4]));
263
264 cluster1.append(&cluster2, 100);
265
266 assert_eq!(cluster1.len(), 4);
267 assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
268 }
269
270 #[test]
271 fn test_cluster_storage() {
272 let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
273
274 storage.add(0, 10, 0, TestCode(vec![1]));
275 storage.add(0, 11, 0, TestCode(vec![2]));
276 storage.add(1, 20, 0, TestCode(vec![3]));
277
278 assert_eq!(storage.num_clusters(), 2);
279 assert_eq!(storage.len(), 3);
280 assert_eq!(storage.get(0).unwrap().len(), 2);
281 assert_eq!(storage.get(1).unwrap().len(), 1);
282 }
283
284 #[test]
285 fn test_cluster_storage_merge() {
286 let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
287 storage1.add(0, 0, 0, TestCode(vec![1]));
288
289 let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
290 storage2.add(0, 0, 0, TestCode(vec![2]));
291 storage2.add(1, 0, 0, TestCode(vec![3]));
292
293 storage1.merge(&storage2, 100);
294
295 assert_eq!(storage1.len(), 3);
296 assert_eq!(storage1.get(0).unwrap().len(), 2);
297 assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
298 }
299}