Skip to main content

oxirs_vec/
adaptive_recall_tuner.rs

1//! # Adaptive Recall Tuner for Vector Search
2//!
3//! Dynamically adjusts vector search parameters (ef_search, num_candidates,
4//! re-ranking depth, over-fetch ratio) to achieve target recall@k while
5//! minimizing latency.
6//!
7//! ## Problem
8//!
9//! Fixed search parameters lead to either:
10//! - Under-fetching: Missing relevant results (low recall)
11//! - Over-fetching: Wasting compute on unnecessary candidates (high latency)
12//!
13//! ## Solution
14//!
15//! The adaptive tuner observes query feedback (user clicks, explicit relevance
16//! judgments, or ground-truth evaluations) and uses a PID-like control loop
17//! to adjust parameters in real-time.
18//!
19//! ## Architecture
20//!
21//! ```text
22//! Query --> SearchEngine --> Results --> FeedbackCollector
23//!   ^                                         |
24//!   |                                         v
25//!   +--- ParameterAdjuster <--- RecallEstimator
26//! ```
27
28use serde::{Deserialize, Serialize};
29use std::collections::VecDeque;
30use std::fmt;
31use std::time::{Duration, Instant};
32
33// ---------------------------------------------------------------------------
34// Search parameters
35// ---------------------------------------------------------------------------
36
37/// Tunable search parameters for vector index queries.
38#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39pub struct SearchParams {
40    /// HNSW ef_search: number of candidates to explore during search.
41    pub ef_search: usize,
42    /// Number of candidates to pre-fetch before re-ranking.
43    pub num_candidates: usize,
44    /// Over-fetch ratio: fetch this many times more candidates than k.
45    pub over_fetch_ratio: f64,
46    /// Re-ranking depth: how many candidates to re-rank with exact distances.
47    pub rerank_depth: usize,
48    /// Whether to enable approximate early termination.
49    pub early_termination: bool,
50}
51
52impl Default for SearchParams {
53    fn default() -> Self {
54        Self {
55            ef_search: 64,
56            num_candidates: 100,
57            over_fetch_ratio: 2.0,
58            rerank_depth: 50,
59            early_termination: true,
60        }
61    }
62}
63
64impl SearchParams {
65    /// Create parameters optimized for high recall.
66    pub fn high_recall() -> Self {
67        Self {
68            ef_search: 256,
69            num_candidates: 500,
70            over_fetch_ratio: 5.0,
71            rerank_depth: 200,
72            early_termination: false,
73        }
74    }
75
76    /// Create parameters optimized for low latency.
77    pub fn low_latency() -> Self {
78        Self {
79            ef_search: 32,
80            num_candidates: 50,
81            over_fetch_ratio: 1.5,
82            rerank_depth: 20,
83            early_termination: true,
84        }
85    }
86
87    /// Clamp all parameters to valid ranges.
88    pub fn clamp(&mut self) {
89        self.ef_search = self.ef_search.clamp(8, 1024);
90        self.num_candidates = self.num_candidates.clamp(10, 5000);
91        self.over_fetch_ratio = self.over_fetch_ratio.clamp(1.0, 20.0);
92        self.rerank_depth = self.rerank_depth.clamp(0, self.num_candidates);
93    }
94}
95
96// ---------------------------------------------------------------------------
97// Feedback
98// ---------------------------------------------------------------------------
99
100/// Feedback from a single query execution.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct QueryFeedback {
103    /// The parameters used for this query.
104    pub params: SearchParams,
105    /// Number of results requested (k).
106    pub k: usize,
107    /// Number of truly relevant results in the top-k (from ground truth or clicks).
108    pub relevant_in_top_k: usize,
109    /// Total number of known relevant results (if available).
110    pub total_relevant: Option<usize>,
111    /// Query latency.
112    pub latency: Duration,
113    /// Timestamp of the feedback.
114    #[serde(skip, default = "std::time::Instant::now")]
115    pub timestamp: Instant,
116}
117
118impl QueryFeedback {
119    /// Compute recall@k for this feedback.
120    pub fn recall_at_k(&self) -> f64 {
121        match self.total_relevant {
122            Some(total) if total > 0 => self.relevant_in_top_k as f64 / total as f64,
123            _ => {
124                // If total_relevant is unknown, estimate from k
125                if self.k == 0 {
126                    return 0.0;
127                }
128                self.relevant_in_top_k as f64 / self.k as f64
129            }
130        }
131    }
132
133    /// Compute precision@k.
134    pub fn precision_at_k(&self) -> f64 {
135        if self.k == 0 {
136            return 0.0;
137        }
138        self.relevant_in_top_k as f64 / self.k as f64
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Tuner configuration
144// ---------------------------------------------------------------------------
145
146/// Configuration for the adaptive recall tuner.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TunerConfig {
149    /// Target recall@k (e.g., 0.95 for 95%).
150    pub target_recall: f64,
151    /// Maximum acceptable latency.
152    pub max_latency: Duration,
153    /// Number of recent feedback samples to consider.
154    pub window_size: usize,
155    /// Proportional gain for the control loop.
156    pub kp: f64,
157    /// Integral gain for the control loop.
158    pub ki: f64,
159    /// Derivative gain for the control loop.
160    pub kd: f64,
161    /// Minimum number of samples before adjusting.
162    pub min_samples: usize,
163    /// How often to recalculate (in number of queries).
164    pub adjust_interval: usize,
165}
166
167impl Default for TunerConfig {
168    fn default() -> Self {
169        Self {
170            target_recall: 0.95,
171            max_latency: Duration::from_millis(100),
172            window_size: 100,
173            kp: 0.5,
174            ki: 0.1,
175            kd: 0.05,
176            min_samples: 10,
177            adjust_interval: 5,
178        }
179    }
180}
181
182// ---------------------------------------------------------------------------
183// Tuner statistics
184// ---------------------------------------------------------------------------
185
186/// Statistics from the adaptive tuner.
187#[derive(Debug, Clone, Default, Serialize, Deserialize)]
188pub struct TunerStats {
189    /// Total number of feedback samples received.
190    pub total_feedbacks: u64,
191    /// Number of parameter adjustments made.
192    pub adjustments_made: u64,
193    /// Current estimated recall.
194    pub current_recall: f64,
195    /// Current average latency in milliseconds.
196    pub current_avg_latency_ms: f64,
197    /// Whether the tuner is currently meeting its target.
198    pub target_met: bool,
199    /// Running average precision.
200    pub avg_precision: f64,
201    /// Historical recall values (last N).
202    pub recall_history: Vec<f64>,
203}
204
205impl TunerStats {
206    /// Check if recall is within tolerance of target.
207    pub fn is_near_target(&self, target: f64, tolerance: f64) -> bool {
208        (self.current_recall - target).abs() < tolerance
209    }
210}
211
212// ---------------------------------------------------------------------------
213// Adaptive Recall Tuner
214// ---------------------------------------------------------------------------
215
216/// The main adaptive recall tuner.
217///
218/// Collects query feedback and adjusts search parameters to converge on
219/// the target recall while respecting latency constraints.
220pub struct AdaptiveRecallTuner {
221    config: TunerConfig,
222    current_params: SearchParams,
223    feedback_window: VecDeque<QueryFeedback>,
224    stats: TunerStats,
225    /// PID controller state
226    integral_error: f64,
227    prev_error: f64,
228    query_count: u64,
229}
230
231impl AdaptiveRecallTuner {
232    /// Create a new tuner with default parameters and configuration.
233    pub fn new(config: TunerConfig) -> Self {
234        Self {
235            config,
236            current_params: SearchParams::default(),
237            feedback_window: VecDeque::new(),
238            stats: TunerStats::default(),
239            integral_error: 0.0,
240            prev_error: 0.0,
241            query_count: 0,
242        }
243    }
244
245    /// Create with specific initial parameters.
246    pub fn with_initial_params(config: TunerConfig, initial: SearchParams) -> Self {
247        Self {
248            config,
249            current_params: initial,
250            feedback_window: VecDeque::new(),
251            stats: TunerStats::default(),
252            integral_error: 0.0,
253            prev_error: 0.0,
254            query_count: 0,
255        }
256    }
257
258    /// Get the current recommended search parameters.
259    pub fn current_params(&self) -> &SearchParams {
260        &self.current_params
261    }
262
263    /// Get tuner statistics.
264    pub fn stats(&self) -> &TunerStats {
265        &self.stats
266    }
267
268    /// Record a query feedback observation.
269    ///
270    /// Returns `true` if parameters were adjusted as a result.
271    pub fn record_feedback(&mut self, feedback: QueryFeedback) -> bool {
272        // Add to window
273        self.feedback_window.push_back(feedback);
274        while self.feedback_window.len() > self.config.window_size {
275            self.feedback_window.pop_front();
276        }
277
278        self.stats.total_feedbacks += 1;
279        self.query_count += 1;
280
281        // Update running statistics
282        self.update_stats();
283
284        // Check if we should adjust
285        if self.feedback_window.len() >= self.config.min_samples
286            && self.query_count % self.config.adjust_interval as u64 == 0
287        {
288            self.adjust_parameters();
289            return true;
290        }
291
292        false
293    }
294
295    /// Force a parameter adjustment regardless of interval.
296    pub fn force_adjust(&mut self) {
297        if self.feedback_window.len() >= self.config.min_samples {
298            self.adjust_parameters();
299        }
300    }
301
302    /// Reset the tuner state (keeps configuration).
303    pub fn reset(&mut self) {
304        self.current_params = SearchParams::default();
305        self.feedback_window.clear();
306        self.stats = TunerStats::default();
307        self.integral_error = 0.0;
308        self.prev_error = 0.0;
309        self.query_count = 0;
310    }
311
312    // ── Internal methods ──────────────────────────────────────────────────
313
314    fn update_stats(&mut self) {
315        if self.feedback_window.is_empty() {
316            return;
317        }
318
319        let recalls: Vec<f64> = self
320            .feedback_window
321            .iter()
322            .map(|f| f.recall_at_k())
323            .collect();
324        let precisions: Vec<f64> = self
325            .feedback_window
326            .iter()
327            .map(|f| f.precision_at_k())
328            .collect();
329        let latencies: Vec<f64> = self
330            .feedback_window
331            .iter()
332            .map(|f| f.latency.as_millis() as f64)
333            .collect();
334
335        let n = recalls.len() as f64;
336        self.stats.current_recall = recalls.iter().sum::<f64>() / n;
337        self.stats.avg_precision = precisions.iter().sum::<f64>() / n;
338        self.stats.current_avg_latency_ms = latencies.iter().sum::<f64>() / n;
339        self.stats.target_met = self.stats.current_recall >= self.config.target_recall;
340
341        // Record recall history (keep last 50)
342        self.stats.recall_history.push(self.stats.current_recall);
343        if self.stats.recall_history.len() > 50 {
344            self.stats.recall_history.remove(0);
345        }
346    }
347
348    fn adjust_parameters(&mut self) {
349        let error = self.config.target_recall - self.stats.current_recall;
350
351        // PID control
352        self.integral_error += error;
353        // Clamp integral to prevent windup
354        self.integral_error = self.integral_error.clamp(-10.0, 10.0);
355
356        let derivative = error - self.prev_error;
357        self.prev_error = error;
358
359        let adjustment = self.config.kp * error
360            + self.config.ki * self.integral_error
361            + self.config.kd * derivative;
362
363        // Apply adjustment to parameters
364        // Positive adjustment means recall is too low -> increase search effort
365        // Negative adjustment means recall is high enough -> can reduce effort
366
367        let scale = 1.0 + adjustment;
368
369        self.current_params.ef_search =
370            ((self.current_params.ef_search as f64 * scale) as usize).max(8);
371        self.current_params.num_candidates =
372            ((self.current_params.num_candidates as f64 * scale) as usize).max(10);
373        self.current_params.over_fetch_ratio =
374            (self.current_params.over_fetch_ratio * scale).max(1.0);
375        self.current_params.rerank_depth =
376            ((self.current_params.rerank_depth as f64 * scale) as usize).max(1);
377
378        // If latency is too high, pull back
379        if self.stats.current_avg_latency_ms > self.config.max_latency.as_millis() as f64 {
380            let latency_ratio =
381                self.config.max_latency.as_millis() as f64 / self.stats.current_avg_latency_ms;
382            self.current_params.ef_search =
383                ((self.current_params.ef_search as f64 * latency_ratio) as usize).max(8);
384            self.current_params.num_candidates =
385                ((self.current_params.num_candidates as f64 * latency_ratio) as usize).max(10);
386        }
387
388        self.current_params.clamp();
389        self.stats.adjustments_made += 1;
390    }
391}
392
393impl fmt::Debug for AdaptiveRecallTuner {
394    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395        f.debug_struct("AdaptiveRecallTuner")
396            .field("config", &self.config)
397            .field("current_params", &self.current_params)
398            .field("stats", &self.stats)
399            .finish()
400    }
401}
402
403// ---------------------------------------------------------------------------
404// Recall evaluator (ground-truth based)
405// ---------------------------------------------------------------------------
406
407/// Evaluates recall by comparing search results against ground-truth.
408pub struct RecallEvaluator;
409
410/// A single recall evaluation result.
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct RecallEvaluation {
413    /// Recall at the specified k.
414    pub recall_at_k: f64,
415    /// Precision at the specified k.
416    pub precision_at_k: f64,
417    /// F1 score.
418    pub f1_score: f64,
419    /// Average precision (AP).
420    pub average_precision: f64,
421    /// Normalized discounted cumulative gain (nDCG).
422    pub ndcg: f64,
423    /// Number of queries evaluated.
424    pub num_queries: usize,
425}
426
427impl RecallEvaluator {
428    /// Evaluate recall for a set of query results against ground truth.
429    ///
430    /// - `results`: For each query, the IDs returned by the search engine.
431    /// - `ground_truth`: For each query, the set of truly relevant IDs.
432    /// - `k`: The cutoff for evaluation.
433    pub fn evaluate(
434        results: &[Vec<String>],
435        ground_truth: &[Vec<String>],
436        k: usize,
437    ) -> RecallEvaluation {
438        if results.is_empty() || ground_truth.is_empty() {
439            return RecallEvaluation {
440                recall_at_k: 0.0,
441                precision_at_k: 0.0,
442                f1_score: 0.0,
443                average_precision: 0.0,
444                ndcg: 0.0,
445                num_queries: 0,
446            };
447        }
448
449        let n = results.len().min(ground_truth.len());
450        let mut total_recall = 0.0;
451        let mut total_precision = 0.0;
452        let mut total_ap = 0.0;
453        let mut total_ndcg = 0.0;
454
455        for i in 0..n {
456            let result_k: Vec<_> = results[i].iter().take(k).cloned().collect();
457            let truth: std::collections::HashSet<_> = ground_truth[i].iter().cloned().collect();
458
459            if truth.is_empty() {
460                continue;
461            }
462
463            // Recall@k
464            let relevant_found = result_k.iter().filter(|r| truth.contains(*r)).count();
465            let recall = relevant_found as f64 / truth.len() as f64;
466            total_recall += recall;
467
468            // Precision@k
469            let precision = if result_k.is_empty() {
470                0.0
471            } else {
472                relevant_found as f64 / result_k.len() as f64
473            };
474            total_precision += precision;
475
476            // Average Precision (AP)
477            let mut running_relevant = 0.0;
478            let mut ap_sum = 0.0;
479            for (pos, item) in result_k.iter().enumerate() {
480                if truth.contains(item) {
481                    running_relevant += 1.0;
482                    ap_sum += running_relevant / (pos + 1) as f64;
483                }
484            }
485            total_ap += if truth.is_empty() {
486                0.0
487            } else {
488                ap_sum / truth.len() as f64
489            };
490
491            // nDCG
492            let dcg: f64 = result_k
493                .iter()
494                .enumerate()
495                .map(|(pos, item)| {
496                    let rel = if truth.contains(item) { 1.0 } else { 0.0 };
497                    rel / ((pos + 2) as f64).ln()
498                })
499                .sum();
500            let ideal_k = truth.len().min(k);
501            let ideal_dcg: f64 = (0..ideal_k).map(|pos| 1.0 / ((pos + 2) as f64).ln()).sum();
502            let ndcg = if ideal_dcg > 0.0 {
503                dcg / ideal_dcg
504            } else {
505                0.0
506            };
507            total_ndcg += ndcg;
508        }
509
510        let n_f = n as f64;
511        let avg_recall = total_recall / n_f;
512        let avg_precision = total_precision / n_f;
513        let f1 = if (avg_recall + avg_precision) > 0.0 {
514            2.0 * avg_recall * avg_precision / (avg_recall + avg_precision)
515        } else {
516            0.0
517        };
518
519        RecallEvaluation {
520            recall_at_k: avg_recall,
521            precision_at_k: avg_precision,
522            f1_score: f1,
523            average_precision: total_ap / n_f,
524            ndcg: total_ndcg / n_f,
525            num_queries: n,
526        }
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Tests
532// ---------------------------------------------------------------------------
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    // ── SearchParams tests ────────────────────────────────────────────────
539
540    #[test]
541    fn test_search_params_default() {
542        let p = SearchParams::default();
543        assert_eq!(p.ef_search, 64);
544        assert_eq!(p.num_candidates, 100);
545        assert!((p.over_fetch_ratio - 2.0).abs() < 0.01);
546        assert_eq!(p.rerank_depth, 50);
547        assert!(p.early_termination);
548    }
549
550    #[test]
551    fn test_search_params_high_recall() {
552        let p = SearchParams::high_recall();
553        assert!(p.ef_search >= 256);
554        assert!(!p.early_termination);
555    }
556
557    #[test]
558    fn test_search_params_low_latency() {
559        let p = SearchParams::low_latency();
560        assert!(p.ef_search <= 32);
561        assert!(p.early_termination);
562    }
563
564    #[test]
565    fn test_search_params_clamp() {
566        let mut p = SearchParams {
567            ef_search: 0,
568            num_candidates: 0,
569            over_fetch_ratio: 0.1,
570            rerank_depth: 99999,
571            early_termination: false,
572        };
573        p.clamp();
574        assert_eq!(p.ef_search, 8);
575        assert_eq!(p.num_candidates, 10);
576        assert!((p.over_fetch_ratio - 1.0).abs() < 0.01);
577        assert_eq!(p.rerank_depth, 10); // clamped to num_candidates
578    }
579
580    // ── QueryFeedback tests ───────────────────────────────────────────────
581
582    #[test]
583    fn test_feedback_recall_with_ground_truth() {
584        let fb = QueryFeedback {
585            params: SearchParams::default(),
586            k: 10,
587            relevant_in_top_k: 8,
588            total_relevant: Some(10),
589            latency: Duration::from_millis(50),
590            timestamp: Instant::now(),
591        };
592        assert!((fb.recall_at_k() - 0.8).abs() < 0.01);
593    }
594
595    #[test]
596    fn test_feedback_recall_without_ground_truth() {
597        let fb = QueryFeedback {
598            params: SearchParams::default(),
599            k: 10,
600            relevant_in_top_k: 7,
601            total_relevant: None,
602            latency: Duration::from_millis(50),
603            timestamp: Instant::now(),
604        };
605        assert!((fb.recall_at_k() - 0.7).abs() < 0.01);
606    }
607
608    #[test]
609    fn test_feedback_precision() {
610        let fb = QueryFeedback {
611            params: SearchParams::default(),
612            k: 10,
613            relevant_in_top_k: 5,
614            total_relevant: Some(20),
615            latency: Duration::from_millis(50),
616            timestamp: Instant::now(),
617        };
618        assert!((fb.precision_at_k() - 0.5).abs() < 0.01);
619    }
620
621    #[test]
622    fn test_feedback_k_zero() {
623        let fb = QueryFeedback {
624            params: SearchParams::default(),
625            k: 0,
626            relevant_in_top_k: 0,
627            total_relevant: None,
628            latency: Duration::from_millis(1),
629            timestamp: Instant::now(),
630        };
631        assert_eq!(fb.recall_at_k(), 0.0);
632        assert_eq!(fb.precision_at_k(), 0.0);
633    }
634
635    // ── TunerConfig tests ─────────────────────────────────────────────────
636
637    #[test]
638    fn test_tuner_config_default() {
639        let c = TunerConfig::default();
640        assert!((c.target_recall - 0.95).abs() < 0.01);
641        assert_eq!(c.window_size, 100);
642        assert_eq!(c.min_samples, 10);
643    }
644
645    // ── Tuner core tests ──────────────────────────────────────────────────
646
647    fn make_feedback(recall_ratio: f64, k: usize, latency_ms: u64) -> QueryFeedback {
648        let relevant = (k as f64 * recall_ratio) as usize;
649        QueryFeedback {
650            params: SearchParams::default(),
651            k,
652            relevant_in_top_k: relevant,
653            total_relevant: Some(k),
654            latency: Duration::from_millis(latency_ms),
655            timestamp: Instant::now(),
656        }
657    }
658
659    #[test]
660    fn test_tuner_initial_params() {
661        let tuner = AdaptiveRecallTuner::new(TunerConfig::default());
662        assert_eq!(tuner.current_params().ef_search, 64);
663    }
664
665    #[test]
666    fn test_tuner_with_initial_params() {
667        let initial = SearchParams::high_recall();
668        let tuner =
669            AdaptiveRecallTuner::with_initial_params(TunerConfig::default(), initial.clone());
670        assert_eq!(tuner.current_params().ef_search, initial.ef_search);
671    }
672
673    #[test]
674    fn test_tuner_no_adjust_before_min_samples() {
675        let config = TunerConfig {
676            min_samples: 10,
677            adjust_interval: 1,
678            ..Default::default()
679        };
680        let mut tuner = AdaptiveRecallTuner::new(config);
681
682        for _ in 0..5 {
683            let adjusted = tuner.record_feedback(make_feedback(0.5, 10, 50));
684            assert!(!adjusted);
685        }
686        assert_eq!(tuner.stats().adjustments_made, 0);
687    }
688
689    #[test]
690    fn test_tuner_adjusts_after_min_samples() {
691        let config = TunerConfig {
692            min_samples: 5,
693            adjust_interval: 5,
694            ..Default::default()
695        };
696        let mut tuner = AdaptiveRecallTuner::new(config);
697
698        for i in 0..10 {
699            tuner.record_feedback(make_feedback(0.5, 10, 50));
700            if i >= 4 && (i + 1) % 5 == 0 {
701                // Should have adjusted
702            }
703        }
704        assert!(tuner.stats().adjustments_made > 0);
705    }
706
707    #[test]
708    fn test_tuner_increases_params_for_low_recall() {
709        let config = TunerConfig {
710            min_samples: 5,
711            adjust_interval: 1,
712            target_recall: 0.95,
713            kp: 0.5,
714            ki: 0.0,
715            kd: 0.0,
716            ..Default::default()
717        };
718        let mut tuner = AdaptiveRecallTuner::new(config);
719        let initial_ef = tuner.current_params().ef_search;
720
721        // Feed low recall data
722        for _ in 0..10 {
723            tuner.record_feedback(make_feedback(0.3, 10, 30));
724        }
725
726        // Parameters should have increased
727        assert!(tuner.current_params().ef_search > initial_ef);
728    }
729
730    #[test]
731    fn test_tuner_decreases_params_for_high_recall() {
732        let config = TunerConfig {
733            min_samples: 5,
734            adjust_interval: 1,
735            target_recall: 0.5,
736            kp: 0.5,
737            ki: 0.0,
738            kd: 0.0,
739            ..Default::default()
740        };
741        let initial = SearchParams::high_recall();
742        let mut tuner = AdaptiveRecallTuner::with_initial_params(config, initial.clone());
743
744        // Feed high recall data (above target)
745        for _ in 0..10 {
746            tuner.record_feedback(make_feedback(0.99, 10, 30));
747        }
748
749        // Parameters should have decreased (or stayed) since recall exceeds target
750        assert!(tuner.current_params().ef_search <= initial.ef_search);
751    }
752
753    #[test]
754    fn test_tuner_respects_latency_constraint() {
755        let config = TunerConfig {
756            min_samples: 5,
757            adjust_interval: 1,
758            max_latency: Duration::from_millis(50),
759            target_recall: 0.99,
760            kp: 1.0,
761            ki: 0.0,
762            kd: 0.0,
763            ..Default::default()
764        };
765        let mut tuner = AdaptiveRecallTuner::new(config);
766
767        // Feed data with high latency -> tuner should pull back despite low recall
768        for _ in 0..20 {
769            tuner.record_feedback(make_feedback(0.3, 10, 200));
770        }
771
772        // ef_search should not grow unboundedly
773        assert!(tuner.current_params().ef_search < 1024);
774    }
775
776    #[test]
777    fn test_tuner_stats_tracking() {
778        let config = TunerConfig {
779            min_samples: 3,
780            adjust_interval: 1,
781            ..Default::default()
782        };
783        let mut tuner = AdaptiveRecallTuner::new(config);
784
785        for _ in 0..5 {
786            tuner.record_feedback(make_feedback(0.8, 10, 40));
787        }
788
789        assert_eq!(tuner.stats().total_feedbacks, 5);
790        assert!(tuner.stats().current_recall > 0.0);
791        assert!(tuner.stats().current_avg_latency_ms > 0.0);
792    }
793
794    #[test]
795    fn test_tuner_reset() {
796        let config = TunerConfig {
797            min_samples: 3,
798            adjust_interval: 1,
799            ..Default::default()
800        };
801        let mut tuner = AdaptiveRecallTuner::new(config);
802
803        for _ in 0..5 {
804            tuner.record_feedback(make_feedback(0.8, 10, 40));
805        }
806        tuner.reset();
807
808        assert_eq!(tuner.stats().total_feedbacks, 0);
809        assert_eq!(tuner.stats().adjustments_made, 0);
810        assert_eq!(tuner.current_params().ef_search, 64);
811    }
812
813    #[test]
814    fn test_tuner_force_adjust() {
815        let config = TunerConfig {
816            min_samples: 3,
817            adjust_interval: 100, // Very high interval
818            ..Default::default()
819        };
820        let mut tuner = AdaptiveRecallTuner::new(config);
821
822        for _ in 0..5 {
823            tuner.record_feedback(make_feedback(0.5, 10, 40));
824        }
825        assert_eq!(tuner.stats().adjustments_made, 0);
826
827        tuner.force_adjust();
828        assert_eq!(tuner.stats().adjustments_made, 1);
829    }
830
831    #[test]
832    fn test_stats_near_target() {
833        let stats = TunerStats {
834            current_recall: 0.94,
835            ..Default::default()
836        };
837        assert!(stats.is_near_target(0.95, 0.02));
838        assert!(!stats.is_near_target(0.95, 0.005));
839    }
840
841    #[test]
842    fn test_recall_history() {
843        let config = TunerConfig {
844            min_samples: 3,
845            adjust_interval: 1,
846            ..Default::default()
847        };
848        let mut tuner = AdaptiveRecallTuner::new(config);
849
850        for _ in 0..5 {
851            tuner.record_feedback(make_feedback(0.8, 10, 40));
852        }
853
854        assert!(!tuner.stats().recall_history.is_empty());
855    }
856
857    // ── RecallEvaluator tests ─────────────────────────────────────────────
858
859    #[test]
860    fn test_evaluator_perfect_recall() {
861        let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
862        let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
863        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
864        assert!((eval.recall_at_k - 1.0).abs() < 0.01);
865        assert!((eval.precision_at_k - 1.0).abs() < 0.01);
866    }
867
868    #[test]
869    fn test_evaluator_partial_recall() {
870        let results = vec![vec!["a".to_string(), "b".to_string(), "x".to_string()]];
871        let truth = vec![vec![
872            "a".to_string(),
873            "b".to_string(),
874            "c".to_string(),
875            "d".to_string(),
876        ]];
877        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
878        assert!((eval.recall_at_k - 0.5).abs() < 0.01); // 2/4
879        assert!((eval.precision_at_k - 2.0 / 3.0).abs() < 0.01); // 2/3
880    }
881
882    #[test]
883    fn test_evaluator_zero_recall() {
884        let results = vec![vec!["x".to_string(), "y".to_string(), "z".to_string()]];
885        let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
886        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
887        assert_eq!(eval.recall_at_k, 0.0);
888        assert_eq!(eval.precision_at_k, 0.0);
889    }
890
891    #[test]
892    fn test_evaluator_empty() {
893        let eval = RecallEvaluator::evaluate(&[], &[], 10);
894        assert_eq!(eval.num_queries, 0);
895        assert_eq!(eval.recall_at_k, 0.0);
896    }
897
898    #[test]
899    fn test_evaluator_multiple_queries() {
900        let results = vec![
901            vec!["a".to_string(), "b".to_string()],
902            vec!["c".to_string(), "d".to_string()],
903        ];
904        let truth = vec![
905            vec!["a".to_string(), "b".to_string()],
906            vec!["c".to_string(), "x".to_string()],
907        ];
908        let eval = RecallEvaluator::evaluate(&results, &truth, 2);
909        assert_eq!(eval.num_queries, 2);
910        // Query 1: recall=1.0, Query 2: recall=0.5 -> avg=0.75
911        assert!((eval.recall_at_k - 0.75).abs() < 0.01);
912    }
913
914    #[test]
915    fn test_evaluator_k_less_than_results() {
916        let results = vec![vec![
917            "a".to_string(),
918            "b".to_string(),
919            "c".to_string(),
920            "d".to_string(),
921        ]];
922        let truth = vec![vec!["a".to_string(), "b".to_string()]];
923        let eval = RecallEvaluator::evaluate(&results, &truth, 2);
924        // Only top-2 considered: a, b — both relevant
925        assert!((eval.recall_at_k - 1.0).abs() < 0.01);
926    }
927
928    #[test]
929    fn test_evaluator_ndcg() {
930        let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
931        let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
932        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
933        // Perfect ranking -> nDCG should be 1.0
934        assert!((eval.ndcg - 1.0).abs() < 0.01);
935    }
936
937    #[test]
938    fn test_evaluator_f1_score() {
939        let results = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
940        let truth = vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]];
941        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
942        // P=1.0, R=1.0 -> F1=1.0
943        assert!((eval.f1_score - 1.0).abs() < 0.01);
944    }
945
946    #[test]
947    fn test_evaluator_average_precision() {
948        let results = vec![vec!["a".to_string(), "x".to_string(), "b".to_string()]];
949        let truth = vec![vec!["a".to_string(), "b".to_string()]];
950        let eval = RecallEvaluator::evaluate(&results, &truth, 3);
951        // AP: (1/1 + 2/3) / 2 = (1 + 0.667) / 2 = 0.833
952        assert!(eval.average_precision > 0.0);
953    }
954
955    // ── Integration test ──────────────────────────────────────────────────
956
957    #[test]
958    fn test_tuner_convergence_simulation() {
959        let config = TunerConfig {
960            min_samples: 5,
961            adjust_interval: 1,
962            target_recall: 0.9,
963            kp: 0.3,
964            ki: 0.05,
965            kd: 0.02,
966            ..Default::default()
967        };
968        let mut tuner = AdaptiveRecallTuner::new(config);
969
970        // Simulate recall improving as parameters are tuned
971        for i in 0..50 {
972            let recall = 0.5 + (i as f64 * 0.01).min(0.45);
973            tuner.record_feedback(make_feedback(recall, 10, 30));
974        }
975
976        // After 50 iterations, should have made adjustments
977        assert!(tuner.stats().adjustments_made > 0);
978        assert!(tuner.stats().total_feedbacks == 50);
979    }
980
981    #[test]
982    fn test_integral_windup_prevention() {
983        let config = TunerConfig {
984            min_samples: 3,
985            adjust_interval: 1,
986            target_recall: 0.99,
987            kp: 0.1,
988            ki: 1.0, // Very high integral gain
989            kd: 0.0,
990            ..Default::default()
991        };
992        let mut tuner = AdaptiveRecallTuner::new(config);
993
994        // Feed consistently low recall
995        for _ in 0..100 {
996            tuner.record_feedback(make_feedback(0.1, 10, 20));
997        }
998
999        // Despite high integral error, parameters should be clamped
1000        assert!(tuner.current_params().ef_search <= 1024);
1001        assert!(tuner.current_params().num_candidates <= 5000);
1002    }
1003}