use std::collections::HashMap;
use crate::CupelError;
use crate::diagnostics::CountRequirementShortfall;
use crate::model::{ContextBudget, ContextItem, ContextKind, ScoredItem};
use crate::slicer::{QuotaConstraint, QuotaConstraintMode, QuotaPolicy, Slicer};
#[derive(Debug, Clone, Copy, Default)]
pub enum ScarcityBehavior {
#[default]
Degrade,
Throw,
}
#[derive(Debug, Clone)]
pub struct CountQuotaEntry {
kind: ContextKind,
require_count: usize,
cap_count: usize,
}
impl CountQuotaEntry {
pub fn new(
kind: ContextKind,
require_count: usize,
cap_count: usize,
) -> Result<Self, CupelError> {
if cap_count == 0 && require_count > 0 {
return Err(CupelError::SlicerConfig(format!(
"kind {:?}: cap_count is 0 but require_count is {require_count}; \
a zero cap with a positive requirement can never be satisfied",
kind.as_str(),
)));
}
if require_count > cap_count {
return Err(CupelError::SlicerConfig(format!(
"kind {:?}: require_count ({require_count}) must be <= cap_count ({cap_count})",
kind.as_str(),
)));
}
Ok(Self {
kind,
require_count,
cap_count,
})
}
pub fn kind(&self) -> &ContextKind {
&self.kind
}
pub fn require_count(&self) -> usize {
self.require_count
}
pub fn cap_count(&self) -> usize {
self.cap_count
}
}
pub struct CountQuotaSlice {
entries: Vec<CountQuotaEntry>,
inner: Box<dyn Slicer>,
scarcity: ScarcityBehavior,
}
impl std::fmt::Debug for CountQuotaSlice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CountQuotaSlice")
.field("entries", &self.entries)
.field("inner", &"<dyn Slicer>")
.field("scarcity", &self.scarcity)
.finish()
}
}
impl CountQuotaSlice {
pub fn new(
entries: Vec<CountQuotaEntry>,
inner: Box<dyn Slicer>,
scarcity: ScarcityBehavior,
) -> Result<Self, CupelError> {
if inner.is_knapsack() {
return Err(CupelError::SlicerConfig(
"CountQuotaSlice does not support KnapsackSlice as the inner slicer in this \
version. Use GreedySlice as the inner slicer. A CountConstrainedKnapsackSlice \
will be provided in a future release."
.to_owned(),
));
}
Ok(Self {
entries,
inner,
scarcity,
})
}
pub fn entries(&self) -> &[CountQuotaEntry] {
&self.entries
}
pub fn scarcity(&self) -> ScarcityBehavior {
self.scarcity
}
fn build_policy_maps(&self) -> (HashMap<ContextKind, usize>, HashMap<ContextKind, usize>) {
let mut require_map: HashMap<ContextKind, usize> = HashMap::new();
let mut cap_map: HashMap<ContextKind, usize> = HashMap::new();
for entry in &self.entries {
require_map.insert(entry.kind().clone(), entry.require_count());
cap_map.insert(entry.kind().clone(), entry.cap_count());
}
(require_map, cap_map)
}
}
impl Slicer for CountQuotaSlice {
fn slice(
&self,
sorted: &[ScoredItem],
budget: &ContextBudget,
) -> Result<Vec<ContextItem>, CupelError> {
if sorted.is_empty() || budget.target_tokens() <= 0 {
return Ok(Vec::new());
}
let (require_map, cap_map) = self.build_policy_maps();
let target_tokens = budget.target_tokens();
let mut partitions: HashMap<ContextKind, Vec<&ScoredItem>> = HashMap::new();
for si in sorted {
partitions
.entry(si.item.kind().clone())
.or_default()
.push(si);
}
for items in partitions.values_mut() {
items.sort_by(|a, b| b.score.total_cmp(&a.score));
}
let mut committed: Vec<ContextItem> = Vec::new();
let mut selected_count: HashMap<ContextKind, usize> = HashMap::new();
let mut pre_alloc_tokens: i64 = 0;
let mut committed_ids: std::collections::HashSet<*const ScoredItem> =
std::collections::HashSet::new();
let mut shortfalls: Vec<CountRequirementShortfall> = Vec::new();
let mut sorted_kinds: Vec<&ContextKind> = partitions.keys().collect();
sorted_kinds.sort_by_key(|k| k.as_str().to_ascii_lowercase());
for kind in &sorted_kinds {
let req_count = require_map.get(*kind).copied().unwrap_or(0);
if req_count == 0 {
continue;
}
let candidates = &partitions[*kind];
let mut satisfied = 0usize;
for &si in candidates.iter() {
if satisfied >= req_count {
break;
}
committed.push(si.item.clone());
committed_ids.insert(si as *const ScoredItem);
pre_alloc_tokens += si.item.tokens();
satisfied += 1;
}
selected_count.insert((*kind).clone(), satisfied);
if satisfied < req_count {
match self.scarcity {
ScarcityBehavior::Degrade => {
shortfalls.push(CountRequirementShortfall {
kind: kind.as_str().to_owned(),
required_count: req_count,
satisfied_count: satisfied,
});
}
ScarcityBehavior::Throw => {
return Err(CupelError::SlicerConfig(format!(
"CountQuotaSlice: kind {:?} requires {req_count} items but only \
{satisfied} candidates are available",
kind.as_str(),
)));
}
}
}
}
let residual_budget_tokens = (target_tokens - pre_alloc_tokens).max(0);
let remaining: Vec<ScoredItem> = sorted
.iter()
.filter(|si| !committed_ids.contains(&(*si as *const ScoredItem)))
.cloned()
.collect();
let mut phase2_selected: Vec<ContextItem> =
if residual_budget_tokens > 0 && !remaining.is_empty() {
let sub_budget = ContextBudget::new(
residual_budget_tokens,
residual_budget_tokens,
0,
HashMap::new(),
0.0,
)
.expect("residual budget is non-negative");
self.inner.slice(&remaining, &sub_budget)?
} else {
Vec::new()
};
let mut filtered_phase2: Vec<ContextItem> = Vec::new();
for item in phase2_selected.drain(..) {
let kind = item.kind();
let cap = cap_map.get(kind).copied();
let current = selected_count.entry(kind.clone()).or_insert(0);
match cap {
Some(cap_count) if *current >= cap_count => {
}
_ => {
filtered_phase2.push(item);
*current += 1;
}
}
}
let mut result = committed;
result.extend(filtered_phase2);
Ok(result)
}
fn is_count_quota(&self) -> bool {
true
}
fn count_cap_map(&self) -> std::collections::HashMap<ContextKind, usize> {
self.entries
.iter()
.map(|e| (e.kind().clone(), e.cap_count()))
.collect()
}
}
impl QuotaPolicy for CountQuotaSlice {
fn quota_constraints(&self) -> Vec<QuotaConstraint> {
self.entries
.iter()
.map(|e| QuotaConstraint {
kind: e.kind().clone(),
mode: QuotaConstraintMode::Count,
require: e.require_count() as f64,
cap: e.cap_count() as f64,
})
.collect()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::model::ContextBudget;
use crate::{ContextItemBuilder, ContextKind, GreedySlice, KnapsackSlice, ScoredItem};
fn make_item(content: &str, tokens: i64, kind: &str, score: f64) -> ScoredItem {
ScoredItem {
item: ContextItemBuilder::new(content, tokens)
.kind(ContextKind::new(kind).unwrap())
.build()
.unwrap(),
score,
}
}
fn make_budget(target: i64) -> ContextBudget {
ContextBudget::new(target, target, 0, HashMap::new(), 0.0).unwrap()
}
#[test]
fn count_quota_construction_rejects_knapsack_inner() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 1, 2).unwrap()];
let inner = Box::new(KnapsackSlice::with_default_bucket_size());
let result = CountQuotaSlice::new(entries, inner, ScarcityBehavior::Degrade);
match result {
Err(CupelError::SlicerConfig(msg)) => {
assert!(
msg.contains("CountQuotaSlice"),
"expected message to name CountQuotaSlice"
);
assert!(
msg.contains("KnapsackSlice"),
"expected message to name KnapsackSlice"
);
assert!(
msg.contains("GreedySlice"),
"expected message to name GreedySlice"
);
}
other => panic!("expected Err(SlicerConfig), got {other:?}"),
}
}
#[test]
fn entry_rejects_require_greater_than_cap() {
let result = CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 5, 3);
assert!(matches!(result, Err(CupelError::SlicerConfig(_))));
}
#[test]
fn entry_rejects_zero_cap_with_positive_require() {
let result = CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 1, 0);
assert!(matches!(result, Err(CupelError::SlicerConfig(_))));
}
#[test]
fn entry_allows_zero_require_with_positive_cap() {
let result = CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 0, 3);
assert!(result.is_ok());
}
#[test]
fn entry_allows_zero_require_and_zero_cap() {
let result = CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 0, 0);
assert!(result.is_ok());
}
#[test]
fn count_quota_v1_baseline_count_satisfaction() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 2, 4).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![
make_item("tool-a", 100, "tool", 0.9),
make_item("tool-b", 100, "tool", 0.7),
make_item("tool-c", 100, "tool", 0.5),
];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
assert_eq!(selected.len(), 3, "all 3 tool items should be selected");
let contents: Vec<&str> = selected.iter().map(|i| i.content()).collect();
assert!(
contents.contains(&"tool-a"),
"tool-a (score 0.9) must be committed"
);
assert!(
contents.contains(&"tool-b"),
"tool-b (score 0.7) must be committed"
);
}
#[test]
fn count_quota_v2_cap_exclusion() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 0, 1).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![
make_item("tool-a", 100, "tool", 0.9),
make_item("tool-b", 100, "tool", 0.7),
make_item("tool-c", 100, "tool", 0.5),
];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
let tool_items: Vec<_> = selected
.iter()
.filter(|i| i.kind().as_str().eq_ignore_ascii_case("tool"))
.collect();
assert_eq!(
tool_items.len(),
1,
"cap of 1 should exclude 2 tool items; got {}",
tool_items.len()
);
}
#[test]
fn count_quota_v4_scarcity_degrade() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 3, 5).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![make_item("tool-a", 100, "tool", 0.9)];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
assert_eq!(selected.len(), 1, "should select the 1 available item");
assert_eq!(selected[0].content(), "tool-a");
}
#[test]
fn count_quota_v4_scarcity_throw() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 3, 5).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Throw).unwrap();
let items = vec![make_item("tool-a", 100, "tool", 0.9)];
let budget = make_budget(1000);
let result = slicer.slice(&items, &budget);
assert!(
matches!(result, Err(CupelError::SlicerConfig(_))),
"Throw mode should return error on scarcity"
);
}
#[test]
fn count_quota_empty_input_returns_empty() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 2, 4).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let selected = slicer.slice(&[], &make_budget(1000)).unwrap();
assert!(selected.is_empty());
}
#[test]
fn count_quota_zero_budget_returns_empty() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 2, 4).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![make_item("tool-a", 100, "tool", 0.9)];
let _ = (items, slicer);
}
#[test]
fn count_quota_no_require_only_cap() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 0, 2).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![
make_item("tool-a", 100, "tool", 0.9),
make_item("tool-b", 100, "tool", 0.7),
make_item("tool-c", 100, "tool", 0.5),
];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
let tool_count = selected
.iter()
.filter(|i| i.kind().as_str().eq_ignore_ascii_case("tool"))
.count();
assert_eq!(
tool_count, 2,
"cap of 2 should allow exactly 2 tool items; got {tool_count}"
);
}
#[test]
fn count_quota_require_equals_cap_fills_exactly() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 2, 2).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![
make_item("tool-a", 100, "tool", 0.9),
make_item("tool-b", 100, "tool", 0.7),
make_item("tool-c", 100, "tool", 0.5),
];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
let tool_count = selected
.iter()
.filter(|i| i.kind().as_str().eq_ignore_ascii_case("tool"))
.count();
assert_eq!(
tool_count, 2,
"require==cap should yield exactly 2 tool items; got {tool_count}"
);
let contents: Vec<&str> = selected.iter().map(|i| i.content()).collect();
assert!(contents.contains(&"tool-a"));
assert!(contents.contains(&"tool-b"));
assert!(
!contents.contains(&"tool-c"),
"tool-c must be excluded by cap"
);
}
#[test]
fn count_quota_mixed_kinds_independent() {
let entries = vec![
CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 1, 2).unwrap(),
CountQuotaEntry::new(ContextKind::new("system").unwrap(), 1, 1).unwrap(),
];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
let items = vec![
make_item("tool-a", 100, "tool", 0.9),
make_item("tool-b", 100, "tool", 0.7),
make_item("sys-a", 100, "system", 0.8),
make_item("sys-b", 100, "system", 0.6),
];
let budget = make_budget(1000);
let selected = slicer.slice(&items, &budget).unwrap();
let tool_count = selected
.iter()
.filter(|i| i.kind().as_str().eq_ignore_ascii_case("tool"))
.count();
let sys_count = selected
.iter()
.filter(|i| i.kind().as_str().eq_ignore_ascii_case("system"))
.count();
assert!(
tool_count <= 2,
"tool count must not exceed cap of 2; got {tool_count}"
);
assert!(
tool_count >= 1,
"tool count must satisfy require of 1; got {tool_count}"
);
assert_eq!(
sys_count, 1,
"system cap of 1 must yield exactly 1; got {sys_count}"
);
}
#[test]
fn is_knapsack_false_for_count_quota_slice() {
let entries = vec![CountQuotaEntry::new(ContextKind::new("tool").unwrap(), 1, 2).unwrap()];
let slicer =
CountQuotaSlice::new(entries, Box::new(GreedySlice), ScarcityBehavior::Degrade)
.unwrap();
assert!(
!slicer.is_knapsack(),
"CountQuotaSlice.is_knapsack() must return false"
);
}
#[test]
fn is_knapsack_true_for_knapsack_slice() {
let slicer = KnapsackSlice::with_default_bucket_size();
assert!(
slicer.is_knapsack(),
"KnapsackSlice.is_knapsack() must return true"
);
}
#[test]
fn is_knapsack_false_for_greedy_slice() {
assert!(
!GreedySlice.is_knapsack(),
"GreedySlice.is_knapsack() must return false"
);
}
}