1use crate::{RragResult, SearchResult};
8use std::collections::HashMap;
9
10pub struct MultiSignalReranker {
12 config: MultiSignalConfig,
14
15 signal_extractors: HashMap<SignalType, Box<dyn SignalExtractor>>,
17
18 signal_weights: HashMap<SignalType, f32>,
20
21 aggregation: SignalAggregation,
23}
24
25#[derive(Debug, Clone)]
27pub struct MultiSignalConfig {
28 pub enabled_signals: Vec<SignalType>,
30
31 pub signal_weights: HashMap<SignalType, SignalWeight>,
33
34 pub aggregation_method: SignalAggregation,
36
37 pub normalization: SignalNormalization,
39
40 pub min_signal_confidence: f32,
42
43 pub enable_adaptive_weights: bool,
45
46 pub learning_rate: f32,
48}
49
50impl Default for MultiSignalConfig {
51 fn default() -> Self {
52 let mut signal_weights = HashMap::new();
53 signal_weights.insert(SignalType::SemanticRelevance, SignalWeight::Fixed(0.3));
54 signal_weights.insert(SignalType::TextualRelevance, SignalWeight::Fixed(0.25));
55 signal_weights.insert(SignalType::DocumentFreshness, SignalWeight::Fixed(0.15));
56 signal_weights.insert(SignalType::DocumentAuthority, SignalWeight::Fixed(0.1));
57 signal_weights.insert(SignalType::DocumentQuality, SignalWeight::Fixed(0.1));
58 signal_weights.insert(SignalType::UserPreference, SignalWeight::Fixed(0.05));
59 signal_weights.insert(SignalType::ClickThroughRate, SignalWeight::Fixed(0.05));
60
61 Self {
62 enabled_signals: vec![
63 SignalType::SemanticRelevance,
64 SignalType::TextualRelevance,
65 SignalType::DocumentFreshness,
66 SignalType::DocumentQuality,
67 ],
68 signal_weights,
69 aggregation_method: SignalAggregation::WeightedSum,
70 normalization: SignalNormalization::MinMax,
71 min_signal_confidence: 0.1,
72 enable_adaptive_weights: false,
73 learning_rate: 0.01,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Hash, PartialEq, Eq)]
80pub enum SignalType {
81 SemanticRelevance,
83 TextualRelevance,
85 DocumentFreshness,
87 DocumentAuthority,
89 DocumentQuality,
91 UserPreference,
93 ClickThroughRate,
95 DocumentPopularity,
97 InteractionHistory,
99 DomainSpecific(String),
101}
102
103pub enum SignalWeight {
105 Fixed(f32),
107 QueryDependent(Box<dyn Fn(&str) -> f32 + Send + Sync>),
109 Learned,
111 Adaptive(f32), }
114
115impl std::fmt::Debug for SignalWeight {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 Self::Fixed(w) => write!(f, "Fixed({})", w),
119 Self::QueryDependent(_) => write!(f, "QueryDependent(<function>)"),
120 Self::Learned => write!(f, "Learned"),
121 Self::Adaptive(w) => write!(f, "Adaptive({})", w),
122 }
123 }
124}
125
126impl Clone for SignalWeight {
127 fn clone(&self) -> Self {
128 match self {
129 Self::Fixed(w) => Self::Fixed(*w),
130 Self::QueryDependent(_) => Self::Fixed(0.5), Self::Learned => Self::Learned,
132 Self::Adaptive(w) => Self::Adaptive(*w),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub enum SignalAggregation {
140 WeightedSum,
142 WeightedAverage,
144 Max,
146 Min,
148 LearnedCombination,
150 Custom(String),
152}
153
154#[derive(Debug, Clone)]
156pub enum SignalNormalization {
157 MinMax,
159 ZScore,
161 Rank,
163 Sigmoid,
165 None,
167}
168
169#[derive(Debug, Clone)]
171pub struct RelevanceSignal {
172 pub signal_type: SignalType,
174
175 pub value: f32,
177
178 pub confidence: f32,
180
181 pub metadata: SignalMetadata,
183}
184
185#[derive(Debug, Clone)]
187pub struct SignalMetadata {
188 pub source: String,
190
191 pub extraction_time_ms: u64,
193
194 pub features: HashMap<String, f32>,
196
197 pub warnings: Vec<String>,
199}
200
201pub trait SignalExtractor: Send + Sync {
203 fn extract_signal(
205 &self,
206 query: &str,
207 document: &SearchResult,
208 context: &RetrievalContext,
209 ) -> RragResult<RelevanceSignal>;
210
211 fn extract_batch(
213 &self,
214 query: &str,
215 documents: &[SearchResult],
216 context: &RetrievalContext,
217 ) -> RragResult<Vec<RelevanceSignal>> {
218 documents
219 .iter()
220 .map(|doc| self.extract_signal(query, doc, context))
221 .collect()
222 }
223
224 fn signal_type(&self) -> SignalType;
226
227 fn get_config(&self) -> SignalExtractorConfig;
229}
230
231#[derive(Debug, Clone)]
233pub struct SignalExtractorConfig {
234 pub name: String,
236
237 pub version: String,
239
240 pub features: Vec<String>,
242
243 pub performance: PerformanceMetrics,
245}
246
247#[derive(Debug, Clone)]
249pub struct PerformanceMetrics {
250 pub avg_extraction_time_ms: f32,
252
253 pub accuracy: f32,
255
256 pub memory_usage_mb: f32,
258}
259
260#[derive(Debug, Clone)]
262pub struct RetrievalContext {
263 pub user_id: Option<String>,
265
266 pub session_id: Option<String>,
268
269 pub timestamp: chrono::DateTime<chrono::Utc>,
271
272 pub query_intent: Option<String>,
274
275 pub user_preferences: HashMap<String, f32>,
277
278 pub interaction_history: Vec<InteractionRecord>,
280}
281
282#[derive(Debug, Clone)]
284pub struct InteractionRecord {
285 pub document_id: String,
287
288 pub interaction_type: String,
290
291 pub timestamp: chrono::DateTime<chrono::Utc>,
293
294 pub value: f32,
296}
297
298impl MultiSignalReranker {
299 pub fn new(config: MultiSignalConfig) -> Self {
301 let mut reranker = Self {
302 config: config.clone(),
303 signal_extractors: HashMap::new(),
304 signal_weights: HashMap::new(),
305 aggregation: config.aggregation_method.clone(),
306 };
307
308 reranker.initialize_extractors();
310
311 reranker.initialize_weights();
313
314 reranker
315 }
316
317 fn initialize_extractors(&mut self) {
319 for signal_type in &self.config.enabled_signals {
320 let extractor: Box<dyn SignalExtractor> = match signal_type {
321 SignalType::SemanticRelevance => Box::new(SemanticRelevanceExtractor::new()),
322 SignalType::TextualRelevance => Box::new(TextualRelevanceExtractor::new()),
323 SignalType::DocumentFreshness => Box::new(DocumentFreshnessExtractor::new()),
324 SignalType::DocumentAuthority => Box::new(DocumentAuthorityExtractor::new()),
325 SignalType::DocumentQuality => Box::new(DocumentQualityExtractor::new()),
326 SignalType::UserPreference => Box::new(UserPreferenceExtractor::new()),
327 SignalType::ClickThroughRate => Box::new(ClickThroughRateExtractor::new()),
328 SignalType::DocumentPopularity => Box::new(DocumentPopularityExtractor::new()),
329 SignalType::InteractionHistory => Box::new(InteractionHistoryExtractor::new()),
330 SignalType::DomainSpecific(domain) => {
331 Box::new(DomainSpecificExtractor::new(domain.clone()))
332 }
333 };
334
335 self.signal_extractors
336 .insert(signal_type.clone(), extractor);
337 }
338 }
339
340 fn initialize_weights(&mut self) {
342 for (signal_type, weight_config) in &self.config.signal_weights {
343 let weight = match weight_config {
344 SignalWeight::Fixed(w) => *w,
345 SignalWeight::Adaptive(w) => *w,
346 SignalWeight::Learned => 1.0 / self.config.signal_weights.len() as f32, SignalWeight::QueryDependent(_) => 1.0, };
349
350 self.signal_weights.insert(signal_type.clone(), weight);
351 }
352 }
353
354 pub async fn rerank(
356 &self,
357 query: &str,
358 results: &[SearchResult],
359 ) -> RragResult<HashMap<usize, f32>> {
360 let context = RetrievalContext {
361 user_id: None,
362 session_id: None,
363 timestamp: chrono::Utc::now(),
364 query_intent: None,
365 user_preferences: HashMap::new(),
366 interaction_history: Vec::new(),
367 };
368
369 self.rerank_with_context(query, results, &context).await
370 }
371
372 pub async fn rerank_with_context(
374 &self,
375 query: &str,
376 results: &[SearchResult],
377 context: &RetrievalContext,
378 ) -> RragResult<HashMap<usize, f32>> {
379 let mut final_scores = HashMap::new();
380
381 let mut all_signals: HashMap<SignalType, Vec<RelevanceSignal>> = HashMap::new();
383
384 for (signal_type, extractor) in &self.signal_extractors {
385 match extractor.extract_batch(query, results, context) {
386 Ok(signals) => {
387 all_signals.insert(signal_type.clone(), signals);
388 }
389 Err(e) => {
390 eprintln!("Warning: Failed to extract signal {:?}: {}", signal_type, e);
391 }
393 }
394 }
395
396 let normalized_signals = self.normalize_signals(all_signals)?;
398
399 for (doc_idx, _) in results.iter().enumerate() {
401 let mut signal_values = Vec::new();
402 let mut signal_weights = Vec::new();
403
404 for (signal_type, signals) in &normalized_signals {
405 if let Some(signal) = signals.get(doc_idx) {
406 if signal.confidence >= self.config.min_signal_confidence {
407 signal_values.push(signal.value);
408
409 let weight = self.get_signal_weight(signal_type, query, signal)?;
410 signal_weights.push(weight);
411 }
412 }
413 }
414
415 let final_score = self.aggregate_signals(&signal_values, &signal_weights)?;
417 final_scores.insert(doc_idx, final_score);
418 }
419
420 Ok(final_scores)
421 }
422
423 fn normalize_signals(
425 &self,
426 signals: HashMap<SignalType, Vec<RelevanceSignal>>,
427 ) -> RragResult<HashMap<SignalType, Vec<RelevanceSignal>>> {
428 let mut normalized = HashMap::new();
429
430 for (signal_type, signal_list) in signals {
431 let normalized_list = match self.config.normalization {
432 SignalNormalization::MinMax => self.normalize_min_max(&signal_list),
433 SignalNormalization::ZScore => self.normalize_z_score(&signal_list),
434 SignalNormalization::Rank => self.normalize_rank(&signal_list),
435 SignalNormalization::Sigmoid => self.normalize_sigmoid(&signal_list),
436 SignalNormalization::None => signal_list,
437 };
438
439 normalized.insert(signal_type, normalized_list);
440 }
441
442 Ok(normalized)
443 }
444
445 fn normalize_min_max(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
447 let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
448 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
449 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
450
451 let range = max_val - min_val;
452 if range == 0.0 {
453 return signals.to_vec(); }
455
456 signals
457 .iter()
458 .map(|signal| {
459 let mut normalized = signal.clone();
460 normalized.value = (signal.value - min_val) / range;
461 normalized
462 })
463 .collect()
464 }
465
466 fn normalize_z_score(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
468 let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
469 let mean = values.iter().sum::<f32>() / values.len() as f32;
470 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
471 let std_dev = variance.sqrt();
472
473 if std_dev == 0.0 {
474 return signals.to_vec();
475 }
476
477 signals
478 .iter()
479 .map(|signal| {
480 let mut normalized = signal.clone();
481 normalized.value = (signal.value - mean) / std_dev;
482 normalized.value = 1.0 / (1.0 + (-normalized.value).exp());
484 normalized
485 })
486 .collect()
487 }
488
489 fn normalize_rank(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
491 let mut indexed_signals: Vec<(usize, &RelevanceSignal)> =
492 signals.iter().enumerate().collect();
493
494 indexed_signals.sort_by(|a, b| {
495 b.1.value
496 .partial_cmp(&a.1.value)
497 .unwrap_or(std::cmp::Ordering::Equal)
498 });
499
500 let mut normalized = vec![signals[0].clone(); signals.len()];
501 for (rank, (original_idx, signal)) in indexed_signals.iter().enumerate() {
502 normalized[*original_idx] = (*signal).clone();
503 normalized[*original_idx].value = 1.0 - (rank as f32 / signals.len() as f32);
504 }
505
506 normalized
507 }
508
509 fn normalize_sigmoid(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
511 signals
512 .iter()
513 .map(|signal| {
514 let mut normalized = signal.clone();
515 normalized.value = 1.0 / (1.0 + (-signal.value).exp());
516 normalized
517 })
518 .collect()
519 }
520
521 fn get_signal_weight(
523 &self,
524 signal_type: &SignalType,
525 query: &str,
526 _signal: &RelevanceSignal,
527 ) -> RragResult<f32> {
528 if let Some(weight_config) = self.config.signal_weights.get(signal_type) {
529 match weight_config {
530 SignalWeight::Fixed(w) => Ok(*w),
531 SignalWeight::Adaptive(w) => Ok(*w),
532 SignalWeight::Learned => {
533 Ok(self.signal_weights.get(signal_type).copied().unwrap_or(1.0))
534 }
535 SignalWeight::QueryDependent(func) => Ok(func(query)),
536 }
537 } else {
538 Ok(1.0 / self.config.signal_weights.len() as f32) }
540 }
541
542 fn aggregate_signals(&self, values: &[f32], weights: &[f32]) -> RragResult<f32> {
544 if values.is_empty() {
545 return Ok(0.0);
546 }
547
548 match &self.aggregation {
549 SignalAggregation::WeightedSum => {
550 Ok(values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum())
551 }
552 SignalAggregation::WeightedAverage => {
553 let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
554 let weight_sum: f32 = weights.iter().sum();
555 Ok(if weight_sum > 0.0 {
556 weighted_sum / weight_sum
557 } else {
558 0.0
559 })
560 }
561 SignalAggregation::Max => Ok(values.iter().fold(0.0f32, |a, &b| a.max(b))),
562 SignalAggregation::Min => Ok(values.iter().fold(1.0f32, |a, &b| a.min(b))),
563 SignalAggregation::LearnedCombination => {
564 let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
566 let weight_sum: f32 = weights.iter().sum();
567 Ok(if weight_sum > 0.0 {
568 weighted_sum / weight_sum
569 } else {
570 0.0
571 })
572 }
573 SignalAggregation::Custom(_) => {
574 Ok(values.iter().sum::<f32>() / values.len() as f32)
576 }
577 }
578 }
579}
580
581struct SemanticRelevanceExtractor;
586
587impl SemanticRelevanceExtractor {
588 fn new() -> Self {
589 Self
590 }
591}
592
593impl SignalExtractor for SemanticRelevanceExtractor {
594 fn extract_signal(
595 &self,
596 _query: &str,
597 document: &SearchResult,
598 _context: &RetrievalContext,
599 ) -> RragResult<RelevanceSignal> {
600 Ok(RelevanceSignal {
602 signal_type: SignalType::SemanticRelevance,
603 value: document.score,
604 confidence: 0.8,
605 metadata: SignalMetadata {
606 source: "search_engine".to_string(),
607 extraction_time_ms: 1,
608 features: HashMap::new(),
609 warnings: Vec::new(),
610 },
611 })
612 }
613
614 fn signal_type(&self) -> SignalType {
615 SignalType::SemanticRelevance
616 }
617
618 fn get_config(&self) -> SignalExtractorConfig {
619 SignalExtractorConfig {
620 name: "SemanticRelevanceExtractor".to_string(),
621 version: "1.0".to_string(),
622 features: vec!["vector_similarity".to_string()],
623 performance: PerformanceMetrics {
624 avg_extraction_time_ms: 1.0,
625 accuracy: 0.8,
626 memory_usage_mb: 0.1,
627 },
628 }
629 }
630}
631
632struct TextualRelevanceExtractor;
634
635impl TextualRelevanceExtractor {
636 fn new() -> Self {
637 Self
638 }
639}
640
641impl SignalExtractor for TextualRelevanceExtractor {
642 fn extract_signal(
643 &self,
644 query: &str,
645 document: &SearchResult,
646 _context: &RetrievalContext,
647 ) -> RragResult<RelevanceSignal> {
648 let query_terms: std::collections::HashSet<&str> = query.split_whitespace().collect();
650 let doc_terms: std::collections::HashSet<&str> =
651 document.content.split_whitespace().collect();
652
653 let intersection = query_terms.intersection(&doc_terms).count();
654 let union = query_terms.union(&doc_terms).count();
655
656 let jaccard = if union == 0 {
657 0.0
658 } else {
659 intersection as f32 / union as f32
660 };
661
662 Ok(RelevanceSignal {
663 signal_type: SignalType::TextualRelevance,
664 value: jaccard,
665 confidence: 0.7,
666 metadata: SignalMetadata {
667 source: "textual_analysis".to_string(),
668 extraction_time_ms: 2,
669 features: [
670 ("intersection".to_string(), intersection as f32),
671 ("union".to_string(), union as f32),
672 ]
673 .iter()
674 .cloned()
675 .collect(),
676 warnings: Vec::new(),
677 },
678 })
679 }
680
681 fn signal_type(&self) -> SignalType {
682 SignalType::TextualRelevance
683 }
684
685 fn get_config(&self) -> SignalExtractorConfig {
686 SignalExtractorConfig {
687 name: "TextualRelevanceExtractor".to_string(),
688 version: "1.0".to_string(),
689 features: vec!["term_overlap".to_string(), "jaccard_similarity".to_string()],
690 performance: PerformanceMetrics {
691 avg_extraction_time_ms: 2.0,
692 accuracy: 0.7,
693 memory_usage_mb: 0.05,
694 },
695 }
696 }
697}
698
699struct DocumentFreshnessExtractor;
701
702impl DocumentFreshnessExtractor {
703 fn new() -> Self {
704 Self
705 }
706}
707
708impl SignalExtractor for DocumentFreshnessExtractor {
709 fn extract_signal(
710 &self,
711 _query: &str,
712 document: &SearchResult,
713 context: &RetrievalContext,
714 ) -> RragResult<RelevanceSignal> {
715 let doc_timestamp = document
717 .metadata
718 .get("timestamp")
719 .and_then(|v| v.as_str())
720 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
721 .map(|dt| dt.with_timezone(&chrono::Utc))
722 .unwrap_or_else(|| context.timestamp - chrono::Duration::days(30)); let age_hours = (context.timestamp - doc_timestamp).num_hours() as f32;
725
726 let freshness = (-age_hours / (24.0 * 7.0)).exp().min(1.0); Ok(RelevanceSignal {
730 signal_type: SignalType::DocumentFreshness,
731 value: freshness,
732 confidence: 0.9,
733 metadata: SignalMetadata {
734 source: "document_metadata".to_string(),
735 extraction_time_ms: 1,
736 features: [("age_hours".to_string(), age_hours)]
737 .iter()
738 .cloned()
739 .collect(),
740 warnings: Vec::new(),
741 },
742 })
743 }
744
745 fn signal_type(&self) -> SignalType {
746 SignalType::DocumentFreshness
747 }
748
749 fn get_config(&self) -> SignalExtractorConfig {
750 SignalExtractorConfig {
751 name: "DocumentFreshnessExtractor".to_string(),
752 version: "1.0".to_string(),
753 features: vec!["temporal_decay".to_string()],
754 performance: PerformanceMetrics {
755 avg_extraction_time_ms: 1.0,
756 accuracy: 0.9,
757 memory_usage_mb: 0.01,
758 },
759 }
760 }
761}
762
763struct DocumentQualityExtractor;
765
766impl DocumentQualityExtractor {
767 fn new() -> Self {
768 Self
769 }
770}
771
772impl SignalExtractor for DocumentQualityExtractor {
773 fn extract_signal(
774 &self,
775 _query: &str,
776 document: &SearchResult,
777 _context: &RetrievalContext,
778 ) -> RragResult<RelevanceSignal> {
779 let length = document.content.len() as f32;
781 let words = document.content.split_whitespace().count() as f32;
782 let sentences = document.content.split('.').count() as f32;
783
784 let length_score = if length > 100.0 && length < 5000.0 {
786 1.0
787 } else {
788 0.5
789 };
790 let avg_word_length = if words > 0.0 { length / words } else { 0.0 };
791 let word_length_score = if avg_word_length > 3.0 && avg_word_length < 15.0 {
792 1.0
793 } else {
794 0.7
795 };
796 let sentence_length = if sentences > 0.0 {
797 words / sentences
798 } else {
799 0.0
800 };
801 let sentence_score = if sentence_length > 5.0 && sentence_length < 30.0 {
802 1.0
803 } else {
804 0.8
805 };
806
807 let quality_score = (length_score + word_length_score + sentence_score) / 3.0;
808
809 Ok(RelevanceSignal {
810 signal_type: SignalType::DocumentQuality,
811 value: quality_score,
812 confidence: 0.6,
813 metadata: SignalMetadata {
814 source: "quality_analysis".to_string(),
815 extraction_time_ms: 3,
816 features: [
817 ("length".to_string(), length),
818 ("word_count".to_string(), words),
819 ("sentence_count".to_string(), sentences),
820 ("avg_word_length".to_string(), avg_word_length),
821 ("avg_sentence_length".to_string(), sentence_length),
822 ]
823 .iter()
824 .cloned()
825 .collect(),
826 warnings: Vec::new(),
827 },
828 })
829 }
830
831 fn signal_type(&self) -> SignalType {
832 SignalType::DocumentQuality
833 }
834
835 fn get_config(&self) -> SignalExtractorConfig {
836 SignalExtractorConfig {
837 name: "DocumentQualityExtractor".to_string(),
838 version: "1.0".to_string(),
839 features: vec![
840 "length_analysis".to_string(),
841 "structural_analysis".to_string(),
842 ],
843 performance: PerformanceMetrics {
844 avg_extraction_time_ms: 3.0,
845 accuracy: 0.6,
846 memory_usage_mb: 0.02,
847 },
848 }
849 }
850}
851
852macro_rules! impl_placeholder_extractor {
854 ($name:ident, $signal_type:expr, $default_value:expr) => {
855 struct $name;
856
857 impl $name {
858 fn new() -> Self {
859 Self
860 }
861 }
862
863 impl SignalExtractor for $name {
864 fn extract_signal(
865 &self,
866 _query: &str,
867 _document: &SearchResult,
868 _context: &RetrievalContext,
869 ) -> RragResult<RelevanceSignal> {
870 Ok(RelevanceSignal {
871 signal_type: $signal_type,
872 value: $default_value,
873 confidence: 0.5,
874 metadata: SignalMetadata {
875 source: "placeholder".to_string(),
876 extraction_time_ms: 1,
877 features: HashMap::new(),
878 warnings: vec!["Placeholder implementation".to_string()],
879 },
880 })
881 }
882
883 fn signal_type(&self) -> SignalType {
884 $signal_type
885 }
886
887 fn get_config(&self) -> SignalExtractorConfig {
888 SignalExtractorConfig {
889 name: stringify!($name).to_string(),
890 version: "0.1".to_string(),
891 features: vec!["placeholder".to_string()],
892 performance: PerformanceMetrics {
893 avg_extraction_time_ms: 1.0,
894 accuracy: 0.5,
895 memory_usage_mb: 0.01,
896 },
897 }
898 }
899 }
900 };
901}
902
903impl_placeholder_extractor!(
904 DocumentAuthorityExtractor,
905 SignalType::DocumentAuthority,
906 0.5
907);
908impl_placeholder_extractor!(UserPreferenceExtractor, SignalType::UserPreference, 0.5);
909impl_placeholder_extractor!(ClickThroughRateExtractor, SignalType::ClickThroughRate, 0.5);
910impl_placeholder_extractor!(
911 DocumentPopularityExtractor,
912 SignalType::DocumentPopularity,
913 0.5
914);
915impl_placeholder_extractor!(
916 InteractionHistoryExtractor,
917 SignalType::InteractionHistory,
918 0.5
919);
920
921struct DomainSpecificExtractor {
922 domain: String,
923}
924
925impl DomainSpecificExtractor {
926 fn new(domain: String) -> Self {
927 Self { domain }
928 }
929}
930
931impl SignalExtractor for DomainSpecificExtractor {
932 fn extract_signal(
933 &self,
934 _query: &str,
935 _document: &SearchResult,
936 _context: &RetrievalContext,
937 ) -> RragResult<RelevanceSignal> {
938 Ok(RelevanceSignal {
939 signal_type: SignalType::DomainSpecific(self.domain.clone()),
940 value: 0.5,
941 confidence: 0.5,
942 metadata: SignalMetadata {
943 source: "domain_specific".to_string(),
944 extraction_time_ms: 1,
945 features: HashMap::new(),
946 warnings: vec!["Placeholder implementation".to_string()],
947 },
948 })
949 }
950
951 fn signal_type(&self) -> SignalType {
952 SignalType::DomainSpecific(self.domain.clone())
953 }
954
955 fn get_config(&self) -> SignalExtractorConfig {
956 SignalExtractorConfig {
957 name: format!("DomainSpecificExtractor({})", self.domain),
958 version: "0.1".to_string(),
959 features: vec!["domain_analysis".to_string()],
960 performance: PerformanceMetrics {
961 avg_extraction_time_ms: 1.0,
962 accuracy: 0.5,
963 memory_usage_mb: 0.01,
964 },
965 }
966 }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972 use crate::SearchResult;
973
974 #[tokio::test]
975 async fn test_multi_signal_reranking() {
976 let config = MultiSignalConfig::default();
977 let reranker = MultiSignalReranker::new(config);
978
979 let results = vec![
980 SearchResult {
981 id: "doc1".to_string(),
982 content: "Machine learning is a subset of artificial intelligence that focuses on algorithms".to_string(),
983 score: 0.8,
984 rank: 0,
985 metadata: HashMap::new(),
986 embedding: None,
987 },
988 SearchResult {
989 id: "doc2".to_string(),
990 content: "AI".to_string(), score: 0.9,
992 rank: 1,
993 metadata: HashMap::new(),
994 embedding: None,
995 },
996 ];
997
998 let query = "What is machine learning in artificial intelligence?";
999 let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1000
1001 assert!(!reranked_scores.is_empty());
1002 assert!(reranked_scores.get(&0).unwrap_or(&0.0) > &0.0);
1004 }
1005
1006 #[test]
1007 fn test_signal_normalization() {
1008 let config = MultiSignalConfig::default();
1009 let reranker = MultiSignalReranker::new(config);
1010
1011 let signals = vec![
1012 RelevanceSignal {
1013 signal_type: SignalType::SemanticRelevance,
1014 value: 0.1,
1015 confidence: 1.0,
1016 metadata: SignalMetadata {
1017 source: "test".to_string(),
1018 extraction_time_ms: 0,
1019 features: HashMap::new(),
1020 warnings: Vec::new(),
1021 },
1022 },
1023 RelevanceSignal {
1024 signal_type: SignalType::SemanticRelevance,
1025 value: 0.9,
1026 confidence: 1.0,
1027 metadata: SignalMetadata {
1028 source: "test".to_string(),
1029 extraction_time_ms: 0,
1030 features: HashMap::new(),
1031 warnings: Vec::new(),
1032 },
1033 },
1034 ];
1035
1036 let normalized = reranker.normalize_min_max(&signals);
1037 assert_eq!(normalized[0].value, 0.0); assert_eq!(normalized[1].value, 1.0); }
1040}