1use serde::{Deserialize, Serialize};
29use std::collections::VecDeque;
30use std::fmt;
31use std::time::{Duration, Instant};
32
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub struct SearchParams {
40 pub ef_search: usize,
42 pub num_candidates: usize,
44 pub over_fetch_ratio: f64,
46 pub rerank_depth: usize,
48 pub early_termination: bool,
50}
51
52impl Default for SearchParams {
53 fn default() -> Self {
54 Self {
55 ef_search: 64,
56 num_candidates: 100,
57 over_fetch_ratio: 2.0,
58 rerank_depth: 50,
59 early_termination: true,
60 }
61 }
62}
63
64impl SearchParams {
65 pub fn high_recall() -> Self {
67 Self {
68 ef_search: 256,
69 num_candidates: 500,
70 over_fetch_ratio: 5.0,
71 rerank_depth: 200,
72 early_termination: false,
73 }
74 }
75
76 pub fn low_latency() -> Self {
78 Self {
79 ef_search: 32,
80 num_candidates: 50,
81 over_fetch_ratio: 1.5,
82 rerank_depth: 20,
83 early_termination: true,
84 }
85 }
86
87 pub fn clamp(&mut self) {
89 self.ef_search = self.ef_search.clamp(8, 1024);
90 self.num_candidates = self.num_candidates.clamp(10, 5000);
91 self.over_fetch_ratio = self.over_fetch_ratio.clamp(1.0, 20.0);
92 self.rerank_depth = self.rerank_depth.clamp(0, self.num_candidates);
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct QueryFeedback {
103 pub params: SearchParams,
105 pub k: usize,
107 pub relevant_in_top_k: usize,
109 pub total_relevant: Option<usize>,
111 pub latency: Duration,
113 #[serde(skip, default = "std::time::Instant::now")]
115 pub timestamp: Instant,
116}
117
118impl QueryFeedback {
119 pub fn recall_at_k(&self) -> f64 {
121 match self.total_relevant {
122 Some(total) if total > 0 => self.relevant_in_top_k as f64 / total as f64,
123 _ => {
124 if self.k == 0 {
126 return 0.0;
127 }
128 self.relevant_in_top_k as f64 / self.k as f64
129 }
130 }
131 }
132
133 pub fn precision_at_k(&self) -> f64 {
135 if self.k == 0 {
136 return 0.0;
137 }
138 self.relevant_in_top_k as f64 / self.k as f64
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TunerConfig {
149 pub target_recall: f64,
151 pub max_latency: Duration,
153 pub window_size: usize,
155 pub kp: f64,
157 pub ki: f64,
159 pub kd: f64,
161 pub min_samples: usize,
163 pub adjust_interval: usize,
165}
166
167impl Default for TunerConfig {
168 fn default() -> Self {
169 Self {
170 target_recall: 0.95,
171 max_latency: Duration::from_millis(100),
172 window_size: 100,
173 kp: 0.5,
174 ki: 0.1,
175 kd: 0.05,
176 min_samples: 10,
177 adjust_interval: 5,
178 }
179 }
180}
181
182#[derive(Debug, Clone, Default, Serialize, Deserialize)]
188pub struct TunerStats {
189 pub total_feedbacks: u64,
191 pub adjustments_made: u64,
193 pub current_recall: f64,
195 pub current_avg_latency_ms: f64,
197 pub target_met: bool,
199 pub avg_precision: f64,
201 pub recall_history: Vec<f64>,
203}
204
205impl TunerStats {
206 pub fn is_near_target(&self, target: f64, tolerance: f64) -> bool {
208 (self.current_recall - target).abs() < tolerance
209 }
210}
211
212pub struct AdaptiveRecallTuner {
221 config: TunerConfig,
222 current_params: SearchParams,
223 feedback_window: VecDeque<QueryFeedback>,
224 stats: TunerStats,
225 integral_error: f64,
227 prev_error: f64,
228 query_count: u64,
229}
230
231impl AdaptiveRecallTuner {
232 pub fn new(config: TunerConfig) -> Self {
234 Self {
235 config,
236 current_params: SearchParams::default(),
237 feedback_window: VecDeque::new(),
238 stats: TunerStats::default(),
239 integral_error: 0.0,
240 prev_error: 0.0,
241 query_count: 0,
242 }
243 }
244
245 pub fn with_initial_params(config: TunerConfig, initial: SearchParams) -> Self {
247 Self {
248 config,
249 current_params: initial,
250 feedback_window: VecDeque::new(),
251 stats: TunerStats::default(),
252 integral_error: 0.0,
253 prev_error: 0.0,
254 query_count: 0,
255 }
256 }
257
258 pub fn current_params(&self) -> &SearchParams {
260 &self.current_params
261 }
262
263 pub fn stats(&self) -> &TunerStats {
265 &self.stats
266 }
267
268 pub fn record_feedback(&mut self, feedback: QueryFeedback) -> bool {
272 self.feedback_window.push_back(feedback);
274 while self.feedback_window.len() > self.config.window_size {
275 self.feedback_window.pop_front();
276 }
277
278 self.stats.total_feedbacks += 1;
279 self.query_count += 1;
280
281 self.update_stats();
283
284 if self.feedback_window.len() >= self.config.min_samples
286 && self.query_count % self.config.adjust_interval as u64 == 0
287 {
288 self.adjust_parameters();
289 return true;
290 }
291
292 false
293 }
294
295 pub fn force_adjust(&mut self) {
297 if self.feedback_window.len() >= self.config.min_samples {
298 self.adjust_parameters();
299 }
300 }
301
302 pub fn reset(&mut self) {
304 self.current_params = SearchParams::default();
305 self.feedback_window.clear();
306 self.stats = TunerStats::default();
307 self.integral_error = 0.0;
308 self.prev_error = 0.0;
309 self.query_count = 0;
310 }
311
312 fn update_stats(&mut self) {
315 if self.feedback_window.is_empty() {
316 return;
317 }
318
319 let recalls: Vec<f64> = self
320 .feedback_window
321 .iter()
322 .map(|f| f.recall_at_k())
323 .collect();
324 let precisions: Vec<f64> = self
325 .feedback_window
326 .iter()
327 .map(|f| f.precision_at_k())
328 .collect();
329 let latencies: Vec<f64> = self
330 .feedback_window
331 .iter()
332 .map(|f| f.latency.as_millis() as f64)
333 .collect();
334
335 let n = recalls.len() as f64;
336 self.stats.current_recall = recalls.iter().sum::<f64>() / n;
337 self.stats.avg_precision = precisions.iter().sum::<f64>() / n;
338 self.stats.current_avg_latency_ms = latencies.iter().sum::<f64>() / n;
339 self.stats.target_met = self.stats.current_recall >= self.config.target_recall;
340
341 self.stats.recall_history.push(self.stats.current_recall);
343 if self.stats.recall_history.len() > 50 {
344 self.stats.recall_history.remove(0);
345 }
346 }
347
348 fn adjust_parameters(&mut self) {
349 let error = self.config.target_recall - self.stats.current_recall;
350
351 self.integral_error += error;
353 self.integral_error = self.integral_error.clamp(-10.0, 10.0);
355
356 let derivative = error - self.prev_error;
357 self.prev_error = error;
358
359 let adjustment = self.config.kp * error
360 + self.config.ki * self.integral_error
361 + self.config.kd * derivative;
362
363 let scale = 1.0 + adjustment;
368
369 self.current_params.ef_search =
370 ((self.current_params.ef_search as f64 * scale) as usize).max(8);
371 self.current_params.num_candidates =
372 ((self.current_params.num_candidates as f64 * scale) as usize).max(10);
373 self.current_params.over_fetch_ratio =
374 (self.current_params.over_fetch_ratio * scale).max(1.0);
375 self.current_params.rerank_depth =
376 ((self.current_params.rerank_depth as f64 * scale) as usize).max(1);
377
378 if self.stats.current_avg_latency_ms > self.config.max_latency.as_millis() as f64 {
380 let latency_ratio =
381 self.config.max_latency.as_millis() as f64 / self.stats.current_avg_latency_ms;
382 self.current_params.ef_search =
383 ((self.current_params.ef_search as f64 * latency_ratio) as usize).max(8);
384 self.current_params.num_candidates =
385 ((self.current_params.num_candidates as f64 * latency_ratio) as usize).max(10);
386 }
387
388 self.current_params.clamp();
389 self.stats.adjustments_made += 1;
390 }
391}
392
393impl fmt::Debug for AdaptiveRecallTuner {
394 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395 f.debug_struct("AdaptiveRecallTuner")
396 .field("config", &self.config)
397 .field("current_params", &self.current_params)
398 .field("stats", &self.stats)
399 .finish()
400 }
401}
402
403pub struct RecallEvaluator;
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct RecallEvaluation {
413 pub recall_at_k: f64,
415 pub precision_at_k: f64,
417 pub f1_score: f64,
419 pub average_precision: f64,
421 pub ndcg: f64,
423 pub num_queries: usize,
425}
426
427impl RecallEvaluator {
428 pub fn evaluate(
434 results: &[Vec<String>],
435 ground_truth: &[Vec<String>],
436 k: usize,
437 ) -> RecallEvaluation {
438 if results.is_empty() || ground_truth.is_empty() {
439 return RecallEvaluation {
440 recall_at_k: 0.0,
441 precision_at_k: 0.0,
442 f1_score: 0.0,
443 average_precision: 0.0,
444 ndcg: 0.0,
445 num_queries: 0,
446 };
447 }
448
449 let n = results.len().min(ground_truth.len());
450 let mut total_recall = 0.0;
451 let mut total_precision = 0.0;
452 let mut total_ap = 0.0;
453 let mut total_ndcg = 0.0;
454
455 for i in 0..n {
456 let result_k: Vec<_> = results[i].iter().take(k).cloned().collect();
457 let truth: std::collections::HashSet<_> = ground_truth[i].iter().cloned().collect();
458
459 if truth.is_empty() {
460 continue;
461 }
462
463 let relevant_found = result_k.iter().filter(|r| truth.contains(*r)).count();
465 let recall = relevant_found as f64 / truth.len() as f64;
466 total_recall += recall;
467
468 let precision = if result_k.is_empty() {
470 0.0
471 } else {
472 relevant_found as f64 / result_k.len() as f64
473 };
474 total_precision += precision;
475
476 let mut running_relevant = 0.0;
478 let mut ap_sum = 0.0;
479 for (pos, item) in result_k.iter().enumerate() {
480 if truth.contains(item) {
481 running_relevant += 1.0;
482 ap_sum += running_relevant / (pos + 1) as f64;
483 }
484 }
485 total_ap += if truth.is_empty() {
486 0.0
487 } else {
488 ap_sum / truth.len() as f64
489 };
490
491 let dcg: f64 = result_k
493 .iter()
494 .enumerate()
495 .map(|(pos, item)| {
496 let rel = if truth.contains(item) { 1.0 } else { 0.0 };
497 rel / ((pos + 2) as f64).ln()
498 })
499 .sum();
500 let ideal_k = truth.len().min(k);
501 let ideal_dcg: f64 = (0..ideal_k).map(|pos| 1.0 / ((pos + 2) as f64).ln()).sum();
502 let ndcg = if ideal_dcg > 0.0 {
503 dcg / ideal_dcg
504 } else {
505 0.0
506 };
507 total_ndcg += ndcg;
508 }
509
510 let n_f = n as f64;
511 let avg_recall = total_recall / n_f;
512 let avg_precision = total_precision / n_f;
513 let f1 = if (avg_recall + avg_precision) > 0.0 {
514 2.0 * avg_recall * avg_precision / (avg_recall + avg_precision)
515 } else {
516 0.0
517 };
518
519 RecallEvaluation {
520 recall_at_k: avg_recall,
521 precision_at_k: avg_precision,
522 f1_score: f1,
523 average_precision: total_ap / n_f,
524 ndcg: total_ndcg / n_f,
525 num_queries: n,
526 }
527 }
528}
529
530#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
541 fn test_search_params_default() {
542 let p = SearchParams::default();
543 assert_eq!(p.ef_search, 64);
544 assert_eq!(p.num_candidates, 100);
545 assert!((p.over_fetch_ratio - 2.0).abs() < 0.01);
546 assert_eq!(p.rerank_depth, 50);
547 assert!(p.early_termination);
548 }
549
550 #[test]
551 fn test_search_params_high_recall() {
552 let p = SearchParams::high_recall();
553 assert!(p.ef_search >= 256);
554 assert!(!p.early_termination);
555 }
556
557 #[test]
558 fn test_search_params_low_latency() {
559 let p = SearchParams::low_latency();
560 assert!(p.ef_search <= 32);
561 assert!(p.early_termination);
562 }
563
564 #[test]
565 fn test_search_params_clamp() {
566 let mut p = SearchParams {
567 ef_search: 0,
568 num_candidates: 0,
569 over_fetch_ratio: 0.1,
570 rerank_depth: 99999,
571 early_termination: false,
572 };
573 p.clamp();
574 assert_eq!(p.ef_search, 8);
575 assert_eq!(p.num_candidates, 10);
576 assert!((p.over_fetch_ratio - 1.0).abs() < 0.01);
577 assert_eq!(p.rerank_depth, 10); }
579
580 #[test]
583 fn test_feedback_recall_with_ground_truth() {
584 let fb = QueryFeedback {
585 params: SearchParams::default(),
586 k: 10,
587 relevant_in_top_k: 8,
588 total_relevant: Some(10),
589 latency: Duration::from_millis(50),
590 timestamp: Instant::now(),
591 };
592 assert!((fb.recall_at_k() - 0.8).abs() < 0.01);
593 }
594
595 #[test]
596 fn test_feedback_recall_without_ground_truth() {
597 let fb = QueryFeedback {
598 params: SearchParams::default(),
599 k: 10,
600 relevant_in_top_k: 7,
601 total_relevant: None,
602 latency: Duration::from_millis(50),
603 timestamp: Instant::now(),
604 };
605 assert!((fb.recall_at_k() - 0.7).abs() < 0.01);
606 }
607
608 #[test]
609 fn test_feedback_precision() {
610 let fb = QueryFeedback {
611 params: SearchParams::default(),
612 k: 10,
613 relevant_in_top_k: 5,
614 total_relevant: Some(20),
615 latency: Duration::from_millis(50),
616 timestamp: Instant::now(),
617 };
618 assert!((fb.precision_at_k() - 0.5).abs() < 0.01);
619 }
620
621 #[test]
622 fn test_feedback_k_zero() {
623 let fb = QueryFeedback {
624 params: SearchParams::default(),
625 k: 0,
626 relevant_in_top_k: 0,
627 total_relevant: None,
628 latency: Duration::from_millis(1),
629 timestamp: Instant::now(),
630 };
631 assert_eq!(fb.recall_at_k(), 0.0);
632 assert_eq!(fb.precision_at_k(), 0.0);
633 }
634
635 #[test]
638 fn test_tuner_config_default() {
639 let c = TunerConfig::default();
640 assert!((c.target_recall - 0.95).abs() < 0.01);
641 assert_eq!(c.window_size, 100);
642 assert_eq!(c.min_samples, 10);
643 }
644
645 fn make_feedback(recall_ratio: f64, k: usize, latency_ms: u64) -> QueryFeedback {
648 let relevant = (k as f64 * recall_ratio) as usize;
649 QueryFeedback {
650 params: SearchParams::default(),
651 k,
652 relevant_in_top_k: relevant,
653 total_relevant: Some(k),
654 latency: Duration::from_millis(latency_ms),
655 timestamp: Instant::now(),
656 }
657 }
658
659 #[test]
660 fn test_tuner_initial_params() {
661 let tuner = AdaptiveRecallTuner::new(TunerConfig::default());
662 assert_eq!(tuner.current_params().ef_search, 64);
663 }
664
665 #[test]
666 fn test_tuner_with_initial_params() {
667 let initial = SearchParams::high_recall();
668 let tuner =
669 AdaptiveRecallTuner::with_initial_params(TunerConfig::default(), initial.clone());
670 assert_eq!(tuner.current_params().ef_search, initial.ef_search);
671 }
672
673 #[test]
674 fn test_tuner_no_adjust_before_min_samples() {
675 let config = TunerConfig {
676 min_samples: 10,
677 adjust_interval: 1,
678 ..Default::default()
679 };
680 let mut tuner = AdaptiveRecallTuner::new(config);
681
682 for _ in 0..5 {
683 let adjusted = tuner.record_feedback(make_feedback(0.5, 10, 50));
684 assert!(!adjusted);
685 }
686 assert_eq!(tuner.stats().adjustments_made, 0);
687 }
688
689 #[test]
690 fn test_tuner_adjusts_after_min_samples() {
691 let config = TunerConfig {
692 min_samples: 5,
693 adjust_interval: 5,
694 ..Default::default()
695 };
696 let mut tuner = AdaptiveRecallTuner::new(config);
697
698 for i in 0..10 {
699 tuner.record_feedback(make_feedback(0.5, 10, 50));
700 if i >= 4 && (i + 1) % 5 == 0 {
701 }
703 }
704 assert!(tuner.stats().adjustments_made > 0);
705 }
706
707 #[test]
708 fn test_tuner_increases_params_for_low_recall() {
709 let config = TunerConfig {
710 min_samples: 5,
711 adjust_interval: 1,
712 target_recall: 0.95,
713 kp: 0.5,
714 ki: 0.0,
715 kd: 0.0,
716 ..Default::default()
717 };
718 let mut tuner = AdaptiveRecallTuner::new(config);
719 let initial_ef = tuner.current_params().ef_search;
720
721 for _ in 0..10 {
723 tuner.record_feedback(make_feedback(0.3, 10, 30));
724 }
725
726 assert!(tuner.current_params().ef_search > initial_ef);
728 }
729
730 #[test]
731 fn test_tuner_decreases_params_for_high_recall() {
732 let config = TunerConfig {
733 min_samples: 5,
734 adjust_interval: 1,
735 target_recall: 0.5,
736 kp: 0.5,
737 ki: 0.0,
738 kd: 0.0,
739 ..Default::default()
740 };
741 let initial = SearchParams::high_recall();
742 let mut tuner = AdaptiveRecallTuner::with_initial_params(config, initial.clone());
743
744 for _ in 0..10 {
746 tuner.record_feedback(make_feedback(0.99, 10, 30));
747 }
748
749 assert!(tuner.current_params().ef_search <= initial.ef_search);
751 }
752
753 #[test]
754 fn test_tuner_respects_latency_constraint() {
755 let config = TunerConfig {
756 min_samples: 5,
757 adjust_interval: 1,
758 max_latency: Duration::from_millis(50),
759 target_recall: 0.99,
760 kp: 1.0,
761 ki: 0.0,
762 kd: 0.0,
763 ..Default::default()
764 };
765 let mut tuner = AdaptiveRecallTuner::new(config);
766
767 for _ in 0..20 {
769 tuner.record_feedback(make_feedback(0.3, 10, 200));
770 }
771
772 assert!(tuner.current_params().ef_search < 1024);
774 }
775
776 #[test]
777 fn test_tuner_stats_tracking() {
778 let config = TunerConfig {
779 min_samples: 3,
780 adjust_interval: 1,
781 ..Default::default()
782 };
783 let mut tuner = AdaptiveRecallTuner::new(config);
784
785 for _ in 0..5 {
786 tuner.record_feedback(make_feedback(0.8, 10, 40));
787 }
788
789 assert_eq!(tuner.stats().total_feedbacks, 5);
790 assert!(tuner.stats().current_recall > 0.0);
791 assert!(tuner.stats().current_avg_latency_ms > 0.0);
792 }
793
794 #[test]
795 fn test_tuner_reset() {
796 let config = TunerConfig {
797 min_samples: 3,
798 adjust_interval: 1,
799 ..Default::default()
800 };
801 let mut tuner = AdaptiveRecallTuner::new(config);
802
803 for _ in 0..5 {
804 tuner.record_feedback(make_feedback(0.8, 10, 40));
805 }
806 tuner.reset();
807
808 assert_eq!(tuner.stats().total_feedbacks, 0);
809 assert_eq!(tuner.stats().adjustments_made, 0);
810 assert_eq!(tuner.current_params().ef_search, 64);
811 }
812
813 #[test]
814 fn test_tuner_force_adjust() {
815 let config = TunerConfig {
816 min_samples: 3,
817 adjust_interval: 100, ..Default::default()
819 };
820 let mut tuner = AdaptiveRecallTuner::new(config);
821
822 for _ in 0..5 {
823 tuner.record_feedback(make_feedback(0.5, 10, 40));
824 }
825 assert_eq!(tuner.stats().adjustments_made, 0);
826
827 tuner.force_adjust();
828 assert_eq!(tuner.stats().adjustments_made, 1);
829 }
830
831 #[test]
832 fn test_stats_near_target() {
833 let stats = TunerStats {
834 current_recall: 0.94,
835 ..Default::default()
836 };
837 assert!(stats.is_near_target(0.95, 0.02));
838 assert!(!stats.is_near_target(0.95, 0.005));
839 }
840
841 #[test]
842 fn test_recall_history() {
843 let config = TunerConfig {
844 min_samples: 3,
845 adjust_interval: 1,
846 ..Default::default()
847 };
848 let mut tuner = AdaptiveRecallTuner::new(config);
849
850 for _ in 0..5 {
851 tuner.record_feedback(make_feedback(0.8, 10, 40));
852 }
853
854 assert!(!tuner.stats().recall_history.is_empty());
855 }
856
857 #[test]
860 fn test_evaluator_perfect_recall() {
861 let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
862 let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
863 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
864 assert!((eval.recall_at_k - 1.0).abs() < 0.01);
865 assert!((eval.precision_at_k - 1.0).abs() < 0.01);
866 }
867
868 #[test]
869 fn test_evaluator_partial_recall() {
870 let results = vec![vec!["a".to_string(), "b".to_string(), "x".to_string()]];
871 let truth = vec![vec![
872 "a".to_string(),
873 "b".to_string(),
874 "c".to_string(),
875 "d".to_string(),
876 ]];
877 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
878 assert!((eval.recall_at_k - 0.5).abs() < 0.01); assert!((eval.precision_at_k - 2.0 / 3.0).abs() < 0.01); }
881
882 #[test]
883 fn test_evaluator_zero_recall() {
884 let results = vec![vec!["x".to_string(), "y".to_string(), "z".to_string()]];
885 let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
886 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
887 assert_eq!(eval.recall_at_k, 0.0);
888 assert_eq!(eval.precision_at_k, 0.0);
889 }
890
891 #[test]
892 fn test_evaluator_empty() {
893 let eval = RecallEvaluator::evaluate(&[], &[], 10);
894 assert_eq!(eval.num_queries, 0);
895 assert_eq!(eval.recall_at_k, 0.0);
896 }
897
898 #[test]
899 fn test_evaluator_multiple_queries() {
900 let results = vec![
901 vec!["a".to_string(), "b".to_string()],
902 vec!["c".to_string(), "d".to_string()],
903 ];
904 let truth = vec![
905 vec!["a".to_string(), "b".to_string()],
906 vec!["c".to_string(), "x".to_string()],
907 ];
908 let eval = RecallEvaluator::evaluate(&results, &truth, 2);
909 assert_eq!(eval.num_queries, 2);
910 assert!((eval.recall_at_k - 0.75).abs() < 0.01);
912 }
913
914 #[test]
915 fn test_evaluator_k_less_than_results() {
916 let results = vec![vec![
917 "a".to_string(),
918 "b".to_string(),
919 "c".to_string(),
920 "d".to_string(),
921 ]];
922 let truth = vec![vec!["a".to_string(), "b".to_string()]];
923 let eval = RecallEvaluator::evaluate(&results, &truth, 2);
924 assert!((eval.recall_at_k - 1.0).abs() < 0.01);
926 }
927
928 #[test]
929 fn test_evaluator_ndcg() {
930 let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
931 let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
932 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
933 assert!((eval.ndcg - 1.0).abs() < 0.01);
935 }
936
937 #[test]
938 fn test_evaluator_f1_score() {
939 let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
940 let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
941 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
942 assert!((eval.f1_score - 1.0).abs() < 0.01);
944 }
945
946 #[test]
947 fn test_evaluator_average_precision() {
948 let results = vec![vec!["a".to_string(), "x".to_string(), "b".to_string()]];
949 let truth = vec![vec!["a".to_string(), "b".to_string()]];
950 let eval = RecallEvaluator::evaluate(&results, &truth, 3);
951 assert!(eval.average_precision > 0.0);
953 }
954
955 #[test]
958 fn test_tuner_convergence_simulation() {
959 let config = TunerConfig {
960 min_samples: 5,
961 adjust_interval: 1,
962 target_recall: 0.9,
963 kp: 0.3,
964 ki: 0.05,
965 kd: 0.02,
966 ..Default::default()
967 };
968 let mut tuner = AdaptiveRecallTuner::new(config);
969
970 for i in 0..50 {
972 let recall = 0.5 + (i as f64 * 0.01).min(0.45);
973 tuner.record_feedback(make_feedback(recall, 10, 30));
974 }
975
976 assert!(tuner.stats().adjustments_made > 0);
978 assert!(tuner.stats().total_feedbacks == 50);
979 }
980
981 #[test]
982 fn test_integral_windup_prevention() {
983 let config = TunerConfig {
984 min_samples: 3,
985 adjust_interval: 1,
986 target_recall: 0.99,
987 kp: 0.1,
988 ki: 1.0, kd: 0.0,
990 ..Default::default()
991 };
992 let mut tuner = AdaptiveRecallTuner::new(config);
993
994 for _ in 0..100 {
996 tuner.record_feedback(make_feedback(0.1, 10, 20));
997 }
998
999 assert!(tuner.current_params().ef_search <= 1024);
1001 assert!(tuner.current_params().num_candidates <= 5000);
1002 }
1003}