rrag/reranking/
multi_signal.rs

1//! # Multi-Signal Reranking
2//!
3//! Combines multiple relevance signals beyond semantic similarity to improve
4//! retrieval accuracy. Includes signals like freshness, authority, click-through
5//! rates, document quality, and user preferences.
6
7use crate::{RragResult, SearchResult};
8use std::collections::HashMap;
9
10/// Multi-signal reranker that combines various relevance signals
11pub struct MultiSignalReranker {
12    /// Configuration
13    config: MultiSignalConfig,
14
15    /// Signal extractors
16    signal_extractors: HashMap<SignalType, Box<dyn SignalExtractor>>,
17
18    /// Signal weights (learned or configured)
19    signal_weights: HashMap<SignalType, f32>,
20
21    /// Signal aggregation method
22    aggregation: SignalAggregation,
23}
24
25/// Configuration for multi-signal reranking
26#[derive(Debug, Clone)]
27pub struct MultiSignalConfig {
28    /// Enabled signal types
29    pub enabled_signals: Vec<SignalType>,
30
31    /// Signal weights
32    pub signal_weights: HashMap<SignalType, SignalWeight>,
33
34    /// Aggregation method
35    pub aggregation_method: SignalAggregation,
36
37    /// Normalization method
38    pub normalization: SignalNormalization,
39
40    /// Minimum signal confidence
41    pub min_signal_confidence: f32,
42
43    /// Enable adaptive weighting
44    pub enable_adaptive_weights: bool,
45
46    /// Learning rate for adaptive weights
47    pub learning_rate: f32,
48}
49
50impl Default for MultiSignalConfig {
51    fn default() -> Self {
52        let mut signal_weights = HashMap::new();
53        signal_weights.insert(SignalType::SemanticRelevance, SignalWeight::Fixed(0.3));
54        signal_weights.insert(SignalType::TextualRelevance, SignalWeight::Fixed(0.25));
55        signal_weights.insert(SignalType::DocumentFreshness, SignalWeight::Fixed(0.15));
56        signal_weights.insert(SignalType::DocumentAuthority, SignalWeight::Fixed(0.1));
57        signal_weights.insert(SignalType::DocumentQuality, SignalWeight::Fixed(0.1));
58        signal_weights.insert(SignalType::UserPreference, SignalWeight::Fixed(0.05));
59        signal_weights.insert(SignalType::ClickThroughRate, SignalWeight::Fixed(0.05));
60
61        Self {
62            enabled_signals: vec![
63                SignalType::SemanticRelevance,
64                SignalType::TextualRelevance,
65                SignalType::DocumentFreshness,
66                SignalType::DocumentQuality,
67            ],
68            signal_weights,
69            aggregation_method: SignalAggregation::WeightedSum,
70            normalization: SignalNormalization::MinMax,
71            min_signal_confidence: 0.1,
72            enable_adaptive_weights: false,
73            learning_rate: 0.01,
74        }
75    }
76}
77
78/// Types of relevance signals
79#[derive(Debug, Clone, Hash, PartialEq, Eq)]
80pub enum SignalType {
81    /// Semantic similarity between query and document
82    SemanticRelevance,
83    /// Textual/keyword relevance (BM25, TF-IDF)
84    TextualRelevance,
85    /// Document freshness/recency
86    DocumentFreshness,
87    /// Document authority/credibility
88    DocumentAuthority,
89    /// Document quality metrics
90    DocumentQuality,
91    /// User preference signals
92    UserPreference,
93    /// Click-through rates
94    ClickThroughRate,
95    /// Document popularity
96    DocumentPopularity,
97    /// Query-document interaction history
98    InteractionHistory,
99    /// Domain-specific signals
100    DomainSpecific(String),
101}
102
103/// Signal weight configuration
104pub enum SignalWeight {
105    /// Fixed weight
106    Fixed(f32),
107    /// Query-dependent weight
108    QueryDependent(Box<dyn Fn(&str) -> f32 + Send + Sync>),
109    /// Learned weight (requires training data)
110    Learned,
111    /// Adaptive weight (updates based on feedback)
112    Adaptive(f32), // Current weight
113}
114
115impl std::fmt::Debug for SignalWeight {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Self::Fixed(w) => write!(f, "Fixed({})", w),
119            Self::QueryDependent(_) => write!(f, "QueryDependent(<function>)"),
120            Self::Learned => write!(f, "Learned"),
121            Self::Adaptive(w) => write!(f, "Adaptive({})", w),
122        }
123    }
124}
125
126impl Clone for SignalWeight {
127    fn clone(&self) -> Self {
128        match self {
129            Self::Fixed(w) => Self::Fixed(*w),
130            Self::QueryDependent(_) => Self::Fixed(0.5), // Can't clone function, default to fixed
131            Self::Learned => Self::Learned,
132            Self::Adaptive(w) => Self::Adaptive(*w),
133        }
134    }
135}
136
137/// Methods for aggregating multiple signals
138#[derive(Debug, Clone)]
139pub enum SignalAggregation {
140    /// Weighted sum of signals
141    WeightedSum,
142    /// Weighted average
143    WeightedAverage,
144    /// Maximum signal value
145    Max,
146    /// Minimum signal value
147    Min,
148    /// Learning-to-rank combination
149    LearnedCombination,
150    /// Custom aggregation function
151    Custom(String),
152}
153
154/// Signal normalization methods
155#[derive(Debug, Clone)]
156pub enum SignalNormalization {
157    /// Min-max normalization (0-1 range)
158    MinMax,
159    /// Z-score normalization
160    ZScore,
161    /// Rank normalization
162    Rank,
163    /// Sigmoid normalization
164    Sigmoid,
165    /// No normalization
166    None,
167}
168
169/// A relevance signal extracted from query-document pair
170#[derive(Debug, Clone)]
171pub struct RelevanceSignal {
172    /// Type of signal
173    pub signal_type: SignalType,
174
175    /// Signal value (typically 0-1)
176    pub value: f32,
177
178    /// Confidence in the signal (0-1)
179    pub confidence: f32,
180
181    /// Signal metadata
182    pub metadata: SignalMetadata,
183}
184
185/// Metadata about a relevance signal
186#[derive(Debug, Clone)]
187pub struct SignalMetadata {
188    /// Source of the signal
189    pub source: String,
190
191    /// Extraction time
192    pub extraction_time_ms: u64,
193
194    /// Features used
195    pub features: HashMap<String, f32>,
196
197    /// Warnings or notes
198    pub warnings: Vec<String>,
199}
200
201/// Trait for extracting relevance signals
202pub trait SignalExtractor: Send + Sync {
203    /// Extract signal from query-document pair
204    fn extract_signal(
205        &self,
206        query: &str,
207        document: &SearchResult,
208        context: &RetrievalContext,
209    ) -> RragResult<RelevanceSignal>;
210
211    /// Extract signals for multiple documents in batch
212    fn extract_batch(
213        &self,
214        query: &str,
215        documents: &[SearchResult],
216        context: &RetrievalContext,
217    ) -> RragResult<Vec<RelevanceSignal>> {
218        documents
219            .iter()
220            .map(|doc| self.extract_signal(query, doc, context))
221            .collect()
222    }
223
224    /// Get signal type
225    fn signal_type(&self) -> SignalType;
226
227    /// Get extractor configuration
228    fn get_config(&self) -> SignalExtractorConfig;
229}
230
231/// Configuration for signal extractors
232#[derive(Debug, Clone)]
233pub struct SignalExtractorConfig {
234    /// Extractor name
235    pub name: String,
236
237    /// Extractor version
238    pub version: String,
239
240    /// Supported features
241    pub features: Vec<String>,
242
243    /// Performance characteristics
244    pub performance: PerformanceMetrics,
245}
246
247/// Performance metrics for signal extractors
248#[derive(Debug, Clone)]
249pub struct PerformanceMetrics {
250    /// Average extraction time (ms)
251    pub avg_extraction_time_ms: f32,
252
253    /// Accuracy/precision of the signal
254    pub accuracy: f32,
255
256    /// Memory usage (MB)
257    pub memory_usage_mb: f32,
258}
259
260/// Context for signal extraction
261#[derive(Debug, Clone)]
262pub struct RetrievalContext {
263    /// User identifier (if available)
264    pub user_id: Option<String>,
265
266    /// Session information
267    pub session_id: Option<String>,
268
269    /// Query timestamp
270    pub timestamp: chrono::DateTime<chrono::Utc>,
271
272    /// Query type/intent
273    pub query_intent: Option<String>,
274
275    /// User preferences
276    pub user_preferences: HashMap<String, f32>,
277
278    /// Historical interactions
279    pub interaction_history: Vec<InteractionRecord>,
280}
281
282/// Historical interaction record
283#[derive(Debug, Clone)]
284pub struct InteractionRecord {
285    /// Document ID
286    pub document_id: String,
287
288    /// Interaction type (click, dwell, etc.)
289    pub interaction_type: String,
290
291    /// Interaction timestamp
292    pub timestamp: chrono::DateTime<chrono::Utc>,
293
294    /// Interaction value/strength
295    pub value: f32,
296}
297
298impl MultiSignalReranker {
299    /// Create a new multi-signal reranker
300    pub fn new(config: MultiSignalConfig) -> Self {
301        let mut reranker = Self {
302            config: config.clone(),
303            signal_extractors: HashMap::new(),
304            signal_weights: HashMap::new(),
305            aggregation: config.aggregation_method.clone(),
306        };
307
308        // Initialize signal extractors
309        reranker.initialize_extractors();
310
311        // Initialize weights
312        reranker.initialize_weights();
313
314        reranker
315    }
316
317    /// Initialize signal extractors based on configuration
318    fn initialize_extractors(&mut self) {
319        for signal_type in &self.config.enabled_signals {
320            let extractor: Box<dyn SignalExtractor> = match signal_type {
321                SignalType::SemanticRelevance => Box::new(SemanticRelevanceExtractor::new()),
322                SignalType::TextualRelevance => Box::new(TextualRelevanceExtractor::new()),
323                SignalType::DocumentFreshness => Box::new(DocumentFreshnessExtractor::new()),
324                SignalType::DocumentAuthority => Box::new(DocumentAuthorityExtractor::new()),
325                SignalType::DocumentQuality => Box::new(DocumentQualityExtractor::new()),
326                SignalType::UserPreference => Box::new(UserPreferenceExtractor::new()),
327                SignalType::ClickThroughRate => Box::new(ClickThroughRateExtractor::new()),
328                SignalType::DocumentPopularity => Box::new(DocumentPopularityExtractor::new()),
329                SignalType::InteractionHistory => Box::new(InteractionHistoryExtractor::new()),
330                SignalType::DomainSpecific(domain) => {
331                    Box::new(DomainSpecificExtractor::new(domain.clone()))
332                }
333            };
334
335            self.signal_extractors
336                .insert(signal_type.clone(), extractor);
337        }
338    }
339
340    /// Initialize signal weights
341    fn initialize_weights(&mut self) {
342        for (signal_type, weight_config) in &self.config.signal_weights {
343            let weight = match weight_config {
344                SignalWeight::Fixed(w) => *w,
345                SignalWeight::Adaptive(w) => *w,
346                SignalWeight::Learned => 1.0 / self.config.signal_weights.len() as f32, // Default uniform
347                SignalWeight::QueryDependent(_) => 1.0, // Will be computed per query
348            };
349
350            self.signal_weights.insert(signal_type.clone(), weight);
351        }
352    }
353
354    /// Rerank search results using multiple signals
355    pub async fn rerank(
356        &self,
357        query: &str,
358        results: &[SearchResult],
359    ) -> RragResult<HashMap<usize, f32>> {
360        let context = RetrievalContext {
361            user_id: None,
362            session_id: None,
363            timestamp: chrono::Utc::now(),
364            query_intent: None,
365            user_preferences: HashMap::new(),
366            interaction_history: Vec::new(),
367        };
368
369        self.rerank_with_context(query, results, &context).await
370    }
371
372    /// Rerank with full context information
373    pub async fn rerank_with_context(
374        &self,
375        query: &str,
376        results: &[SearchResult],
377        context: &RetrievalContext,
378    ) -> RragResult<HashMap<usize, f32>> {
379        let mut final_scores = HashMap::new();
380
381        // Extract all signals for all documents
382        let mut all_signals: HashMap<SignalType, Vec<RelevanceSignal>> = HashMap::new();
383
384        for (signal_type, extractor) in &self.signal_extractors {
385            match extractor.extract_batch(query, results, context) {
386                Ok(signals) => {
387                    all_signals.insert(signal_type.clone(), signals);
388                }
389                Err(e) => {
390                    eprintln!("Warning: Failed to extract signal {:?}: {}", signal_type, e);
391                    // Continue with other signals
392                }
393            }
394        }
395
396        // Normalize signals if needed
397        let normalized_signals = self.normalize_signals(all_signals)?;
398
399        // Compute final scores for each document
400        for (doc_idx, _) in results.iter().enumerate() {
401            let mut signal_values = Vec::new();
402            let mut signal_weights = Vec::new();
403
404            for (signal_type, signals) in &normalized_signals {
405                if let Some(signal) = signals.get(doc_idx) {
406                    if signal.confidence >= self.config.min_signal_confidence {
407                        signal_values.push(signal.value);
408
409                        let weight = self.get_signal_weight(signal_type, query, signal)?;
410                        signal_weights.push(weight);
411                    }
412                }
413            }
414
415            // Aggregate signals
416            let final_score = self.aggregate_signals(&signal_values, &signal_weights)?;
417            final_scores.insert(doc_idx, final_score);
418        }
419
420        Ok(final_scores)
421    }
422
423    /// Normalize signals based on configuration
424    fn normalize_signals(
425        &self,
426        signals: HashMap<SignalType, Vec<RelevanceSignal>>,
427    ) -> RragResult<HashMap<SignalType, Vec<RelevanceSignal>>> {
428        let mut normalized = HashMap::new();
429
430        for (signal_type, signal_list) in signals {
431            let normalized_list = match self.config.normalization {
432                SignalNormalization::MinMax => self.normalize_min_max(&signal_list),
433                SignalNormalization::ZScore => self.normalize_z_score(&signal_list),
434                SignalNormalization::Rank => self.normalize_rank(&signal_list),
435                SignalNormalization::Sigmoid => self.normalize_sigmoid(&signal_list),
436                SignalNormalization::None => signal_list,
437            };
438
439            normalized.insert(signal_type, normalized_list);
440        }
441
442        Ok(normalized)
443    }
444
445    /// Min-max normalization
446    fn normalize_min_max(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
447        let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
448        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
449        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
450
451        let range = max_val - min_val;
452        if range == 0.0 {
453            return signals.to_vec(); // No normalization needed
454        }
455
456        signals
457            .iter()
458            .map(|signal| {
459                let mut normalized = signal.clone();
460                normalized.value = (signal.value - min_val) / range;
461                normalized
462            })
463            .collect()
464    }
465
466    /// Z-score normalization
467    fn normalize_z_score(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
468        let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
469        let mean = values.iter().sum::<f32>() / values.len() as f32;
470        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
471        let std_dev = variance.sqrt();
472
473        if std_dev == 0.0 {
474            return signals.to_vec();
475        }
476
477        signals
478            .iter()
479            .map(|signal| {
480                let mut normalized = signal.clone();
481                normalized.value = (signal.value - mean) / std_dev;
482                // Convert to 0-1 range using sigmoid
483                normalized.value = 1.0 / (1.0 + (-normalized.value).exp());
484                normalized
485            })
486            .collect()
487    }
488
489    /// Rank normalization
490    fn normalize_rank(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
491        let mut indexed_signals: Vec<(usize, &RelevanceSignal)> =
492            signals.iter().enumerate().collect();
493
494        indexed_signals.sort_by(|a, b| {
495            b.1.value
496                .partial_cmp(&a.1.value)
497                .unwrap_or(std::cmp::Ordering::Equal)
498        });
499
500        let mut normalized = vec![signals[0].clone(); signals.len()];
501        for (rank, (original_idx, signal)) in indexed_signals.iter().enumerate() {
502            normalized[*original_idx] = (*signal).clone();
503            normalized[*original_idx].value = 1.0 - (rank as f32 / signals.len() as f32);
504        }
505
506        normalized
507    }
508
509    /// Sigmoid normalization
510    fn normalize_sigmoid(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
511        signals
512            .iter()
513            .map(|signal| {
514                let mut normalized = signal.clone();
515                normalized.value = 1.0 / (1.0 + (-signal.value).exp());
516                normalized
517            })
518            .collect()
519    }
520
521    /// Get weight for a specific signal
522    fn get_signal_weight(
523        &self,
524        signal_type: &SignalType,
525        query: &str,
526        _signal: &RelevanceSignal,
527    ) -> RragResult<f32> {
528        if let Some(weight_config) = self.config.signal_weights.get(signal_type) {
529            match weight_config {
530                SignalWeight::Fixed(w) => Ok(*w),
531                SignalWeight::Adaptive(w) => Ok(*w),
532                SignalWeight::Learned => {
533                    Ok(self.signal_weights.get(signal_type).copied().unwrap_or(1.0))
534                }
535                SignalWeight::QueryDependent(func) => Ok(func(query)),
536            }
537        } else {
538            Ok(1.0 / self.config.signal_weights.len() as f32) // Default uniform weight
539        }
540    }
541
542    /// Aggregate multiple signals into final score
543    fn aggregate_signals(&self, values: &[f32], weights: &[f32]) -> RragResult<f32> {
544        if values.is_empty() {
545            return Ok(0.0);
546        }
547
548        match &self.aggregation {
549            SignalAggregation::WeightedSum => {
550                Ok(values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum())
551            }
552            SignalAggregation::WeightedAverage => {
553                let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
554                let weight_sum: f32 = weights.iter().sum();
555                Ok(if weight_sum > 0.0 {
556                    weighted_sum / weight_sum
557                } else {
558                    0.0
559                })
560            }
561            SignalAggregation::Max => Ok(values.iter().fold(0.0f32, |a, &b| a.max(b))),
562            SignalAggregation::Min => Ok(values.iter().fold(1.0f32, |a, &b| a.min(b))),
563            SignalAggregation::LearnedCombination => {
564                // Would use a learned model - for now, use weighted average
565                let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
566                let weight_sum: f32 = weights.iter().sum();
567                Ok(if weight_sum > 0.0 {
568                    weighted_sum / weight_sum
569                } else {
570                    0.0
571                })
572            }
573            SignalAggregation::Custom(_) => {
574                // Custom aggregation would be implemented here
575                Ok(values.iter().sum::<f32>() / values.len() as f32)
576            }
577        }
578    }
579}
580
581// Signal extractors implementation would go here...
582// For brevity, I'll implement a few key ones:
583
584/// Extractor for semantic relevance signals
585struct SemanticRelevanceExtractor;
586
587impl SemanticRelevanceExtractor {
588    fn new() -> Self {
589        Self
590    }
591}
592
593impl SignalExtractor for SemanticRelevanceExtractor {
594    fn extract_signal(
595        &self,
596        _query: &str,
597        document: &SearchResult,
598        _context: &RetrievalContext,
599    ) -> RragResult<RelevanceSignal> {
600        // Use the existing search score as semantic relevance
601        Ok(RelevanceSignal {
602            signal_type: SignalType::SemanticRelevance,
603            value: document.score,
604            confidence: 0.8,
605            metadata: SignalMetadata {
606                source: "search_engine".to_string(),
607                extraction_time_ms: 1,
608                features: HashMap::new(),
609                warnings: Vec::new(),
610            },
611        })
612    }
613
614    fn signal_type(&self) -> SignalType {
615        SignalType::SemanticRelevance
616    }
617
618    fn get_config(&self) -> SignalExtractorConfig {
619        SignalExtractorConfig {
620            name: "SemanticRelevanceExtractor".to_string(),
621            version: "1.0".to_string(),
622            features: vec!["vector_similarity".to_string()],
623            performance: PerformanceMetrics {
624                avg_extraction_time_ms: 1.0,
625                accuracy: 0.8,
626                memory_usage_mb: 0.1,
627            },
628        }
629    }
630}
631
632/// Extractor for textual relevance (BM25-style)
633struct TextualRelevanceExtractor;
634
635impl TextualRelevanceExtractor {
636    fn new() -> Self {
637        Self
638    }
639}
640
641impl SignalExtractor for TextualRelevanceExtractor {
642    fn extract_signal(
643        &self,
644        query: &str,
645        document: &SearchResult,
646        _context: &RetrievalContext,
647    ) -> RragResult<RelevanceSignal> {
648        // Simple textual relevance based on term overlap
649        let query_terms: std::collections::HashSet<&str> = query.split_whitespace().collect();
650        let doc_terms: std::collections::HashSet<&str> =
651            document.content.split_whitespace().collect();
652
653        let intersection = query_terms.intersection(&doc_terms).count();
654        let union = query_terms.union(&doc_terms).count();
655
656        let jaccard = if union == 0 {
657            0.0
658        } else {
659            intersection as f32 / union as f32
660        };
661
662        Ok(RelevanceSignal {
663            signal_type: SignalType::TextualRelevance,
664            value: jaccard,
665            confidence: 0.7,
666            metadata: SignalMetadata {
667                source: "textual_analysis".to_string(),
668                extraction_time_ms: 2,
669                features: [
670                    ("intersection".to_string(), intersection as f32),
671                    ("union".to_string(), union as f32),
672                ]
673                .iter()
674                .cloned()
675                .collect(),
676                warnings: Vec::new(),
677            },
678        })
679    }
680
681    fn signal_type(&self) -> SignalType {
682        SignalType::TextualRelevance
683    }
684
685    fn get_config(&self) -> SignalExtractorConfig {
686        SignalExtractorConfig {
687            name: "TextualRelevanceExtractor".to_string(),
688            version: "1.0".to_string(),
689            features: vec!["term_overlap".to_string(), "jaccard_similarity".to_string()],
690            performance: PerformanceMetrics {
691                avg_extraction_time_ms: 2.0,
692                accuracy: 0.7,
693                memory_usage_mb: 0.05,
694            },
695        }
696    }
697}
698
699/// Extractor for document freshness
700struct DocumentFreshnessExtractor;
701
702impl DocumentFreshnessExtractor {
703    fn new() -> Self {
704        Self
705    }
706}
707
708impl SignalExtractor for DocumentFreshnessExtractor {
709    fn extract_signal(
710        &self,
711        _query: &str,
712        document: &SearchResult,
713        context: &RetrievalContext,
714    ) -> RragResult<RelevanceSignal> {
715        // Extract timestamp from document metadata or use current time as fallback
716        let doc_timestamp = document
717            .metadata
718            .get("timestamp")
719            .and_then(|v| v.as_str())
720            .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
721            .map(|dt| dt.with_timezone(&chrono::Utc))
722            .unwrap_or_else(|| context.timestamp - chrono::Duration::days(30)); // Default: 30 days old
723
724        let age_hours = (context.timestamp - doc_timestamp).num_hours() as f32;
725
726        // Exponential decay: newer documents get higher scores
727        let freshness = (-age_hours / (24.0 * 7.0)).exp().min(1.0); // 1 week half-life
728
729        Ok(RelevanceSignal {
730            signal_type: SignalType::DocumentFreshness,
731            value: freshness,
732            confidence: 0.9,
733            metadata: SignalMetadata {
734                source: "document_metadata".to_string(),
735                extraction_time_ms: 1,
736                features: [("age_hours".to_string(), age_hours)]
737                    .iter()
738                    .cloned()
739                    .collect(),
740                warnings: Vec::new(),
741            },
742        })
743    }
744
745    fn signal_type(&self) -> SignalType {
746        SignalType::DocumentFreshness
747    }
748
749    fn get_config(&self) -> SignalExtractorConfig {
750        SignalExtractorConfig {
751            name: "DocumentFreshnessExtractor".to_string(),
752            version: "1.0".to_string(),
753            features: vec!["temporal_decay".to_string()],
754            performance: PerformanceMetrics {
755                avg_extraction_time_ms: 1.0,
756                accuracy: 0.9,
757                memory_usage_mb: 0.01,
758            },
759        }
760    }
761}
762
763/// Extractor for document quality
764struct DocumentQualityExtractor;
765
766impl DocumentQualityExtractor {
767    fn new() -> Self {
768        Self
769    }
770}
771
772impl SignalExtractor for DocumentQualityExtractor {
773    fn extract_signal(
774        &self,
775        _query: &str,
776        document: &SearchResult,
777        _context: &RetrievalContext,
778    ) -> RragResult<RelevanceSignal> {
779        // Simple quality metrics based on document characteristics
780        let length = document.content.len() as f32;
781        let words = document.content.split_whitespace().count() as f32;
782        let sentences = document.content.split('.').count() as f32;
783
784        // Quality heuristics
785        let length_score = if length > 100.0 && length < 5000.0 {
786            1.0
787        } else {
788            0.5
789        };
790        let avg_word_length = if words > 0.0 { length / words } else { 0.0 };
791        let word_length_score = if avg_word_length > 3.0 && avg_word_length < 15.0 {
792            1.0
793        } else {
794            0.7
795        };
796        let sentence_length = if sentences > 0.0 {
797            words / sentences
798        } else {
799            0.0
800        };
801        let sentence_score = if sentence_length > 5.0 && sentence_length < 30.0 {
802            1.0
803        } else {
804            0.8
805        };
806
807        let quality_score = (length_score + word_length_score + sentence_score) / 3.0;
808
809        Ok(RelevanceSignal {
810            signal_type: SignalType::DocumentQuality,
811            value: quality_score,
812            confidence: 0.6,
813            metadata: SignalMetadata {
814                source: "quality_analysis".to_string(),
815                extraction_time_ms: 3,
816                features: [
817                    ("length".to_string(), length),
818                    ("word_count".to_string(), words),
819                    ("sentence_count".to_string(), sentences),
820                    ("avg_word_length".to_string(), avg_word_length),
821                    ("avg_sentence_length".to_string(), sentence_length),
822                ]
823                .iter()
824                .cloned()
825                .collect(),
826                warnings: Vec::new(),
827            },
828        })
829    }
830
831    fn signal_type(&self) -> SignalType {
832        SignalType::DocumentQuality
833    }
834
835    fn get_config(&self) -> SignalExtractorConfig {
836        SignalExtractorConfig {
837            name: "DocumentQualityExtractor".to_string(),
838            version: "1.0".to_string(),
839            features: vec![
840                "length_analysis".to_string(),
841                "structural_analysis".to_string(),
842            ],
843            performance: PerformanceMetrics {
844                avg_extraction_time_ms: 3.0,
845                accuracy: 0.6,
846                memory_usage_mb: 0.02,
847            },
848        }
849    }
850}
851
852// Placeholder implementations for other extractors
853macro_rules! impl_placeholder_extractor {
854    ($name:ident, $signal_type:expr, $default_value:expr) => {
855        struct $name;
856
857        impl $name {
858            fn new() -> Self {
859                Self
860            }
861        }
862
863        impl SignalExtractor for $name {
864            fn extract_signal(
865                &self,
866                _query: &str,
867                _document: &SearchResult,
868                _context: &RetrievalContext,
869            ) -> RragResult<RelevanceSignal> {
870                Ok(RelevanceSignal {
871                    signal_type: $signal_type,
872                    value: $default_value,
873                    confidence: 0.5,
874                    metadata: SignalMetadata {
875                        source: "placeholder".to_string(),
876                        extraction_time_ms: 1,
877                        features: HashMap::new(),
878                        warnings: vec!["Placeholder implementation".to_string()],
879                    },
880                })
881            }
882
883            fn signal_type(&self) -> SignalType {
884                $signal_type
885            }
886
887            fn get_config(&self) -> SignalExtractorConfig {
888                SignalExtractorConfig {
889                    name: stringify!($name).to_string(),
890                    version: "0.1".to_string(),
891                    features: vec!["placeholder".to_string()],
892                    performance: PerformanceMetrics {
893                        avg_extraction_time_ms: 1.0,
894                        accuracy: 0.5,
895                        memory_usage_mb: 0.01,
896                    },
897                }
898            }
899        }
900    };
901}
902
903impl_placeholder_extractor!(
904    DocumentAuthorityExtractor,
905    SignalType::DocumentAuthority,
906    0.5
907);
908impl_placeholder_extractor!(UserPreferenceExtractor, SignalType::UserPreference, 0.5);
909impl_placeholder_extractor!(ClickThroughRateExtractor, SignalType::ClickThroughRate, 0.5);
910impl_placeholder_extractor!(
911    DocumentPopularityExtractor,
912    SignalType::DocumentPopularity,
913    0.5
914);
915impl_placeholder_extractor!(
916    InteractionHistoryExtractor,
917    SignalType::InteractionHistory,
918    0.5
919);
920
921struct DomainSpecificExtractor {
922    domain: String,
923}
924
925impl DomainSpecificExtractor {
926    fn new(domain: String) -> Self {
927        Self { domain }
928    }
929}
930
931impl SignalExtractor for DomainSpecificExtractor {
932    fn extract_signal(
933        &self,
934        _query: &str,
935        _document: &SearchResult,
936        _context: &RetrievalContext,
937    ) -> RragResult<RelevanceSignal> {
938        Ok(RelevanceSignal {
939            signal_type: SignalType::DomainSpecific(self.domain.clone()),
940            value: 0.5,
941            confidence: 0.5,
942            metadata: SignalMetadata {
943                source: "domain_specific".to_string(),
944                extraction_time_ms: 1,
945                features: HashMap::new(),
946                warnings: vec!["Placeholder implementation".to_string()],
947            },
948        })
949    }
950
951    fn signal_type(&self) -> SignalType {
952        SignalType::DomainSpecific(self.domain.clone())
953    }
954
955    fn get_config(&self) -> SignalExtractorConfig {
956        SignalExtractorConfig {
957            name: format!("DomainSpecificExtractor({})", self.domain),
958            version: "0.1".to_string(),
959            features: vec!["domain_analysis".to_string()],
960            performance: PerformanceMetrics {
961                avg_extraction_time_ms: 1.0,
962                accuracy: 0.5,
963                memory_usage_mb: 0.01,
964            },
965        }
966    }
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972    use crate::SearchResult;
973
974    #[tokio::test]
975    async fn test_multi_signal_reranking() {
976        let config = MultiSignalConfig::default();
977        let reranker = MultiSignalReranker::new(config);
978
979        let results = vec![
980            SearchResult {
981                id: "doc1".to_string(),
982                content: "Machine learning is a subset of artificial intelligence that focuses on algorithms".to_string(),
983                score: 0.8,
984                rank: 0,
985                metadata: HashMap::new(),
986                embedding: None,
987            },
988            SearchResult {
989                id: "doc2".to_string(),
990                content: "AI".to_string(), // Short, low quality
991                score: 0.9,
992                rank: 1,
993                metadata: HashMap::new(),
994                embedding: None,
995            },
996        ];
997
998        let query = "What is machine learning in artificial intelligence?";
999        let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1000
1001        assert!(!reranked_scores.is_empty());
1002        // Doc1 should rank higher due to better quality despite lower initial score
1003        assert!(reranked_scores.get(&0).unwrap_or(&0.0) > &0.0);
1004    }
1005
1006    #[test]
1007    fn test_signal_normalization() {
1008        let config = MultiSignalConfig::default();
1009        let reranker = MultiSignalReranker::new(config);
1010
1011        let signals = vec![
1012            RelevanceSignal {
1013                signal_type: SignalType::SemanticRelevance,
1014                value: 0.1,
1015                confidence: 1.0,
1016                metadata: SignalMetadata {
1017                    source: "test".to_string(),
1018                    extraction_time_ms: 0,
1019                    features: HashMap::new(),
1020                    warnings: Vec::new(),
1021                },
1022            },
1023            RelevanceSignal {
1024                signal_type: SignalType::SemanticRelevance,
1025                value: 0.9,
1026                confidence: 1.0,
1027                metadata: SignalMetadata {
1028                    source: "test".to_string(),
1029                    extraction_time_ms: 0,
1030                    features: HashMap::new(),
1031                    warnings: Vec::new(),
1032                },
1033            },
1034        ];
1035
1036        let normalized = reranker.normalize_min_max(&signals);
1037        assert_eq!(normalized[0].value, 0.0); // Min becomes 0
1038        assert_eq!(normalized[1].value, 1.0); // Max becomes 1
1039    }
1040}