1use std::collections::HashMap;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum IndexType {
16 HNSW,
18 IVF,
20 IVFPQ,
22 Flat,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub struct HnswParams {
31 pub m: usize,
33 pub ef_construction: usize,
35 pub ef_search: usize,
37}
38
39impl HnswParams {
40 pub fn new(m: usize, ef_construction: usize, ef_search: usize) -> Self {
42 Self {
43 m,
44 ef_construction,
45 ef_search,
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct IvfParams {
53 pub n_lists: usize,
55 pub n_probes: usize,
57}
58
59impl IvfParams {
60 pub fn new(n_lists: usize, n_probes: usize) -> Self {
62 Self { n_lists, n_probes }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum IndexParams {
71 Hnsw(HnswParams),
73 Ivf(IvfParams),
75}
76
77impl IndexParams {
78 pub fn as_hnsw(&self) -> Option<&HnswParams> {
80 match self {
81 Self::Hnsw(p) => Some(p),
82 _ => None,
83 }
84 }
85
86 pub fn as_ivf(&self) -> Option<&IvfParams> {
88 match self {
89 Self::Ivf(p) => Some(p),
90 _ => None,
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
99pub enum OptimizationTarget {
100 MaxRecall,
102 MaxQPS,
104 BalancedRecallQPS {
106 recall_weight: f64,
108 },
109}
110
111#[derive(Debug, Clone)]
115pub struct BenchmarkPoint {
116 pub params: IndexParams,
118 pub recall_at_10: f64,
120 pub qps: f64,
122 pub build_time_ms: u64,
124}
125
126impl BenchmarkPoint {
127 pub fn new(params: IndexParams, recall_at_10: f64, qps: f64, build_time_ms: u64) -> Self {
129 Self {
130 params,
131 recall_at_10,
132 qps,
133 build_time_ms,
134 }
135 }
136}
137
138pub struct IndexOptimizer {
147 index_type: IndexType,
148 target: OptimizationTarget,
149 benchmarks: Vec<BenchmarkPoint>,
150}
151
152impl IndexOptimizer {
153 pub fn new(index_type: IndexType, target: OptimizationTarget) -> Self {
155 Self {
156 index_type,
157 target,
158 benchmarks: Vec::new(),
159 }
160 }
161
162 pub fn add_benchmark(&mut self, point: BenchmarkPoint) {
164 self.benchmarks.push(point);
165 }
166
167 pub fn benchmark_count(&self) -> usize {
169 self.benchmarks.len()
170 }
171
172 fn score(&self, point: &BenchmarkPoint) -> f64 {
177 match &self.target {
178 OptimizationTarget::MaxRecall => point.recall_at_10,
179 OptimizationTarget::MaxQPS => point.qps,
180 OptimizationTarget::BalancedRecallQPS { recall_weight } => {
181 let max_qps = self
182 .benchmarks
183 .iter()
184 .map(|b| b.qps)
185 .fold(f64::NEG_INFINITY, f64::max);
186 let norm_qps = if max_qps > 0.0 {
187 point.qps / max_qps
188 } else {
189 0.0
190 };
191 recall_weight * point.recall_at_10 + (1.0 - recall_weight) * norm_qps
192 }
193 }
194 }
195
196 pub fn best_params(&self) -> Option<&BenchmarkPoint> {
199 self.benchmarks.iter().max_by(|a, b| {
200 self.score(a)
201 .partial_cmp(&self.score(b))
202 .unwrap_or(std::cmp::Ordering::Equal)
203 })
204 }
205
206 pub fn pareto_front(&self) -> Vec<&BenchmarkPoint> {
212 let mut front: Vec<&BenchmarkPoint> = Vec::new();
213
214 for candidate in &self.benchmarks {
215 let dominated = front.iter().any(|existing| {
216 existing.recall_at_10 >= candidate.recall_at_10
218 && existing.qps >= candidate.qps
219 && (existing.recall_at_10 > candidate.recall_at_10
220 || existing.qps > candidate.qps)
221 });
222
223 if !dominated {
224 front.retain(|existing| {
226 !(candidate.recall_at_10 >= existing.recall_at_10
227 && candidate.qps >= existing.qps
228 && (candidate.recall_at_10 > existing.recall_at_10
229 || candidate.qps > existing.qps))
230 });
231 front.push(candidate);
232 }
233 }
234
235 front.sort_by(|a, b| {
237 b.recall_at_10
238 .partial_cmp(&a.recall_at_10)
239 .unwrap_or(std::cmp::Ordering::Equal)
240 });
241
242 front
243 }
244
245 pub fn suggest_next_params(&self) -> Option<IndexParams> {
252 let best = self.best_params()?;
253
254 match &best.params {
255 IndexParams::Hnsw(p) => {
256 let next = if p.ef_search < 512 {
258 HnswParams::new(p.m, p.ef_construction, p.ef_search * 2)
259 } else if p.m < 64 {
260 HnswParams::new(p.m * 2, p.ef_construction, p.ef_search)
261 } else {
262 HnswParams::new(p.m, p.ef_construction * 2, p.ef_search)
263 };
264 Some(IndexParams::Hnsw(next))
265 }
266 IndexParams::Ivf(p) => {
267 let next = if p.n_probes < p.n_lists {
269 IvfParams::new(p.n_lists, (p.n_probes * 2).min(p.n_lists))
270 } else {
271 IvfParams::new(p.n_lists * 2, p.n_probes)
272 };
273 Some(IndexParams::Ivf(next))
274 }
275 }
276 }
277
278 pub fn index_type(&self) -> &IndexType {
280 &self.index_type
281 }
282
283 pub fn benchmarks(&self) -> &[BenchmarkPoint] {
285 &self.benchmarks
286 }
287
288 pub fn clear(&mut self) {
290 self.benchmarks.clear();
291 }
292
293 pub fn ranked_benchmarks(&self) -> Vec<&BenchmarkPoint> {
295 let mut ranked: Vec<&BenchmarkPoint> = self.benchmarks.iter().collect();
296 ranked.sort_by(|a, b| {
297 self.score(b)
298 .partial_cmp(&self.score(a))
299 .unwrap_or(std::cmp::Ordering::Equal)
300 });
301 ranked
302 }
303
304 pub fn score_of(&self, point: &BenchmarkPoint) -> f64 {
306 self.score(point)
307 }
308
309 pub fn filter_by_recall(&self, min_recall: f64) -> Vec<&BenchmarkPoint> {
311 self.benchmarks
312 .iter()
313 .filter(|b| b.recall_at_10 >= min_recall)
314 .collect()
315 }
316
317 pub fn filter_by_qps(&self, min_qps: f64) -> Vec<&BenchmarkPoint> {
319 self.benchmarks
320 .iter()
321 .filter(|b| b.qps >= min_qps)
322 .collect()
323 }
324
325 pub fn recall_stats(&self) -> Option<RecallStats> {
327 if self.benchmarks.is_empty() {
328 return None;
329 }
330 let recalls: Vec<f64> = self.benchmarks.iter().map(|b| b.recall_at_10).collect();
331 let min = recalls.iter().cloned().fold(f64::INFINITY, f64::min);
332 let max = recalls.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
333 let mean = recalls.iter().sum::<f64>() / recalls.len() as f64;
334 Some(RecallStats { min, max, mean })
335 }
336
337 pub fn group_by_variant(&self) -> HashMap<&'static str, Vec<&BenchmarkPoint>> {
339 let mut groups: HashMap<&'static str, Vec<&BenchmarkPoint>> = HashMap::new();
340 for b in &self.benchmarks {
341 let key = match &b.params {
342 IndexParams::Hnsw(_) => "hnsw",
343 IndexParams::Ivf(_) => "ivf",
344 };
345 groups.entry(key).or_default().push(b);
346 }
347 groups
348 }
349}
350
351#[derive(Debug, Clone)]
353pub struct RecallStats {
354 pub min: f64,
356 pub max: f64,
358 pub mean: f64,
360}
361
362#[cfg(test)]
365mod tests {
366 use super::*;
367
368 fn hnsw_point(m: usize, ef_c: usize, ef_s: usize, recall: f64, qps: f64) -> BenchmarkPoint {
371 BenchmarkPoint::new(
372 IndexParams::Hnsw(HnswParams::new(m, ef_c, ef_s)),
373 recall,
374 qps,
375 100,
376 )
377 }
378
379 fn ivf_point(n_lists: usize, n_probes: usize, recall: f64, qps: f64) -> BenchmarkPoint {
380 BenchmarkPoint::new(
381 IndexParams::Ivf(IvfParams::new(n_lists, n_probes)),
382 recall,
383 qps,
384 200,
385 )
386 }
387
388 fn make_hnsw_optimizer(target: OptimizationTarget) -> IndexOptimizer {
389 IndexOptimizer::new(IndexType::HNSW, target)
390 }
391
392 fn make_ivf_optimizer(target: OptimizationTarget) -> IndexOptimizer {
393 IndexOptimizer::new(IndexType::IVF, target)
394 }
395
396 #[test]
399 fn test_new_optimizer_empty() {
400 let opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
401 assert_eq!(opt.benchmark_count(), 0);
402 assert!(opt.best_params().is_none());
403 assert!(opt.pareto_front().is_empty());
404 assert!(opt.suggest_next_params().is_none());
405 }
406
407 #[test]
408 fn test_index_type_stored() {
409 let opt = IndexOptimizer::new(IndexType::IVF, OptimizationTarget::MaxQPS);
410 assert_eq!(opt.index_type(), &IndexType::IVF);
411 }
412
413 #[test]
414 fn test_flat_index_type() {
415 let opt = IndexOptimizer::new(IndexType::Flat, OptimizationTarget::MaxRecall);
416 assert_eq!(opt.index_type(), &IndexType::Flat);
417 }
418
419 #[test]
422 fn test_add_single_benchmark() {
423 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
424 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
425 assert_eq!(opt.benchmark_count(), 1);
426 }
427
428 #[test]
429 fn test_add_multiple_benchmarks() {
430 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
431 for i in 0..10 {
432 opt.add_benchmark(hnsw_point(
433 16,
434 200,
435 50 + i * 10,
436 0.8 + i as f64 * 0.01,
437 5000.0 - i as f64 * 100.0,
438 ));
439 }
440 assert_eq!(opt.benchmark_count(), 10);
441 }
442
443 #[test]
446 fn test_best_params_max_recall_single() {
447 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
448 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
449 let best = opt.best_params().expect("should have best");
450 assert_eq!(best.recall_at_10, 0.9);
451 }
452
453 #[test]
454 fn test_best_params_max_recall_picks_highest() {
455 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
456 opt.add_benchmark(hnsw_point(16, 200, 50, 0.75, 8000.0));
457 opt.add_benchmark(hnsw_point(32, 400, 100, 0.95, 3000.0));
458 opt.add_benchmark(hnsw_point(16, 200, 80, 0.85, 6000.0));
459 let best = opt.best_params().expect("some best");
460 assert_eq!(best.recall_at_10, 0.95);
461 }
462
463 #[test]
464 fn test_best_params_max_recall_ignores_qps() {
465 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
466 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 100_000.0));
468 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 100.0));
470 let best = opt.best_params().expect("some best");
471 assert_eq!(best.recall_at_10, 0.99);
472 }
473
474 #[test]
477 fn test_best_params_max_qps_picks_highest_qps() {
478 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxQPS);
479 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 3000.0));
480 opt.add_benchmark(hnsw_point(4, 50, 10, 0.6, 12000.0));
481 opt.add_benchmark(hnsw_point(32, 400, 100, 0.95, 1500.0));
482 let best = opt.best_params().expect("some best");
483 assert_eq!(best.qps, 12000.0);
484 }
485
486 #[test]
487 fn test_best_params_max_qps_ignores_recall() {
488 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxQPS);
489 opt.add_benchmark(hnsw_point(4, 50, 10, 0.1, 50000.0));
490 opt.add_benchmark(hnsw_point(64, 800, 400, 1.0, 100.0));
491 let best = opt.best_params().expect("some best");
492 assert_eq!(best.qps, 50000.0);
493 }
494
495 #[test]
498 fn test_best_params_balanced_equal_weight() {
499 let mut opt =
500 make_hnsw_optimizer(OptimizationTarget::BalancedRecallQPS { recall_weight: 0.5 });
501 opt.add_benchmark(hnsw_point(64, 800, 400, 1.0, 100.0));
503 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
505 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 10000.0));
507 let best = opt.best_params();
513 assert!(best.is_some());
514 }
515
516 #[test]
517 fn test_best_params_balanced_recall_heavy() {
518 let mut opt =
519 make_hnsw_optimizer(OptimizationTarget::BalancedRecallQPS { recall_weight: 0.9 });
520 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 1000.0));
521 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 10000.0));
522 let best = opt.best_params().expect("some best");
525 assert_eq!(best.recall_at_10, 0.99);
526 }
527
528 #[test]
529 fn test_best_params_balanced_qps_heavy() {
530 let mut opt =
531 make_hnsw_optimizer(OptimizationTarget::BalancedRecallQPS { recall_weight: 0.1 });
532 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 1000.0));
533 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 10000.0));
534 let best = opt.best_params().expect("some best");
537 assert_eq!(best.qps, 10000.0);
538 }
539
540 #[test]
543 fn test_pareto_front_empty() {
544 let opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
545 assert!(opt.pareto_front().is_empty());
546 }
547
548 #[test]
549 fn test_pareto_front_single_point() {
550 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
551 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
552 let front = opt.pareto_front();
553 assert_eq!(front.len(), 1);
554 }
555
556 #[test]
557 fn test_pareto_front_no_dominated() {
558 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
559 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0)); opt.add_benchmark(hnsw_point(16, 200, 50, 0.85, 5000.0)); opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0)); let front = opt.pareto_front();
564 assert_eq!(front.len(), 3);
565 }
566
567 #[test]
568 fn test_pareto_front_dominated_excluded() {
569 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
570 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0)); opt.add_benchmark(hnsw_point(8, 100, 30, 0.8, 4000.0)); opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0)); let front = opt.pareto_front();
574 assert_eq!(front.len(), 2);
576 let recalls: Vec<f64> = front.iter().map(|p| p.recall_at_10).collect();
577 assert!(!recalls.contains(&0.8));
578 }
579
580 #[test]
581 fn test_pareto_front_sorted_by_recall_desc() {
582 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
583 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0));
584 opt.add_benchmark(hnsw_point(16, 200, 50, 0.85, 5000.0));
585 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0));
586 let front = opt.pareto_front();
587 for window in front.windows(2) {
588 assert!(window[0].recall_at_10 >= window[1].recall_at_10);
589 }
590 }
591
592 #[test]
593 fn test_pareto_front_all_dominated_except_best() {
594 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
595 opt.add_benchmark(hnsw_point(64, 800, 400, 1.0, 20000.0));
597 opt.add_benchmark(hnsw_point(16, 200, 50, 0.8, 5000.0));
598 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 3000.0));
599 let front = opt.pareto_front();
600 assert_eq!(front.len(), 1);
601 assert_eq!(front[0].recall_at_10, 1.0);
602 }
603
604 #[test]
607 fn test_suggest_next_none_when_empty() {
608 let opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
609 assert!(opt.suggest_next_params().is_none());
610 }
611
612 #[test]
613 fn test_suggest_next_hnsw_increments_ef_search() {
614 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
615 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
616 let next = opt.suggest_next_params().expect("suggestion");
617 if let IndexParams::Hnsw(p) = next {
618 assert_eq!(p.ef_search, 100);
620 assert_eq!(p.m, 16);
621 } else {
622 panic!("Expected Hnsw params");
623 }
624 }
625
626 #[test]
627 fn test_suggest_next_hnsw_increments_m_when_ef_search_maxed() {
628 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
629 opt.add_benchmark(hnsw_point(16, 200, 512, 0.99, 1000.0));
630 let next = opt.suggest_next_params().expect("suggestion");
631 if let IndexParams::Hnsw(p) = next {
632 assert_eq!(p.ef_search, 512);
634 assert_eq!(p.m, 32);
635 } else {
636 panic!("Expected Hnsw params");
637 }
638 }
639
640 #[test]
641 fn test_suggest_next_ivf_increments_n_probes() {
642 let mut opt = make_ivf_optimizer(OptimizationTarget::MaxRecall);
643 opt.add_benchmark(ivf_point(256, 4, 0.7, 8000.0));
644 let next = opt.suggest_next_params().expect("suggestion");
645 if let IndexParams::Ivf(p) = next {
646 assert_eq!(p.n_probes, 8);
647 assert_eq!(p.n_lists, 256);
648 } else {
649 panic!("Expected IVF params");
650 }
651 }
652
653 #[test]
654 fn test_suggest_next_ivf_grows_n_lists_when_probes_maxed() {
655 let mut opt = make_ivf_optimizer(OptimizationTarget::MaxRecall);
656 opt.add_benchmark(ivf_point(64, 64, 0.95, 2000.0));
658 let next = opt.suggest_next_params().expect("suggestion");
659 if let IndexParams::Ivf(p) = next {
660 assert_eq!(p.n_lists, 128);
661 } else {
662 panic!("Expected IVF params");
663 }
664 }
665
666 #[test]
669 fn test_ivf_best_params_max_recall() {
670 let mut opt = make_ivf_optimizer(OptimizationTarget::MaxRecall);
671 opt.add_benchmark(ivf_point(64, 4, 0.6, 9000.0));
672 opt.add_benchmark(ivf_point(64, 32, 0.9, 4000.0));
673 opt.add_benchmark(ivf_point(256, 64, 0.97, 1500.0));
674 let best = opt.best_params().expect("best");
675 assert_eq!(best.recall_at_10, 0.97);
676 }
677
678 #[test]
679 fn test_ivf_best_params_max_qps() {
680 let mut opt = make_ivf_optimizer(OptimizationTarget::MaxQPS);
681 opt.add_benchmark(ivf_point(64, 4, 0.6, 9000.0));
682 opt.add_benchmark(ivf_point(64, 32, 0.9, 4000.0));
683 opt.add_benchmark(ivf_point(256, 64, 0.97, 1500.0));
684 let best = opt.best_params().expect("best");
685 assert_eq!(best.qps, 9000.0);
686 }
687
688 #[test]
691 fn test_score_of_max_recall() {
692 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
693 let p = hnsw_point(16, 200, 50, 0.92, 5000.0);
694 opt.add_benchmark(p.clone());
695 assert!((opt.score_of(&p) - 0.92).abs() < 1e-9);
696 }
697
698 #[test]
699 fn test_score_of_max_qps() {
700 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxQPS);
701 let p = hnsw_point(16, 200, 50, 0.92, 7777.0);
702 opt.add_benchmark(p.clone());
703 assert!((opt.score_of(&p) - 7777.0).abs() < 1e-9);
704 }
705
706 #[test]
709 fn test_ranked_benchmarks_descending() {
710 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
711 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0));
712 opt.add_benchmark(hnsw_point(16, 200, 50, 0.85, 5000.0));
713 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0));
714 let ranked = opt.ranked_benchmarks();
715 assert_eq!(ranked[0].recall_at_10, 0.99);
716 assert_eq!(ranked[1].recall_at_10, 0.85);
717 assert_eq!(ranked[2].recall_at_10, 0.5);
718 }
719
720 #[test]
723 fn test_filter_by_recall() {
724 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
725 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0));
726 opt.add_benchmark(hnsw_point(16, 200, 50, 0.85, 5000.0));
727 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0));
728 let filtered = opt.filter_by_recall(0.8);
729 assert_eq!(filtered.len(), 2);
730 }
731
732 #[test]
733 fn test_filter_by_qps() {
734 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxQPS);
735 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0));
736 opt.add_benchmark(hnsw_point(16, 200, 50, 0.85, 5000.0));
737 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0));
738 let filtered = opt.filter_by_qps(5000.0);
739 assert_eq!(filtered.len(), 2);
740 }
741
742 #[test]
745 fn test_recall_stats_none_when_empty() {
746 let opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
747 assert!(opt.recall_stats().is_none());
748 }
749
750 #[test]
751 fn test_recall_stats_single() {
752 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
753 opt.add_benchmark(hnsw_point(16, 200, 50, 0.88, 5000.0));
754 let stats = opt.recall_stats().expect("stats");
755 assert!((stats.min - 0.88).abs() < 1e-9);
756 assert!((stats.max - 0.88).abs() < 1e-9);
757 assert!((stats.mean - 0.88).abs() < 1e-9);
758 }
759
760 #[test]
761 fn test_recall_stats_multiple() {
762 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
763 opt.add_benchmark(hnsw_point(4, 50, 10, 0.6, 15000.0));
764 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
765 opt.add_benchmark(hnsw_point(64, 800, 400, 0.99, 500.0));
766 let stats = opt.recall_stats().expect("stats");
767 assert!((stats.min - 0.6).abs() < 1e-9);
768 assert!((stats.max - 0.99).abs() < 1e-9);
769 assert!((stats.mean - (0.6 + 0.9 + 0.99) / 3.0).abs() < 1e-9);
770 }
771
772 #[test]
775 fn test_clear_removes_all() {
776 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
777 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
778 opt.add_benchmark(hnsw_point(4, 50, 10, 0.5, 15000.0));
779 opt.clear();
780 assert_eq!(opt.benchmark_count(), 0);
781 assert!(opt.best_params().is_none());
782 }
783
784 #[test]
787 fn test_group_by_variant_mixed() {
788 let mut opt = IndexOptimizer::new(IndexType::IVFPQ, OptimizationTarget::MaxRecall);
789 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
790 opt.add_benchmark(ivf_point(64, 8, 0.75, 7000.0));
791 opt.add_benchmark(hnsw_point(32, 400, 100, 0.95, 2000.0));
792 let groups = opt.group_by_variant();
793 assert_eq!(groups["hnsw"].len(), 2);
794 assert_eq!(groups["ivf"].len(), 1);
795 }
796
797 #[test]
800 fn test_index_params_as_hnsw() {
801 let p = IndexParams::Hnsw(HnswParams::new(16, 200, 50));
802 assert!(p.as_hnsw().is_some());
803 assert!(p.as_ivf().is_none());
804 }
805
806 #[test]
807 fn test_index_params_as_ivf() {
808 let p = IndexParams::Ivf(IvfParams::new(64, 8));
809 assert!(p.as_ivf().is_some());
810 assert!(p.as_hnsw().is_none());
811 }
812
813 #[test]
816 fn test_benchmarks_accessor() {
817 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
818 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 5000.0));
819 assert_eq!(opt.benchmarks().len(), 1);
820 }
821
822 #[test]
825 fn test_zero_qps_does_not_panic() {
826 let mut opt =
827 make_hnsw_optimizer(OptimizationTarget::BalancedRecallQPS { recall_weight: 0.5 });
828 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 0.0));
829 let best = opt.best_params().expect("some best");
831 let s = opt.score_of(best);
832 assert!((s - 0.45).abs() < 1e-9);
833 }
834
835 #[test]
836 fn test_identical_recall_uses_qps_tiebreak() {
837 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
838 opt.add_benchmark(hnsw_point(16, 200, 50, 0.9, 1000.0));
839 opt.add_benchmark(hnsw_point(32, 400, 100, 0.9, 5000.0));
840 assert!(opt.best_params().is_some());
843 }
844
845 #[test]
846 fn test_build_time_stored() {
847 let mut opt = make_hnsw_optimizer(OptimizationTarget::MaxRecall);
848 opt.add_benchmark(BenchmarkPoint::new(
849 IndexParams::Hnsw(HnswParams::new(16, 200, 50)),
850 0.9,
851 5000.0,
852 12345,
853 ));
854 assert_eq!(opt.benchmarks()[0].build_time_ms, 12345);
855 }
856
857 #[test]
858 fn test_hnsw_params_equality() {
859 let a = HnswParams::new(16, 200, 50);
860 let b = HnswParams::new(16, 200, 50);
861 assert_eq!(a, b);
862 }
863
864 #[test]
865 fn test_ivf_params_equality() {
866 let a = IvfParams::new(64, 8);
867 let b = IvfParams::new(64, 8);
868 assert_eq!(a, b);
869 }
870
871 #[test]
872 fn test_index_type_equality() {
873 assert_eq!(IndexType::HNSW, IndexType::HNSW);
874 assert_ne!(IndexType::HNSW, IndexType::IVF);
875 assert_ne!(IndexType::IVF, IndexType::IVFPQ);
876 assert_ne!(IndexType::IVFPQ, IndexType::Flat);
877 }
878}