1use serde::{Deserialize, Serialize};
13
14use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
15use ailake_vec::{kmeans_centroids, PQCodebook};
16
17fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
19 if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
20 return result;
21 }
22 if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
23 return result;
24 }
25 kmeans_centroids(vecs, k, max_iter)
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct IvfPqConfig {
31 pub nlist: usize,
34 pub nprobe: usize,
37 pub pq_m: usize,
39 pub pq_k: usize,
41 pub max_iter: usize,
43}
44
45impl Default for IvfPqConfig {
46 fn default() -> Self {
47 Self {
48 nlist: 256,
49 nprobe: 8,
50 pq_m: 8,
51 pq_k: 256,
52 max_iter: 25,
53 }
54 }
55}
56
57impl IvfPqConfig {
58 pub fn for_dim(dim: usize) -> Self {
60 let pq_m = (dim / 16).clamp(4, 64);
61 Self {
62 pq_m: find_valid_pq_m(pq_m, dim),
63 ..Self::default()
64 }
65 }
66
67 pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
72 let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
73 let nprobe = (nlist / 8).max(1);
74 let pq_m_hint = (dim / 16).clamp(4, 64);
75 Self {
76 nlist,
77 nprobe,
78 pq_m: find_valid_pq_m(pq_m_hint, dim),
79 pq_k: 256,
80 max_iter: 25,
81 }
82 }
83}
84
85pub struct IvfPqIndex {
86 pub config: IvfPqConfig,
87 pub metric: VectorMetric,
88 pub dim: usize,
89 coarse_centroids: Vec<Vec<f32>>,
91 pq: PQCodebook,
93 inv_row_ids: Vec<Vec<u64>>,
95 inv_codes: Vec<Vec<u8>>,
97}
98
99impl IvfPqIndex {
100 pub fn train(
102 row_ids: &[RowId],
103 vectors: &[Vec<f32>],
104 metric: VectorMetric,
105 config: IvfPqConfig,
106 ) -> AilakeResult<Self> {
107 let n = vectors.len();
108 if n == 0 {
109 return Err(AilakeError::Catalog(
110 "IVF-PQ training requires at least 1 vector".into(),
111 ));
112 }
113 let dim = vectors[0].len();
114
115 let normed_storage: Vec<Vec<f32>>;
116 let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
117 normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
118 &normed_storage
119 } else {
120 vectors
121 };
122
123 let nlist = config.nlist.min(n);
124 let nprobe = config.nprobe.min(nlist);
125 let pq_m = find_valid_pq_m(config.pq_m, dim);
126
127 let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
129
130 let assignments: Vec<usize> = vecs
132 .iter()
133 .map(|v| nearest_idx(v, &coarse_centroids))
134 .collect();
135
136 let pq = PQCodebook::train_with_kmeans(
138 vecs,
139 pq_m,
140 config.pq_k.min(256),
141 config.max_iter,
142 kmeans_dispatch,
143 )
144 .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
145
146 let mut inv_row_ids = vec![Vec::new(); nlist];
148 let mut inv_codes = vec![Vec::new(); nlist];
149
150 for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
151 let codes = pq.encode(v);
152 inv_row_ids[list_idx].push(row_ids[i].0);
153 inv_codes[list_idx].extend_from_slice(&codes);
154 }
155
156 Ok(IvfPqIndex {
157 config: IvfPqConfig {
158 nlist,
159 nprobe,
160 pq_m,
161 ..config
162 },
163 metric,
164 dim,
165 coarse_centroids,
166 pq,
167 inv_row_ids,
168 inv_codes,
169 })
170 }
171
172 pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
176 let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
177
178 let q_normed: Vec<f32>;
179 let q: &[f32] = if self.metric == VectorMetric::Cosine {
180 q_normed = l2_normalize(query);
181 &q_normed
182 } else {
183 query
184 };
185
186 let mut c_dists: Vec<(usize, f32)> = self
188 .coarse_centroids
189 .iter()
190 .enumerate()
191 .map(|(i, c)| (i, l2_sq(q, c)))
192 .collect();
193 c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
194 c_dists.truncate(nprobe);
195
196 let adc_table = self.pq.compute_adc_table(q);
198
199 let pq_m = self.config.pq_m;
201 let mut candidates: Vec<(RowId, f32)> = Vec::new();
202
203 for (list_idx, _) in &c_dists {
204 let row_ids = &self.inv_row_ids[*list_idx];
205 let codes_flat = &self.inv_codes[*list_idx];
206
207 for (j, &rid) in row_ids.iter().enumerate() {
208 let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
209 let dist = self.pq.adc_distance(codes, &adc_table);
210 candidates.push((RowId(rid), dist));
211 }
212 }
213
214 candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
215 candidates.truncate(top_k);
216 candidates
217 }
218
219 pub fn node_count(&self) -> u64 {
220 self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
221 }
222
223 pub fn dim(&self) -> usize {
224 self.dim
225 }
226}
227
228#[derive(Serialize, Deserialize)]
231struct IvfPqSnapshot {
232 nlist: usize,
233 nprobe: usize,
234 pq_m: usize,
235 pq_k: usize,
236 max_iter: usize,
237 dim: usize,
238 metric: u8,
239 coarse_flat: Vec<f32>, pq: PQCodebook,
241 inv_row_ids: Vec<Vec<u64>>,
242 inv_codes: Vec<Vec<u8>>, }
244
245pub struct IvfPqSerializer;
246
247impl IvfPqSerializer {
248 pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
249 let coarse_flat: Vec<f32> = index
250 .coarse_centroids
251 .iter()
252 .flat_map(|c| c.iter().copied())
253 .collect();
254 let snap = IvfPqSnapshot {
255 nlist: index.config.nlist,
256 nprobe: index.config.nprobe,
257 pq_m: index.config.pq_m,
258 pq_k: index.config.pq_k,
259 max_iter: index.config.max_iter,
260 dim: index.dim,
261 metric: metric_to_u8(index.metric),
262 coarse_flat,
263 pq: index.pq.clone(),
264 inv_row_ids: index.inv_row_ids.clone(),
265 inv_codes: index.inv_codes.clone(),
266 };
267 bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
268 }
269
270 pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
271 let snap: IvfPqSnapshot =
272 bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
273 let metric = u8_to_metric(snap.metric)?;
274 let coarse_centroids: Vec<Vec<f32>> = snap
275 .coarse_flat
276 .chunks_exact(snap.dim)
277 .map(|c| c.to_vec())
278 .collect();
279 Ok(IvfPqIndex {
280 config: IvfPqConfig {
281 nlist: snap.nlist,
282 nprobe: snap.nprobe,
283 pq_m: snap.pq_m,
284 pq_k: snap.pq_k,
285 max_iter: snap.max_iter,
286 },
287 metric,
288 dim: snap.dim,
289 coarse_centroids,
290 pq: snap.pq,
291 inv_row_ids: snap.inv_row_ids,
292 inv_codes: snap.inv_codes,
293 })
294 }
295}
296
297fn l2_normalize(v: &[f32]) -> Vec<f32> {
300 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
301 if norm < 1e-9 {
302 v.to_vec()
303 } else {
304 v.iter().map(|x| x / norm).collect()
305 }
306}
307
308fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
309 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
310}
311
312fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
313 centroids
314 .iter()
315 .enumerate()
316 .map(|(i, c)| (i, l2_sq(v, c)))
317 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
318 .map(|(i, _)| i)
319 .unwrap_or(0)
320}
321
322pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
324 for m in (1..=requested).rev() {
325 if dim.is_multiple_of(m) {
326 return m;
327 }
328 }
329 1
330}
331
332fn metric_to_u8(m: VectorMetric) -> u8 {
333 match m {
334 VectorMetric::Cosine => 0,
335 VectorMetric::Euclidean => 1,
336 VectorMetric::DotProduct => 2,
337 }
338}
339
340fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
341 match v {
342 0 => Ok(VectorMetric::Cosine),
343 1 => Ok(VectorMetric::Euclidean),
344 2 => Ok(VectorMetric::DotProduct),
345 _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
354 let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
355 let vecs: Vec<Vec<f32>> = (0..n)
356 .map(|i| {
357 let mut v = vec![0.0f32; dim];
358 v[i % dim] = 1.0;
359 v
360 })
361 .collect();
362 (row_ids, vecs)
363 }
364
365 #[test]
366 fn train_and_search_basic() {
367 let dim = 8;
368 let (ids, vecs) = make_vecs(64, dim);
369 let config = IvfPqConfig {
370 nlist: 4,
371 nprobe: 2,
372 pq_m: 2,
373 pq_k: 4,
374 max_iter: 10,
375 };
376 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
377 assert_eq!(idx.node_count(), 64);
378
379 let query = vecs[0].clone();
380 let results = idx.search(&query, 5, None);
381 assert!(!results.is_empty());
382 assert!(results[0].1 < 0.1, "nearest should be approximate self");
384 }
385
386 #[test]
387 fn train_cosine_normalizes() {
388 let dim = 4;
389 let (ids, vecs) = make_vecs(32, dim);
390 let config = IvfPqConfig {
391 nlist: 4,
392 nprobe: 2,
393 pq_m: 2,
394 pq_k: 4,
395 max_iter: 10,
396 };
397 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
398 let results = idx.search(&vecs[0], 1, None);
399 assert!(!results.is_empty());
400 }
401
402 #[test]
403 fn serialize_roundtrip() {
404 let dim = 8;
405 let (ids, vecs) = make_vecs(32, dim);
406 let config = IvfPqConfig {
407 nlist: 4,
408 nprobe: 2,
409 pq_m: 2,
410 pq_k: 4,
411 max_iter: 10,
412 };
413 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
414 let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
415 let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
416
417 assert_eq!(idx2.node_count(), idx.node_count());
418 assert_eq!(idx2.dim(), idx.dim());
419
420 let q = vecs[0].clone();
421 let r1 = idx.search(&q, 5, None);
422 let r2 = idx2.search(&q, 5, None);
423 assert_eq!(r1.len(), r2.len());
424 for (a, b) in r1.iter().zip(r2.iter()) {
425 assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
426 }
427 }
428
429 #[test]
430 fn nlist_clamped_to_n() {
431 let dim = 4;
432 let (ids, vecs) = make_vecs(10, dim); let config = IvfPqConfig {
434 nlist: 256, nprobe: 8,
436 pq_m: 2,
437 pq_k: 4,
438 max_iter: 5,
439 };
440 let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
441 assert!(idx.config.nlist <= 10);
442 assert_eq!(idx.node_count(), 10);
443 }
444}