Skip to main content

hermes_core/structures/vector/ivf/
cluster.rs

1//! Generic cluster data storage for IVF indexes
2//!
3//! Provides a unified storage structure for cluster data that works
4//! with any quantization method (RaBitQ, PQ, etc.).
5
6use serde::{Deserialize, Serialize};
7
8/// Trait for quantized vector codes
9pub trait QuantizedCode: Clone + Send + Sync {
10    /// Size in bytes of this code
11    fn size_bytes(&self) -> usize;
12}
13
14/// Generic cluster data storage
15///
16/// Stores document IDs, element ordinals, and quantized codes for vectors in a cluster.
17/// Reranking uses raw vectors from the document store, not from the index.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ClusterData<C: Clone> {
20    /// Document IDs (local to segment)
21    pub doc_ids: Vec<u32>,
22    /// Element ordinals for multi-valued fields (0 for single-valued)
23    /// Stored as u16 to support up to 65535 values per document per field
24    pub ordinals: Vec<u16>,
25    /// Quantized vector codes
26    pub codes: Vec<C>,
27}
28
29impl<C: Clone> Default for ClusterData<C> {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl<C: Clone> ClusterData<C> {
36    pub fn new() -> Self {
37        Self {
38            doc_ids: Vec::new(),
39            ordinals: Vec::new(),
40            codes: Vec::new(),
41        }
42    }
43
44    pub fn with_capacity(capacity: usize) -> Self {
45        Self {
46            doc_ids: Vec::with_capacity(capacity),
47            ordinals: Vec::with_capacity(capacity),
48            codes: Vec::with_capacity(capacity),
49        }
50    }
51
52    pub fn len(&self) -> usize {
53        self.doc_ids.len()
54    }
55
56    pub fn is_empty(&self) -> bool {
57        self.doc_ids.is_empty()
58    }
59
60    /// Add a vector to the cluster
61    pub fn add(&mut self, doc_id: u32, ordinal: u16, code: C) {
62        self.doc_ids.push(doc_id);
63        self.ordinals.push(ordinal);
64        self.codes.push(code);
65    }
66
67    /// Append another cluster's data (for merging segments)
68    pub fn append(&mut self, other: &ClusterData<C>, doc_id_offset: u32) {
69        for &doc_id in &other.doc_ids {
70            self.doc_ids.push(doc_id + doc_id_offset);
71        }
72        self.ordinals.extend(other.ordinals.iter().copied());
73        self.codes.extend(other.codes.iter().cloned());
74    }
75
76    /// Get iterator over (doc_id, ordinal, code) tuples
77    pub fn iter(&self) -> impl Iterator<Item = (u32, u16, &C)> {
78        self.doc_ids
79            .iter()
80            .copied()
81            .zip(self.ordinals.iter().copied())
82            .zip(self.codes.iter())
83            .map(|((doc_id, ordinal), code)| (doc_id, ordinal, code))
84    }
85
86    /// Clear all data
87    pub fn clear(&mut self) {
88        self.doc_ids.clear();
89        self.ordinals.clear();
90        self.codes.clear();
91    }
92
93    /// Reserve capacity
94    pub fn reserve(&mut self, additional: usize) {
95        self.doc_ids.reserve(additional);
96        self.ordinals.reserve(additional);
97        self.codes.reserve(additional);
98    }
99}
100
101impl<C: Clone + QuantizedCode> ClusterData<C> {
102    /// Memory usage in bytes
103    pub fn size_bytes(&self) -> usize {
104        use std::mem::size_of;
105        let doc_ids_size = self.doc_ids.len() * size_of::<u32>();
106        let ordinals_size = self.ordinals.len() * size_of::<u16>();
107        let codes_size: usize = self.codes.iter().map(|c| c.size_bytes()).sum();
108
109        doc_ids_size + ordinals_size + codes_size
110    }
111}
112
113/// Storage for multiple clusters (HashMap wrapper with utilities)
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ClusterStorage<C: Clone> {
116    /// Cluster data indexed by cluster ID
117    pub clusters: std::collections::HashMap<u32, ClusterData<C>>,
118    /// Total number of vectors across all clusters
119    pub total_vectors: usize,
120}
121
122impl<C: Clone> Default for ClusterStorage<C> {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl<C: Clone> ClusterStorage<C> {
129    pub fn new() -> Self {
130        Self {
131            clusters: std::collections::HashMap::new(),
132            total_vectors: 0,
133        }
134    }
135
136    pub fn with_capacity(num_clusters: usize) -> Self {
137        Self {
138            clusters: std::collections::HashMap::with_capacity(num_clusters),
139            total_vectors: 0,
140        }
141    }
142
143    /// Add a vector to a cluster
144    pub fn add(&mut self, cluster_id: u32, doc_id: u32, ordinal: u16, code: C) {
145        self.clusters
146            .entry(cluster_id)
147            .or_default()
148            .add(doc_id, ordinal, code);
149        self.total_vectors += 1;
150    }
151
152    /// Get cluster data
153    pub fn get(&self, cluster_id: u32) -> Option<&ClusterData<C>> {
154        self.clusters.get(&cluster_id)
155    }
156
157    /// Get mutable cluster data
158    pub fn get_mut(&mut self, cluster_id: u32) -> Option<&mut ClusterData<C>> {
159        self.clusters.get_mut(&cluster_id)
160    }
161
162    /// Get or create cluster data
163    pub fn get_or_create(&mut self, cluster_id: u32) -> &mut ClusterData<C> {
164        self.clusters.entry(cluster_id).or_default()
165    }
166
167    /// Number of non-empty clusters
168    pub fn num_clusters(&self) -> usize {
169        self.clusters.len()
170    }
171
172    /// Total number of vectors
173    pub fn len(&self) -> usize {
174        self.total_vectors
175    }
176
177    pub fn is_empty(&self) -> bool {
178        self.total_vectors == 0
179    }
180
181    /// Iterate over all clusters
182    pub fn iter(&self) -> impl Iterator<Item = (u32, &ClusterData<C>)> {
183        self.clusters.iter().map(|(&id, data)| (id, data))
184    }
185
186    /// Merge another storage into this one
187    pub fn merge(&mut self, other: &ClusterStorage<C>, doc_id_offset: u32) {
188        for (&cluster_id, other_data) in &other.clusters {
189            self.clusters
190                .entry(cluster_id)
191                .or_default()
192                .append(other_data, doc_id_offset);
193        }
194        self.total_vectors += other.total_vectors;
195    }
196
197    /// Clear all clusters
198    pub fn clear(&mut self) {
199        self.clusters.clear();
200        self.total_vectors = 0;
201    }
202}
203
204impl<C: Clone + QuantizedCode> ClusterStorage<C> {
205    /// Total memory usage in bytes
206    pub fn size_bytes(&self) -> usize {
207        self.clusters.values().map(|c| c.size_bytes()).sum()
208    }
209
210    /// Estimated memory usage in bytes (alias for size_bytes)
211    pub fn estimated_memory_bytes(&self) -> usize {
212        self.size_bytes()
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    // Simple test code type
221    #[derive(Clone, Debug)]
222    struct TestCode(Vec<u8>);
223
224    impl QuantizedCode for TestCode {
225        fn size_bytes(&self) -> usize {
226            self.0.len()
227        }
228    }
229
230    #[test]
231    fn test_cluster_data_basic() {
232        let mut cluster: ClusterData<TestCode> = ClusterData::new();
233
234        cluster.add(0, 0, TestCode(vec![1, 2, 3]));
235        cluster.add(1, 0, TestCode(vec![4, 5, 6]));
236
237        assert_eq!(cluster.len(), 2);
238        assert!(!cluster.is_empty());
239    }
240
241    #[test]
242    fn test_cluster_data_with_ordinals() {
243        let mut cluster: ClusterData<TestCode> = ClusterData::new();
244
245        // Multi-valued field: doc 0 has 3 vectors
246        cluster.add(0, 0, TestCode(vec![1]));
247        cluster.add(0, 1, TestCode(vec![2]));
248        cluster.add(0, 2, TestCode(vec![3]));
249
250        assert_eq!(cluster.len(), 3);
251        assert_eq!(cluster.ordinals, vec![0, 1, 2]);
252    }
253
254    #[test]
255    fn test_cluster_data_append() {
256        let mut cluster1: ClusterData<TestCode> = ClusterData::new();
257        cluster1.add(0, 0, TestCode(vec![1]));
258        cluster1.add(1, 0, TestCode(vec![2]));
259
260        let mut cluster2: ClusterData<TestCode> = ClusterData::new();
261        cluster2.add(0, 0, TestCode(vec![3]));
262        cluster2.add(1, 0, TestCode(vec![4]));
263
264        cluster1.append(&cluster2, 100);
265
266        assert_eq!(cluster1.len(), 4);
267        assert_eq!(cluster1.doc_ids, vec![0, 1, 100, 101]);
268    }
269
270    #[test]
271    fn test_cluster_storage() {
272        let mut storage: ClusterStorage<TestCode> = ClusterStorage::new();
273
274        storage.add(0, 10, 0, TestCode(vec![1]));
275        storage.add(0, 11, 0, TestCode(vec![2]));
276        storage.add(1, 20, 0, TestCode(vec![3]));
277
278        assert_eq!(storage.num_clusters(), 2);
279        assert_eq!(storage.len(), 3);
280        assert_eq!(storage.get(0).unwrap().len(), 2);
281        assert_eq!(storage.get(1).unwrap().len(), 1);
282    }
283
284    #[test]
285    fn test_cluster_storage_merge() {
286        let mut storage1: ClusterStorage<TestCode> = ClusterStorage::new();
287        storage1.add(0, 0, 0, TestCode(vec![1]));
288
289        let mut storage2: ClusterStorage<TestCode> = ClusterStorage::new();
290        storage2.add(0, 0, 0, TestCode(vec![2]));
291        storage2.add(1, 0, 0, TestCode(vec![3]));
292
293        storage1.merge(&storage2, 100);
294
295        assert_eq!(storage1.len(), 3);
296        assert_eq!(storage1.get(0).unwrap().len(), 2);
297        assert_eq!(storage1.get(0).unwrap().doc_ids, vec![0, 100]);
298    }
299}