Skip to main content

hermes_core/structures/vector/ivf/
coarse.rs

1//! Coarse centroids for IVF partitioning
2//!
3//! Provides k-means clustering for the first level of IVF indexing.
4//! Trained once, shared across all segments for O(1) merge compatibility.
5
6use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10#[cfg(not(feature = "native"))]
11use rand::SeedableRng;
12#[cfg(not(feature = "native"))]
13use rand::prelude::SliceRandom;
14use serde::{Deserialize, Serialize};
15
16use super::soar::{MultiAssignment, SoarConfig};
17
18/// Magic number for coarse centroids file
19const CENTROIDS_MAGIC: u32 = 0x48435643; // "CVCH" - Coarse Vector Centroids Hermes
20
21/// Configuration for coarse quantizer
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CoarseConfig {
24    /// Number of clusters
25    pub num_clusters: usize,
26    /// Vector dimension
27    pub dim: usize,
28    /// Maximum k-means iterations
29    pub max_iters: usize,
30    /// Random seed for reproducibility
31    pub seed: u64,
32    /// SOAR configuration (optional)
33    pub soar: Option<SoarConfig>,
34}
35
36impl CoarseConfig {
37    pub fn new(dim: usize, num_clusters: usize) -> Self {
38        Self {
39            num_clusters,
40            dim,
41            max_iters: 25,
42            seed: 42,
43            soar: None,
44        }
45    }
46
47    pub fn with_soar(mut self, config: SoarConfig) -> Self {
48        self.soar = Some(config);
49        self
50    }
51
52    pub fn with_seed(mut self, seed: u64) -> Self {
53        self.seed = seed;
54        self
55    }
56
57    pub fn with_max_iters(mut self, iters: usize) -> Self {
58        self.max_iters = iters;
59        self
60    }
61}
62
63/// Coarse centroids for IVF - trained once, shared across all segments
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CoarseCentroids {
66    /// Number of clusters
67    pub num_clusters: u32,
68    /// Vector dimension
69    pub dim: usize,
70    /// Centroids stored as flat array (num_clusters × dim)
71    pub centroids: Vec<f32>,
72    /// Version for compatibility checking during merge
73    pub version: u64,
74    /// SOAR configuration (if enabled)
75    pub soar_config: Option<SoarConfig>,
76}
77
78impl CoarseCentroids {
79    /// Train coarse centroids using k-means algorithm
80    ///
81    /// Uses kentro crate for clustering (native feature).
82    #[cfg(feature = "native")]
83    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
84        use kentro::KMeans;
85        use ndarray::Array2;
86
87        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
88        assert!(config.num_clusters > 0, "Need at least 1 cluster");
89
90        let actual_clusters = config.num_clusters.min(vectors.len());
91        let dim = config.dim;
92
93        // Convert to ndarray format
94        let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
95        let data = Array2::from_shape_vec((vectors.len(), dim), flat)
96            .expect("Failed to create ndarray from vectors");
97
98        // Run k-means with euclidean distance
99        let mut kmeans = KMeans::new(actual_clusters)
100            .with_euclidean(true)
101            .with_iterations(config.max_iters);
102        let _ = kmeans
103            .train(data.view(), None)
104            .expect("K-means training failed");
105
106        // Extract centroids
107        let centroids: Vec<f32> = kmeans
108            .centroids()
109            .expect("No centroids after training")
110            .iter()
111            .copied()
112            .collect();
113
114        let version = std::time::SystemTime::now()
115            .duration_since(std::time::UNIX_EPOCH)
116            .unwrap_or_default()
117            .as_millis() as u64;
118
119        Self {
120            num_clusters: actual_clusters as u32,
121            dim,
122            centroids,
123            version,
124            soar_config: config.soar.clone(),
125        }
126    }
127
128    /// Fallback k-means for non-native builds (WASM)
129    #[cfg(not(feature = "native"))]
130    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
131        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
132        assert!(config.num_clusters > 0, "Need at least 1 cluster");
133
134        let actual_clusters = config.num_clusters.min(vectors.len());
135        let dim = config.dim;
136        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
137
138        // Simple random initialization
139        let mut indices: Vec<usize> = (0..vectors.len()).collect();
140        indices.shuffle(&mut rng);
141
142        let mut centroids: Vec<f32> = indices[..actual_clusters]
143            .iter()
144            .flat_map(|&i| vectors[i].iter().copied())
145            .collect();
146
147        // K-means iterations
148        for _ in 0..config.max_iters {
149            let assignments: Vec<usize> = vectors
150                .iter()
151                .map(|v| Self::find_nearest_idx_static(v, &centroids, dim))
152                .collect();
153
154            let mut new_centroids = vec![0.0f32; actual_clusters * dim];
155            let mut counts = vec![0usize; actual_clusters];
156
157            for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
158                counts[cluster_id] += 1;
159                let offset = cluster_id * dim;
160                for (i, &val) in vectors[vec_idx].iter().enumerate() {
161                    new_centroids[offset + i] += val;
162                }
163            }
164
165            for (cluster_id, &count) in counts.iter().enumerate().take(actual_clusters) {
166                if count > 0 {
167                    let offset = cluster_id * dim;
168                    for i in 0..dim {
169                        new_centroids[offset + i] /= count as f32;
170                    }
171                }
172            }
173
174            centroids = new_centroids;
175        }
176
177        let version = std::time::SystemTime::now()
178            .duration_since(std::time::UNIX_EPOCH)
179            .unwrap_or_default()
180            .as_millis() as u64;
181
182        Self {
183            num_clusters: actual_clusters as u32,
184            dim,
185            centroids,
186            version,
187            soar_config: config.soar.clone(),
188        }
189    }
190
191    /// Find nearest centroid index for a vector (static helper)
192    fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
193        let num_clusters = centroids.len() / dim;
194        let mut best_idx = 0;
195        let mut best_dist = f32::MAX;
196
197        for c in 0..num_clusters {
198            let offset = c * dim;
199            let dist: f32 = vector
200                .iter()
201                .zip(&centroids[offset..offset + dim])
202                .map(|(&a, &b)| (a - b) * (a - b))
203                .sum();
204
205            if dist < best_dist {
206                best_dist = dist;
207                best_idx = c;
208            }
209        }
210
211        best_idx
212    }
213
214    /// Find nearest cluster for a query vector
215    pub fn find_nearest(&self, vector: &[f32]) -> u32 {
216        Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
217    }
218
219    /// Find k nearest clusters for a query vector
220    pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
221        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
222            .map(|c| {
223                let offset = c as usize * self.dim;
224                let dist: f32 = vector
225                    .iter()
226                    .zip(&self.centroids[offset..offset + self.dim])
227                    .map(|(&a, &b)| (a - b) * (a - b))
228                    .sum();
229                (c, dist)
230            })
231            .collect();
232
233        // Partial sort: O(n + k log k) instead of O(n log n)
234        if distances.len() > k {
235            distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
236            distances.truncate(k);
237        }
238        distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
239        distances.into_iter().map(|(c, _)| c).collect()
240    }
241
242    /// Find k nearest clusters with their distances
243    pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
244        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
245            .map(|c| {
246                let offset = c as usize * self.dim;
247                let dist: f32 = vector
248                    .iter()
249                    .zip(&self.centroids[offset..offset + self.dim])
250                    .map(|(&a, &b)| (a - b) * (a - b))
251                    .sum();
252                (c, dist)
253            })
254            .collect();
255
256        // Partial sort: O(n + k log k) instead of O(n log n)
257        if distances.len() > k {
258            distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
259            distances.truncate(k);
260        }
261        distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
262        distances
263    }
264
265    /// Assign vector with SOAR (if configured) or standard assignment
266    pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
267        if let Some(ref soar_config) = self.soar_config {
268            self.assign_with_soar(vector, soar_config)
269        } else {
270            MultiAssignment {
271                primary_cluster: self.find_nearest(vector),
272                secondary_clusters: Vec::new(),
273            }
274        }
275    }
276
277    /// SOAR-style assignment: find secondary clusters with orthogonal residuals
278    pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
279        // 1. Find primary cluster (nearest centroid)
280        let primary = self.find_nearest(vector);
281        let primary_centroid = self.get_centroid(primary);
282
283        // 2. Compute primary residual r = x - c
284        let residual: Vec<f32> = vector
285            .iter()
286            .zip(primary_centroid)
287            .map(|(v, c)| v - c)
288            .collect();
289
290        let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
291
292        // 3. Check if we should spill (selective spilling)
293        if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
294            return MultiAssignment {
295                primary_cluster: primary,
296                secondary_clusters: Vec::new(),
297            };
298        }
299
300        // 4. Find secondary clusters that MINIMIZE |⟨r, r'⟩| (orthogonal residuals)
301        let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
302            .filter(|&c| c != primary)
303            .map(|c| {
304                let centroid = self.get_centroid(c);
305                // Compute r' = x - c'
306                // Then compute |⟨r, r'⟩| - we want this SMALL (orthogonal)
307                let dot: f32 = vector
308                    .iter()
309                    .zip(centroid)
310                    .zip(&residual)
311                    .map(|((v, c), r)| (v - c) * r)
312                    .sum();
313                (c, dot.abs())
314            })
315            .collect();
316
317        // Partial sort by orthogonality (smallest dot product first)
318        let take = config.num_secondary.min(candidates.len());
319        if candidates.len() > take {
320            candidates.select_nth_unstable_by(take, |a, b| a.1.partial_cmp(&b.1).unwrap());
321            candidates.truncate(take);
322        }
323
324        MultiAssignment {
325            primary_cluster: primary,
326            secondary_clusters: candidates
327                .iter()
328                .take(config.num_secondary)
329                .map(|(c, _)| *c)
330                .collect(),
331        }
332    }
333
334    /// Get centroid for a cluster
335    pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
336        let offset = cluster_id as usize * self.dim;
337        &self.centroids[offset..offset + self.dim]
338    }
339
340    /// Compute residual vector (vector - centroid)
341    pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
342        let centroid = self.get_centroid(cluster_id);
343        vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
344    }
345
346    /// Save to binary file
347    pub fn save(&self, path: &Path) -> io::Result<()> {
348        let mut file = std::fs::File::create(path)?;
349        self.write_to(&mut file)
350    }
351
352    /// Write to any writer
353    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
354        writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
355        writer.write_u32::<LittleEndian>(2)?; // version 2 with SOAR support
356        writer.write_u64::<LittleEndian>(self.version)?;
357        writer.write_u32::<LittleEndian>(self.num_clusters)?;
358        writer.write_u32::<LittleEndian>(self.dim as u32)?;
359
360        // Write SOAR config
361        if let Some(ref soar) = self.soar_config {
362            writer.write_u8(1)?;
363            writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
364            writer.write_u8(if soar.selective { 1 } else { 0 })?;
365            writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
366        } else {
367            writer.write_u8(0)?;
368        }
369
370        for &val in &self.centroids {
371            writer.write_f32::<LittleEndian>(val)?;
372        }
373
374        Ok(())
375    }
376
377    /// Load from binary file
378    pub fn load(path: &Path) -> io::Result<Self> {
379        let data = std::fs::read(path)?;
380        Self::read_from(&mut Cursor::new(data))
381    }
382
383    /// Read from any reader
384    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
385        let magic = reader.read_u32::<LittleEndian>()?;
386        if magic != CENTROIDS_MAGIC {
387            return Err(io::Error::new(
388                io::ErrorKind::InvalidData,
389                "Invalid centroids file magic",
390            ));
391        }
392
393        let file_version = reader.read_u32::<LittleEndian>()?;
394        let version = reader.read_u64::<LittleEndian>()?;
395        let num_clusters = reader.read_u32::<LittleEndian>()?;
396        let dim = reader.read_u32::<LittleEndian>()? as usize;
397
398        // Read SOAR config (version 2+)
399        let soar_config = if file_version >= 2 {
400            let has_soar = reader.read_u8()? != 0;
401            if has_soar {
402                let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
403                let selective = reader.read_u8()? != 0;
404                let spill_threshold = reader.read_f32::<LittleEndian>()?;
405                Some(SoarConfig {
406                    num_secondary,
407                    selective,
408                    spill_threshold,
409                })
410            } else {
411                None
412            }
413        } else {
414            None
415        };
416
417        let mut centroids = vec![0.0f32; num_clusters as usize * dim];
418        for val in &mut centroids {
419            *val = reader.read_f32::<LittleEndian>()?;
420        }
421
422        Ok(Self {
423            num_clusters,
424            dim,
425            centroids,
426            version,
427            soar_config,
428        })
429    }
430
431    /// Serialize to bytes
432    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
433        let mut buf = Vec::new();
434        self.write_to(&mut buf)?;
435        Ok(buf)
436    }
437
438    /// Deserialize from bytes
439    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
440        Self::read_from(&mut Cursor::new(data))
441    }
442
443    /// Memory usage in bytes
444    pub fn size_bytes(&self) -> usize {
445        self.centroids.len() * 4 + 64 // centroids + overhead
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use rand::prelude::*;
453
454    #[test]
455    fn test_coarse_centroids_basic() {
456        let dim = 64;
457        let n = 1000;
458        let num_clusters = 16;
459
460        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
461        let vectors: Vec<Vec<f32>> = (0..n)
462            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
463            .collect();
464
465        let config = CoarseConfig::new(dim, num_clusters);
466        let centroids = CoarseCentroids::train(&config, &vectors);
467
468        assert_eq!(centroids.num_clusters, num_clusters as u32);
469        assert_eq!(centroids.dim, dim);
470    }
471
472    #[test]
473    fn test_find_nearest() {
474        let dim = 32;
475        let n = 500;
476        let num_clusters = 8;
477
478        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
479        let vectors: Vec<Vec<f32>> = (0..n)
480            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
481            .collect();
482
483        let config = CoarseConfig::new(dim, num_clusters);
484        let centroids = CoarseCentroids::train(&config, &vectors);
485
486        // Test that find_nearest returns valid cluster IDs
487        for v in &vectors {
488            let cluster = centroids.find_nearest(v);
489            assert!(cluster < centroids.num_clusters);
490        }
491    }
492
493    #[test]
494    fn test_soar_assignment() {
495        let dim = 32;
496        let n = 100;
497        let num_clusters = 8;
498
499        let mut rng = rand::rngs::StdRng::seed_from_u64(456);
500        let vectors: Vec<Vec<f32>> = (0..n)
501            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
502            .collect();
503
504        let soar_config = SoarConfig {
505            num_secondary: 2,
506            selective: false,
507            spill_threshold: 0.0,
508        };
509        let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
510        let centroids = CoarseCentroids::train(&config, &vectors);
511
512        // Test SOAR assignment
513        let assignment = centroids.assign(&vectors[0]);
514        assert!(assignment.primary_cluster < centroids.num_clusters);
515        assert_eq!(assignment.secondary_clusters.len(), 2);
516
517        // Secondary clusters should be different from primary
518        for &sec in &assignment.secondary_clusters {
519            assert_ne!(sec, assignment.primary_cluster);
520        }
521    }
522
523    #[test]
524    fn test_serialization() {
525        let dim = 16;
526        let n = 50;
527        let num_clusters = 4;
528
529        let mut rng = rand::rngs::StdRng::seed_from_u64(789);
530        let vectors: Vec<Vec<f32>> = (0..n)
531            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
532            .collect();
533
534        let config = CoarseConfig::new(dim, num_clusters);
535        let centroids = CoarseCentroids::train(&config, &vectors);
536
537        // Serialize and deserialize
538        let bytes = centroids.to_bytes().unwrap();
539        let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
540
541        assert_eq!(loaded.num_clusters, centroids.num_clusters);
542        assert_eq!(loaded.dim, centroids.dim);
543        assert_eq!(loaded.centroids.len(), centroids.centroids.len());
544    }
545}