1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use crate::error::Result;
4use crate::query_understanding::{QueryUnderstanding, QueryType, QueryIntent, QueryComplexity};
5
6static FEATURE_WEIGHTS: &[(&str, f32)] = &[
8 ("query_length", 0.15),
9 ("complexity", 0.25),
10 ("technical_terms", 0.20),
11 ("domain_specificity", 0.15),
12 ("semantic_complexity", 0.25),
13];
14
15static STRATEGY_WEIGHTS: &[(RetrievalStrategy, f32)] = &[
17 (RetrievalStrategy::BM25Only, 1.0),
18 (RetrievalStrategy::VectorOnly, 1.0),
19 (RetrievalStrategy::Hybrid, 1.2),
20 (RetrievalStrategy::HydeEnhanced, 0.8),
21 (RetrievalStrategy::MultiStep, 0.9),
22 (RetrievalStrategy::Adaptive, 1.1),
23];
24
25struct FeatureScoringRule {
27 condition: fn(&MLFeatures) -> bool,
28 strategy: RetrievalStrategy,
29 score: f32,
30}
31
32static FEATURE_SCORING_RULES: &[FeatureScoringRule] = &[
33 FeatureScoringRule {
34 condition: |f| f.semantic_complexity > 0.7,
35 strategy: RetrievalStrategy::VectorOnly,
36 score: 0.3,
37 },
38 FeatureScoringRule {
39 condition: |f| f.semantic_complexity > 0.7,
40 strategy: RetrievalStrategy::HydeEnhanced,
41 score: 0.2,
42 },
43 FeatureScoringRule {
44 condition: |f| f.technical_term_count > 0.5 || f.has_code > 0.5,
45 strategy: RetrievalStrategy::BM25Only,
46 score: 0.3,
47 },
48 FeatureScoringRule {
49 condition: |f| f.query_complexity_score > 0.6,
50 strategy: RetrievalStrategy::Hybrid,
51 score: 0.4,
52 },
53 FeatureScoringRule {
54 condition: |f| f.query_complexity_score > 0.6,
55 strategy: RetrievalStrategy::MultiStep,
56 score: 0.2,
57 },
58 FeatureScoringRule {
59 condition: |f| f.domain_specificity < 0.5,
60 strategy: RetrievalStrategy::Adaptive,
61 score: 0.2,
62 },
63];
64
65static FEATURE_NAMES: &[&str] = &[
67 "query_length",
68 "query_complexity_score",
69 "technical_term_count",
70 "question_word_presence",
71 "domain_specificity",
72 "has_code",
73 "has_numbers",
74 "intent_score",
75 "semantic_complexity",
76];
77
78static STRATEGY_NAMES: &[(RetrievalStrategy, &str)] = &[
80 (RetrievalStrategy::BM25Only, "BM25-only"),
81 (RetrievalStrategy::VectorOnly, "Vector-only"),
82 (RetrievalStrategy::Hybrid, "Hybrid"),
83 (RetrievalStrategy::HydeEnhanced, "HyDE-enhanced"),
84 (RetrievalStrategy::MultiStep, "Multi-step"),
85 (RetrievalStrategy::Adaptive, "Adaptive"),
86];
87
88static COMPLEXITY_SCORES: &[(QueryComplexity, f32)] = &[
90 (QueryComplexity::Simple, 0.2),
91 (QueryComplexity::Medium, 0.5),
92 (QueryComplexity::Complex, 0.8),
93 (QueryComplexity::VeryComplex, 1.0),
94];
95
96static INTENT_SCORES: &[(QueryIntent, f32)] = &[
98 (QueryIntent::Search, 0.8),
99 (QueryIntent::Explain, 0.6),
100 (QueryIntent::Code, 1.0),
101 (QueryIntent::Debug, 0.9),
102 (QueryIntent::Compare, 0.7),
103 (QueryIntent::Guide, 0.5),
104 (QueryIntent::Assist, 0.4),
105 (QueryIntent::Chat, 0.2),
106];
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct RetrievalStrategyPrediction {
111 pub strategy: RetrievalStrategy,
112 pub confidence: f32,
113 pub features_used: Vec<String>,
114 pub alternatives: Vec<(RetrievalStrategy, f32)>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
119pub enum RetrievalStrategy {
120 BM25Only,
122 VectorOnly,
124 Hybrid,
126 HydeEnhanced,
128 MultiStep,
130 Adaptive,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct MLFeatures {
137 pub query_length: f32,
138 pub query_complexity_score: f32,
139 pub technical_term_count: f32,
140 pub question_word_presence: f32,
141 pub domain_specificity: f32,
142 pub has_code: f32,
143 pub has_numbers: f32,
144 pub intent_score: f32,
145 pub semantic_complexity: f32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct MLPredictionResult {
151 pub prediction: RetrievalStrategyPrediction,
152 pub explanation: String,
153 pub feature_importance: HashMap<String, f32>,
154 pub model_confidence: f32,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct MLPredictionConfig {
160 pub enable_hybrid_fallback: bool,
161 pub confidence_threshold: f32,
162 pub feature_weights: HashMap<String, f32>,
163 pub strategy_weights: HashMap<RetrievalStrategy, f32>,
164}
165
166impl Default for MLPredictionConfig {
167 fn default() -> Self {
168 let feature_weights = FEATURE_WEIGHTS
169 .iter()
170 .map(|(k, v)| (k.to_string(), *v))
171 .collect();
172
173 let strategy_weights = STRATEGY_WEIGHTS
174 .iter()
175 .map(|(k, v)| (k.clone(), *v))
176 .collect();
177
178 Self {
179 enable_hybrid_fallback: true,
180 confidence_threshold: 0.7,
181 feature_weights,
182 strategy_weights,
183 }
184 }
185}
186
187pub struct MLPredictionService {
189 _config: MLPredictionConfig,
190 strategy_rules: Vec<Box<dyn StrategyRule>>,
191}
192
193impl MLPredictionService {
194 pub fn new(config: MLPredictionConfig) -> Self {
195 let mut service = Self {
196 _config: config,
197 strategy_rules: Vec::new(),
198 };
199
200 service.initialize_rules();
201 service
202 }
203
204 pub fn predict_strategy(&self, understanding: &QueryUnderstanding) -> Result<MLPredictionResult> {
206 let features = self.extract_features(understanding);
207 let (strategy_scores, explanations) = self.collect_strategy_scores(understanding, &features);
208 let prediction = self.create_prediction_from_scores(strategy_scores, &features);
209 let explanation = self.generate_explanation(&prediction, understanding, &explanations);
210 let feature_importance = self.calculate_feature_importance(&features);
211 let confidence = prediction.confidence;
212
213 Ok(MLPredictionResult {
214 prediction,
215 explanation,
216 feature_importance,
217 model_confidence: confidence,
218 })
219 }
220
221 fn collect_strategy_scores(
223 &self,
224 understanding: &QueryUnderstanding,
225 features: &MLFeatures
226 ) -> (HashMap<RetrievalStrategy, f32>, Vec<String>) {
227 let mut strategy_scores: HashMap<RetrievalStrategy, f32> = HashMap::new();
228 let mut explanations = Vec::new();
229
230 for rule in &self.strategy_rules {
232 if let Some(prediction) = rule.evaluate(understanding, features) {
233 *strategy_scores.entry(prediction.strategy.clone()).or_insert(0.0) += prediction.confidence;
234 explanations.push(prediction.explanation);
235 }
236 }
237
238 self.apply_feature_scoring(features, &mut strategy_scores);
240
241 (strategy_scores, explanations)
242 }
243
244 fn create_prediction_from_scores(
246 &self,
247 strategy_scores: HashMap<RetrievalStrategy, f32>,
248 features: &MLFeatures
249 ) -> RetrievalStrategyPrediction {
250 let (best_strategy, best_score) = self.select_best_strategy(&strategy_scores);
251 let total_score: f32 = strategy_scores.values().sum();
252 let alternatives = self.create_alternatives(strategy_scores, &best_strategy, total_score);
253
254 RetrievalStrategyPrediction {
255 strategy: best_strategy,
256 confidence: (best_score / total_score).min(1.0),
257 features_used: features.get_feature_names(),
258 alternatives,
259 }
260 }
261
262 fn select_best_strategy(&self, strategy_scores: &HashMap<RetrievalStrategy, f32>) -> (RetrievalStrategy, f32) {
264 strategy_scores
265 .iter()
266 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
267 .map(|(s, score)| (s.clone(), *score))
268 .unwrap_or((RetrievalStrategy::Hybrid, 0.5))
269 }
270
271 fn create_alternatives(
273 &self,
274 strategy_scores: HashMap<RetrievalStrategy, f32>,
275 best_strategy: &RetrievalStrategy,
276 total_score: f32
277 ) -> Vec<(RetrievalStrategy, f32)> {
278 let mut alternatives: Vec<(RetrievalStrategy, f32)> = strategy_scores
279 .into_iter()
280 .filter(|(s, _)| s != best_strategy)
281 .map(|(s, score)| (s, score / total_score))
282 .collect();
283
284 alternatives.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
285 alternatives
286 }
287
288 fn extract_features(&self, understanding: &QueryUnderstanding) -> MLFeatures {
290 let query_length = (understanding.original_query.len() as f32 / 100.0).min(2.0);
291
292 let query_complexity_score = COMPLEXITY_SCORES
293 .iter()
294 .find(|(complexity, _)| *complexity == understanding.complexity)
295 .map(|(_, score)| *score)
296 .unwrap_or(0.5);
297
298 let technical_term_count = (understanding.features.technical_terms.len() as f32 / 10.0).min(1.0);
299
300 let question_word_presence = if understanding.features.question_words.is_empty() {
301 0.0
302 } else {
303 (understanding.features.question_words.len() as f32 / 5.0).min(1.0)
304 };
305
306 let domain_specificity = understanding.domain.confidence;
307
308 let has_code = if understanding.features.has_code { 1.0 } else { 0.0 };
309 let has_numbers = if understanding.features.has_numbers { 1.0 } else { 0.0 };
310
311 let intent_score = INTENT_SCORES
312 .iter()
313 .find(|(intent, _)| *intent == understanding.intent)
314 .map(|(_, score)| *score)
315 .unwrap_or(0.5);
316
317 let semantic_complexity = self.calculate_semantic_complexity(understanding);
318
319 MLFeatures {
320 query_length,
321 query_complexity_score,
322 technical_term_count,
323 question_word_presence,
324 domain_specificity,
325 has_code,
326 has_numbers,
327 intent_score,
328 semantic_complexity,
329 }
330 }
331
332 fn apply_feature_scoring(&self, features: &MLFeatures, strategy_scores: &mut HashMap<RetrievalStrategy, f32>) {
334 for rule in FEATURE_SCORING_RULES {
335 if (rule.condition)(features) {
336 *strategy_scores.entry(rule.strategy.clone()).or_insert(0.0) += rule.score;
337 }
338 }
339 }
340
341 fn calculate_semantic_complexity(&self, understanding: &QueryUnderstanding) -> f32 {
343 let mut complexity = 0.0;
344
345 if understanding.query_type == QueryType::Analytical ||
347 understanding.query_type == QueryType::Subjective {
348 complexity += 0.3;
349 }
350
351 complexity += (understanding.entities.len() as f32 / 10.0).min(0.3);
353
354 if understanding.features.word_count > 10 && understanding.features.technical_terms.len() < 3 {
356 complexity += 0.4;
357 }
358
359 complexity.min(1.0)
360 }
361
362 fn generate_explanation(
364 &self,
365 prediction: &RetrievalStrategyPrediction,
366 understanding: &QueryUnderstanding,
367 _rule_explanations: &[String],
368 ) -> String {
369 let mut explanation = format!(
370 "Selected {} strategy with {:.1}% confidence. ",
371 strategy_to_string(&prediction.strategy),
372 prediction.confidence * 100.0
373 );
374
375 match prediction.strategy {
377 RetrievalStrategy::BM25Only => {
378 explanation.push_str("This strategy was chosen because the query contains specific technical terms or keywords that benefit from exact matching.");
379 }
380 RetrievalStrategy::VectorOnly => {
381 explanation.push_str("This strategy was chosen because the query is conceptual and would benefit from semantic similarity matching.");
382 }
383 RetrievalStrategy::Hybrid => {
384 explanation.push_str("This strategy combines both keyword matching and semantic similarity for comprehensive results.");
385 }
386 RetrievalStrategy::HydeEnhanced => {
387 explanation.push_str("This strategy uses hypothetical document generation to improve semantic matching for complex queries.");
388 }
389 RetrievalStrategy::MultiStep => {
390 explanation.push_str("This strategy uses multiple retrieval phases with reranking for high-precision results.");
391 }
392 RetrievalStrategy::Adaptive => {
393 explanation.push_str("This strategy dynamically adjusts based on initial results quality.");
394 }
395 }
396
397 if understanding.features.has_code {
399 explanation.push_str(" Code-related queries detected.");
400 }
401 if understanding.complexity == QueryComplexity::VeryComplex {
402 explanation.push_str(" High query complexity requires sophisticated retrieval.");
403 }
404
405 explanation
406 }
407
408 fn calculate_feature_importance(&self, features: &MLFeatures) -> HashMap<String, f32> {
410 let mut importance = HashMap::new();
411
412 importance.insert("query_length".to_string(), features.query_length * 0.15);
413 importance.insert("complexity".to_string(), features.query_complexity_score * 0.25);
414 importance.insert("technical_terms".to_string(), features.technical_term_count * 0.20);
415 importance.insert("domain_specificity".to_string(), features.domain_specificity * 0.15);
416 importance.insert("semantic_complexity".to_string(), features.semantic_complexity * 0.25);
417
418 importance
419 }
420
421 fn initialize_rules(&mut self) {
423 self.strategy_rules.push(Box::new(TechnicalQueryRule));
424 self.strategy_rules.push(Box::new(SemanticQueryRule));
425 self.strategy_rules.push(Box::new(ComplexQueryRule));
426 self.strategy_rules.push(Box::new(CodeQueryRule));
427 self.strategy_rules.push(Box::new(ComparisonQueryRule));
428 }
429}
430
431impl Default for MLPredictionService {
432 fn default() -> Self {
433 Self::new(MLPredictionConfig::default())
434 }
435}
436
437trait StrategyRule: Send + Sync {
439 fn evaluate(&self, understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction>;
440}
441
442struct RulePrediction {
444 strategy: RetrievalStrategy,
445 confidence: f32,
446 explanation: String,
447}
448
449struct TechnicalQueryRule;
451
452impl StrategyRule for TechnicalQueryRule {
453 fn evaluate(&self, _understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction> {
454 if features.technical_term_count > 0.6 || features.has_code > 0.5 {
455 Some(RulePrediction {
456 strategy: RetrievalStrategy::BM25Only,
457 confidence: 0.8,
458 explanation: "Technical terms favor keyword-based search".to_string(),
459 })
460 } else {
461 None
462 }
463 }
464}
465
466struct SemanticQueryRule;
468
469impl StrategyRule for SemanticQueryRule {
470 fn evaluate(&self, _understanding: &QueryUnderstanding, features: &MLFeatures) -> Option<RulePrediction> {
471 if features.semantic_complexity > 0.7 && features.technical_term_count < 0.3 {
472 Some(RulePrediction {
473 strategy: RetrievalStrategy::VectorOnly,
474 confidence: 0.7,
475 explanation: "High semantic complexity favors vector search".to_string(),
476 })
477 } else {
478 None
479 }
480 }
481}
482
483struct ComplexQueryRule;
485
486impl StrategyRule for ComplexQueryRule {
487 fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
488 if understanding.complexity == QueryComplexity::VeryComplex {
489 Some(RulePrediction {
490 strategy: RetrievalStrategy::MultiStep,
491 confidence: 0.6,
492 explanation: "Very complex queries benefit from multi-step retrieval".to_string(),
493 })
494 } else if understanding.complexity == QueryComplexity::Complex {
495 Some(RulePrediction {
496 strategy: RetrievalStrategy::Hybrid,
497 confidence: 0.7,
498 explanation: "Complex queries benefit from hybrid approach".to_string(),
499 })
500 } else {
501 None
502 }
503 }
504}
505
506struct CodeQueryRule;
508
509impl StrategyRule for CodeQueryRule {
510 fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
511 if understanding.query_type == QueryType::Technical && understanding.intent == QueryIntent::Code {
512 Some(RulePrediction {
513 strategy: RetrievalStrategy::BM25Only,
514 confidence: 0.9,
515 explanation: "Code queries require exact matching".to_string(),
516 })
517 } else {
518 None
519 }
520 }
521}
522
523struct ComparisonQueryRule;
525
526impl StrategyRule for ComparisonQueryRule {
527 fn evaluate(&self, understanding: &QueryUnderstanding, _features: &MLFeatures) -> Option<RulePrediction> {
528 if understanding.query_type == QueryType::Comparative {
529 Some(RulePrediction {
530 strategy: RetrievalStrategy::HydeEnhanced,
531 confidence: 0.6,
532 explanation: "Comparison queries benefit from hypothetical document expansion".to_string(),
533 })
534 } else {
535 None
536 }
537 }
538}
539
540impl MLFeatures {
541 fn get_feature_names(&self) -> Vec<String> {
542 FEATURE_NAMES.iter().map(|s| s.to_string()).collect()
543 }
544}
545
546fn strategy_to_string(strategy: &RetrievalStrategy) -> &'static str {
547 STRATEGY_NAMES
548 .iter()
549 .find(|(s, _)| s == strategy)
550 .map(|(_, name)| *name)
551 .unwrap_or("Unknown")
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::query_understanding::{QueryDomain, QueryFeatures};
558
559 fn create_test_understanding(query_type: QueryType, intent: QueryIntent, complexity: QueryComplexity) -> QueryUnderstanding {
560 let (technical_terms, has_code) = match query_type {
561 QueryType::Technical => (vec!["code".to_string(), "api".to_string()], true),
562 QueryType::Analytical => (vec![], false),
563 _ => (vec!["term".to_string()], false),
564 };
565
566 QueryUnderstanding {
567 original_query: "test query".to_string(),
568 query_type,
569 intent,
570 complexity,
571 domain: QueryDomain {
572 primary_domain: "programming".to_string(),
573 secondary_domains: vec![],
574 confidence: 0.8,
575 },
576 entities: vec![],
577 features: QueryFeatures {
578 word_count: 5,
579 sentence_count: 1,
580 question_words: vec!["what".to_string()],
581 technical_terms,
582 has_code,
583 has_numbers: false,
584 has_dates: false,
585 language: "en".to_string(),
586 },
587 keywords: vec!["test".to_string(), "query".to_string()],
588 confidence: 0.8,
589 }
590 }
591
592 #[test]
593 fn test_technical_query_prediction() {
594 let service = MLPredictionService::default();
595 let understanding = create_test_understanding(
596 QueryType::Technical,
597 QueryIntent::Code,
598 QueryComplexity::Medium
599 );
600
601 let result = service.predict_strategy(&understanding).unwrap();
602 assert_eq!(result.prediction.strategy, RetrievalStrategy::BM25Only);
603 assert!(result.prediction.confidence > 0.5);
604 }
605
606 #[test]
607 fn test_complex_query_prediction() {
608 let service = MLPredictionService::default();
609 let understanding = create_test_understanding(
610 QueryType::Analytical,
611 QueryIntent::Explain,
612 QueryComplexity::VeryComplex
613 );
614
615 let result = service.predict_strategy(&understanding).unwrap();
616 assert!(matches!(result.prediction.strategy, RetrievalStrategy::MultiStep | RetrievalStrategy::Hybrid));
618 }
619
620 #[test]
621 fn test_feature_extraction() {
622 let service = MLPredictionService::default();
623 let understanding = create_test_understanding(
624 QueryType::Technical,
625 QueryIntent::Code,
626 QueryComplexity::Complex
627 );
628
629 let features = service.extract_features(&understanding);
630 assert!(features.has_code > 0.0);
631 assert!(features.query_complexity_score > 0.5);
632 assert!(features.technical_term_count > 0.0);
633 }
634
635 #[test]
636 fn test_explanation_generation() {
637 let service = MLPredictionService::default();
638 let understanding = create_test_understanding(
639 QueryType::Technical,
640 QueryIntent::Code,
641 QueryComplexity::Medium
642 );
643
644 let result = service.predict_strategy(&understanding).unwrap();
645 assert!(!result.explanation.is_empty());
646 assert!(result.explanation.contains("strategy"));
647 }
648
649 #[test]
650 fn test_feature_importance() {
651 let service = MLPredictionService::default();
652 let understanding = create_test_understanding(
653 QueryType::Technical,
654 QueryIntent::Code,
655 QueryComplexity::Medium
656 );
657
658 let result = service.predict_strategy(&understanding).unwrap();
659 assert!(!result.feature_importance.is_empty());
660 assert!(result.feature_importance.contains_key("complexity"));
661 }
662}