use crate::CupelError;
use crate::model::{ContextBudget, ContextItem, ScoredItem};
use crate::slicer::Slicer;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KnapsackSlice {
bucket_size: i64,
}
impl KnapsackSlice {
pub fn new(bucket_size: i64) -> Result<Self, CupelError> {
if bucket_size <= 0 {
return Err(CupelError::SlicerConfig(
"bucket_size must be > 0".to_owned(),
));
}
Ok(Self { bucket_size })
}
pub fn with_default_bucket_size() -> Self {
Self { bucket_size: 100 }
}
}
impl Slicer for KnapsackSlice {
fn is_knapsack(&self) -> bool {
true
}
fn slice(
&self,
sorted: &[ScoredItem],
budget: &ContextBudget,
) -> Result<Vec<ContextItem>, CupelError> {
if sorted.is_empty() || budget.target_tokens() <= 0 {
return Ok(Vec::new());
}
let mut zero_token_items: Vec<ContextItem> = Vec::new();
let mut candidates: Vec<&ScoredItem> = Vec::new();
for si in sorted {
match si.item.tokens().cmp(&0) {
std::cmp::Ordering::Equal => zero_token_items.push(si.item.clone()),
std::cmp::Ordering::Greater => candidates.push(si),
std::cmp::Ordering::Less => {}
}
}
if candidates.is_empty() {
return Ok(zero_token_items);
}
let n = candidates.len();
let mut weights = Vec::with_capacity(n);
let mut values = Vec::with_capacity(n);
for c in &candidates {
weights.push(c.item.tokens());
let v = (c.score * 10000.0).floor() as i64;
values.push(v.max(0));
}
let capacity = (budget.target_tokens() / self.bucket_size) as usize;
if capacity == 0 {
return Ok(zero_token_items);
}
let cells = (capacity as u64) * (n as u64);
if cells > 50_000_000 {
return Err(CupelError::TableTooLarge {
candidates: n,
capacity,
cells,
});
}
let discretized_weights: Vec<usize> = weights
.iter()
.map(|&w| ((w + self.bucket_size - 1) / self.bucket_size) as usize)
.collect();
let mut dp = vec![0i64; capacity + 1];
let stride = capacity + 1;
let mut keep = vec![false; n * stride];
for i in 0..n {
let dw = discretized_weights[i];
let dv = values[i];
for w in (dw..=capacity).rev() {
let with_item = dp[w - dw] + dv;
if with_item > dp[w] {
dp[w] = with_item;
keep[i * stride + w] = true;
}
}
}
let mut selected: Vec<ContextItem> = Vec::new();
let mut remaining_capacity = capacity;
for i in (0..n).rev() {
if keep[i * stride + remaining_capacity] {
selected.push(candidates[i].item.clone());
remaining_capacity -= discretized_weights[i];
}
}
let mut result = zero_token_items;
result.extend(selected);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::model::{ContextBudget, ScoredItem};
use crate::{ContextItemBuilder, CupelError};
use super::KnapsackSlice;
use crate::slicer::Slicer;
#[test]
fn knapsack_table_too_large() {
let slicer = KnapsackSlice::new(1).expect("bucket_size=1 is valid");
let items: Vec<ScoredItem> = (0..1001)
.map(|i| ScoredItem {
item: ContextItemBuilder::new(format!("item-{i}"), 1)
.build()
.expect("valid item"),
score: 0.5,
})
.collect();
let budget =
ContextBudget::new(100_000, 50_001, 0, HashMap::new(), 0.0).expect("valid budget");
let result = slicer.slice(&items, &budget);
match result {
Err(CupelError::TableTooLarge {
candidates,
capacity,
cells,
}) => {
assert_eq!(candidates, 1001);
assert_eq!(capacity, 50_001);
assert!(cells > 50_000_000, "cells={cells} should exceed limit");
}
other => panic!("expected Err(TableTooLarge), got {other:?}"),
}
}
}