use std::collections::HashMap;
use crate::CupelError;
use crate::diagnostics::CountRequirementShortfall;
use crate::model::{ContextBudget, ContextItem, ContextKind, ScoredItem};
use crate::slicer::count_quota::{CountQuotaEntry, ScarcityBehavior};
use crate::slicer::{KnapsackSlice, QuotaConstraint, QuotaConstraintMode, QuotaPolicy, Slicer};
#[derive(Clone)]
pub struct CountConstrainedKnapsackSlice {
entries: Vec<CountQuotaEntry>,
knapsack: KnapsackSlice,
scarcity: ScarcityBehavior,
}
impl std::fmt::Debug for CountConstrainedKnapsackSlice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CountConstrainedKnapsackSlice")
.field("entries", &self.entries)
.field("knapsack", &self.knapsack)
.field("scarcity", &self.scarcity)
.finish()
}
}
impl CountConstrainedKnapsackSlice {
pub fn new(
entries: Vec<CountQuotaEntry>,
knapsack: KnapsackSlice,
scarcity: ScarcityBehavior,
) -> Result<Self, CupelError> {
Ok(Self {
entries,
knapsack,
scarcity,
})
}
pub fn entries(&self) -> &[CountQuotaEntry] {
&self.entries
}
pub fn scarcity(&self) -> ScarcityBehavior {
self.scarcity
}
fn build_policy_maps(&self) -> (HashMap<ContextKind, usize>, HashMap<ContextKind, usize>) {
let mut require_map: HashMap<ContextKind, usize> = HashMap::new();
let mut cap_map: HashMap<ContextKind, usize> = HashMap::new();
for entry in &self.entries {
require_map.insert(entry.kind().clone(), entry.require_count());
cap_map.insert(entry.kind().clone(), entry.cap_count());
}
(require_map, cap_map)
}
}
impl Slicer for CountConstrainedKnapsackSlice {
fn slice(
&self,
sorted: &[ScoredItem],
budget: &ContextBudget,
) -> Result<Vec<ContextItem>, CupelError> {
if sorted.is_empty() || budget.target_tokens() <= 0 {
return Ok(Vec::new());
}
let (require_map, cap_map) = self.build_policy_maps();
let target_tokens = budget.target_tokens();
let mut partitions: HashMap<ContextKind, Vec<&ScoredItem>> = HashMap::new();
for si in sorted {
partitions
.entry(si.item.kind().clone())
.or_default()
.push(si);
}
for items in partitions.values_mut() {
items.sort_by(|a, b| b.score.total_cmp(&a.score));
}
let mut committed: Vec<ContextItem> = Vec::new();
let mut selected_count: HashMap<ContextKind, usize> = HashMap::new();
let mut pre_alloc_tokens: i64 = 0;
let mut committed_ids: std::collections::HashSet<*const ScoredItem> =
std::collections::HashSet::new();
let mut shortfalls: Vec<CountRequirementShortfall> = Vec::new();
let mut sorted_kinds: Vec<&ContextKind> = partitions.keys().collect();
sorted_kinds.sort_by_key(|k| k.as_str().to_ascii_lowercase());
for kind in &sorted_kinds {
let req_count = require_map.get(*kind).copied().unwrap_or(0);
if req_count == 0 {
continue;
}
let candidates = &partitions[*kind];
let mut satisfied = 0usize;
for &si in candidates.iter() {
if satisfied >= req_count {
break;
}
committed.push(si.item.clone());
committed_ids.insert(si as *const ScoredItem);
pre_alloc_tokens += si.item.tokens();
satisfied += 1;
}
selected_count.insert((*kind).clone(), satisfied);
if satisfied < req_count {
match self.scarcity {
ScarcityBehavior::Degrade => {
shortfalls.push(CountRequirementShortfall {
kind: kind.as_str().to_owned(),
required_count: req_count,
satisfied_count: satisfied,
});
}
ScarcityBehavior::Throw => {
return Err(CupelError::SlicerConfig(format!(
"CountConstrainedKnapsackSlice: kind {:?} requires {req_count} items \
but only {satisfied} candidates are available",
kind.as_str(),
)));
}
}
}
}
let residual_budget_tokens = (target_tokens - pre_alloc_tokens).max(0);
let remaining: Vec<ScoredItem> = sorted
.iter()
.filter(|si| !committed_ids.contains(&(*si as *const ScoredItem)))
.cloned()
.collect();
let score_by_content: HashMap<String, f64> = remaining
.iter()
.map(|si| (si.item.content().to_owned(), si.score))
.collect();
let mut phase2_selected: Vec<ContextItem> =
if residual_budget_tokens > 0 && !remaining.is_empty() {
let sub_budget = ContextBudget::new(
residual_budget_tokens,
residual_budget_tokens,
0,
HashMap::new(),
0.0,
)
.expect("residual budget is non-negative");
let mut selected = self.knapsack.slice(&remaining, &sub_budget)?;
selected.sort_by(|a, b| {
let sa = score_by_content.get(a.content()).copied().unwrap_or(0.0);
let sb = score_by_content.get(b.content()).copied().unwrap_or(0.0);
sb.total_cmp(&sa)
});
selected
} else {
Vec::new()
};
let mut filtered_phase2: Vec<ContextItem> = Vec::new();
for item in phase2_selected.drain(..) {
let kind = item.kind();
let cap = cap_map.get(kind).copied();
let current = selected_count.entry(kind.clone()).or_insert(0);
match cap {
Some(cap_count) if *current >= cap_count => {
}
_ => {
filtered_phase2.push(item);
*current += 1;
}
}
}
let _ = shortfalls;
let mut result = committed;
result.extend(filtered_phase2);
Ok(result)
}
fn is_count_quota(&self) -> bool {
true
}
fn count_cap_map(&self) -> std::collections::HashMap<ContextKind, usize> {
self.entries
.iter()
.map(|e| (e.kind().clone(), e.cap_count()))
.collect()
}
}
impl QuotaPolicy for CountConstrainedKnapsackSlice {
fn quota_constraints(&self) -> Vec<QuotaConstraint> {
self.entries
.iter()
.map(|e| QuotaConstraint {
kind: e.kind().clone(),
mode: QuotaConstraintMode::Count,
require: e.require_count() as f64,
cap: e.cap_count() as f64,
})
.collect()
}
}