rexis_rag/reranking/
multi_signal.rs

1//! # Multi-Signal Reranking
2//!
3//! Combines multiple relevance signals beyond semantic similarity to improve
4//! retrieval accuracy. Includes signals like freshness, authority, click-through
5//! rates, document quality, and user preferences.
6
7use crate::{RragResult, SearchResult};
8use std::collections::HashMap;
9use tracing::warn;
10
11/// Multi-signal reranker that combines various relevance signals
12pub struct MultiSignalReranker {
13    /// Configuration
14    config: MultiSignalConfig,
15
16    /// Signal extractors
17    signal_extractors: HashMap<SignalType, Box<dyn SignalExtractor>>,
18
19    /// Signal weights (learned or configured)
20    signal_weights: HashMap<SignalType, f32>,
21
22    /// Signal aggregation method
23    aggregation: SignalAggregation,
24}
25
26/// Configuration for multi-signal reranking
27#[derive(Debug, Clone)]
28pub struct MultiSignalConfig {
29    /// Enabled signal types
30    pub enabled_signals: Vec<SignalType>,
31
32    /// Signal weights
33    pub signal_weights: HashMap<SignalType, SignalWeight>,
34
35    /// Aggregation method
36    pub aggregation_method: SignalAggregation,
37
38    /// Normalization method
39    pub normalization: SignalNormalization,
40
41    /// Minimum signal confidence
42    pub min_signal_confidence: f32,
43
44    /// Enable adaptive weighting
45    pub enable_adaptive_weights: bool,
46
47    /// Learning rate for adaptive weights
48    pub learning_rate: f32,
49}
50
51impl Default for MultiSignalConfig {
52    fn default() -> Self {
53        let mut signal_weights = HashMap::new();
54        signal_weights.insert(SignalType::SemanticRelevance, SignalWeight::Fixed(0.3));
55        signal_weights.insert(SignalType::TextualRelevance, SignalWeight::Fixed(0.25));
56        signal_weights.insert(SignalType::DocumentFreshness, SignalWeight::Fixed(0.15));
57        signal_weights.insert(SignalType::DocumentAuthority, SignalWeight::Fixed(0.1));
58        signal_weights.insert(SignalType::DocumentQuality, SignalWeight::Fixed(0.1));
59        signal_weights.insert(SignalType::UserPreference, SignalWeight::Fixed(0.05));
60        signal_weights.insert(SignalType::ClickThroughRate, SignalWeight::Fixed(0.05));
61
62        Self {
63            enabled_signals: vec![
64                SignalType::SemanticRelevance,
65                SignalType::TextualRelevance,
66                SignalType::DocumentFreshness,
67                SignalType::DocumentQuality,
68            ],
69            signal_weights,
70            aggregation_method: SignalAggregation::WeightedSum,
71            normalization: SignalNormalization::MinMax,
72            min_signal_confidence: 0.1,
73            enable_adaptive_weights: false,
74            learning_rate: 0.01,
75        }
76    }
77}
78
79/// Types of relevance signals
80#[derive(Debug, Clone, Hash, PartialEq, Eq)]
81pub enum SignalType {
82    /// Semantic similarity between query and document
83    SemanticRelevance,
84    /// Textual/keyword relevance (BM25, TF-IDF)
85    TextualRelevance,
86    /// Document freshness/recency
87    DocumentFreshness,
88    /// Document authority/credibility
89    DocumentAuthority,
90    /// Document quality metrics
91    DocumentQuality,
92    /// User preference signals
93    UserPreference,
94    /// Click-through rates
95    ClickThroughRate,
96    /// Document popularity
97    DocumentPopularity,
98    /// Query-document interaction history
99    InteractionHistory,
100    /// Domain-specific signals
101    DomainSpecific(String),
102}
103
104/// Signal weight configuration
105pub enum SignalWeight {
106    /// Fixed weight
107    Fixed(f32),
108    /// Query-dependent weight
109    QueryDependent(Box<dyn Fn(&str) -> f32 + Send + Sync>),
110    /// Learned weight (requires training data)
111    Learned,
112    /// Adaptive weight (updates based on feedback)
113    Adaptive(f32), // Current weight
114}
115
116impl std::fmt::Debug for SignalWeight {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Fixed(w) => write!(f, "Fixed({})", w),
120            Self::QueryDependent(_) => write!(f, "QueryDependent(<function>)"),
121            Self::Learned => write!(f, "Learned"),
122            Self::Adaptive(w) => write!(f, "Adaptive({})", w),
123        }
124    }
125}
126
127impl Clone for SignalWeight {
128    fn clone(&self) -> Self {
129        match self {
130            Self::Fixed(w) => Self::Fixed(*w),
131            Self::QueryDependent(_) => Self::Fixed(0.5), // Can't clone function, default to fixed
132            Self::Learned => Self::Learned,
133            Self::Adaptive(w) => Self::Adaptive(*w),
134        }
135    }
136}
137
138/// Methods for aggregating multiple signals
139#[derive(Debug, Clone)]
140pub enum SignalAggregation {
141    /// Weighted sum of signals
142    WeightedSum,
143    /// Weighted average
144    WeightedAverage,
145    /// Maximum signal value
146    Max,
147    /// Minimum signal value
148    Min,
149    /// Learning-to-rank combination
150    LearnedCombination,
151    /// Custom aggregation function
152    Custom(String),
153}
154
155/// Signal normalization methods
156#[derive(Debug, Clone)]
157pub enum SignalNormalization {
158    /// Min-max normalization (0-1 range)
159    MinMax,
160    /// Z-score normalization
161    ZScore,
162    /// Rank normalization
163    Rank,
164    /// Sigmoid normalization
165    Sigmoid,
166    /// No normalization
167    None,
168}
169
170/// A relevance signal extracted from query-document pair
171#[derive(Debug, Clone)]
172pub struct RelevanceSignal {
173    /// Type of signal
174    pub signal_type: SignalType,
175
176    /// Signal value (typically 0-1)
177    pub value: f32,
178
179    /// Confidence in the signal (0-1)
180    pub confidence: f32,
181
182    /// Signal metadata
183    pub metadata: SignalMetadata,
184}
185
186/// Metadata about a relevance signal
187#[derive(Debug, Clone)]
188pub struct SignalMetadata {
189    /// Source of the signal
190    pub source: String,
191
192    /// Extraction time
193    pub extraction_time_ms: u64,
194
195    /// Features used
196    pub features: HashMap<String, f32>,
197
198    /// Warnings or notes
199    pub warnings: Vec<String>,
200}
201
202/// Trait for extracting relevance signals
203pub trait SignalExtractor: Send + Sync {
204    /// Extract signal from query-document pair
205    fn extract_signal(
206        &self,
207        query: &str,
208        document: &SearchResult,
209        context: &RetrievalContext,
210    ) -> RragResult<RelevanceSignal>;
211
212    /// Extract signals for multiple documents in batch
213    fn extract_batch(
214        &self,
215        query: &str,
216        documents: &[SearchResult],
217        context: &RetrievalContext,
218    ) -> RragResult<Vec<RelevanceSignal>> {
219        documents
220            .iter()
221            .map(|doc| self.extract_signal(query, doc, context))
222            .collect()
223    }
224
225    /// Get signal type
226    fn signal_type(&self) -> SignalType;
227
228    /// Get extractor configuration
229    fn get_config(&self) -> SignalExtractorConfig;
230}
231
232/// Configuration for signal extractors
233#[derive(Debug, Clone)]
234pub struct SignalExtractorConfig {
235    /// Extractor name
236    pub name: String,
237
238    /// Extractor version
239    pub version: String,
240
241    /// Supported features
242    pub features: Vec<String>,
243
244    /// Performance characteristics
245    pub performance: PerformanceMetrics,
246}
247
248/// Performance metrics for signal extractors
249#[derive(Debug, Clone)]
250pub struct PerformanceMetrics {
251    /// Average extraction time (ms)
252    pub avg_extraction_time_ms: f32,
253
254    /// Accuracy/precision of the signal
255    pub accuracy: f32,
256
257    /// Memory usage (MB)
258    pub memory_usage_mb: f32,
259}
260
261/// Context for signal extraction
262#[derive(Debug, Clone)]
263pub struct RetrievalContext {
264    /// User identifier (if available)
265    pub user_id: Option<String>,
266
267    /// Session information
268    pub session_id: Option<String>,
269
270    /// Query timestamp
271    pub timestamp: chrono::DateTime<chrono::Utc>,
272
273    /// Query type/intent
274    pub query_intent: Option<String>,
275
276    /// User preferences
277    pub user_preferences: HashMap<String, f32>,
278
279    /// Historical interactions
280    pub interaction_history: Vec<InteractionRecord>,
281}
282
283/// Historical interaction record
284#[derive(Debug, Clone)]
285pub struct InteractionRecord {
286    /// Document ID
287    pub document_id: String,
288
289    /// Interaction type (click, dwell, etc.)
290    pub interaction_type: String,
291
292    /// Interaction timestamp
293    pub timestamp: chrono::DateTime<chrono::Utc>,
294
295    /// Interaction value/strength
296    pub value: f32,
297}
298
299impl MultiSignalReranker {
300    /// Create a new multi-signal reranker
301    pub fn new(config: MultiSignalConfig) -> Self {
302        let mut reranker = Self {
303            config: config.clone(),
304            signal_extractors: HashMap::new(),
305            signal_weights: HashMap::new(),
306            aggregation: config.aggregation_method.clone(),
307        };
308
309        // Initialize signal extractors
310        reranker.initialize_extractors();
311
312        // Initialize weights
313        reranker.initialize_weights();
314
315        reranker
316    }
317
318    /// Initialize signal extractors based on configuration
319    fn initialize_extractors(&mut self) {
320        for signal_type in &self.config.enabled_signals {
321            let extractor: Box<dyn SignalExtractor> = match signal_type {
322                SignalType::SemanticRelevance => Box::new(SemanticRelevanceExtractor::new()),
323                SignalType::TextualRelevance => Box::new(TextualRelevanceExtractor::new()),
324                SignalType::DocumentFreshness => Box::new(DocumentFreshnessExtractor::new()),
325                SignalType::DocumentAuthority => Box::new(DocumentAuthorityExtractor::new()),
326                SignalType::DocumentQuality => Box::new(DocumentQualityExtractor::new()),
327                SignalType::UserPreference => Box::new(UserPreferenceExtractor::new()),
328                SignalType::ClickThroughRate => Box::new(ClickThroughRateExtractor::new()),
329                SignalType::DocumentPopularity => Box::new(DocumentPopularityExtractor::new()),
330                SignalType::InteractionHistory => Box::new(InteractionHistoryExtractor::new()),
331                SignalType::DomainSpecific(domain) => {
332                    Box::new(DomainSpecificExtractor::new(domain.clone()))
333                }
334            };
335
336            self.signal_extractors
337                .insert(signal_type.clone(), extractor);
338        }
339    }
340
341    /// Initialize signal weights
342    fn initialize_weights(&mut self) {
343        for (signal_type, weight_config) in &self.config.signal_weights {
344            let weight = match weight_config {
345                SignalWeight::Fixed(w) => *w,
346                SignalWeight::Adaptive(w) => *w,
347                SignalWeight::Learned => 1.0 / self.config.signal_weights.len() as f32, // Default uniform
348                SignalWeight::QueryDependent(_) => 1.0, // Will be computed per query
349            };
350
351            self.signal_weights.insert(signal_type.clone(), weight);
352        }
353    }
354
355    /// Rerank search results using multiple signals
356    pub async fn rerank(
357        &self,
358        query: &str,
359        results: &[SearchResult],
360    ) -> RragResult<HashMap<usize, f32>> {
361        let context = RetrievalContext {
362            user_id: None,
363            session_id: None,
364            timestamp: chrono::Utc::now(),
365            query_intent: None,
366            user_preferences: HashMap::new(),
367            interaction_history: Vec::new(),
368        };
369
370        self.rerank_with_context(query, results, &context).await
371    }
372
373    /// Rerank with full context information
374    pub async fn rerank_with_context(
375        &self,
376        query: &str,
377        results: &[SearchResult],
378        context: &RetrievalContext,
379    ) -> RragResult<HashMap<usize, f32>> {
380        let mut final_scores = HashMap::new();
381
382        // Extract all signals for all documents
383        let mut all_signals: HashMap<SignalType, Vec<RelevanceSignal>> = HashMap::new();
384
385        for (signal_type, extractor) in &self.signal_extractors {
386            match extractor.extract_batch(query, results, context) {
387                Ok(signals) => {
388                    all_signals.insert(signal_type.clone(), signals);
389                }
390                Err(e) => {
391                    warn!(" Failed to extract signal {:?}: {}", signal_type, e);
392                    // Continue with other signals
393                }
394            }
395        }
396
397        // Normalize signals if needed
398        let normalized_signals = self.normalize_signals(all_signals)?;
399
400        // Compute final scores for each document
401        for (doc_idx, _) in results.iter().enumerate() {
402            let mut signal_values = Vec::new();
403            let mut signal_weights = Vec::new();
404
405            for (signal_type, signals) in &normalized_signals {
406                if let Some(signal) = signals.get(doc_idx) {
407                    if signal.confidence >= self.config.min_signal_confidence {
408                        signal_values.push(signal.value);
409
410                        let weight = self.get_signal_weight(signal_type, query, signal)?;
411                        signal_weights.push(weight);
412                    }
413                }
414            }
415
416            // Aggregate signals
417            let final_score = self.aggregate_signals(&signal_values, &signal_weights)?;
418            final_scores.insert(doc_idx, final_score);
419        }
420
421        Ok(final_scores)
422    }
423
424    /// Normalize signals based on configuration
425    fn normalize_signals(
426        &self,
427        signals: HashMap<SignalType, Vec<RelevanceSignal>>,
428    ) -> RragResult<HashMap<SignalType, Vec<RelevanceSignal>>> {
429        let mut normalized = HashMap::new();
430
431        for (signal_type, signal_list) in signals {
432            let normalized_list = match self.config.normalization {
433                SignalNormalization::MinMax => self.normalize_min_max(&signal_list),
434                SignalNormalization::ZScore => self.normalize_z_score(&signal_list),
435                SignalNormalization::Rank => self.normalize_rank(&signal_list),
436                SignalNormalization::Sigmoid => self.normalize_sigmoid(&signal_list),
437                SignalNormalization::None => signal_list,
438            };
439
440            normalized.insert(signal_type, normalized_list);
441        }
442
443        Ok(normalized)
444    }
445
446    /// Min-max normalization
447    fn normalize_min_max(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
448        let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
449        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
450        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
451
452        let range = max_val - min_val;
453        if range == 0.0 {
454            return signals.to_vec(); // No normalization needed
455        }
456
457        signals
458            .iter()
459            .map(|signal| {
460                let mut normalized = signal.clone();
461                normalized.value = (signal.value - min_val) / range;
462                normalized
463            })
464            .collect()
465    }
466
467    /// Z-score normalization
468    fn normalize_z_score(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
469        let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
470        let mean = values.iter().sum::<f32>() / values.len() as f32;
471        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
472        let std_dev = variance.sqrt();
473
474        if std_dev == 0.0 {
475            return signals.to_vec();
476        }
477
478        signals
479            .iter()
480            .map(|signal| {
481                let mut normalized = signal.clone();
482                normalized.value = (signal.value - mean) / std_dev;
483                // Convert to 0-1 range using sigmoid
484                normalized.value = 1.0 / (1.0 + (-normalized.value).exp());
485                normalized
486            })
487            .collect()
488    }
489
490    /// Rank normalization
491    fn normalize_rank(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
492        let mut indexed_signals: Vec<(usize, &RelevanceSignal)> =
493            signals.iter().enumerate().collect();
494
495        indexed_signals.sort_by(|a, b| {
496            b.1.value
497                .partial_cmp(&a.1.value)
498                .unwrap_or(std::cmp::Ordering::Equal)
499        });
500
501        let mut normalized = vec![signals[0].clone(); signals.len()];
502        for (rank, (original_idx, signal)) in indexed_signals.iter().enumerate() {
503            normalized[*original_idx] = (*signal).clone();
504            normalized[*original_idx].value = 1.0 - (rank as f32 / signals.len() as f32);
505        }
506
507        normalized
508    }
509
510    /// Sigmoid normalization
511    fn normalize_sigmoid(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
512        signals
513            .iter()
514            .map(|signal| {
515                let mut normalized = signal.clone();
516                normalized.value = 1.0 / (1.0 + (-signal.value).exp());
517                normalized
518            })
519            .collect()
520    }
521
522    /// Get weight for a specific signal
523    fn get_signal_weight(
524        &self,
525        signal_type: &SignalType,
526        query: &str,
527        _signal: &RelevanceSignal,
528    ) -> RragResult<f32> {
529        if let Some(weight_config) = self.config.signal_weights.get(signal_type) {
530            match weight_config {
531                SignalWeight::Fixed(w) => Ok(*w),
532                SignalWeight::Adaptive(w) => Ok(*w),
533                SignalWeight::Learned => {
534                    Ok(self.signal_weights.get(signal_type).copied().unwrap_or(1.0))
535                }
536                SignalWeight::QueryDependent(func) => Ok(func(query)),
537            }
538        } else {
539            Ok(1.0 / self.config.signal_weights.len() as f32) // Default uniform weight
540        }
541    }
542
543    /// Aggregate multiple signals into final score
544    fn aggregate_signals(&self, values: &[f32], weights: &[f32]) -> RragResult<f32> {
545        if values.is_empty() {
546            return Ok(0.0);
547        }
548
549        match &self.aggregation {
550            SignalAggregation::WeightedSum => {
551                Ok(values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum())
552            }
553            SignalAggregation::WeightedAverage => {
554                let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
555                let weight_sum: f32 = weights.iter().sum();
556                Ok(if weight_sum > 0.0 {
557                    weighted_sum / weight_sum
558                } else {
559                    0.0
560                })
561            }
562            SignalAggregation::Max => Ok(values.iter().fold(0.0f32, |a, &b| a.max(b))),
563            SignalAggregation::Min => Ok(values.iter().fold(1.0f32, |a, &b| a.min(b))),
564            SignalAggregation::LearnedCombination => {
565                // Would use a learned model - for now, use weighted average
566                let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
567                let weight_sum: f32 = weights.iter().sum();
568                Ok(if weight_sum > 0.0 {
569                    weighted_sum / weight_sum
570                } else {
571                    0.0
572                })
573            }
574            SignalAggregation::Custom(_) => {
575                // Custom aggregation would be implemented here
576                Ok(values.iter().sum::<f32>() / values.len() as f32)
577            }
578        }
579    }
580}
581
582// Signal extractors implementation would go here...
583// For brevity, I'll implement a few key ones:
584
585/// Extractor for semantic relevance signals
586struct SemanticRelevanceExtractor;
587
588impl SemanticRelevanceExtractor {
589    fn new() -> Self {
590        Self
591    }
592}
593
594impl SignalExtractor for SemanticRelevanceExtractor {
595    fn extract_signal(
596        &self,
597        _query: &str,
598        document: &SearchResult,
599        _context: &RetrievalContext,
600    ) -> RragResult<RelevanceSignal> {
601        // Use the existing search score as semantic relevance
602        Ok(RelevanceSignal {
603            signal_type: SignalType::SemanticRelevance,
604            value: document.score,
605            confidence: 0.8,
606            metadata: SignalMetadata {
607                source: "search_engine".to_string(),
608                extraction_time_ms: 1,
609                features: HashMap::new(),
610                warnings: Vec::new(),
611            },
612        })
613    }
614
615    fn signal_type(&self) -> SignalType {
616        SignalType::SemanticRelevance
617    }
618
619    fn get_config(&self) -> SignalExtractorConfig {
620        SignalExtractorConfig {
621            name: "SemanticRelevanceExtractor".to_string(),
622            version: "1.0".to_string(),
623            features: vec!["vector_similarity".to_string()],
624            performance: PerformanceMetrics {
625                avg_extraction_time_ms: 1.0,
626                accuracy: 0.8,
627                memory_usage_mb: 0.1,
628            },
629        }
630    }
631}
632
633/// Extractor for textual relevance (BM25-style)
634struct TextualRelevanceExtractor;
635
636impl TextualRelevanceExtractor {
637    fn new() -> Self {
638        Self
639    }
640}
641
642impl SignalExtractor for TextualRelevanceExtractor {
643    fn extract_signal(
644        &self,
645        query: &str,
646        document: &SearchResult,
647        _context: &RetrievalContext,
648    ) -> RragResult<RelevanceSignal> {
649        // Simple textual relevance based on term overlap
650        let query_terms: std::collections::HashSet<&str> = query.split_whitespace().collect();
651        let doc_terms: std::collections::HashSet<&str> =
652            document.content.split_whitespace().collect();
653
654        let intersection = query_terms.intersection(&doc_terms).count();
655        let union = query_terms.union(&doc_terms).count();
656
657        let jaccard = if union == 0 {
658            0.0
659        } else {
660            intersection as f32 / union as f32
661        };
662
663        Ok(RelevanceSignal {
664            signal_type: SignalType::TextualRelevance,
665            value: jaccard,
666            confidence: 0.7,
667            metadata: SignalMetadata {
668                source: "textual_analysis".to_string(),
669                extraction_time_ms: 2,
670                features: [
671                    ("intersection".to_string(), intersection as f32),
672                    ("union".to_string(), union as f32),
673                ]
674                .iter()
675                .cloned()
676                .collect(),
677                warnings: Vec::new(),
678            },
679        })
680    }
681
682    fn signal_type(&self) -> SignalType {
683        SignalType::TextualRelevance
684    }
685
686    fn get_config(&self) -> SignalExtractorConfig {
687        SignalExtractorConfig {
688            name: "TextualRelevanceExtractor".to_string(),
689            version: "1.0".to_string(),
690            features: vec!["term_overlap".to_string(), "jaccard_similarity".to_string()],
691            performance: PerformanceMetrics {
692                avg_extraction_time_ms: 2.0,
693                accuracy: 0.7,
694                memory_usage_mb: 0.05,
695            },
696        }
697    }
698}
699
700/// Extractor for document freshness
701struct DocumentFreshnessExtractor;
702
703impl DocumentFreshnessExtractor {
704    fn new() -> Self {
705        Self
706    }
707}
708
709impl SignalExtractor for DocumentFreshnessExtractor {
710    fn extract_signal(
711        &self,
712        _query: &str,
713        document: &SearchResult,
714        context: &RetrievalContext,
715    ) -> RragResult<RelevanceSignal> {
716        // Extract timestamp from document metadata or use current time as fallback
717        let doc_timestamp = document
718            .metadata
719            .get("timestamp")
720            .and_then(|v| v.as_str())
721            .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
722            .map(|dt| dt.with_timezone(&chrono::Utc))
723            .unwrap_or_else(|| context.timestamp - chrono::Duration::days(30)); // Default: 30 days old
724
725        let age_hours = (context.timestamp - doc_timestamp).num_hours() as f32;
726
727        // Exponential decay: newer documents get higher scores
728        let freshness = (-age_hours / (24.0 * 7.0)).exp().min(1.0); // 1 week half-life
729
730        Ok(RelevanceSignal {
731            signal_type: SignalType::DocumentFreshness,
732            value: freshness,
733            confidence: 0.9,
734            metadata: SignalMetadata {
735                source: "document_metadata".to_string(),
736                extraction_time_ms: 1,
737                features: [("age_hours".to_string(), age_hours)]
738                    .iter()
739                    .cloned()
740                    .collect(),
741                warnings: Vec::new(),
742            },
743        })
744    }
745
746    fn signal_type(&self) -> SignalType {
747        SignalType::DocumentFreshness
748    }
749
750    fn get_config(&self) -> SignalExtractorConfig {
751        SignalExtractorConfig {
752            name: "DocumentFreshnessExtractor".to_string(),
753            version: "1.0".to_string(),
754            features: vec!["temporal_decay".to_string()],
755            performance: PerformanceMetrics {
756                avg_extraction_time_ms: 1.0,
757                accuracy: 0.9,
758                memory_usage_mb: 0.01,
759            },
760        }
761    }
762}
763
764/// Extractor for document quality
765struct DocumentQualityExtractor;
766
767impl DocumentQualityExtractor {
768    fn new() -> Self {
769        Self
770    }
771}
772
773impl SignalExtractor for DocumentQualityExtractor {
774    fn extract_signal(
775        &self,
776        _query: &str,
777        document: &SearchResult,
778        _context: &RetrievalContext,
779    ) -> RragResult<RelevanceSignal> {
780        // Simple quality metrics based on document characteristics
781        let length = document.content.len() as f32;
782        let words = document.content.split_whitespace().count() as f32;
783        let sentences = document.content.split('.').count() as f32;
784
785        // Quality heuristics
786        let length_score = if length > 100.0 && length < 5000.0 {
787            1.0
788        } else {
789            0.5
790        };
791        let avg_word_length = if words > 0.0 { length / words } else { 0.0 };
792        let word_length_score = if avg_word_length > 3.0 && avg_word_length < 15.0 {
793            1.0
794        } else {
795            0.7
796        };
797        let sentence_length = if sentences > 0.0 {
798            words / sentences
799        } else {
800            0.0
801        };
802        let sentence_score = if sentence_length > 5.0 && sentence_length < 30.0 {
803            1.0
804        } else {
805            0.8
806        };
807
808        let quality_score = (length_score + word_length_score + sentence_score) / 3.0;
809
810        Ok(RelevanceSignal {
811            signal_type: SignalType::DocumentQuality,
812            value: quality_score,
813            confidence: 0.6,
814            metadata: SignalMetadata {
815                source: "quality_analysis".to_string(),
816                extraction_time_ms: 3,
817                features: [
818                    ("length".to_string(), length),
819                    ("word_count".to_string(), words),
820                    ("sentence_count".to_string(), sentences),
821                    ("avg_word_length".to_string(), avg_word_length),
822                    ("avg_sentence_length".to_string(), sentence_length),
823                ]
824                .iter()
825                .cloned()
826                .collect(),
827                warnings: Vec::new(),
828            },
829        })
830    }
831
832    fn signal_type(&self) -> SignalType {
833        SignalType::DocumentQuality
834    }
835
836    fn get_config(&self) -> SignalExtractorConfig {
837        SignalExtractorConfig {
838            name: "DocumentQualityExtractor".to_string(),
839            version: "1.0".to_string(),
840            features: vec![
841                "length_analysis".to_string(),
842                "structural_analysis".to_string(),
843            ],
844            performance: PerformanceMetrics {
845                avg_extraction_time_ms: 3.0,
846                accuracy: 0.6,
847                memory_usage_mb: 0.02,
848            },
849        }
850    }
851}
852
853// Placeholder implementations for other extractors
854macro_rules! impl_placeholder_extractor {
855    ($name:ident, $signal_type:expr, $default_value:expr) => {
856        struct $name;
857
858        impl $name {
859            fn new() -> Self {
860                Self
861            }
862        }
863
864        impl SignalExtractor for $name {
865            fn extract_signal(
866                &self,
867                _query: &str,
868                _document: &SearchResult,
869                _context: &RetrievalContext,
870            ) -> RragResult<RelevanceSignal> {
871                Ok(RelevanceSignal {
872                    signal_type: $signal_type,
873                    value: $default_value,
874                    confidence: 0.5,
875                    metadata: SignalMetadata {
876                        source: "placeholder".to_string(),
877                        extraction_time_ms: 1,
878                        features: HashMap::new(),
879                        warnings: vec!["Placeholder implementation".to_string()],
880                    },
881                })
882            }
883
884            fn signal_type(&self) -> SignalType {
885                $signal_type
886            }
887
888            fn get_config(&self) -> SignalExtractorConfig {
889                SignalExtractorConfig {
890                    name: stringify!($name).to_string(),
891                    version: "0.1".to_string(),
892                    features: vec!["placeholder".to_string()],
893                    performance: PerformanceMetrics {
894                        avg_extraction_time_ms: 1.0,
895                        accuracy: 0.5,
896                        memory_usage_mb: 0.01,
897                    },
898                }
899            }
900        }
901    };
902}
903
904impl_placeholder_extractor!(
905    DocumentAuthorityExtractor,
906    SignalType::DocumentAuthority,
907    0.5
908);
909impl_placeholder_extractor!(UserPreferenceExtractor, SignalType::UserPreference, 0.5);
910impl_placeholder_extractor!(ClickThroughRateExtractor, SignalType::ClickThroughRate, 0.5);
911impl_placeholder_extractor!(
912    DocumentPopularityExtractor,
913    SignalType::DocumentPopularity,
914    0.5
915);
916impl_placeholder_extractor!(
917    InteractionHistoryExtractor,
918    SignalType::InteractionHistory,
919    0.5
920);
921
922struct DomainSpecificExtractor {
923    domain: String,
924}
925
926impl DomainSpecificExtractor {
927    fn new(domain: String) -> Self {
928        Self { domain }
929    }
930}
931
932impl SignalExtractor for DomainSpecificExtractor {
933    fn extract_signal(
934        &self,
935        _query: &str,
936        _document: &SearchResult,
937        _context: &RetrievalContext,
938    ) -> RragResult<RelevanceSignal> {
939        Ok(RelevanceSignal {
940            signal_type: SignalType::DomainSpecific(self.domain.clone()),
941            value: 0.5,
942            confidence: 0.5,
943            metadata: SignalMetadata {
944                source: "domain_specific".to_string(),
945                extraction_time_ms: 1,
946                features: HashMap::new(),
947                warnings: vec!["Placeholder implementation".to_string()],
948            },
949        })
950    }
951
952    fn signal_type(&self) -> SignalType {
953        SignalType::DomainSpecific(self.domain.clone())
954    }
955
956    fn get_config(&self) -> SignalExtractorConfig {
957        SignalExtractorConfig {
958            name: format!("DomainSpecificExtractor({})", self.domain),
959            version: "0.1".to_string(),
960            features: vec!["domain_analysis".to_string()],
961            performance: PerformanceMetrics {
962                avg_extraction_time_ms: 1.0,
963                accuracy: 0.5,
964                memory_usage_mb: 0.01,
965            },
966        }
967    }
968}
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973    use crate::SearchResult;
974
975    #[tokio::test]
976    async fn test_multi_signal_reranking() {
977        let config = MultiSignalConfig::default();
978        let reranker = MultiSignalReranker::new(config);
979
980        let results = vec![
981            SearchResult {
982                id: "doc1".to_string(),
983                content: "Machine learning is a subset of artificial intelligence that focuses on algorithms".to_string(),
984                score: 0.8,
985                rank: 0,
986                metadata: HashMap::new(),
987                embedding: None,
988            },
989            SearchResult {
990                id: "doc2".to_string(),
991                content: "AI".to_string(), // Short, low quality
992                score: 0.9,
993                rank: 1,
994                metadata: HashMap::new(),
995                embedding: None,
996            },
997        ];
998
999        let query = "What is machine learning in artificial intelligence?";
1000        let reranked_scores = reranker.rerank(query, &results).await.unwrap();
1001
1002        assert!(!reranked_scores.is_empty());
1003        // Doc1 should rank higher due to better quality despite lower initial score
1004        assert!(reranked_scores.get(&0).unwrap_or(&0.0) > &0.0);
1005    }
1006
1007    #[test]
1008    fn test_signal_normalization() {
1009        let config = MultiSignalConfig::default();
1010        let reranker = MultiSignalReranker::new(config);
1011
1012        let signals = vec![
1013            RelevanceSignal {
1014                signal_type: SignalType::SemanticRelevance,
1015                value: 0.1,
1016                confidence: 1.0,
1017                metadata: SignalMetadata {
1018                    source: "test".to_string(),
1019                    extraction_time_ms: 0,
1020                    features: HashMap::new(),
1021                    warnings: Vec::new(),
1022                },
1023            },
1024            RelevanceSignal {
1025                signal_type: SignalType::SemanticRelevance,
1026                value: 0.9,
1027                confidence: 1.0,
1028                metadata: SignalMetadata {
1029                    source: "test".to_string(),
1030                    extraction_time_ms: 0,
1031                    features: HashMap::new(),
1032                    warnings: Vec::new(),
1033                },
1034            },
1035        ];
1036
1037        let normalized = reranker.normalize_min_max(&signals);
1038        assert_eq!(normalized[0].value, 0.0); // Min becomes 0
1039        assert_eq!(normalized[1].value, 1.0); // Max becomes 1
1040    }
1041}