1use crate::{RragResult, SearchResult};
8use std::collections::HashMap;
9use tracing::warn;
10
11pub struct MultiSignalReranker {
13 config: MultiSignalConfig,
15
16 signal_extractors: HashMap<SignalType, Box<dyn SignalExtractor>>,
18
19 signal_weights: HashMap<SignalType, f32>,
21
22 aggregation: SignalAggregation,
24}
25
26#[derive(Debug, Clone)]
28pub struct MultiSignalConfig {
29 pub enabled_signals: Vec<SignalType>,
31
32 pub signal_weights: HashMap<SignalType, SignalWeight>,
34
35 pub aggregation_method: SignalAggregation,
37
38 pub normalization: SignalNormalization,
40
41 pub min_signal_confidence: f32,
43
44 pub enable_adaptive_weights: bool,
46
47 pub learning_rate: f32,
49}
50
51impl Default for MultiSignalConfig {
52 fn default() -> Self {
53 let mut signal_weights = HashMap::new();
54 signal_weights.insert(SignalType::SemanticRelevance, SignalWeight::Fixed(0.3));
55 signal_weights.insert(SignalType::TextualRelevance, SignalWeight::Fixed(0.25));
56 signal_weights.insert(SignalType::DocumentFreshness, SignalWeight::Fixed(0.15));
57 signal_weights.insert(SignalType::DocumentAuthority, SignalWeight::Fixed(0.1));
58 signal_weights.insert(SignalType::DocumentQuality, SignalWeight::Fixed(0.1));
59 signal_weights.insert(SignalType::UserPreference, SignalWeight::Fixed(0.05));
60 signal_weights.insert(SignalType::ClickThroughRate, SignalWeight::Fixed(0.05));
61
62 Self {
63 enabled_signals: vec![
64 SignalType::SemanticRelevance,
65 SignalType::TextualRelevance,
66 SignalType::DocumentFreshness,
67 SignalType::DocumentQuality,
68 ],
69 signal_weights,
70 aggregation_method: SignalAggregation::WeightedSum,
71 normalization: SignalNormalization::MinMax,
72 min_signal_confidence: 0.1,
73 enable_adaptive_weights: false,
74 learning_rate: 0.01,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Hash, PartialEq, Eq)]
81pub enum SignalType {
82 SemanticRelevance,
84 TextualRelevance,
86 DocumentFreshness,
88 DocumentAuthority,
90 DocumentQuality,
92 UserPreference,
94 ClickThroughRate,
96 DocumentPopularity,
98 InteractionHistory,
100 DomainSpecific(String),
102}
103
104pub enum SignalWeight {
106 Fixed(f32),
108 QueryDependent(Box<dyn Fn(&str) -> f32 + Send + Sync>),
110 Learned,
112 Adaptive(f32), }
115
116impl std::fmt::Debug for SignalWeight {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 Self::Fixed(w) => write!(f, "Fixed({})", w),
120 Self::QueryDependent(_) => write!(f, "QueryDependent(<function>)"),
121 Self::Learned => write!(f, "Learned"),
122 Self::Adaptive(w) => write!(f, "Adaptive({})", w),
123 }
124 }
125}
126
127impl Clone for SignalWeight {
128 fn clone(&self) -> Self {
129 match self {
130 Self::Fixed(w) => Self::Fixed(*w),
131 Self::QueryDependent(_) => Self::Fixed(0.5), Self::Learned => Self::Learned,
133 Self::Adaptive(w) => Self::Adaptive(*w),
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
140pub enum SignalAggregation {
141 WeightedSum,
143 WeightedAverage,
145 Max,
147 Min,
149 LearnedCombination,
151 Custom(String),
153}
154
155#[derive(Debug, Clone)]
157pub enum SignalNormalization {
158 MinMax,
160 ZScore,
162 Rank,
164 Sigmoid,
166 None,
168}
169
170#[derive(Debug, Clone)]
172pub struct RelevanceSignal {
173 pub signal_type: SignalType,
175
176 pub value: f32,
178
179 pub confidence: f32,
181
182 pub metadata: SignalMetadata,
184}
185
186#[derive(Debug, Clone)]
188pub struct SignalMetadata {
189 pub source: String,
191
192 pub extraction_time_ms: u64,
194
195 pub features: HashMap<String, f32>,
197
198 pub warnings: Vec<String>,
200}
201
202pub trait SignalExtractor: Send + Sync {
204 fn extract_signal(
206 &self,
207 query: &str,
208 document: &SearchResult,
209 context: &RetrievalContext,
210 ) -> RragResult<RelevanceSignal>;
211
212 fn extract_batch(
214 &self,
215 query: &str,
216 documents: &[SearchResult],
217 context: &RetrievalContext,
218 ) -> RragResult<Vec<RelevanceSignal>> {
219 documents
220 .iter()
221 .map(|doc| self.extract_signal(query, doc, context))
222 .collect()
223 }
224
225 fn signal_type(&self) -> SignalType;
227
228 fn get_config(&self) -> SignalExtractorConfig;
230}
231
232#[derive(Debug, Clone)]
234pub struct SignalExtractorConfig {
235 pub name: String,
237
238 pub version: String,
240
241 pub features: Vec<String>,
243
244 pub performance: PerformanceMetrics,
246}
247
248#[derive(Debug, Clone)]
250pub struct PerformanceMetrics {
251 pub avg_extraction_time_ms: f32,
253
254 pub accuracy: f32,
256
257 pub memory_usage_mb: f32,
259}
260
261#[derive(Debug, Clone)]
263pub struct RetrievalContext {
264 pub user_id: Option<String>,
266
267 pub session_id: Option<String>,
269
270 pub timestamp: chrono::DateTime<chrono::Utc>,
272
273 pub query_intent: Option<String>,
275
276 pub user_preferences: HashMap<String, f32>,
278
279 pub interaction_history: Vec<InteractionRecord>,
281}
282
283#[derive(Debug, Clone)]
285pub struct InteractionRecord {
286 pub document_id: String,
288
289 pub interaction_type: String,
291
292 pub timestamp: chrono::DateTime<chrono::Utc>,
294
295 pub value: f32,
297}
298
299impl MultiSignalReranker {
300 pub fn new(config: MultiSignalConfig) -> Self {
302 let mut reranker = Self {
303 config: config.clone(),
304 signal_extractors: HashMap::new(),
305 signal_weights: HashMap::new(),
306 aggregation: config.aggregation_method.clone(),
307 };
308
309 reranker.initialize_extractors();
311
312 reranker.initialize_weights();
314
315 reranker
316 }
317
318 fn initialize_extractors(&mut self) {
320 for signal_type in &self.config.enabled_signals {
321 let extractor: Box<dyn SignalExtractor> = match signal_type {
322 SignalType::SemanticRelevance => Box::new(SemanticRelevanceExtractor::new()),
323 SignalType::TextualRelevance => Box::new(TextualRelevanceExtractor::new()),
324 SignalType::DocumentFreshness => Box::new(DocumentFreshnessExtractor::new()),
325 SignalType::DocumentAuthority => Box::new(DocumentAuthorityExtractor::new()),
326 SignalType::DocumentQuality => Box::new(DocumentQualityExtractor::new()),
327 SignalType::UserPreference => Box::new(UserPreferenceExtractor::new()),
328 SignalType::ClickThroughRate => Box::new(ClickThroughRateExtractor::new()),
329 SignalType::DocumentPopularity => Box::new(DocumentPopularityExtractor::new()),
330 SignalType::InteractionHistory => Box::new(InteractionHistoryExtractor::new()),
331 SignalType::DomainSpecific(domain) => {
332 Box::new(DomainSpecificExtractor::new(domain.clone()))
333 }
334 };
335
336 self.signal_extractors
337 .insert(signal_type.clone(), extractor);
338 }
339 }
340
341 fn initialize_weights(&mut self) {
343 for (signal_type, weight_config) in &self.config.signal_weights {
344 let weight = match weight_config {
345 SignalWeight::Fixed(w) => *w,
346 SignalWeight::Adaptive(w) => *w,
347 SignalWeight::Learned => 1.0 / self.config.signal_weights.len() as f32, SignalWeight::QueryDependent(_) => 1.0, };
350
351 self.signal_weights.insert(signal_type.clone(), weight);
352 }
353 }
354
355 pub async fn rerank(
357 &self,
358 query: &str,
359 results: &[SearchResult],
360 ) -> RragResult<HashMap<usize, f32>> {
361 let context = RetrievalContext {
362 user_id: None,
363 session_id: None,
364 timestamp: chrono::Utc::now(),
365 query_intent: None,
366 user_preferences: HashMap::new(),
367 interaction_history: Vec::new(),
368 };
369
370 self.rerank_with_context(query, results, &context).await
371 }
372
373 pub async fn rerank_with_context(
375 &self,
376 query: &str,
377 results: &[SearchResult],
378 context: &RetrievalContext,
379 ) -> RragResult<HashMap<usize, f32>> {
380 let mut final_scores = HashMap::new();
381
382 let mut all_signals: HashMap<SignalType, Vec<RelevanceSignal>> = HashMap::new();
384
385 for (signal_type, extractor) in &self.signal_extractors {
386 match extractor.extract_batch(query, results, context) {
387 Ok(signals) => {
388 all_signals.insert(signal_type.clone(), signals);
389 }
390 Err(e) => {
391 warn!(" Failed to extract signal {:?}: {}", signal_type, e);
392 }
394 }
395 }
396
397 let normalized_signals = self.normalize_signals(all_signals)?;
399
400 for (doc_idx, _) in results.iter().enumerate() {
402 let mut signal_values = Vec::new();
403 let mut signal_weights = Vec::new();
404
405 for (signal_type, signals) in &normalized_signals {
406 if let Some(signal) = signals.get(doc_idx) {
407 if signal.confidence >= self.config.min_signal_confidence {
408 signal_values.push(signal.value);
409
410 let weight = self.get_signal_weight(signal_type, query, signal)?;
411 signal_weights.push(weight);
412 }
413 }
414 }
415
416 let final_score = self.aggregate_signals(&signal_values, &signal_weights)?;
418 final_scores.insert(doc_idx, final_score);
419 }
420
421 Ok(final_scores)
422 }
423
424 fn normalize_signals(
426 &self,
427 signals: HashMap<SignalType, Vec<RelevanceSignal>>,
428 ) -> RragResult<HashMap<SignalType, Vec<RelevanceSignal>>> {
429 let mut normalized = HashMap::new();
430
431 for (signal_type, signal_list) in signals {
432 let normalized_list = match self.config.normalization {
433 SignalNormalization::MinMax => self.normalize_min_max(&signal_list),
434 SignalNormalization::ZScore => self.normalize_z_score(&signal_list),
435 SignalNormalization::Rank => self.normalize_rank(&signal_list),
436 SignalNormalization::Sigmoid => self.normalize_sigmoid(&signal_list),
437 SignalNormalization::None => signal_list,
438 };
439
440 normalized.insert(signal_type, normalized_list);
441 }
442
443 Ok(normalized)
444 }
445
446 fn normalize_min_max(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
448 let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
449 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
450 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
451
452 let range = max_val - min_val;
453 if range == 0.0 {
454 return signals.to_vec(); }
456
457 signals
458 .iter()
459 .map(|signal| {
460 let mut normalized = signal.clone();
461 normalized.value = (signal.value - min_val) / range;
462 normalized
463 })
464 .collect()
465 }
466
467 fn normalize_z_score(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
469 let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
470 let mean = values.iter().sum::<f32>() / values.len() as f32;
471 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
472 let std_dev = variance.sqrt();
473
474 if std_dev == 0.0 {
475 return signals.to_vec();
476 }
477
478 signals
479 .iter()
480 .map(|signal| {
481 let mut normalized = signal.clone();
482 normalized.value = (signal.value - mean) / std_dev;
483 normalized.value = 1.0 / (1.0 + (-normalized.value).exp());
485 normalized
486 })
487 .collect()
488 }
489
490 fn normalize_rank(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
492 let mut indexed_signals: Vec<(usize, &RelevanceSignal)> =
493 signals.iter().enumerate().collect();
494
495 indexed_signals.sort_by(|a, b| {
496 b.1.value
497 .partial_cmp(&a.1.value)
498 .unwrap_or(std::cmp::Ordering::Equal)
499 });
500
501 let mut normalized = vec![signals[0].clone(); signals.len()];
502 for (rank, (original_idx, signal)) in indexed_signals.iter().enumerate() {
503 normalized[*original_idx] = (*signal).clone();
504 normalized[*original_idx].value = 1.0 - (rank as f32 / signals.len() as f32);
505 }
506
507 normalized
508 }
509
510 fn normalize_sigmoid(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
512 signals
513 .iter()
514 .map(|signal| {
515 let mut normalized = signal.clone();
516 normalized.value = 1.0 / (1.0 + (-signal.value).exp());
517 normalized
518 })
519 .collect()
520 }
521
522 fn get_signal_weight(
524 &self,
525 signal_type: &SignalType,
526 query: &str,
527 _signal: &RelevanceSignal,
528 ) -> RragResult<f32> {
529 if let Some(weight_config) = self.config.signal_weights.get(signal_type) {
530 match weight_config {
531 SignalWeight::Fixed(w) => Ok(*w),
532 SignalWeight::Adaptive(w) => Ok(*w),
533 SignalWeight::Learned => {
534 Ok(self.signal_weights.get(signal_type).copied().unwrap_or(1.0))
535 }
536 SignalWeight::QueryDependent(func) => Ok(func(query)),
537 }
538 } else {
539 Ok(1.0 / self.config.signal_weights.len() as f32) }
541 }
542
543 fn aggregate_signals(&self, values: &[f32], weights: &[f32]) -> RragResult<f32> {
545 if values.is_empty() {
546 return Ok(0.0);
547 }
548
549 match &self.aggregation {
550 SignalAggregation::WeightedSum => {
551 Ok(values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum())
552 }
553 SignalAggregation::WeightedAverage => {
554 let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
555 let weight_sum: f32 = weights.iter().sum();
556 Ok(if weight_sum > 0.0 {
557 weighted_sum / weight_sum
558 } else {
559 0.0
560 })
561 }
562 SignalAggregation::Max => Ok(values.iter().fold(0.0f32, |a, &b| a.max(b))),
563 SignalAggregation::Min => Ok(values.iter().fold(1.0f32, |a, &b| a.min(b))),
564 SignalAggregation::LearnedCombination => {
565 let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
567 let weight_sum: f32 = weights.iter().sum();
568 Ok(if weight_sum > 0.0 {
569 weighted_sum / weight_sum
570 } else {
571 0.0
572 })
573 }
574 SignalAggregation::Custom(_) => {
575 Ok(values.iter().sum::<f32>() / values.len() as f32)
577 }
578 }
579 }
580}
581
582struct SemanticRelevanceExtractor;
587
588impl SemanticRelevanceExtractor {
589 fn new() -> Self {
590 Self
591 }
592}
593
594impl SignalExtractor for SemanticRelevanceExtractor {
595 fn extract_signal(
596 &self,
597 _query: &str,
598 document: &SearchResult,
599 _context: &RetrievalContext,
600 ) -> RragResult<RelevanceSignal> {
601 Ok(RelevanceSignal {
603 signal_type: SignalType::SemanticRelevance,
604 value: document.score,
605 confidence: 0.8,
606 metadata: SignalMetadata {
607 source: "search_engine".to_string(),
608 extraction_time_ms: 1,
609 features: HashMap::new(),
610 warnings: Vec::new(),
611 },
612 })
613 }
614
615 fn signal_type(&self) -> SignalType {
616 SignalType::SemanticRelevance
617 }
618
619 fn get_config(&self) -> SignalExtractorConfig {
620 SignalExtractorConfig {
621 name: "SemanticRelevanceExtractor".to_string(),
622 version: "1.0".to_string(),
623 features: vec!["vector_similarity".to_string()],
624 performance: PerformanceMetrics {
625 avg_extraction_time_ms: 1.0,
626 accuracy: 0.8,
627 memory_usage_mb: 0.1,
628 },
629 }
630 }
631}
632
633struct TextualRelevanceExtractor;
635
636impl TextualRelevanceExtractor {
637 fn new() -> Self {
638 Self
639 }
640}
641
642impl SignalExtractor for TextualRelevanceExtractor {
643 fn extract_signal(
644 &self,
645 query: &str,
646 document: &SearchResult,
647 _context: &RetrievalContext,
648 ) -> RragResult<RelevanceSignal> {
649 let query_terms: std::collections::HashSet<&str> = query.split_whitespace().collect();
651 let doc_terms: std::collections::HashSet<&str> =
652 document.content.split_whitespace().collect();
653
654 let intersection = query_terms.intersection(&doc_terms).count();
655 let union = query_terms.union(&doc_terms).count();
656
657 let jaccard = if union == 0 {
658 0.0
659 } else {
660 intersection as f32 / union as f32
661 };
662
663 Ok(RelevanceSignal {
664 signal_type: SignalType::TextualRelevance,
665 value: jaccard,
666 confidence: 0.7,
667 metadata: SignalMetadata {
668 source: "textual_analysis".to_string(),
669 extraction_time_ms: 2,
670 features: [
671 ("intersection".to_string(), intersection as f32),
672 ("union".to_string(), union as f32),
673 ]
674 .iter()
675 .cloned()
676 .collect(),
677 warnings: Vec::new(),
678 },
679 })
680 }
681
682 fn signal_type(&self) -> SignalType {
683 SignalType::TextualRelevance
684 }
685
686 fn get_config(&self) -> SignalExtractorConfig {
687 SignalExtractorConfig {
688 name: "TextualRelevanceExtractor".to_string(),
689 version: "1.0".to_string(),
690 features: vec!["term_overlap".to_string(), "jaccard_similarity".to_string()],
691 performance: PerformanceMetrics {
692 avg_extraction_time_ms: 2.0,
693 accuracy: 0.7,
694 memory_usage_mb: 0.05,
695 },
696 }
697 }
698}
699
700struct DocumentFreshnessExtractor;
702
703impl DocumentFreshnessExtractor {
704 fn new() -> Self {
705 Self
706 }
707}
708
709impl SignalExtractor for DocumentFreshnessExtractor {
710 fn extract_signal(
711 &self,
712 _query: &str,
713 document: &SearchResult,
714 context: &RetrievalContext,
715 ) -> RragResult<RelevanceSignal> {
716 let doc_timestamp = document
718 .metadata
719 .get("timestamp")
720 .and_then(|v| v.as_str())
721 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
722 .map(|dt| dt.with_timezone(&chrono::Utc))
723 .unwrap_or_else(|| context.timestamp - chrono::Duration::days(30)); let age_hours = (context.timestamp - doc_timestamp).num_hours() as f32;
726
727 let freshness = (-age_hours / (24.0 * 7.0)).exp().min(1.0); Ok(RelevanceSignal {
731 signal_type: SignalType::DocumentFreshness,
732 value: freshness,
733 confidence: 0.9,
734 metadata: SignalMetadata {
735 source: "document_metadata".to_string(),
736 extraction_time_ms: 1,
737 features: [("age_hours".to_string(), age_hours)]
738 .iter()
739 .cloned()
740 .collect(),
741 warnings: Vec::new(),
742 },
743 })
744 }
745
746 fn signal_type(&self) -> SignalType {
747 SignalType::DocumentFreshness
748 }
749
750 fn get_config(&self) -> SignalExtractorConfig {
751 SignalExtractorConfig {
752 name: "DocumentFreshnessExtractor".to_string(),
753 version: "1.0".to_string(),
754 features: vec!["temporal_decay".to_string()],
755 performance: PerformanceMetrics {
756 avg_extraction_time_ms: 1.0,
757 accuracy: 0.9,
758 memory_usage_mb: 0.01,
759 },
760 }
761 }
762}
763
764struct DocumentQualityExtractor;
766
767impl DocumentQualityExtractor {
768 fn new() -> Self {
769 Self
770 }
771}
772
773impl SignalExtractor for DocumentQualityExtractor {
774 fn extract_signal(
775 &self,
776 _query: &str,
777 document: &SearchResult,
778 _context: &RetrievalContext,
779 ) -> RragResult<RelevanceSignal> {
780 let length = document.content.len() as f32;
782 let words = document.content.split_whitespace().count() as f32;
783 let sentences = document.content.split('.').count() as f32;
784
785 let length_score = if length > 100.0 && length < 5000.0 {
787 1.0
788 } else {
789 0.5
790 };
791 let avg_word_length = if words > 0.0 { length / words } else { 0.0 };
792 let word_length_score = if avg_word_length > 3.0 && avg_word_length < 15.0 {
793 1.0
794 } else {
795 0.7
796 };
797 let sentence_length = if sentences > 0.0 {
798 words / sentences
799 } else {
800 0.0
801 };
802 let sentence_score = if sentence_length > 5.0 && sentence_length < 30.0 {
803 1.0
804 } else {
805 0.8
806 };
807
808 let quality_score = (length_score + word_length_score + sentence_score) / 3.0;
809
810 Ok(RelevanceSignal {
811 signal_type: SignalType::DocumentQuality,
812 value: quality_score,
813 confidence: 0.6,
814 metadata: SignalMetadata {
815 source: "quality_analysis".to_string(),
816 extraction_time_ms: 3,
817 features: [
818 ("length".to_string(), length),
819 ("word_count".to_string(), words),
820 ("sentence_count".to_string(), sentences),
821 ("avg_word_length".to_string(), avg_word_length),
822 ("avg_sentence_length".to_string(), sentence_length),
823 ]
824 .iter()
825 .cloned()
826 .collect(),
827 warnings: Vec::new(),
828 },
829 })
830 }
831
832 fn signal_type(&self) -> SignalType {
833 SignalType::DocumentQuality
834 }
835
836 fn get_config(&self) -> SignalExtractorConfig {
837 SignalExtractorConfig {
838 name: "DocumentQualityExtractor".to_string(),
839 version: "1.0".to_string(),
840 features: vec![
841 "length_analysis".to_string(),
842 "structural_analysis".to_string(),
843 ],
844 performance: PerformanceMetrics {
845 avg_extraction_time_ms: 3.0,
846 accuracy: 0.6,
847 memory_usage_mb: 0.02,
848 },
849 }
850 }
851}
852
853macro_rules! impl_placeholder_extractor {
855 ($name:ident, $signal_type:expr, $default_value:expr) => {
856 struct $name;
857
858 impl $name {
859 fn new() -> Self {
860 Self
861 }
862 }
863
864 impl SignalExtractor for $name {
865 fn extract_signal(
866 &self,
867 _query: &str,
868 _document: &SearchResult,
869 _context: &RetrievalContext,
870 ) -> RragResult<RelevanceSignal> {
871 Ok(RelevanceSignal {
872 signal_type: $signal_type,
873 value: $default_value,
874 confidence: 0.5,
875 metadata: SignalMetadata {
876 source: "placeholder".to_string(),
877 extraction_time_ms: 1,
878 features: HashMap::new(),
879 warnings: vec!["Placeholder implementation".to_string()],
880 },
881 })
882 }
883
884 fn signal_type(&self) -> SignalType {
885 $signal_type
886 }
887
888 fn get_config(&self) -> SignalExtractorConfig {
889 SignalExtractorConfig {
890 name: stringify!($name).to_string(),
891 version: "0.1".to_string(),
892 features: vec!["placeholder".to_string()],
893 performance: PerformanceMetrics {
894 avg_extraction_time_ms: 1.0,
895 accuracy: 0.5,
896 memory_usage_mb: 0.01,
897 },
898 }
899 }
900 }
901 };
902}
903
904impl_placeholder_extractor!(
905 DocumentAuthorityExtractor,
906 SignalType::DocumentAuthority,
907 0.5
908);
909impl_placeholder_extractor!(UserPreferenceExtractor, SignalType::UserPreference, 0.5);
910impl_placeholder_extractor!(ClickThroughRateExtractor, SignalType::ClickThroughRate, 0.5);
911impl_placeholder_extractor!(
912 DocumentPopularityExtractor,
913 SignalType::DocumentPopularity,
914 0.5
915);
916impl_placeholder_extractor!(
917 InteractionHistoryExtractor,
918 SignalType::InteractionHistory,
919 0.5
920);
921
922struct DomainSpecificExtractor {
923 domain: String,
924}
925
926impl DomainSpecificExtractor {
927 fn new(domain: String) -> Self {
928 Self { domain }
929 }
930}
931
932impl SignalExtractor for DomainSpecificExtractor {
933 fn extract_signal(
934 &self,
935 _query: &str,
936 _document: &SearchResult,
937 _context: &RetrievalContext,
938 ) -> RragResult<RelevanceSignal> {
939 Ok(RelevanceSignal {
940 signal_type: SignalType::DomainSpecific(self.domain.clone()),
941 value: 0.5,
942 confidence: 0.5,
943 metadata: SignalMetadata {
944 source: "domain_specific".to_string(),
945 extraction_time_ms: 1,
946 features: HashMap::new(),
947 warnings: vec!["Placeholder implementation".to_string()],
948 },
949 })
950 }
951
952 fn signal_type(&self) -> SignalType {
953 SignalType::DomainSpecific(self.domain.clone())
954 }
955
956 fn get_config(&self) -> SignalExtractorConfig {
957 SignalExtractorConfig {
958 name: format!("DomainSpecificExtractor({})", self.domain),
959 version: "0.1".to_string(),
960 features: vec!["domain_analysis".to_string()],
961 performance: PerformanceMetrics {
962 avg_extraction_time_ms: 1.0,
963 accuracy: 0.5,
964 memory_usage_mb: 0.01,
965 },
966 }
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973 use crate::SearchResult;
974
975 #[tokio::test]
976 async fn test_multi_signal_reranking() {
977 let config = MultiSignalConfig::default();
978 let reranker = MultiSignalReranker::new(config);
979
980 let results = vec![
981 SearchResult {
982 id: "doc1".to_string(),
983 content: "Machine learning is a subset of artificial intelligence that focuses on algorithms".to_string(),
984 score: 0.8,
985 rank: 0,
986 metadata: HashMap::new(),
987 embedding: None,
988 },
989 SearchResult {
990 id: "doc2".to_string(),
991 content: "AI".to_string(), score: 0.9,
993 rank: 1,
994 metadata: HashMap::new(),
995 embedding: None,
996 },
997 ];
998
999 let query = "What is machine learning in artificial intelligence?";
1000 let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1001
1002 assert!(!reranked_scores.is_empty());
1003 assert!(reranked_scores.get(&0).unwrap_or(&0.0) > &0.0);
1005 }
1006
1007 #[test]
1008 fn test_signal_normalization() {
1009 let config = MultiSignalConfig::default();
1010 let reranker = MultiSignalReranker::new(config);
1011
1012 let signals = vec![
1013 RelevanceSignal {
1014 signal_type: SignalType::SemanticRelevance,
1015 value: 0.1,
1016 confidence: 1.0,
1017 metadata: SignalMetadata {
1018 source: "test".to_string(),
1019 extraction_time_ms: 0,
1020 features: HashMap::new(),
1021 warnings: Vec::new(),
1022 },
1023 },
1024 RelevanceSignal {
1025 signal_type: SignalType::SemanticRelevance,
1026 value: 0.9,
1027 confidence: 1.0,
1028 metadata: SignalMetadata {
1029 source: "test".to_string(),
1030 extraction_time_ms: 0,
1031 features: HashMap::new(),
1032 warnings: Vec::new(),
1033 },
1034 },
1035 ];
1036
1037 let normalized = reranker.normalize_min_max(&signals);
1038 assert_eq!(normalized[0].value, 0.0); assert_eq!(normalized[1].value, 1.0); }
1041}