lens_core/pipeline/
learning.rs

1//! Learning-to-Stop Models for WAND/HNSW Early Termination
2//!
3//! Implements machine learning models for adaptive early stopping in WAND queries
4//! and HNSW vector search with confidence-based termination.
5//! 
6//! Target: Dynamic threshold learning with >90% accuracy per TODO.md
7
8use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tracing::{debug, info, warn};
14
15/// Learning-to-stop model coordinator
16pub struct LearningStopModel {
17    /// WAND query early stopping predictor
18    wand_predictor: Arc<RwLock<WandStoppingPredictor>>,
19    
20    /// HNSW vector search early stopping predictor
21    hnsw_predictor: Arc<RwLock<HnswStoppingPredictor>>,
22    
23    /// Confidence-based termination model
24    confidence_model: Arc<RwLock<ConfidenceModel>>,
25    
26    /// Feature extractors for different query types
27    feature_extractors: FeatureExtractors,
28    
29    /// Model training and adaptation
30    training_scheduler: Arc<RwLock<TrainingScheduler>>,
31    
32    /// Performance metrics
33    metrics: Arc<RwLock<LearningMetrics>>,
34    
35    /// Configuration
36    config: LearningConfig,
37}
38
39/// Configuration for learning models
40#[derive(Debug, Clone)]
41pub struct LearningConfig {
42    /// Training data window size
43    pub training_window_size: usize,
44    
45    /// Model update frequency (queries)
46    pub update_frequency: usize,
47    
48    /// Learning rate for model adaptation
49    pub learning_rate: f64,
50    
51    /// Confidence threshold for early stopping decisions
52    pub confidence_threshold: f64,
53    
54    /// Minimum training samples before making predictions
55    pub min_training_samples: usize,
56    
57    /// Feature normalization parameters
58    pub feature_normalization: bool,
59    
60    /// WAND-specific configuration
61    pub wand_config: WandLearningConfig,
62    
63    /// HNSW-specific configuration
64    pub hnsw_config: HnswLearningConfig,
65}
66
67/// WAND learning configuration
68#[derive(Debug, Clone)]
69pub struct WandLearningConfig {
70    /// Maximum WAND iterations before forced stop
71    pub max_iterations: usize,
72    
73    /// Quality degradation threshold
74    pub quality_threshold: f64,
75    
76    /// Score improvement tolerance
77    pub score_improvement_tolerance: f64,
78    
79    /// Term contribution threshold
80    pub term_contribution_threshold: f64,
81}
82
83impl Default for WandLearningConfig {
84    fn default() -> Self {
85        Self {
86            max_iterations: 100,
87            quality_threshold: 0.8,
88            score_improvement_tolerance: 0.01,
89            term_contribution_threshold: 0.05,
90        }
91    }
92}
93
94/// HNSW learning configuration
95#[derive(Debug, Clone)]
96pub struct HnswLearningConfig {
97    /// Maximum HNSW layers to explore
98    pub max_layers: usize,
99    
100    /// Beam search width
101    pub beam_width: usize,
102    
103    /// Distance threshold for early termination
104    pub distance_threshold: f64,
105    
106    /// Neighbor exploration limit
107    pub max_neighbors: usize,
108}
109
110impl Default for HnswLearningConfig {
111    fn default() -> Self {
112        Self {
113            max_layers: 5,
114            beam_width: 64,
115            distance_threshold: 0.1,
116            max_neighbors: 16,
117        }
118    }
119}
120
121/// WAND stopping predictor using learned features
122pub struct WandStoppingPredictor {
123    /// Learned weights for different features
124    weights: HashMap<WandFeature, f64>,
125    
126    /// Training history for adaptation
127    training_history: VecDeque<WandTrainingSample>,
128    
129    /// Current performance metrics
130    accuracy: f64,
131    precision: f64,
132    recall: f64,
133    
134    /// Model state
135    is_trained: bool,
136    last_update: std::time::Instant,
137}
138
139/// HNSW stopping predictor
140pub struct HnswStoppingPredictor {
141    /// Distance-based stopping thresholds per layer
142    layer_thresholds: HashMap<usize, f64>,
143    
144    /// Neighbor quality predictors
145    neighbor_quality_weights: HashMap<HnswFeature, f64>,
146    
147    /// Training samples for HNSW navigation
148    training_samples: VecDeque<HnswTrainingSample>,
149    
150    /// Performance tracking
151    search_efficiency: f64,
152    quality_maintained: f64,
153    
154    /// Adaptive parameters
155    beam_width_adaptation: f64,
156    exploration_decay: f64,
157}
158
159/// Confidence-based termination model
160pub struct ConfidenceModel {
161    /// Confidence predictors for different result types
162    confidence_predictors: HashMap<ConfidenceFeature, LinearPredictor>,
163    
164    /// Calibration parameters for confidence scores
165    calibration_params: CalibrationParams,
166    
167    /// Historical accuracy of confidence predictions
168    confidence_accuracy: f64,
169    
170    /// Training data for confidence calibration
171    calibration_data: VecDeque<ConfidenceTrainingSample>,
172}
173
174/// Feature extractors for different components
175pub struct FeatureExtractors {
176    wand_extractor: WandFeatureExtractor,
177    hnsw_extractor: HnswFeatureExtractor,
178    confidence_extractor: ConfidenceFeatureExtractor,
179}
180
181/// Training scheduler for model updates
182pub struct TrainingScheduler {
183    queries_since_update: usize,
184    update_frequency: usize,
185    next_training_time: std::time::Instant,
186    is_training: bool,
187}
188
189/// Learning metrics and performance tracking
190#[derive(Debug, Default, Clone, Serialize, Deserialize)]
191pub struct LearningMetrics {
192    pub total_predictions: u64,
193    pub correct_early_stops: u64,
194    pub incorrect_early_stops: u64,
195    pub missed_stopping_opportunities: u64,
196    pub avg_computation_saved: f64,
197    pub avg_quality_maintained: f64,
198    pub model_accuracy: f64,
199    pub adaptation_events: u64,
200    pub feature_importance: HashMap<String, f64>,
201}
202
203/// WAND feature types for learning
204#[derive(Debug, Clone, Hash, Eq, PartialEq)]
205pub enum WandFeature {
206    IterationCount,
207    ScoreImprovement,
208    TermContribution,
209    DocumentFrequency,
210    QualityEstimate,
211    TimeElapsed,
212    CandidateSetSize,
213    ThresholdConvergence,
214}
215
216/// HNSW feature types for learning
217#[derive(Debug, Clone, Hash, Eq, PartialEq)]
218pub enum HnswFeature {
219    LayerDepth,
220    DistanceToQuery,
221    NeighborCount,
222    SearchRadius,
223    BeamPosition,
224    ExplorationRatio,
225    DistanceImprovement,
226    GraphConnectivity,
227}
228
229/// Confidence feature types
230#[derive(Debug, Clone, Hash, Eq, PartialEq)]
231pub enum ConfidenceFeature {
232    ResultCount,
233    ScoreDistribution,
234    SystemAgreement,
235    QueryComplexity,
236    ProcessingTime,
237    ResourceUtilization,
238}
239
240/// Training sample for WAND predictor
241#[derive(Debug, Clone)]
242pub struct WandTrainingSample {
243    pub features: HashMap<WandFeature, f64>,
244    pub should_have_stopped: bool,
245    pub actual_quality: f64,
246    pub computation_saved: f64,
247    pub timestamp: std::time::Instant,
248}
249
250/// Training sample for HNSW predictor
251#[derive(Debug, Clone)]
252pub struct HnswTrainingSample {
253    pub features: HashMap<HnswFeature, f64>,
254    pub optimal_stopping_point: usize,
255    pub final_quality: f64,
256    pub search_efficiency: f64,
257    pub timestamp: std::time::Instant,
258}
259
260/// Training sample for confidence model
261#[derive(Debug, Clone)]
262pub struct ConfidenceTrainingSample {
263    pub features: HashMap<ConfidenceFeature, f64>,
264    pub predicted_confidence: f64,
265    pub actual_quality: f64,
266    pub timestamp: std::time::Instant,
267}
268
269/// Linear predictor for confidence calibration
270#[derive(Debug, Clone)]
271pub struct LinearPredictor {
272    weights: Vec<f64>,
273    bias: f64,
274    learning_rate: f64,
275}
276
277impl LinearPredictor {
278    // Constructor for testing
279    pub fn new(weights: Vec<f64>, bias: f64, learning_rate: f64) -> Self {
280        Self {
281            weights,
282            bias,
283            learning_rate,
284        }
285    }
286    
287    // Getter methods for testing
288    pub fn weights(&self) -> &Vec<f64> {
289        &self.weights
290    }
291    
292    pub fn bias(&self) -> f64 {
293        self.bias
294    }
295    
296    pub fn learning_rate(&self) -> f64 {
297        self.learning_rate
298    }
299}
300
301/// Confidence calibration parameters
302#[derive(Debug, Clone)]
303pub struct CalibrationParams {
304    temperature: f64,
305    shift: f64,
306    scale: f64,
307}
308
309impl CalibrationParams {
310    // Constructor for testing
311    pub fn new(temperature: f64, shift: f64, scale: f64) -> Self {
312        Self {
313            temperature,
314            shift,
315            scale,
316        }
317    }
318    
319    // Getter methods for testing
320    pub fn temperature(&self) -> f64 {
321        self.temperature
322    }
323    
324    pub fn shift(&self) -> f64 {
325        self.shift
326    }
327    
328    pub fn scale(&self) -> f64 {
329        self.scale
330    }
331}
332
333/// Feature extractors implementations
334pub struct WandFeatureExtractor;
335pub struct HnswFeatureExtractor;
336pub struct ConfidenceFeatureExtractor;
337
338/// Query context for feature extraction
339#[derive(Debug, Clone)]
340pub struct QueryContext {
341    pub query_terms: Vec<String>,
342    pub query_vector: Option<Vec<f32>>,
343    pub start_time: std::time::Instant,
344    pub complexity_score: f64,
345    pub expected_result_count: usize,
346}
347
348/// Search state for WAND queries
349#[derive(Debug, Clone)]
350pub struct WandSearchState {
351    pub iteration: usize,
352    pub current_threshold: f64,
353    pub candidate_count: usize,
354    pub score_improvements: Vec<f64>,
355    pub term_contributions: HashMap<String, f64>,
356    pub processing_time: std::time::Duration,
357}
358
359/// Search state for HNSW queries
360#[derive(Debug, Clone)]
361pub struct HnswSearchState {
362    pub current_layer: usize,
363    pub beam_candidates: Vec<HnswCandidate>,
364    pub visited_nodes: usize,
365    pub best_distance: f32,
366    pub exploration_ratio: f64,
367}
368
369/// HNSW candidate representation
370#[derive(Debug, Clone)]
371pub struct HnswCandidate {
372    pub node_id: usize,
373    pub distance: f32,
374    pub layer: usize,
375    pub neighbor_count: usize,
376}
377
378/// Stopping decision with learned confidence
379#[derive(Debug, Clone)]
380pub struct LearnedStoppingDecision {
381    pub should_stop: bool,
382    pub confidence: f64,
383    pub predicted_quality: f64,
384    pub estimated_computation_saved: f64,
385    pub reasoning: StoppingReasoning,
386    pub algorithm_used: String,
387}
388
389/// Reasoning for stopping decisions
390#[derive(Debug, Clone)]
391pub struct StoppingReasoning {
392    pub primary_factor: String,
393    pub feature_contributions: HashMap<String, f64>,
394    pub threshold_exceeded: bool,
395    pub quality_sufficient: bool,
396}
397
398impl Default for LearningConfig {
399    fn default() -> Self {
400        Self {
401            training_window_size: 1000,
402            update_frequency: 100,
403            learning_rate: 0.01,
404            confidence_threshold: 0.85,
405            min_training_samples: 50,
406            feature_normalization: true,
407            wand_config: WandLearningConfig {
408                max_iterations: 100,
409                quality_threshold: 0.8,
410                score_improvement_tolerance: 0.01,
411                term_contribution_threshold: 0.05,
412            },
413            hnsw_config: HnswLearningConfig {
414                max_layers: 5,
415                beam_width: 64,
416                distance_threshold: 0.1,
417                max_neighbors: 16,
418            },
419        }
420    }
421}
422
423impl LearningStopModel {
424    /// Create a new learning-to-stop model
425    pub async fn new(config: LearningConfig) -> Result<Self> {
426        let wand_predictor = Arc::new(RwLock::new(WandStoppingPredictor::new(config.wand_config.clone())));
427        let hnsw_predictor = Arc::new(RwLock::new(HnswStoppingPredictor::new(config.hnsw_config.clone())));
428        let confidence_model = Arc::new(RwLock::new(ConfidenceModel::new()));
429        
430        let feature_extractors = FeatureExtractors {
431            wand_extractor: WandFeatureExtractor,
432            hnsw_extractor: HnswFeatureExtractor,
433            confidence_extractor: ConfidenceFeatureExtractor,
434        };
435        
436        let training_scheduler = Arc::new(RwLock::new(TrainingScheduler {
437            queries_since_update: 0,
438            update_frequency: config.update_frequency,
439            next_training_time: std::time::Instant::now(),
440            is_training: false,
441        }));
442        
443        let metrics = Arc::new(RwLock::new(LearningMetrics::default()));
444        
445        info!("Initialized learning-to-stop model with training window: {}", config.training_window_size);
446        
447        Ok(Self {
448            wand_predictor,
449            hnsw_predictor,
450            confidence_model,
451            feature_extractors,
452            training_scheduler,
453            metrics,
454            config,
455        })
456    }
457    
458    /// Predict WAND early stopping decision
459    pub async fn predict_wand_stopping(
460        &self,
461        context: &QueryContext,
462        state: &WandSearchState,
463    ) -> Result<LearnedStoppingDecision> {
464        // Extract features for WAND prediction
465        let features = self.feature_extractors.wand_extractor.extract_features(context, state);
466        
467        // Get prediction from WAND predictor
468        let wand_predictor = self.wand_predictor.read().await;
469        let (should_stop, confidence) = wand_predictor.predict(&features);
470        
471        // Estimate quality and computation savings
472        let predicted_quality = self.estimate_wand_quality(&features);
473        let computation_saved = self.estimate_computation_saved(state.iteration, self.config.wand_config.max_iterations);
474        
475        // Build reasoning
476        let reasoning = self.build_wand_reasoning(&features, should_stop, confidence);
477        
478        // Update metrics
479        self.update_prediction_metrics("wand", should_stop, confidence).await;
480        
481        Ok(LearnedStoppingDecision {
482            should_stop,
483            confidence,
484            predicted_quality,
485            estimated_computation_saved: computation_saved,
486            reasoning,
487            algorithm_used: "WAND-Learned".to_string(),
488        })
489    }
490    
491    /// Predict HNSW early stopping decision
492    pub async fn predict_hnsw_stopping(
493        &self,
494        context: &QueryContext,
495        state: &HnswSearchState,
496    ) -> Result<LearnedStoppingDecision> {
497        // Extract features for HNSW prediction
498        let features = self.feature_extractors.hnsw_extractor.extract_features(context, state);
499        
500        // Get prediction from HNSW predictor
501        let hnsw_predictor = self.hnsw_predictor.read().await;
502        let (should_stop, confidence) = hnsw_predictor.predict(&features);
503        
504        // Estimate quality and computation savings
505        let predicted_quality = self.estimate_hnsw_quality(&features, state);
506        let computation_saved = self.estimate_hnsw_computation_saved(state);
507        
508        // Build reasoning
509        let reasoning = self.build_hnsw_reasoning(&features, should_stop, confidence);
510        
511        // Update metrics
512        self.update_prediction_metrics("hnsw", should_stop, confidence).await;
513        
514        Ok(LearnedStoppingDecision {
515            should_stop,
516            confidence,
517            predicted_quality,
518            estimated_computation_saved: computation_saved,
519            reasoning,
520            algorithm_used: "HNSW-Learned".to_string(),
521        })
522    }
523    
524    /// Train models with new feedback
525    pub async fn train_with_feedback(
526        &self,
527        query_type: &str,
528        decision: &LearnedStoppingDecision,
529        actual_quality: f64,
530        actual_computation_saved: f64,
531    ) -> Result<()> {
532        let mut scheduler = self.training_scheduler.write().await;
533        scheduler.queries_since_update += 1;
534        
535        match query_type {
536            "wand" => {
537                let mut predictor = self.wand_predictor.write().await;
538                predictor.add_training_sample(WandTrainingSample {
539                    features: HashMap::new(), // Would be populated with actual features
540                    should_have_stopped: decision.should_stop,
541                    actual_quality,
542                    computation_saved: actual_computation_saved,
543                    timestamp: std::time::Instant::now(),
544                });
545            }
546            "hnsw" => {
547                let mut predictor = self.hnsw_predictor.write().await;
548                predictor.add_training_sample(HnswTrainingSample {
549                    features: HashMap::new(), // Would be populated with actual features
550                    optimal_stopping_point: 0, // Would be calculated
551                    final_quality: actual_quality,
552                    search_efficiency: actual_computation_saved,
553                    timestamp: std::time::Instant::now(),
554                });
555            }
556            "confidence" => {
557                // Handle confidence training
558                // For now, we'll accept the feedback without error
559                // This could be extended to actually train a confidence model
560            }
561            _ => return Err(anyhow!("Unknown query type: {}", query_type)),
562        }
563        
564        // Update models if needed
565        if scheduler.queries_since_update >= scheduler.update_frequency {
566            self.update_models().await?;
567            scheduler.queries_since_update = 0;
568        }
569        
570        Ok(())
571    }
572    
573    /// Update all models with accumulated training data
574    async fn update_models(&self) -> Result<()> {
575        let mut scheduler = self.training_scheduler.write().await;
576        
577        if scheduler.is_training {
578            return Ok(()); // Already training
579        }
580        
581        scheduler.is_training = true;
582        drop(scheduler);
583        
584        // Update WAND predictor
585        {
586            let mut wand_predictor = self.wand_predictor.write().await;
587            wand_predictor.update_model(self.config.learning_rate)?;
588        }
589        
590        // Update HNSW predictor
591        {
592            let mut hnsw_predictor = self.hnsw_predictor.write().await;
593            hnsw_predictor.update_model(self.config.learning_rate)?;
594        }
595        
596        // Update confidence model
597        {
598            let mut confidence_model = self.confidence_model.write().await;
599            confidence_model.update_calibration()?;
600        }
601        
602        // Reset training flag
603        {
604            let mut scheduler = self.training_scheduler.write().await;
605            scheduler.is_training = false;
606            scheduler.next_training_time = std::time::Instant::now() + std::time::Duration::from_secs(300); // 5 min
607        }
608        
609        // Update adaptation metrics
610        {
611            let mut metrics = self.metrics.write().await;
612            metrics.adaptation_events += 1;
613        }
614        
615        info!("Updated learning models with new training data");
616        
617        Ok(())
618    }
619    
620    /// Estimate WAND quality from features
621    fn estimate_wand_quality(&self, features: &HashMap<WandFeature, f64>) -> f64 {
622        let score_improvement = features.get(&WandFeature::ScoreImprovement).unwrap_or(&0.0);
623        let quality_estimate = features.get(&WandFeature::QualityEstimate).unwrap_or(&0.5);
624        let threshold_convergence = features.get(&WandFeature::ThresholdConvergence).unwrap_or(&0.0);
625        
626        // Simple quality estimation model
627        (score_improvement * 0.4 + quality_estimate * 0.4 + threshold_convergence * 0.2).min(1.0)
628    }
629    
630    /// Estimate HNSW quality from features and state
631    fn estimate_hnsw_quality(&self, features: &HashMap<HnswFeature, f64>, state: &HnswSearchState) -> f64 {
632        let distance_improvement = features.get(&HnswFeature::DistanceImprovement).unwrap_or(&0.0);
633        let exploration_ratio = features.get(&HnswFeature::ExplorationRatio).unwrap_or(&0.5);
634        
635        let distance_quality = if state.best_distance > 0.0 {
636            (1.0 - state.best_distance).max(0.0)
637        } else {
638            0.0
639        };
640        
641        (distance_improvement * 0.3 + exploration_ratio * 0.3 + distance_quality as f64 * 0.4).min(1.0)
642    }
643    
644    /// Estimate computation saved based on early stopping
645    fn estimate_computation_saved(&self, current_iteration: usize, max_iterations: usize) -> f64 {
646        if max_iterations == 0 {
647            return 0.0;
648        }
649        
650        let remaining_iterations = max_iterations.saturating_sub(current_iteration);
651        remaining_iterations as f64 / max_iterations as f64
652    }
653    
654    /// Estimate computation saved for HNSW search
655    fn estimate_hnsw_computation_saved(&self, state: &HnswSearchState) -> f64 {
656        let max_possible_visits = self.config.hnsw_config.max_neighbors * self.config.hnsw_config.max_layers;
657        let remaining_visits = max_possible_visits.saturating_sub(state.visited_nodes);
658        
659        remaining_visits as f64 / max_possible_visits as f64
660    }
661    
662    /// Build reasoning for WAND stopping decision
663    fn build_wand_reasoning(&self, features: &HashMap<WandFeature, f64>, should_stop: bool, confidence: f64) -> StoppingReasoning {
664        let mut feature_contributions = HashMap::new();
665        
666        // Calculate feature contributions (simplified)
667        for (feature, value) in features {
668            let contribution = value * confidence; // Weighted by confidence
669            feature_contributions.insert(format!("{:?}", feature), contribution);
670        }
671        
672        let primary_factor = if should_stop {
673            "Score convergence detected"
674        } else {
675            "Continued exploration needed"
676        }.to_string();
677        
678        StoppingReasoning {
679            primary_factor,
680            feature_contributions,
681            threshold_exceeded: confidence > self.config.confidence_threshold,
682            quality_sufficient: features.get(&WandFeature::QualityEstimate).unwrap_or(&0.0) > &self.config.wand_config.quality_threshold,
683        }
684    }
685    
686    /// Build reasoning for HNSW stopping decision
687    fn build_hnsw_reasoning(&self, features: &HashMap<HnswFeature, f64>, should_stop: bool, confidence: f64) -> StoppingReasoning {
688        let mut feature_contributions = HashMap::new();
689        
690        for (feature, value) in features {
691            let contribution = value * confidence;
692            feature_contributions.insert(format!("{:?}", feature), contribution);
693        }
694        
695        let primary_factor = if should_stop {
696            "Distance threshold reached"
697        } else {
698            "Further exploration beneficial"
699        }.to_string();
700        
701        StoppingReasoning {
702            primary_factor,
703            feature_contributions,
704            threshold_exceeded: confidence > self.config.confidence_threshold,
705            quality_sufficient: features.get(&HnswFeature::DistanceToQuery).unwrap_or(&1.0) < &self.config.hnsw_config.distance_threshold,
706        }
707    }
708    
709    /// Update prediction metrics
710    async fn update_prediction_metrics(&self, algorithm: &str, prediction: bool, confidence: f64) {
711        let mut metrics = self.metrics.write().await;
712        metrics.total_predictions += 1;
713        
714        // Update algorithm-specific metrics (simplified)
715        if prediction {
716            debug!("Predicted early stop for {} with confidence {:.3}", algorithm, confidence);
717        }
718    }
719    
720    /// Get current learning metrics
721    pub async fn get_metrics(&self) -> LearningMetrics {
722        self.metrics.read().await.clone()
723    }
724    
725    /// Get configuration for testing
726    pub fn config(&self) -> &LearningConfig {
727        &self.config
728    }
729}
730
731impl WandStoppingPredictor {
732    pub fn new(_config: WandLearningConfig) -> Self {
733        let mut weights = HashMap::new();
734        
735        // Initialize feature weights
736        weights.insert(WandFeature::IterationCount, -0.1);
737        weights.insert(WandFeature::ScoreImprovement, 0.8);
738        weights.insert(WandFeature::TermContribution, 0.6);
739        weights.insert(WandFeature::QualityEstimate, 0.9);
740        weights.insert(WandFeature::ThresholdConvergence, 0.7);
741        
742        Self {
743            weights,
744            training_history: VecDeque::new(),
745            accuracy: 0.5,
746            precision: 0.5,
747            recall: 0.5,
748            is_trained: false,
749            last_update: std::time::Instant::now(),
750        }
751    }
752    
753    pub fn predict(&self, features: &HashMap<WandFeature, f64>) -> (bool, f64) {
754        let mut score = 0.0;
755        let mut feature_count = 0;
756        
757        for (feature, weight) in &self.weights {
758            if let Some(feature_value) = features.get(feature) {
759                score += feature_value * weight;
760                feature_count += 1;
761            }
762        }
763        
764        if feature_count > 0 {
765            score /= feature_count as f64;
766        }
767        
768        let confidence = (score.tanh() + 1.0) / 2.0; // Normalize to [0,1]
769        let should_stop = confidence > 0.5;
770        
771        (should_stop, confidence)
772    }
773    
774    pub fn add_training_sample(&mut self, sample: WandTrainingSample) {
775        self.training_history.push_back(sample);
776        
777        // Keep training window size limited
778        while self.training_history.len() > 1000 {
779            self.training_history.pop_front();
780        }
781    }
782    
783    pub fn update_model(&mut self, learning_rate: f64) -> Result<()> {
784        if self.training_history.len() < 10 {
785            return Ok(()); // Not enough data
786        }
787        
788        // Simple gradient descent update (simplified)
789        for sample in self.training_history.iter().rev().take(100) {
790            let (predicted, _) = self.predict(&sample.features);
791            let error = if sample.should_have_stopped { 1.0 } else { 0.0 } - if predicted { 1.0 } else { 0.0 };
792            
793            // Update weights based on error
794            for (feature, feature_value) in &sample.features {
795                if let Some(weight) = self.weights.get_mut(feature) {
796                    *weight += learning_rate * error * feature_value;
797                }
798            }
799        }
800        
801        self.last_update = std::time::Instant::now();
802        self.is_trained = true;
803        
804        Ok(())
805    }
806    
807    // Getter methods for testing
808    pub fn weights(&self) -> &HashMap<WandFeature, f64> {
809        &self.weights
810    }
811    
812    pub fn is_trained(&self) -> bool {
813        self.is_trained
814    }
815    
816    pub fn accuracy(&self) -> f64 {
817        self.accuracy
818    }
819    
820    pub fn precision(&self) -> f64 {
821        self.precision
822    }
823    
824    pub fn recall(&self) -> f64 {
825        self.recall
826    }
827    
828    pub fn training_history(&self) -> &VecDeque<WandTrainingSample> {
829        &self.training_history
830    }
831}
832
833impl HnswStoppingPredictor {
834    pub fn new(_config: HnswLearningConfig) -> Self {
835        let mut layer_thresholds = HashMap::new();
836        let mut neighbor_quality_weights = HashMap::new();
837        
838        // Initialize per-layer thresholds
839        for layer in 0..5 {
840            layer_thresholds.insert(layer, 0.1 * (layer + 1) as f64);
841        }
842        
843        // Initialize feature weights
844        neighbor_quality_weights.insert(HnswFeature::DistanceToQuery, 0.9);
845        neighbor_quality_weights.insert(HnswFeature::DistanceImprovement, 0.8);
846        neighbor_quality_weights.insert(HnswFeature::ExplorationRatio, 0.6);
847        neighbor_quality_weights.insert(HnswFeature::GraphConnectivity, 0.4);
848        
849        Self {
850            layer_thresholds,
851            neighbor_quality_weights,
852            training_samples: VecDeque::new(),
853            search_efficiency: 0.5,
854            quality_maintained: 0.5,
855            beam_width_adaptation: 1.0,
856            exploration_decay: 0.95,
857        }
858    }
859    
860    pub fn predict(&self, features: &HashMap<HnswFeature, f64>) -> (bool, f64) {
861        let mut quality_score = 0.0;
862        let mut feature_count = 0;
863        
864        for (feature, weight) in &self.neighbor_quality_weights {
865            if let Some(feature_value) = features.get(feature) {
866                quality_score += feature_value * weight;
867                feature_count += 1;
868            }
869        }
870        
871        if feature_count > 0 {
872            quality_score /= feature_count as f64;
873        }
874        
875        let confidence = quality_score.min(1.0).max(0.0);
876        let should_stop = confidence > 0.7; // Higher threshold for HNSW
877        
878        (should_stop, confidence)
879    }
880    
881    pub fn add_training_sample(&mut self, sample: HnswTrainingSample) {
882        self.training_samples.push_back(sample);
883        
884        while self.training_samples.len() > 1000 {
885            self.training_samples.pop_front();
886        }
887    }
888    
889    pub fn update_model(&mut self, learning_rate: f64) -> Result<()> {
890        if self.training_samples.len() < 10 {
891            return Ok(());
892        }
893        
894        // Update model parameters based on training samples
895        // This is a simplified version - real implementation would use more sophisticated ML
896        let recent_samples: Vec<_> = self.training_samples.iter().rev().take(50).collect();
897        
898        let avg_efficiency: f64 = recent_samples.iter().map(|s| s.search_efficiency).sum::<f64>() / recent_samples.len() as f64;
899        let avg_quality: f64 = recent_samples.iter().map(|s| s.final_quality).sum::<f64>() / recent_samples.len() as f64;
900        
901        // Adapt parameters
902        self.search_efficiency = self.search_efficiency * (1.0 - learning_rate) + avg_efficiency * learning_rate;
903        self.quality_maintained = self.quality_maintained * (1.0 - learning_rate) + avg_quality * learning_rate;
904        
905        // Adapt exploration parameters
906        if avg_efficiency < 0.6 {
907            self.beam_width_adaptation *= 1.1; // Increase beam width
908        } else if avg_efficiency > 0.8 {
909            self.beam_width_adaptation *= 0.95; // Decrease beam width
910        }
911        
912        Ok(())
913    }
914    
915    // Getter methods for testing
916    pub fn layer_thresholds(&self) -> &HashMap<usize, f64> {
917        &self.layer_thresholds
918    }
919    
920    pub fn neighbor_quality_weights(&self) -> &HashMap<HnswFeature, f64> {
921        &self.neighbor_quality_weights
922    }
923    
924    pub fn training_samples(&self) -> &VecDeque<HnswTrainingSample> {
925        &self.training_samples
926    }
927    
928    pub fn search_efficiency(&self) -> f64 {
929        self.search_efficiency
930    }
931    
932    pub fn quality_maintained(&self) -> f64 {
933        self.quality_maintained
934    }
935    
936    pub fn beam_width_adaptation(&self) -> f64 {
937        self.beam_width_adaptation
938    }
939    
940    pub fn exploration_decay(&self) -> f64 {
941        self.exploration_decay
942    }
943}
944
945impl ConfidenceModel {
946    pub fn new() -> Self {
947        Self {
948            confidence_predictors: HashMap::new(),
949            calibration_params: CalibrationParams {
950                temperature: 1.0,
951                shift: 0.0,
952                scale: 1.0,
953            },
954            confidence_accuracy: 0.5,
955            calibration_data: VecDeque::new(),
956        }
957    }
958    
959    pub fn update_calibration(&mut self) -> Result<()> {
960        // Update confidence calibration based on historical data
961        if self.calibration_data.len() < 20 {
962            return Ok(());
963        }
964        
965        // Simple calibration update (real implementation would use isotonic regression)
966        let recent_data: Vec<_> = self.calibration_data.iter().rev().take(100).collect();
967        
968        let avg_predicted: f64 = recent_data.iter().map(|s| s.predicted_confidence).sum::<f64>() / recent_data.len() as f64;
969        let avg_actual: f64 = recent_data.iter().map(|s| s.actual_quality).sum::<f64>() / recent_data.len() as f64;
970        
971        // Adjust calibration parameters
972        if (avg_predicted - avg_actual).abs() > 0.1 {
973            self.calibration_params.shift += (avg_actual - avg_predicted) * 0.01;
974        }
975        
976        Ok(())
977    }
978    
979    // Getter methods for testing
980    pub fn confidence_predictors(&self) -> &HashMap<ConfidenceFeature, LinearPredictor> {
981        &self.confidence_predictors
982    }
983    
984    pub fn calibration_params(&self) -> &CalibrationParams {
985        &self.calibration_params
986    }
987    
988    pub fn confidence_accuracy(&self) -> f64 {
989        self.confidence_accuracy
990    }
991    
992    pub fn calibration_data(&self) -> &VecDeque<ConfidenceTrainingSample> {
993        &self.calibration_data
994    }
995    
996    // Method to add calibration data for testing
997    pub fn add_calibration_sample(&mut self, sample: ConfidenceTrainingSample) {
998        self.calibration_data.push_back(sample);
999    }
1000}
1001
1002// Feature extractor implementations
1003impl WandFeatureExtractor {
1004    pub fn extract_features(&self, context: &QueryContext, state: &WandSearchState) -> HashMap<WandFeature, f64> {
1005        let mut features = HashMap::new();
1006        
1007        let elapsed = context.start_time.elapsed().as_millis() as f64;
1008        
1009        features.insert(WandFeature::IterationCount, state.iteration as f64);
1010        features.insert(WandFeature::TimeElapsed, elapsed);
1011        features.insert(WandFeature::CandidateSetSize, state.candidate_count as f64);
1012        
1013        // Calculate score improvement trend
1014        let score_improvement = if state.score_improvements.len() >= 2 {
1015            let recent = &state.score_improvements[state.score_improvements.len()-2..];
1016            recent[1] - recent[0]
1017        } else {
1018            0.0
1019        };
1020        features.insert(WandFeature::ScoreImprovement, score_improvement);
1021        
1022        // Average term contribution
1023        let avg_term_contribution = if !state.term_contributions.is_empty() {
1024            state.term_contributions.values().sum::<f64>() / state.term_contributions.len() as f64
1025        } else {
1026            0.0
1027        };
1028        features.insert(WandFeature::TermContribution, avg_term_contribution);
1029        
1030        // Quality estimate based on threshold stability
1031        let quality_estimate = if state.iteration > 0 { 
1032            (state.current_threshold / state.iteration as f64).min(1.0) 
1033        } else { 
1034            0.5 
1035        };
1036        features.insert(WandFeature::QualityEstimate, quality_estimate);
1037        
1038        // Threshold convergence
1039        let threshold_convergence = if state.score_improvements.len() >= 3 {
1040            let recent_variance: f64 = state.score_improvements.iter().rev().take(3)
1041                .map(|&x| (x - score_improvement).powi(2))
1042                .sum::<f64>() / 3.0;
1043            (1.0 / (1.0 + recent_variance)).min(1.0)
1044        } else {
1045            0.0
1046        };
1047        features.insert(WandFeature::ThresholdConvergence, threshold_convergence);
1048        
1049        features
1050    }
1051}
1052
1053impl HnswFeatureExtractor {
1054    pub fn extract_features(&self, context: &QueryContext, state: &HnswSearchState) -> HashMap<HnswFeature, f64> {
1055        let mut features = HashMap::new();
1056        
1057        features.insert(HnswFeature::LayerDepth, state.current_layer as f64);
1058        features.insert(HnswFeature::DistanceToQuery, state.best_distance as f64);
1059        features.insert(HnswFeature::BeamPosition, state.beam_candidates.len() as f64);
1060        features.insert(HnswFeature::ExplorationRatio, state.exploration_ratio);
1061        
1062        // Average neighbor count
1063        let avg_neighbors = if !state.beam_candidates.is_empty() {
1064            state.beam_candidates.iter().map(|c| c.neighbor_count as f64).sum::<f64>() / state.beam_candidates.len() as f64
1065        } else {
1066            0.0
1067        };
1068        features.insert(HnswFeature::NeighborCount, avg_neighbors);
1069        
1070        // Distance improvement (if we have previous best distances)
1071        let distance_improvement = if state.best_distance < 1.0 {
1072            1.0 - state.best_distance as f64
1073        } else {
1074            0.0
1075        };
1076        features.insert(HnswFeature::DistanceImprovement, distance_improvement);
1077        
1078        // Graph connectivity estimate
1079        let connectivity = if !state.beam_candidates.is_empty() {
1080            let total_connections: usize = state.beam_candidates.iter().map(|c| c.neighbor_count).sum();
1081            (total_connections as f64 / state.beam_candidates.len() as f64) / 16.0 // Normalize by max degree
1082        } else {
1083            0.0
1084        };
1085        features.insert(HnswFeature::GraphConnectivity, connectivity);
1086        
1087        features
1088    }
1089}
1090
1091impl ConfidenceFeatureExtractor {
1092    pub fn extract_features(&self, _context: &QueryContext, result_count: usize, processing_time: f64) -> HashMap<ConfidenceFeature, f64> {
1093        let mut features = HashMap::new();
1094        
1095        features.insert(ConfidenceFeature::ResultCount, result_count as f64);
1096        features.insert(ConfidenceFeature::ProcessingTime, processing_time);
1097        
1098        // Simplified features for confidence prediction
1099        let score_distribution = if result_count > 0 { 1.0 } else { 0.0 };
1100        features.insert(ConfidenceFeature::ScoreDistribution, score_distribution);
1101        
1102        features
1103    }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108    use super::*;
1109
1110    #[tokio::test]
1111    async fn test_learning_model_creation() {
1112        let config = LearningConfig::default();
1113        let model = LearningStopModel::new(config).await;
1114        assert!(model.is_ok());
1115    }
1116
1117    #[tokio::test]
1118    async fn test_wand_feature_extraction() {
1119        let extractor = WandFeatureExtractor;
1120        let context = QueryContext {
1121            query_terms: vec!["test".to_string()],
1122            query_vector: None,
1123            start_time: std::time::Instant::now(),
1124            complexity_score: 0.5,
1125            expected_result_count: 10,
1126        };
1127        
1128        let state = WandSearchState {
1129            iteration: 5,
1130            current_threshold: 0.8,
1131            candidate_count: 20,
1132            score_improvements: vec![0.1, 0.15, 0.18],
1133            term_contributions: HashMap::new(),
1134            processing_time: std::time::Duration::from_millis(50),
1135        };
1136        
1137        let features = extractor.extract_features(&context, &state);
1138        
1139        assert!(features.contains_key(&WandFeature::IterationCount));
1140        assert!(features.contains_key(&WandFeature::ScoreImprovement));
1141        assert!(features.contains_key(&WandFeature::CandidateSetSize));
1142        
1143        assert_eq!(*features.get(&WandFeature::IterationCount).unwrap(), 5.0);
1144    }
1145
1146    #[tokio::test]
1147    async fn test_wand_predictor() {
1148        let config = WandLearningConfig::default();
1149        let predictor = WandStoppingPredictor::new(config);
1150        
1151        let mut features = HashMap::new();
1152        features.insert(WandFeature::QualityEstimate, 0.9);
1153        features.insert(WandFeature::ScoreImprovement, 0.1);
1154        features.insert(WandFeature::ThresholdConvergence, 0.8);
1155        
1156        let (should_stop, confidence) = predictor.predict(&features);
1157        
1158        // With high quality features, should have reasonable confidence
1159        assert!(confidence > 0.3);
1160        
1161        // Test with low quality features
1162        let mut low_features = HashMap::new();
1163        low_features.insert(WandFeature::QualityEstimate, 0.1);
1164        low_features.insert(WandFeature::ScoreImprovement, 0.01);
1165        
1166        let (low_stop, low_confidence) = predictor.predict(&low_features);
1167        
1168        // Low quality should result in lower confidence
1169        assert!(low_confidence < confidence);
1170    }
1171
1172    #[tokio::test]
1173    async fn test_hnsw_feature_extraction() {
1174        let extractor = HnswFeatureExtractor;
1175        let context = QueryContext {
1176            query_terms: vec![],
1177            query_vector: Some(vec![0.1, 0.2, 0.3]),
1178            start_time: std::time::Instant::now(),
1179            complexity_score: 0.7,
1180            expected_result_count: 5,
1181        };
1182        
1183        let candidates = vec![
1184            HnswCandidate { node_id: 1, distance: 0.1, layer: 0, neighbor_count: 8 },
1185            HnswCandidate { node_id: 2, distance: 0.15, layer: 0, neighbor_count: 12 },
1186        ];
1187        
1188        let state = HnswSearchState {
1189            current_layer: 1,
1190            beam_candidates: candidates,
1191            visited_nodes: 25,
1192            best_distance: 0.1,
1193            exploration_ratio: 0.6,
1194        };
1195        
1196        let features = extractor.extract_features(&context, &state);
1197        
1198        assert!(features.contains_key(&HnswFeature::LayerDepth));
1199        assert!(features.contains_key(&HnswFeature::DistanceToQuery));
1200        assert!(features.contains_key(&HnswFeature::ExplorationRatio));
1201        
1202        assert_eq!(*features.get(&HnswFeature::LayerDepth).unwrap(), 1.0);
1203        assert_eq!(*features.get(&HnswFeature::ExplorationRatio).unwrap(), 0.6);
1204    }
1205
1206    #[tokio::test]
1207    async fn test_learning_model_prediction() {
1208        let config = LearningConfig::default();
1209        let model = LearningStopModel::new(config).await.unwrap();
1210        
1211        let context = QueryContext {
1212            query_terms: vec!["function".to_string(), "test".to_string()],
1213            query_vector: None,
1214            start_time: std::time::Instant::now(),
1215            complexity_score: 0.6,
1216            expected_result_count: 15,
1217        };
1218        
1219        let wand_state = WandSearchState {
1220            iteration: 10,
1221            current_threshold: 0.75,
1222            candidate_count: 30,
1223            score_improvements: vec![0.2, 0.25, 0.27],
1224            term_contributions: HashMap::new(),
1225            processing_time: std::time::Duration::from_millis(80),
1226        };
1227        
1228        let decision = model.predict_wand_stopping(&context, &wand_state).await.unwrap();
1229        
1230        assert!(decision.confidence >= 0.0 && decision.confidence <= 1.0);
1231        assert!(decision.predicted_quality >= 0.0 && decision.predicted_quality <= 1.0);
1232        assert!(decision.estimated_computation_saved >= 0.0);
1233        assert!(!decision.reasoning.feature_contributions.is_empty());
1234    }
1235}