use crate::estimator::TokenEstimator;
use crate::prompt::rag_dedup::RagEntry;
#[derive(Debug, Clone)]
pub struct RankedEntry {
pub entry: RagEntry,
pub score: f64,
}
#[derive(Debug)]
pub struct RankedRagResult {
pub entries: Vec<RankedEntry>,
pub excluded_count: usize,
pub total_tokens: u32,
}
#[must_use]
pub fn rank_and_fit_rag(entries: &[RagEntry], budget_tokens: u32) -> RankedRagResult {
if entries.is_empty() || budget_tokens == 0 {
return RankedRagResult {
entries: Vec::new(),
excluded_count: entries.len(),
total_tokens: 0,
};
}
let mut scored: Vec<RankedEntry> = entries
.iter()
.map(|entry| {
let entry_tokens = TokenEstimator::estimate_tokens(&entry.content);
let token_efficiency = compute_token_efficiency(entry_tokens, budget_tokens);
let score = token_efficiency.mul_add(0.3, f64::from(entry.relevance) * 0.7);
RankedEntry {
entry: entry.clone(),
score,
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut selected: Vec<RankedEntry> = Vec::new();
let mut total_tokens: u32 = 0;
let mut excluded_count: usize = 0;
for entry in scored {
let entry_tokens = TokenEstimator::estimate_tokens(&entry.entry.content);
if total_tokens + entry_tokens <= budget_tokens {
total_tokens += entry_tokens;
selected.push(entry);
} else {
excluded_count += 1;
}
}
RankedRagResult {
entries: selected,
excluded_count,
total_tokens,
}
}
fn compute_token_efficiency(entry_tokens: u32, budget_tokens: u32) -> f64 {
if budget_tokens == 0 {
return 0.0;
}
let ratio = f64::from(entry_tokens) / f64::from(budget_tokens);
(1.0 - ratio).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(content: &str, relevance: f32) -> RagEntry {
RagEntry {
content: content.to_string(),
relevance,
embedding: None,
}
}
#[test]
fn empty_entries_returns_empty() {
let result = rank_and_fit_rag(&[], 100);
assert!(result.entries.is_empty());
assert_eq!(result.excluded_count, 0);
}
#[test]
fn zero_budget_excludes_all() {
let entries = vec![make_entry("hello", 0.9)];
let result = rank_and_fit_rag(&entries, 0);
assert!(result.entries.is_empty());
assert_eq!(result.excluded_count, 1);
}
#[test]
fn single_entry_within_budget() {
let entries = vec![make_entry("hello world", 0.9)];
let result = rank_and_fit_rag(&entries, 100);
assert_eq!(result.entries.len(), 1);
assert_eq!(result.excluded_count, 0);
}
#[test]
fn compact_relevant_preferred_over_long_slightly_more_relevant() {
let entries = vec![
make_entry("Short answer", 0.85),
make_entry(&"Long verbose answer ".repeat(50), 0.90),
];
let result = rank_and_fit_rag(&entries, 20);
assert_eq!(result.entries.len(), 1);
assert!(result.entries[0].entry.content.contains("Short"));
}
#[test]
fn high_relevance_wins_within_budget() {
let entries = vec![
make_entry("Less relevant entry", 0.5),
make_entry("Very relevant entry", 0.99),
];
let result = rank_and_fit_rag(&entries, 1000);
assert!(result.entries[0].score >= result.entries[1].score);
}
#[test]
fn budget_constraint_respected() {
let entries = vec![
make_entry(&"word ".repeat(100), 0.9), make_entry(&"word ".repeat(100), 0.8), ];
let result = rank_and_fit_rag(&entries, 50);
assert!(result.entries.is_empty() || result.total_tokens <= 50);
}
#[test]
fn token_efficiency_correct() {
assert!((compute_token_efficiency(10, 100) - 0.9).abs() < f64::EPSILON);
assert!((compute_token_efficiency(50, 100) - 0.5).abs() < f64::EPSILON);
assert!((compute_token_efficiency(100, 100)).abs() < f64::EPSILON);
}
#[test]
fn total_tokens_tracked() {
let entries = vec![
make_entry("hello world", 0.9),
make_entry("foo bar baz", 0.8),
];
let result = rank_and_fit_rag(&entries, 1000);
assert!(result.total_tokens > 0);
}
}