hermes_core/structures/
ivf_rabitq.rs

1//! IVF-RaBitQ: Inverted File Index with RaBitQ quantization
2//!
3//! Two-level index for billion-scale vector search:
4//! - Level 1: Coarse quantizer (k-means centroids)
5//! - Level 2: RaBitQ binary codes per cluster
6//!
7//! Key feature: Segments sharing the same coarse centroids can be merged
8//! in O(1) by concatenating cluster data - no re-quantization needed.
9
10use std::collections::HashMap;
11use std::io::{self, Cursor, Read, Write};
12use std::path::Path;
13
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use rand::prelude::*;
16use serde::{Deserialize, Serialize};
17
18use super::rabitq::QuantizedVector;
19
20/// Magic number for coarse centroids file
21const CENTROIDS_MAGIC: u32 = 0x48435643; // "CVCH" - Coarse Vector Centroids Hermes
22
23/// Magic number for IVF-RaBitQ segment file
24#[allow(dead_code)]
25const IVF_MAGIC: u32 = 0x49565651; // "IVFQ"
26
27/// Coarse centroids for IVF - trained once, shared across all segments
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CoarseCentroids {
30    /// Number of clusters
31    pub num_clusters: u32,
32    /// Vector dimension
33    pub dim: usize,
34    /// Centroids stored as flat array (num_clusters × dim)
35    pub centroids: Vec<f32>,
36    /// Version for compatibility checking during merge
37    pub version: u64,
38}
39
40impl CoarseCentroids {
41    /// Train coarse centroids using k-means algorithm
42    ///
43    /// Uses kmeans crate with SIMD acceleration (native feature).
44    #[cfg(feature = "native")]
45    pub fn train(vectors: &[Vec<f32>], num_clusters: usize, max_iters: usize, _seed: u64) -> Self {
46        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
47
48        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
49        assert!(num_clusters > 0, "Need at least 1 cluster");
50
51        let actual_clusters = num_clusters.min(vectors.len());
52        let dim = vectors[0].len();
53
54        // Flatten vectors for kmeans crate (expects flat slice)
55        let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
56
57        // Run k-means with k-means++ initialization
58        // KMeans<f32, 8, _> uses 8-lane SIMD (AVX256)
59        let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
60        let result = kmean.kmeans_lloyd(
61            actual_clusters,
62            max_iters,
63            KMeans::init_kmeanplusplus,
64            &KMeansConfig::default(),
65        );
66
67        // Extract centroids from StrideBuffer to flat Vec
68        let centroids: Vec<f32> = result
69            .centroids
70            .iter()
71            .flat_map(|c| c.iter().copied())
72            .collect();
73
74        let version = std::time::SystemTime::now()
75            .duration_since(std::time::UNIX_EPOCH)
76            .unwrap_or_default()
77            .as_millis() as u64;
78
79        Self {
80            num_clusters: actual_clusters as u32,
81            dim,
82            centroids,
83            version,
84        }
85    }
86
87    /// Fallback k-means for non-native builds (WASM)
88    #[cfg(not(feature = "native"))]
89    pub fn train(vectors: &[Vec<f32>], num_clusters: usize, max_iters: usize, seed: u64) -> Self {
90        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
91        assert!(num_clusters > 0, "Need at least 1 cluster");
92
93        let actual_clusters = num_clusters.min(vectors.len());
94        let dim = vectors[0].len();
95        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
96
97        // Simple random initialization
98        let mut indices: Vec<usize> = (0..vectors.len()).collect();
99        indices.shuffle(&mut rng);
100
101        let mut centroids: Vec<f32> = indices[..actual_clusters]
102            .iter()
103            .flat_map(|&i| vectors[i].iter().copied())
104            .collect();
105
106        // K-means iterations
107        for _ in 0..max_iters {
108            let assignments: Vec<usize> = vectors
109                .iter()
110                .map(|v| Self::find_nearest_centroid_idx(v, &centroids, dim))
111                .collect();
112
113            let mut new_centroids = vec![0.0f32; actual_clusters * dim];
114            let mut counts = vec![0usize; actual_clusters];
115
116            for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
117                counts[cluster_id] += 1;
118                let offset = cluster_id * dim;
119                for (i, &val) in vectors[vec_idx].iter().enumerate() {
120                    new_centroids[offset + i] += val;
121                }
122            }
123
124            for cluster_id in 0..actual_clusters {
125                if counts[cluster_id] > 0 {
126                    let offset = cluster_id * dim;
127                    for i in 0..dim {
128                        new_centroids[offset + i] /= counts[cluster_id] as f32;
129                    }
130                }
131            }
132
133            centroids = new_centroids;
134        }
135
136        let version = std::time::SystemTime::now()
137            .duration_since(std::time::UNIX_EPOCH)
138            .unwrap_or_default()
139            .as_millis() as u64;
140
141        Self {
142            num_clusters: actual_clusters as u32,
143            dim,
144            centroids,
145            version,
146        }
147    }
148
149    /// K-means++ initialization for better starting centroids
150    #[allow(dead_code)]
151    fn kmeans_plusplus_init(
152        vectors: &[Vec<f32>],
153        num_clusters: usize,
154        rng: &mut impl Rng,
155    ) -> Vec<f32> {
156        let dim = vectors[0].len();
157        let mut centroids = Vec::with_capacity(num_clusters * dim);
158
159        // First centroid: random
160        let first_idx = rng.random_range(0..vectors.len());
161        centroids.extend_from_slice(&vectors[first_idx]);
162
163        // Remaining centroids: weighted by distance to nearest existing centroid
164        for _ in 1..num_clusters {
165            let mut distances: Vec<f32> = vectors
166                .iter()
167                .map(|v| {
168                    let mut min_dist = f32::MAX;
169                    for c in 0..(centroids.len() / dim) {
170                        let offset = c * dim;
171                        let dist: f32 = v
172                            .iter()
173                            .zip(&centroids[offset..offset + dim])
174                            .map(|(&a, &b)| (a - b) * (a - b))
175                            .sum();
176                        min_dist = min_dist.min(dist);
177                    }
178                    min_dist
179                })
180                .collect();
181
182            // Normalize to probabilities
183            let total: f32 = distances.iter().sum();
184            if total > 0.0 {
185                for d in &mut distances {
186                    *d /= total;
187                }
188            }
189
190            // Sample proportional to distance squared
191            let r: f32 = rng.random();
192            let mut cumsum = 0.0;
193            let mut chosen_idx = 0;
194            for (i, &d) in distances.iter().enumerate() {
195                cumsum += d;
196                if cumsum >= r {
197                    chosen_idx = i;
198                    break;
199                }
200            }
201
202            centroids.extend_from_slice(&vectors[chosen_idx]);
203        }
204
205        centroids
206    }
207
208    /// Find nearest centroid index for a vector
209    fn find_nearest_centroid_idx(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
210        let num_clusters = centroids.len() / dim;
211        let mut best_idx = 0;
212        let mut best_dist = f32::MAX;
213
214        for c in 0..num_clusters {
215            let offset = c * dim;
216            let dist: f32 = vector
217                .iter()
218                .zip(&centroids[offset..offset + dim])
219                .map(|(&a, &b)| (a - b) * (a - b))
220                .sum();
221
222            if dist < best_dist {
223                best_dist = dist;
224                best_idx = c;
225            }
226        }
227
228        best_idx
229    }
230
231    /// Find nearest cluster for a query vector
232    pub fn find_nearest(&self, vector: &[f32]) -> u32 {
233        Self::find_nearest_centroid_idx(vector, &self.centroids, self.dim) as u32
234    }
235
236    /// Find k nearest clusters for a query vector
237    pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
238        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
239            .map(|c| {
240                let offset = c as usize * self.dim;
241                let dist: f32 = vector
242                    .iter()
243                    .zip(&self.centroids[offset..offset + self.dim])
244                    .map(|(&a, &b)| (a - b) * (a - b))
245                    .sum();
246                (c, dist)
247            })
248            .collect();
249
250        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
251        distances.truncate(k);
252        distances.into_iter().map(|(c, _)| c).collect()
253    }
254
255    /// Get centroid for a cluster
256    pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
257        let offset = cluster_id as usize * self.dim;
258        &self.centroids[offset..offset + self.dim]
259    }
260
261    /// Save to binary file
262    pub fn save(&self, path: &Path) -> io::Result<()> {
263        let mut file = std::fs::File::create(path)?;
264        self.write_to(&mut file)
265    }
266
267    /// Write to any writer
268    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
269        writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
270        writer.write_u32::<LittleEndian>(1)?; // version
271        writer.write_u64::<LittleEndian>(self.version)?;
272        writer.write_u32::<LittleEndian>(self.num_clusters)?;
273        writer.write_u32::<LittleEndian>(self.dim as u32)?;
274
275        for &val in &self.centroids {
276            writer.write_f32::<LittleEndian>(val)?;
277        }
278
279        Ok(())
280    }
281
282    /// Load from binary file
283    pub fn load(path: &Path) -> io::Result<Self> {
284        let data = std::fs::read(path)?;
285        Self::read_from(&mut Cursor::new(data))
286    }
287
288    /// Read from any reader
289    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
290        let magic = reader.read_u32::<LittleEndian>()?;
291        if magic != CENTROIDS_MAGIC {
292            return Err(io::Error::new(
293                io::ErrorKind::InvalidData,
294                "Invalid centroids file magic",
295            ));
296        }
297
298        let _file_version = reader.read_u32::<LittleEndian>()?;
299        let version = reader.read_u64::<LittleEndian>()?;
300        let num_clusters = reader.read_u32::<LittleEndian>()?;
301        let dim = reader.read_u32::<LittleEndian>()? as usize;
302
303        let mut centroids = vec![0.0f32; num_clusters as usize * dim];
304        for val in &mut centroids {
305            *val = reader.read_f32::<LittleEndian>()?;
306        }
307
308        Ok(Self {
309            num_clusters,
310            dim,
311            centroids,
312            version,
313        })
314    }
315
316    /// Serialize to bytes
317    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
318        let mut buf = Vec::new();
319        self.write_to(&mut buf)?;
320        Ok(buf)
321    }
322
323    /// Deserialize from bytes
324    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
325        Self::read_from(&mut Cursor::new(data))
326    }
327}
328
329/// Data for a single cluster within a segment
330#[derive(Debug, Clone, Default, Serialize, Deserialize)]
331pub struct ClusterData {
332    /// Document IDs (local to segment)
333    pub doc_ids: Vec<u32>,
334    /// Binary quantized vectors
335    pub binary_codes: Vec<QuantizedVector>,
336    /// Raw vectors for re-ranking (optional)
337    pub raw_vectors: Option<Vec<Vec<f32>>>,
338}
339
340impl ClusterData {
341    pub fn new() -> Self {
342        Self::default()
343    }
344
345    pub fn len(&self) -> usize {
346        self.doc_ids.len()
347    }
348
349    pub fn is_empty(&self) -> bool {
350        self.doc_ids.is_empty()
351    }
352
353    /// Append another cluster's data (for merging)
354    pub fn append(&mut self, other: &ClusterData, doc_id_offset: u32) {
355        for &doc_id in &other.doc_ids {
356            self.doc_ids.push(doc_id + doc_id_offset);
357        }
358        self.binary_codes.extend(other.binary_codes.iter().cloned());
359
360        if let Some(ref other_raw) = other.raw_vectors {
361            let raw = self.raw_vectors.get_or_insert_with(Vec::new);
362            raw.extend(other_raw.iter().cloned());
363        }
364    }
365}
366
367/// IVF-RaBitQ index configuration
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct IVFConfig {
370    /// Vector dimension
371    pub dim: usize,
372    /// Random seed for reproducible transforms
373    pub seed: u64,
374    /// Number of bits for query quantization (usually 4)
375    pub query_bits: u8,
376    /// Store raw vectors for re-ranking
377    pub store_raw: bool,
378    /// Number of clusters to probe during search
379    pub default_nprobe: usize,
380}
381
382impl IVFConfig {
383    pub fn new(dim: usize) -> Self {
384        Self {
385            dim,
386            seed: 42,
387            query_bits: 4,
388            store_raw: true,
389            default_nprobe: 32,
390        }
391    }
392}
393
394/// IVF-RaBitQ index for a single segment
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct IVFRaBitQIndex {
397    /// Configuration
398    pub config: IVFConfig,
399    /// Version of coarse centroids used (for merge compatibility)
400    pub centroids_version: u64,
401    /// Random signs for transform (+1 or -1)
402    pub random_signs: Vec<i8>,
403    /// Random permutation for transform
404    pub random_perm: Vec<u32>,
405    /// Cluster data (sparse - only populated clusters)
406    pub clusters: HashMap<u32, ClusterData>,
407    /// Total number of vectors indexed
408    pub num_vectors: usize,
409}
410
411impl IVFRaBitQIndex {
412    /// Create a new empty IVF index
413    pub fn new(config: IVFConfig, centroids_version: u64) -> Self {
414        let dim = config.dim;
415        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
416
417        // Generate random signs
418        let random_signs: Vec<i8> = (0..dim)
419            .map(|_| if rng.random::<bool>() { 1 } else { -1 })
420            .collect();
421
422        // Generate random permutation
423        let mut random_perm: Vec<u32> = (0..dim as u32).collect();
424        for i in (1..dim).rev() {
425            let j = rng.random_range(0..=i);
426            random_perm.swap(i, j);
427        }
428
429        Self {
430            config,
431            centroids_version,
432            random_signs,
433            random_perm,
434            clusters: HashMap::new(),
435            num_vectors: 0,
436        }
437    }
438
439    /// Build index from vectors using provided coarse centroids
440    pub fn build(
441        config: IVFConfig,
442        coarse_centroids: &CoarseCentroids,
443        vectors: &[Vec<f32>],
444        doc_ids: Option<&[u32]>,
445    ) -> Self {
446        let mut index = Self::new(config.clone(), coarse_centroids.version);
447
448        for (i, vector) in vectors.iter().enumerate() {
449            let doc_id = doc_ids.map(|ids| ids[i]).unwrap_or(i as u32);
450            index.add_vector(coarse_centroids, doc_id, vector);
451        }
452
453        index
454    }
455
456    /// Add a single vector to the index
457    pub fn add_vector(&mut self, coarse_centroids: &CoarseCentroids, doc_id: u32, vector: &[f32]) {
458        // Find nearest cluster
459        let cluster_id = coarse_centroids.find_nearest(vector);
460
461        // Get cluster centroid
462        let centroid = coarse_centroids.get_centroid(cluster_id);
463
464        // Quantize relative to cluster centroid
465        let binary_code = self.quantize_vector(vector, centroid);
466
467        // Store in cluster
468        let cluster = self.clusters.entry(cluster_id).or_default();
469        cluster.doc_ids.push(doc_id);
470        cluster.binary_codes.push(binary_code);
471
472        if self.config.store_raw {
473            cluster
474                .raw_vectors
475                .get_or_insert_with(Vec::new)
476                .push(vector.to_vec());
477        }
478
479        self.num_vectors += 1;
480    }
481
482    /// Quantize a vector relative to a centroid
483    fn quantize_vector(&self, raw: &[f32], centroid: &[f32]) -> QuantizedVector {
484        let dim = self.config.dim;
485
486        // Subtract centroid and compute norm
487        let mut centered: Vec<f32> = raw.iter().zip(centroid).map(|(&v, &c)| v - c).collect();
488
489        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
490        let dist_to_centroid = norm;
491
492        // Normalize
493        if norm > 1e-10 {
494            for x in &mut centered {
495                *x /= norm;
496            }
497        }
498
499        // Apply random transform
500        let transformed: Vec<f32> = (0..dim)
501            .map(|i| {
502                let src_idx = self.random_perm[i] as usize;
503                centered[src_idx] * self.random_signs[src_idx] as f32
504            })
505            .collect();
506
507        // Binary quantize
508        let num_bytes = dim.div_ceil(8);
509        let mut bits = vec![0u8; num_bytes];
510        let mut popcount = 0u32;
511
512        for i in 0..dim {
513            if transformed[i] >= 0.0 {
514                bits[i / 8] |= 1 << (i % 8);
515                popcount += 1;
516            }
517        }
518
519        // Compute self dot product
520        let scale = 1.0 / (dim as f32).sqrt();
521        let mut self_dot = 0.0f32;
522        for i in 0..dim {
523            let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
524                scale
525            } else {
526                -scale
527            };
528            self_dot += transformed[i] * o_bar_i;
529        }
530
531        QuantizedVector {
532            bits,
533            dist_to_centroid,
534            self_dot,
535            popcount,
536        }
537    }
538
539    /// Search for k nearest neighbors
540    pub fn search(
541        &self,
542        coarse_centroids: &CoarseCentroids,
543        query: &[f32],
544        k: usize,
545        nprobe: usize,
546    ) -> Vec<(u32, f32)> {
547        // Find nprobe nearest clusters
548        let nearest_clusters = coarse_centroids.find_k_nearest(query, nprobe);
549
550        // Collect candidates from all probed clusters
551        let mut candidates: Vec<(u32, f32)> = Vec::new();
552
553        for cluster_id in nearest_clusters {
554            if let Some(cluster) = self.clusters.get(&cluster_id) {
555                let centroid = coarse_centroids.get_centroid(cluster_id);
556                let prepared = self.prepare_query(query, centroid);
557
558                for (i, binary_code) in cluster.binary_codes.iter().enumerate() {
559                    let dist = self.estimate_distance(&prepared, binary_code);
560                    candidates.push((cluster.doc_ids[i], dist));
561                }
562            }
563        }
564
565        // Sort by estimated distance
566        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
567
568        // Re-rank top candidates with exact distance if raw vectors available
569        let rerank_count = (k * 3).min(candidates.len());
570        if rerank_count > 0 {
571            let mut reranked: Vec<(u32, f32)> = Vec::with_capacity(rerank_count);
572
573            for &(doc_id, _) in candidates.iter().take(rerank_count) {
574                // Find the vector in clusters
575                for cluster in self.clusters.values() {
576                    if let Some(pos) = cluster.doc_ids.iter().position(|&d| d == doc_id) {
577                        if let Some(ref raw_vecs) = cluster.raw_vectors {
578                            let raw_vec = &raw_vecs[pos];
579                            let dist: f32 = query
580                                .iter()
581                                .zip(raw_vec.iter())
582                                .map(|(&a, &b)| (a - b).powi(2))
583                                .sum();
584                            reranked.push((doc_id, dist));
585                        }
586                        break;
587                    }
588                }
589            }
590
591            if !reranked.is_empty() {
592                reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
593                reranked.truncate(k);
594                return reranked;
595            }
596        }
597
598        // Sort by distance
599        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
600        candidates.truncate(k);
601        candidates
602    }
603
604    /// Prepare query for fast distance estimation (matches RaBitQ algorithm)
605    fn prepare_query(&self, raw_query: &[f32], centroid: &[f32]) -> PreparedQuery {
606        let dim = self.config.dim;
607
608        // Subtract centroid
609        let mut centered: Vec<f32> = raw_query
610            .iter()
611            .zip(centroid)
612            .map(|(&v, &c)| v - c)
613            .collect();
614
615        let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
616        let dist_to_centroid = norm;
617
618        // Normalize
619        if norm > 1e-10 {
620            for x in &mut centered {
621                *x /= norm;
622            }
623        }
624
625        // Apply random transform
626        let transformed: Vec<f32> = (0..dim)
627            .map(|i| {
628                let src_idx = self.random_perm[i] as usize;
629                centered[src_idx] * self.random_signs[src_idx] as f32
630            })
631            .collect();
632
633        // Scalar quantize to 4-bit (same as RaBitQ)
634        let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
635        let max_val = transformed
636            .iter()
637            .cloned()
638            .fold(f32::NEG_INFINITY, f32::max);
639        let lower = min_val;
640        let width = if max_val > min_val {
641            max_val - min_val
642        } else {
643            1.0
644        };
645
646        // Quantize to 0-15 range
647        let quantized_vals: Vec<u8> = transformed
648            .iter()
649            .map(|&x| {
650                let normalized = (x - lower) / width;
651                (normalized * 15.0).round().clamp(0.0, 15.0) as u8
652            })
653            .collect();
654
655        // Compute sum of quantized values
656        let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
657
658        // Build LUTs for fast dot product
659        let num_luts = dim.div_ceil(4);
660        let mut luts = vec![[0u16; 16]; num_luts];
661
662        for (lut_idx, lut) in luts.iter_mut().enumerate() {
663            let base_dim = lut_idx * 4;
664            for pattern in 0u8..16 {
665                let mut dot = 0u16;
666                for bit in 0..4 {
667                    let dim_idx = base_dim + bit;
668                    if dim_idx < dim && (pattern >> bit) & 1 == 1 {
669                        dot += quantized_vals[dim_idx] as u16;
670                    }
671                }
672                lut[pattern as usize] = dot;
673            }
674        }
675
676        PreparedQuery {
677            dist_to_centroid,
678            lower,
679            width,
680            sum,
681            luts,
682        }
683    }
684
685    /// Estimate distance using LUT-based dot product (matches RaBitQ algorithm)
686    fn estimate_distance(&self, query: &PreparedQuery, vec: &QuantizedVector) -> f32 {
687        let dim = self.config.dim;
688
689        // LUT-based dot product
690        let mut dot_sum = 0u32;
691        for (lut_idx, lut) in query.luts.iter().enumerate() {
692            let base_bit = lut_idx * 4;
693            let byte_idx = base_bit / 8;
694            let bit_offset = base_bit % 8;
695
696            let byte = vec.bits.get(byte_idx).copied().unwrap_or(0);
697            let next_byte = vec.bits.get(byte_idx + 1).copied().unwrap_or(0);
698
699            let pattern = if bit_offset <= 4 {
700                (byte >> bit_offset) & 0x0F
701            } else {
702                ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
703            };
704
705            dot_sum += lut[pattern as usize] as u32;
706        }
707
708        // Dequantize using RaBitQ formula
709        let scale = 1.0 / (dim as f32).sqrt();
710
711        // sum_positive = sum of q[i] where bit[i] = 1
712        // = popcount * lower + (dot_sum * width / 15)
713        let sum_positive = vec.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
714
715        // sum_all = D * lower + sum_q * width / 15
716        let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
717
718        // <q, o_bar> = scale * (2 * sum_positive - sum_all)
719        let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
720
721        // Estimate <q, o> using the corrective factor <o, o_bar>
722        let q_o_estimate = if vec.self_dot.abs() > 1e-6 {
723            q_obar_dot / vec.self_dot
724        } else {
725            q_obar_dot
726        };
727
728        // Clamp to valid range
729        let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
730
731        // Distance formula: ||o - q||^2 = ||o||^2 + ||q||^2 - 2*||o||*||q||*<o,q>
732        let dist_sq = vec.dist_to_centroid * vec.dist_to_centroid
733            + query.dist_to_centroid * query.dist_to_centroid
734            - 2.0 * vec.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
735
736        dist_sq.max(0.0)
737    }
738
739    /// Merge multiple IVF indexes (O(1) per cluster - just concatenate)
740    pub fn merge(
741        indexes: &[&IVFRaBitQIndex],
742        doc_id_offsets: &[u32],
743    ) -> Result<Self, &'static str> {
744        if indexes.is_empty() {
745            return Err("No indexes to merge");
746        }
747
748        // Verify all indexes use same centroids version
749        let version = indexes[0].centroids_version;
750        for idx in indexes.iter().skip(1) {
751            if idx.centroids_version != version {
752                return Err("Cannot merge indexes with different centroid versions");
753            }
754        }
755
756        let config = indexes[0].config.clone();
757        let mut merged = Self::new(config, version);
758
759        // Merge clusters
760        for (seg_idx, index) in indexes.iter().enumerate() {
761            let offset = doc_id_offsets[seg_idx];
762
763            for (&cluster_id, cluster_data) in &index.clusters {
764                let merged_cluster = merged.clusters.entry(cluster_id).or_default();
765
766                merged_cluster.append(cluster_data, offset);
767            }
768
769            merged.num_vectors += index.num_vectors;
770        }
771
772        Ok(merged)
773    }
774
775    /// Get number of populated clusters
776    pub fn num_clusters(&self) -> usize {
777        self.clusters.len()
778    }
779
780    /// Get total number of vectors
781    pub fn len(&self) -> usize {
782        self.num_vectors
783    }
784
785    pub fn is_empty(&self) -> bool {
786        self.num_vectors == 0
787    }
788}
789
790/// Prepared query for fast distance estimation
791struct PreparedQuery {
792    dist_to_centroid: f32,
793    lower: f32,
794    width: f32,
795    #[allow(dead_code)]
796    sum: u32,
797    luts: Vec<[u16; 16]>,
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803
804    #[test]
805    fn test_coarse_centroids_train() {
806        // Generate random vectors
807        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
808        let vectors: Vec<Vec<f32>> = (0..1000)
809            .map(|_| (0..64).map(|_| rng.random::<f32>()).collect())
810            .collect();
811
812        let centroids = CoarseCentroids::train(&vectors, 16, 10, 42);
813
814        assert_eq!(centroids.num_clusters, 16);
815        assert_eq!(centroids.dim, 64);
816        assert_eq!(centroids.centroids.len(), 16 * 64);
817    }
818
819    #[test]
820    fn test_coarse_centroids_save_load() {
821        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
822        let vectors: Vec<Vec<f32>> = (0..100)
823            .map(|_| (0..32).map(|_| rng.random::<f32>()).collect())
824            .collect();
825
826        let centroids = CoarseCentroids::train(&vectors, 8, 5, 42);
827        let bytes = centroids.to_bytes().unwrap();
828        let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
829
830        assert_eq!(centroids.num_clusters, loaded.num_clusters);
831        assert_eq!(centroids.dim, loaded.dim);
832        assert_eq!(centroids.centroids, loaded.centroids);
833    }
834
835    #[test]
836    fn test_ivf_build_and_search() {
837        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
838        let dim = 64;
839
840        // Generate vectors
841        let vectors: Vec<Vec<f32>> = (0..1000)
842            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
843            .collect();
844
845        // Train centroids
846        let centroids = CoarseCentroids::train(&vectors, 16, 10, 42);
847
848        // Build index
849        let config = IVFConfig::new(dim);
850        let index = IVFRaBitQIndex::build(config, &centroids, &vectors, None);
851
852        assert_eq!(index.len(), 1000);
853        assert!(index.num_clusters() <= 16);
854
855        // Search
856        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
857        let results = index.search(&centroids, &query, 10, 4);
858
859        assert_eq!(results.len(), 10);
860    }
861
862    #[test]
863    fn test_ivf_merge() {
864        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
865        let dim = 32;
866
867        // Generate vectors for two segments
868        let vectors1: Vec<Vec<f32>> = (0..500)
869            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
870            .collect();
871        let vectors2: Vec<Vec<f32>> = (0..500)
872            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
873            .collect();
874
875        // Train centroids (shared)
876        let all_vectors: Vec<Vec<f32>> = vectors1.iter().chain(vectors2.iter()).cloned().collect();
877        let centroids = CoarseCentroids::train(&all_vectors, 8, 10, 42);
878
879        // Build two indexes
880        let config = IVFConfig::new(dim);
881        let index1 = IVFRaBitQIndex::build(config.clone(), &centroids, &vectors1, None);
882        let index2 = IVFRaBitQIndex::build(config, &centroids, &vectors2, None);
883
884        // Merge
885        let merged = IVFRaBitQIndex::merge(&[&index1, &index2], &[0, 500]).unwrap();
886
887        assert_eq!(merged.len(), 1000);
888
889        // Search merged index
890        let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
891        let results = merged.search(&centroids, &query, 10, 4);
892
893        assert_eq!(results.len(), 10);
894    }
895}