use crate::{RragResult, SearchResult};
use std::collections::HashMap;
pub struct LearningToRankReranker {
config: LTRConfig,
model: Box<dyn LTRModel>,
feature_extractors: Vec<Box<dyn FeatureExtractor>>,
feature_cache: HashMap<String, Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct LTRConfig {
pub model_type: LTRModelType,
pub feature_config: FeatureExtractionConfig,
pub model_parameters: HashMap<String, f32>,
pub training_config: Option<TrainingConfig>,
pub enable_feature_caching: bool,
pub batch_size: usize,
}
impl Default for LTRConfig {
fn default() -> Self {
let mut model_parameters = HashMap::new();
model_parameters.insert("learning_rate".to_string(), 0.01);
model_parameters.insert("num_trees".to_string(), 100.0);
model_parameters.insert("max_depth".to_string(), 6.0);
Self {
model_type: LTRModelType::SimulatedLambdaMART,
feature_config: FeatureExtractionConfig::default(),
model_parameters,
training_config: None,
enable_feature_caching: true,
batch_size: 32,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LTRModelType {
RankNet,
LambdaMART,
ListNet,
RankSVM,
Custom(String),
SimulatedLambdaMART,
}
#[derive(Debug, Clone)]
pub struct FeatureExtractionConfig {
pub enabled_features: Vec<FeatureType>,
pub normalization: FeatureNormalization,
pub max_features: usize,
pub feature_selection: FeatureSelection,
}
impl Default for FeatureExtractionConfig {
fn default() -> Self {
Self {
enabled_features: vec![
FeatureType::QueryDocumentSimilarity,
FeatureType::DocumentLength,
FeatureType::QueryTermFrequency,
FeatureType::DocumentTermFrequency,
FeatureType::InverseLinkFrequency,
],
normalization: FeatureNormalization::ZScore,
max_features: 100,
feature_selection: FeatureSelection::None,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub num_iterations: usize,
pub learning_rate: f32,
pub regularization: RegularizationConfig,
pub early_stopping: EarlyStoppingConfig,
pub cv_folds: usize,
}
#[derive(Debug, Clone)]
pub struct RegularizationConfig {
pub l1_weight: f32,
pub l2_weight: f32,
pub dropout_rate: f32,
}
#[derive(Debug, Clone)]
pub struct EarlyStoppingConfig {
pub metric: String,
pub patience: usize,
pub min_delta: f32,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum FeatureType {
QueryDocumentSimilarity,
DocumentLength,
QueryTermFrequency,
DocumentTermFrequency,
InverseLinkFrequency,
BM25Score,
AuthorityScore,
ClickThroughRate,
ExactMatches,
PositionalFeatures,
TemporalFeatures,
Custom(String),
}
#[derive(Debug, Clone)]
pub enum FeatureNormalization {
MinMax,
ZScore,
Quantile,
None,
}
#[derive(Debug, Clone)]
pub enum FeatureSelection {
None,
TopK(usize),
Correlation(f32),
RFE,
}
#[derive(Debug, Clone)]
pub struct RankingFeature {
pub feature_type: FeatureType,
pub name: String,
pub value: f32,
pub importance: Option<f32>,
pub metadata: FeatureMetadata,
}
#[derive(Debug, Clone)]
pub struct FeatureMetadata {
pub extraction_method: String,
pub extraction_time_ms: u64,
pub confidence: f32,
pub properties: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct LTRFeatures {
pub query_id: String,
pub document_id: String,
pub features: Vec<f32>,
pub feature_names: Vec<String>,
pub relevance: Option<f32>,
pub metadata: LTRFeaturesMetadata,
}
#[derive(Debug, Clone)]
pub struct LTRFeaturesMetadata {
pub extraction_time_ms: u64,
pub num_features: usize,
pub quality_score: f32,
pub warnings: Vec<String>,
}
pub trait LTRModel: Send + Sync {
fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>>;
fn predict_single(&self, features: &[f32]) -> RragResult<f32> {
let batch_result = self.predict(&[features.to_vec()])?;
Ok(batch_result.into_iter().next().unwrap_or(0.0))
}
fn train(&mut self, training_data: &[LTRTrainingExample]) -> RragResult<TrainingResult> {
let _ = training_data; Err(crate::RragError::validation(
"training",
"Training not implemented for this model",
"",
))
}
fn get_model_info(&self) -> LTRModelInfo;
fn get_feature_importance(&self) -> Option<Vec<f32>> {
None
}
}
#[derive(Debug, Clone)]
pub struct LTRTrainingExample {
pub query_id: String,
pub document_id: String,
pub features: Vec<f32>,
pub relevance: f32,
pub weight: f32,
}
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub final_loss: f32,
pub validation_metrics: HashMap<String, f32>,
pub training_time_ms: u64,
pub iterations_completed: usize,
pub early_stopped: bool,
}
#[derive(Debug, Clone)]
pub struct LTRModelInfo {
pub name: String,
pub version: String,
pub num_features: usize,
pub parameters: HashMap<String, f32>,
pub is_trained: bool,
pub performance_metrics: Option<HashMap<String, f32>>,
}
pub trait FeatureExtractor: Send + Sync {
fn extract_features(
&self,
_query: &str,
document: &SearchResult,
context: &FeatureExtractionContext,
) -> RragResult<Vec<RankingFeature>>;
fn supported_features(&self) -> Vec<FeatureType>;
fn get_config(&self) -> FeatureExtractorConfig;
}
#[derive(Debug, Clone)]
pub struct FeatureExtractionContext {
pub all_documents: Vec<SearchResult>,
pub query_stats: QueryStats,
pub collection_stats: CollectionStats,
pub user_context: Option<UserContext>,
}
#[derive(Debug, Clone)]
pub struct QueryStats {
pub length: usize,
pub terms: Vec<String>,
pub query_type: Option<String>,
pub term_frequencies: HashMap<String, usize>,
}
#[derive(Debug, Clone)]
pub struct CollectionStats {
pub total_documents: usize,
pub avg_document_length: f32,
pub document_frequencies: HashMap<String, usize>,
pub vocabulary_size: usize,
}
#[derive(Debug, Clone)]
pub struct UserContext {
pub user_id: String,
pub preferences: HashMap<String, f32>,
pub interaction_history: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct FeatureExtractorConfig {
pub name: String,
pub supported_features: Vec<FeatureType>,
pub performance: FeatureExtractorPerformance,
}
#[derive(Debug, Clone)]
pub struct FeatureExtractorPerformance {
pub avg_extraction_time_ms: f32,
pub memory_usage_mb: f32,
pub quality_score: f32,
}
impl LearningToRankReranker {
pub fn new(config: LTRConfig) -> Self {
let model = Self::create_model(&config.model_type, &config.model_parameters);
let feature_extractors = Self::create_feature_extractors(&config.feature_config);
Self {
config,
model,
feature_extractors,
feature_cache: HashMap::new(),
}
}
fn create_model(
model_type: <RModelType,
parameters: &HashMap<String, f32>,
) -> Box<dyn LTRModel> {
match model_type {
LTRModelType::SimulatedLambdaMART => {
Box::new(SimulatedLambdaMARTModel::new(parameters.clone()))
}
LTRModelType::LambdaMART => Box::new(SimulatedLambdaMARTModel::new(parameters.clone())),
LTRModelType::RankNet => Box::new(SimulatedRankNetModel::new()),
LTRModelType::ListNet => Box::new(SimulatedListNetModel::new()),
LTRModelType::RankSVM => Box::new(SimulatedRankSVMModel::new()),
LTRModelType::Custom(name) => Box::new(CustomLTRModel::new(name.clone())),
}
}
fn create_feature_extractors(
config: &FeatureExtractionConfig,
) -> Vec<Box<dyn FeatureExtractor>> {
let mut extractors: Vec<Box<dyn FeatureExtractor>> = Vec::new();
if config
.enabled_features
.contains(&FeatureType::QueryDocumentSimilarity)
{
extractors.push(Box::new(SimilarityFeatureExtractor::new()));
}
if config
.enabled_features
.contains(&FeatureType::DocumentLength)
{
extractors.push(Box::new(LengthFeatureExtractor::new()));
}
if config
.enabled_features
.contains(&FeatureType::QueryTermFrequency)
{
extractors.push(Box::new(TermFrequencyExtractor::new()));
}
extractors
}
pub async fn rerank(
&self,
query: &str,
results: &[SearchResult],
) -> RragResult<HashMap<usize, f32>> {
let context = FeatureExtractionContext {
all_documents: results.to_vec(),
query_stats: self.compute_query_stats(query),
collection_stats: self.compute_collection_stats(results),
user_context: None,
};
let mut feature_vectors = Vec::new();
for document in results {
let features = self.extract_document_features(query, document, &context)?;
feature_vectors.push(features);
}
let scores = self.model.predict(&feature_vectors)?;
let mut score_map = HashMap::new();
for (idx, score) in scores.into_iter().enumerate() {
score_map.insert(idx, score);
}
Ok(score_map)
}
fn extract_document_features(
&self,
query: &str,
document: &SearchResult,
context: &FeatureExtractionContext,
) -> RragResult<Vec<f32>> {
let mut all_features = Vec::new();
for extractor in &self.feature_extractors {
let features = extractor.extract_features(query, document, context)?;
for feature in features {
all_features.push(feature.value);
}
}
let normalized_features = match self.config.feature_config.normalization {
FeatureNormalization::None => all_features,
_ => self.normalize_features(all_features)?,
};
let selected_features = match self.config.feature_config.feature_selection {
FeatureSelection::None => normalized_features,
_ => self.select_features(normalized_features)?,
};
Ok(selected_features)
}
fn compute_query_stats(&self, query: &str) -> QueryStats {
let terms: Vec<String> = query.split_whitespace().map(|s| s.to_lowercase()).collect();
let mut term_frequencies = HashMap::new();
for term in &terms {
*term_frequencies.entry(term.clone()).or_insert(0) += 1;
}
QueryStats {
length: terms.len(),
terms,
query_type: None, term_frequencies,
}
}
fn compute_collection_stats(&self, documents: &[SearchResult]) -> CollectionStats {
let total_documents = documents.len();
let total_length: usize = documents
.iter()
.map(|d| d.content.split_whitespace().count())
.sum();
let avg_document_length = if total_documents > 0 {
total_length as f32 / total_documents as f32
} else {
0.0
};
let mut document_frequencies = HashMap::new();
let mut vocabulary = std::collections::HashSet::new();
for document in documents {
let terms: std::collections::HashSet<String> = document
.content
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
for term in &terms {
*document_frequencies.entry(term.clone()).or_insert(0) += 1;
vocabulary.insert(term.clone());
}
}
CollectionStats {
total_documents,
avg_document_length,
document_frequencies,
vocabulary_size: vocabulary.len(),
}
}
fn normalize_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
match self.config.feature_config.normalization {
FeatureNormalization::MinMax => {
let min_val = features.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = features.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max_val - min_val;
if range == 0.0 {
Ok(features) } else {
Ok(features
.into_iter()
.map(|f| (f - min_val) / range)
.collect())
}
}
FeatureNormalization::ZScore => {
let mean = features.iter().sum::<f32>() / features.len() as f32;
let variance = features.iter().map(|f| (f - mean).powi(2)).sum::<f32>()
/ features.len() as f32;
let std_dev = variance.sqrt();
if std_dev == 0.0 {
Ok(features)
} else {
Ok(features.into_iter().map(|f| (f - mean) / std_dev).collect())
}
}
_ => Ok(features), }
}
fn select_features(&self, features: Vec<f32>) -> RragResult<Vec<f32>> {
match self.config.feature_config.feature_selection {
FeatureSelection::TopK(k) => {
Ok(features.into_iter().take(k).collect())
}
_ => Ok(features), }
}
}
struct SimulatedLambdaMARTModel {
parameters: HashMap<String, f32>,
num_trees: usize,
}
impl SimulatedLambdaMARTModel {
fn new(parameters: HashMap<String, f32>) -> Self {
let num_trees = parameters.get("num_trees").copied().unwrap_or(100.0) as usize;
Self {
parameters,
num_trees,
}
}
}
impl LTRModel for SimulatedLambdaMARTModel {
fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
let mut scores = Vec::new();
for feature_vector in features {
let mut score = 0.0;
for tree_idx in 0..self.num_trees {
let tree_score = feature_vector
.iter()
.enumerate()
.map(|(i, &f)| f * (0.1 + 0.01 * (tree_idx + i) as f32).sin())
.sum::<f32>()
/ feature_vector.len() as f32;
score += tree_score * 0.01; }
scores.push(1.0 / (1.0 + (-score).exp()));
}
Ok(scores)
}
fn get_model_info(&self) -> LTRModelInfo {
LTRModelInfo {
name: "SimulatedLambdaMART".to_string(),
version: "1.0".to_string(),
num_features: 0, parameters: self.parameters.clone(),
is_trained: true,
performance_metrics: None,
}
}
fn get_feature_importance(&self) -> Option<Vec<f32>> {
Some(vec![0.3, 0.25, 0.2, 0.15, 0.1]) }
}
macro_rules! impl_mock_ltr_model {
($name:ident) => {
struct $name;
impl $name {
fn new() -> Self {
Self
}
}
impl LTRModel for $name {
fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
Ok(features
.iter()
.map(|f| f.iter().sum::<f32>() / f.len() as f32)
.map(|s| 1.0 / (1.0 + (-s).exp())) .collect())
}
fn get_model_info(&self) -> LTRModelInfo {
LTRModelInfo {
name: stringify!($name).to_string(),
version: "1.0".to_string(),
num_features: 0,
parameters: HashMap::new(),
is_trained: false,
performance_metrics: None,
}
}
}
};
}
impl_mock_ltr_model!(SimulatedRankNetModel);
impl_mock_ltr_model!(SimulatedListNetModel);
impl_mock_ltr_model!(SimulatedRankSVMModel);
struct CustomLTRModel {
name: String,
}
impl CustomLTRModel {
fn new(name: String) -> Self {
Self { name }
}
}
impl LTRModel for CustomLTRModel {
fn predict(&self, features: &[Vec<f32>]) -> RragResult<Vec<f32>> {
Ok(vec![0.5; features.len()]) }
fn get_model_info(&self) -> LTRModelInfo {
LTRModelInfo {
name: self.name.clone(),
version: "custom".to_string(),
num_features: 0,
parameters: HashMap::new(),
is_trained: false,
performance_metrics: None,
}
}
}
struct SimilarityFeatureExtractor;
impl SimilarityFeatureExtractor {
fn new() -> Self {
Self
}
}
impl FeatureExtractor for SimilarityFeatureExtractor {
fn extract_features(
&self,
_query: &str,
document: &SearchResult,
_context: &FeatureExtractionContext,
) -> RragResult<Vec<RankingFeature>> {
let similarity = document.score;
Ok(vec![RankingFeature {
feature_type: FeatureType::QueryDocumentSimilarity,
name: "cosine_similarity".to_string(),
value: similarity,
importance: Some(0.8),
metadata: FeatureMetadata {
extraction_method: "vector_similarity".to_string(),
extraction_time_ms: 1,
confidence: 0.9,
properties: HashMap::new(),
},
}])
}
fn supported_features(&self) -> Vec<FeatureType> {
vec![FeatureType::QueryDocumentSimilarity]
}
fn get_config(&self) -> FeatureExtractorConfig {
FeatureExtractorConfig {
name: "SimilarityFeatureExtractor".to_string(),
supported_features: self.supported_features(),
performance: FeatureExtractorPerformance {
avg_extraction_time_ms: 1.0,
memory_usage_mb: 0.1,
quality_score: 0.9,
},
}
}
}
struct LengthFeatureExtractor;
impl LengthFeatureExtractor {
fn new() -> Self {
Self
}
}
impl FeatureExtractor for LengthFeatureExtractor {
fn extract_features(
&self,
_query: &str,
document: &SearchResult,
context: &FeatureExtractionContext,
) -> RragResult<Vec<RankingFeature>> {
let doc_length = document.content.split_whitespace().count() as f32;
let normalized_length = doc_length / context.collection_stats.avg_document_length;
Ok(vec![
RankingFeature {
feature_type: FeatureType::DocumentLength,
name: "document_length".to_string(),
value: doc_length,
importance: Some(0.3),
metadata: FeatureMetadata {
extraction_method: "word_count".to_string(),
extraction_time_ms: 1,
confidence: 1.0,
properties: HashMap::new(),
},
},
RankingFeature {
feature_type: FeatureType::DocumentLength,
name: "normalized_document_length".to_string(),
value: normalized_length,
importance: Some(0.4),
metadata: FeatureMetadata {
extraction_method: "normalized_word_count".to_string(),
extraction_time_ms: 1,
confidence: 1.0,
properties: HashMap::new(),
},
},
])
}
fn supported_features(&self) -> Vec<FeatureType> {
vec![FeatureType::DocumentLength]
}
fn get_config(&self) -> FeatureExtractorConfig {
FeatureExtractorConfig {
name: "LengthFeatureExtractor".to_string(),
supported_features: self.supported_features(),
performance: FeatureExtractorPerformance {
avg_extraction_time_ms: 1.0,
memory_usage_mb: 0.01,
quality_score: 1.0,
},
}
}
}
struct TermFrequencyExtractor;
impl TermFrequencyExtractor {
fn new() -> Self {
Self
}
}
impl FeatureExtractor for TermFrequencyExtractor {
fn extract_features(
&self,
_query: &str,
document: &SearchResult,
context: &FeatureExtractionContext,
) -> RragResult<Vec<RankingFeature>> {
let mut features = Vec::new();
let doc_terms: std::collections::HashMap<String, usize> = {
let mut map = std::collections::HashMap::new();
for term in document.content.split_whitespace() {
let term = term.to_lowercase();
*map.entry(term).or_insert(0) += 1;
}
map
};
let mut total_qtf = 0.0;
let mut matched_terms = 0;
for query_term in &context.query_stats.terms {
if let Some(&tf) = doc_terms.get(query_term) {
total_qtf += tf as f32;
matched_terms += 1;
}
}
features.push(RankingFeature {
feature_type: FeatureType::QueryTermFrequency,
name: "total_query_term_frequency".to_string(),
value: total_qtf,
importance: Some(0.6),
metadata: FeatureMetadata {
extraction_method: "term_counting".to_string(),
extraction_time_ms: 2,
confidence: 0.9,
properties: HashMap::new(),
},
});
features.push(RankingFeature {
feature_type: FeatureType::QueryTermFrequency,
name: "query_term_coverage".to_string(),
value: matched_terms as f32 / context.query_stats.terms.len() as f32,
importance: Some(0.7),
metadata: FeatureMetadata {
extraction_method: "coverage_calculation".to_string(),
extraction_time_ms: 1,
confidence: 1.0,
properties: HashMap::new(),
},
});
Ok(features)
}
fn supported_features(&self) -> Vec<FeatureType> {
vec![FeatureType::QueryTermFrequency]
}
fn get_config(&self) -> FeatureExtractorConfig {
FeatureExtractorConfig {
name: "TermFrequencyExtractor".to_string(),
supported_features: self.supported_features(),
performance: FeatureExtractorPerformance {
avg_extraction_time_ms: 3.0,
memory_usage_mb: 0.05,
quality_score: 0.8,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SearchResult;
#[tokio::test]
async fn test_ltr_reranking() {
let config = LTRConfig::default();
let reranker = LearningToRankReranker::new(config);
let results = vec![
SearchResult {
id: "doc1".to_string(),
content: "Machine learning is a subset of artificial intelligence that enables computers to learn".to_string(),
score: 0.8,
rank: 0,
metadata: HashMap::new(),
embedding: None,
},
SearchResult {
id: "doc2".to_string(),
content: "AI and ML".to_string(),
score: 0.6,
rank: 1,
metadata: HashMap::new(),
embedding: None,
},
];
let query = "machine learning artificial intelligence";
let reranked_scores = reranker.rerank(query, &results).await.unwrap();
assert!(!reranked_scores.is_empty());
assert!(reranked_scores.contains_key(&0));
assert!(reranked_scores.contains_key(&1));
}
#[test]
fn test_feature_extraction() {
let extractor = SimilarityFeatureExtractor::new();
let context = FeatureExtractionContext {
all_documents: vec![],
query_stats: QueryStats {
length: 2,
terms: vec!["test".to_string(), "query".to_string()],
query_type: None,
term_frequencies: HashMap::new(),
},
collection_stats: CollectionStats {
total_documents: 1,
avg_document_length: 10.0,
document_frequencies: HashMap::new(),
vocabulary_size: 100,
},
user_context: None,
};
let document = SearchResult {
id: "test_doc".to_string(),
content: "test document content".to_string(),
score: 0.7,
rank: 0,
metadata: HashMap::new(),
embedding: None,
};
let features = extractor
.extract_features("test query", &document, &context)
.unwrap();
assert!(!features.is_empty());
assert_eq!(
features[0].feature_type,
FeatureType::QueryDocumentSimilarity
);
assert_eq!(features[0].value, 0.7);
}
}