hermes_core/structures/vector/ivf/
coarse.rs

1//! Coarse centroids for IVF partitioning
2//!
3//! Provides k-means clustering for the first level of IVF indexing.
4//! Trained once, shared across all segments for O(1) merge compatibility.
5
6use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10use serde::{Deserialize, Serialize};
11
12use super::soar::{MultiAssignment, SoarConfig};
13
14/// Magic number for coarse centroids file
15const CENTROIDS_MAGIC: u32 = 0x48435643; // "CVCH" - Coarse Vector Centroids Hermes
16
17/// Configuration for coarse quantizer
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CoarseConfig {
20    /// Number of clusters
21    pub num_clusters: usize,
22    /// Vector dimension
23    pub dim: usize,
24    /// Maximum k-means iterations
25    pub max_iters: usize,
26    /// Random seed for reproducibility
27    pub seed: u64,
28    /// SOAR configuration (optional)
29    pub soar: Option<SoarConfig>,
30}
31
32impl CoarseConfig {
33    pub fn new(dim: usize, num_clusters: usize) -> Self {
34        Self {
35            num_clusters,
36            dim,
37            max_iters: 25,
38            seed: 42,
39            soar: None,
40        }
41    }
42
43    pub fn with_soar(mut self, config: SoarConfig) -> Self {
44        self.soar = Some(config);
45        self
46    }
47
48    pub fn with_seed(mut self, seed: u64) -> Self {
49        self.seed = seed;
50        self
51    }
52
53    pub fn with_max_iters(mut self, iters: usize) -> Self {
54        self.max_iters = iters;
55        self
56    }
57}
58
59/// Coarse centroids for IVF - trained once, shared across all segments
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CoarseCentroids {
62    /// Number of clusters
63    pub num_clusters: u32,
64    /// Vector dimension
65    pub dim: usize,
66    /// Centroids stored as flat array (num_clusters × dim)
67    pub centroids: Vec<f32>,
68    /// Version for compatibility checking during merge
69    pub version: u64,
70    /// SOAR configuration (if enabled)
71    pub soar_config: Option<SoarConfig>,
72}
73
74impl CoarseCentroids {
75    /// Train coarse centroids using k-means algorithm
76    ///
77    /// Uses kmeans crate with SIMD acceleration (native feature).
78    #[cfg(feature = "native")]
79    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
80        use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
81
82        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
83        assert!(config.num_clusters > 0, "Need at least 1 cluster");
84
85        let actual_clusters = config.num_clusters.min(vectors.len());
86        let dim = config.dim;
87
88        // Flatten vectors for kmeans crate (expects flat slice)
89        let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
90
91        // Run k-means with k-means++ initialization
92        // KMeans<f32, 8, _> uses 8-lane SIMD (AVX256)
93        let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
94        let result = kmean.kmeans_lloyd(
95            actual_clusters,
96            config.max_iters,
97            KMeans::init_kmeanplusplus,
98            &KMeansConfig::default(),
99        );
100
101        // Extract centroids from StrideBuffer to flat Vec
102        let centroids: Vec<f32> = result
103            .centroids
104            .iter()
105            .flat_map(|c| c.iter().copied())
106            .collect();
107
108        let version = std::time::SystemTime::now()
109            .duration_since(std::time::UNIX_EPOCH)
110            .unwrap_or_default()
111            .as_millis() as u64;
112
113        Self {
114            num_clusters: actual_clusters as u32,
115            dim,
116            centroids,
117            version,
118            soar_config: config.soar.clone(),
119        }
120    }
121
122    /// Fallback k-means for non-native builds (WASM)
123    #[cfg(not(feature = "native"))]
124    pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
125        assert!(!vectors.is_empty(), "Cannot train on empty vector set");
126        assert!(config.num_clusters > 0, "Need at least 1 cluster");
127
128        let actual_clusters = config.num_clusters.min(vectors.len());
129        let dim = config.dim;
130        let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
131
132        // Simple random initialization
133        let mut indices: Vec<usize> = (0..vectors.len()).collect();
134        indices.shuffle(&mut rng);
135
136        let mut centroids: Vec<f32> = indices[..actual_clusters]
137            .iter()
138            .flat_map(|&i| vectors[i].iter().copied())
139            .collect();
140
141        // K-means iterations
142        for _ in 0..config.max_iters {
143            let assignments: Vec<usize> = vectors
144                .iter()
145                .map(|v| Self::find_nearest_idx_static(v, &centroids, dim))
146                .collect();
147
148            let mut new_centroids = vec![0.0f32; actual_clusters * dim];
149            let mut counts = vec![0usize; actual_clusters];
150
151            for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
152                counts[cluster_id] += 1;
153                let offset = cluster_id * dim;
154                for (i, &val) in vectors[vec_idx].iter().enumerate() {
155                    new_centroids[offset + i] += val;
156                }
157            }
158
159            for cluster_id in 0..actual_clusters {
160                if counts[cluster_id] > 0 {
161                    let offset = cluster_id * dim;
162                    for i in 0..dim {
163                        new_centroids[offset + i] /= counts[cluster_id] as f32;
164                    }
165                }
166            }
167
168            centroids = new_centroids;
169        }
170
171        let version = std::time::SystemTime::now()
172            .duration_since(std::time::UNIX_EPOCH)
173            .unwrap_or_default()
174            .as_millis() as u64;
175
176        Self {
177            num_clusters: actual_clusters as u32,
178            dim,
179            centroids,
180            version,
181            soar_config: config.soar.clone(),
182        }
183    }
184
185    /// Find nearest centroid index for a vector (static helper)
186    fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
187        let num_clusters = centroids.len() / dim;
188        let mut best_idx = 0;
189        let mut best_dist = f32::MAX;
190
191        for c in 0..num_clusters {
192            let offset = c * dim;
193            let dist: f32 = vector
194                .iter()
195                .zip(&centroids[offset..offset + dim])
196                .map(|(&a, &b)| (a - b) * (a - b))
197                .sum();
198
199            if dist < best_dist {
200                best_dist = dist;
201                best_idx = c;
202            }
203        }
204
205        best_idx
206    }
207
208    /// Find nearest cluster for a query vector
209    pub fn find_nearest(&self, vector: &[f32]) -> u32 {
210        Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
211    }
212
213    /// Find k nearest clusters for a query vector
214    pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
215        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
216            .map(|c| {
217                let offset = c as usize * self.dim;
218                let dist: f32 = vector
219                    .iter()
220                    .zip(&self.centroids[offset..offset + self.dim])
221                    .map(|(&a, &b)| (a - b) * (a - b))
222                    .sum();
223                (c, dist)
224            })
225            .collect();
226
227        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
228        distances.truncate(k);
229        distances.into_iter().map(|(c, _)| c).collect()
230    }
231
232    /// Find k nearest clusters with their distances
233    pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
234        let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
235            .map(|c| {
236                let offset = c as usize * self.dim;
237                let dist: f32 = vector
238                    .iter()
239                    .zip(&self.centroids[offset..offset + self.dim])
240                    .map(|(&a, &b)| (a - b) * (a - b))
241                    .sum();
242                (c, dist)
243            })
244            .collect();
245
246        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
247        distances.truncate(k);
248        distances
249    }
250
251    /// Assign vector with SOAR (if configured) or standard assignment
252    pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
253        if let Some(ref soar_config) = self.soar_config {
254            self.assign_with_soar(vector, soar_config)
255        } else {
256            MultiAssignment {
257                primary_cluster: self.find_nearest(vector),
258                secondary_clusters: Vec::new(),
259            }
260        }
261    }
262
263    /// SOAR-style assignment: find secondary clusters with orthogonal residuals
264    pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
265        // 1. Find primary cluster (nearest centroid)
266        let primary = self.find_nearest(vector);
267        let primary_centroid = self.get_centroid(primary);
268
269        // 2. Compute primary residual r = x - c
270        let residual: Vec<f32> = vector
271            .iter()
272            .zip(primary_centroid)
273            .map(|(v, c)| v - c)
274            .collect();
275
276        let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
277
278        // 3. Check if we should spill (selective spilling)
279        if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
280            return MultiAssignment {
281                primary_cluster: primary,
282                secondary_clusters: Vec::new(),
283            };
284        }
285
286        // 4. Find secondary clusters that MINIMIZE |⟨r, r'⟩| (orthogonal residuals)
287        let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
288            .filter(|&c| c != primary)
289            .map(|c| {
290                let centroid = self.get_centroid(c);
291                // Compute r' = x - c'
292                // Then compute |⟨r, r'⟩| - we want this SMALL (orthogonal)
293                let dot: f32 = vector
294                    .iter()
295                    .zip(centroid)
296                    .zip(&residual)
297                    .map(|((v, c), r)| (v - c) * r)
298                    .sum();
299                (c, dot.abs())
300            })
301            .collect();
302
303        // Sort by orthogonality (smallest dot product first)
304        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
305
306        MultiAssignment {
307            primary_cluster: primary,
308            secondary_clusters: candidates
309                .iter()
310                .take(config.num_secondary)
311                .map(|(c, _)| *c)
312                .collect(),
313        }
314    }
315
316    /// Get centroid for a cluster
317    pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
318        let offset = cluster_id as usize * self.dim;
319        &self.centroids[offset..offset + self.dim]
320    }
321
322    /// Compute residual vector (vector - centroid)
323    pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
324        let centroid = self.get_centroid(cluster_id);
325        vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
326    }
327
328    /// Save to binary file
329    pub fn save(&self, path: &Path) -> io::Result<()> {
330        let mut file = std::fs::File::create(path)?;
331        self.write_to(&mut file)
332    }
333
334    /// Write to any writer
335    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
336        writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
337        writer.write_u32::<LittleEndian>(2)?; // version 2 with SOAR support
338        writer.write_u64::<LittleEndian>(self.version)?;
339        writer.write_u32::<LittleEndian>(self.num_clusters)?;
340        writer.write_u32::<LittleEndian>(self.dim as u32)?;
341
342        // Write SOAR config
343        if let Some(ref soar) = self.soar_config {
344            writer.write_u8(1)?;
345            writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
346            writer.write_u8(if soar.selective { 1 } else { 0 })?;
347            writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
348        } else {
349            writer.write_u8(0)?;
350        }
351
352        for &val in &self.centroids {
353            writer.write_f32::<LittleEndian>(val)?;
354        }
355
356        Ok(())
357    }
358
359    /// Load from binary file
360    pub fn load(path: &Path) -> io::Result<Self> {
361        let data = std::fs::read(path)?;
362        Self::read_from(&mut Cursor::new(data))
363    }
364
365    /// Read from any reader
366    pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
367        let magic = reader.read_u32::<LittleEndian>()?;
368        if magic != CENTROIDS_MAGIC {
369            return Err(io::Error::new(
370                io::ErrorKind::InvalidData,
371                "Invalid centroids file magic",
372            ));
373        }
374
375        let file_version = reader.read_u32::<LittleEndian>()?;
376        let version = reader.read_u64::<LittleEndian>()?;
377        let num_clusters = reader.read_u32::<LittleEndian>()?;
378        let dim = reader.read_u32::<LittleEndian>()? as usize;
379
380        // Read SOAR config (version 2+)
381        let soar_config = if file_version >= 2 {
382            let has_soar = reader.read_u8()? != 0;
383            if has_soar {
384                let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
385                let selective = reader.read_u8()? != 0;
386                let spill_threshold = reader.read_f32::<LittleEndian>()?;
387                Some(SoarConfig {
388                    num_secondary,
389                    selective,
390                    spill_threshold,
391                })
392            } else {
393                None
394            }
395        } else {
396            None
397        };
398
399        let mut centroids = vec![0.0f32; num_clusters as usize * dim];
400        for val in &mut centroids {
401            *val = reader.read_f32::<LittleEndian>()?;
402        }
403
404        Ok(Self {
405            num_clusters,
406            dim,
407            centroids,
408            version,
409            soar_config,
410        })
411    }
412
413    /// Serialize to bytes
414    pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
415        let mut buf = Vec::new();
416        self.write_to(&mut buf)?;
417        Ok(buf)
418    }
419
420    /// Deserialize from bytes
421    pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
422        Self::read_from(&mut Cursor::new(data))
423    }
424
425    /// Memory usage in bytes
426    pub fn size_bytes(&self) -> usize {
427        self.centroids.len() * 4 + 64 // centroids + overhead
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use rand::prelude::*;
435
436    #[test]
437    fn test_coarse_centroids_basic() {
438        let dim = 64;
439        let n = 1000;
440        let num_clusters = 16;
441
442        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
443        let vectors: Vec<Vec<f32>> = (0..n)
444            .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
445            .collect();
446
447        let config = CoarseConfig::new(dim, num_clusters);
448        let centroids = CoarseCentroids::train(&config, &vectors);
449
450        assert_eq!(centroids.num_clusters, num_clusters as u32);
451        assert_eq!(centroids.dim, dim);
452    }
453
454    #[test]
455    fn test_find_nearest() {
456        let dim = 32;
457        let n = 500;
458        let num_clusters = 8;
459
460        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
461        let vectors: Vec<Vec<f32>> = (0..n)
462            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
463            .collect();
464
465        let config = CoarseConfig::new(dim, num_clusters);
466        let centroids = CoarseCentroids::train(&config, &vectors);
467
468        // Test that find_nearest returns valid cluster IDs
469        for v in &vectors {
470            let cluster = centroids.find_nearest(v);
471            assert!(cluster < centroids.num_clusters);
472        }
473    }
474
475    #[test]
476    fn test_soar_assignment() {
477        let dim = 32;
478        let n = 100;
479        let num_clusters = 8;
480
481        let mut rng = rand::rngs::StdRng::seed_from_u64(456);
482        let vectors: Vec<Vec<f32>> = (0..n)
483            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
484            .collect();
485
486        let soar_config = SoarConfig {
487            num_secondary: 2,
488            selective: false,
489            spill_threshold: 0.0,
490        };
491        let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
492        let centroids = CoarseCentroids::train(&config, &vectors);
493
494        // Test SOAR assignment
495        let assignment = centroids.assign(&vectors[0]);
496        assert!(assignment.primary_cluster < centroids.num_clusters);
497        assert_eq!(assignment.secondary_clusters.len(), 2);
498
499        // Secondary clusters should be different from primary
500        for &sec in &assignment.secondary_clusters {
501            assert_ne!(sec, assignment.primary_cluster);
502        }
503    }
504
505    #[test]
506    fn test_serialization() {
507        let dim = 16;
508        let n = 50;
509        let num_clusters = 4;
510
511        let mut rng = rand::rngs::StdRng::seed_from_u64(789);
512        let vectors: Vec<Vec<f32>> = (0..n)
513            .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
514            .collect();
515
516        let config = CoarseConfig::new(dim, num_clusters);
517        let centroids = CoarseCentroids::train(&config, &vectors);
518
519        // Serialize and deserialize
520        let bytes = centroids.to_bytes().unwrap();
521        let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
522
523        assert_eq!(loaded.num_clusters, centroids.num_clusters);
524        assert_eq!(loaded.dim, centroids.dim);
525        assert_eq!(loaded.centroids.len(), centroids.centroids.len());
526    }
527}