use crate::{RragResult, SearchResult};
use std::collections::HashMap;
pub struct MultiSignalReranker {
config: MultiSignalConfig,
signal_extractors: HashMap<SignalType, Box<dyn SignalExtractor>>,
signal_weights: HashMap<SignalType, f32>,
aggregation: SignalAggregation,
}
#[derive(Debug, Clone)]
pub struct MultiSignalConfig {
pub enabled_signals: Vec<SignalType>,
pub signal_weights: HashMap<SignalType, SignalWeight>,
pub aggregation_method: SignalAggregation,
pub normalization: SignalNormalization,
pub min_signal_confidence: f32,
pub enable_adaptive_weights: bool,
pub learning_rate: f32,
}
impl Default for MultiSignalConfig {
fn default() -> Self {
let mut signal_weights = HashMap::new();
signal_weights.insert(SignalType::SemanticRelevance, SignalWeight::Fixed(0.3));
signal_weights.insert(SignalType::TextualRelevance, SignalWeight::Fixed(0.25));
signal_weights.insert(SignalType::DocumentFreshness, SignalWeight::Fixed(0.15));
signal_weights.insert(SignalType::DocumentAuthority, SignalWeight::Fixed(0.1));
signal_weights.insert(SignalType::DocumentQuality, SignalWeight::Fixed(0.1));
signal_weights.insert(SignalType::UserPreference, SignalWeight::Fixed(0.05));
signal_weights.insert(SignalType::ClickThroughRate, SignalWeight::Fixed(0.05));
Self {
enabled_signals: vec![
SignalType::SemanticRelevance,
SignalType::TextualRelevance,
SignalType::DocumentFreshness,
SignalType::DocumentQuality,
],
signal_weights,
aggregation_method: SignalAggregation::WeightedSum,
normalization: SignalNormalization::MinMax,
min_signal_confidence: 0.1,
enable_adaptive_weights: false,
learning_rate: 0.01,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum SignalType {
SemanticRelevance,
TextualRelevance,
DocumentFreshness,
DocumentAuthority,
DocumentQuality,
UserPreference,
ClickThroughRate,
DocumentPopularity,
InteractionHistory,
DomainSpecific(String),
}
pub enum SignalWeight {
Fixed(f32),
QueryDependent(Box<dyn Fn(&str) -> f32 + Send + Sync>),
Learned,
Adaptive(f32), }
impl std::fmt::Debug for SignalWeight {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fixed(w) => write!(f, "Fixed({})", w),
Self::QueryDependent(_) => write!(f, "QueryDependent(<function>)"),
Self::Learned => write!(f, "Learned"),
Self::Adaptive(w) => write!(f, "Adaptive({})", w),
}
}
}
impl Clone for SignalWeight {
fn clone(&self) -> Self {
match self {
Self::Fixed(w) => Self::Fixed(*w),
Self::QueryDependent(_) => Self::Fixed(0.5), Self::Learned => Self::Learned,
Self::Adaptive(w) => Self::Adaptive(*w),
}
}
}
#[derive(Debug, Clone)]
pub enum SignalAggregation {
WeightedSum,
WeightedAverage,
Max,
Min,
LearnedCombination,
Custom(String),
}
#[derive(Debug, Clone)]
pub enum SignalNormalization {
MinMax,
ZScore,
Rank,
Sigmoid,
None,
}
#[derive(Debug, Clone)]
pub struct RelevanceSignal {
pub signal_type: SignalType,
pub value: f32,
pub confidence: f32,
pub metadata: SignalMetadata,
}
#[derive(Debug, Clone)]
pub struct SignalMetadata {
pub source: String,
pub extraction_time_ms: u64,
pub features: HashMap<String, f32>,
pub warnings: Vec<String>,
}
pub trait SignalExtractor: Send + Sync {
fn extract_signal(
&self,
query: &str,
document: &SearchResult,
context: &RetrievalContext,
) -> RragResult<RelevanceSignal>;
fn extract_batch(
&self,
query: &str,
documents: &[SearchResult],
context: &RetrievalContext,
) -> RragResult<Vec<RelevanceSignal>> {
documents
.iter()
.map(|doc| self.extract_signal(query, doc, context))
.collect()
}
fn signal_type(&self) -> SignalType;
fn get_config(&self) -> SignalExtractorConfig;
}
#[derive(Debug, Clone)]
pub struct SignalExtractorConfig {
pub name: String,
pub version: String,
pub features: Vec<String>,
pub performance: PerformanceMetrics,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub avg_extraction_time_ms: f32,
pub accuracy: f32,
pub memory_usage_mb: f32,
}
#[derive(Debug, Clone)]
pub struct RetrievalContext {
pub user_id: Option<String>,
pub session_id: Option<String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub query_intent: Option<String>,
pub user_preferences: HashMap<String, f32>,
pub interaction_history: Vec<InteractionRecord>,
}
#[derive(Debug, Clone)]
pub struct InteractionRecord {
pub document_id: String,
pub interaction_type: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub value: f32,
}
impl MultiSignalReranker {
pub fn new(config: MultiSignalConfig) -> Self {
let mut reranker = Self {
config: config.clone(),
signal_extractors: HashMap::new(),
signal_weights: HashMap::new(),
aggregation: config.aggregation_method.clone(),
};
reranker.initialize_extractors();
reranker.initialize_weights();
reranker
}
fn initialize_extractors(&mut self) {
for signal_type in &self.config.enabled_signals {
let extractor: Box<dyn SignalExtractor> = match signal_type {
SignalType::SemanticRelevance => Box::new(SemanticRelevanceExtractor::new()),
SignalType::TextualRelevance => Box::new(TextualRelevanceExtractor::new()),
SignalType::DocumentFreshness => Box::new(DocumentFreshnessExtractor::new()),
SignalType::DocumentAuthority => Box::new(DocumentAuthorityExtractor::new()),
SignalType::DocumentQuality => Box::new(DocumentQualityExtractor::new()),
SignalType::UserPreference => Box::new(UserPreferenceExtractor::new()),
SignalType::ClickThroughRate => Box::new(ClickThroughRateExtractor::new()),
SignalType::DocumentPopularity => Box::new(DocumentPopularityExtractor::new()),
SignalType::InteractionHistory => Box::new(InteractionHistoryExtractor::new()),
SignalType::DomainSpecific(domain) => {
Box::new(DomainSpecificExtractor::new(domain.clone()))
}
};
self.signal_extractors
.insert(signal_type.clone(), extractor);
}
}
fn initialize_weights(&mut self) {
for (signal_type, weight_config) in &self.config.signal_weights {
let weight = match weight_config {
SignalWeight::Fixed(w) => *w,
SignalWeight::Adaptive(w) => *w,
SignalWeight::Learned => 1.0 / self.config.signal_weights.len() as f32, SignalWeight::QueryDependent(_) => 1.0, };
self.signal_weights.insert(signal_type.clone(), weight);
}
}
pub async fn rerank(
&self,
query: &str,
results: &[SearchResult],
) -> RragResult<HashMap<usize, f32>> {
let context = RetrievalContext {
user_id: None,
session_id: None,
timestamp: chrono::Utc::now(),
query_intent: None,
user_preferences: HashMap::new(),
interaction_history: Vec::new(),
};
self.rerank_with_context(query, results, &context).await
}
pub async fn rerank_with_context(
&self,
query: &str,
results: &[SearchResult],
context: &RetrievalContext,
) -> RragResult<HashMap<usize, f32>> {
let mut final_scores = HashMap::new();
let mut all_signals: HashMap<SignalType, Vec<RelevanceSignal>> = HashMap::new();
for (signal_type, extractor) in &self.signal_extractors {
match extractor.extract_batch(query, results, context) {
Ok(signals) => {
all_signals.insert(signal_type.clone(), signals);
}
Err(e) => {
eprintln!("Warning: Failed to extract signal {:?}: {}", signal_type, e);
}
}
}
let normalized_signals = self.normalize_signals(all_signals)?;
for (doc_idx, _) in results.iter().enumerate() {
let mut signal_values = Vec::new();
let mut signal_weights = Vec::new();
for (signal_type, signals) in &normalized_signals {
if let Some(signal) = signals.get(doc_idx) {
if signal.confidence >= self.config.min_signal_confidence {
signal_values.push(signal.value);
let weight = self.get_signal_weight(signal_type, query, signal)?;
signal_weights.push(weight);
}
}
}
let final_score = self.aggregate_signals(&signal_values, &signal_weights)?;
final_scores.insert(doc_idx, final_score);
}
Ok(final_scores)
}
fn normalize_signals(
&self,
signals: HashMap<SignalType, Vec<RelevanceSignal>>,
) -> RragResult<HashMap<SignalType, Vec<RelevanceSignal>>> {
let mut normalized = HashMap::new();
for (signal_type, signal_list) in signals {
let normalized_list = match self.config.normalization {
SignalNormalization::MinMax => self.normalize_min_max(&signal_list),
SignalNormalization::ZScore => self.normalize_z_score(&signal_list),
SignalNormalization::Rank => self.normalize_rank(&signal_list),
SignalNormalization::Sigmoid => self.normalize_sigmoid(&signal_list),
SignalNormalization::None => signal_list,
};
normalized.insert(signal_type, normalized_list);
}
Ok(normalized)
}
fn normalize_min_max(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max_val - min_val;
if range == 0.0 {
return signals.to_vec(); }
signals
.iter()
.map(|signal| {
let mut normalized = signal.clone();
normalized.value = (signal.value - min_val) / range;
normalized
})
.collect()
}
fn normalize_z_score(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
let values: Vec<f32> = signals.iter().map(|s| s.value).collect();
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
let std_dev = variance.sqrt();
if std_dev == 0.0 {
return signals.to_vec();
}
signals
.iter()
.map(|signal| {
let mut normalized = signal.clone();
normalized.value = (signal.value - mean) / std_dev;
normalized.value = 1.0 / (1.0 + (-normalized.value).exp());
normalized
})
.collect()
}
fn normalize_rank(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
let mut indexed_signals: Vec<(usize, &RelevanceSignal)> =
signals.iter().enumerate().collect();
indexed_signals.sort_by(|a, b| {
b.1.value
.partial_cmp(&a.1.value)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut normalized = vec![signals[0].clone(); signals.len()];
for (rank, (original_idx, signal)) in indexed_signals.iter().enumerate() {
normalized[*original_idx] = (*signal).clone();
normalized[*original_idx].value = 1.0 - (rank as f32 / signals.len() as f32);
}
normalized
}
fn normalize_sigmoid(&self, signals: &[RelevanceSignal]) -> Vec<RelevanceSignal> {
signals
.iter()
.map(|signal| {
let mut normalized = signal.clone();
normalized.value = 1.0 / (1.0 + (-signal.value).exp());
normalized
})
.collect()
}
fn get_signal_weight(
&self,
signal_type: &SignalType,
query: &str,
_signal: &RelevanceSignal,
) -> RragResult<f32> {
if let Some(weight_config) = self.config.signal_weights.get(signal_type) {
match weight_config {
SignalWeight::Fixed(w) => Ok(*w),
SignalWeight::Adaptive(w) => Ok(*w),
SignalWeight::Learned => {
Ok(self.signal_weights.get(signal_type).copied().unwrap_or(1.0))
}
SignalWeight::QueryDependent(func) => Ok(func(query)),
}
} else {
Ok(1.0 / self.config.signal_weights.len() as f32) }
}
fn aggregate_signals(&self, values: &[f32], weights: &[f32]) -> RragResult<f32> {
if values.is_empty() {
return Ok(0.0);
}
match &self.aggregation {
SignalAggregation::WeightedSum => {
Ok(values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum())
}
SignalAggregation::WeightedAverage => {
let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
let weight_sum: f32 = weights.iter().sum();
Ok(if weight_sum > 0.0 {
weighted_sum / weight_sum
} else {
0.0
})
}
SignalAggregation::Max => Ok(values.iter().fold(0.0f32, |a, &b| a.max(b))),
SignalAggregation::Min => Ok(values.iter().fold(1.0f32, |a, &b| a.min(b))),
SignalAggregation::LearnedCombination => {
let weighted_sum: f32 = values.iter().zip(weights.iter()).map(|(v, w)| v * w).sum();
let weight_sum: f32 = weights.iter().sum();
Ok(if weight_sum > 0.0 {
weighted_sum / weight_sum
} else {
0.0
})
}
SignalAggregation::Custom(_) => {
Ok(values.iter().sum::<f32>() / values.len() as f32)
}
}
}
}
struct SemanticRelevanceExtractor;
impl SemanticRelevanceExtractor {
fn new() -> Self {
Self
}
}
impl SignalExtractor for SemanticRelevanceExtractor {
fn extract_signal(
&self,
_query: &str,
document: &SearchResult,
_context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
Ok(RelevanceSignal {
signal_type: SignalType::SemanticRelevance,
value: document.score,
confidence: 0.8,
metadata: SignalMetadata {
source: "search_engine".to_string(),
extraction_time_ms: 1,
features: HashMap::new(),
warnings: Vec::new(),
},
})
}
fn signal_type(&self) -> SignalType {
SignalType::SemanticRelevance
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: "SemanticRelevanceExtractor".to_string(),
version: "1.0".to_string(),
features: vec!["vector_similarity".to_string()],
performance: PerformanceMetrics {
avg_extraction_time_ms: 1.0,
accuracy: 0.8,
memory_usage_mb: 0.1,
},
}
}
}
struct TextualRelevanceExtractor;
impl TextualRelevanceExtractor {
fn new() -> Self {
Self
}
}
impl SignalExtractor for TextualRelevanceExtractor {
fn extract_signal(
&self,
query: &str,
document: &SearchResult,
_context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
let query_terms: std::collections::HashSet<&str> = query.split_whitespace().collect();
let doc_terms: std::collections::HashSet<&str> =
document.content.split_whitespace().collect();
let intersection = query_terms.intersection(&doc_terms).count();
let union = query_terms.union(&doc_terms).count();
let jaccard = if union == 0 {
0.0
} else {
intersection as f32 / union as f32
};
Ok(RelevanceSignal {
signal_type: SignalType::TextualRelevance,
value: jaccard,
confidence: 0.7,
metadata: SignalMetadata {
source: "textual_analysis".to_string(),
extraction_time_ms: 2,
features: [
("intersection".to_string(), intersection as f32),
("union".to_string(), union as f32),
]
.iter()
.cloned()
.collect(),
warnings: Vec::new(),
},
})
}
fn signal_type(&self) -> SignalType {
SignalType::TextualRelevance
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: "TextualRelevanceExtractor".to_string(),
version: "1.0".to_string(),
features: vec!["term_overlap".to_string(), "jaccard_similarity".to_string()],
performance: PerformanceMetrics {
avg_extraction_time_ms: 2.0,
accuracy: 0.7,
memory_usage_mb: 0.05,
},
}
}
}
struct DocumentFreshnessExtractor;
impl DocumentFreshnessExtractor {
fn new() -> Self {
Self
}
}
impl SignalExtractor for DocumentFreshnessExtractor {
fn extract_signal(
&self,
_query: &str,
document: &SearchResult,
context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
let doc_timestamp = document
.metadata
.get("timestamp")
.and_then(|v| v.as_str())
.and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_else(|| context.timestamp - chrono::Duration::days(30));
let age_hours = (context.timestamp - doc_timestamp).num_hours() as f32;
let freshness = (-age_hours / (24.0 * 7.0)).exp().min(1.0);
Ok(RelevanceSignal {
signal_type: SignalType::DocumentFreshness,
value: freshness,
confidence: 0.9,
metadata: SignalMetadata {
source: "document_metadata".to_string(),
extraction_time_ms: 1,
features: [("age_hours".to_string(), age_hours)]
.iter()
.cloned()
.collect(),
warnings: Vec::new(),
},
})
}
fn signal_type(&self) -> SignalType {
SignalType::DocumentFreshness
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: "DocumentFreshnessExtractor".to_string(),
version: "1.0".to_string(),
features: vec!["temporal_decay".to_string()],
performance: PerformanceMetrics {
avg_extraction_time_ms: 1.0,
accuracy: 0.9,
memory_usage_mb: 0.01,
},
}
}
}
struct DocumentQualityExtractor;
impl DocumentQualityExtractor {
fn new() -> Self {
Self
}
}
impl SignalExtractor for DocumentQualityExtractor {
fn extract_signal(
&self,
_query: &str,
document: &SearchResult,
_context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
let length = document.content.len() as f32;
let words = document.content.split_whitespace().count() as f32;
let sentences = document.content.split('.').count() as f32;
let length_score = if length > 100.0 && length < 5000.0 {
1.0
} else {
0.5
};
let avg_word_length = if words > 0.0 { length / words } else { 0.0 };
let word_length_score = if avg_word_length > 3.0 && avg_word_length < 15.0 {
1.0
} else {
0.7
};
let sentence_length = if sentences > 0.0 {
words / sentences
} else {
0.0
};
let sentence_score = if sentence_length > 5.0 && sentence_length < 30.0 {
1.0
} else {
0.8
};
let quality_score = (length_score + word_length_score + sentence_score) / 3.0;
Ok(RelevanceSignal {
signal_type: SignalType::DocumentQuality,
value: quality_score,
confidence: 0.6,
metadata: SignalMetadata {
source: "quality_analysis".to_string(),
extraction_time_ms: 3,
features: [
("length".to_string(), length),
("word_count".to_string(), words),
("sentence_count".to_string(), sentences),
("avg_word_length".to_string(), avg_word_length),
("avg_sentence_length".to_string(), sentence_length),
]
.iter()
.cloned()
.collect(),
warnings: Vec::new(),
},
})
}
fn signal_type(&self) -> SignalType {
SignalType::DocumentQuality
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: "DocumentQualityExtractor".to_string(),
version: "1.0".to_string(),
features: vec![
"length_analysis".to_string(),
"structural_analysis".to_string(),
],
performance: PerformanceMetrics {
avg_extraction_time_ms: 3.0,
accuracy: 0.6,
memory_usage_mb: 0.02,
},
}
}
}
macro_rules! impl_placeholder_extractor {
($name:ident, $signal_type:expr, $default_value:expr) => {
struct $name;
impl $name {
fn new() -> Self {
Self
}
}
impl SignalExtractor for $name {
fn extract_signal(
&self,
_query: &str,
_document: &SearchResult,
_context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
Ok(RelevanceSignal {
signal_type: $signal_type,
value: $default_value,
confidence: 0.5,
metadata: SignalMetadata {
source: "placeholder".to_string(),
extraction_time_ms: 1,
features: HashMap::new(),
warnings: vec!["Placeholder implementation".to_string()],
},
})
}
fn signal_type(&self) -> SignalType {
$signal_type
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: stringify!($name).to_string(),
version: "0.1".to_string(),
features: vec!["placeholder".to_string()],
performance: PerformanceMetrics {
avg_extraction_time_ms: 1.0,
accuracy: 0.5,
memory_usage_mb: 0.01,
},
}
}
}
};
}
impl_placeholder_extractor!(
DocumentAuthorityExtractor,
SignalType::DocumentAuthority,
0.5
);
impl_placeholder_extractor!(UserPreferenceExtractor, SignalType::UserPreference, 0.5);
impl_placeholder_extractor!(ClickThroughRateExtractor, SignalType::ClickThroughRate, 0.5);
impl_placeholder_extractor!(
DocumentPopularityExtractor,
SignalType::DocumentPopularity,
0.5
);
impl_placeholder_extractor!(
InteractionHistoryExtractor,
SignalType::InteractionHistory,
0.5
);
struct DomainSpecificExtractor {
domain: String,
}
impl DomainSpecificExtractor {
fn new(domain: String) -> Self {
Self { domain }
}
}
impl SignalExtractor for DomainSpecificExtractor {
fn extract_signal(
&self,
_query: &str,
_document: &SearchResult,
_context: &RetrievalContext,
) -> RragResult<RelevanceSignal> {
Ok(RelevanceSignal {
signal_type: SignalType::DomainSpecific(self.domain.clone()),
value: 0.5,
confidence: 0.5,
metadata: SignalMetadata {
source: "domain_specific".to_string(),
extraction_time_ms: 1,
features: HashMap::new(),
warnings: vec!["Placeholder implementation".to_string()],
},
})
}
fn signal_type(&self) -> SignalType {
SignalType::DomainSpecific(self.domain.clone())
}
fn get_config(&self) -> SignalExtractorConfig {
SignalExtractorConfig {
name: format!("DomainSpecificExtractor({})", self.domain),
version: "0.1".to_string(),
features: vec!["domain_analysis".to_string()],
performance: PerformanceMetrics {
avg_extraction_time_ms: 1.0,
accuracy: 0.5,
memory_usage_mb: 0.01,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SearchResult;
#[tokio::test]
async fn test_multi_signal_reranking() {
let config = MultiSignalConfig::default();
let reranker = MultiSignalReranker::new(config);
let results = vec![
SearchResult {
id: "doc1".to_string(),
content: "Machine learning is a subset of artificial intelligence that focuses on algorithms".to_string(),
score: 0.8,
rank: 0,
metadata: HashMap::new(),
embedding: None,
},
SearchResult {
id: "doc2".to_string(),
content: "AI".to_string(), score: 0.9,
rank: 1,
metadata: HashMap::new(),
embedding: None,
},
];
let query = "What is machine learning in artificial intelligence?";
let reranked_scores = reranker.rerank(query, &results).await.unwrap();
assert!(!reranked_scores.is_empty());
assert!(reranked_scores.get(&0).unwrap_or(&0.0) > &0.0);
}
#[test]
fn test_signal_normalization() {
let config = MultiSignalConfig::default();
let reranker = MultiSignalReranker::new(config);
let signals = vec![
RelevanceSignal {
signal_type: SignalType::SemanticRelevance,
value: 0.1,
confidence: 1.0,
metadata: SignalMetadata {
source: "test".to_string(),
extraction_time_ms: 0,
features: HashMap::new(),
warnings: Vec::new(),
},
},
RelevanceSignal {
signal_type: SignalType::SemanticRelevance,
value: 0.9,
confidence: 1.0,
metadata: SignalMetadata {
source: "test".to_string(),
extraction_time_ms: 0,
features: HashMap::new(),
warnings: Vec::new(),
},
},
];
let normalized = reranker.normalize_min_max(&signals);
assert_eq!(normalized[0].value, 0.0); assert_eq!(normalized[1].value, 1.0); }
}