hermes_core/structures/vector/ivf/
coarse.rs1use std::io::{self, Cursor, Read, Write};
7use std::path::Path;
8
9use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
10use serde::{Deserialize, Serialize};
11
12use super::soar::{MultiAssignment, SoarConfig};
13
14const CENTROIDS_MAGIC: u32 = 0x48435643; #[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CoarseConfig {
20 pub num_clusters: usize,
22 pub dim: usize,
24 pub max_iters: usize,
26 pub seed: u64,
28 pub soar: Option<SoarConfig>,
30}
31
32impl CoarseConfig {
33 pub fn new(dim: usize, num_clusters: usize) -> Self {
34 Self {
35 num_clusters,
36 dim,
37 max_iters: 25,
38 seed: 42,
39 soar: None,
40 }
41 }
42
43 pub fn with_soar(mut self, config: SoarConfig) -> Self {
44 self.soar = Some(config);
45 self
46 }
47
48 pub fn with_seed(mut self, seed: u64) -> Self {
49 self.seed = seed;
50 self
51 }
52
53 pub fn with_max_iters(mut self, iters: usize) -> Self {
54 self.max_iters = iters;
55 self
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CoarseCentroids {
62 pub num_clusters: u32,
64 pub dim: usize,
66 pub centroids: Vec<f32>,
68 pub version: u64,
70 pub soar_config: Option<SoarConfig>,
72}
73
74impl CoarseCentroids {
75 #[cfg(feature = "native")]
79 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
80 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
81
82 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
83 assert!(config.num_clusters > 0, "Need at least 1 cluster");
84
85 let actual_clusters = config.num_clusters.min(vectors.len());
86 let dim = config.dim;
87
88 let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
90
91 let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
94 let result = kmean.kmeans_lloyd(
95 actual_clusters,
96 config.max_iters,
97 KMeans::init_kmeanplusplus,
98 &KMeansConfig::default(),
99 );
100
101 let centroids: Vec<f32> = result
103 .centroids
104 .iter()
105 .flat_map(|c| c.iter().copied())
106 .collect();
107
108 let version = std::time::SystemTime::now()
109 .duration_since(std::time::UNIX_EPOCH)
110 .unwrap_or_default()
111 .as_millis() as u64;
112
113 Self {
114 num_clusters: actual_clusters as u32,
115 dim,
116 centroids,
117 version,
118 soar_config: config.soar.clone(),
119 }
120 }
121
122 #[cfg(not(feature = "native"))]
124 pub fn train(config: &CoarseConfig, vectors: &[Vec<f32>]) -> Self {
125 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
126 assert!(config.num_clusters > 0, "Need at least 1 cluster");
127
128 let actual_clusters = config.num_clusters.min(vectors.len());
129 let dim = config.dim;
130 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
131
132 let mut indices: Vec<usize> = (0..vectors.len()).collect();
134 indices.shuffle(&mut rng);
135
136 let mut centroids: Vec<f32> = indices[..actual_clusters]
137 .iter()
138 .flat_map(|&i| vectors[i].iter().copied())
139 .collect();
140
141 for _ in 0..config.max_iters {
143 let assignments: Vec<usize> = vectors
144 .iter()
145 .map(|v| Self::find_nearest_idx_static(v, ¢roids, dim))
146 .collect();
147
148 let mut new_centroids = vec![0.0f32; actual_clusters * dim];
149 let mut counts = vec![0usize; actual_clusters];
150
151 for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
152 counts[cluster_id] += 1;
153 let offset = cluster_id * dim;
154 for (i, &val) in vectors[vec_idx].iter().enumerate() {
155 new_centroids[offset + i] += val;
156 }
157 }
158
159 for cluster_id in 0..actual_clusters {
160 if counts[cluster_id] > 0 {
161 let offset = cluster_id * dim;
162 for i in 0..dim {
163 new_centroids[offset + i] /= counts[cluster_id] as f32;
164 }
165 }
166 }
167
168 centroids = new_centroids;
169 }
170
171 let version = std::time::SystemTime::now()
172 .duration_since(std::time::UNIX_EPOCH)
173 .unwrap_or_default()
174 .as_millis() as u64;
175
176 Self {
177 num_clusters: actual_clusters as u32,
178 dim,
179 centroids,
180 version,
181 soar_config: config.soar.clone(),
182 }
183 }
184
185 fn find_nearest_idx_static(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
187 let num_clusters = centroids.len() / dim;
188 let mut best_idx = 0;
189 let mut best_dist = f32::MAX;
190
191 for c in 0..num_clusters {
192 let offset = c * dim;
193 let dist: f32 = vector
194 .iter()
195 .zip(¢roids[offset..offset + dim])
196 .map(|(&a, &b)| (a - b) * (a - b))
197 .sum();
198
199 if dist < best_dist {
200 best_dist = dist;
201 best_idx = c;
202 }
203 }
204
205 best_idx
206 }
207
208 pub fn find_nearest(&self, vector: &[f32]) -> u32 {
210 Self::find_nearest_idx_static(vector, &self.centroids, self.dim) as u32
211 }
212
213 pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
215 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
216 .map(|c| {
217 let offset = c as usize * self.dim;
218 let dist: f32 = vector
219 .iter()
220 .zip(&self.centroids[offset..offset + self.dim])
221 .map(|(&a, &b)| (a - b) * (a - b))
222 .sum();
223 (c, dist)
224 })
225 .collect();
226
227 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
228 distances.truncate(k);
229 distances.into_iter().map(|(c, _)| c).collect()
230 }
231
232 pub fn find_k_nearest_with_distances(&self, vector: &[f32], k: usize) -> Vec<(u32, f32)> {
234 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
235 .map(|c| {
236 let offset = c as usize * self.dim;
237 let dist: f32 = vector
238 .iter()
239 .zip(&self.centroids[offset..offset + self.dim])
240 .map(|(&a, &b)| (a - b) * (a - b))
241 .sum();
242 (c, dist)
243 })
244 .collect();
245
246 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
247 distances.truncate(k);
248 distances
249 }
250
251 pub fn assign(&self, vector: &[f32]) -> MultiAssignment {
253 if let Some(ref soar_config) = self.soar_config {
254 self.assign_with_soar(vector, soar_config)
255 } else {
256 MultiAssignment {
257 primary_cluster: self.find_nearest(vector),
258 secondary_clusters: Vec::new(),
259 }
260 }
261 }
262
263 pub fn assign_with_soar(&self, vector: &[f32], config: &SoarConfig) -> MultiAssignment {
265 let primary = self.find_nearest(vector);
267 let primary_centroid = self.get_centroid(primary);
268
269 let residual: Vec<f32> = vector
271 .iter()
272 .zip(primary_centroid)
273 .map(|(v, c)| v - c)
274 .collect();
275
276 let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
277
278 if config.selective && residual_norm_sq < config.spill_threshold * config.spill_threshold {
280 return MultiAssignment {
281 primary_cluster: primary,
282 secondary_clusters: Vec::new(),
283 };
284 }
285
286 let mut candidates: Vec<(u32, f32)> = (0..self.num_clusters)
288 .filter(|&c| c != primary)
289 .map(|c| {
290 let centroid = self.get_centroid(c);
291 let dot: f32 = vector
294 .iter()
295 .zip(centroid)
296 .zip(&residual)
297 .map(|((v, c), r)| (v - c) * r)
298 .sum();
299 (c, dot.abs())
300 })
301 .collect();
302
303 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
305
306 MultiAssignment {
307 primary_cluster: primary,
308 secondary_clusters: candidates
309 .iter()
310 .take(config.num_secondary)
311 .map(|(c, _)| *c)
312 .collect(),
313 }
314 }
315
316 pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
318 let offset = cluster_id as usize * self.dim;
319 &self.centroids[offset..offset + self.dim]
320 }
321
322 pub fn compute_residual(&self, vector: &[f32], cluster_id: u32) -> Vec<f32> {
324 let centroid = self.get_centroid(cluster_id);
325 vector.iter().zip(centroid).map(|(&v, &c)| v - c).collect()
326 }
327
328 pub fn save(&self, path: &Path) -> io::Result<()> {
330 let mut file = std::fs::File::create(path)?;
331 self.write_to(&mut file)
332 }
333
334 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
336 writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
337 writer.write_u32::<LittleEndian>(2)?; writer.write_u64::<LittleEndian>(self.version)?;
339 writer.write_u32::<LittleEndian>(self.num_clusters)?;
340 writer.write_u32::<LittleEndian>(self.dim as u32)?;
341
342 if let Some(ref soar) = self.soar_config {
344 writer.write_u8(1)?;
345 writer.write_u32::<LittleEndian>(soar.num_secondary as u32)?;
346 writer.write_u8(if soar.selective { 1 } else { 0 })?;
347 writer.write_f32::<LittleEndian>(soar.spill_threshold)?;
348 } else {
349 writer.write_u8(0)?;
350 }
351
352 for &val in &self.centroids {
353 writer.write_f32::<LittleEndian>(val)?;
354 }
355
356 Ok(())
357 }
358
359 pub fn load(path: &Path) -> io::Result<Self> {
361 let data = std::fs::read(path)?;
362 Self::read_from(&mut Cursor::new(data))
363 }
364
365 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
367 let magic = reader.read_u32::<LittleEndian>()?;
368 if magic != CENTROIDS_MAGIC {
369 return Err(io::Error::new(
370 io::ErrorKind::InvalidData,
371 "Invalid centroids file magic",
372 ));
373 }
374
375 let file_version = reader.read_u32::<LittleEndian>()?;
376 let version = reader.read_u64::<LittleEndian>()?;
377 let num_clusters = reader.read_u32::<LittleEndian>()?;
378 let dim = reader.read_u32::<LittleEndian>()? as usize;
379
380 let soar_config = if file_version >= 2 {
382 let has_soar = reader.read_u8()? != 0;
383 if has_soar {
384 let num_secondary = reader.read_u32::<LittleEndian>()? as usize;
385 let selective = reader.read_u8()? != 0;
386 let spill_threshold = reader.read_f32::<LittleEndian>()?;
387 Some(SoarConfig {
388 num_secondary,
389 selective,
390 spill_threshold,
391 })
392 } else {
393 None
394 }
395 } else {
396 None
397 };
398
399 let mut centroids = vec![0.0f32; num_clusters as usize * dim];
400 for val in &mut centroids {
401 *val = reader.read_f32::<LittleEndian>()?;
402 }
403
404 Ok(Self {
405 num_clusters,
406 dim,
407 centroids,
408 version,
409 soar_config,
410 })
411 }
412
413 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
415 let mut buf = Vec::new();
416 self.write_to(&mut buf)?;
417 Ok(buf)
418 }
419
420 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
422 Self::read_from(&mut Cursor::new(data))
423 }
424
425 pub fn size_bytes(&self) -> usize {
427 self.centroids.len() * 4 + 64 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use rand::prelude::*;
435
436 #[test]
437 fn test_coarse_centroids_basic() {
438 let dim = 64;
439 let n = 1000;
440 let num_clusters = 16;
441
442 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
443 let vectors: Vec<Vec<f32>> = (0..n)
444 .map(|_| (0..dim).map(|_| rng.random::<f32>() - 0.5).collect())
445 .collect();
446
447 let config = CoarseConfig::new(dim, num_clusters);
448 let centroids = CoarseCentroids::train(&config, &vectors);
449
450 assert_eq!(centroids.num_clusters, num_clusters as u32);
451 assert_eq!(centroids.dim, dim);
452 }
453
454 #[test]
455 fn test_find_nearest() {
456 let dim = 32;
457 let n = 500;
458 let num_clusters = 8;
459
460 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
461 let vectors: Vec<Vec<f32>> = (0..n)
462 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
463 .collect();
464
465 let config = CoarseConfig::new(dim, num_clusters);
466 let centroids = CoarseCentroids::train(&config, &vectors);
467
468 for v in &vectors {
470 let cluster = centroids.find_nearest(v);
471 assert!(cluster < centroids.num_clusters);
472 }
473 }
474
475 #[test]
476 fn test_soar_assignment() {
477 let dim = 32;
478 let n = 100;
479 let num_clusters = 8;
480
481 let mut rng = rand::rngs::StdRng::seed_from_u64(456);
482 let vectors: Vec<Vec<f32>> = (0..n)
483 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
484 .collect();
485
486 let soar_config = SoarConfig {
487 num_secondary: 2,
488 selective: false,
489 spill_threshold: 0.0,
490 };
491 let config = CoarseConfig::new(dim, num_clusters).with_soar(soar_config);
492 let centroids = CoarseCentroids::train(&config, &vectors);
493
494 let assignment = centroids.assign(&vectors[0]);
496 assert!(assignment.primary_cluster < centroids.num_clusters);
497 assert_eq!(assignment.secondary_clusters.len(), 2);
498
499 for &sec in &assignment.secondary_clusters {
501 assert_ne!(sec, assignment.primary_cluster);
502 }
503 }
504
505 #[test]
506 fn test_serialization() {
507 let dim = 16;
508 let n = 50;
509 let num_clusters = 4;
510
511 let mut rng = rand::rngs::StdRng::seed_from_u64(789);
512 let vectors: Vec<Vec<f32>> = (0..n)
513 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
514 .collect();
515
516 let config = CoarseConfig::new(dim, num_clusters);
517 let centroids = CoarseCentroids::train(&config, &vectors);
518
519 let bytes = centroids.to_bytes().unwrap();
521 let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
522
523 assert_eq!(loaded.num_clusters, centroids.num_clusters);
524 assert_eq!(loaded.dim, centroids.dim);
525 assert_eq!(loaded.centroids.len(), centroids.centroids.len());
526 }
527}