Skip to main content

hermes_core/structures/vector/ivf/
coarse.rs

1//! Coarse centroids for IVF partitioning
2//!
3//! Provides k-means clustering for the first level of IVF indexing.
4//! Trained once, shared across all segments for O(1) merge compatibility.
5
6use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10#[cfg(not(feature = "native"))]
11use rand::SeedableRng;
12#[cfg(not(feature = "native"))]
13use rand::prelude::SliceRandom;
14use serde::{Deserialize, Serialize};
15
16use super::soar::{MultiAssignment, SoarConfig};
17
18/// Magic number for coarse centroids file
19const CENTROIDS_MAGIC: u32 = 0x48435643; // "CVCH" - Coarse Vector Centroids Hermes
20
21/// Configuration for coarse quantizer
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CoarseConfig {
24    /// Number of clusters
25    pub num_clusters: usize,
26    /// Vector dimension
27    pub dim: usize,
28    /// Maximum k-means iterations
29    pub max_iters: usize,
30    /// Random seed for reproducibility
31    pub seed: u64,
32    /// SOAR configuration (optional)
33    pub soar: Option<SoarConfig>,
34}
35
36impl CoarseConfig {
37    pub fn new(dim: usize, num_clusters: usize) -> Self {
38        Self {
39            num_clusters,
40            dim,
41            max_iters: 25,
42            seed: 42,
43            soar: None,
44        }
45    }
46
47    pub fn with_soar(mut self, config: SoarConfig) -> Self {
48        self.soar = Some(config);
49        self
50    }
51
52    pub fn with_seed(mut self, seed: u64) -> Self {
53        self.seed = seed;
54        self
55    }
56
57    pub fn with_max_iters(mut self, iters: usize) -> Self {
58        self.max_iters = iters;
59        self
60    }
61}
62
63/// Coarse centroids for IVF - trained once, shared across all segments
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CoarseCentroids {
66    /// Number of clusters
67    pub num_clusters: u32,
68    /// Vector dimension
69    pub dim: usize,
70    /// Centroids stored as flat array (num_clusters × dim)
71    pub centroids: Vec<f32>,
72    /// Version for compatibility checking during merge
73    pub version: u64,
74    /// SOAR configuration (if enabled)
75    pub soar_config: Option<SoarConfig>,
76}
77
78impl CoarseCentroids {
79    /// Train coarse centroids using k-means algorithm
80    ///
81    /// Uses kmeans crate with SIMD acceleration (native feature).
82    #[cfg(feature = "native")]
83    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
84        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
85
86        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
87        assert!(config.num_clusters > 0, "Need at least 1 cluster");
88
89        let actual_clusters = config.num_clusters.min(vectors.len());
90        let dim = config.dim;
91
92        // Flatten vectors for kmeans crate (expects flat slice)
93        let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
94
95        // Run k-means with k-means++ initialization
96        // KMeans<f32, 8, _> uses 8-lane SIMD (AVX256)
97        let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
98        let result = kmean.kmeans_lloyd(
99            actual_clusters,
100            config.max_iters,
101            KMeans::init_kmeanplusplus,
102            &KMeansConfig::default(),
103        );
104
105        // Extract centroids from StrideBuffer to flat Vec
106        let centroids: Vec<f32> = result
107            .centroids
108            .iter()
109            .flat_map(|c| c.iter().copied())
110            .collect();
111
112        let version = std::time::SystemTime::now()
113            .duration_since(std::time::UNIX_EPOCH)
114            .unwrap_or_default()
115            .as_millis() as u64;
116
117        Self {
118            num_clusters: actual_clusters as u32,
119            dim,
120            centroids,
121            version,
122            soar_config: config.soar.clone(),
123        }
124    }
125
126    /// Fallback k-means for non-native builds (WASM)
127    #[cfg(not(feature = "native"))]
128    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
129        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
130        assert!(config.num_clusters > 0, "Need at least 1 cluster");
131
132        let actual_clusters = config.num_clusters.min(vectors.len());
133        let dim = config.dim;
134        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
135
136        // Simple random initialization
137        let mut indices: Vec<usize> = (0..vectors.len()).collect();
138        indices.shuffle(&mut rng);
139
140        let mut centroids: Vec<f32> = indices[..actual_clusters]
141            .iter()
142            .flat_map(|&i| vectors[i].iter().copied())
143            .collect();
144
145        // K-means iterations
146        for _ in 0..config.max_iters {
147            let assignments: Vec<usize> = vectors
148                .iter()
149                .map(|v| Self::find_nearest_idx_static(v, &centroids, dim))
150                .collect();
151
152            let mut new_centroids = vec![0.0f32; actual_clusters * dim];
153            let mut counts = vec![0usize; actual_clusters];
154
155            for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
156                counts[cluster_id] += 1;
157                let offset = cluster_id * dim;
158                for (i, &val) in vectors[vec_idx].iter().enumerate() {
159                    new_centroids[offset + i] += val;
160                }
161            }
162
163            for cluster_id in 0..actual_clusters {
164                if counts[cluster_id] > 0 {
165                    let offset = cluster_id * dim;
166                    for i in 0..dim {
167                        new_centroids[offset + i] /= counts[cluster_id] as f32;
168                    }
169                }
170            }
171
172            centroids = new_centroids;
173        }
174
175        let version = std::time::SystemTime::now()
176            .duration_since(std::time::UNIX_EPOCH)
177            .unwrap_or_default()
178            .as_millis() as u64;
179
180        Self {
181            num_clusters: actual_clusters as u32,
182            dim,
183            centroids,
184            version,
185            soar_config: config.soar.clone(),
186        }
187    }
188
189    /// Find nearest centroid index for a vector (static helper)
190    fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
191        let num_clusters = centroids.len() / dim;
192        let mut best_idx = 0;
193        let mut best_dist = f32::MAX;
194
195        for c in 0..num_clusters {
196            let offset = c * dim;
197            let dist: f32 = vector
198                .iter()
199                .zip(&centroids[offset..offset + dim])
200                .map(|(&a, &b)| (a - b) * (a - b))
201                .sum();
202
203            if dist < best_dist {
204                best_dist = dist;
205                best_idx = c;
206            }
207        }
208
209        best_idx
210    }
211
212    /// Find nearest cluster for a query vector
213    pub fn find_nearest(&self, vector: &[f32]) -> u32 {
214        Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
215    }
216
217    /// Find k nearest clusters for a query vector
218    pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
219        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
220            .map(|c| {
221                let offset = c as usize * self.dim;
222                let dist: f32 = vector
223                    .iter()
224                    .zip(&self.centroids[offset..offset + self.dim])
225                    .map(|(&a, &b)| (a - b) * (a - b))
226                    .sum();
227                (c, dist)
228            })
229            .collect();
230
231        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
232        distances.truncate(k);
233        distances.into_iter().map(|(c, _)| c).collect()
234    }
235
236    /// Find k nearest clusters with their distances
237    pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
238        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
239            .map(|c| {
240                let offset = c as usize * self.dim;
241                let dist: f32 = vector
242                    .iter()
243                    .zip(&self.centroids[offset..offset + self.dim])
244                    .map(|(&a, &b)| (a - b) * (a - b))
245                    .sum();
246                (c, dist)
247            })
248            .collect();
249
250        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
251        distances.truncate(k);
252        distances
253    }
254
255    /// Assign vector with SOAR (if configured) or standard assignment
256    pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
257        if let Some(ref soar_config) = self.soar_config {
258            self.assign_with_soar(vector, soar_config)
259        } else {
260            MultiAssignment {
261                primary_cluster: self.find_nearest(vector),
262                secondary_clusters: Vec::new(),
263            }
264        }
265    }
266
267    /// SOAR-style assignment: find secondary clusters with orthogonal residuals
268    pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
269        // 1. Find primary cluster (nearest centroid)
270        let primary = self.find_nearest(vector);
271        let primary_centroid = self.get_centroid(primary);
272
273        // 2. Compute primary residual r = x - c
274        let residual: Vec<f32> = vector
275            .iter()
276            .zip(primary_centroid)
277            .map(|(v, c)| v - c)
278            .collect();
279
280        let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
281
282        // 3. Check if we should spill (selective spilling)
283        if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
284            return MultiAssignment {
285                primary_cluster: primary,
286                secondary_clusters: Vec::new(),
287            };
288        }
289
290        // 4. Find secondary clusters that MINIMIZE |⟨r, r'⟩| (orthogonal residuals)
291        let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
292            .filter(|&c| c != primary)
293            .map(|c| {
294                let centroid = self.get_centroid(c);
295                // Compute r' = x - c'
296                // Then compute |⟨r, r'⟩| - we want this SMALL (orthogonal)
297                let dot: f32 = vector
298                    .iter()
299                    .zip(centroid)
300                    .zip(&residual)
301                    .map(|((v, c), r)| (v - c) * r)
302                    .sum();
303                (c, dot.abs())
304            })
305            .collect();
306
307        // Sort by orthogonality (smallest dot product first)
308        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
309
310        MultiAssignment {
311            primary_cluster: primary,
312            secondary_clusters: candidates
313                .iter()
314                .take(config.num_secondary)
315                .map(|(c, _)| *c)
316                .collect(),
317        }
318    }
319
320    /// Get centroid for a cluster
321    pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
322        let offset = cluster_id as usize * self.dim;
323        &self.centroids[offset..offset + self.dim]
324    }
325
326    /// Compute residual vector (vector - centroid)
327    pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
328        let centroid = self.get_centroid(cluster_id);
329        vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
330    }
331
332    /// Save to binary file
333    pub fn save(&self, path: &Path) -> io::Result<()> {
334        let mut file = std::fs::File::create(path)?;
335        self.write_to(&mut file)
336    }
337
338    /// Write to any writer
339    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
340        writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
341        writer.write_u32::<LittleEndian>(2)?; // version 2 with SOAR support
342        writer.write_u64::<LittleEndian>(self.version)?;
343        writer.write_u32::<LittleEndian>(self.num_clusters)?;
344        writer.write_u32::<LittleEndian>(self.dim as u32)?;
345
346        // Write SOAR config
347        if let Some(ref soar) = self.soar_config {
348            writer.write_u8(1)?;
349            writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
350            writer.write_u8(if soar.selective { 1 } else { 0 })?;
351            writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
352        } else {
353            writer.write_u8(0)?;
354        }
355
356        for &val in &self.centroids {
357            writer.write_f32::<LittleEndian>(val)?;
358        }
359
360        Ok(())
361    }
362
363    /// Load from binary file
364    pub fn load(path: &Path) -> io::Result<Self> {
365        let data = std::fs::read(path)?;
366        Self::read_from(&mut Cursor::new(data))
367    }
368
369    /// Read from any reader
370    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
371        let magic = reader.read_u32::<LittleEndian>()?;
372        if magic != CENTROIDS_MAGIC {
373            return Err(io::Error::new(
374                io::ErrorKind::InvalidData,
375                "Invalid centroids file magic",
376            ));
377        }
378
379        let file_version = reader.read_u32::<LittleEndian>()?;
380        let version = reader.read_u64::<LittleEndian>()?;
381        let num_clusters = reader.read_u32::<LittleEndian>()?;
382        let dim = reader.read_u32::<LittleEndian>()? as usize;
383
384        // Read SOAR config (version 2+)
385        let soar_config = if file_version >= 2 {
386            let has_soar = reader.read_u8()? != 0;
387            if has_soar {
388                let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
389                let selective = reader.read_u8()? != 0;
390                let spill_threshold = reader.read_f32::<LittleEndian>()?;
391                Some(SoarConfig {
392                    num_secondary,
393                    selective,
394                    spill_threshold,
395                })
396            } else {
397                None
398            }
399        } else {
400            None
401        };
402
403        let mut centroids = vec![0.0f32; num_clusters as usize * dim];
404        for val in &mut centroids {
405            *val = reader.read_f32::<LittleEndian>()?;
406        }
407
408        Ok(Self {
409            num_clusters,
410            dim,
411            centroids,
412            version,
413            soar_config,
414        })
415    }
416
417    /// Serialize to bytes
418    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
419        let mut buf = Vec::new();
420        self.write_to(&mut buf)?;
421        Ok(buf)
422    }
423
424    /// Deserialize from bytes
425    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
426        Self::read_from(&mut Cursor::new(data))
427    }
428
429    /// Memory usage in bytes
430    pub fn size_bytes(&self) -> usize {
431        self.centroids.len() * 4 + 64 // centroids + overhead
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use rand::prelude::*;
439
440    #[test]
441    fn test_coarse_centroids_basic() {
442        let dim = 64;
443        let n = 1000;
444        let num_clusters = 16;
445
446        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
447        let vectors: Vec<Vec<f32>> = (0..n)
448            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
449            .collect();
450
451        let config = CoarseConfig::new(dim, num_clusters);
452        let centroids = CoarseCentroids::train(&config, &vectors);
453
454        assert_eq!(centroids.num_clusters, num_clusters as u32);
455        assert_eq!(centroids.dim, dim);
456    }
457
458    #[test]
459    fn test_find_nearest() {
460        let dim = 32;
461        let n = 500;
462        let num_clusters = 8;
463
464        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
465        let vectors: Vec<Vec<f32>> = (0..n)
466            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
467            .collect();
468
469        let config = CoarseConfig::new(dim, num_clusters);
470        let centroids = CoarseCentroids::train(&config, &vectors);
471
472        // Test that find_nearest returns valid cluster IDs
473        for v in &vectors {
474            let cluster = centroids.find_nearest(v);
475            assert!(cluster < centroids.num_clusters);
476        }
477    }
478
479    #[test]
480    fn test_soar_assignment() {
481        let dim = 32;
482        let n = 100;
483        let num_clusters = 8;
484
485        let mut rng = rand::rngs::StdRng::seed_from_u64(456);
486        let vectors: Vec<Vec<f32>> = (0..n)
487            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
488            .collect();
489
490        let soar_config = SoarConfig {
491            num_secondary: 2,
492            selective: false,
493            spill_threshold: 0.0,
494        };
495        let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
496        let centroids = CoarseCentroids::train(&config, &vectors);
497
498        // Test SOAR assignment
499        let assignment = centroids.assign(&vectors[0]);
500        assert!(assignment.primary_cluster < centroids.num_clusters);
501        assert_eq!(assignment.secondary_clusters.len(), 2);
502
503        // Secondary clusters should be different from primary
504        for &sec in &assignment.secondary_clusters {
505            assert_ne!(sec, assignment.primary_cluster);
506        }
507    }
508
509    #[test]
510    fn test_serialization() {
511        let dim = 16;
512        let n = 50;
513        let num_clusters = 4;
514
515        let mut rng = rand::rngs::StdRng::seed_from_u64(789);
516        let vectors: Vec<Vec<f32>> = (0..n)
517            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
518            .collect();
519
520        let config = CoarseConfig::new(dim, num_clusters);
521        let centroids = CoarseCentroids::train(&config, &vectors);
522
523        // Serialize and deserialize
524        let bytes = centroids.to_bytes().unwrap();
525        let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
526
527        assert_eq!(loaded.num_clusters, centroids.num_clusters);
528        assert_eq!(loaded.dim, centroids.dim);
529        assert_eq!(loaded.centroids.len(), centroids.centroids.len());
530    }
531}