use std::collections::HashMap;
use std::path::PathBuf;
use crate::chunk::CodeChunk;
pub trait RankingLayer: Send + Sync {
fn apply(&self, items: &mut Vec<(usize, f32)>, chunks: &[CodeChunk]);
}
pub fn apply_chain(
items: &mut Vec<(usize, f32)>,
chunks: &[CodeChunk],
layers: &[Box<dyn RankingLayer>],
) {
for layer in layers {
layer.apply(items, chunks);
}
}
pub struct PageRankBoost {
pagerank: HashMap<String, f32>,
alpha: f32,
}
impl PageRankBoost {
#[must_use]
pub fn new(pagerank: HashMap<String, f32>, alpha: f32) -> Self {
Self { pagerank, alpha }
}
}
impl RankingLayer for PageRankBoost {
fn apply(&self, items: &mut Vec<(usize, f32)>, chunks: &[CodeChunk]) {
for (idx, score) in items.iter_mut() {
if let Some(chunk) = chunks.get(*idx) {
let rank = crate::hybrid::lookup_rank_for_chunk(
&self.pagerank,
&chunk.file_path,
&chunk.name,
);
*score *= crate::hybrid::pagerank_boost_factor(rank, self.alpha);
}
}
items.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
}
}
pub struct PathPenalty {
corpus_root: PathBuf,
}
impl PathPenalty {
#[must_use]
pub fn new(corpus_root: PathBuf) -> Self {
Self { corpus_root }
}
}
impl RankingLayer for PathPenalty {
fn apply(&self, items: &mut Vec<(usize, f32)>, chunks: &[CodeChunk]) {
let prefix = self.corpus_root.to_string_lossy().into_owned();
let trimmed_root = prefix.trim_end_matches('/');
for (idx, score) in items.iter_mut() {
if let Some(chunk) = chunks.get(*idx) {
let rel = chunk
.file_path
.strip_prefix(trimmed_root)
.map(|s| s.trim_start_matches('/'))
.unwrap_or(&chunk.file_path);
let penalty = crate::encoder::ripvec::penalties::file_path_penalty(rel);
*score *= penalty;
}
}
items.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
}
}
pub struct Threshold {
pub min_score: f32,
}
impl RankingLayer for Threshold {
fn apply(&self, items: &mut Vec<(usize, f32)>, _chunks: &[CodeChunk]) {
items.retain(|(_, score)| *score >= self.min_score);
}
}
pub struct TopK {
pub k: usize,
}
impl RankingLayer for TopK {
fn apply(&self, items: &mut Vec<(usize, f32)>, _chunks: &[CodeChunk]) {
if self.k > 0 {
items.truncate(self.k);
}
}
}
pub struct CrossEncoderRerank {
reranker: std::sync::Arc<crate::rerank::Reranker>,
query: String,
candidates: usize,
blend: f32,
}
impl CrossEncoderRerank {
#[must_use]
pub fn new(
reranker: std::sync::Arc<crate::rerank::Reranker>,
query: String,
candidates: usize,
) -> Self {
Self {
reranker,
query,
candidates,
blend: 0.7,
}
}
#[must_use]
pub fn with_blend(mut self, blend: f32) -> Self {
self.blend = blend.clamp(0.0, 1.0);
self
}
}
impl RankingLayer for CrossEncoderRerank {
fn apply(&self, items: &mut Vec<(usize, f32)>, chunks: &[CodeChunk]) {
if items.len() > self.candidates {
items.truncate(self.candidates);
}
if items.is_empty() {
return;
}
let pairs: Vec<(&str, &str)> = items
.iter()
.filter_map(|&(idx, _)| {
chunks
.get(idx)
.map(|c| (self.query.as_str(), c.content.as_str()))
})
.collect();
let Ok(scores) = self.reranker.score_pairs(&pairs) else {
return;
};
for (item, &cross_score) in items.iter_mut().zip(scores.iter()) {
let bi_score = item.1;
item.1 = self.blend * cross_score + (1.0 - self.blend) * bi_score;
}
items.sort_unstable_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_chunk(file: &str, name: &str) -> CodeChunk {
CodeChunk {
file_path: file.into(),
name: name.into(),
kind: "function".into(),
start_line: 1,
end_line: 10,
content: String::new(),
enriched_content: String::new(),
}
}
#[test]
fn threshold_drops_below_min() {
let chunks = vec![dummy_chunk("a.rs", "f"), dummy_chunk("b.rs", "g")];
let mut items = vec![(0, 0.9), (1, 0.3)];
Threshold { min_score: 0.5 }.apply(&mut items, &chunks);
assert_eq!(items, vec![(0, 0.9)]);
}
#[test]
fn topk_truncates() {
let chunks = vec![
dummy_chunk("a.rs", "f"),
dummy_chunk("b.rs", "g"),
dummy_chunk("c.rs", "h"),
];
let mut items = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
TopK { k: 2 }.apply(&mut items, &chunks);
assert_eq!(items, vec![(0, 0.9), (1, 0.8)]);
}
#[test]
fn topk_zero_keeps_all() {
let chunks = vec![dummy_chunk("a.rs", "f"), dummy_chunk("b.rs", "g")];
let mut items = vec![(0, 0.9), (1, 0.8)];
TopK { k: 0 }.apply(&mut items, &chunks);
assert_eq!(items.len(), 2);
}
#[test]
fn chain_runs_layers_in_order() {
let chunks = vec![
dummy_chunk("a.rs", "f"),
dummy_chunk("b.rs", "g"),
dummy_chunk("c.rs", "h"),
];
let mut items = vec![(0, 1.0), (1, 0.6), (2, 0.3)];
let layers: Vec<Box<dyn RankingLayer>> = vec![
Box::new(Threshold { min_score: 0.5 }),
Box::new(TopK { k: 1 }),
];
apply_chain(&mut items, &chunks, &layers);
assert_eq!(items, vec![(0, 1.0)]);
}
#[test]
fn pagerank_boost_layer_reorders() {
let chunks = vec![
dummy_chunk("important.rs", "a"),
dummy_chunk("obscure.rs", "b"),
];
let mut items = vec![(0, 0.8), (1, 0.8)];
let mut pr = HashMap::new();
pr.insert("important.rs".to_string(), 1.0); pr.insert("obscure.rs".to_string(), 0.1); PageRankBoost::new(pr, 0.3).apply(&mut items, &chunks);
assert_eq!(items[0].0, 0);
assert!(items[0].1 > items[1].1);
}
}