1use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15pub struct LearningStopModel {
17 wand_predictor: Arc<RwLock<WandStoppingPredictor>>,
19
20 hnsw_predictor: Arc<RwLock<HnswStoppingPredictor>>,
22
23 confidence_model: Arc<RwLock<ConfidenceModel>>,
25
26 feature_extractors: FeatureExtractors,
28
29 training_scheduler: Arc<RwLock<TrainingScheduler>>,
31
32 metrics: Arc<RwLock<LearningMetrics>>,
34
35 config: LearningConfig,
37}
38
39#[derive(Debug, Clone)]
41pub struct LearningConfig {
42 pub training_window_size: usize,
44
45 pub update_frequency: usize,
47
48 pub learning_rate: f64,
50
51 pub confidence_threshold: f64,
53
54 pub min_training_samples: usize,
56
57 pub feature_normalization: bool,
59
60 pub wand_config: WandLearningConfig,
62
63 pub hnsw_config: HnswLearningConfig,
65}
66
67#[derive(Debug, Clone)]
69pub struct WandLearningConfig {
70 pub max_iterations: usize,
72
73 pub quality_threshold: f64,
75
76 pub score_improvement_tolerance: f64,
78
79 pub term_contribution_threshold: f64,
81}
82
83impl Default for WandLearningConfig {
84 fn default() -> Self {
85 Self {
86 max_iterations: 100,
87 quality_threshold: 0.8,
88 score_improvement_tolerance: 0.01,
89 term_contribution_threshold: 0.05,
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct HnswLearningConfig {
97 pub max_layers: usize,
99
100 pub beam_width: usize,
102
103 pub distance_threshold: f64,
105
106 pub max_neighbors: usize,
108}
109
110impl Default for HnswLearningConfig {
111 fn default() -> Self {
112 Self {
113 max_layers: 5,
114 beam_width: 64,
115 distance_threshold: 0.1,
116 max_neighbors: 16,
117 }
118 }
119}
120
121pub struct WandStoppingPredictor {
123 weights: HashMap<WandFeature, f64>,
125
126 training_history: VecDeque<WandTrainingSample>,
128
129 accuracy: f64,
131 precision: f64,
132 recall: f64,
133
134 is_trained: bool,
136 last_update: std::time::Instant,
137}
138
139pub struct HnswStoppingPredictor {
141 layer_thresholds: HashMap<usize, f64>,
143
144 neighbor_quality_weights: HashMap<HnswFeature, f64>,
146
147 training_samples: VecDeque<HnswTrainingSample>,
149
150 search_efficiency: f64,
152 quality_maintained: f64,
153
154 beam_width_adaptation: f64,
156 exploration_decay: f64,
157}
158
159pub struct ConfidenceModel {
161 confidence_predictors: HashMap<ConfidenceFeature, LinearPredictor>,
163
164 calibration_params: CalibrationParams,
166
167 confidence_accuracy: f64,
169
170 calibration_data: VecDeque<ConfidenceTrainingSample>,
172}
173
174pub struct FeatureExtractors {
176 wand_extractor: WandFeatureExtractor,
177 hnsw_extractor: HnswFeatureExtractor,
178 confidence_extractor: ConfidenceFeatureExtractor,
179}
180
181pub struct TrainingScheduler {
183 queries_since_update: usize,
184 update_frequency: usize,
185 next_training_time: std::time::Instant,
186 is_training: bool,
187}
188
189#[derive(Debug, Default, Clone, Serialize, Deserialize)]
191pub struct LearningMetrics {
192 pub total_predictions: u64,
193 pub correct_early_stops: u64,
194 pub incorrect_early_stops: u64,
195 pub missed_stopping_opportunities: u64,
196 pub avg_computation_saved: f64,
197 pub avg_quality_maintained: f64,
198 pub model_accuracy: f64,
199 pub adaptation_events: u64,
200 pub feature_importance: HashMap<String, f64>,
201}
202
203#[derive(Debug, Clone, Hash, Eq, PartialEq)]
205pub enum WandFeature {
206 IterationCount,
207 ScoreImprovement,
208 TermContribution,
209 DocumentFrequency,
210 QualityEstimate,
211 TimeElapsed,
212 CandidateSetSize,
213 ThresholdConvergence,
214}
215
216#[derive(Debug, Clone, Hash, Eq, PartialEq)]
218pub enum HnswFeature {
219 LayerDepth,
220 DistanceToQuery,
221 NeighborCount,
222 SearchRadius,
223 BeamPosition,
224 ExplorationRatio,
225 DistanceImprovement,
226 GraphConnectivity,
227}
228
229#[derive(Debug, Clone, Hash, Eq, PartialEq)]
231pub enum ConfidenceFeature {
232 ResultCount,
233 ScoreDistribution,
234 SystemAgreement,
235 QueryComplexity,
236 ProcessingTime,
237 ResourceUtilization,
238}
239
240#[derive(Debug, Clone)]
242pub struct WandTrainingSample {
243 pub features: HashMap<WandFeature, f64>,
244 pub should_have_stopped: bool,
245 pub actual_quality: f64,
246 pub computation_saved: f64,
247 pub timestamp: std::time::Instant,
248}
249
250#[derive(Debug, Clone)]
252pub struct HnswTrainingSample {
253 pub features: HashMap<HnswFeature, f64>,
254 pub optimal_stopping_point: usize,
255 pub final_quality: f64,
256 pub search_efficiency: f64,
257 pub timestamp: std::time::Instant,
258}
259
260#[derive(Debug, Clone)]
262pub struct ConfidenceTrainingSample {
263 pub features: HashMap<ConfidenceFeature, f64>,
264 pub predicted_confidence: f64,
265 pub actual_quality: f64,
266 pub timestamp: std::time::Instant,
267}
268
269#[derive(Debug, Clone)]
271pub struct LinearPredictor {
272 weights: Vec<f64>,
273 bias: f64,
274 learning_rate: f64,
275}
276
277impl LinearPredictor {
278 pub fn new(weights: Vec<f64>, bias: f64, learning_rate: f64) -> Self {
280 Self {
281 weights,
282 bias,
283 learning_rate,
284 }
285 }
286
287 pub fn weights(&self) -> &Vec<f64> {
289 &self.weights
290 }
291
292 pub fn bias(&self) -> f64 {
293 self.bias
294 }
295
296 pub fn learning_rate(&self) -> f64 {
297 self.learning_rate
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct CalibrationParams {
304 temperature: f64,
305 shift: f64,
306 scale: f64,
307}
308
309impl CalibrationParams {
310 pub fn new(temperature: f64, shift: f64, scale: f64) -> Self {
312 Self {
313 temperature,
314 shift,
315 scale,
316 }
317 }
318
319 pub fn temperature(&self) -> f64 {
321 self.temperature
322 }
323
324 pub fn shift(&self) -> f64 {
325 self.shift
326 }
327
328 pub fn scale(&self) -> f64 {
329 self.scale
330 }
331}
332
333pub struct WandFeatureExtractor;
335pub struct HnswFeatureExtractor;
336pub struct ConfidenceFeatureExtractor;
337
338#[derive(Debug, Clone)]
340pub struct QueryContext {
341 pub query_terms: Vec<String>,
342 pub query_vector: Option<Vec<f32>>,
343 pub start_time: std::time::Instant,
344 pub complexity_score: f64,
345 pub expected_result_count: usize,
346}
347
348#[derive(Debug, Clone)]
350pub struct WandSearchState {
351 pub iteration: usize,
352 pub current_threshold: f64,
353 pub candidate_count: usize,
354 pub score_improvements: Vec<f64>,
355 pub term_contributions: HashMap<String, f64>,
356 pub processing_time: std::time::Duration,
357}
358
359#[derive(Debug, Clone)]
361pub struct HnswSearchState {
362 pub current_layer: usize,
363 pub beam_candidates: Vec<HnswCandidate>,
364 pub visited_nodes: usize,
365 pub best_distance: f32,
366 pub exploration_ratio: f64,
367}
368
369#[derive(Debug, Clone)]
371pub struct HnswCandidate {
372 pub node_id: usize,
373 pub distance: f32,
374 pub layer: usize,
375 pub neighbor_count: usize,
376}
377
378#[derive(Debug, Clone)]
380pub struct LearnedStoppingDecision {
381 pub should_stop: bool,
382 pub confidence: f64,
383 pub predicted_quality: f64,
384 pub estimated_computation_saved: f64,
385 pub reasoning: StoppingReasoning,
386 pub algorithm_used: String,
387}
388
389#[derive(Debug, Clone)]
391pub struct StoppingReasoning {
392 pub primary_factor: String,
393 pub feature_contributions: HashMap<String, f64>,
394 pub threshold_exceeded: bool,
395 pub quality_sufficient: bool,
396}
397
398impl Default for LearningConfig {
399 fn default() -> Self {
400 Self {
401 training_window_size: 1000,
402 update_frequency: 100,
403 learning_rate: 0.01,
404 confidence_threshold: 0.85,
405 min_training_samples: 50,
406 feature_normalization: true,
407 wand_config: WandLearningConfig {
408 max_iterations: 100,
409 quality_threshold: 0.8,
410 score_improvement_tolerance: 0.01,
411 term_contribution_threshold: 0.05,
412 },
413 hnsw_config: HnswLearningConfig {
414 max_layers: 5,
415 beam_width: 64,
416 distance_threshold: 0.1,
417 max_neighbors: 16,
418 },
419 }
420 }
421}
422
423impl LearningStopModel {
424 pub async fn new(config: LearningConfig) -> Result<Self> {
426 let wand_predictor = Arc::new(RwLock::new(WandStoppingPredictor::new(config.wand_config.clone())));
427 let hnsw_predictor = Arc::new(RwLock::new(HnswStoppingPredictor::new(config.hnsw_config.clone())));
428 let confidence_model = Arc::new(RwLock::new(ConfidenceModel::new()));
429
430 let feature_extractors = FeatureExtractors {
431 wand_extractor: WandFeatureExtractor,
432 hnsw_extractor: HnswFeatureExtractor,
433 confidence_extractor: ConfidenceFeatureExtractor,
434 };
435
436 let training_scheduler = Arc::new(RwLock::new(TrainingScheduler {
437 queries_since_update: 0,
438 update_frequency: config.update_frequency,
439 next_training_time: std::time::Instant::now(),
440 is_training: false,
441 }));
442
443 let metrics = Arc::new(RwLock::new(LearningMetrics::default()));
444
445 info!("Initialized learning-to-stop model with training window: {}", config.training_window_size);
446
447 Ok(Self {
448 wand_predictor,
449 hnsw_predictor,
450 confidence_model,
451 feature_extractors,
452 training_scheduler,
453 metrics,
454 config,
455 })
456 }
457
458 pub async fn predict_wand_stopping(
460 &self,
461 context: &QueryContext,
462 state: &WandSearchState,
463 ) -> Result<LearnedStoppingDecision> {
464 let features = self.feature_extractors.wand_extractor.extract_features(context, state);
466
467 let wand_predictor = self.wand_predictor.read().await;
469 let (should_stop, confidence) = wand_predictor.predict(&features);
470
471 let predicted_quality = self.estimate_wand_quality(&features);
473 let computation_saved = self.estimate_computation_saved(state.iteration, self.config.wand_config.max_iterations);
474
475 let reasoning = self.build_wand_reasoning(&features, should_stop, confidence);
477
478 self.update_prediction_metrics("wand", should_stop, confidence).await;
480
481 Ok(LearnedStoppingDecision {
482 should_stop,
483 confidence,
484 predicted_quality,
485 estimated_computation_saved: computation_saved,
486 reasoning,
487 algorithm_used: "WAND-Learned".to_string(),
488 })
489 }
490
491 pub async fn predict_hnsw_stopping(
493 &self,
494 context: &QueryContext,
495 state: &HnswSearchState,
496 ) -> Result<LearnedStoppingDecision> {
497 let features = self.feature_extractors.hnsw_extractor.extract_features(context, state);
499
500 let hnsw_predictor = self.hnsw_predictor.read().await;
502 let (should_stop, confidence) = hnsw_predictor.predict(&features);
503
504 let predicted_quality = self.estimate_hnsw_quality(&features, state);
506 let computation_saved = self.estimate_hnsw_computation_saved(state);
507
508 let reasoning = self.build_hnsw_reasoning(&features, should_stop, confidence);
510
511 self.update_prediction_metrics("hnsw", should_stop, confidence).await;
513
514 Ok(LearnedStoppingDecision {
515 should_stop,
516 confidence,
517 predicted_quality,
518 estimated_computation_saved: computation_saved,
519 reasoning,
520 algorithm_used: "HNSW-Learned".to_string(),
521 })
522 }
523
524 pub async fn train_with_feedback(
526 &self,
527 query_type: &str,
528 decision: &LearnedStoppingDecision,
529 actual_quality: f64,
530 actual_computation_saved: f64,
531 ) -> Result<()> {
532 let mut scheduler = self.training_scheduler.write().await;
533 scheduler.queries_since_update += 1;
534
535 match query_type {
536 "wand" => {
537 let mut predictor = self.wand_predictor.write().await;
538 predictor.add_training_sample(WandTrainingSample {
539 features: HashMap::new(), should_have_stopped: decision.should_stop,
541 actual_quality,
542 computation_saved: actual_computation_saved,
543 timestamp: std::time::Instant::now(),
544 });
545 }
546 "hnsw" => {
547 let mut predictor = self.hnsw_predictor.write().await;
548 predictor.add_training_sample(HnswTrainingSample {
549 features: HashMap::new(), optimal_stopping_point: 0, final_quality: actual_quality,
552 search_efficiency: actual_computation_saved,
553 timestamp: std::time::Instant::now(),
554 });
555 }
556 "confidence" => {
557 }
561 _ => return Err(anyhow!("Unknown query type: {}", query_type)),
562 }
563
564 if scheduler.queries_since_update >= scheduler.update_frequency {
566 self.update_models().await?;
567 scheduler.queries_since_update = 0;
568 }
569
570 Ok(())
571 }
572
573 async fn update_models(&self) -> Result<()> {
575 let mut scheduler = self.training_scheduler.write().await;
576
577 if scheduler.is_training {
578 return Ok(()); }
580
581 scheduler.is_training = true;
582 drop(scheduler);
583
584 {
586 let mut wand_predictor = self.wand_predictor.write().await;
587 wand_predictor.update_model(self.config.learning_rate)?;
588 }
589
590 {
592 let mut hnsw_predictor = self.hnsw_predictor.write().await;
593 hnsw_predictor.update_model(self.config.learning_rate)?;
594 }
595
596 {
598 let mut confidence_model = self.confidence_model.write().await;
599 confidence_model.update_calibration()?;
600 }
601
602 {
604 let mut scheduler = self.training_scheduler.write().await;
605 scheduler.is_training = false;
606 scheduler.next_training_time = std::time::Instant::now() + std::time::Duration::from_secs(300); }
608
609 {
611 let mut metrics = self.metrics.write().await;
612 metrics.adaptation_events += 1;
613 }
614
615 info!("Updated learning models with new training data");
616
617 Ok(())
618 }
619
620 fn estimate_wand_quality(&self, features: &HashMap<WandFeature, f64>) -> f64 {
622 let score_improvement = features.get(&WandFeature::ScoreImprovement).unwrap_or(&0.0);
623 let quality_estimate = features.get(&WandFeature::QualityEstimate).unwrap_or(&0.5);
624 let threshold_convergence = features.get(&WandFeature::ThresholdConvergence).unwrap_or(&0.0);
625
626 (score_improvement * 0.4 + quality_estimate * 0.4 + threshold_convergence * 0.2).min(1.0)
628 }
629
630 fn estimate_hnsw_quality(&self, features: &HashMap<HnswFeature, f64>, state: &HnswSearchState) -> f64 {
632 let distance_improvement = features.get(&HnswFeature::DistanceImprovement).unwrap_or(&0.0);
633 let exploration_ratio = features.get(&HnswFeature::ExplorationRatio).unwrap_or(&0.5);
634
635 let distance_quality = if state.best_distance > 0.0 {
636 (1.0 - state.best_distance).max(0.0)
637 } else {
638 0.0
639 };
640
641 (distance_improvement * 0.3 + exploration_ratio * 0.3 + distance_quality as f64 * 0.4).min(1.0)
642 }
643
644 fn estimate_computation_saved(&self, current_iteration: usize, max_iterations: usize) -> f64 {
646 if max_iterations == 0 {
647 return 0.0;
648 }
649
650 let remaining_iterations = max_iterations.saturating_sub(current_iteration);
651 remaining_iterations as f64 / max_iterations as f64
652 }
653
654 fn estimate_hnsw_computation_saved(&self, state: &HnswSearchState) -> f64 {
656 let max_possible_visits = self.config.hnsw_config.max_neighbors * self.config.hnsw_config.max_layers;
657 let remaining_visits = max_possible_visits.saturating_sub(state.visited_nodes);
658
659 remaining_visits as f64 / max_possible_visits as f64
660 }
661
662 fn build_wand_reasoning(&self, features: &HashMap<WandFeature, f64>, should_stop: bool, confidence: f64) -> StoppingReasoning {
664 let mut feature_contributions = HashMap::new();
665
666 for (feature, value) in features {
668 let contribution = value * confidence; feature_contributions.insert(format!("{:?}", feature), contribution);
670 }
671
672 let primary_factor = if should_stop {
673 "Score convergence detected"
674 } else {
675 "Continued exploration needed"
676 }.to_string();
677
678 StoppingReasoning {
679 primary_factor,
680 feature_contributions,
681 threshold_exceeded: confidence > self.config.confidence_threshold,
682 quality_sufficient: features.get(&WandFeature::QualityEstimate).unwrap_or(&0.0) > &self.config.wand_config.quality_threshold,
683 }
684 }
685
686 fn build_hnsw_reasoning(&self, features: &HashMap<HnswFeature, f64>, should_stop: bool, confidence: f64) -> StoppingReasoning {
688 let mut feature_contributions = HashMap::new();
689
690 for (feature, value) in features {
691 let contribution = value * confidence;
692 feature_contributions.insert(format!("{:?}", feature), contribution);
693 }
694
695 let primary_factor = if should_stop {
696 "Distance threshold reached"
697 } else {
698 "Further exploration beneficial"
699 }.to_string();
700
701 StoppingReasoning {
702 primary_factor,
703 feature_contributions,
704 threshold_exceeded: confidence > self.config.confidence_threshold,
705 quality_sufficient: features.get(&HnswFeature::DistanceToQuery).unwrap_or(&1.0) < &self.config.hnsw_config.distance_threshold,
706 }
707 }
708
709 async fn update_prediction_metrics(&self, algorithm: &str, prediction: bool, confidence: f64) {
711 let mut metrics = self.metrics.write().await;
712 metrics.total_predictions += 1;
713
714 if prediction {
716 debug!("Predicted early stop for {} with confidence {:.3}", algorithm, confidence);
717 }
718 }
719
720 pub async fn get_metrics(&self) -> LearningMetrics {
722 self.metrics.read().await.clone()
723 }
724
725 pub fn config(&self) -> &LearningConfig {
727 &self.config
728 }
729}
730
731impl WandStoppingPredictor {
732 pub fn new(_config: WandLearningConfig) -> Self {
733 let mut weights = HashMap::new();
734
735 weights.insert(WandFeature::IterationCount, -0.1);
737 weights.insert(WandFeature::ScoreImprovement, 0.8);
738 weights.insert(WandFeature::TermContribution, 0.6);
739 weights.insert(WandFeature::QualityEstimate, 0.9);
740 weights.insert(WandFeature::ThresholdConvergence, 0.7);
741
742 Self {
743 weights,
744 training_history: VecDeque::new(),
745 accuracy: 0.5,
746 precision: 0.5,
747 recall: 0.5,
748 is_trained: false,
749 last_update: std::time::Instant::now(),
750 }
751 }
752
753 pub fn predict(&self, features: &HashMap<WandFeature, f64>) -> (bool, f64) {
754 let mut score = 0.0;
755 let mut feature_count = 0;
756
757 for (feature, weight) in &self.weights {
758 if let Some(feature_value) = features.get(feature) {
759 score += feature_value * weight;
760 feature_count += 1;
761 }
762 }
763
764 if feature_count > 0 {
765 score /= feature_count as f64;
766 }
767
768 let confidence = (score.tanh() + 1.0) / 2.0; let should_stop = confidence > 0.5;
770
771 (should_stop, confidence)
772 }
773
774 pub fn add_training_sample(&mut self, sample: WandTrainingSample) {
775 self.training_history.push_back(sample);
776
777 while self.training_history.len() > 1000 {
779 self.training_history.pop_front();
780 }
781 }
782
783 pub fn update_model(&mut self, learning_rate: f64) -> Result<()> {
784 if self.training_history.len() < 10 {
785 return Ok(()); }
787
788 for sample in self.training_history.iter().rev().take(100) {
790 let (predicted, _) = self.predict(&sample.features);
791 let error = if sample.should_have_stopped { 1.0 } else { 0.0 } - if predicted { 1.0 } else { 0.0 };
792
793 for (feature, feature_value) in &sample.features {
795 if let Some(weight) = self.weights.get_mut(feature) {
796 *weight += learning_rate * error * feature_value;
797 }
798 }
799 }
800
801 self.last_update = std::time::Instant::now();
802 self.is_trained = true;
803
804 Ok(())
805 }
806
807 pub fn weights(&self) -> &HashMap<WandFeature, f64> {
809 &self.weights
810 }
811
812 pub fn is_trained(&self) -> bool {
813 self.is_trained
814 }
815
816 pub fn accuracy(&self) -> f64 {
817 self.accuracy
818 }
819
820 pub fn precision(&self) -> f64 {
821 self.precision
822 }
823
824 pub fn recall(&self) -> f64 {
825 self.recall
826 }
827
828 pub fn training_history(&self) -> &VecDeque<WandTrainingSample> {
829 &self.training_history
830 }
831}
832
833impl HnswStoppingPredictor {
834 pub fn new(_config: HnswLearningConfig) -> Self {
835 let mut layer_thresholds = HashMap::new();
836 let mut neighbor_quality_weights = HashMap::new();
837
838 for layer in 0..5 {
840 layer_thresholds.insert(layer, 0.1 * (layer + 1) as f64);
841 }
842
843 neighbor_quality_weights.insert(HnswFeature::DistanceToQuery, 0.9);
845 neighbor_quality_weights.insert(HnswFeature::DistanceImprovement, 0.8);
846 neighbor_quality_weights.insert(HnswFeature::ExplorationRatio, 0.6);
847 neighbor_quality_weights.insert(HnswFeature::GraphConnectivity, 0.4);
848
849 Self {
850 layer_thresholds,
851 neighbor_quality_weights,
852 training_samples: VecDeque::new(),
853 search_efficiency: 0.5,
854 quality_maintained: 0.5,
855 beam_width_adaptation: 1.0,
856 exploration_decay: 0.95,
857 }
858 }
859
860 pub fn predict(&self, features: &HashMap<HnswFeature, f64>) -> (bool, f64) {
861 let mut quality_score = 0.0;
862 let mut feature_count = 0;
863
864 for (feature, weight) in &self.neighbor_quality_weights {
865 if let Some(feature_value) = features.get(feature) {
866 quality_score += feature_value * weight;
867 feature_count += 1;
868 }
869 }
870
871 if feature_count > 0 {
872 quality_score /= feature_count as f64;
873 }
874
875 let confidence = quality_score.min(1.0).max(0.0);
876 let should_stop = confidence > 0.7; (should_stop, confidence)
879 }
880
881 pub fn add_training_sample(&mut self, sample: HnswTrainingSample) {
882 self.training_samples.push_back(sample);
883
884 while self.training_samples.len() > 1000 {
885 self.training_samples.pop_front();
886 }
887 }
888
889 pub fn update_model(&mut self, learning_rate: f64) -> Result<()> {
890 if self.training_samples.len() < 10 {
891 return Ok(());
892 }
893
894 let recent_samples: Vec<_> = self.training_samples.iter().rev().take(50).collect();
897
898 let avg_efficiency: f64 = recent_samples.iter().map(|s| s.search_efficiency).sum::<f64>() / recent_samples.len() as f64;
899 let avg_quality: f64 = recent_samples.iter().map(|s| s.final_quality).sum::<f64>() / recent_samples.len() as f64;
900
901 self.search_efficiency = self.search_efficiency * (1.0 - learning_rate) + avg_efficiency * learning_rate;
903 self.quality_maintained = self.quality_maintained * (1.0 - learning_rate) + avg_quality * learning_rate;
904
905 if avg_efficiency < 0.6 {
907 self.beam_width_adaptation *= 1.1; } else if avg_efficiency > 0.8 {
909 self.beam_width_adaptation *= 0.95; }
911
912 Ok(())
913 }
914
915 pub fn layer_thresholds(&self) -> &HashMap<usize, f64> {
917 &self.layer_thresholds
918 }
919
920 pub fn neighbor_quality_weights(&self) -> &HashMap<HnswFeature, f64> {
921 &self.neighbor_quality_weights
922 }
923
924 pub fn training_samples(&self) -> &VecDeque<HnswTrainingSample> {
925 &self.training_samples
926 }
927
928 pub fn search_efficiency(&self) -> f64 {
929 self.search_efficiency
930 }
931
932 pub fn quality_maintained(&self) -> f64 {
933 self.quality_maintained
934 }
935
936 pub fn beam_width_adaptation(&self) -> f64 {
937 self.beam_width_adaptation
938 }
939
940 pub fn exploration_decay(&self) -> f64 {
941 self.exploration_decay
942 }
943}
944
945impl ConfidenceModel {
946 pub fn new() -> Self {
947 Self {
948 confidence_predictors: HashMap::new(),
949 calibration_params: CalibrationParams {
950 temperature: 1.0,
951 shift: 0.0,
952 scale: 1.0,
953 },
954 confidence_accuracy: 0.5,
955 calibration_data: VecDeque::new(),
956 }
957 }
958
959 pub fn update_calibration(&mut self) -> Result<()> {
960 if self.calibration_data.len() < 20 {
962 return Ok(());
963 }
964
965 let recent_data: Vec<_> = self.calibration_data.iter().rev().take(100).collect();
967
968 let avg_predicted: f64 = recent_data.iter().map(|s| s.predicted_confidence).sum::<f64>() / recent_data.len() as f64;
969 let avg_actual: f64 = recent_data.iter().map(|s| s.actual_quality).sum::<f64>() / recent_data.len() as f64;
970
971 if (avg_predicted - avg_actual).abs() > 0.1 {
973 self.calibration_params.shift += (avg_actual - avg_predicted) * 0.01;
974 }
975
976 Ok(())
977 }
978
979 pub fn confidence_predictors(&self) -> &HashMap<ConfidenceFeature, LinearPredictor> {
981 &self.confidence_predictors
982 }
983
984 pub fn calibration_params(&self) -> &CalibrationParams {
985 &self.calibration_params
986 }
987
988 pub fn confidence_accuracy(&self) -> f64 {
989 self.confidence_accuracy
990 }
991
992 pub fn calibration_data(&self) -> &VecDeque<ConfidenceTrainingSample> {
993 &self.calibration_data
994 }
995
996 pub fn add_calibration_sample(&mut self, sample: ConfidenceTrainingSample) {
998 self.calibration_data.push_back(sample);
999 }
1000}
1001
1002impl WandFeatureExtractor {
1004 pub fn extract_features(&self, context: &QueryContext, state: &WandSearchState) -> HashMap<WandFeature, f64> {
1005 let mut features = HashMap::new();
1006
1007 let elapsed = context.start_time.elapsed().as_millis() as f64;
1008
1009 features.insert(WandFeature::IterationCount, state.iteration as f64);
1010 features.insert(WandFeature::TimeElapsed, elapsed);
1011 features.insert(WandFeature::CandidateSetSize, state.candidate_count as f64);
1012
1013 let score_improvement = if state.score_improvements.len() >= 2 {
1015 let recent = &state.score_improvements[state.score_improvements.len()-2..];
1016 recent[1] - recent[0]
1017 } else {
1018 0.0
1019 };
1020 features.insert(WandFeature::ScoreImprovement, score_improvement);
1021
1022 let avg_term_contribution = if !state.term_contributions.is_empty() {
1024 state.term_contributions.values().sum::<f64>() / state.term_contributions.len() as f64
1025 } else {
1026 0.0
1027 };
1028 features.insert(WandFeature::TermContribution, avg_term_contribution);
1029
1030 let quality_estimate = if state.iteration > 0 {
1032 (state.current_threshold / state.iteration as f64).min(1.0)
1033 } else {
1034 0.5
1035 };
1036 features.insert(WandFeature::QualityEstimate, quality_estimate);
1037
1038 let threshold_convergence = if state.score_improvements.len() >= 3 {
1040 let recent_variance: f64 = state.score_improvements.iter().rev().take(3)
1041 .map(|&x| (x - score_improvement).powi(2))
1042 .sum::<f64>() / 3.0;
1043 (1.0 / (1.0 + recent_variance)).min(1.0)
1044 } else {
1045 0.0
1046 };
1047 features.insert(WandFeature::ThresholdConvergence, threshold_convergence);
1048
1049 features
1050 }
1051}
1052
1053impl HnswFeatureExtractor {
1054 pub fn extract_features(&self, context: &QueryContext, state: &HnswSearchState) -> HashMap<HnswFeature, f64> {
1055 let mut features = HashMap::new();
1056
1057 features.insert(HnswFeature::LayerDepth, state.current_layer as f64);
1058 features.insert(HnswFeature::DistanceToQuery, state.best_distance as f64);
1059 features.insert(HnswFeature::BeamPosition, state.beam_candidates.len() as f64);
1060 features.insert(HnswFeature::ExplorationRatio, state.exploration_ratio);
1061
1062 let avg_neighbors = if !state.beam_candidates.is_empty() {
1064 state.beam_candidates.iter().map(|c| c.neighbor_count as f64).sum::<f64>() / state.beam_candidates.len() as f64
1065 } else {
1066 0.0
1067 };
1068 features.insert(HnswFeature::NeighborCount, avg_neighbors);
1069
1070 let distance_improvement = if state.best_distance < 1.0 {
1072 1.0 - state.best_distance as f64
1073 } else {
1074 0.0
1075 };
1076 features.insert(HnswFeature::DistanceImprovement, distance_improvement);
1077
1078 let connectivity = if !state.beam_candidates.is_empty() {
1080 let total_connections: usize = state.beam_candidates.iter().map(|c| c.neighbor_count).sum();
1081 (total_connections as f64 / state.beam_candidates.len() as f64) / 16.0 } else {
1083 0.0
1084 };
1085 features.insert(HnswFeature::GraphConnectivity, connectivity);
1086
1087 features
1088 }
1089}
1090
1091impl ConfidenceFeatureExtractor {
1092 pub fn extract_features(&self, _context: &QueryContext, result_count: usize, processing_time: f64) -> HashMap<ConfidenceFeature, f64> {
1093 let mut features = HashMap::new();
1094
1095 features.insert(ConfidenceFeature::ResultCount, result_count as f64);
1096 features.insert(ConfidenceFeature::ProcessingTime, processing_time);
1097
1098 let score_distribution = if result_count > 0 { 1.0 } else { 0.0 };
1100 features.insert(ConfidenceFeature::ScoreDistribution, score_distribution);
1101
1102 features
1103 }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use super::*;
1109
1110 #[tokio::test]
1111 async fn test_learning_model_creation() {
1112 let config = LearningConfig::default();
1113 let model = LearningStopModel::new(config).await;
1114 assert!(model.is_ok());
1115 }
1116
1117 #[tokio::test]
1118 async fn test_wand_feature_extraction() {
1119 let extractor = WandFeatureExtractor;
1120 let context = QueryContext {
1121 query_terms: vec!["test".to_string()],
1122 query_vector: None,
1123 start_time: std::time::Instant::now(),
1124 complexity_score: 0.5,
1125 expected_result_count: 10,
1126 };
1127
1128 let state = WandSearchState {
1129 iteration: 5,
1130 current_threshold: 0.8,
1131 candidate_count: 20,
1132 score_improvements: vec![0.1, 0.15, 0.18],
1133 term_contributions: HashMap::new(),
1134 processing_time: std::time::Duration::from_millis(50),
1135 };
1136
1137 let features = extractor.extract_features(&context, &state);
1138
1139 assert!(features.contains_key(&WandFeature::IterationCount));
1140 assert!(features.contains_key(&WandFeature::ScoreImprovement));
1141 assert!(features.contains_key(&WandFeature::CandidateSetSize));
1142
1143 assert_eq!(*features.get(&WandFeature::IterationCount).unwrap(), 5.0);
1144 }
1145
1146 #[tokio::test]
1147 async fn test_wand_predictor() {
1148 let config = WandLearningConfig::default();
1149 let predictor = WandStoppingPredictor::new(config);
1150
1151 let mut features = HashMap::new();
1152 features.insert(WandFeature::QualityEstimate, 0.9);
1153 features.insert(WandFeature::ScoreImprovement, 0.1);
1154 features.insert(WandFeature::ThresholdConvergence, 0.8);
1155
1156 let (should_stop, confidence) = predictor.predict(&features);
1157
1158 assert!(confidence > 0.3);
1160
1161 let mut low_features = HashMap::new();
1163 low_features.insert(WandFeature::QualityEstimate, 0.1);
1164 low_features.insert(WandFeature::ScoreImprovement, 0.01);
1165
1166 let (low_stop, low_confidence) = predictor.predict(&low_features);
1167
1168 assert!(low_confidence < confidence);
1170 }
1171
1172 #[tokio::test]
1173 async fn test_hnsw_feature_extraction() {
1174 let extractor = HnswFeatureExtractor;
1175 let context = QueryContext {
1176 query_terms: vec![],
1177 query_vector: Some(vec![0.1, 0.2, 0.3]),
1178 start_time: std::time::Instant::now(),
1179 complexity_score: 0.7,
1180 expected_result_count: 5,
1181 };
1182
1183 let candidates = vec![
1184 HnswCandidate { node_id: 1, distance: 0.1, layer: 0, neighbor_count: 8 },
1185 HnswCandidate { node_id: 2, distance: 0.15, layer: 0, neighbor_count: 12 },
1186 ];
1187
1188 let state = HnswSearchState {
1189 current_layer: 1,
1190 beam_candidates: candidates,
1191 visited_nodes: 25,
1192 best_distance: 0.1,
1193 exploration_ratio: 0.6,
1194 };
1195
1196 let features = extractor.extract_features(&context, &state);
1197
1198 assert!(features.contains_key(&HnswFeature::LayerDepth));
1199 assert!(features.contains_key(&HnswFeature::DistanceToQuery));
1200 assert!(features.contains_key(&HnswFeature::ExplorationRatio));
1201
1202 assert_eq!(*features.get(&HnswFeature::LayerDepth).unwrap(), 1.0);
1203 assert_eq!(*features.get(&HnswFeature::ExplorationRatio).unwrap(), 0.6);
1204 }
1205
1206 #[tokio::test]
1207 async fn test_learning_model_prediction() {
1208 let config = LearningConfig::default();
1209 let model = LearningStopModel::new(config).await.unwrap();
1210
1211 let context = QueryContext {
1212 query_terms: vec!["function".to_string(), "test".to_string()],
1213 query_vector: None,
1214 start_time: std::time::Instant::now(),
1215 complexity_score: 0.6,
1216 expected_result_count: 15,
1217 };
1218
1219 let wand_state = WandSearchState {
1220 iteration: 10,
1221 current_threshold: 0.75,
1222 candidate_count: 30,
1223 score_improvements: vec![0.2, 0.25, 0.27],
1224 term_contributions: HashMap::new(),
1225 processing_time: std::time::Duration::from_millis(80),
1226 };
1227
1228 let decision = model.predict_wand_stopping(&context, &wand_state).await.unwrap();
1229
1230 assert!(decision.confidence >= 0.0 && decision.confidence <= 1.0);
1231 assert!(decision.predicted_quality >= 0.0 && decision.predicted_quality <= 1.0);
1232 assert!(decision.estimated_computation_saved >= 0.0);
1233 assert!(!decision.reasoning.feature_contributions.is_empty());
1234 }
1235}