use super::simd;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MaxSimExplanation {
pub total_score: f32,
pub token_contributions: Vec<TokenMatch>,
pub query_token_texts: Option<Vec<String>>,
pub doc_token_texts: Option<Vec<String>>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TokenMatch {
pub query_token_idx: usize,
pub best_doc_token_idx: usize,
pub similarity: f32,
pub query_token_text: Option<String>,
pub doc_token_text: Option<String>,
pub contribution: f32,
}
pub fn maxsim_explained(
query_tokens: &[Vec<f32>],
doc_tokens: &[Vec<f32>],
query_texts: Option<&[&str]>,
doc_texts: Option<&[&str]>,
use_cosine: bool,
) -> MaxSimExplanation {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return MaxSimExplanation {
total_score: 0.0,
token_contributions: Vec::new(),
query_token_texts: query_texts.map(|t| t.iter().map(|s| s.to_string()).collect()),
doc_token_texts: doc_texts.map(|t| t.iter().map(|s| s.to_string()).collect()),
};
}
let query_refs: Vec<&[f32]> = query_tokens.iter().map(|v| v.as_slice()).collect();
let doc_refs: Vec<&[f32]> = doc_tokens.iter().map(|v| v.as_slice()).collect();
let alignments = if use_cosine {
simd::maxsim_alignments_cosine(&query_refs, &doc_refs)
} else {
simd::maxsim_alignments(&query_refs, &doc_refs)
};
let token_contributions: Vec<TokenMatch> = alignments
.into_iter()
.map(|(q_idx, d_idx, similarity)| TokenMatch {
query_token_idx: q_idx,
best_doc_token_idx: d_idx,
similarity,
query_token_text: query_texts.and_then(|texts| texts.get(q_idx).map(|s| s.to_string())),
doc_token_text: doc_texts.and_then(|texts| texts.get(d_idx).map(|s| s.to_string())),
contribution: similarity,
})
.collect();
let total_score: f32 = token_contributions.iter().map(|m| m.contribution).sum();
MaxSimExplanation {
total_score,
token_contributions,
query_token_texts: query_texts.map(|t| t.iter().map(|s| s.to_string()).collect()),
doc_token_texts: doc_texts.map(|t| t.iter().map(|s| s.to_string()).collect()),
}
}
#[derive(Debug, Clone)]
pub struct RerankerInput<'a, K> {
pub query_dense: Option<&'a [f32]>,
pub query_tokens: Option<&'a [Vec<f32>]>,
pub candidates: Vec<Candidate<'a, K>>,
}
#[derive(Debug, Clone)]
pub struct Candidate<'a, K> {
pub id: K,
pub original_score: f32,
pub dense_embedding: Option<&'a [f32]>,
pub token_embeddings: Option<&'a [Vec<f32>]>,
pub text: Option<&'a str>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RerankMethod {
DenseCosine,
MaxSim,
MaxSimCosine,
MaxSimWeighted,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RankedResult<K> {
pub id: K,
pub score: f32,
pub original_score: f32,
pub rank: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FineGrainedConfig {
pub min_score: f32,
pub max_score: f32,
pub use_probability_weighting: bool,
pub temperature: f32,
}
impl Default for FineGrainedConfig {
fn default() -> Self {
Self {
min_score: -1.0,
max_score: 1.0,
use_probability_weighting: true,
temperature: 1.0,
}
}
}
impl FineGrainedConfig {
pub const fn new(min_score: f32, max_score: f32) -> Self {
Self {
min_score,
max_score,
use_probability_weighting: true,
temperature: 1.0,
}
}
pub const fn without_weighting(mut self) -> Self {
self.use_probability_weighting = false;
self
}
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FineGrainedResult<K> {
pub id: K,
pub fine_score: u8,
pub similarity_score: f32,
pub original_score: f32,
pub rank: usize,
}
pub fn rerank_fine_grained<'a, K: Clone>(
input: RerankerInput<'a, K>,
method: RerankMethod,
config: FineGrainedConfig,
top_k: usize,
) -> Vec<FineGrainedResult<K>> {
let mut results: Vec<(K, f32, f32)> = input
.candidates
.into_iter()
.map(|candidate| {
let score = match method {
RerankMethod::DenseCosine => {
if let (Some(q), Some(d)) = (input.query_dense, candidate.dense_embedding) {
simd::cosine(q, d)
} else {
candidate.original_score
}
}
RerankMethod::MaxSim => {
if let (Some(q_tokens), Some(d_tokens)) =
(input.query_tokens, candidate.token_embeddings)
{
simd::maxsim_vecs(q_tokens, d_tokens)
} else {
candidate.original_score
}
}
RerankMethod::MaxSimCosine => {
if let (Some(q_tokens), Some(d_tokens)) =
(input.query_tokens, candidate.token_embeddings)
{
simd::maxsim_cosine_vecs(q_tokens, d_tokens)
} else {
candidate.original_score
}
}
RerankMethod::MaxSimWeighted => {
candidate.original_score
}
};
(candidate.id, score, candidate.original_score)
})
.collect();
results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
if results.is_empty() {
return Vec::new();
}
let _min_sim = results
.iter()
.map(|(_, s, _)| *s)
.fold(f32::INFINITY, f32::min);
let _max_sim = results
.iter()
.map(|(_, s, _)| *s)
.fold(f32::NEG_INFINITY, f32::max);
let score_range = config.max_score - config.min_score;
let normalized: Vec<(K, f32, f32, f32)> = if score_range > 1e-9 {
results
.into_iter()
.map(|(id, sim, orig)| {
let clamped = sim.clamp(config.min_score, config.max_score);
let norm = (clamped - config.min_score) / score_range;
(id, sim, orig, norm)
})
.collect()
} else {
results
.into_iter()
.map(|(id, sim, orig)| (id, sim, orig, 0.5))
.collect()
};
let weighted: Vec<(K, f32, f32, f32)> =
if config.use_probability_weighting && normalized.len() > 1 {
let exp_scores: Vec<f32> = normalized
.iter()
.map(|(_, _, _, norm)| (norm / config.temperature).exp())
.collect();
let sum_exp: f32 = exp_scores.iter().sum();
normalized
.into_iter()
.zip(exp_scores)
.map(|((id, sim, orig, norm), exp)| {
let weight = exp / sum_exp;
let weighted_norm = 0.7 * norm + 0.3 * weight;
(id, sim, orig, weighted_norm)
})
.collect()
} else {
normalized
};
let mut fine_results: Vec<FineGrainedResult<K>> = weighted
.into_iter()
.enumerate()
.map(|(rank, (id, sim, orig, norm))| {
let fine_score = (norm * 10.0).round().clamp(0.0, 10.0) as u8;
FineGrainedResult {
id,
fine_score,
similarity_score: sim,
original_score: orig,
rank,
}
})
.collect();
fine_results.truncate(top_k);
fine_results
}
pub fn rerank_batch<'a, K: Clone>(
input: RerankerInput<'a, K>,
method: RerankMethod,
top_k: usize,
) -> Vec<RankedResult<K>> {
let mut results: Vec<RankedResult<K>> = input
.candidates
.into_iter()
.map(|candidate| {
let score = match method {
RerankMethod::DenseCosine => {
if let (Some(q), Some(d)) = (input.query_dense, candidate.dense_embedding) {
simd::cosine(q, d)
} else {
candidate.original_score
}
}
RerankMethod::MaxSim => {
if let (Some(q_tokens), Some(d_tokens)) =
(input.query_tokens, candidate.token_embeddings)
{
simd::maxsim_vecs(q_tokens, d_tokens)
} else {
candidate.original_score
}
}
RerankMethod::MaxSimCosine => {
if let (Some(q_tokens), Some(d_tokens)) =
(input.query_tokens, candidate.token_embeddings)
{
simd::maxsim_cosine_vecs(q_tokens, d_tokens)
} else {
candidate.original_score
}
}
RerankMethod::MaxSimWeighted => {
candidate.original_score
}
};
RankedResult {
id: candidate.id,
score,
original_score: candidate.original_score,
rank: 0, }
})
.collect();
results.sort_unstable_by(|a, b| b.score.total_cmp(&a.score));
for (rank, result) in results.iter_mut().enumerate() {
result.rank = rank;
}
results.truncate(top_k);
results
}
pub mod weights {
use std::collections::HashMap;
pub fn idf_weights(
token_ids: &[u32],
idf_table: &HashMap<u32, f32>,
default_idf: f32,
) -> Vec<f32> {
token_ids
.iter()
.map(|&id| idf_table.get(&id).copied().unwrap_or(default_idf))
.collect()
}
pub fn attention_weights(attention_scores: &[f32]) -> Vec<f32> {
if attention_scores.is_empty() {
return Vec::new();
}
let sum: f32 = attention_scores.iter().sum();
if sum.abs() < 1e-9 {
return vec![1.0 / attention_scores.len() as f32; attention_scores.len()];
}
attention_scores.iter().map(|&s| s / sum).collect()
}
pub fn load_learned_weights(path: &std::path::Path) -> std::io::Result<Vec<f32>> {
let content = std::fs::read_to_string(path)?;
let weights: Result<Vec<f32>, _> = content
.lines()
.filter(|l| !l.trim().is_empty())
.map(|l| l.trim().parse::<f32>())
.collect();
weights.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn maxsim_explained_basic() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let explanation = maxsim_explained(&query, &doc, None, None, false);
assert_eq!(explanation.token_contributions.len(), 2);
assert!((explanation.total_score - 1.8).abs() < 0.1);
}
#[test]
fn maxsim_explained_with_texts() {
let query = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let explanation = maxsim_explained(
&query,
&doc,
Some(&["capital", "France"]),
Some(&["Paris", "capital", "France"]),
false,
);
assert_eq!(explanation.token_contributions.len(), 2);
}
#[test]
fn maxsim_explained_uses_or_not_and() {
let empty_query: Vec<Vec<f32>> = vec![];
let non_empty_doc = vec![vec![1.0, 0.0]];
let explanation1 = maxsim_explained(&empty_query, &non_empty_doc, None, None, false);
assert_eq!(explanation1.total_score, 0.0);
assert_eq!(explanation1.token_contributions.len(), 0);
let non_empty_query = vec![vec![1.0, 0.0]];
let empty_doc: Vec<Vec<f32>> = vec![];
let explanation2 = maxsim_explained(&non_empty_query, &empty_doc, None, None, false);
assert_eq!(explanation2.total_score, 0.0);
assert_eq!(explanation2.token_contributions.len(), 0);
let explanation3 = maxsim_explained(&empty_query, &empty_doc, None, None, false);
assert_eq!(explanation3.total_score, 0.0);
assert_eq!(explanation3.token_contributions.len(), 0);
let explanation4 = maxsim_explained(&non_empty_query, &non_empty_doc, None, None, false);
assert!(explanation4.total_score != 0.0);
assert_eq!(explanation4.token_contributions.len(), 1);
}
#[test]
fn rerank_batch_maxsim() {
let query_tokens = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let doc1_tokens = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
let doc2_tokens = vec![vec![0.5, 0.5]];
let candidates = vec![
Candidate {
id: "doc1",
original_score: 0.8,
dense_embedding: None,
token_embeddings: Some(&doc1_tokens),
text: None,
},
Candidate {
id: "doc2",
original_score: 0.7,
dense_embedding: None,
token_embeddings: Some(&doc2_tokens),
text: None,
},
];
let input = RerankerInput {
query_dense: None,
query_tokens: Some(&query_tokens),
candidates,
};
let results = rerank_batch(input, RerankMethod::MaxSim, 10);
assert_eq!(results.len(), 2);
assert!(results[0].score >= results[1].score);
}
#[test]
fn weights_idf() {
use std::collections::HashMap;
let idf_table = HashMap::from([(100, 2.0), (200, 0.5)]);
let weights = weights::idf_weights(&[100, 200], &idf_table, 1.0);
assert_eq!(weights.len(), 2);
assert!(weights[0] > weights[1]);
}
#[test]
fn weights_attention() {
let attention = vec![0.1, 0.3, 0.6];
let weights = weights::attention_weights(&attention);
let sum: f32 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
}