Skip to main content

oxibonsai_runtime/
token_budget.rs

1//! Token budget management: enforce per-request and global token limits.
2
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use thiserror::Error;
7
8// ─── Error type ──────────────────────────────────────────────────────────────
9
10/// Errors raised by budget enforcement.
11#[derive(Debug, Error)]
12pub enum BudgetError {
13    #[error("prompt tokens {prompt} exceeds max_prompt_tokens {max}")]
14    PromptTooLong { prompt: usize, max: usize },
15    #[error("completion token budget exhausted (limit = {limit})")]
16    CompletionBudgetExhausted { limit: usize },
17    #[error("total token budget exhausted (limit = {limit}, used = {used})")]
18    TotalBudgetExhausted { limit: usize, used: usize },
19}
20
21// ─── Policy ───────────────────────────────────────────────────────────────────
22
23/// Action taken when a budget limit is reached.
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum BudgetPolicy {
26    /// Stop generation cleanly.
27    StopGeneration,
28    /// Truncate the oldest context.
29    TruncateContext,
30    /// Return an error.
31    ReturnError,
32}
33
34// ─── Budget configuration ─────────────────────────────────────────────────────
35
36/// Configuration for token budget enforcement.
37#[derive(Debug, Clone)]
38pub struct BudgetConfig {
39    /// Maximum tokens allowed in the prompt.
40    pub max_prompt_tokens: Option<usize>,
41    /// Maximum tokens to generate (completion only).
42    pub max_completion_tokens: Option<usize>,
43    /// Maximum prompt + completion tokens combined.
44    pub max_total_tokens: Option<usize>,
45    /// Policy to apply when a limit is breached.
46    pub policy: BudgetPolicy,
47}
48
49impl BudgetConfig {
50    /// Create a config with no limits and the default policy (`StopGeneration`).
51    pub fn new() -> Self {
52        Self {
53            max_prompt_tokens: None,
54            max_completion_tokens: None,
55            max_total_tokens: None,
56            policy: BudgetPolicy::StopGeneration,
57        }
58    }
59
60    /// Set the maximum completion tokens.
61    pub fn with_max_completion(mut self, n: usize) -> Self {
62        self.max_completion_tokens = Some(n);
63        self
64    }
65
66    /// Set the maximum total (prompt + completion) tokens.
67    pub fn with_max_total(mut self, n: usize) -> Self {
68        self.max_total_tokens = Some(n);
69        self
70    }
71
72    /// Override the enforcement policy.
73    pub fn with_policy(mut self, policy: BudgetPolicy) -> Self {
74        self.policy = policy;
75        self
76    }
77
78    /// Convenience: no limits whatsoever.
79    pub fn unlimited() -> Self {
80        Self::new()
81    }
82}
83
84impl Default for BudgetConfig {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90// ─── Per-request budget ───────────────────────────────────────────────────────
91
92/// Per-request token budget tracker.
93#[derive(Debug)]
94pub struct RequestBudget {
95    config: BudgetConfig,
96    prompt_tokens: usize,
97    completion_tokens: usize,
98}
99
100impl RequestBudget {
101    /// Create a new `RequestBudget`, validating the initial prompt length.
102    ///
103    /// Returns `Err(BudgetError::PromptTooLong)` if `prompt_tokens` exceeds
104    /// `config.max_prompt_tokens`.
105    pub fn new(config: BudgetConfig, prompt_tokens: usize) -> Result<Self, BudgetError> {
106        if let Some(max) = config.max_prompt_tokens {
107            if prompt_tokens > max {
108                return Err(BudgetError::PromptTooLong {
109                    prompt: prompt_tokens,
110                    max,
111                });
112            }
113        }
114        Ok(Self {
115            config,
116            prompt_tokens,
117            completion_tokens: 0,
118        })
119    }
120
121    /// Record one generated token.
122    ///
123    /// Returns an error (according to the configured policy) if any limit is
124    /// exceeded after recording the token.
125    pub fn record_token(&mut self) -> Result<(), BudgetError> {
126        self.record_tokens(1)
127    }
128
129    /// Record `n` generated tokens.
130    ///
131    /// Limits are checked after adding `n`.
132    pub fn record_tokens(&mut self, n: usize) -> Result<(), BudgetError> {
133        self.completion_tokens = self.completion_tokens.saturating_add(n);
134
135        // Check completion limit.
136        if let Some(limit) = self.config.max_completion_tokens {
137            if self.completion_tokens > limit {
138                return Err(BudgetError::CompletionBudgetExhausted { limit });
139            }
140        }
141
142        // Check total limit.
143        if let Some(limit) = self.config.max_total_tokens {
144            let used = self.total_tokens();
145            if used > limit {
146                return Err(BudgetError::TotalBudgetExhausted { limit, used });
147            }
148        }
149
150        Ok(())
151    }
152
153    /// Number of tokens in the prompt.
154    pub fn prompt_tokens(&self) -> usize {
155        self.prompt_tokens
156    }
157
158    /// Number of tokens generated so far.
159    pub fn completion_tokens(&self) -> usize {
160        self.completion_tokens
161    }
162
163    /// Prompt + completion tokens.
164    pub fn total_tokens(&self) -> usize {
165        self.prompt_tokens.saturating_add(self.completion_tokens)
166    }
167
168    /// How many more completion tokens can be generated, or `None` if unlimited.
169    pub fn remaining_completion_tokens(&self) -> Option<usize> {
170        self.config
171            .max_completion_tokens
172            .map(|limit| limit.saturating_sub(self.completion_tokens))
173    }
174
175    /// Whether any configured budget is exhausted.
176    pub fn is_exhausted(&self) -> bool {
177        if let Some(limit) = self.config.max_completion_tokens {
178            if self.completion_tokens >= limit {
179                return true;
180            }
181        }
182        if let Some(limit) = self.config.max_total_tokens {
183            if self.total_tokens() >= limit {
184                return true;
185            }
186        }
187        false
188    }
189
190    /// The policy that governs how exhaustion is handled.
191    pub fn policy(&self) -> BudgetPolicy {
192        self.config.policy
193    }
194}
195
196// ─── Global token budget ──────────────────────────────────────────────────────
197
198/// Global token budget shared across requests via an `Arc<AtomicU64>`.
199pub struct GlobalTokenBudget {
200    total_tokens_used: Arc<AtomicU64>,
201    max_tokens: Option<u64>,
202}
203
204impl GlobalTokenBudget {
205    /// Create a global budget with an optional hard cap.
206    pub fn new(max_tokens: Option<u64>) -> Self {
207        Self {
208            total_tokens_used: Arc::new(AtomicU64::new(0)),
209            max_tokens,
210        }
211    }
212
213    /// Convenience: no cap.
214    pub fn unlimited() -> Self {
215        Self::new(None)
216    }
217
218    /// Add `tokens` to the global counter.
219    pub fn record(&self, tokens: u64) {
220        self.total_tokens_used.fetch_add(tokens, Ordering::Relaxed);
221    }
222
223    /// Total tokens consumed so far.
224    pub fn total_used(&self) -> u64 {
225        self.total_tokens_used.load(Ordering::Relaxed)
226    }
227
228    /// How many tokens remain before the cap, or `None` if unlimited.
229    pub fn remaining(&self) -> Option<u64> {
230        self.max_tokens
231            .map(|cap| cap.saturating_sub(self.total_used()))
232    }
233
234    /// Whether the global cap has been reached.
235    pub fn is_exhausted(&self) -> bool {
236        match self.max_tokens {
237            None => false,
238            Some(cap) => self.total_used() >= cap,
239        }
240    }
241
242    /// Fraction of the cap consumed (`total_used / max_tokens`), or `None`
243    /// if the budget is unlimited.
244    pub fn utilization(&self) -> Option<f32> {
245        self.max_tokens.map(|cap| {
246            if cap == 0 {
247                1.0
248            } else {
249                self.total_used() as f32 / cap as f32
250            }
251        })
252    }
253}
254
255// ─── Cost estimator ───────────────────────────────────────────────────────────
256
257/// Estimated monetary cost for a request (for billing / logging).
258#[derive(Debug, Clone)]
259pub struct TokenCostEstimate {
260    /// Tokens in the prompt.
261    pub prompt_tokens: usize,
262    /// Tokens generated.
263    pub completion_tokens: usize,
264    /// Cost for the prompt portion.
265    pub prompt_cost: f64,
266    /// Cost for the completion portion.
267    pub completion_cost: f64,
268    /// `prompt_cost + completion_cost`.
269    pub total_cost: f64,
270}
271
272impl TokenCostEstimate {
273    /// Compute cost given per-1 000-token rates for prompt and completion.
274    pub fn compute(
275        prompt_tokens: usize,
276        completion_tokens: usize,
277        prompt_cost_per_1k: f64,
278        completion_cost_per_1k: f64,
279    ) -> Self {
280        let prompt_cost = prompt_tokens as f64 / 1_000.0 * prompt_cost_per_1k;
281        let completion_cost = completion_tokens as f64 / 1_000.0 * completion_cost_per_1k;
282        let total_cost = prompt_cost + completion_cost;
283        Self {
284            prompt_tokens,
285            completion_tokens,
286            prompt_cost,
287            completion_cost,
288            total_cost,
289        }
290    }
291
292    /// Human-readable summary of the cost breakdown.
293    pub fn summary(&self) -> String {
294        format!(
295            "tokens: prompt={} completion={} | cost: prompt=${:.6} completion=${:.6} total=${:.6}",
296            self.prompt_tokens,
297            self.completion_tokens,
298            self.prompt_cost,
299            self.completion_cost,
300            self.total_cost,
301        )
302    }
303}