use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BudgetStatus {
Healthy,
Constrained,
Exhausted,
}
impl BudgetStatus {
pub fn allow_llm(self) -> bool {
matches!(self, Self::Healthy | Self::Constrained)
}
pub fn should_stop(self) -> bool {
self == Self::Exhausted
}
}
pub struct RetrievalBudgetController {
total_budget: usize,
consumed: AtomicUsize,
exhaustion_signaled: AtomicBool,
constrain_threshold: f32,
}
impl Clone for RetrievalBudgetController {
fn clone(&self) -> Self {
Self {
total_budget: self.total_budget,
consumed: AtomicUsize::new(self.consumed.load(Ordering::Relaxed)),
exhaustion_signaled: AtomicBool::new(self.exhaustion_signaled.load(Ordering::Relaxed)),
constrain_threshold: self.constrain_threshold,
}
}
}
impl RetrievalBudgetController {
pub fn new(total_budget: usize) -> Self {
Self {
total_budget,
consumed: AtomicUsize::new(0),
exhaustion_signaled: AtomicBool::new(false),
constrain_threshold: 0.7,
}
}
pub fn with_constrain_threshold(mut self, threshold: f32) -> Self {
self.constrain_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn status(&self) -> BudgetStatus {
if self.exhaustion_signaled.load(Ordering::Relaxed) {
return BudgetStatus::Exhausted;
}
let consumed = self.consumed.load(Ordering::Relaxed);
if consumed >= self.total_budget {
self.exhaustion_signaled.store(true, Ordering::Relaxed);
return BudgetStatus::Exhausted;
}
let utilization = consumed as f32 / self.total_budget as f32;
if utilization >= self.constrain_threshold {
BudgetStatus::Constrained
} else {
BudgetStatus::Healthy
}
}
pub fn record_tokens(&self, tokens: usize) {
self.consumed.fetch_add(tokens, Ordering::Relaxed);
}
pub fn consumed(&self) -> usize {
self.consumed.load(Ordering::Relaxed)
}
pub fn remaining(&self) -> usize {
self.total_budget
.saturating_sub(self.consumed.load(Ordering::Relaxed))
}
pub fn total_budget(&self) -> usize {
self.total_budget
}
pub fn utilization(&self) -> f32 {
if self.total_budget == 0 {
0.0
} else {
(self.consumed.load(Ordering::Relaxed) as f32 / self.total_budget as f32).min(1.0)
}
}
pub fn signal_exhausted(&self) {
self.exhaustion_signaled.store(true, Ordering::Relaxed);
}
pub fn is_exhausted(&self) -> bool {
self.exhaustion_signaled.load(Ordering::Relaxed)
|| self.consumed.load(Ordering::Relaxed) >= self.total_budget
}
pub fn reset(&self) {
self.consumed.store(0, Ordering::Relaxed);
self.exhaustion_signaled.store(false, Ordering::Relaxed);
}
pub fn suggested_beam_width(&self, current_beam: usize, iteration: usize) -> usize {
match self.status() {
BudgetStatus::Healthy => {
current_beam
}
BudgetStatus::Constrained => {
let reduced = if iteration <= 1 {
current_beam
} else {
(current_beam / 2).max(1)
};
reduced
}
BudgetStatus::Exhausted => {
0
}
}
}
pub fn should_continue_search(&self, current_confidence: f32, iteration: usize) -> bool {
if self.is_exhausted() {
return false;
}
if current_confidence > 0.8 && iteration >= 1 {
return false;
}
if self.status() == BudgetStatus::Constrained && current_confidence > 0.4 {
return false;
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_healthy() {
let budget = RetrievalBudgetController::new(1000);
assert_eq!(budget.status(), BudgetStatus::Healthy);
assert!(!budget.is_exhausted());
assert_eq!(budget.remaining(), 1000);
}
#[test]
fn test_budget_constrained() {
let budget = RetrievalBudgetController::new(1000);
budget.record_tokens(750); assert_eq!(budget.status(), BudgetStatus::Constrained);
assert!(budget.status().allow_llm());
}
#[test]
fn test_budget_exhausted() {
let budget = RetrievalBudgetController::new(1000);
budget.record_tokens(1000);
assert_eq!(budget.status(), BudgetStatus::Exhausted);
assert!(budget.status().should_stop());
assert!(!budget.status().allow_llm());
}
#[test]
fn test_budget_exhausted_over() {
let budget = RetrievalBudgetController::new(1000);
budget.record_tokens(1500);
assert_eq!(budget.status(), BudgetStatus::Exhausted);
}
#[test]
fn test_budget_signal_exhausted() {
let budget = RetrievalBudgetController::new(1000);
budget.signal_exhausted();
assert_eq!(budget.status(), BudgetStatus::Exhausted);
assert_eq!(budget.consumed(), 0); }
#[test]
fn test_budget_reset() {
let budget = RetrievalBudgetController::new(1000);
budget.record_tokens(800);
assert_eq!(budget.status(), BudgetStatus::Constrained);
budget.reset();
assert_eq!(budget.status(), BudgetStatus::Healthy);
assert_eq!(budget.consumed(), 0);
}
#[test]
fn test_suggested_beam_width() {
let budget = RetrievalBudgetController::new(1000);
assert_eq!(budget.suggested_beam_width(4, 0), 4);
budget.record_tokens(750);
assert_eq!(budget.suggested_beam_width(4, 0), 4);
assert_eq!(budget.suggested_beam_width(4, 2), 2);
budget.record_tokens(300);
assert_eq!(budget.suggested_beam_width(4, 0), 0);
}
#[test]
fn test_should_continue_search() {
let budget = RetrievalBudgetController::new(1000);
assert!(budget.should_continue_search(0.2, 0));
assert!(!budget.should_continue_search(0.9, 1));
assert!(budget.should_continue_search(0.5, 1));
budget.record_tokens(750);
assert!(!budget.should_continue_search(0.5, 2));
assert!(budget.should_continue_search(0.2, 2));
}
#[test]
fn test_utilization() {
let budget = RetrievalBudgetController::new(1000);
assert!((budget.utilization() - 0.0).abs() < 0.01);
budget.record_tokens(500);
assert!((budget.utilization() - 0.5).abs() < 0.01);
}
#[test]
fn test_custom_constrain_threshold() {
let budget = RetrievalBudgetController::new(1000).with_constrain_threshold(0.5);
budget.record_tokens(500);
assert_eq!(budget.status(), BudgetStatus::Constrained);
}
}