Skip to main content

lutum_protocol/
budget.rs

1//! Core budget primitives intentionally live in the shared protocol crate.
2//!
3//! Budget policy is user-driven through the [`BudgetManager`] trait, so
4//! adapters and higher-level crates can plug in their own accounting and
5//! enforcement strategies without moving budget ownership out of core.
6
7use std::{
8    collections::BTreeMap,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicU64, Ordering},
12    },
13};
14
15use thiserror::Error;
16
17use crate::{AgentError, RequestExtensions};
18
19#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
20pub struct UsageEstimate {
21    pub input_tokens: u64,
22    pub output_tokens: u64,
23    pub total_tokens: u64,
24    pub cost_micros_usd: u64,
25}
26
27impl UsageEstimate {
28    pub const fn zero() -> Self {
29        Self {
30            input_tokens: 0,
31            output_tokens: 0,
32            total_tokens: 0,
33            cost_micros_usd: 0,
34        }
35    }
36}
37
38#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
39pub struct Usage {
40    pub input_tokens: u64,
41    pub output_tokens: u64,
42    pub total_tokens: u64,
43    pub cost_micros_usd: u64,
44}
45
46impl Usage {
47    pub const fn zero() -> Self {
48        Self {
49            input_tokens: 0,
50            output_tokens: 0,
51            total_tokens: 0,
52            cost_micros_usd: 0,
53        }
54    }
55}
56
57#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
58pub struct Remaining {
59    pub tokens: u64,
60    pub cost_micros_usd: u64,
61    pub below_threshold: bool,
62}
63
64#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
65pub struct RequestBudget {
66    pub tokens: Option<u64>,
67    pub cost_micros_usd: Option<u64>,
68}
69
70impl RequestBudget {
71    pub const fn unlimited() -> Self {
72        Self {
73            tokens: None,
74            cost_micros_usd: None,
75        }
76    }
77
78    pub const fn from_tokens(tokens: u64) -> Self {
79        Self {
80            tokens: Some(tokens),
81            cost_micros_usd: None,
82        }
83    }
84
85    pub const fn from_cost_micros_usd(cost_micros_usd: u64) -> Self {
86        Self {
87            tokens: None,
88            cost_micros_usd: Some(cost_micros_usd),
89        }
90    }
91
92    pub const fn with_limits(tokens: Option<u64>, cost_micros_usd: Option<u64>) -> Self {
93        Self {
94            tokens,
95            cost_micros_usd,
96        }
97    }
98
99    fn allows(self, tokens: u64, cost_micros_usd: u64) -> bool {
100        self.tokens.is_none_or(|limit| tokens <= limit)
101            && self
102                .cost_micros_usd
103                .is_none_or(|limit| cost_micros_usd <= limit)
104    }
105}
106
107#[derive(Clone, Debug, Eq, PartialEq)]
108pub struct BudgetLease {
109    id: u64,
110    reserved: UsageEstimate,
111    request_budget: RequestBudget,
112}
113
114impl BudgetLease {
115    pub fn new(id: u64, reserved: UsageEstimate, request_budget: RequestBudget) -> Self {
116        Self {
117            id,
118            reserved,
119            request_budget,
120        }
121    }
122
123    pub fn id(&self) -> u64 {
124        self.id
125    }
126
127    pub fn reserved(&self) -> UsageEstimate {
128        self.reserved
129    }
130
131    pub fn request_budget(&self) -> RequestBudget {
132        self.request_budget
133    }
134}
135
136pub trait BudgetManager: Send + Sync + 'static {
137    fn remaining(&self, extensions: &RequestExtensions) -> Remaining;
138    fn reserve(
139        &self,
140        extensions: &RequestExtensions,
141        estimate: &UsageEstimate,
142        request_budget: RequestBudget,
143    ) -> Result<BudgetLease, AgentError>;
144    fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError>;
145}
146
147impl<T> BudgetManager for Arc<T>
148where
149    T: BudgetManager + ?Sized,
150{
151    fn remaining(&self, extensions: &RequestExtensions) -> Remaining {
152        (**self).remaining(extensions)
153    }
154
155    fn reserve(
156        &self,
157        extensions: &RequestExtensions,
158        estimate: &UsageEstimate,
159        request_budget: RequestBudget,
160    ) -> Result<BudgetLease, AgentError> {
161        (**self).reserve(extensions, estimate, request_budget)
162    }
163
164    fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError> {
165        (**self).record_used(lease, usage)
166    }
167}
168
169#[derive(Clone, Copy, Debug, Eq, PartialEq)]
170pub struct SharedPoolBudgetOptions {
171    pub capacity_tokens: u64,
172    pub capacity_cost_micros_usd: u64,
173    pub stop_threshold_tokens: u64,
174    pub stop_threshold_cost_micros_usd: u64,
175}
176
177impl Default for SharedPoolBudgetOptions {
178    fn default() -> Self {
179        Self {
180            capacity_tokens: u64::MAX,
181            capacity_cost_micros_usd: u64::MAX,
182            stop_threshold_tokens: 0,
183            stop_threshold_cost_micros_usd: 0,
184        }
185    }
186}
187
188#[derive(Debug, Error, Clone, Eq, PartialEq)]
189pub enum SharedPoolBudgetError {
190    #[error(
191        "request budget exceeded: requested {requested_tokens} tokens / {requested_cost_micros_usd} micros exceeds tokens={budget_tokens:?}, cost={budget_cost_micros_usd:?}"
192    )]
193    RequestBudgetExceeded {
194        requested_tokens: u64,
195        requested_cost_micros_usd: u64,
196        budget_tokens: Option<u64>,
197        budget_cost_micros_usd: Option<u64>,
198    },
199    #[error(
200        "reserving {requested_tokens} tokens / {requested_cost_micros_usd} micros would cross the stop threshold"
201    )]
202    ThresholdExceeded {
203        requested_tokens: u64,
204        requested_cost_micros_usd: u64,
205        remaining_tokens: u64,
206        remaining_cost_micros_usd: u64,
207    },
208    #[error("unknown budget lease {lease_id}")]
209    UnknownLease { lease_id: u64 },
210    #[error("shared budget state poisoned")]
211    Poisoned,
212}
213
214#[derive(Clone)]
215pub struct SharedPoolBudgetManager {
216    options: SharedPoolBudgetOptions,
217    next_lease_id: Arc<AtomicU64>,
218    state: Arc<Mutex<SharedPoolBudgetState>>,
219}
220
221#[derive(Debug, Default)]
222struct SharedPoolBudgetState {
223    committed_tokens: u64,
224    committed_cost_micros_usd: u64,
225    reserved_tokens: u64,
226    reserved_cost_micros_usd: u64,
227    leases: BTreeMap<u64, (UsageEstimate, RequestBudget)>,
228}
229
230impl SharedPoolBudgetManager {
231    pub fn new(options: SharedPoolBudgetOptions) -> Self {
232        Self {
233            options,
234            next_lease_id: Arc::new(AtomicU64::new(1)),
235            state: Arc::new(Mutex::new(SharedPoolBudgetState::default())),
236        }
237    }
238
239    fn remaining_with_state(&self, state: &SharedPoolBudgetState) -> Remaining {
240        let tokens = self
241            .options
242            .capacity_tokens
243            .saturating_sub(state.committed_tokens.saturating_add(state.reserved_tokens));
244        let cost_micros_usd = self.options.capacity_cost_micros_usd.saturating_sub(
245            state
246                .committed_cost_micros_usd
247                .saturating_add(state.reserved_cost_micros_usd),
248        );
249
250        Remaining {
251            tokens,
252            cost_micros_usd,
253            below_threshold: tokens <= self.options.stop_threshold_tokens
254                || cost_micros_usd <= self.options.stop_threshold_cost_micros_usd,
255        }
256    }
257}
258
259impl BudgetManager for SharedPoolBudgetManager {
260    fn remaining(&self, _extensions: &RequestExtensions) -> Remaining {
261        let state = self
262            .state
263            .lock()
264            .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned));
265        match state {
266            Ok(state) => self.remaining_with_state(&state),
267            Err(_) => Remaining {
268                tokens: 0,
269                cost_micros_usd: 0,
270                below_threshold: true,
271            },
272        }
273    }
274
275    fn reserve(
276        &self,
277        _extensions: &RequestExtensions,
278        estimate: &UsageEstimate,
279        request_budget: RequestBudget,
280    ) -> Result<BudgetLease, AgentError> {
281        if !request_budget.allows(estimate.total_tokens, estimate.cost_micros_usd) {
282            return Err(AgentError::budget(
283                SharedPoolBudgetError::RequestBudgetExceeded {
284                    requested_tokens: estimate.total_tokens,
285                    requested_cost_micros_usd: estimate.cost_micros_usd,
286                    budget_tokens: request_budget.tokens,
287                    budget_cost_micros_usd: request_budget.cost_micros_usd,
288                },
289            ));
290        }
291
292        let mut state = self
293            .state
294            .lock()
295            .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned))?;
296        let remaining = self.remaining_with_state(&state);
297
298        let remaining_after_tokens = remaining.tokens.saturating_sub(estimate.total_tokens);
299        let remaining_after_cost = remaining
300            .cost_micros_usd
301            .saturating_sub(estimate.cost_micros_usd);
302        let denied = estimate.total_tokens > remaining.tokens
303            || estimate.cost_micros_usd > remaining.cost_micros_usd
304            || remaining_after_tokens < self.options.stop_threshold_tokens
305            || remaining_after_cost < self.options.stop_threshold_cost_micros_usd;
306
307        if denied {
308            return Err(AgentError::budget(
309                SharedPoolBudgetError::ThresholdExceeded {
310                    requested_tokens: estimate.total_tokens,
311                    requested_cost_micros_usd: estimate.cost_micros_usd,
312                    remaining_tokens: remaining.tokens,
313                    remaining_cost_micros_usd: remaining.cost_micros_usd,
314                },
315            ));
316        }
317
318        let id = self.next_lease_id.fetch_add(1, Ordering::Relaxed);
319        state.reserved_tokens = state.reserved_tokens.saturating_add(estimate.total_tokens);
320        state.reserved_cost_micros_usd = state
321            .reserved_cost_micros_usd
322            .saturating_add(estimate.cost_micros_usd);
323        state.leases.insert(id, (*estimate, request_budget));
324
325        Ok(BudgetLease::new(id, *estimate, request_budget))
326    }
327
328    fn record_used(&self, lease: BudgetLease, usage: Usage) -> Result<(), AgentError> {
329        let mut state = self
330            .state
331            .lock()
332            .map_err(|_| AgentError::budget(SharedPoolBudgetError::Poisoned))?;
333        let Some((reserved, request_budget)) = state.leases.remove(&lease.id) else {
334            return Err(AgentError::budget(SharedPoolBudgetError::UnknownLease {
335                lease_id: lease.id,
336            }));
337        };
338
339        state.reserved_tokens = state.reserved_tokens.saturating_sub(reserved.total_tokens);
340        state.reserved_cost_micros_usd = state
341            .reserved_cost_micros_usd
342            .saturating_sub(reserved.cost_micros_usd);
343        state.committed_tokens = state.committed_tokens.saturating_add(usage.total_tokens);
344        state.committed_cost_micros_usd = state
345            .committed_cost_micros_usd
346            .saturating_add(usage.cost_micros_usd);
347
348        if !request_budget.allows(usage.total_tokens, usage.cost_micros_usd) {
349            return Err(AgentError::budget(
350                SharedPoolBudgetError::RequestBudgetExceeded {
351                    requested_tokens: usage.total_tokens,
352                    requested_cost_micros_usd: usage.cost_micros_usd,
353                    budget_tokens: request_budget.tokens,
354                    budget_cost_micros_usd: request_budget.cost_micros_usd,
355                },
356            ));
357        }
358
359        Ok(())
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    fn shared_pool_error(err: &AgentError) -> &SharedPoolBudgetError {
368        match err {
369            AgentError::Budget(source) => source
370                .downcast_ref::<SharedPoolBudgetError>()
371                .expect("shared pool budget error source"),
372            other => panic!("expected budget error, got {other}"),
373        }
374    }
375
376    #[test]
377    fn shared_pool_reserves_and_refunds_difference() {
378        let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions {
379            capacity_tokens: 100,
380            capacity_cost_micros_usd: 1_000,
381            stop_threshold_tokens: 10,
382            stop_threshold_cost_micros_usd: 100,
383        });
384
385        let extensions = RequestExtensions::new();
386        let lease = manager
387            .reserve(
388                &extensions,
389                &UsageEstimate {
390                    total_tokens: 20,
391                    cost_micros_usd: 200,
392                    ..UsageEstimate::zero()
393                },
394                RequestBudget::unlimited(),
395            )
396            .unwrap();
397        assert_eq!(manager.remaining(&extensions).tokens, 80);
398
399        manager
400            .record_used(
401                lease,
402                Usage {
403                    total_tokens: 12,
404                    cost_micros_usd: 120,
405                    ..Usage::zero()
406                },
407            )
408            .unwrap();
409
410        let remaining = manager.remaining(&extensions);
411        assert_eq!(remaining.tokens, 88);
412        assert_eq!(remaining.cost_micros_usd, 880);
413    }
414
415    #[test]
416    fn shared_pool_blocks_when_threshold_would_be_crossed() {
417        let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions {
418            capacity_tokens: 100,
419            capacity_cost_micros_usd: 1_000,
420            stop_threshold_tokens: 10,
421            stop_threshold_cost_micros_usd: 0,
422        });
423
424        let err = manager
425            .reserve(
426                &RequestExtensions::new(),
427                &UsageEstimate {
428                    total_tokens: 91,
429                    ..UsageEstimate::zero()
430                },
431                RequestBudget::unlimited(),
432            )
433            .unwrap_err();
434
435        assert!(matches!(
436            shared_pool_error(&err),
437            SharedPoolBudgetError::ThresholdExceeded { .. }
438        ));
439    }
440
441    #[test]
442    fn request_budget_can_restrict_reservation() {
443        let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions::default());
444
445        let err = manager
446            .reserve(
447                &RequestExtensions::new(),
448                &UsageEstimate {
449                    total_tokens: 32,
450                    ..UsageEstimate::zero()
451                },
452                RequestBudget::from_tokens(16),
453            )
454            .unwrap_err();
455
456        assert!(matches!(
457            shared_pool_error(&err),
458            SharedPoolBudgetError::RequestBudgetExceeded {
459                requested_tokens: 32,
460                budget_tokens: Some(16),
461                ..
462            }
463        ));
464    }
465
466    #[test]
467    fn request_budget_can_fail_after_actual_usage_is_higher_than_estimate() {
468        let manager = SharedPoolBudgetManager::new(SharedPoolBudgetOptions::default());
469
470        let lease = manager
471            .reserve(
472                &RequestExtensions::new(),
473                &UsageEstimate {
474                    total_tokens: 8,
475                    ..UsageEstimate::zero()
476                },
477                RequestBudget::from_tokens(10),
478            )
479            .unwrap();
480
481        let err = manager
482            .record_used(
483                lease,
484                Usage {
485                    total_tokens: 12,
486                    ..Usage::zero()
487                },
488            )
489            .unwrap_err();
490
491        assert!(matches!(
492            shared_pool_error(&err),
493            SharedPoolBudgetError::RequestBudgetExceeded {
494                requested_tokens: 12,
495                budget_tokens: Some(10),
496                ..
497            }
498        ));
499    }
500}