1use crate::{RragResult, SearchResult};
7use std::collections::HashMap;
8
9pub struct LearningToRankReranker {
11 config: LTRConfig,
13
14 model: Box<dyn LTRModel>,
16
17 feature_extractors: Vec<Box<dyn FeatureExtractor>>,
19
20 feature_cache: HashMap<String, Vec<f32>>,
22}
23
24#[derive(Debug, Clone)]
26pub struct LTRConfig {
27 pub model_type: LTRModelType,
29
30 pub feature_config: FeatureExtractionConfig,
32
33 pub model_parameters: HashMap<String, f32>,
35
36 pub training_config: Option<TrainingConfig>,
38
39 pub enable_feature_caching: bool,
41
42 pub batch_size: usize,
44}
45
46impl Default for LTRConfig {
47 fn default() -> Self {
48 let mut model_parameters = HashMap::new();
49 model_parameters.insert("learning_rate".to_string(), 0.01);
50 model_parameters.insert("num_trees".to_string(), 100.0);
51 model_parameters.insert("max_depth".to_string(), 6.0);
52
53 Self {
54 model_type: LTRModelType::SimulatedLambdaMART,
55 feature_config: FeatureExtractionConfig::default(),
56 model_parameters,
57 training_config: None,
58 enable_feature_caching: true,
59 batch_size: 32,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub enum LTRModelType {
67 RankNet,
69 LambdaMART,
71 ListNet,
73 RankSVM,
75 Custom(String),
77 SimulatedLambdaMART,
79}
80
81#[derive(Debug, Clone)]
83pub struct FeatureExtractionConfig {
84 pub enabled_features: Vec<FeatureType>,
86
87 pub normalization: FeatureNormalization,
89
90 pub max_features: usize,
92
93 pub feature_selection: FeatureSelection,
95}
96
97impl Default for FeatureExtractionConfig {
98 fn default() -> Self {
99 Self {
100 enabled_features: vec![
101 FeatureType::QueryDocumentSimilarity,
102 FeatureType::DocumentLength,
103 FeatureType::QueryTermFrequency,
104 FeatureType::DocumentTermFrequency,
105 FeatureType::InverseLinkFrequency,
106 ],
107 normalization: FeatureNormalization::ZScore,
108 max_features: 100,
109 feature_selection: FeatureSelection::None,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct TrainingConfig {
117 pub num_iterations: usize,
119
120 pub learning_rate: f32,
122
123 pub regularization: RegularizationConfig,
125
126 pub early_stopping: EarlyStoppingConfig,
128
129 pub cv_folds: usize,
131}
132
133#[derive(Debug, Clone)]
135pub struct RegularizationConfig {
136 pub l1_weight: f32,
138
139 pub l2_weight: f32,
141
142 pub dropout_rate: f32,
144}
145
146#[derive(Debug, Clone)]
148pub struct EarlyStoppingConfig {
149 pub metric: String,
151
152 pub patience: usize,
154
155 pub min_delta: f32,
157}
158
159#[derive(Debug, Clone, Hash, PartialEq, Eq)]
161pub enum FeatureType {
162 QueryDocumentSimilarity,
164 DocumentLength,
166 QueryTermFrequency,
168 DocumentTermFrequency,
170 InverseLinkFrequency,
172 BM25Score,
174 AuthorityScore,
176 ClickThroughRate,
178 ExactMatches,
180 PositionalFeatures,
182 TemporalFeatures,
184 Custom(String),
186}
187
188#[derive(Debug, Clone)]
190pub enum FeatureNormalization {
191 MinMax,
193 ZScore,
195 Quantile,
197 None,
199}
200
201#[derive(Debug, Clone)]
203pub enum FeatureSelection {
204 None,
206 TopK(usize),
208 Correlation(f32),
210 RFE,
212}
213
214#[derive(Debug, Clone)]
216pub struct RankingFeature {
217 pub feature_type: FeatureType,
219
220 pub name: String,
222
223 pub value: f32,
225
226 pub importance: Option<f32>,
228
229 pub metadata: FeatureMetadata,
231}
232
233#[derive(Debug, Clone)]
235pub struct FeatureMetadata {
236 pub extraction_method: String,
238
239 pub extraction_time_ms: u64,
241
242 pub confidence: f32,
244
245 pub properties: HashMap<String, f32>,
247}
248
249#[derive(Debug, Clone)]
251pub struct LTRFeatures {
252 pub query_id: String,
254
255 pub document_id: String,
257
258 pub features: Vec<f32>,
260
261 pub feature_names: Vec<String>,
263
264 pub relevance: Option<f32>,
266
267 pub metadata: LTRFeaturesMetadata,
269}
270
271#[derive(Debug, Clone)]
273pub struct LTRFeaturesMetadata {
274 pub extraction_time_ms: u64,
276
277 pub num_features: usize,
279
280 pub quality_score: f32,
282
283 pub warnings: Vec<String>,
285}
286
287pub trait LTRModel: Send + Sync {
289 fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>>;
291
292 fn predict_single(&self, features: &[f32]) -> RragResult<f32> {
294 let batch_result = self.predict(&[features.to_vec()])?;
295 Ok(batch_result.into_iter().next().unwrap_or(0.0))
296 }
297
298 fn train(&mut self, training_data: &[LTRTrainingExample]) -> RragResult<TrainingResult> {
300 let _ = training_data; Err(crate::RragError::validation(
302 "training",
303 "Training not implemented for this model",
304 "",
305 ))
306 }
307
308 fn get_model_info(&self) -> LTRModelInfo;
310
311 fn get_feature_importance(&self) -> Option<Vec<f32>> {
313 None
314 }
315}
316
317#[derive(Debug, Clone)]
319pub struct LTRTrainingExample {
320 pub query_id: String,
322
323 pub document_id: String,
325
326 pub features: Vec<f32>,
328
329 pub relevance: f32,
331
332 pub weight: f32,
334}
335
336#[derive(Debug, Clone)]
338pub struct TrainingResult {
339 pub final_loss: f32,
341
342 pub validation_metrics: HashMap<String, f32>,
344
345 pub training_time_ms: u64,
347
348 pub iterations_completed: usize,
350
351 pub early_stopped: bool,
353}
354
355#[derive(Debug, Clone)]
357pub struct LTRModelInfo {
358 pub name: String,
360
361 pub version: String,
363
364 pub num_features: usize,
366
367 pub parameters: HashMap<String, f32>,
369
370 pub is_trained: bool,
372
373 pub performance_metrics: Option<HashMap<String, f32>>,
375}
376
377pub trait FeatureExtractor: Send + Sync {
379 fn extract_features(
381 &self,
382 _query: &str,
383 document: &SearchResult,
384 context: &FeatureExtractionContext,
385 ) -> RragResult<Vec<RankingFeature>>;
386
387 fn supported_features(&self) -> Vec<FeatureType>;
389
390 fn get_config(&self) -> FeatureExtractorConfig;
392}
393
394#[derive(Debug, Clone)]
396pub struct FeatureExtractionContext {
397 pub all_documents: Vec<SearchResult>,
399
400 pub query_stats: QueryStats,
402
403 pub collection_stats: CollectionStats,
405
406 pub user_context: Option<UserContext>,
408}
409
410#[derive(Debug, Clone)]
412pub struct QueryStats {
413 pub length: usize,
415
416 pub terms: Vec<String>,
418
419 pub query_type: Option<String>,
421
422 pub term_frequencies: HashMap<String, usize>,
424}
425
426#[derive(Debug, Clone)]
428pub struct CollectionStats {
429 pub total_documents: usize,
431
432 pub avg_document_length: f32,
434
435 pub document_frequencies: HashMap<String, usize>,
437
438 pub vocabulary_size: usize,
440}
441
442#[derive(Debug, Clone)]
444pub struct UserContext {
445 pub user_id: String,
447
448 pub preferences: HashMap<String, f32>,
450
451 pub interaction_history: Vec<String>,
453}
454
455#[derive(Debug, Clone)]
457pub struct FeatureExtractorConfig {
458 pub name: String,
460
461 pub supported_features: Vec<FeatureType>,
463
464 pub performance: FeatureExtractorPerformance,
466}
467
468#[derive(Debug, Clone)]
470pub struct FeatureExtractorPerformance {
471 pub avg_extraction_time_ms: f32,
473
474 pub memory_usage_mb: f32,
476
477 pub quality_score: f32,
479}
480
481impl LearningToRankReranker {
482 pub fn new(config: LTRConfig) -> Self {
484 let model = Self::create_model(&config.model_type, &config.model_parameters);
485 let feature_extractors = Self::create_feature_extractors(&config.feature_config);
486
487 Self {
488 config,
489 model,
490 feature_extractors,
491 feature_cache: HashMap::new(),
492 }
493 }
494
495 fn create_model(
497 model_type: <RModelType,
498 parameters: &HashMap<String, f32>,
499 ) -> Box<dyn LTRModel> {
500 match model_type {
501 LTRModelType::SimulatedLambdaMART => {
502 Box::new(SimulatedLambdaMARTModel::new(parameters.clone()))
503 }
504 LTRModelType::LambdaMART => Box::new(SimulatedLambdaMARTModel::new(parameters.clone())),
505 LTRModelType::RankNet => Box::new(SimulatedRankNetModel::new()),
506 LTRModelType::ListNet => Box::new(SimulatedListNetModel::new()),
507 LTRModelType::RankSVM => Box::new(SimulatedRankSVMModel::new()),
508 LTRModelType::Custom(name) => Box::new(CustomLTRModel::new(name.clone())),
509 }
510 }
511
512 fn create_feature_extractors(
514 config: &FeatureExtractionConfig,
515 ) -> Vec<Box<dyn FeatureExtractor>> {
516 let mut extractors: Vec<Box<dyn FeatureExtractor>> = Vec::new();
517
518 if config
519 .enabled_features
520 .contains(&FeatureType::QueryDocumentSimilarity)
521 {
522 extractors.push(Box::new(SimilarityFeatureExtractor::new()));
523 }
524
525 if config
526 .enabled_features
527 .contains(&FeatureType::DocumentLength)
528 {
529 extractors.push(Box::new(LengthFeatureExtractor::new()));
530 }
531
532 if config
533 .enabled_features
534 .contains(&FeatureType::QueryTermFrequency)
535 {
536 extractors.push(Box::new(TermFrequencyExtractor::new()));
537 }
538
539 extractors
540 }
541
542 pub async fn rerank(
544 &self,
545 query: &str,
546 results: &[SearchResult],
547 ) -> RragResult<HashMap<usize, f32>> {
548 let context = FeatureExtractionContext {
550 all_documents: results.to_vec(),
551 query_stats: self.compute_query_stats(query),
552 collection_stats: self.compute_collection_stats(results),
553 user_context: None,
554 };
555
556 let mut feature_vectors = Vec::new();
558
559 for document in results {
560 let features = self.extract_document_features(query, document, &context)?;
561 feature_vectors.push(features);
562 }
563
564 let scores = self.model.predict(&feature_vectors)?;
566
567 let mut score_map = HashMap::new();
569 for (idx, score) in scores.into_iter().enumerate() {
570 score_map.insert(idx, score);
571 }
572
573 Ok(score_map)
574 }
575
576 fn extract_document_features(
578 &self,
579 query: &str,
580 document: &SearchResult,
581 context: &FeatureExtractionContext,
582 ) -> RragResult<Vec<f32>> {
583 let mut all_features = Vec::new();
584
585 for extractor in &self.feature_extractors {
587 let features = extractor.extract_features(query, document, context)?;
588
589 for feature in features {
590 all_features.push(feature.value);
591 }
592 }
593
594 let normalized_features = match self.config.feature_config.normalization {
596 FeatureNormalization::None => all_features,
597 _ => self.normalize_features(all_features)?,
598 };
599
600 let selected_features = match self.config.feature_config.feature_selection {
602 FeatureSelection::None => normalized_features,
603 _ => self.select_features(normalized_features)?,
604 };
605
606 Ok(selected_features)
607 }
608
609 fn compute_query_stats(&self, query: &str) -> QueryStats {
611 let terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
612
613 let mut term_frequencies = HashMap::new();
614 for term in &terms {
615 *term_frequencies.entry(term.clone()).or_insert(0) += 1;
616 }
617
618 QueryStats {
619 length: terms.len(),
620 terms,
621 query_type: None, term_frequencies,
623 }
624 }
625
626 fn compute_collection_stats(&self, documents: &[SearchResult]) -> CollectionStats {
628 let total_documents = documents.len();
629 let total_length: usize = documents
630 .iter()
631 .map(|d| d.content.split_whitespace().count())
632 .sum();
633 let avg_document_length = if total_documents > 0 {
634 total_length as f32 / total_documents as f32
635 } else {
636 0.0
637 };
638
639 let mut document_frequencies = HashMap::new();
641 let mut vocabulary = std::collections::HashSet::new();
642
643 for document in documents {
644 let terms: std::collections::HashSet<String> = document
645 .content
646 .split_whitespace()
647 .map(|s| s.to_lowercase())
648 .collect();
649
650 for term in &terms {
651 *document_frequencies.entry(term.clone()).or_insert(0) += 1;
652 vocabulary.insert(term.clone());
653 }
654 }
655
656 CollectionStats {
657 total_documents,
658 avg_document_length,
659 document_frequencies,
660 vocabulary_size: vocabulary.len(),
661 }
662 }
663
664 fn normalize_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
666 match self.config.feature_config.normalization {
667 FeatureNormalization::MinMax => {
668 let min_val = features.iter().fold(f32::INFINITY, |a, &b| a.min(b));
669 let max_val = features.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
670 let range = max_val - min_val;
671
672 if range == 0.0 {
673 Ok(features) } else {
675 Ok(features
676 .into_iter()
677 .map(|f| (f - min_val) / range)
678 .collect())
679 }
680 }
681 FeatureNormalization::ZScore => {
682 let mean = features.iter().sum::<f32>() / features.len() as f32;
683 let variance = features.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
684 / features.len() as f32;
685 let std_dev = variance.sqrt();
686
687 if std_dev == 0.0 {
688 Ok(features)
689 } else {
690 Ok(features.into_iter().map(|f| (f - mean) / std_dev).collect())
691 }
692 }
693 _ => Ok(features), }
695 }
696
697 fn select_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
699 match self.config.feature_config.feature_selection {
700 FeatureSelection::TopK(k) => {
701 Ok(features.into_iter().take(k).collect())
703 }
704 _ => Ok(features), }
706 }
707}
708
709struct SimulatedLambdaMARTModel {
711 parameters: HashMap<String, f32>,
712 num_trees: usize,
713}
714
715impl SimulatedLambdaMARTModel {
716 fn new(parameters: HashMap<String, f32>) -> Self {
717 let num_trees = parameters.get("num_trees").copied().unwrap_or(100.0) as usize;
718 Self {
719 parameters,
720 num_trees,
721 }
722 }
723}
724
725impl LTRModel for SimulatedLambdaMARTModel {
726 fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
727 let mut scores = Vec::new();
728
729 for feature_vector in features {
730 let mut score = 0.0;
732
733 for tree_idx in 0..self.num_trees {
734 let tree_score = feature_vector
736 .iter()
737 .enumerate()
738 .map(|(i, &f)| f * (0.1 + 0.01 * (tree_idx + i) as f32).sin())
739 .sum::<f32>()
740 / feature_vector.len() as f32;
741
742 score += tree_score * 0.01; }
744
745 scores.push(1.0 / (1.0 + (-score).exp()));
747 }
748
749 Ok(scores)
750 }
751
752 fn get_model_info(&self) -> LTRModelInfo {
753 LTRModelInfo {
754 name: "SimulatedLambdaMART".to_string(),
755 version: "1.0".to_string(),
756 num_features: 0, parameters: self.parameters.clone(),
758 is_trained: true,
759 performance_metrics: None,
760 }
761 }
762
763 fn get_feature_importance(&self) -> Option<Vec<f32>> {
764 Some(vec![0.3, 0.25, 0.2, 0.15, 0.1]) }
767}
768
769macro_rules! impl_mock_ltr_model {
771 ($name:ident) => {
772 struct $name;
773
774 impl $name {
775 fn new() -> Self {
776 Self
777 }
778 }
779
780 impl LTRModel for $name {
781 fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
782 Ok(features
783 .iter()
784 .map(|f| f.iter().sum::<f32>() / f.len() as f32)
785 .map(|s| 1.0 / (1.0 + (-s).exp())) .collect())
787 }
788
789 fn get_model_info(&self) -> LTRModelInfo {
790 LTRModelInfo {
791 name: stringify!($name).to_string(),
792 version: "1.0".to_string(),
793 num_features: 0,
794 parameters: HashMap::new(),
795 is_trained: false,
796 performance_metrics: None,
797 }
798 }
799 }
800 };
801}
802
803impl_mock_ltr_model!(SimulatedRankNetModel);
804impl_mock_ltr_model!(SimulatedListNetModel);
805impl_mock_ltr_model!(SimulatedRankSVMModel);
806
807struct CustomLTRModel {
808 name: String,
809}
810
811impl CustomLTRModel {
812 fn new(name: String) -> Self {
813 Self { name }
814 }
815}
816
817impl LTRModel for CustomLTRModel {
818 fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
819 Ok(vec![0.5; features.len()]) }
821
822 fn get_model_info(&self) -> LTRModelInfo {
823 LTRModelInfo {
824 name: self.name.clone(),
825 version: "custom".to_string(),
826 num_features: 0,
827 parameters: HashMap::new(),
828 is_trained: false,
829 performance_metrics: None,
830 }
831 }
832}
833
834struct SimilarityFeatureExtractor;
836
837impl SimilarityFeatureExtractor {
838 fn new() -> Self {
839 Self
840 }
841}
842
843impl FeatureExtractor for SimilarityFeatureExtractor {
844 fn extract_features(
845 &self,
846 _query: &str,
847 document: &SearchResult,
848 _context: &FeatureExtractionContext,
849 ) -> RragResult<Vec<RankingFeature>> {
850 let similarity = document.score; Ok(vec![RankingFeature {
853 feature_type: FeatureType::QueryDocumentSimilarity,
854 name: "cosine_similarity".to_string(),
855 value: similarity,
856 importance: Some(0.8),
857 metadata: FeatureMetadata {
858 extraction_method: "vector_similarity".to_string(),
859 extraction_time_ms: 1,
860 confidence: 0.9,
861 properties: HashMap::new(),
862 },
863 }])
864 }
865
866 fn supported_features(&self) -> Vec<FeatureType> {
867 vec![FeatureType::QueryDocumentSimilarity]
868 }
869
870 fn get_config(&self) -> FeatureExtractorConfig {
871 FeatureExtractorConfig {
872 name: "SimilarityFeatureExtractor".to_string(),
873 supported_features: self.supported_features(),
874 performance: FeatureExtractorPerformance {
875 avg_extraction_time_ms: 1.0,
876 memory_usage_mb: 0.1,
877 quality_score: 0.9,
878 },
879 }
880 }
881}
882
883struct LengthFeatureExtractor;
884
885impl LengthFeatureExtractor {
886 fn new() -> Self {
887 Self
888 }
889}
890
891impl FeatureExtractor for LengthFeatureExtractor {
892 fn extract_features(
893 &self,
894 _query: &str,
895 document: &SearchResult,
896 context: &FeatureExtractionContext,
897 ) -> RragResult<Vec<RankingFeature>> {
898 let doc_length = document.content.split_whitespace().count() as f32;
899 let normalized_length = doc_length / context.collection_stats.avg_document_length;
900
901 Ok(vec![
902 RankingFeature {
903 feature_type: FeatureType::DocumentLength,
904 name: "document_length".to_string(),
905 value: doc_length,
906 importance: Some(0.3),
907 metadata: FeatureMetadata {
908 extraction_method: "word_count".to_string(),
909 extraction_time_ms: 1,
910 confidence: 1.0,
911 properties: HashMap::new(),
912 },
913 },
914 RankingFeature {
915 feature_type: FeatureType::DocumentLength,
916 name: "normalized_document_length".to_string(),
917 value: normalized_length,
918 importance: Some(0.4),
919 metadata: FeatureMetadata {
920 extraction_method: "normalized_word_count".to_string(),
921 extraction_time_ms: 1,
922 confidence: 1.0,
923 properties: HashMap::new(),
924 },
925 },
926 ])
927 }
928
929 fn supported_features(&self) -> Vec<FeatureType> {
930 vec![FeatureType::DocumentLength]
931 }
932
933 fn get_config(&self) -> FeatureExtractorConfig {
934 FeatureExtractorConfig {
935 name: "LengthFeatureExtractor".to_string(),
936 supported_features: self.supported_features(),
937 performance: FeatureExtractorPerformance {
938 avg_extraction_time_ms: 1.0,
939 memory_usage_mb: 0.01,
940 quality_score: 1.0,
941 },
942 }
943 }
944}
945
946struct TermFrequencyExtractor;
947
948impl TermFrequencyExtractor {
949 fn new() -> Self {
950 Self
951 }
952}
953
954impl FeatureExtractor for TermFrequencyExtractor {
955 fn extract_features(
956 &self,
957 _query: &str,
958 document: &SearchResult,
959 context: &FeatureExtractionContext,
960 ) -> RragResult<Vec<RankingFeature>> {
961 let mut features = Vec::new();
962
963 let doc_terms: std::collections::HashMap<String, usize> = {
964 let mut map = std::collections::HashMap::new();
965 for term in document.content.split_whitespace() {
966 let term = term.to_lowercase();
967 *map.entry(term).or_insert(0) += 1;
968 }
969 map
970 };
971
972 let mut total_qtf = 0.0;
974 let mut matched_terms = 0;
975
976 for query_term in &context.query_stats.terms {
977 if let Some(&tf) = doc_terms.get(query_term) {
978 total_qtf += tf as f32;
979 matched_terms += 1;
980 }
981 }
982
983 features.push(RankingFeature {
984 feature_type: FeatureType::QueryTermFrequency,
985 name: "total_query_term_frequency".to_string(),
986 value: total_qtf,
987 importance: Some(0.6),
988 metadata: FeatureMetadata {
989 extraction_method: "term_counting".to_string(),
990 extraction_time_ms: 2,
991 confidence: 0.9,
992 properties: HashMap::new(),
993 },
994 });
995
996 features.push(RankingFeature {
997 feature_type: FeatureType::QueryTermFrequency,
998 name: "query_term_coverage".to_string(),
999 value: matched_terms as f32 / context.query_stats.terms.len() as f32,
1000 importance: Some(0.7),
1001 metadata: FeatureMetadata {
1002 extraction_method: "coverage_calculation".to_string(),
1003 extraction_time_ms: 1,
1004 confidence: 1.0,
1005 properties: HashMap::new(),
1006 },
1007 });
1008
1009 Ok(features)
1010 }
1011
1012 fn supported_features(&self) -> Vec<FeatureType> {
1013 vec![FeatureType::QueryTermFrequency]
1014 }
1015
1016 fn get_config(&self) -> FeatureExtractorConfig {
1017 FeatureExtractorConfig {
1018 name: "TermFrequencyExtractor".to_string(),
1019 supported_features: self.supported_features(),
1020 performance: FeatureExtractorPerformance {
1021 avg_extraction_time_ms: 3.0,
1022 memory_usage_mb: 0.05,
1023 quality_score: 0.8,
1024 },
1025 }
1026 }
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031 use super::*;
1032 use crate::SearchResult;
1033
1034 #[tokio::test]
1035 async fn test_ltr_reranking() {
1036 let config = LTRConfig::default();
1037 let reranker = LearningToRankReranker::new(config);
1038
1039 let results = vec![
1040 SearchResult {
1041 id: "doc1".to_string(),
1042 content: "Machine learning is a subset of artificial intelligence that enables computers to learn".to_string(),
1043 score: 0.8,
1044 rank: 0,
1045 metadata: HashMap::new(),
1046 embedding: None,
1047 },
1048 SearchResult {
1049 id: "doc2".to_string(),
1050 content: "AI and ML".to_string(),
1051 score: 0.6,
1052 rank: 1,
1053 metadata: HashMap::new(),
1054 embedding: None,
1055 },
1056 ];
1057
1058 let query = "machine learning artificial intelligence";
1059 let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1060
1061 assert!(!reranked_scores.is_empty());
1062 assert!(reranked_scores.contains_key(&0));
1063 assert!(reranked_scores.contains_key(&1));
1064 }
1065
1066 #[test]
1067 fn test_feature_extraction() {
1068 let extractor = SimilarityFeatureExtractor::new();
1069 let context = FeatureExtractionContext {
1070 all_documents: vec![],
1071 query_stats: QueryStats {
1072 length: 2,
1073 terms: vec!["test".to_string(), "query".to_string()],
1074 query_type: None,
1075 term_frequencies: HashMap::new(),
1076 },
1077 collection_stats: CollectionStats {
1078 total_documents: 1,
1079 avg_document_length: 10.0,
1080 document_frequencies: HashMap::new(),
1081 vocabulary_size: 100,
1082 },
1083 user_context: None,
1084 };
1085
1086 let document = SearchResult {
1087 id: "test_doc".to_string(),
1088 content: "test document content".to_string(),
1089 score: 0.7,
1090 rank: 0,
1091 metadata: HashMap::new(),
1092 embedding: None,
1093 };
1094
1095 let features = extractor
1096 .extract_features("test query", &document, &context)
1097 .unwrap();
1098
1099 assert!(!features.is_empty());
1100 assert_eq!(
1101 features[0].feature_type,
1102 FeatureType::QueryDocumentSimilarity
1103 );
1104 assert_eq!(features[0].value, 0.7);
1105 }
1106}