use std::collections::HashMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeStruct};
use crate::CupelError;
use crate::model::ContextKind;
#[derive(Debug, Clone, PartialEq)]
pub struct ContextBudget {
max_tokens: i64,
target_tokens: i64,
output_reserve: i64,
reserved_slots: HashMap<ContextKind, i64>,
estimation_safety_margin_percent: f64,
}
impl ContextBudget {
pub fn new(
max_tokens: i64,
target_tokens: i64,
output_reserve: i64,
reserved_slots: HashMap<ContextKind, i64>,
estimation_safety_margin_percent: f64,
) -> Result<Self, CupelError> {
if max_tokens < 0 {
return Err(CupelError::InvalidBudget(format!(
"max_tokens ({max_tokens}) must be >= 0"
)));
}
if target_tokens < 0 {
return Err(CupelError::InvalidBudget(format!(
"target_tokens ({target_tokens}) must be >= 0"
)));
}
if target_tokens > max_tokens {
return Err(CupelError::InvalidBudget(format!(
"target_tokens ({target_tokens}) must be <= max_tokens ({max_tokens})"
)));
}
if output_reserve < 0 {
return Err(CupelError::InvalidBudget(format!(
"output_reserve ({output_reserve}) must be >= 0"
)));
}
if output_reserve > max_tokens {
return Err(CupelError::InvalidBudget(format!(
"output_reserve ({output_reserve}) must be <= max_tokens ({max_tokens})"
)));
}
if !(0.0..=100.0).contains(&estimation_safety_margin_percent) {
return Err(CupelError::InvalidBudget(format!(
"estimation_safety_margin_percent ({estimation_safety_margin_percent}) must be in [0.0, 100.0]"
)));
}
for (kind, &count) in &reserved_slots {
if count < 0 {
return Err(CupelError::InvalidBudget(format!(
"reserved slot count for kind '{}' must be >= 0",
kind,
)));
}
}
Ok(Self {
max_tokens,
target_tokens,
output_reserve,
reserved_slots,
estimation_safety_margin_percent,
})
}
pub fn max_tokens(&self) -> i64 {
self.max_tokens
}
pub fn target_tokens(&self) -> i64 {
self.target_tokens
}
pub fn output_reserve(&self) -> i64 {
self.output_reserve
}
pub fn reserved_slots(&self) -> &HashMap<ContextKind, i64> {
&self.reserved_slots
}
pub fn estimation_safety_margin_percent(&self) -> f64 {
self.estimation_safety_margin_percent
}
#[must_use]
pub fn total_reserved(&self) -> i64 {
self.output_reserve + self.reserved_slots.values().sum::<i64>()
}
#[must_use]
pub fn unreserved_capacity(&self) -> i64 {
self.max_tokens - self.total_reserved()
}
#[must_use]
pub fn has_capacity(&self) -> bool {
self.unreserved_capacity() > 0
}
}
#[cfg(feature = "serde")]
impl Serialize for ContextBudget {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("ContextBudget", 5)?;
state.serialize_field("max_tokens", &self.max_tokens)?;
state.serialize_field("target_tokens", &self.target_tokens)?;
state.serialize_field("output_reserve", &self.output_reserve)?;
state.serialize_field("reserved_slots", &self.reserved_slots)?;
state.serialize_field(
"estimation_safety_margin_percent",
&self.estimation_safety_margin_percent,
)?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for ContextBudget {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct Raw {
max_tokens: i64,
target_tokens: i64,
output_reserve: i64,
#[serde(default)]
reserved_slots: HashMap<ContextKind, i64>,
estimation_safety_margin_percent: f64,
}
let raw = Raw::deserialize(deserializer)?;
ContextBudget::new(
raw.max_tokens,
raw.target_tokens,
raw.output_reserve,
raw.reserved_slots,
raw.estimation_safety_margin_percent,
)
.map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn budget_with_slots() -> ContextBudget {
let mut slots = HashMap::new();
slots.insert(ContextKind::message(), 2000);
slots.insert(ContextKind::document(), 3000);
ContextBudget::new(128_000, 100_000, 4_096, slots, 5.0).unwrap()
}
#[test]
fn total_reserved_includes_output_reserve_and_slots() {
let b = budget_with_slots();
assert_eq!(b.total_reserved(), 4_096 + 2000 + 3000);
}
#[test]
fn total_reserved_no_slots_returns_output_reserve() {
let b = ContextBudget::new(4096, 3000, 1024, HashMap::new(), 0.0).unwrap();
assert_eq!(b.total_reserved(), 1024);
}
#[test]
fn unreserved_capacity_positive() {
let b = budget_with_slots();
assert_eq!(b.unreserved_capacity(), 128_000 - 4_096 - 5_000);
}
#[test]
fn unreserved_capacity_negative_over_committed() {
let mut slots = HashMap::new();
slots.insert(ContextKind::message(), 90_000);
let b = ContextBudget::new(100_000, 80_000, 20_000, slots, 0.0).unwrap();
assert_eq!(b.unreserved_capacity(), -10_000);
}
#[test]
fn has_capacity_true_when_positive() {
let b = budget_with_slots();
assert!(b.has_capacity());
}
#[test]
fn has_capacity_false_when_zero() {
let mut slots = HashMap::new();
slots.insert(ContextKind::message(), 80_000);
let b = ContextBudget::new(100_000, 80_000, 20_000, slots, 0.0).unwrap();
assert!(!b.has_capacity());
}
#[test]
fn has_capacity_false_when_negative() {
let mut slots = HashMap::new();
slots.insert(ContextKind::message(), 90_000);
let b = ContextBudget::new(100_000, 80_000, 20_000, slots, 0.0).unwrap();
assert!(!b.has_capacity());
}
}