1use serde::{Deserialize, Serialize};
14
15use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
16use ailake_vec::{kmeans_centroids, PQCodebook};
17
18fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
20 if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
21 return result;
22 }
23 if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
24 return result;
25 }
26 kmeans_centroids(vecs, k, max_iter)
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct IvfPqConfig {
32 pub nlist: usize,
35 pub nprobe: usize,
38 pub pq_m: usize,
40 pub pq_k: usize,
42 pub max_iter: usize,
44}
45
46impl Default for IvfPqConfig {
47 fn default() -> Self {
48 Self {
49 nlist: 256,
50 nprobe: 8,
51 pq_m: 8,
52 pq_k: 256,
53 max_iter: 25,
54 }
55 }
56}
57
58impl IvfPqConfig {
59 pub fn for_dim(dim: usize) -> Self {
61 let pq_m = (dim / 16).clamp(4, 64);
62 Self {
63 pq_m: find_valid_pq_m(pq_m, dim),
64 ..Self::default()
65 }
66 }
67
68 pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
73 let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
74 let nprobe = (nlist / 8).max(1);
75 let pq_m_hint = (dim / 16).clamp(4, 64);
76 Self {
77 nlist,
78 nprobe,
79 pq_m: find_valid_pq_m(pq_m_hint, dim),
80 pq_k: 256,
81 max_iter: 25,
82 }
83 }
84}
85
86pub struct IvfPqIndex {
87 pub config: IvfPqConfig,
88 pub metric: VectorMetric,
89 pub dim: usize,
90 coarse_centroids: Vec<Vec<f32>>,
92 pq: PQCodebook,
94 inv_row_ids: Vec<Vec<u64>>,
96 inv_codes: Vec<Vec<u8>>,
98}
99
100impl IvfPqIndex {
101 pub fn train(
103 row_ids: &[RowId],
104 vectors: &[Vec<f32>],
105 metric: VectorMetric,
106 config: IvfPqConfig,
107 ) -> AilakeResult<Self> {
108 let n = vectors.len();
109 if n == 0 {
110 return Err(AilakeError::Catalog(
111 "IVF-PQ training requires at least 1 vector".into(),
112 ));
113 }
114 let dim = vectors[0].len();
115
116 let normed_storage: Vec<Vec<f32>>;
117 let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
118 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
119 &normed_storage
120 } else {
121 vectors
122 };
123
124 let nlist = config.nlist.min(n);
125 let nprobe = config.nprobe.min(nlist);
126 let pq_m = find_valid_pq_m(config.pq_m, dim);
127
128 let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
130
131 let assignments: Vec<usize> = vecs
133 .iter()
134 .map(|v| nearest_idx(v, &coarse_centroids))
135 .collect();
136
137 let pq = PQCodebook::train_with_kmeans(
139 vecs,
140 pq_m,
141 config.pq_k.min(256),
142 config.max_iter,
143 kmeans_dispatch,
144 )
145 .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
146
147 let mut inv_row_ids = vec![Vec::new(); nlist];
149 let mut inv_codes = vec![Vec::new(); nlist];
150
151 for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
152 let codes = pq.encode(v);
153 inv_row_ids[list_idx].push(row_ids[i].0);
154 inv_codes[list_idx].extend_from_slice(&codes);
155 }
156
157 Ok(IvfPqIndex {
158 config: IvfPqConfig {
159 nlist,
160 nprobe,
161 pq_m,
162 ..config
163 },
164 metric,
165 dim,
166 coarse_centroids,
167 pq,
168 inv_row_ids,
169 inv_codes,
170 })
171 }
172
173 pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
177 let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
178
179 let q_normed: Vec<f32>;
180 let q: &[f32] = if self.metric == VectorMetric::Cosine {
181 q_normed = l2_normalize(query);
182 &q_normed
183 } else {
184 query
185 };
186
187 let mut c_dists: Vec<(usize, f32)> = self
189 .coarse_centroids
190 .iter()
191 .enumerate()
192 .map(|(i, c)| (i, l2_sq(q, c)))
193 .collect();
194 c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
195 c_dists.truncate(nprobe);
196
197 let adc_table = self.pq.compute_adc_table(q);
199
200 let pq_m = self.config.pq_m;
202 let mut candidates: Vec<(RowId, f32)> = Vec::new();
203
204 for (list_idx, _) in &c_dists {
205 let row_ids = &self.inv_row_ids[*list_idx];
206 let codes_flat = &self.inv_codes[*list_idx];
207
208 for (j, &rid) in row_ids.iter().enumerate() {
209 let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
210 let dist = self.pq.adc_distance(codes, &adc_table);
211 candidates.push((RowId(rid), dist));
212 }
213 }
214
215 candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
216 candidates.truncate(top_k);
217 candidates
218 }
219
220 pub fn node_count(&self) -> u64 {
221 self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
222 }
223
224 pub fn dim(&self) -> usize {
225 self.dim
226 }
227}
228
229#[derive(Serialize, Deserialize)]
232struct IvfPqSnapshot {
233 nlist: usize,
234 nprobe: usize,
235 pq_m: usize,
236 pq_k: usize,
237 max_iter: usize,
238 dim: usize,
239 metric: u8,
240 coarse_flat: Vec<f32>, pq: PQCodebook,
242 inv_row_ids: Vec<Vec<u64>>,
243 inv_codes: Vec<Vec<u8>>, }
245
246pub struct IvfPqSerializer;
247
248impl IvfPqSerializer {
249 pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
250 let coarse_flat: Vec<f32> = index
251 .coarse_centroids
252 .iter()
253 .flat_map(|c| c.iter().copied())
254 .collect();
255 let snap = IvfPqSnapshot {
256 nlist: index.config.nlist,
257 nprobe: index.config.nprobe,
258 pq_m: index.config.pq_m,
259 pq_k: index.config.pq_k,
260 max_iter: index.config.max_iter,
261 dim: index.dim,
262 metric: metric_to_u8(index.metric),
263 coarse_flat,
264 pq: index.pq.clone(),
265 inv_row_ids: index.inv_row_ids.clone(),
266 inv_codes: index.inv_codes.clone(),
267 };
268 bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
269 }
270
271 pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
272 let snap: IvfPqSnapshot =
273 bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
274 let metric = u8_to_metric(snap.metric)?;
275 let coarse_centroids: Vec<Vec<f32>> = snap
276 .coarse_flat
277 .chunks_exact(snap.dim)
278 .map(|c| c.to_vec())
279 .collect();
280 Ok(IvfPqIndex {
281 config: IvfPqConfig {
282 nlist: snap.nlist,
283 nprobe: snap.nprobe,
284 pq_m: snap.pq_m,
285 pq_k: snap.pq_k,
286 max_iter: snap.max_iter,
287 },
288 metric,
289 dim: snap.dim,
290 coarse_centroids,
291 pq: snap.pq,
292 inv_row_ids: snap.inv_row_ids,
293 inv_codes: snap.inv_codes,
294 })
295 }
296}
297
298fn l2_normalize(v: &[f32]) -> Vec<f32> {
301 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
302 if norm < 1e-9 {
303 v.to_vec()
304 } else {
305 v.iter().map(|x| x / norm).collect()
306 }
307}
308
309fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
310 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
311}
312
313fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
314 centroids
315 .iter()
316 .enumerate()
317 .map(|(i, c)| (i, l2_sq(v, c)))
318 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
319 .map(|(i, _)| i)
320 .unwrap_or(0)
321}
322
323pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
325 for m in (1..=requested).rev() {
326 if dim.is_multiple_of(m) {
327 return m;
328 }
329 }
330 1
331}
332
333fn metric_to_u8(m: VectorMetric) -> u8 {
334 match m {
335 VectorMetric::Cosine => 0,
336 VectorMetric::Euclidean => 1,
337 VectorMetric::DotProduct => 2,
338 }
339}
340
341fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
342 match v {
343 0 => Ok(VectorMetric::Cosine),
344 1 => Ok(VectorMetric::Euclidean),
345 2 => Ok(VectorMetric::DotProduct),
346 _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
355 let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
356 let vecs: Vec<Vec<f32>> = (0..n)
357 .map(|i| {
358 let mut v = vec![0.0f32; dim];
359 v[i % dim] = 1.0;
360 v
361 })
362 .collect();
363 (row_ids, vecs)
364 }
365
366 #[test]
367 fn train_and_search_basic() {
368 let dim = 8;
369 let (ids, vecs) = make_vecs(64, dim);
370 let config = IvfPqConfig {
371 nlist: 4,
372 nprobe: 2,
373 pq_m: 2,
374 pq_k: 4,
375 max_iter: 10,
376 };
377 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
378 assert_eq!(idx.node_count(), 64);
379
380 let query = vecs[0].clone();
381 let results = idx.search(&query, 5, None);
382 assert!(!results.is_empty());
383 assert!(results[0].1 < 0.1, "nearest should be approximate self");
385 }
386
387 #[test]
388 fn train_cosine_normalizes() {
389 let dim = 4;
390 let (ids, vecs) = make_vecs(32, dim);
391 let config = IvfPqConfig {
392 nlist: 4,
393 nprobe: 2,
394 pq_m: 2,
395 pq_k: 4,
396 max_iter: 10,
397 };
398 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
399 let results = idx.search(&vecs[0], 1, None);
400 assert!(!results.is_empty());
401 }
402
403 #[test]
404 fn serialize_roundtrip() {
405 let dim = 8;
406 let (ids, vecs) = make_vecs(32, dim);
407 let config = IvfPqConfig {
408 nlist: 4,
409 nprobe: 2,
410 pq_m: 2,
411 pq_k: 4,
412 max_iter: 10,
413 };
414 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
415 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
416 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
417
418 assert_eq!(idx2.node_count(), idx.node_count());
419 assert_eq!(idx2.dim(), idx.dim());
420
421 let q = vecs[0].clone();
422 let r1 = idx.search(&q, 5, None);
423 let r2 = idx2.search(&q, 5, None);
424 assert_eq!(r1.len(), r2.len());
425 for (a, b) in r1.iter().zip(r2.iter()) {
426 assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
427 }
428 }
429
430 #[test]
431 fn nlist_clamped_to_n() {
432 let dim = 4;
433 let (ids, vecs) = make_vecs(10, dim); let config = IvfPqConfig {
435 nlist: 256, nprobe: 8,
437 pq_m: 2,
438 pq_k: 4,
439 max_iter: 5,
440 };
441 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
442 assert!(idx.config.nlist <= 10);
443 assert_eq!(idx.node_count(), 10);
444 }
445}