rexis_rag/reranking/
learning_to_rank.rs

1//! # Learning-to-Rank Reranking
2//!
3//! Machine learning models specifically designed for ranking tasks.
4//! Supports various LTR algorithms including RankNet, LambdaMART, and ListNet.
5
6use crate::{RragResult, SearchResult};
7use std::collections::HashMap;
8
9/// Learning-to-rank reranker
10pub struct LearningToRankReranker {
11    /// Configuration
12    config: LTRConfig,
13
14    /// Trained model
15    model: Box<dyn LTRModel>,
16
17    /// Feature extractors
18    feature_extractors: Vec<Box<dyn FeatureExtractor>>,
19
20    /// Feature cache for performance
21    feature_cache: HashMap<String, Vec<f32>>,
22}
23
24/// Configuration for learning-to-rank
25#[derive(Debug, Clone)]
26pub struct LTRConfig {
27    /// Model type
28    pub model_type: LTRModelType,
29
30    /// Feature extraction configuration
31    pub feature_config: FeatureExtractionConfig,
32
33    /// Model parameters
34    pub model_parameters: HashMap<String, f32>,
35
36    /// Training configuration
37    pub training_config: Option<TrainingConfig>,
38
39    /// Enable feature caching
40    pub enable_feature_caching: bool,
41
42    /// Batch size for prediction
43    pub batch_size: usize,
44}
45
46impl Default for LTRConfig {
47    fn default() -> Self {
48        let mut model_parameters = HashMap::new();
49        model_parameters.insert("learning_rate".to_string(), 0.01);
50        model_parameters.insert("num_trees".to_string(), 100.0);
51        model_parameters.insert("max_depth".to_string(), 6.0);
52
53        Self {
54            model_type: LTRModelType::SimulatedLambdaMART,
55            feature_config: FeatureExtractionConfig::default(),
56            model_parameters,
57            training_config: None,
58            enable_feature_caching: true,
59            batch_size: 32,
60        }
61    }
62}
63
64/// Types of LTR models
65#[derive(Debug, Clone, PartialEq)]
66pub enum LTRModelType {
67    /// RankNet neural network model
68    RankNet,
69    /// LambdaMART gradient boosting model
70    LambdaMART,
71    /// ListNet list-wise learning model
72    ListNet,
73    /// RankSVM support vector machine model
74    RankSVM,
75    /// Custom model implementation
76    Custom(String),
77    /// Simulated LambdaMART for demonstration
78    SimulatedLambdaMART,
79}
80
81/// Configuration for feature extraction
82#[derive(Debug, Clone)]
83pub struct FeatureExtractionConfig {
84    /// Enabled feature types
85    pub enabled_features: Vec<FeatureType>,
86
87    /// Feature normalization method
88    pub normalization: FeatureNormalization,
89
90    /// Maximum number of features
91    pub max_features: usize,
92
93    /// Feature selection method
94    pub feature_selection: FeatureSelection,
95}
96
97impl Default for FeatureExtractionConfig {
98    fn default() -> Self {
99        Self {
100            enabled_features: vec![
101                FeatureType::QueryDocumentSimilarity,
102                FeatureType::DocumentLength,
103                FeatureType::QueryTermFrequency,
104                FeatureType::DocumentTermFrequency,
105                FeatureType::InverseLinkFrequency,
106            ],
107            normalization: FeatureNormalization::ZScore,
108            max_features: 100,
109            feature_selection: FeatureSelection::None,
110        }
111    }
112}
113
114/// Training configuration for LTR models
115#[derive(Debug, Clone)]
116pub struct TrainingConfig {
117    /// Number of training iterations
118    pub num_iterations: usize,
119
120    /// Learning rate
121    pub learning_rate: f32,
122
123    /// Regularization parameters
124    pub regularization: RegularizationConfig,
125
126    /// Early stopping criteria
127    pub early_stopping: EarlyStoppingConfig,
128
129    /// Cross-validation folds
130    pub cv_folds: usize,
131}
132
133/// Regularization configuration
134#[derive(Debug, Clone)]
135pub struct RegularizationConfig {
136    /// L1 regularization strength
137    pub l1_weight: f32,
138
139    /// L2 regularization strength
140    pub l2_weight: f32,
141
142    /// Dropout rate (for neural models)
143    pub dropout_rate: f32,
144}
145
146/// Early stopping configuration
147#[derive(Debug, Clone)]
148pub struct EarlyStoppingConfig {
149    /// Metric to monitor
150    pub metric: String,
151
152    /// Patience (iterations without improvement)
153    pub patience: usize,
154
155    /// Minimum improvement threshold
156    pub min_delta: f32,
157}
158
159/// Types of ranking features
160#[derive(Debug, Clone, Hash, PartialEq, Eq)]
161pub enum FeatureType {
162    /// Query-document similarity scores
163    QueryDocumentSimilarity,
164    /// Document length features
165    DocumentLength,
166    /// Query term frequency in document
167    QueryTermFrequency,
168    /// Document term frequency
169    DocumentTermFrequency,
170    /// Inverse document frequency
171    InverseLinkFrequency,
172    /// BM25 score
173    BM25Score,
174    /// PageRank or authority score
175    AuthorityScore,
176    /// Click-through rate
177    ClickThroughRate,
178    /// Query-document exact matches
179    ExactMatches,
180    /// Positional features
181    PositionalFeatures,
182    /// Temporal features
183    TemporalFeatures,
184    /// Custom feature
185    Custom(String),
186}
187
188/// Feature normalization methods
189#[derive(Debug, Clone)]
190pub enum FeatureNormalization {
191    /// Min-max normalization
192    MinMax,
193    /// Z-score normalization
194    ZScore,
195    /// Quantile normalization
196    Quantile,
197    /// No normalization
198    None,
199}
200
201/// Feature selection methods
202#[derive(Debug, Clone)]
203pub enum FeatureSelection {
204    /// No feature selection
205    None,
206    /// Select top-k features by importance
207    TopK(usize),
208    /// Select features by correlation threshold
209    Correlation(f32),
210    /// Recursive feature elimination
211    RFE,
212}
213
214/// A ranking feature extracted from query-document pair
215#[derive(Debug, Clone)]
216pub struct RankingFeature {
217    /// Feature type
218    pub feature_type: FeatureType,
219
220    /// Feature name
221    pub name: String,
222
223    /// Feature value
224    pub value: f32,
225
226    /// Feature importance (if available)
227    pub importance: Option<f32>,
228
229    /// Feature metadata
230    pub metadata: FeatureMetadata,
231}
232
233/// Metadata about feature extraction
234#[derive(Debug, Clone)]
235pub struct FeatureMetadata {
236    /// Extraction method
237    pub extraction_method: String,
238
239    /// Extraction time
240    pub extraction_time_ms: u64,
241
242    /// Confidence in feature quality
243    pub confidence: f32,
244
245    /// Additional properties
246    pub properties: HashMap<String, f32>,
247}
248
249/// Features extracted for a query-document pair
250#[derive(Debug, Clone)]
251pub struct LTRFeatures {
252    /// Query identifier
253    pub query_id: String,
254
255    /// Document identifier
256    pub document_id: String,
257
258    /// Feature vector
259    pub features: Vec<f32>,
260
261    /// Feature names (for interpretability)
262    pub feature_names: Vec<String>,
263
264    /// Ground truth relevance (for training)
265    pub relevance: Option<f32>,
266
267    /// Feature extraction metadata
268    pub metadata: LTRFeaturesMetadata,
269}
270
271/// Metadata for LTR features
272#[derive(Debug, Clone)]
273pub struct LTRFeaturesMetadata {
274    /// Total extraction time
275    pub extraction_time_ms: u64,
276
277    /// Number of features extracted
278    pub num_features: usize,
279
280    /// Feature quality score
281    pub quality_score: f32,
282
283    /// Warnings during extraction
284    pub warnings: Vec<String>,
285}
286
287/// Trait for LTR models
288pub trait LTRModel: Send + Sync {
289    /// Predict ranking scores for a batch of feature vectors
290    fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>>;
291
292    /// Predict ranking scores for a single feature vector
293    fn predict_single(&self, features: &[f32]) -> RragResult<f32> {
294        let batch_result = self.predict(&[features.to_vec()])?;
295        Ok(batch_result.into_iter().next().unwrap_or(0.0))
296    }
297
298    /// Train the model (if training is supported)
299    fn train(&mut self, training_data: &[LTRTrainingExample]) -> RragResult<TrainingResult> {
300        let _ = training_data; // Suppress unused parameter warning
301        Err(crate::RragError::validation(
302            "training",
303            "Training not implemented for this model",
304            "",
305        ))
306    }
307
308    /// Get model information
309    fn get_model_info(&self) -> LTRModelInfo;
310
311    /// Get feature importance if supported
312    fn get_feature_importance(&self) -> Option<Vec<f32>> {
313        None
314    }
315}
316
317/// Training example for LTR models
318#[derive(Debug, Clone)]
319pub struct LTRTrainingExample {
320    /// Query identifier
321    pub query_id: String,
322
323    /// Document identifier
324    pub document_id: String,
325
326    /// Feature vector
327    pub features: Vec<f32>,
328
329    /// Relevance label (0-4 typically)
330    pub relevance: f32,
331
332    /// Training weight
333    pub weight: f32,
334}
335
336/// Result from model training
337#[derive(Debug, Clone)]
338pub struct TrainingResult {
339    /// Training loss
340    pub final_loss: f32,
341
342    /// Validation metrics
343    pub validation_metrics: HashMap<String, f32>,
344
345    /// Training time
346    pub training_time_ms: u64,
347
348    /// Number of iterations completed
349    pub iterations_completed: usize,
350
351    /// Whether early stopping was triggered
352    pub early_stopped: bool,
353}
354
355/// Information about an LTR model
356#[derive(Debug, Clone)]
357pub struct LTRModelInfo {
358    /// Model name
359    pub name: String,
360
361    /// Model version
362    pub version: String,
363
364    /// Number of features expected
365    pub num_features: usize,
366
367    /// Model parameters
368    pub parameters: HashMap<String, f32>,
369
370    /// Training status
371    pub is_trained: bool,
372
373    /// Model performance metrics
374    pub performance_metrics: Option<HashMap<String, f32>>,
375}
376
377/// Trait for feature extractors
378pub trait FeatureExtractor: Send + Sync {
379    /// Extract features for a query-document pair
380    fn extract_features(
381        &self,
382        _query: &str,
383        document: &SearchResult,
384        context: &FeatureExtractionContext,
385    ) -> RragResult<Vec<RankingFeature>>;
386
387    /// Get supported feature types
388    fn supported_features(&self) -> Vec<FeatureType>;
389
390    /// Get extractor configuration
391    fn get_config(&self) -> FeatureExtractorConfig;
392}
393
394/// Context for feature extraction
395#[derive(Debug, Clone)]
396pub struct FeatureExtractionContext {
397    /// All documents in the result set (for relative features)
398    pub all_documents: Vec<SearchResult>,
399
400    /// Query statistics
401    pub query_stats: QueryStats,
402
403    /// Collection statistics
404    pub collection_stats: CollectionStats,
405
406    /// User context
407    pub user_context: Option<UserContext>,
408}
409
410/// Statistics about the query
411#[derive(Debug, Clone)]
412pub struct QueryStats {
413    /// Query length in terms
414    pub length: usize,
415
416    /// Query terms
417    pub terms: Vec<String>,
418
419    /// Query type/intent
420    pub query_type: Option<String>,
421
422    /// Term frequencies in query
423    pub term_frequencies: HashMap<String, usize>,
424}
425
426/// Statistics about the document collection
427#[derive(Debug, Clone)]
428pub struct CollectionStats {
429    /// Total number of documents
430    pub total_documents: usize,
431
432    /// Average document length
433    pub avg_document_length: f32,
434
435    /// Term document frequencies
436    pub document_frequencies: HashMap<String, usize>,
437
438    /// Collection vocabulary size
439    pub vocabulary_size: usize,
440}
441
442/// User context for personalized features
443#[derive(Debug, Clone)]
444pub struct UserContext {
445    /// User identifier
446    pub user_id: String,
447
448    /// User preferences
449    pub preferences: HashMap<String, f32>,
450
451    /// User interaction history
452    pub interaction_history: Vec<String>,
453}
454
455/// Configuration for feature extractors
456#[derive(Debug, Clone)]
457pub struct FeatureExtractorConfig {
458    /// Extractor name
459    pub name: String,
460
461    /// Supported features
462    pub supported_features: Vec<FeatureType>,
463
464    /// Performance characteristics
465    pub performance: FeatureExtractorPerformance,
466}
467
468/// Performance characteristics of feature extractors
469#[derive(Debug, Clone)]
470pub struct FeatureExtractorPerformance {
471    /// Average extraction time per document (ms)
472    pub avg_extraction_time_ms: f32,
473
474    /// Memory usage (MB)
475    pub memory_usage_mb: f32,
476
477    /// Feature quality score
478    pub quality_score: f32,
479}
480
481impl LearningToRankReranker {
482    /// Create a new LTR reranker
483    pub fn new(config: LTRConfig) -> Self {
484        let model = Self::create_model(&config.model_type, &config.model_parameters);
485        let feature_extractors = Self::create_feature_extractors(&config.feature_config);
486
487        Self {
488            config,
489            model,
490            feature_extractors,
491            feature_cache: HashMap::new(),
492        }
493    }
494
495    /// Create model based on configuration
496    fn create_model(
497        model_type: &LTRModelType,
498        parameters: &HashMap<String, f32>,
499    ) -> Box<dyn LTRModel> {
500        match model_type {
501            LTRModelType::SimulatedLambdaMART => {
502                Box::new(SimulatedLambdaMARTModel::new(parameters.clone()))
503            }
504            LTRModelType::LambdaMART => Box::new(SimulatedLambdaMARTModel::new(parameters.clone())),
505            LTRModelType::RankNet => Box::new(SimulatedRankNetModel::new()),
506            LTRModelType::ListNet => Box::new(SimulatedListNetModel::new()),
507            LTRModelType::RankSVM => Box::new(SimulatedRankSVMModel::new()),
508            LTRModelType::Custom(name) => Box::new(CustomLTRModel::new(name.clone())),
509        }
510    }
511
512    /// Create feature extractors based on configuration
513    fn create_feature_extractors(
514        config: &FeatureExtractionConfig,
515    ) -> Vec<Box<dyn FeatureExtractor>> {
516        let mut extractors: Vec<Box<dyn FeatureExtractor>> = Vec::new();
517
518        if config
519            .enabled_features
520            .contains(&FeatureType::QueryDocumentSimilarity)
521        {
522            extractors.push(Box::new(SimilarityFeatureExtractor::new()));
523        }
524
525        if config
526            .enabled_features
527            .contains(&FeatureType::DocumentLength)
528        {
529            extractors.push(Box::new(LengthFeatureExtractor::new()));
530        }
531
532        if config
533            .enabled_features
534            .contains(&FeatureType::QueryTermFrequency)
535        {
536            extractors.push(Box::new(TermFrequencyExtractor::new()));
537        }
538
539        extractors
540    }
541
542    /// Rerank search results using LTR model
543    pub async fn rerank(
544        &self,
545        query: &str,
546        results: &[SearchResult],
547    ) -> RragResult<HashMap<usize, f32>> {
548        // Create feature extraction context
549        let context = FeatureExtractionContext {
550            all_documents: results.to_vec(),
551            query_stats: self.compute_query_stats(query),
552            collection_stats: self.compute_collection_stats(results),
553            user_context: None,
554        };
555
556        // Extract features for all documents
557        let mut feature_vectors = Vec::new();
558
559        for document in results {
560            let features = self.extract_document_features(query, document, &context)?;
561            feature_vectors.push(features);
562        }
563
564        // Predict scores using LTR model
565        let scores = self.model.predict(&feature_vectors)?;
566
567        // Create result mapping
568        let mut score_map = HashMap::new();
569        for (idx, score) in scores.into_iter().enumerate() {
570            score_map.insert(idx, score);
571        }
572
573        Ok(score_map)
574    }
575
576    /// Extract features for a single document
577    fn extract_document_features(
578        &self,
579        query: &str,
580        document: &SearchResult,
581        context: &FeatureExtractionContext,
582    ) -> RragResult<Vec<f32>> {
583        let mut all_features = Vec::new();
584
585        // Extract features using all extractors
586        for extractor in &self.feature_extractors {
587            let features = extractor.extract_features(query, document, context)?;
588
589            for feature in features {
590                all_features.push(feature.value);
591            }
592        }
593
594        // Apply normalization if configured
595        let normalized_features = match self.config.feature_config.normalization {
596            FeatureNormalization::None => all_features,
597            _ => self.normalize_features(all_features)?,
598        };
599
600        // Apply feature selection if configured
601        let selected_features = match self.config.feature_config.feature_selection {
602            FeatureSelection::None => normalized_features,
603            _ => self.select_features(normalized_features)?,
604        };
605
606        Ok(selected_features)
607    }
608
609    /// Compute query statistics
610    fn compute_query_stats(&self, query: &str) -> QueryStats {
611        let terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
612
613        let mut term_frequencies = HashMap::new();
614        for term in &terms {
615            *term_frequencies.entry(term.clone()).or_insert(0) += 1;
616        }
617
618        QueryStats {
619            length: terms.len(),
620            terms,
621            query_type: None, // Could be inferred
622            term_frequencies,
623        }
624    }
625
626    /// Compute collection statistics
627    fn compute_collection_stats(&self, documents: &[SearchResult]) -> CollectionStats {
628        let total_documents = documents.len();
629        let total_length: usize = documents
630            .iter()
631            .map(|d| d.content.split_whitespace().count())
632            .sum();
633        let avg_document_length = if total_documents > 0 {
634            total_length as f32 / total_documents as f32
635        } else {
636            0.0
637        };
638
639        // Compute document frequencies
640        let mut document_frequencies = HashMap::new();
641        let mut vocabulary = std::collections::HashSet::new();
642
643        for document in documents {
644            let terms: std::collections::HashSet<String> = document
645                .content
646                .split_whitespace()
647                .map(|s| s.to_lowercase())
648                .collect();
649
650            for term in &terms {
651                *document_frequencies.entry(term.clone()).or_insert(0) += 1;
652                vocabulary.insert(term.clone());
653            }
654        }
655
656        CollectionStats {
657            total_documents,
658            avg_document_length,
659            document_frequencies,
660            vocabulary_size: vocabulary.len(),
661        }
662    }
663
664    /// Normalize features
665    fn normalize_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
666        match self.config.feature_config.normalization {
667            FeatureNormalization::MinMax => {
668                let min_val = features.iter().fold(f32::INFINITY, |a, &b| a.min(b));
669                let max_val = features.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
670                let range = max_val - min_val;
671
672                if range == 0.0 {
673                    Ok(features) // No normalization needed
674                } else {
675                    Ok(features
676                        .into_iter()
677                        .map(|f| (f - min_val) / range)
678                        .collect())
679                }
680            }
681            FeatureNormalization::ZScore => {
682                let mean = features.iter().sum::<f32>() / features.len() as f32;
683                let variance = features.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
684                    / features.len() as f32;
685                let std_dev = variance.sqrt();
686
687                if std_dev == 0.0 {
688                    Ok(features)
689                } else {
690                    Ok(features.into_iter().map(|f| (f - mean) / std_dev).collect())
691                }
692            }
693            _ => Ok(features), // Other normalizations not implemented
694        }
695    }
696
697    /// Select features based on configuration
698    fn select_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
699        match self.config.feature_config.feature_selection {
700            FeatureSelection::TopK(k) => {
701                // For simplicity, just take first k features
702                Ok(features.into_iter().take(k).collect())
703            }
704            _ => Ok(features), // Other selection methods not implemented
705        }
706    }
707}
708
709// Mock implementations of LTR models
710struct SimulatedLambdaMARTModel {
711    parameters: HashMap<String, f32>,
712    num_trees: usize,
713}
714
715impl SimulatedLambdaMARTModel {
716    fn new(parameters: HashMap<String, f32>) -> Self {
717        let num_trees = parameters.get("num_trees").copied().unwrap_or(100.0) as usize;
718        Self {
719            parameters,
720            num_trees,
721        }
722    }
723}
724
725impl LTRModel for SimulatedLambdaMARTModel {
726    fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
727        let mut scores = Vec::new();
728
729        for feature_vector in features {
730            // Simulate LambdaMART prediction with ensemble of trees
731            let mut score = 0.0;
732
733            for tree_idx in 0..self.num_trees {
734                // Simulate tree prediction (very simplified)
735                let tree_score = feature_vector
736                    .iter()
737                    .enumerate()
738                    .map(|(i, &f)| f * (0.1 + 0.01 * (tree_idx + i) as f32).sin())
739                    .sum::<f32>()
740                    / feature_vector.len() as f32;
741
742                score += tree_score * 0.01; // Learning rate
743            }
744
745            // Apply sigmoid to get 0-1 score
746            scores.push(1.0 / (1.0 + (-score).exp()));
747        }
748
749        Ok(scores)
750    }
751
752    fn get_model_info(&self) -> LTRModelInfo {
753        LTRModelInfo {
754            name: "SimulatedLambdaMART".to_string(),
755            version: "1.0".to_string(),
756            num_features: 0, // Dynamic
757            parameters: self.parameters.clone(),
758            is_trained: true,
759            performance_metrics: None,
760        }
761    }
762
763    fn get_feature_importance(&self) -> Option<Vec<f32>> {
764        // Simulate feature importance scores
765        Some(vec![0.3, 0.25, 0.2, 0.15, 0.1]) // Top 5 features
766    }
767}
768
769// Placeholder implementations for other models
770macro_rules! impl_mock_ltr_model {
771    ($name:ident) => {
772        struct $name;
773
774        impl $name {
775            fn new() -> Self {
776                Self
777            }
778        }
779
780        impl LTRModel for $name {
781            fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
782                Ok(features
783                    .iter()
784                    .map(|f| f.iter().sum::<f32>() / f.len() as f32)
785                    .map(|s| 1.0 / (1.0 + (-s).exp())) // Sigmoid
786                    .collect())
787            }
788
789            fn get_model_info(&self) -> LTRModelInfo {
790                LTRModelInfo {
791                    name: stringify!($name).to_string(),
792                    version: "1.0".to_string(),
793                    num_features: 0,
794                    parameters: HashMap::new(),
795                    is_trained: false,
796                    performance_metrics: None,
797                }
798            }
799        }
800    };
801}
802
803impl_mock_ltr_model!(SimulatedRankNetModel);
804impl_mock_ltr_model!(SimulatedListNetModel);
805impl_mock_ltr_model!(SimulatedRankSVMModel);
806
807struct CustomLTRModel {
808    name: String,
809}
810
811impl CustomLTRModel {
812    fn new(name: String) -> Self {
813        Self { name }
814    }
815}
816
817impl LTRModel for CustomLTRModel {
818    fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
819        Ok(vec![0.5; features.len()]) // Neutral scores
820    }
821
822    fn get_model_info(&self) -> LTRModelInfo {
823        LTRModelInfo {
824            name: self.name.clone(),
825            version: "custom".to_string(),
826            num_features: 0,
827            parameters: HashMap::new(),
828            is_trained: false,
829            performance_metrics: None,
830        }
831    }
832}
833
834// Feature extractors
835struct SimilarityFeatureExtractor;
836
837impl SimilarityFeatureExtractor {
838    fn new() -> Self {
839        Self
840    }
841}
842
843impl FeatureExtractor for SimilarityFeatureExtractor {
844    fn extract_features(
845        &self,
846        _query: &str,
847        document: &SearchResult,
848        _context: &FeatureExtractionContext,
849    ) -> RragResult<Vec<RankingFeature>> {
850        let similarity = document.score; // Use existing similarity score
851
852        Ok(vec![RankingFeature {
853            feature_type: FeatureType::QueryDocumentSimilarity,
854            name: "cosine_similarity".to_string(),
855            value: similarity,
856            importance: Some(0.8),
857            metadata: FeatureMetadata {
858                extraction_method: "vector_similarity".to_string(),
859                extraction_time_ms: 1,
860                confidence: 0.9,
861                properties: HashMap::new(),
862            },
863        }])
864    }
865
866    fn supported_features(&self) -> Vec<FeatureType> {
867        vec![FeatureType::QueryDocumentSimilarity]
868    }
869
870    fn get_config(&self) -> FeatureExtractorConfig {
871        FeatureExtractorConfig {
872            name: "SimilarityFeatureExtractor".to_string(),
873            supported_features: self.supported_features(),
874            performance: FeatureExtractorPerformance {
875                avg_extraction_time_ms: 1.0,
876                memory_usage_mb: 0.1,
877                quality_score: 0.9,
878            },
879        }
880    }
881}
882
883struct LengthFeatureExtractor;
884
885impl LengthFeatureExtractor {
886    fn new() -> Self {
887        Self
888    }
889}
890
891impl FeatureExtractor for LengthFeatureExtractor {
892    fn extract_features(
893        &self,
894        _query: &str,
895        document: &SearchResult,
896        context: &FeatureExtractionContext,
897    ) -> RragResult<Vec<RankingFeature>> {
898        let doc_length = document.content.split_whitespace().count() as f32;
899        let normalized_length = doc_length / context.collection_stats.avg_document_length;
900
901        Ok(vec![
902            RankingFeature {
903                feature_type: FeatureType::DocumentLength,
904                name: "document_length".to_string(),
905                value: doc_length,
906                importance: Some(0.3),
907                metadata: FeatureMetadata {
908                    extraction_method: "word_count".to_string(),
909                    extraction_time_ms: 1,
910                    confidence: 1.0,
911                    properties: HashMap::new(),
912                },
913            },
914            RankingFeature {
915                feature_type: FeatureType::DocumentLength,
916                name: "normalized_document_length".to_string(),
917                value: normalized_length,
918                importance: Some(0.4),
919                metadata: FeatureMetadata {
920                    extraction_method: "normalized_word_count".to_string(),
921                    extraction_time_ms: 1,
922                    confidence: 1.0,
923                    properties: HashMap::new(),
924                },
925            },
926        ])
927    }
928
929    fn supported_features(&self) -> Vec<FeatureType> {
930        vec![FeatureType::DocumentLength]
931    }
932
933    fn get_config(&self) -> FeatureExtractorConfig {
934        FeatureExtractorConfig {
935            name: "LengthFeatureExtractor".to_string(),
936            supported_features: self.supported_features(),
937            performance: FeatureExtractorPerformance {
938                avg_extraction_time_ms: 1.0,
939                memory_usage_mb: 0.01,
940                quality_score: 1.0,
941            },
942        }
943    }
944}
945
946struct TermFrequencyExtractor;
947
948impl TermFrequencyExtractor {
949    fn new() -> Self {
950        Self
951    }
952}
953
954impl FeatureExtractor for TermFrequencyExtractor {
955    fn extract_features(
956        &self,
957        _query: &str,
958        document: &SearchResult,
959        context: &FeatureExtractionContext,
960    ) -> RragResult<Vec<RankingFeature>> {
961        let mut features = Vec::new();
962
963        let doc_terms: std::collections::HashMap<String, usize> = {
964            let mut map = std::collections::HashMap::new();
965            for term in document.content.split_whitespace() {
966                let term = term.to_lowercase();
967                *map.entry(term).or_insert(0) += 1;
968            }
969            map
970        };
971
972        // Query term frequency in document
973        let mut total_qtf = 0.0;
974        let mut matched_terms = 0;
975
976        for query_term in &context.query_stats.terms {
977            if let Some(&tf) = doc_terms.get(query_term) {
978                total_qtf += tf as f32;
979                matched_terms += 1;
980            }
981        }
982
983        features.push(RankingFeature {
984            feature_type: FeatureType::QueryTermFrequency,
985            name: "total_query_term_frequency".to_string(),
986            value: total_qtf,
987            importance: Some(0.6),
988            metadata: FeatureMetadata {
989                extraction_method: "term_counting".to_string(),
990                extraction_time_ms: 2,
991                confidence: 0.9,
992                properties: HashMap::new(),
993            },
994        });
995
996        features.push(RankingFeature {
997            feature_type: FeatureType::QueryTermFrequency,
998            name: "query_term_coverage".to_string(),
999            value: matched_terms as f32 / context.query_stats.terms.len() as f32,
1000            importance: Some(0.7),
1001            metadata: FeatureMetadata {
1002                extraction_method: "coverage_calculation".to_string(),
1003                extraction_time_ms: 1,
1004                confidence: 1.0,
1005                properties: HashMap::new(),
1006            },
1007        });
1008
1009        Ok(features)
1010    }
1011
1012    fn supported_features(&self) -> Vec<FeatureType> {
1013        vec![FeatureType::QueryTermFrequency]
1014    }
1015
1016    fn get_config(&self) -> FeatureExtractorConfig {
1017        FeatureExtractorConfig {
1018            name: "TermFrequencyExtractor".to_string(),
1019            supported_features: self.supported_features(),
1020            performance: FeatureExtractorPerformance {
1021                avg_extraction_time_ms: 3.0,
1022                memory_usage_mb: 0.05,
1023                quality_score: 0.8,
1024            },
1025        }
1026    }
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031    use super::*;
1032    use crate::SearchResult;
1033
1034    #[tokio::test]
1035    async fn test_ltr_reranking() {
1036        let config = LTRConfig::default();
1037        let reranker = LearningToRankReranker::new(config);
1038
1039        let results = vec![
1040            SearchResult {
1041                id: "doc1".to_string(),
1042                content: "Machine learning is a subset of artificial intelligence that enables computers to learn".to_string(),
1043                score: 0.8,
1044                rank: 0,
1045                metadata: HashMap::new(),
1046                embedding: None,
1047            },
1048            SearchResult {
1049                id: "doc2".to_string(),
1050                content: "AI and ML".to_string(),
1051                score: 0.6,
1052                rank: 1,
1053                metadata: HashMap::new(),
1054                embedding: None,
1055            },
1056        ];
1057
1058        let query = "machine learning artificial intelligence";
1059        let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1060
1061        assert!(!reranked_scores.is_empty());
1062        assert!(reranked_scores.contains_key(&0));
1063        assert!(reranked_scores.contains_key(&1));
1064    }
1065
1066    #[test]
1067    fn test_feature_extraction() {
1068        let extractor = SimilarityFeatureExtractor::new();
1069        let context = FeatureExtractionContext {
1070            all_documents: vec![],
1071            query_stats: QueryStats {
1072                length: 2,
1073                terms: vec!["test".to_string(), "query".to_string()],
1074                query_type: None,
1075                term_frequencies: HashMap::new(),
1076            },
1077            collection_stats: CollectionStats {
1078                total_documents: 1,
1079                avg_document_length: 10.0,
1080                document_frequencies: HashMap::new(),
1081                vocabulary_size: 100,
1082            },
1083            user_context: None,
1084        };
1085
1086        let document = SearchResult {
1087            id: "test_doc".to_string(),
1088            content: "test document content".to_string(),
1089            score: 0.7,
1090            rank: 0,
1091            metadata: HashMap::new(),
1092            embedding: None,
1093        };
1094
1095        let features = extractor
1096            .extract_features("test query", &document, &context)
1097            .unwrap();
1098
1099        assert!(!features.is_empty());
1100        assert_eq!(
1101            features[0].feature_type,
1102            FeatureType::QueryDocumentSimilarity
1103        );
1104        assert_eq!(features[0].value, 0.7);
1105    }
1106}