Skip to main content

hermes_core/structures/vector/ivf/
cluster.rs

1//! Generic cluster data storage for IVF indexes
2//!
3//! Provides a unified storage structure for cluster data that works
4//! with any quantization method (RaBitQ, PQ, etc.).
5
6use serde::{Deserialize, Serialize};
7
8/// Trait for quantized vector codes
9pub trait QuantizedCode: Clone + Send + Sync {
10    /// Size in bytes of this code
11    fn size_bytes(&self) -> usize;
12}
13
14/// Generic cluster data storage
15///
16/// Stores document IDs, element ordinals, and quantized codes for vectors in a cluster.
17/// Can optionally store raw vectors for re-ranking.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20    /// Document IDs (local to segment)
21    pub doc_ids: Vec<u32>,
22    /// Element ordinals for multi-valued fields (0 for single-valued)
23    /// Stored as u16 to support up to 65535 values per document per field
24    pub ordinals: Vec<u16>,
25    /// Quantized vector codes
26    pub codes: Vec<C>,
27    /// Raw vectors for re-ranking (optional)
28    pub raw_vectors: Option<Vec<Vec<f32>>>,
29}
30
31impl<C: Clone> Default for ClusterData<C> {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<C: Clone> ClusterData<C> {
38    pub fn new() -> Self {
39        Self {
40            doc_ids: Vec::new(),
41            ordinals: Vec::new(),
42            codes: Vec::new(),
43            raw_vectors: None,
44        }
45    }
46
47    pub fn with_capacity(capacity: usize) -> Self {
48        Self {
49            doc_ids: Vec::with_capacity(capacity),
50            ordinals: Vec::with_capacity(capacity),
51            codes: Vec::with_capacity(capacity),
52            raw_vectors: None,
53        }
54    }
55
56    pub fn len(&self) -> usize {
57        self.doc_ids.len()
58    }
59
60    pub fn is_empty(&self) -> bool {
61        self.doc_ids.is_empty()
62    }
63
64    /// Add a vector to the cluster
65    pub fn add(&mut self, doc_id: u32, ordinal: u16, code: C, raw_vector: Option<Vec<f32>>) {
66        self.doc_ids.push(doc_id);
67        self.ordinals.push(ordinal);
68        self.codes.push(code);
69
70        if let Some(raw) = raw_vector {
71            self.raw_vectors.get_or_insert_with(Vec::new).push(raw);
72        }
73    }
74
75    /// Append another cluster's data (for merging segments)
76    pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
77        for &doc_id in &other.doc_ids {
78            self.doc_ids.push(doc_id + doc_id_offset);
79        }
80        self.ordinals.extend(other.ordinals.iter().copied());
81        self.codes.extend(other.codes.iter().cloned());
82
83        if let Some(ref other_raw) = other.raw_vectors {
84            let raw = self.raw_vectors.get_or_insert_with(Vec::new);
85            raw.extend(other_raw.iter().cloned());
86        }
87    }
88
89    /// Get iterator over (doc_id, ordinal, code) tuples
90    pub fn iter(&self) -> impl Iterator<Item = (u32, u16, &C)> {
91        self.doc_ids
92            .iter()
93            .copied()
94            .zip(self.ordinals.iter().copied())
95            .zip(self.codes.iter())
96            .map(|((doc_id, ordinal), code)| (doc_id, ordinal, code))
97    }
98
99    /// Get iterator over (doc_id, ordinal, code, optional raw_vector) tuples
100    pub fn iter_with_raw(&self) -> impl Iterator<Item = (u32, u16, &C, Option<&Vec<f32>>)> {
101        let raw_iter = self.raw_vectors.as_ref();
102        self.doc_ids
103            .iter()
104            .copied()
105            .zip(self.ordinals.iter().copied())
106            .zip(self.codes.iter())
107            .enumerate()
108            .map(move |(i, ((doc_id, ordinal), code))| {
109                let raw = raw_iter.and_then(|r| r.get(i));
110                (doc_id, ordinal, code, raw)
111            })
112    }
113
114    /// Clear all data
115    pub fn clear(&mut self) {
116        self.doc_ids.clear();
117        self.ordinals.clear();
118        self.codes.clear();
119        if let Some(ref mut raw) = self.raw_vectors {
120            raw.clear();
121        }
122    }
123
124    /// Reserve capacity
125    pub fn reserve(&mut self, additional: usize) {
126        self.doc_ids.reserve(additional);
127        self.ordinals.reserve(additional);
128        self.codes.reserve(additional);
129        if let Some(ref mut raw) = self.raw_vectors {
130            raw.reserve(additional);
131        }
132    }
133}
134
135impl<C: Clone + QuantizedCode> ClusterData<C> {
136    /// Memory usage in bytes
137    pub fn size_bytes(&self) -> usize {
138        use std::mem::size_of;
139        let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
140        let ordinals_size = self.ordinals.len() * size_of::<u16>();
141        let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
142        let raw_size = self
143            .raw_vectors
144            .as_ref()
145            .map(|vecs| vecs.iter().map(|v| v.len() * size_of::<f32>()).sum())
146            .unwrap_or(0);
147
148        doc_ids_size + ordinals_size + codes_size + raw_size
149    }
150}
151
152/// Storage for multiple clusters (HashMap wrapper with utilities)
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ClusterStorage<C: Clone> {
155    /// Cluster data indexed by cluster ID
156    pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
157    /// Total number of vectors across all clusters
158    pub total_vectors: usize,
159}
160
161impl<C: Clone> Default for ClusterStorage<C> {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167impl<C: Clone> ClusterStorage<C> {
168    pub fn new() -> Self {
169        Self {
170            clusters: std::collections::HashMap::new(),
171            total_vectors: 0,
172        }
173    }
174
175    pub fn with_capacity(num_clusters: usize) -> Self {
176        Self {
177            clusters: std::collections::HashMap::with_capacity(num_clusters),
178            total_vectors: 0,
179        }
180    }
181
182    /// Add a vector to a cluster
183    pub fn add(
184        &mut self,
185        cluster_id: u32,
186        doc_id: u32,
187        ordinal: u16,
188        code: C,
189        raw_vector: Option<Vec<f32>>,
190    ) {
191        self.clusters
192            .entry(cluster_id)
193            .or_default()
194            .add(doc_id, ordinal, code, raw_vector);
195        self.total_vectors += 1;
196    }
197
198    /// Get cluster data
199    pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
200        self.clusters.get(&cluster_id)
201    }
202
203    /// Get mutable cluster data
204    pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
205        self.clusters.get_mut(&cluster_id)
206    }
207
208    /// Get or create cluster data
209    pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
210        self.clusters.entry(cluster_id).or_default()
211    }
212
213    /// Number of non-empty clusters
214    pub fn num_clusters(&self) -> usize {
215        self.clusters.len()
216    }
217
218    /// Total number of vectors
219    pub fn len(&self) -> usize {
220        self.total_vectors
221    }
222
223    pub fn is_empty(&self) -> bool {
224        self.total_vectors == 0
225    }
226
227    /// Iterate over all clusters
228    pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
229        self.clusters.iter().map(|(&id, data)| (id, data))
230    }
231
232    /// Merge another storage into this one
233    pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
234        for (&cluster_id, other_data) in &other.clusters {
235            self.clusters
236                .entry(cluster_id)
237                .or_default()
238                .append(other_data, doc_id_offset);
239        }
240        self.total_vectors += other.total_vectors;
241    }
242
243    /// Clear all clusters
244    pub fn clear(&mut self) {
245        self.clusters.clear();
246        self.total_vectors = 0;
247    }
248}
249
250impl<C: Clone + QuantizedCode> ClusterStorage<C> {
251    /// Total memory usage in bytes
252    pub fn size_bytes(&self) -> usize {
253        self.clusters.values().map(|c| c.size_bytes()).sum()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    // Simple test code type
262    #[derive(Clone, Debug)]
263    struct TestCode(Vec<u8>);
264
265    impl QuantizedCode for TestCode {
266        fn size_bytes(&self) -> usize {
267            self.0.len()
268        }
269    }
270
271    #[test]
272    fn test_cluster_data_basic() {
273        let mut cluster: ClusterData<TestCode> = ClusterData::new();
274
275        cluster.add(0, 0, TestCode(vec![1, 2, 3]), None);
276        cluster.add(1, 0, TestCode(vec![4, 5, 6]), None);
277
278        assert_eq!(cluster.len(), 2);
279        assert!(!cluster.is_empty());
280    }
281
282    #[test]
283    fn test_cluster_data_with_raw() {
284        let mut cluster: ClusterData<TestCode> = ClusterData::new();
285
286        cluster.add(0, 0, TestCode(vec![1]), Some(vec![1.0, 2.0, 3.0]));
287        cluster.add(1, 0, TestCode(vec![2]), Some(vec![4.0, 5.0, 6.0]));
288
289        assert!(cluster.raw_vectors.is_some());
290        assert_eq!(cluster.raw_vectors.as_ref().unwrap().len(), 2);
291    }
292
293    #[test]
294    fn test_cluster_data_with_ordinals() {
295        let mut cluster: ClusterData<TestCode> = ClusterData::new();
296
297        // Multi-valued field: doc 0 has 3 vectors
298        cluster.add(0, 0, TestCode(vec![1]), None);
299        cluster.add(0, 1, TestCode(vec![2]), None);
300        cluster.add(0, 2, TestCode(vec![3]), None);
301
302        assert_eq!(cluster.len(), 3);
303        assert_eq!(cluster.ordinals, vec![0, 1, 2]);
304    }
305
306    #[test]
307    fn test_cluster_data_append() {
308        let mut cluster1: ClusterData<TestCode> = ClusterData::new();
309        cluster1.add(0, 0, TestCode(vec![1]), None);
310        cluster1.add(1, 0, TestCode(vec![2]), None);
311
312        let mut cluster2: ClusterData<TestCode> = ClusterData::new();
313        cluster2.add(0, 0, TestCode(vec![3]), None);
314        cluster2.add(1, 0, TestCode(vec![4]), None);
315
316        cluster1.append(&cluster2, 100);
317
318        assert_eq!(cluster1.len(), 4);
319        assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
320    }
321
322    #[test]
323    fn test_cluster_storage() {
324        let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
325
326        storage.add(0, 10, 0, TestCode(vec![1]), None);
327        storage.add(0, 11, 0, TestCode(vec![2]), None);
328        storage.add(1, 20, 0, TestCode(vec![3]), None);
329
330        assert_eq!(storage.num_clusters(), 2);
331        assert_eq!(storage.len(), 3);
332        assert_eq!(storage.get(0).unwrap().len(), 2);
333        assert_eq!(storage.get(1).unwrap().len(), 1);
334    }
335
336    #[test]
337    fn test_cluster_storage_merge() {
338        let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
339        storage1.add(0, 0, 0, TestCode(vec![1]), None);
340
341        let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
342        storage2.add(0, 0, 0, TestCode(vec![2]), None);
343        storage2.add(1, 0, 0, TestCode(vec![3]), None);
344
345        storage1.merge(&storage2, 100);
346
347        assert_eq!(storage1.len(), 3);
348        assert_eq!(storage1.get(0).unwrap().len(), 2);
349        assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
350    }
351}