hermes_core/structures/vector/ivf/
coarse.rs1use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10#[cfg(not(feature = "native"))]
11use rand::SeedableRng;
12#[cfg(not(feature = "native"))]
13use rand::prelude::SliceRandom;
14use serde::{Deserialize, Serialize};
15
16use super::soar::{MultiAssignment, SoarConfig};
17
18const CENTROIDS_MAGIC: u32 = 0x48435643; #[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CoarseConfig {
24 pub num_clusters: usize,
26 pub dim: usize,
28 pub max_iters: usize,
30 pub seed: u64,
32 pub soar: Option<SoarConfig>,
34}
35
36impl CoarseConfig {
37 pub fn new(dim: usize, num_clusters: usize) -> Self {
38 Self {
39 num_clusters,
40 dim,
41 max_iters: 25,
42 seed: 42,
43 soar: None,
44 }
45 }
46
47 pub fn with_soar(mut self, config: SoarConfig) -> Self {
48 self.soar = Some(config);
49 self
50 }
51
52 pub fn with_seed(mut self, seed: u64) -> Self {
53 self.seed = seed;
54 self
55 }
56
57 pub fn with_max_iters(mut self, iters: usize) -> Self {
58 self.max_iters = iters;
59 self
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct CoarseCentroids {
66 pub num_clusters: u32,
68 pub dim: usize,
70 pub centroids: Vec<f32>,
72 pub version: u64,
74 pub soar_config: Option<SoarConfig>,
76}
77
78impl CoarseCentroids {
79 #[cfg(feature = "native")]
83 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
84 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
85
86 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
87 assert!(config.num_clusters > 0, "Need at least 1 cluster");
88
89 let actual_clusters = config.num_clusters.min(vectors.len());
90 let dim = config.dim;
91
92 let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
94
95 let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
98 let result = kmean.kmeans_lloyd(
99 actual_clusters,
100 config.max_iters,
101 KMeans::init_kmeanplusplus,
102 &KMeansConfig::default(),
103 );
104
105 let centroids: Vec<f32> = result
107 .centroids
108 .iter()
109 .flat_map(|c| c.iter().copied())
110 .collect();
111
112 let version = std::time::SystemTime::now()
113 .duration_since(std::time::UNIX_EPOCH)
114 .unwrap_or_default()
115 .as_millis() as u64;
116
117 Self {
118 num_clusters: actual_clusters as u32,
119 dim,
120 centroids,
121 version,
122 soar_config: config.soar.clone(),
123 }
124 }
125
126 #[cfg(not(feature = "native"))]
128 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
129 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
130 assert!(config.num_clusters > 0, "Need at least 1 cluster");
131
132 let actual_clusters = config.num_clusters.min(vectors.len());
133 let dim = config.dim;
134 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
135
136 let mut indices: Vec<usize> = (0..vectors.len()).collect();
138 indices.shuffle(&mut rng);
139
140 let mut centroids: Vec<f32> = indices[..actual_clusters]
141 .iter()
142 .flat_map(|&i| vectors[i].iter().copied())
143 .collect();
144
145 for _ in 0..config.max_iters {
147 let assignments: Vec<usize> = vectors
148 .iter()
149 .map(|v| Self::find_nearest_idx_static(v, ¢roids, dim))
150 .collect();
151
152 let mut new_centroids = vec![0.0f32; actual_clusters * dim];
153 let mut counts = vec![0usize; actual_clusters];
154
155 for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
156 counts[cluster_id] += 1;
157 let offset = cluster_id * dim;
158 for (i, &val) in vectors[vec_idx].iter().enumerate() {
159 new_centroids[offset + i] += val;
160 }
161 }
162
163 for cluster_id in 0..actual_clusters {
164 if counts[cluster_id] > 0 {
165 let offset = cluster_id * dim;
166 for i in 0..dim {
167 new_centroids[offset + i] /= counts[cluster_id] as f32;
168 }
169 }
170 }
171
172 centroids = new_centroids;
173 }
174
175 let version = std::time::SystemTime::now()
176 .duration_since(std::time::UNIX_EPOCH)
177 .unwrap_or_default()
178 .as_millis() as u64;
179
180 Self {
181 num_clusters: actual_clusters as u32,
182 dim,
183 centroids,
184 version,
185 soar_config: config.soar.clone(),
186 }
187 }
188
189 fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
191 let num_clusters = centroids.len() / dim;
192 let mut best_idx = 0;
193 let mut best_dist = f32::MAX;
194
195 for c in 0..num_clusters {
196 let offset = c * dim;
197 let dist: f32 = vector
198 .iter()
199 .zip(¢roids[offset..offset + dim])
200 .map(|(&a, &b)| (a - b) * (a - b))
201 .sum();
202
203 if dist < best_dist {
204 best_dist = dist;
205 best_idx = c;
206 }
207 }
208
209 best_idx
210 }
211
212 pub fn find_nearest(&self, vector: &[f32]) -> u32 {
214 Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
215 }
216
217 pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
219 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
220 .map(|c| {
221 let offset = c as usize * self.dim;
222 let dist: f32 = vector
223 .iter()
224 .zip(&self.centroids[offset..offset + self.dim])
225 .map(|(&a, &b)| (a - b) * (a - b))
226 .sum();
227 (c, dist)
228 })
229 .collect();
230
231 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
232 distances.truncate(k);
233 distances.into_iter().map(|(c, _)| c).collect()
234 }
235
236 pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
238 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
239 .map(|c| {
240 let offset = c as usize * self.dim;
241 let dist: f32 = vector
242 .iter()
243 .zip(&self.centroids[offset..offset + self.dim])
244 .map(|(&a, &b)| (a - b) * (a - b))
245 .sum();
246 (c, dist)
247 })
248 .collect();
249
250 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
251 distances.truncate(k);
252 distances
253 }
254
255 pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
257 if let Some(ref soar_config) = self.soar_config {
258 self.assign_with_soar(vector, soar_config)
259 } else {
260 MultiAssignment {
261 primary_cluster: self.find_nearest(vector),
262 secondary_clusters: Vec::new(),
263 }
264 }
265 }
266
267 pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
269 let primary = self.find_nearest(vector);
271 let primary_centroid = self.get_centroid(primary);
272
273 let residual: Vec<f32> = vector
275 .iter()
276 .zip(primary_centroid)
277 .map(|(v, c)| v - c)
278 .collect();
279
280 let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
281
282 if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
284 return MultiAssignment {
285 primary_cluster: primary,
286 secondary_clusters: Vec::new(),
287 };
288 }
289
290 let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
292 .filter(|&c| c != primary)
293 .map(|c| {
294 let centroid = self.get_centroid(c);
295 let dot: f32 = vector
298 .iter()
299 .zip(centroid)
300 .zip(&residual)
301 .map(|((v, c), r)| (v - c) * r)
302 .sum();
303 (c, dot.abs())
304 })
305 .collect();
306
307 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
309
310 MultiAssignment {
311 primary_cluster: primary,
312 secondary_clusters: candidates
313 .iter()
314 .take(config.num_secondary)
315 .map(|(c, _)| *c)
316 .collect(),
317 }
318 }
319
320 pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
322 let offset = cluster_id as usize * self.dim;
323 &self.centroids[offset..offset + self.dim]
324 }
325
326 pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
328 let centroid = self.get_centroid(cluster_id);
329 vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
330 }
331
332 pub fn save(&self, path: &Path) -> io::Result<()> {
334 let mut file = std::fs::File::create(path)?;
335 self.write_to(&mut file)
336 }
337
338 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
340 writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
341 writer.write_u32::<LittleEndian>(2)?; writer.write_u64::<LittleEndian>(self.version)?;
343 writer.write_u32::<LittleEndian>(self.num_clusters)?;
344 writer.write_u32::<LittleEndian>(self.dim as u32)?;
345
346 if let Some(ref soar) = self.soar_config {
348 writer.write_u8(1)?;
349 writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
350 writer.write_u8(if soar.selective { 1 } else { 0 })?;
351 writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
352 } else {
353 writer.write_u8(0)?;
354 }
355
356 for &val in &self.centroids {
357 writer.write_f32::<LittleEndian>(val)?;
358 }
359
360 Ok(())
361 }
362
363 pub fn load(path: &Path) -> io::Result<Self> {
365 let data = std::fs::read(path)?;
366 Self::read_from(&mut Cursor::new(data))
367 }
368
369 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
371 let magic = reader.read_u32::<LittleEndian>()?;
372 if magic != CENTROIDS_MAGIC {
373 return Err(io::Error::new(
374 io::ErrorKind::InvalidData,
375 "Invalid centroids file magic",
376 ));
377 }
378
379 let file_version = reader.read_u32::<LittleEndian>()?;
380 let version = reader.read_u64::<LittleEndian>()?;
381 let num_clusters = reader.read_u32::<LittleEndian>()?;
382 let dim = reader.read_u32::<LittleEndian>()? as usize;
383
384 let soar_config = if file_version >= 2 {
386 let has_soar = reader.read_u8()? != 0;
387 if has_soar {
388 let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
389 let selective = reader.read_u8()? != 0;
390 let spill_threshold = reader.read_f32::<LittleEndian>()?;
391 Some(SoarConfig {
392 num_secondary,
393 selective,
394 spill_threshold,
395 })
396 } else {
397 None
398 }
399 } else {
400 None
401 };
402
403 let mut centroids = vec![0.0f32; num_clusters as usize * dim];
404 for val in &mut centroids {
405 *val = reader.read_f32::<LittleEndian>()?;
406 }
407
408 Ok(Self {
409 num_clusters,
410 dim,
411 centroids,
412 version,
413 soar_config,
414 })
415 }
416
417 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
419 let mut buf = Vec::new();
420 self.write_to(&mut buf)?;
421 Ok(buf)
422 }
423
424 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
426 Self::read_from(&mut Cursor::new(data))
427 }
428
429 pub fn size_bytes(&self) -> usize {
431 self.centroids.len() * 4 + 64 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use rand::prelude::*;
439
440 #[test]
441 fn test_coarse_centroids_basic() {
442 let dim = 64;
443 let n = 1000;
444 let num_clusters = 16;
445
446 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
447 let vectors: Vec<Vec<f32>> = (0..n)
448 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
449 .collect();
450
451 let config = CoarseConfig::new(dim, num_clusters);
452 let centroids = CoarseCentroids::train(&config, &vectors);
453
454 assert_eq!(centroids.num_clusters, num_clusters as u32);
455 assert_eq!(centroids.dim, dim);
456 }
457
458 #[test]
459 fn test_find_nearest() {
460 let dim = 32;
461 let n = 500;
462 let num_clusters = 8;
463
464 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
465 let vectors: Vec<Vec<f32>> = (0..n)
466 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
467 .collect();
468
469 let config = CoarseConfig::new(dim, num_clusters);
470 let centroids = CoarseCentroids::train(&config, &vectors);
471
472 for v in &vectors {
474 let cluster = centroids.find_nearest(v);
475 assert!(cluster < centroids.num_clusters);
476 }
477 }
478
479 #[test]
480 fn test_soar_assignment() {
481 let dim = 32;
482 let n = 100;
483 let num_clusters = 8;
484
485 let mut rng = rand::rngs::StdRng::seed_from_u64(456);
486 let vectors: Vec<Vec<f32>> = (0..n)
487 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
488 .collect();
489
490 let soar_config = SoarConfig {
491 num_secondary: 2,
492 selective: false,
493 spill_threshold: 0.0,
494 };
495 let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
496 let centroids = CoarseCentroids::train(&config, &vectors);
497
498 let assignment = centroids.assign(&vectors[0]);
500 assert!(assignment.primary_cluster < centroids.num_clusters);
501 assert_eq!(assignment.secondary_clusters.len(), 2);
502
503 for &sec in &assignment.secondary_clusters {
505 assert_ne!(sec, assignment.primary_cluster);
506 }
507 }
508
509 #[test]
510 fn test_serialization() {
511 let dim = 16;
512 let n = 50;
513 let num_clusters = 4;
514
515 let mut rng = rand::rngs::StdRng::seed_from_u64(789);
516 let vectors: Vec<Vec<f32>> = (0..n)
517 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
518 .collect();
519
520 let config = CoarseConfig::new(dim, num_clusters);
521 let centroids = CoarseCentroids::train(&config, &vectors);
522
523 let bytes = centroids.to_bytes().unwrap();
525 let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
526
527 assert_eq!(loaded.num_clusters, centroids.num_clusters);
528 assert_eq!(loaded.dim, centroids.dim);
529 assert_eq!(loaded.centroids.len(), centroids.centroids.len());
530 }
531}