lethe_core_rust/
ml_prediction.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use crate::error::Result;
4use crate::query_understanding::{QueryUnderstanding, QueryType, QueryIntent, QueryComplexity};
5
6/// Static feature weight configurations to avoid HashMap initialization
7static FEATURE_WEIGHTS: &[(&str, f32)] = &[
8    ("query_length", 0.15),
9    ("complexity", 0.25),
10    ("technical_terms", 0.20),
11    ("domain_specificity", 0.15),
12    ("semantic_complexity", 0.25),
13];
14
15/// Static strategy weight configurations
16static STRATEGY_WEIGHTS: &[(RetrievalStrategy, f32)] = &[
17    (RetrievalStrategy::BM25Only, 1.0),
18    (RetrievalStrategy::VectorOnly, 1.0),
19    (RetrievalStrategy::Hybrid, 1.2),
20    (RetrievalStrategy::HydeEnhanced, 0.8),
21    (RetrievalStrategy::MultiStep, 0.9),
22    (RetrievalStrategy::Adaptive, 1.1),
23];
24
25/// Static feature scoring rules to replace complex if-statements
26struct FeatureScoringRule {
27    condition: fn(&MLFeatures) -> bool,
28    strategy: RetrievalStrategy,
29    score: f32,
30}
31
32static FEATURE_SCORING_RULES: &[FeatureScoringRule] = &[
33    FeatureScoringRule {
34        condition: |f| f.semantic_complexity > 0.7,
35        strategy: RetrievalStrategy::VectorOnly,
36        score: 0.3,
37    },
38    FeatureScoringRule {
39        condition: |f| f.semantic_complexity > 0.7,
40        strategy: RetrievalStrategy::HydeEnhanced,
41        score: 0.2,
42    },
43    FeatureScoringRule {
44        condition: |f| f.technical_term_count > 0.5 || f.has_code > 0.5,
45        strategy: RetrievalStrategy::BM25Only,
46        score: 0.3,
47    },
48    FeatureScoringRule {
49        condition: |f| f.query_complexity_score > 0.6,
50        strategy: RetrievalStrategy::Hybrid,
51        score: 0.4,
52    },
53    FeatureScoringRule {
54        condition: |f| f.query_complexity_score > 0.6,
55        strategy: RetrievalStrategy::MultiStep,
56        score: 0.2,
57    },
58    FeatureScoringRule {
59        condition: |f| f.domain_specificity < 0.5,
60        strategy: RetrievalStrategy::Adaptive,
61        score: 0.2,
62    },
63];
64
65/// Static feature names to avoid vector allocation
66static FEATURE_NAMES: &[&str] = &[
67    "query_length",
68    "query_complexity_score", 
69    "technical_term_count",
70    "question_word_presence",
71    "domain_specificity",
72    "has_code",
73    "has_numbers",
74    "intent_score",
75    "semantic_complexity",
76];
77
78/// Static strategy name mappings
79static STRATEGY_NAMES: &[(RetrievalStrategy, &str)] = &[
80    (RetrievalStrategy::BM25Only, "BM25-only"),
81    (RetrievalStrategy::VectorOnly, "Vector-only"),
82    (RetrievalStrategy::Hybrid, "Hybrid"),
83    (RetrievalStrategy::HydeEnhanced, "HyDE-enhanced"),
84    (RetrievalStrategy::MultiStep, "Multi-step"),
85    (RetrievalStrategy::Adaptive, "Adaptive"),
86];
87
88/// Static complexity scoring patterns
89static COMPLEXITY_SCORES: &[(QueryComplexity, f32)] = &[
90    (QueryComplexity::Simple, 0.2),
91    (QueryComplexity::Medium, 0.5),
92    (QueryComplexity::Complex, 0.8),
93    (QueryComplexity::VeryComplex, 1.0),
94];
95
96/// Static intent scoring patterns
97static INTENT_SCORES: &[(QueryIntent, f32)] = &[
98    (QueryIntent::Search, 0.8),
99    (QueryIntent::Explain, 0.6),
100    (QueryIntent::Code, 1.0),
101    (QueryIntent::Debug, 0.9),
102    (QueryIntent::Compare, 0.7),
103    (QueryIntent::Guide, 0.5),
104    (QueryIntent::Assist, 0.4),
105    (QueryIntent::Chat, 0.2),
106];
107
108/// ML model prediction for retrieval strategy selection
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RetrievalStrategyPrediction {
111    pub strategy: RetrievalStrategy,
112    pub confidence: f32,
113    pub features_used: Vec<String>,
114    pub alternatives: Vec<(RetrievalStrategy, f32)>,
115}
116
117/// Available retrieval strategies
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
119pub enum RetrievalStrategy {
120    /// Pure BM25 lexical search
121    BM25Only,
122    /// Pure vector similarity search
123    VectorOnly,
124    /// Hybrid BM25 + vector search
125    Hybrid,
126    /// HyDE-enhanced vector search
127    HydeEnhanced,
128    /// Multi-step retrieval with reranking
129    MultiStep,
130    /// Adaptive strategy based on query
131    Adaptive,
132}
133
134/// Feature vector for ML prediction
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct MLFeatures {
137    pub query_length: f32,
138    pub query_complexity_score: f32,
139    pub technical_term_count: f32,
140    pub question_word_presence: f32,
141    pub domain_specificity: f32,
142    pub has_code: f32,
143    pub has_numbers: f32,
144    pub intent_score: f32,
145    pub semantic_complexity: f32,
146}
147
148/// ML prediction result with explanations
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct MLPredictionResult {
151    pub prediction: RetrievalStrategyPrediction,
152    pub explanation: String,
153    pub feature_importance: HashMap<String, f32>,
154    pub model_confidence: f32,
155}
156
157/// Configuration for ML prediction service
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct MLPredictionConfig {
160    pub enable_hybrid_fallback: bool,
161    pub confidence_threshold: f32,
162    pub feature_weights: HashMap<String, f32>,
163    pub strategy_weights: HashMap<RetrievalStrategy, f32>,
164}
165
166impl Default for MLPredictionConfig {
167    fn default() -> Self {
168        let feature_weights = FEATURE_WEIGHTS
169            .iter()
170            .map(|(k, v)| (k.to_string(), *v))
171            .collect();
172            
173        let strategy_weights = STRATEGY_WEIGHTS
174            .iter()
175            .map(|(k, v)| (k.clone(), *v))
176            .collect();
177
178        Self {
179            enable_hybrid_fallback: true,
180            confidence_threshold: 0.7,
181            feature_weights,
182            strategy_weights,
183        }
184    }
185}
186
187/// ML prediction service for retrieval strategy selection
188pub struct MLPredictionService {
189    _config: MLPredictionConfig,
190    strategy_rules: Vec<Box<dyn StrategyRule>>,
191}
192
193impl MLPredictionService {
194    pub fn new(config: MLPredictionConfig) -> Self {
195        let mut service = Self {
196            _config: config,
197            strategy_rules: Vec::new(),
198        };
199        
200        service.initialize_rules();
201        service
202    }
203
204    /// Predict the best retrieval strategy for a given query understanding
205    pub fn predict_strategy(&self, understanding: &QueryUnderstanding) -> Result<MLPredictionResult> {
206        let features = self.extract_features(understanding);
207        let (strategy_scores, explanations) = self.collect_strategy_scores(understanding, &features);
208        let prediction = self.create_prediction_from_scores(strategy_scores, &features);
209        let explanation = self.generate_explanation(&prediction, understanding, &explanations);
210        let feature_importance = self.calculate_feature_importance(&features);
211        let confidence = prediction.confidence;
212
213        Ok(MLPredictionResult {
214            prediction,
215            explanation,
216            feature_importance,
217            model_confidence: confidence,
218        })
219    }
220    
221    /// Collect strategy scores from rules and features
222    fn collect_strategy_scores(
223        &self, 
224        understanding: &QueryUnderstanding, 
225        features: &MLFeatures
226    ) -> (HashMap<RetrievalStrategy, f32>, Vec<String>) {
227        let mut strategy_scores: HashMap<RetrievalStrategy, f32> = HashMap::new();
228        let mut explanations = Vec::new();
229        
230        // Apply rule-based predictions
231        for rule in &self.strategy_rules {
232            if let Some(prediction) = rule.evaluate(understanding, features) {
233                *strategy_scores.entry(prediction.strategy.clone()).or_insert(0.0) += prediction.confidence;
234                explanations.push(prediction.explanation);
235            }
236        }
237        
238        // Apply feature-based scoring
239        self.apply_feature_scoring(features, &mut strategy_scores);
240        
241        (strategy_scores, explanations)
242    }
243    
244    /// Create prediction from strategy scores
245    fn create_prediction_from_scores(
246        &self,
247        strategy_scores: HashMap<RetrievalStrategy, f32>,
248        features: &MLFeatures
249    ) -> RetrievalStrategyPrediction {
250        let (best_strategy, best_score) = self.select_best_strategy(&strategy_scores);
251        let total_score: f32 = strategy_scores.values().sum();
252        let alternatives = self.create_alternatives(strategy_scores, &best_strategy, total_score);
253        
254        RetrievalStrategyPrediction {
255            strategy: best_strategy,
256            confidence: (best_score / total_score).min(1.0),
257            features_used: features.get_feature_names(),
258            alternatives,
259        }
260    }
261    
262    /// Select the best strategy from scores
263    fn select_best_strategy(&self, strategy_scores: &HashMap<RetrievalStrategy, f32>) -> (RetrievalStrategy, f32) {
264        strategy_scores
265            .iter()
266            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
267            .map(|(s, score)| (s.clone(), *score))
268            .unwrap_or((RetrievalStrategy::Hybrid, 0.5))
269    }
270    
271    /// Create alternative strategies list
272    fn create_alternatives(
273        &self,
274        strategy_scores: HashMap<RetrievalStrategy, f32>,
275        best_strategy: &RetrievalStrategy,
276        total_score: f32
277    ) -> Vec<(RetrievalStrategy, f32)> {
278        let mut alternatives: Vec<(RetrievalStrategy, f32)> = strategy_scores
279            .into_iter()
280            .filter(|(s, _)| s != best_strategy)
281            .map(|(s, score)| (s, score / total_score))
282            .collect();
283        
284        alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
285        alternatives
286    }
287
288    /// Extract ML features from query understanding
289    fn extract_features(&self, understanding: &QueryUnderstanding) -> MLFeatures {
290        let query_length = (understanding.original_query.len() as f32 / 100.0).min(2.0);
291        
292        let query_complexity_score = COMPLEXITY_SCORES
293            .iter()
294            .find(|(complexity, _)| *complexity == understanding.complexity)
295            .map(|(_, score)| *score)
296            .unwrap_or(0.5);
297
298        let technical_term_count = (understanding.features.technical_terms.len() as f32 / 10.0).min(1.0);
299        
300        let question_word_presence = if understanding.features.question_words.is_empty() {
301            0.0
302        } else {
303            (understanding.features.question_words.len() as f32 / 5.0).min(1.0)
304        };
305
306        let domain_specificity = understanding.domain.confidence;
307
308        let has_code = if understanding.features.has_code { 1.0 } else { 0.0 };
309        let has_numbers = if understanding.features.has_numbers { 1.0 } else { 0.0 };
310
311        let intent_score = INTENT_SCORES
312            .iter()
313            .find(|(intent, _)| *intent == understanding.intent)
314            .map(|(_, score)| *score)
315            .unwrap_or(0.5);
316
317        let semantic_complexity = self.calculate_semantic_complexity(understanding);
318
319        MLFeatures {
320            query_length,
321            query_complexity_score,
322            technical_term_count,
323            question_word_presence,
324            domain_specificity,
325            has_code,
326            has_numbers,
327            intent_score,
328            semantic_complexity,
329        }
330    }
331
332    /// Apply feature-based scoring to strategy predictions using static rules
333    fn apply_feature_scoring(&self, features: &MLFeatures, strategy_scores: &mut HashMap<RetrievalStrategy, f32>) {
334        for rule in FEATURE_SCORING_RULES {
335            if (rule.condition)(features) {
336                *strategy_scores.entry(rule.strategy.clone()).or_insert(0.0) += rule.score;
337            }
338        }
339    }
340
341    /// Calculate semantic complexity of the query
342    fn calculate_semantic_complexity(&self, understanding: &QueryUnderstanding) -> f32 {
343        let mut complexity = 0.0;
344
345        // Abstract concepts increase semantic complexity
346        if understanding.query_type == QueryType::Analytical || 
347           understanding.query_type == QueryType::Subjective {
348            complexity += 0.3;
349        }
350
351        // Multiple entities increase complexity
352        complexity += (understanding.entities.len() as f32 / 10.0).min(0.3);
353
354        // Long queries with few technical terms are more semantic
355        if understanding.features.word_count > 10 && understanding.features.technical_terms.len() < 3 {
356            complexity += 0.4;
357        }
358
359        complexity.min(1.0)
360    }
361
362    /// Generate human-readable explanation for the prediction
363    fn generate_explanation(
364        &self,
365        prediction: &RetrievalStrategyPrediction,
366        understanding: &QueryUnderstanding,
367        _rule_explanations: &[String],
368    ) -> String {
369        let mut explanation = format!(
370            "Selected {} strategy with {:.1}% confidence. ",
371            strategy_to_string(&prediction.strategy),
372            prediction.confidence * 100.0
373        );
374
375        // Add reasoning based on query characteristics
376        match prediction.strategy {
377            RetrievalStrategy::BM25Only => {
378                explanation.push_str("This strategy was chosen because the query contains specific technical terms or keywords that benefit from exact matching.");
379            }
380            RetrievalStrategy::VectorOnly => {
381                explanation.push_str("This strategy was chosen because the query is conceptual and would benefit from semantic similarity matching.");
382            }
383            RetrievalStrategy::Hybrid => {
384                explanation.push_str("This strategy combines both keyword matching and semantic similarity for comprehensive results.");
385            }
386            RetrievalStrategy::HydeEnhanced => {
387                explanation.push_str("This strategy uses hypothetical document generation to improve semantic matching for complex queries.");
388            }
389            RetrievalStrategy::MultiStep => {
390                explanation.push_str("This strategy uses multiple retrieval phases with reranking for high-precision results.");
391            }
392            RetrievalStrategy::Adaptive => {
393                explanation.push_str("This strategy dynamically adjusts based on initial results quality.");
394            }
395        }
396
397        // Add specific insights
398        if understanding.features.has_code {
399            explanation.push_str(" Code-related queries detected.");
400        }
401        if understanding.complexity == QueryComplexity::VeryComplex {
402            explanation.push_str(" High query complexity requires sophisticated retrieval.");
403        }
404
405        explanation
406    }
407
408    /// Calculate feature importance scores
409    fn calculate_feature_importance(&self, features: &MLFeatures) -> HashMap<String, f32> {
410        let mut importance = HashMap::new();
411        
412        importance.insert("query_length".to_string(), features.query_length * 0.15);
413        importance.insert("complexity".to_string(), features.query_complexity_score * 0.25);
414        importance.insert("technical_terms".to_string(), features.technical_term_count * 0.20);
415        importance.insert("domain_specificity".to_string(), features.domain_specificity * 0.15);
416        importance.insert("semantic_complexity".to_string(), features.semantic_complexity * 0.25);
417
418        importance
419    }
420
421    /// Initialize strategy selection rules
422    fn initialize_rules(&mut self) {
423        self.strategy_rules.push(Box::new(TechnicalQueryRule));
424        self.strategy_rules.push(Box::new(SemanticQueryRule));
425        self.strategy_rules.push(Box::new(ComplexQueryRule));
426        self.strategy_rules.push(Box::new(CodeQueryRule));
427        self.strategy_rules.push(Box::new(ComparisonQueryRule));
428    }
429}
430
431impl Default for MLPredictionService {
432    fn default() -> Self {
433        Self::new(MLPredictionConfig::default())
434    }
435}
436
437/// Rule-based prediction for strategy selection
438trait StrategyRule: Send + Sync {
439    fn evaluate(&self, understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction>;
440}
441
442/// Individual rule prediction
443struct RulePrediction {
444    strategy: RetrievalStrategy,
445    confidence: f32,
446    explanation: String,
447}
448
449/// Rule for technical queries
450struct TechnicalQueryRule;
451
452impl StrategyRule for TechnicalQueryRule {
453    fn evaluate(&self, _understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction> {
454        if features.technical_term_count > 0.6 || features.has_code > 0.5 {
455            Some(RulePrediction {
456                strategy: RetrievalStrategy::BM25Only,
457                confidence: 0.8,
458                explanation: "Technical terms favor keyword-based search".to_string(),
459            })
460        } else {
461            None
462        }
463    }
464}
465
466/// Rule for semantic queries
467struct SemanticQueryRule;
468
469impl StrategyRule for SemanticQueryRule {
470    fn evaluate(&self, _understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction> {
471        if features.semantic_complexity > 0.7 && features.technical_term_count < 0.3 {
472            Some(RulePrediction {
473                strategy: RetrievalStrategy::VectorOnly,
474                confidence: 0.7,
475                explanation: "High semantic complexity favors vector search".to_string(),
476            })
477        } else {
478            None
479        }
480    }
481}
482
483/// Rule for complex queries
484struct ComplexQueryRule;
485
486impl StrategyRule for ComplexQueryRule {
487    fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
488        if understanding.complexity == QueryComplexity::VeryComplex {
489            Some(RulePrediction {
490                strategy: RetrievalStrategy::MultiStep,
491                confidence: 0.6,
492                explanation: "Very complex queries benefit from multi-step retrieval".to_string(),
493            })
494        } else if understanding.complexity == QueryComplexity::Complex {
495            Some(RulePrediction {
496                strategy: RetrievalStrategy::Hybrid,
497                confidence: 0.7,
498                explanation: "Complex queries benefit from hybrid approach".to_string(),
499            })
500        } else {
501            None
502        }
503    }
504}
505
506/// Rule for code-related queries
507struct CodeQueryRule;
508
509impl StrategyRule for CodeQueryRule {
510    fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
511        if understanding.query_type == QueryType::Technical && understanding.intent == QueryIntent::Code {
512            Some(RulePrediction {
513                strategy: RetrievalStrategy::BM25Only,
514                confidence: 0.9,
515                explanation: "Code queries require exact matching".to_string(),
516            })
517        } else {
518            None
519        }
520    }
521}
522
523/// Rule for comparison queries
524struct ComparisonQueryRule;
525
526impl StrategyRule for ComparisonQueryRule {
527    fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
528        if understanding.query_type == QueryType::Comparative {
529            Some(RulePrediction {
530                strategy: RetrievalStrategy::HydeEnhanced,
531                confidence: 0.6,
532                explanation: "Comparison queries benefit from hypothetical document expansion".to_string(),
533            })
534        } else {
535            None
536        }
537    }
538}
539
540impl MLFeatures {
541    fn get_feature_names(&self) -> Vec<String> {
542        FEATURE_NAMES.iter().map(|s| s.to_string()).collect()
543    }
544}
545
546fn strategy_to_string(strategy: &RetrievalStrategy) -> &'static str {
547    STRATEGY_NAMES
548        .iter()
549        .find(|(s, _)| s == strategy)
550        .map(|(_, name)| *name)
551        .unwrap_or("Unknown")
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use crate::query_understanding::{QueryDomain, QueryFeatures};
558
559    fn create_test_understanding(query_type: QueryType, intent: QueryIntent, complexity: QueryComplexity) -> QueryUnderstanding {
560        let (technical_terms, has_code) = match query_type {
561            QueryType::Technical => (vec!["code".to_string(), "api".to_string()], true),
562            QueryType::Analytical => (vec![], false),
563            _ => (vec!["term".to_string()], false),
564        };
565
566        QueryUnderstanding {
567            original_query: "test query".to_string(),
568            query_type,
569            intent,
570            complexity,
571            domain: QueryDomain {
572                primary_domain: "programming".to_string(),
573                secondary_domains: vec![],
574                confidence: 0.8,
575            },
576            entities: vec![],
577            features: QueryFeatures {
578                word_count: 5,
579                sentence_count: 1,
580                question_words: vec!["what".to_string()],
581                technical_terms,
582                has_code,
583                has_numbers: false,
584                has_dates: false,
585                language: "en".to_string(),
586            },
587            keywords: vec!["test".to_string(), "query".to_string()],
588            confidence: 0.8,
589        }
590    }
591
592    #[test]
593    fn test_technical_query_prediction() {
594        let service = MLPredictionService::default();
595        let understanding = create_test_understanding(
596            QueryType::Technical,
597            QueryIntent::Code,
598            QueryComplexity::Medium
599        );
600
601        let result = service.predict_strategy(&understanding).unwrap();
602        assert_eq!(result.prediction.strategy, RetrievalStrategy::BM25Only);
603        assert!(result.prediction.confidence > 0.5);
604    }
605
606    #[test]
607    fn test_complex_query_prediction() {
608        let service = MLPredictionService::default();
609        let understanding = create_test_understanding(
610            QueryType::Analytical,
611            QueryIntent::Explain,
612            QueryComplexity::VeryComplex
613        );
614
615        let result = service.predict_strategy(&understanding).unwrap();
616        // Should prefer multi-step or hybrid for very complex queries
617        assert!(matches!(result.prediction.strategy, RetrievalStrategy::MultiStep | RetrievalStrategy::Hybrid));
618    }
619
620    #[test]
621    fn test_feature_extraction() {
622        let service = MLPredictionService::default();
623        let understanding = create_test_understanding(
624            QueryType::Technical,
625            QueryIntent::Code,
626            QueryComplexity::Complex
627        );
628
629        let features = service.extract_features(&understanding);
630        assert!(features.has_code > 0.0);
631        assert!(features.query_complexity_score > 0.5);
632        assert!(features.technical_term_count > 0.0);
633    }
634
635    #[test]
636    fn test_explanation_generation() {
637        let service = MLPredictionService::default();
638        let understanding = create_test_understanding(
639            QueryType::Technical,
640            QueryIntent::Code,
641            QueryComplexity::Medium
642        );
643
644        let result = service.predict_strategy(&understanding).unwrap();
645        assert!(!result.explanation.is_empty());
646        assert!(result.explanation.contains("strategy"));
647    }
648
649    #[test]
650    fn test_feature_importance() {
651        let service = MLPredictionService::default();
652        let understanding = create_test_understanding(
653            QueryType::Technical,
654            QueryIntent::Code,
655            QueryComplexity::Medium
656        );
657
658        let result = service.predict_strategy(&understanding).unwrap();
659        assert!(!result.feature_importance.is_empty());
660        assert!(result.feature_importance.contains_key("complexity"));
661    }
662}