hermes_core/structures/vector/ivf/
coarse.rs1use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10#[cfg(not(feature = "native"))]
11use rand::SeedableRng;
12#[cfg(not(feature = "native"))]
13use rand::prelude::SliceRandom;
14use serde::{Deserialize, Serialize};
15
16use super::soar::{MultiAssignment, SoarConfig};
17
18const CENTROIDS_MAGIC: u32 = 0x48435643; #[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CoarseConfig {
24 pub num_clusters: usize,
26 pub dim: usize,
28 pub max_iters: usize,
30 pub seed: u64,
32 pub soar: Option<SoarConfig>,
34}
35
36impl CoarseConfig {
37 pub fn new(dim: usize, num_clusters: usize) -> Self {
38 Self {
39 num_clusters,
40 dim,
41 max_iters: 25,
42 seed: 42,
43 soar: None,
44 }
45 }
46
47 pub fn with_soar(mut self, config: SoarConfig) -> Self {
48 self.soar = Some(config);
49 self
50 }
51
52 pub fn with_seed(mut self, seed: u64) -> Self {
53 self.seed = seed;
54 self
55 }
56
57 pub fn with_max_iters(mut self, iters: usize) -> Self {
58 self.max_iters = iters;
59 self
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CoarseCentroids {
66 pub num_clusters: u32,
68 pub dim: usize,
70 pub centroids: Vec<f32>,
72 pub version: u64,
74 pub soar_config: Option<SoarConfig>,
76}
77
78impl CoarseCentroids {
79 #[cfg(feature = "native")]
83 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
84 use kentro::KMeans;
85 use ndarray::Array2;
86
87 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
88 assert!(config.num_clusters > 0, "Need at least 1 cluster");
89
90 let actual_clusters = config.num_clusters.min(vectors.len());
91 let dim = config.dim;
92
93 let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
95 let data = Array2::from_shape_vec((vectors.len(), dim), flat)
96 .expect("Failed to create ndarray from vectors");
97
98 let mut kmeans = KMeans::new(actual_clusters)
100 .with_euclidean(true)
101 .with_iterations(config.max_iters);
102 let _ = kmeans
103 .train(data.view(), None)
104 .expect("K-means training failed");
105
106 let centroids: Vec<f32> = kmeans
108 .centroids()
109 .expect("No centroids after training")
110 .iter()
111 .copied()
112 .collect();
113
114 let version = std::time::SystemTime::now()
115 .duration_since(std::time::UNIX_EPOCH)
116 .unwrap_or_default()
117 .as_millis() as u64;
118
119 Self {
120 num_clusters: actual_clusters as u32,
121 dim,
122 centroids,
123 version,
124 soar_config: config.soar.clone(),
125 }
126 }
127
128 #[cfg(not(feature = "native"))]
130 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
131 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
132 assert!(config.num_clusters > 0, "Need at least 1 cluster");
133
134 let actual_clusters = config.num_clusters.min(vectors.len());
135 let dim = config.dim;
136 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
137
138 let mut indices: Vec<usize> = (0..vectors.len()).collect();
140 indices.shuffle(&mut rng);
141
142 let mut centroids: Vec<f32> = indices[..actual_clusters]
143 .iter()
144 .flat_map(|&i| vectors[i].iter().copied())
145 .collect();
146
147 for _ in 0..config.max_iters {
149 let assignments: Vec<usize> = vectors
150 .iter()
151 .map(|v| Self::find_nearest_idx_static(v, ¢roids, dim))
152 .collect();
153
154 let mut new_centroids = vec![0.0f32; actual_clusters * dim];
155 let mut counts = vec![0usize; actual_clusters];
156
157 for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
158 counts[cluster_id] += 1;
159 let offset = cluster_id * dim;
160 for (i, &val) in vectors[vec_idx].iter().enumerate() {
161 new_centroids[offset + i] += val;
162 }
163 }
164
165 for (cluster_id, &count) in counts.iter().enumerate().take(actual_clusters) {
166 if count > 0 {
167 let offset = cluster_id * dim;
168 for i in 0..dim {
169 new_centroids[offset + i] /= count as f32;
170 }
171 }
172 }
173
174 centroids = new_centroids;
175 }
176
177 let version = std::time::SystemTime::now()
178 .duration_since(std::time::UNIX_EPOCH)
179 .unwrap_or_default()
180 .as_millis() as u64;
181
182 Self {
183 num_clusters: actual_clusters as u32,
184 dim,
185 centroids,
186 version,
187 soar_config: config.soar.clone(),
188 }
189 }
190
191 fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
193 let num_clusters = centroids.len() / dim;
194 let mut best_idx = 0;
195 let mut best_dist = f32::MAX;
196
197 for c in 0..num_clusters {
198 let offset = c * dim;
199 let dist: f32 = vector
200 .iter()
201 .zip(¢roids[offset..offset + dim])
202 .map(|(&a, &b)| (a - b) * (a - b))
203 .sum();
204
205 if dist < best_dist {
206 best_dist = dist;
207 best_idx = c;
208 }
209 }
210
211 best_idx
212 }
213
214 pub fn find_nearest(&self, vector: &[f32]) -> u32 {
216 Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
217 }
218
219 pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
221 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
222 .map(|c| {
223 let offset = c as usize * self.dim;
224 let dist: f32 = vector
225 .iter()
226 .zip(&self.centroids[offset..offset + self.dim])
227 .map(|(&a, &b)| (a - b) * (a - b))
228 .sum();
229 (c, dist)
230 })
231 .collect();
232
233 if distances.len() > k {
235 distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
236 distances.truncate(k);
237 }
238 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
239 distances.into_iter().map(|(c, _)| c).collect()
240 }
241
242 pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
244 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
245 .map(|c| {
246 let offset = c as usize * self.dim;
247 let dist: f32 = vector
248 .iter()
249 .zip(&self.centroids[offset..offset + self.dim])
250 .map(|(&a, &b)| (a - b) * (a - b))
251 .sum();
252 (c, dist)
253 })
254 .collect();
255
256 if distances.len() > k {
258 distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
259 distances.truncate(k);
260 }
261 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
262 distances
263 }
264
265 pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
267 if let Some(ref soar_config) = self.soar_config {
268 self.assign_with_soar(vector, soar_config)
269 } else {
270 MultiAssignment {
271 primary_cluster: self.find_nearest(vector),
272 secondary_clusters: Vec::new(),
273 }
274 }
275 }
276
277 pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
279 let primary = self.find_nearest(vector);
281 let primary_centroid = self.get_centroid(primary);
282
283 let residual: Vec<f32> = vector
285 .iter()
286 .zip(primary_centroid)
287 .map(|(v, c)| v - c)
288 .collect();
289
290 let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
291
292 if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
294 return MultiAssignment {
295 primary_cluster: primary,
296 secondary_clusters: Vec::new(),
297 };
298 }
299
300 let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
302 .filter(|&c| c != primary)
303 .map(|c| {
304 let centroid = self.get_centroid(c);
305 let dot: f32 = vector
308 .iter()
309 .zip(centroid)
310 .zip(&residual)
311 .map(|((v, c), r)| (v - c) * r)
312 .sum();
313 (c, dot.abs())
314 })
315 .collect();
316
317 let take = config.num_secondary.min(candidates.len());
319 if candidates.len() > take {
320 candidates.select_nth_unstable_by(take, |a, b| a.1.partial_cmp(&b.1).unwrap());
321 candidates.truncate(take);
322 }
323
324 MultiAssignment {
325 primary_cluster: primary,
326 secondary_clusters: candidates
327 .iter()
328 .take(config.num_secondary)
329 .map(|(c, _)| *c)
330 .collect(),
331 }
332 }
333
334 pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
336 let offset = cluster_id as usize * self.dim;
337 &self.centroids[offset..offset + self.dim]
338 }
339
340 pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
342 let centroid = self.get_centroid(cluster_id);
343 vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
344 }
345
346 pub fn save(&self, path: &Path) -> io::Result<()> {
348 let mut file = std::fs::File::create(path)?;
349 self.write_to(&mut file)
350 }
351
352 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
354 writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
355 writer.write_u32::<LittleEndian>(2)?; writer.write_u64::<LittleEndian>(self.version)?;
357 writer.write_u32::<LittleEndian>(self.num_clusters)?;
358 writer.write_u32::<LittleEndian>(self.dim as u32)?;
359
360 if let Some(ref soar) = self.soar_config {
362 writer.write_u8(1)?;
363 writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
364 writer.write_u8(if soar.selective { 1 } else { 0 })?;
365 writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
366 } else {
367 writer.write_u8(0)?;
368 }
369
370 for &val in &self.centroids {
371 writer.write_f32::<LittleEndian>(val)?;
372 }
373
374 Ok(())
375 }
376
377 pub fn load(path: &Path) -> io::Result<Self> {
379 let data = std::fs::read(path)?;
380 Self::read_from(&mut Cursor::new(data))
381 }
382
383 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
385 let magic = reader.read_u32::<LittleEndian>()?;
386 if magic != CENTROIDS_MAGIC {
387 return Err(io::Error::new(
388 io::ErrorKind::InvalidData,
389 "Invalid centroids file magic",
390 ));
391 }
392
393 let file_version = reader.read_u32::<LittleEndian>()?;
394 let version = reader.read_u64::<LittleEndian>()?;
395 let num_clusters = reader.read_u32::<LittleEndian>()?;
396 let dim = reader.read_u32::<LittleEndian>()? as usize;
397
398 let soar_config = if file_version >= 2 {
400 let has_soar = reader.read_u8()? != 0;
401 if has_soar {
402 let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
403 let selective = reader.read_u8()? != 0;
404 let spill_threshold = reader.read_f32::<LittleEndian>()?;
405 Some(SoarConfig {
406 num_secondary,
407 selective,
408 spill_threshold,
409 })
410 } else {
411 None
412 }
413 } else {
414 None
415 };
416
417 let mut centroids = vec![0.0f32; num_clusters as usize * dim];
418 for val in &mut centroids {
419 *val = reader.read_f32::<LittleEndian>()?;
420 }
421
422 Ok(Self {
423 num_clusters,
424 dim,
425 centroids,
426 version,
427 soar_config,
428 })
429 }
430
431 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
433 let mut buf = Vec::new();
434 self.write_to(&mut buf)?;
435 Ok(buf)
436 }
437
438 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
440 Self::read_from(&mut Cursor::new(data))
441 }
442
443 pub fn size_bytes(&self) -> usize {
445 self.centroids.len() * 4 + 64 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use rand::prelude::*;
453
454 #[test]
455 fn test_coarse_centroids_basic() {
456 let dim = 64;
457 let n = 1000;
458 let num_clusters = 16;
459
460 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
461 let vectors: Vec<Vec<f32>> = (0..n)
462 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
463 .collect();
464
465 let config = CoarseConfig::new(dim, num_clusters);
466 let centroids = CoarseCentroids::train(&config, &vectors);
467
468 assert_eq!(centroids.num_clusters, num_clusters as u32);
469 assert_eq!(centroids.dim, dim);
470 }
471
472 #[test]
473 fn test_find_nearest() {
474 let dim = 32;
475 let n = 500;
476 let num_clusters = 8;
477
478 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
479 let vectors: Vec<Vec<f32>> = (0..n)
480 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
481 .collect();
482
483 let config = CoarseConfig::new(dim, num_clusters);
484 let centroids = CoarseCentroids::train(&config, &vectors);
485
486 for v in &vectors {
488 let cluster = centroids.find_nearest(v);
489 assert!(cluster < centroids.num_clusters);
490 }
491 }
492
493 #[test]
494 fn test_soar_assignment() {
495 let dim = 32;
496 let n = 100;
497 let num_clusters = 8;
498
499 let mut rng = rand::rngs::StdRng::seed_from_u64(456);
500 let vectors: Vec<Vec<f32>> = (0..n)
501 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
502 .collect();
503
504 let soar_config = SoarConfig {
505 num_secondary: 2,
506 selective: false,
507 spill_threshold: 0.0,
508 };
509 let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
510 let centroids = CoarseCentroids::train(&config, &vectors);
511
512 let assignment = centroids.assign(&vectors[0]);
514 assert!(assignment.primary_cluster < centroids.num_clusters);
515 assert_eq!(assignment.secondary_clusters.len(), 2);
516
517 for &sec in &assignment.secondary_clusters {
519 assert_ne!(sec, assignment.primary_cluster);
520 }
521 }
522
523 #[test]
524 fn test_serialization() {
525 let dim = 16;
526 let n = 50;
527 let num_clusters = 4;
528
529 let mut rng = rand::rngs::StdRng::seed_from_u64(789);
530 let vectors: Vec<Vec<f32>> = (0..n)
531 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
532 .collect();
533
534 let config = CoarseConfig::new(dim, num_clusters);
535 let centroids = CoarseCentroids::train(&config, &vectors);
536
537 let bytes = centroids.to_bytes().unwrap();
539 let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
540
541 assert_eq!(loaded.num_clusters, centroids.num_clusters);
542 assert_eq!(loaded.dim, centroids.dim);
543 assert_eq!(loaded.centroids.len(), centroids.centroids.len());
544 }
545}