pub struct AttentionReranker {
dim: usize,
#[allow(dead_code)]
num_heads: usize,
}
impl AttentionReranker {
pub fn new(dim: usize, num_heads: usize) -> Self {
Self { dim, num_heads }
}
pub fn rerank(
&self,
query_embedding: &[f32],
results: &[(String, f32, Vec<f32>)],
top_k: usize,
) -> Vec<(String, f32)> {
if results.is_empty() {
return Vec::new();
}
#[cfg(not(target_arch = "wasm32"))]
{
self.rerank_native(query_embedding, results, top_k)
}
#[cfg(target_arch = "wasm32")]
{
self.rerank_wasm(results, top_k)
}
}
#[cfg(not(target_arch = "wasm32"))]
fn rerank_native(
&self,
query_embedding: &[f32],
results: &[(String, f32, Vec<f32>)],
top_k: usize,
) -> Vec<(String, f32)> {
use ruvector_attention::attention::ScaledDotProductAttention;
use ruvector_attention::traits::Attention;
let attn = ScaledDotProductAttention::new(self.dim);
let keys: Vec<&[f32]> = results.iter().map(|(_, _, emb)| emb.as_slice()).collect();
let scale = (self.dim as f32).sqrt();
let scores: Vec<f32> = keys
.iter()
.map(|key| {
query_embedding
.iter()
.zip(key.iter())
.map(|(q, k)| q * k)
.sum::<f32>()
/ scale
})
.collect();
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let exp_sum: f32 = exp_scores.iter().sum();
let attention_weights: Vec<f32> = exp_scores.iter().map(|e| e / exp_sum).collect();
let _attended_output = attn.compute(query_embedding, &keys, &keys);
let mut scored: Vec<(String, f32)> = results
.iter()
.zip(attention_weights.iter())
.map(|((id, cosine, _), &attn_w)| {
let final_score = 0.6 * attn_w + 0.4 * cosine;
(id.clone(), final_score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
#[cfg(target_arch = "wasm32")]
fn rerank_wasm(
&self,
results: &[(String, f32, Vec<f32>)],
top_k: usize,
) -> Vec<(String, f32)> {
let mut scored: Vec<(String, f32)> = results
.iter()
.map(|(id, cosine, _)| (id.clone(), *cosine))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reranker_empty_results() {
let reranker = AttentionReranker::new(4, 1);
let result = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &[], 5);
assert!(result.is_empty());
}
#[test]
fn test_reranker_single_result() {
let reranker = AttentionReranker::new(4, 1);
let results = vec![("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0])];
let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 5);
assert_eq!(ranked.len(), 1);
assert_eq!(ranked[0].0, "a");
}
#[test]
fn test_reranker_respects_top_k() {
let reranker = AttentionReranker::new(4, 1);
let results = vec![
("a".to_string(), 0.9, vec![1.0, 0.0, 0.0, 0.0]),
("b".to_string(), 0.8, vec![0.0, 1.0, 0.0, 0.0]),
("c".to_string(), 0.7, vec![0.0, 0.0, 1.0, 0.0]),
];
let ranked = reranker.rerank(&[1.0, 0.0, 0.0, 0.0], &results, 2);
assert_eq!(ranked.len(), 2);
}
#[test]
fn test_reranker_can_reorder() {
let reranker = AttentionReranker::new(4, 1);
let results = vec![
("a".to_string(), 0.70, vec![0.0, 0.0, 1.0, 0.0]),
("b".to_string(), 0.55, vec![1.0, 0.0, 0.0, 0.0]),
];
let query = vec![1.0, 0.0, 0.0, 0.0];
let ranked = reranker.rerank(&query, &results, 2);
assert_eq!(ranked.len(), 2);
assert_eq!(
ranked[0].0, "b",
"Attention re-ranking should promote the more query-aligned result"
);
}
}