Skip to main content

oxirs_vec/
ivfpq_index.rs

1// Inverted File Index with Product Quantization (IVF-PQ) compound index
2// Added in v1.1.0 Round 7
3
4/// Configuration for an IVF-PQ index.
5#[derive(Debug, Clone)]
6pub struct IvfPqConfig {
7    /// Number of coarse clusters (IVF lists).
8    pub nlist: usize,
9    /// Number of PQ sub-quantizers (must divide dimension evenly).
10    pub m: usize,
11    /// Number of centroids per sub-quantizer.
12    pub k_per_sub: usize,
13    /// Number of coarse clusters to probe at query time.
14    pub nprobe: usize,
15    /// Vector dimension.
16    pub dimension: usize,
17}
18
19impl IvfPqConfig {
20    /// Validate configuration parameters.
21    pub fn validate(&self) -> Result<(), IvfPqError> {
22        if self.m == 0 {
23            return Err(IvfPqError::InvalidConfig("m must be > 0".to_string()));
24        }
25        if self.dimension == 0 {
26            return Err(IvfPqError::InvalidConfig(
27                "dimension must be > 0".to_string(),
28            ));
29        }
30        if self.dimension % self.m != 0 {
31            return Err(IvfPqError::InvalidConfig(format!(
32                "dimension ({}) must be divisible by m ({})",
33                self.dimension, self.m
34            )));
35        }
36        if self.nlist == 0 {
37            return Err(IvfPqError::InvalidConfig("nlist must be > 0".to_string()));
38        }
39        if self.nprobe == 0 {
40            return Err(IvfPqError::InvalidConfig("nprobe must be > 0".to_string()));
41        }
42        if self.nprobe > self.nlist {
43            return Err(IvfPqError::InvalidConfig(format!(
44                "nprobe ({}) must be <= nlist ({})",
45                self.nprobe, self.nlist
46            )));
47        }
48        if self.k_per_sub == 0 {
49            return Err(IvfPqError::InvalidConfig(
50                "k_per_sub must be > 0".to_string(),
51            ));
52        }
53        Ok(())
54    }
55}
56
57/// Errors from IVF-PQ operations.
58#[derive(Debug)]
59pub enum IvfPqError {
60    DimensionMismatch { expected: usize, got: usize },
61    NotTrained,
62    InvalidConfig(String),
63    InsufficientData(String),
64}
65
66impl std::fmt::Display for IvfPqError {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            IvfPqError::DimensionMismatch { expected, got } => {
70                write!(f, "Dimension mismatch: expected {expected}, got {got}")
71            }
72            IvfPqError::NotTrained => write!(f, "Index has not been trained yet"),
73            IvfPqError::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
74            IvfPqError::InsufficientData(msg) => write!(f, "Insufficient data: {msg}"),
75        }
76    }
77}
78
79impl std::error::Error for IvfPqError {}
80
81/// IVF-PQ approximate nearest neighbor index.
82pub struct IvfPqIndex {
83    config: IvfPqConfig,
84    /// Coarse centroids: nlist vectors of size `dimension`.
85    coarse_centroids: Vec<Vec<f64>>,
86    /// Per-cluster inverted lists: (vector_id, pq_codes).
87    inverted_lists: Vec<Vec<(u64, Vec<u8>)>>,
88    /// PQ codebook: m sub-quantizers, each has k_per_sub × (dimension/m) centroids.
89    pq_codebook: Vec<Vec<Vec<f64>>>,
90    is_trained: bool,
91    next_id: u64,
92}
93
94impl IvfPqIndex {
95    /// Create a new (untrained) IVF-PQ index with the given configuration.
96    pub fn new(config: IvfPqConfig) -> Result<Self, IvfPqError> {
97        config.validate()?;
98        let nlist = config.nlist;
99        Ok(Self {
100            config,
101            coarse_centroids: Vec::new(),
102            inverted_lists: vec![Vec::new(); nlist],
103            pq_codebook: Vec::new(),
104            is_trained: false,
105            next_id: 0,
106        })
107    }
108
109    /// Train the index: build coarse centroids (k-means) and PQ codebook.
110    pub fn train(&mut self, vectors: &[Vec<f64>]) -> Result<(), IvfPqError> {
111        if vectors.is_empty() {
112            return Err(IvfPqError::InsufficientData(
113                "Need at least 1 vector to train".to_string(),
114            ));
115        }
116        let n = vectors.len();
117        let dim = self.config.dimension;
118
119        // Validate dimensions
120        for v in vectors.iter() {
121            if v.len() != dim {
122                return Err(IvfPqError::DimensionMismatch {
123                    expected: dim,
124                    got: v.len(),
125                });
126            }
127        }
128
129        let nlist = self.config.nlist.min(n); // can't have more clusters than vectors
130        let m = self.config.m;
131        let k_per_sub = self.config.k_per_sub;
132        let sub_dim = dim / m;
133
134        // Step 1: Train coarse centroids (k-means on full vectors)
135        self.coarse_centroids = Self::kmeans(vectors, nlist, dim, 10);
136
137        // Step 2: Compute residuals and train PQ codebook
138        // For each vector, find its nearest coarse centroid, compute residual
139        let residuals: Vec<Vec<f64>> = vectors
140            .iter()
141            .map(|v| {
142                let nearest = self.find_nearest_centroid_trained(v);
143                let centroid = &self.coarse_centroids[nearest];
144                v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect()
145            })
146            .collect();
147
148        // Step 3: For each sub-quantizer, train k_per_sub centroids on residual slices
149        let mut pq_codebook = Vec::with_capacity(m);
150        for sub_idx in 0..m {
151            let start = sub_idx * sub_dim;
152            let end = start + sub_dim;
153            let sub_data: Vec<Vec<f64>> =
154                residuals.iter().map(|r| r[start..end].to_vec()).collect();
155            let k = k_per_sub.min(sub_data.len());
156            let centroids = Self::kmeans(&sub_data, k, sub_dim, 5);
157            pq_codebook.push(centroids);
158        }
159        self.pq_codebook = pq_codebook;
160        self.is_trained = true;
161
162        // Resize inverted lists to actual nlist
163        let actual_nlist = self.coarse_centroids.len();
164        self.inverted_lists = vec![Vec::new(); actual_nlist];
165        Ok(())
166    }
167
168    /// Add a vector to the trained index.
169    pub fn add(&mut self, vector: &[f64]) -> Result<u64, IvfPqError> {
170        if !self.is_trained {
171            return Err(IvfPqError::NotTrained);
172        }
173        let dim = self.config.dimension;
174        if vector.len() != dim {
175            return Err(IvfPqError::DimensionMismatch {
176                expected: dim,
177                got: vector.len(),
178            });
179        }
180        let cluster_idx = self.find_nearest_centroid(vector);
181        let centroid = &self.coarse_centroids[cluster_idx];
182        let residual: Vec<f64> = vector
183            .iter()
184            .zip(centroid.iter())
185            .map(|(a, b)| a - b)
186            .collect();
187        let codes = self.encode_residual(&residual);
188        let id = self.next_id;
189        self.next_id += 1;
190        self.inverted_lists[cluster_idx].push((id, codes));
191        Ok(id)
192    }
193
194    /// Add multiple vectors in batch.
195    pub fn add_batch(&mut self, vectors: &[Vec<f64>]) -> Result<Vec<u64>, IvfPqError> {
196        vectors.iter().map(|v| self.add(v)).collect()
197    }
198
199    /// Search for the k nearest neighbors of a query vector.
200    ///
201    /// Returns (id, approximate_distance) pairs sorted by distance (ascending).
202    pub fn search(&self, query: &[f64], k: usize) -> Result<Vec<(u64, f64)>, IvfPqError> {
203        if !self.is_trained {
204            return Err(IvfPqError::NotTrained);
205        }
206        let dim = self.config.dimension;
207        if query.len() != dim {
208            return Err(IvfPqError::DimensionMismatch {
209                expected: dim,
210                got: query.len(),
211            });
212        }
213
214        // Find nprobe nearest coarse centroids
215        let nprobe = self.config.nprobe.min(self.coarse_centroids.len());
216        let mut centroid_dists: Vec<(usize, f64)> = self
217            .coarse_centroids
218            .iter()
219            .enumerate()
220            .map(|(i, c)| (i, Self::l2_distance(query, c)))
221            .collect();
222        centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
223
224        // Compute PQ distance tables for the query residual against each top cluster
225        let sub_dim = dim / self.config.m;
226        let m = self.config.m;
227
228        let mut candidates: Vec<(u64, f64)> = Vec::new();
229
230        for &(cluster_idx, _) in centroid_dists.iter().take(nprobe) {
231            let centroid = &self.coarse_centroids[cluster_idx];
232            let residual: Vec<f64> = query
233                .iter()
234                .zip(centroid.iter())
235                .map(|(a, b)| a - b)
236                .collect();
237
238            // Build distance tables: for each sub-quantizer, precompute dist to each code centroid
239            let dist_tables: Vec<Vec<f64>> = (0..m)
240                .map(|sub_idx| {
241                    let start = sub_idx * sub_dim;
242                    let q_sub = &residual[start..start + sub_dim];
243                    self.pq_codebook[sub_idx]
244                        .iter()
245                        .map(|code_centroid| Self::l2_distance(q_sub, code_centroid))
246                        .collect()
247                })
248                .collect();
249
250            for &(id, ref codes) in &self.inverted_lists[cluster_idx] {
251                // Approximate distance via PQ lookup
252                let dist: f64 = codes
253                    .iter()
254                    .enumerate()
255                    .map(|(sub_idx, &code)| {
256                        let code_idx = code as usize;
257                        dist_tables[sub_idx]
258                            .get(code_idx)
259                            .copied()
260                            .unwrap_or(f64::MAX)
261                    })
262                    .sum();
263                candidates.push((id, dist));
264            }
265        }
266
267        // Sort candidates by distance and take top k
268        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
269        candidates.truncate(k);
270        Ok(candidates)
271    }
272
273    /// Total number of vectors added to the index.
274    pub fn size(&self) -> usize {
275        self.inverted_lists.iter().map(|l| l.len()).sum()
276    }
277
278    /// Whether the index has been trained.
279    pub fn is_trained(&self) -> bool {
280        self.is_trained
281    }
282
283    /// Find the nearest coarse centroid index for a vector (only callable after training).
284    pub fn find_nearest_centroid(&self, vector: &[f64]) -> usize {
285        self.find_nearest_centroid_trained(vector)
286    }
287
288    fn find_nearest_centroid_trained(&self, vector: &[f64]) -> usize {
289        let mut best_idx = 0;
290        let mut best_dist = f64::MAX;
291        for (i, centroid) in self.coarse_centroids.iter().enumerate() {
292            let d = Self::l2_distance(vector, centroid);
293            if d < best_dist {
294                best_dist = d;
295                best_idx = i;
296            }
297        }
298        best_idx
299    }
300
301    /// Encode a residual vector using the PQ codebook.
302    pub fn encode_residual(&self, residual: &[f64]) -> Vec<u8> {
303        let sub_dim = self.config.dimension / self.config.m;
304        let m = self.config.m;
305        let mut codes = Vec::with_capacity(m);
306        for sub_idx in 0..m {
307            let start = sub_idx * sub_dim;
308            let sub = &residual[start..start + sub_dim];
309            // Find nearest centroid in PQ codebook for this sub-quantizer
310            let mut best_code = 0u8;
311            let mut best_dist = f64::MAX;
312            for (code_idx, centroid) in self.pq_codebook[sub_idx].iter().enumerate() {
313                let d = Self::l2_distance(sub, centroid);
314                if d < best_dist {
315                    best_dist = d;
316                    best_code = (code_idx & 0xFF) as u8;
317                }
318            }
319            codes.push(best_code);
320        }
321        codes
322    }
323
324    /// L2 squared distance between two equal-length slices.
325    pub fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
326        a.iter()
327            .zip(b.iter())
328            .map(|(x, y)| (x - y).powi(2))
329            .sum::<f64>()
330    }
331
332    /// Simple k-means with random initialization (first k vectors) and `iters` iterations.
333    pub fn kmeans(data: &[Vec<f64>], k: usize, dim: usize, iters: usize) -> Vec<Vec<f64>> {
334        if data.is_empty() || k == 0 {
335            return Vec::new();
336        }
337        let k = k.min(data.len());
338
339        // Initialize centroids: evenly spaced through data
340        let mut centroids: Vec<Vec<f64>> =
341            (0..k).map(|i| data[i * data.len() / k].clone()).collect();
342
343        for _ in 0..iters {
344            // Assign step
345            let mut clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
346            for (idx, point) in data.iter().enumerate() {
347                let best = centroids
348                    .iter()
349                    .enumerate()
350                    .map(|(ci, c)| (ci, Self::l2_distance(point, c)))
351                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
352                    .map(|(ci, _)| ci)
353                    .unwrap_or(0);
354                clusters[best].push(idx);
355            }
356
357            // Update step
358            let mut new_centroids = Vec::with_capacity(k);
359            for (ci, members) in clusters.iter().enumerate() {
360                if members.is_empty() {
361                    // Keep old centroid if cluster is empty
362                    new_centroids.push(centroids[ci].clone());
363                } else {
364                    let mut centroid = vec![0.0f64; dim];
365                    for &idx in members {
366                        for (d, val) in centroid.iter_mut().zip(data[idx].iter()) {
367                            *d += val;
368                        }
369                    }
370                    let count = members.len() as f64;
371                    for d in centroid.iter_mut() {
372                        *d /= count;
373                    }
374                    new_centroids.push(centroid);
375                }
376            }
377            centroids = new_centroids;
378        }
379        centroids
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
386    use super::*;
387
388    fn make_config(dim: usize, nlist: usize, m: usize, k: usize, nprobe: usize) -> IvfPqConfig {
389        IvfPqConfig {
390            nlist,
391            m,
392            k_per_sub: k,
393            nprobe,
394            dimension: dim,
395        }
396    }
397
398    fn make_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
399        // Simple LCG pseudo-random for reproducibility
400        let mut state = seed;
401        (0..n)
402            .map(|_| {
403                (0..dim)
404                    .map(|_| {
405                        state = state
406                            .wrapping_mul(6364136223846793005)
407                            .wrapping_add(1442695040888963407);
408                        ((state >> 33) as f64) / (u32::MAX as f64) * 2.0 - 1.0
409                    })
410                    .collect()
411            })
412            .collect()
413    }
414
415    // ---- new / config validation ----
416
417    #[test]
418    fn test_new_valid_config() {
419        let config = make_config(8, 4, 2, 4, 2);
420        assert!(IvfPqIndex::new(config).is_ok());
421    }
422
423    #[test]
424    fn test_new_m_not_divides_dimension() {
425        let config = make_config(7, 4, 3, 4, 2); // 7 % 3 != 0
426        assert!(matches!(
427            IvfPqIndex::new(config),
428            Err(IvfPqError::InvalidConfig(_))
429        ));
430    }
431
432    #[test]
433    fn test_new_m_zero() {
434        let config = make_config(8, 4, 0, 4, 2);
435        assert!(matches!(
436            IvfPqIndex::new(config),
437            Err(IvfPqError::InvalidConfig(_))
438        ));
439    }
440
441    #[test]
442    fn test_new_nlist_zero() {
443        let config = make_config(8, 0, 2, 4, 2);
444        assert!(matches!(
445            IvfPqIndex::new(config),
446            Err(IvfPqError::InvalidConfig(_))
447        ));
448    }
449
450    #[test]
451    fn test_new_nprobe_gt_nlist() {
452        let config = make_config(8, 2, 2, 4, 5); // nprobe=5 > nlist=2
453        assert!(matches!(
454            IvfPqIndex::new(config),
455            Err(IvfPqError::InvalidConfig(_))
456        ));
457    }
458
459    #[test]
460    fn test_new_dimension_zero() {
461        let config = make_config(0, 4, 0, 4, 2);
462        assert!(matches!(
463            IvfPqIndex::new(config),
464            Err(IvfPqError::InvalidConfig(_))
465        ));
466    }
467
468    // ---- is_trained / train ----
469
470    #[test]
471    fn test_not_trained_initially() -> Result<()> {
472        let config = make_config(8, 4, 2, 4, 2);
473        let index = IvfPqIndex::new(config)?;
474        assert!(!index.is_trained());
475        Ok(())
476    }
477
478    #[test]
479    fn test_train_basic() -> Result<()> {
480        let config = make_config(8, 4, 2, 4, 2);
481        let mut index = IvfPqIndex::new(config)?;
482        let vectors = make_random_vectors(20, 8, 42);
483        index.train(&vectors)?;
484        assert!(index.is_trained());
485        Ok(())
486    }
487
488    #[test]
489    fn test_train_too_few_vectors() -> Result<()> {
490        let config = make_config(8, 4, 2, 4, 2);
491        let mut index = IvfPqIndex::new(config)?;
492        // 0 vectors → error
493        let result = index.train(&[]);
494        assert!(matches!(result, Err(IvfPqError::InsufficientData(_))));
495        Ok(())
496    }
497
498    #[test]
499    fn test_train_dimension_mismatch() -> Result<()> {
500        let config = make_config(8, 4, 2, 4, 2);
501        let mut index = IvfPqIndex::new(config)?;
502        let vectors = vec![vec![1.0, 2.0, 3.0]]; // dim=3, not 8
503        let result = index.train(&vectors);
504        assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
505        Ok(())
506    }
507
508    // ---- add ----
509
510    #[test]
511    fn test_add_before_training_error() -> Result<()> {
512        let config = make_config(8, 4, 2, 4, 2);
513        let mut index = IvfPqIndex::new(config)?;
514        let v = vec![0.0; 8];
515        let result = index.add(&v);
516        assert!(matches!(result, Err(IvfPqError::NotTrained)));
517        Ok(())
518    }
519
520    #[test]
521    fn test_add_after_training() -> Result<()> {
522        let config = make_config(8, 4, 2, 4, 2);
523        let mut index = IvfPqIndex::new(config)?;
524        let vectors = make_random_vectors(20, 8, 1);
525        index.train(&vectors)?;
526        let id = index.add(&vectors[0])?;
527        assert_eq!(id, 0);
528        assert_eq!(index.size(), 1);
529        Ok(())
530    }
531
532    #[test]
533    fn test_add_dimension_mismatch() -> Result<()> {
534        let config = make_config(8, 4, 2, 4, 2);
535        let mut index = IvfPqIndex::new(config)?;
536        let vectors = make_random_vectors(20, 8, 2);
537        index.train(&vectors)?;
538        let bad_v = vec![1.0, 2.0]; // wrong dim
539        let result = index.add(&bad_v);
540        assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
541        Ok(())
542    }
543
544    // ---- add_batch ----
545
546    #[test]
547    fn test_add_batch() -> Result<()> {
548        let config = make_config(8, 4, 2, 4, 2);
549        let mut index = IvfPqIndex::new(config)?;
550        let train_data = make_random_vectors(20, 8, 3);
551        index.train(&train_data)?;
552        let add_data = make_random_vectors(5, 8, 4);
553        let ids = index.add_batch(&add_data)?;
554        assert_eq!(ids.len(), 5);
555        assert_eq!(index.size(), 5);
556        Ok(())
557    }
558
559    // ---- size ----
560
561    #[test]
562    fn test_size_starts_at_zero() -> Result<()> {
563        let config = make_config(8, 4, 2, 4, 2);
564        let mut index = IvfPqIndex::new(config)?;
565        let vectors = make_random_vectors(20, 8, 5);
566        index.train(&vectors)?;
567        assert_eq!(index.size(), 0);
568        Ok(())
569    }
570
571    #[test]
572    fn test_size_after_adding() -> Result<()> {
573        let config = make_config(8, 4, 2, 4, 2);
574        let mut index = IvfPqIndex::new(config)?;
575        let vectors = make_random_vectors(20, 8, 6);
576        index.train(&vectors)?;
577        for v in &vectors {
578            index.add(v)?;
579        }
580        assert_eq!(index.size(), 20);
581        Ok(())
582    }
583
584    // ---- search ----
585
586    #[test]
587    fn test_search_before_training_error() -> Result<()> {
588        let config = make_config(8, 4, 2, 4, 2);
589        let index = IvfPqIndex::new(config)?;
590        let q = vec![0.0; 8];
591        let result = index.search(&q, 5);
592        assert!(matches!(result, Err(IvfPqError::NotTrained)));
593        Ok(())
594    }
595
596    #[test]
597    fn test_search_empty_index() -> Result<()> {
598        let config = make_config(8, 4, 2, 4, 2);
599        let mut index = IvfPqIndex::new(config)?;
600        let vectors = make_random_vectors(20, 8, 7);
601        index.train(&vectors)?;
602        let q = vec![0.0; 8];
603        let results = index.search(&q, 5)?;
604        assert!(results.is_empty());
605        Ok(())
606    }
607
608    #[test]
609    fn test_search_returns_k_results() -> Result<()> {
610        let config = make_config(8, 4, 2, 4, 2);
611        let mut index = IvfPqIndex::new(config)?;
612        let vectors = make_random_vectors(50, 8, 8);
613        index.train(&vectors)?;
614        for v in &vectors {
615            index.add(v)?;
616        }
617        let q = vec![0.0; 8];
618        let results = index.search(&q, 10)?;
619        assert!(results.len() <= 10);
620        assert!(!results.is_empty());
621        Ok(())
622    }
623
624    #[test]
625    fn test_search_sorted_by_distance() -> Result<()> {
626        let config = make_config(8, 4, 2, 4, 2);
627        let mut index = IvfPqIndex::new(config)?;
628        let vectors = make_random_vectors(30, 8, 9);
629        index.train(&vectors)?;
630        for v in &vectors {
631            index.add(v)?;
632        }
633        let q = vec![0.0; 8];
634        let results = index.search(&q, 10)?;
635        for i in 1..results.len() {
636            assert!(
637                results[i - 1].1 <= results[i].1,
638                "Results not sorted: {} > {}",
639                results[i - 1].1,
640                results[i].1
641            );
642        }
643        Ok(())
644    }
645
646    #[test]
647    fn test_search_dimension_mismatch() -> Result<()> {
648        let config = make_config(8, 4, 2, 4, 2);
649        let mut index = IvfPqIndex::new(config)?;
650        let vectors = make_random_vectors(20, 8, 10);
651        index.train(&vectors)?;
652        let bad_q = vec![1.0, 2.0]; // wrong dim
653        let result = index.search(&bad_q, 5);
654        assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
655        Ok(())
656    }
657
658    // ---- l2_distance ----
659
660    #[test]
661    fn test_l2_distance_zero() {
662        let a = vec![1.0, 2.0, 3.0];
663        assert!(IvfPqIndex::l2_distance(&a, &a) < 1e-10);
664    }
665
666    #[test]
667    fn test_l2_distance_known() {
668        let a = vec![0.0, 0.0];
669        let b = vec![3.0, 4.0];
670        let d = IvfPqIndex::l2_distance(&a, &b);
671        assert!((d - 25.0).abs() < 1e-10); // 3^2 + 4^2 = 25
672    }
673
674    // ---- IvfPqError display ----
675
676    #[test]
677    fn test_error_display() {
678        let e = IvfPqError::DimensionMismatch {
679            expected: 8,
680            got: 4,
681        };
682        assert!(format!("{e}").contains("8"));
683        let e2 = IvfPqError::NotTrained;
684        assert!(!format!("{e2}").is_empty());
685        let e3 = IvfPqError::InvalidConfig("m".to_string());
686        assert!(format!("{e3}").contains("m"));
687        let e4 = IvfPqError::InsufficientData("need more".to_string());
688        assert!(format!("{e4}").contains("need more"));
689    }
690
691    // ---- IvfPqConfig validation ----
692
693    #[test]
694    fn test_config_validation_valid() {
695        let config = make_config(8, 4, 2, 4, 2);
696        assert!(config.validate().is_ok());
697    }
698
699    #[test]
700    fn test_config_validation_k_per_sub_zero() {
701        let config = IvfPqConfig {
702            nlist: 4,
703            m: 2,
704            k_per_sub: 0,
705            nprobe: 2,
706            dimension: 8,
707        };
708        assert!(matches!(
709            config.validate(),
710            Err(IvfPqError::InvalidConfig(_))
711        ));
712    }
713}