use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CandidateKind {
Memory,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankedCandidate {
pub source_id: String,
pub source_kind: CandidateKind,
pub raw_score: f64,
pub adjusted_score: f64,
pub token_cost: usize,
}
pub fn top_k_within_budget(
candidates: &[RankedCandidate],
k: usize,
budget: usize,
) -> Vec<RankedCandidate> {
let mut result = Vec::with_capacity(k);
let mut spent = 0;
for c in candidates.iter().take(k) {
if spent + c.token_cost > budget {
break;
}
spent += c.token_cost;
result.push(c.clone());
}
result
}
pub fn top_k_with_pinned(
candidates: &[RankedCandidate],
k: usize,
budget: usize,
always_include: &[&str],
) -> Vec<RankedCandidate> {
let mut pinned = Vec::new();
let mut rest = Vec::new();
for c in candidates {
if always_include.contains(&c.source_id.as_str()) {
pinned.push(c.clone());
} else {
rest.push(c.clone());
}
}
let mut result = Vec::with_capacity(k);
let mut spent = 0;
for c in &pinned {
spent += c.token_cost;
result.push(c.clone());
}
for c in rest.iter() {
if result.len() >= k || spent + c.token_cost > budget {
break;
}
spent += c.token_cost;
result.push(c.clone());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ranked_candidate_ordering() {
let a = RankedCandidate {
source_id: "tool_a".into(),
source_kind: CandidateKind::Tool,
raw_score: 0.9,
adjusted_score: 0.85,
token_cost: 50,
};
let b = RankedCandidate {
source_id: "mem_b".into(),
source_kind: CandidateKind::Memory,
raw_score: 0.7,
adjusted_score: 0.95,
token_cost: 100,
};
let mut v = [a, b];
v.sort_by(|x, y| y.adjusted_score.partial_cmp(&x.adjusted_score).unwrap());
assert_eq!(v[0].source_id, "mem_b");
}
#[test]
fn top_k_prunes_to_budget() {
let candidates: Vec<RankedCandidate> = (0..20)
.map(|i| RankedCandidate {
source_id: format!("item_{i}"),
source_kind: CandidateKind::Tool,
raw_score: 1.0 - (i as f64 * 0.05),
adjusted_score: 1.0 - (i as f64 * 0.05),
token_cost: 50,
})
.collect();
let pruned = top_k_within_budget(&candidates, 10, 300);
assert!(pruned.len() <= 10);
assert!(pruned.iter().map(|c| c.token_cost).sum::<usize>() <= 300);
}
#[test]
fn always_include_list_preserved() {
let candidates = vec![
RankedCandidate {
source_id: "memory_store".into(),
source_kind: CandidateKind::Tool,
raw_score: 0.1,
adjusted_score: 0.1,
token_cost: 30,
},
RankedCandidate {
source_id: "random_tool".into(),
source_kind: CandidateKind::Tool,
raw_score: 0.9,
adjusted_score: 0.9,
token_cost: 30,
},
];
let always_include = &["memory_store"];
let pruned = top_k_with_pinned(&candidates, 1, 100, always_include);
assert!(pruned.iter().any(|c| c.source_id == "memory_store"));
}
}