use std::collections::HashMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeStruct};
use crate::CupelError;
use crate::model::{ContextBudget, ContextItem, ContextKind, ScoredItem};
use crate::slicer::{QuotaConstraint, QuotaConstraintMode, QuotaPolicy, Slicer};
#[derive(Debug, Clone)]
pub struct QuotaEntry {
kind: ContextKind,
require: f64,
cap: f64,
}
impl QuotaEntry {
pub fn new(kind: ContextKind, require: f64, cap: f64) -> Result<Self, CupelError> {
if !(0.0..=100.0).contains(&require) {
return Err(CupelError::SlicerConfig(format!(
"require ({require}) must be in [0.0, 100.0]"
)));
}
if !(0.0..=100.0).contains(&cap) {
return Err(CupelError::SlicerConfig(format!(
"cap ({cap}) must be in [0.0, 100.0]"
)));
}
if require > cap {
return Err(CupelError::SlicerConfig(format!(
"require ({require}) must be <= cap ({cap})"
)));
}
Ok(Self { kind, require, cap })
}
pub fn kind(&self) -> &ContextKind {
&self.kind
}
pub fn require(&self) -> f64 {
self.require
}
pub fn cap(&self) -> f64 {
self.cap
}
}
#[cfg(feature = "serde")]
impl Serialize for QuotaEntry {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("QuotaEntry", 3)?;
state.serialize_field("kind", &self.kind)?;
state.serialize_field("require", &self.require)?;
state.serialize_field("cap", &self.cap)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for QuotaEntry {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct Raw {
kind: ContextKind,
require: f64,
cap: f64,
}
let raw = Raw::deserialize(deserializer)?;
QuotaEntry::new(raw.kind, raw.require, raw.cap).map_err(serde::de::Error::custom)
}
}
pub struct QuotaSlice {
quotas: Vec<QuotaEntry>,
inner: Box<dyn Slicer>,
}
impl QuotaSlice {
pub fn new(quotas: Vec<QuotaEntry>, inner: Box<dyn Slicer>) -> Result<Self, CupelError> {
let require_sum: f64 = quotas.iter().map(|q| q.require()).sum();
if require_sum > 100.0 {
return Err(CupelError::SlicerConfig(format!(
"sum of require percentages ({require_sum}) must not exceed 100.0",
)));
}
Ok(Self { quotas, inner })
}
fn get_cap(&self, kind: &ContextKind) -> f64 {
self.quotas
.iter()
.find(|q| q.kind() == kind)
.map_or(100.0, |q| q.cap())
}
}
impl Slicer for QuotaSlice {
fn slice(
&self,
sorted: &[ScoredItem],
budget: &ContextBudget,
) -> Result<Vec<ContextItem>, CupelError> {
if sorted.is_empty() || budget.target_tokens() <= 0 {
return Ok(Vec::new());
}
let target_tokens = budget.target_tokens();
let mut partitions: HashMap<ContextKind, Vec<ScoredItem>> = HashMap::new();
for si in sorted {
partitions
.entry(si.item.kind().clone())
.or_default()
.push(si.clone());
}
let mut candidate_token_mass: HashMap<ContextKind, i64> = HashMap::new();
for (kind, items) in &partitions {
let mass: i64 = items.iter().map(|si| si.item.tokens()).sum();
candidate_token_mass.insert(kind.clone(), mass);
}
let mut require_tokens: HashMap<ContextKind, i64> = HashMap::new();
let mut cap_tokens: HashMap<ContextKind, i64> = HashMap::new();
for q in &self.quotas {
require_tokens.insert(
q.kind().clone(),
(q.require() / 100.0 * target_tokens as f64).floor() as i64,
);
cap_tokens.insert(
q.kind().clone(),
(q.cap() / 100.0 * target_tokens as f64).floor() as i64,
);
}
let total_required: i64 = require_tokens.values().sum();
let unassigned_budget = (target_tokens - total_required).max(0);
let mut sorted_kinds: Vec<&ContextKind> = partitions.keys().collect();
sorted_kinds.sort_by_key(|k| k.as_str().to_ascii_lowercase());
let mut total_mass_for_distribution: i64 = 0;
for kind in &sorted_kinds {
let cap = cap_tokens.get(*kind).copied().unwrap_or(target_tokens);
let require = require_tokens.get(*kind).copied().unwrap_or(0);
if cap > require {
total_mass_for_distribution +=
candidate_token_mass.get(*kind).copied().unwrap_or(0);
}
}
let mut kind_budgets: HashMap<ContextKind, i64> = HashMap::new();
for kind in &sorted_kinds {
let require = require_tokens.get(*kind).copied().unwrap_or(0);
let cap = cap_tokens.get(*kind).copied().unwrap_or(target_tokens);
let proportional = if total_mass_for_distribution > 0 && cap > require {
let mass = candidate_token_mass.get(*kind).copied().unwrap_or(0);
(unassigned_budget as f64 * mass as f64 / total_mass_for_distribution as f64)
.floor() as i64
} else {
0
};
let mut kind_budget = require + proportional;
if kind_budget > cap {
kind_budget = cap;
}
kind_budgets.insert((*kind).clone(), kind_budget);
}
let mut all_selected: Vec<ContextItem> = Vec::new();
for kind in &sorted_kinds {
let items = match partitions.get(*kind) {
Some(items) => items,
None => continue,
};
let kind_budget = kind_budgets.get(kind).copied().unwrap_or(0);
if kind_budget <= 0 {
continue;
}
let cap = (self.get_cap(kind) / 100.0 * target_tokens as f64).floor() as i64;
let kind_budget = kind_budget.min(cap);
let sub_budget = ContextBudget::new(cap, kind_budget, 0, HashMap::new(), 0.0)
.expect("sub-budget should be valid since cap >= kind_budget >= 0");
let selected = self.inner.slice(items, &sub_budget)?;
all_selected.extend(selected);
}
Ok(all_selected)
}
fn is_quota(&self) -> bool {
true
}
}
impl QuotaPolicy for QuotaSlice {
fn quota_constraints(&self) -> Vec<QuotaConstraint> {
self.quotas
.iter()
.map(|q| QuotaConstraint {
kind: q.kind().clone(),
mode: QuotaConstraintMode::Percentage,
require: q.require(),
cap: q.cap(),
})
.collect()
}
}