Skip to main content

hermes_core/structures/vector/ivf/
cluster.rs

1//! Generic cluster data storage for IVF indexes
2//!
3//! Provides a unified storage structure for cluster data that works
4//! with any quantization method (RaBitQ, PQ, etc.).
5
6use serde::{Deserialize, Serialize};
7
8/// Trait for quantized vector codes
9pub trait QuantizedCode: Clone + Send + Sync {
10    /// Size in bytes of this code
11    fn size_bytes(&self) -> usize;
12}
13
14/// Generic cluster data storage
15///
16/// Stores document IDs and quantized codes for vectors in a cluster.
17/// Can optionally store raw vectors for re-ranking.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20    /// Document IDs (local to segment)
21    pub doc_ids: Vec<u32>,
22    /// Quantized vector codes
23    pub codes: Vec<C>,
24    /// Raw vectors for re-ranking (optional)
25    pub raw_vectors: Option<Vec<Vec<f32>>>,
26}
27
28impl<C: Clone> Default for ClusterData<C> {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<C: Clone> ClusterData<C> {
35    pub fn new() -> Self {
36        Self {
37            doc_ids: Vec::new(),
38            codes: Vec::new(),
39            raw_vectors: None,
40        }
41    }
42
43    pub fn with_capacity(capacity: usize) -> Self {
44        Self {
45            doc_ids: Vec::with_capacity(capacity),
46            codes: Vec::with_capacity(capacity),
47            raw_vectors: None,
48        }
49    }
50
51    pub fn len(&self) -> usize {
52        self.doc_ids.len()
53    }
54
55    pub fn is_empty(&self) -> bool {
56        self.doc_ids.is_empty()
57    }
58
59    /// Add a vector to the cluster
60    pub fn add(&mut self, doc_id: u32, code: C, raw_vector: Option<Vec<f32>>) {
61        self.doc_ids.push(doc_id);
62        self.codes.push(code);
63
64        if let Some(raw) = raw_vector {
65            self.raw_vectors.get_or_insert_with(Vec::new).push(raw);
66        }
67    }
68
69    /// Append another cluster's data (for merging segments)
70    pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
71        for &doc_id in &other.doc_ids {
72            self.doc_ids.push(doc_id + doc_id_offset);
73        }
74        self.codes.extend(other.codes.iter().cloned());
75
76        if let Some(ref other_raw) = other.raw_vectors {
77            let raw = self.raw_vectors.get_or_insert_with(Vec::new);
78            raw.extend(other_raw.iter().cloned());
79        }
80    }
81
82    /// Get iterator over (doc_id, code) pairs
83    pub fn iter(&self) -> impl Iterator<Item = (u32, &C)> {
84        self.doc_ids.iter().copied().zip(self.codes.iter())
85    }
86
87    /// Get iterator over (doc_id, code, optional raw_vector) tuples
88    pub fn iter_with_raw(&self) -> impl Iterator<Item = (u32, &C, Option<&Vec<f32>>)> {
89        let raw_iter = self.raw_vectors.as_ref();
90        self.doc_ids
91            .iter()
92            .copied()
93            .zip(self.codes.iter())
94            .enumerate()
95            .map(move |(i, (doc_id, code))| {
96                let raw = raw_iter.and_then(|r| r.get(i));
97                (doc_id, code, raw)
98            })
99    }
100
101    /// Clear all data
102    pub fn clear(&mut self) {
103        self.doc_ids.clear();
104        self.codes.clear();
105        if let Some(ref mut raw) = self.raw_vectors {
106            raw.clear();
107        }
108    }
109
110    /// Reserve capacity
111    pub fn reserve(&mut self, additional: usize) {
112        self.doc_ids.reserve(additional);
113        self.codes.reserve(additional);
114        if let Some(ref mut raw) = self.raw_vectors {
115            raw.reserve(additional);
116        }
117    }
118}
119
120impl<C: Clone + QuantizedCode> ClusterData<C> {
121    /// Memory usage in bytes
122    pub fn size_bytes(&self) -> usize {
123        let doc_ids_size = self.doc_ids.len() * 4;
124        let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
125        let raw_size = self
126            .raw_vectors
127            .as_ref()
128            .map(|vecs| vecs.iter().map(|v| v.len() * 4).sum())
129            .unwrap_or(0);
130
131        doc_ids_size + codes_size + raw_size
132    }
133}
134
135/// Storage for multiple clusters (HashMap wrapper with utilities)
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ClusterStorage<C: Clone> {
138    /// Cluster data indexed by cluster ID
139    pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
140    /// Total number of vectors across all clusters
141    pub total_vectors: usize,
142}
143
144impl<C: Clone> Default for ClusterStorage<C> {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl<C: Clone> ClusterStorage<C> {
151    pub fn new() -> Self {
152        Self {
153            clusters: std::collections::HashMap::new(),
154            total_vectors: 0,
155        }
156    }
157
158    pub fn with_capacity(num_clusters: usize) -> Self {
159        Self {
160            clusters: std::collections::HashMap::with_capacity(num_clusters),
161            total_vectors: 0,
162        }
163    }
164
165    /// Add a vector to a cluster
166    pub fn add(&mut self, cluster_id: u32, doc_id: u32, code: C, raw_vector: Option<Vec<f32>>) {
167        self.clusters
168            .entry(cluster_id)
169            .or_default()
170            .add(doc_id, code, raw_vector);
171        self.total_vectors += 1;
172    }
173
174    /// Get cluster data
175    pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
176        self.clusters.get(&cluster_id)
177    }
178
179    /// Get mutable cluster data
180    pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
181        self.clusters.get_mut(&cluster_id)
182    }
183
184    /// Get or create cluster data
185    pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
186        self.clusters.entry(cluster_id).or_default()
187    }
188
189    /// Number of non-empty clusters
190    pub fn num_clusters(&self) -> usize {
191        self.clusters.len()
192    }
193
194    /// Total number of vectors
195    pub fn len(&self) -> usize {
196        self.total_vectors
197    }
198
199    pub fn is_empty(&self) -> bool {
200        self.total_vectors == 0
201    }
202
203    /// Iterate over all clusters
204    pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
205        self.clusters.iter().map(|(&id, data)| (id, data))
206    }
207
208    /// Merge another storage into this one
209    pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
210        for (&cluster_id, other_data) in &other.clusters {
211            self.clusters
212                .entry(cluster_id)
213                .or_default()
214                .append(other_data, doc_id_offset);
215        }
216        self.total_vectors += other.total_vectors;
217    }
218
219    /// Clear all clusters
220    pub fn clear(&mut self) {
221        self.clusters.clear();
222        self.total_vectors = 0;
223    }
224}
225
226impl<C: Clone + QuantizedCode> ClusterStorage<C> {
227    /// Total memory usage in bytes
228    pub fn size_bytes(&self) -> usize {
229        self.clusters.values().map(|c| c.size_bytes()).sum()
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    // Simple test code type
238    #[derive(Clone, Debug)]
239    struct TestCode(Vec<u8>);
240
241    impl QuantizedCode for TestCode {
242        fn size_bytes(&self) -> usize {
243            self.0.len()
244        }
245    }
246
247    #[test]
248    fn test_cluster_data_basic() {
249        let mut cluster: ClusterData<TestCode> = ClusterData::new();
250
251        cluster.add(0, TestCode(vec![1, 2, 3]), None);
252        cluster.add(1, TestCode(vec![4, 5, 6]), None);
253
254        assert_eq!(cluster.len(), 2);
255        assert!(!cluster.is_empty());
256    }
257
258    #[test]
259    fn test_cluster_data_with_raw() {
260        let mut cluster: ClusterData<TestCode> = ClusterData::new();
261
262        cluster.add(0, TestCode(vec![1]), Some(vec![1.0, 2.0, 3.0]));
263        cluster.add(1, TestCode(vec![2]), Some(vec![4.0, 5.0, 6.0]));
264
265        assert!(cluster.raw_vectors.is_some());
266        assert_eq!(cluster.raw_vectors.as_ref().unwrap().len(), 2);
267    }
268
269    #[test]
270    fn test_cluster_data_append() {
271        let mut cluster1: ClusterData<TestCode> = ClusterData::new();
272        cluster1.add(0, TestCode(vec![1]), None);
273        cluster1.add(1, TestCode(vec![2]), None);
274
275        let mut cluster2: ClusterData<TestCode> = ClusterData::new();
276        cluster2.add(0, TestCode(vec![3]), None);
277        cluster2.add(1, TestCode(vec![4]), None);
278
279        cluster1.append(&cluster2, 100);
280
281        assert_eq!(cluster1.len(), 4);
282        assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
283    }
284
285    #[test]
286    fn test_cluster_storage() {
287        let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
288
289        storage.add(0, 10, TestCode(vec![1]), None);
290        storage.add(0, 11, TestCode(vec![2]), None);
291        storage.add(1, 20, TestCode(vec![3]), None);
292
293        assert_eq!(storage.num_clusters(), 2);
294        assert_eq!(storage.len(), 3);
295        assert_eq!(storage.get(0).unwrap().len(), 2);
296        assert_eq!(storage.get(1).unwrap().len(), 1);
297    }
298
299    #[test]
300    fn test_cluster_storage_merge() {
301        let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
302        storage1.add(0, 0, TestCode(vec![1]), None);
303
304        let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
305        storage2.add(0, 0, TestCode(vec![2]), None);
306        storage2.add(1, 0, TestCode(vec![3]), None);
307
308        storage1.merge(&storage2, 100);
309
310        assert_eq!(storage1.len(), 3);
311        assert_eq!(storage1.get(0).unwrap().len(), 2);
312        assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
313    }
314}