Skip to main content

heartbit_core/agent/
tenant_tracker.rs

1//! Per-tenant in-flight token tracker with Arc-owning RAII reservations.
2//!
3//! See `docs/superpowers/specs/2026-05-02-b5b-failure-mode-hardening-design.md`
4//! Component 2 for design rationale.
5
6#![allow(missing_docs)]
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use crate::auth::TenantScope;
13use crate::error::Error;
14
15/// Snapshot of in-flight token usage for a single tenant.
16///
17/// `in_flight` is the total number of tokens currently reserved across all active
18/// requests for the tenant. `high_water` is the all-time peak, useful for
19/// capacity planning and alerting. Both values are updated atomically under the
20/// `TenantTokenTracker`'s write lock.
21#[derive(Debug, Default, Clone)]
22pub struct TenantTokenState {
23    pub in_flight: usize,
24    pub high_water: usize,
25}
26
27/// Registry that enforces a per-tenant in-flight token cap across concurrent agent runs.
28///
29/// Each call to `reserve` atomically increments the tenant's `in_flight` counter
30/// and returns a `TokenReservation`; the counter is decremented automatically when
31/// the reservation is dropped. If the requested tokens would exceed `per_tenant_cap`,
32/// `reserve` returns `Error::TenantOverloaded` rather than blocking, enabling
33/// load-shedding at the orchestration layer. Pass an `Arc<TenantTokenTracker>` to
34/// `OrchestratorBuilder::tenant_tracker` so sub-agents also participate in the cap.
35pub struct TenantTokenTracker {
36    states: RwLock<HashMap<String, TenantTokenState>>,
37    per_tenant_cap: usize,
38}
39
40/// RAII guard that holds a token reservation for a tenant.
41///
42/// Created by `TenantTokenTracker::reserve` and automatically releases the
43/// reserved token count back to the tracker when dropped, whether the request
44/// succeeds, fails, or is cancelled. Callers should hold the reservation for
45/// the duration of the LLM call to ensure the in-flight counter stays accurate.
46pub struct TokenReservation {
47    tracker: Arc<TenantTokenTracker>,
48    tenant_id: String,
49    tokens: usize,
50}
51
52impl std::fmt::Debug for TokenReservation {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("TokenReservation")
55            .field("tenant_id", &self.tenant_id)
56            .field("tokens", &self.tokens)
57            .finish()
58    }
59}
60
61impl Drop for TokenReservation {
62    fn drop(&mut self) {
63        self.tracker.release(&self.tenant_id, self.tokens);
64    }
65}
66
67impl TenantTokenTracker {
68    pub fn new(per_tenant_cap: usize) -> Self {
69        Self {
70            states: RwLock::new(HashMap::new()),
71            per_tenant_cap,
72        }
73    }
74
75    pub fn reserve(
76        self: &Arc<Self>,
77        scope: &TenantScope,
78        tokens: usize,
79    ) -> Result<TokenReservation, Error> {
80        let tenant = scope.tenant_id.clone();
81        let mut guard = self.states.write();
82        let state = guard.entry(tenant.clone()).or_default();
83        if state.in_flight.saturating_add(tokens) > self.per_tenant_cap {
84            return Err(Error::TenantOverloaded {
85                tenant_id: tenant,
86                in_flight: state.in_flight,
87                cap: self.per_tenant_cap,
88            });
89        }
90        state.in_flight += tokens;
91        if state.in_flight > state.high_water {
92            state.high_water = state.in_flight;
93        }
94        Ok(TokenReservation {
95            tracker: Arc::clone(self),
96            tenant_id: tenant,
97            tokens,
98        })
99    }
100
101    /// Adjust the in-flight token count for `scope` by `delta` (signed).
102    ///
103    /// Used by `AgentRunner` after each LLM response to reconcile the
104    /// initial estimate with actual usage. Best-effort:
105    /// - Positive deltas are silently **clamped at `per_tenant_cap`** to
106    ///   avoid the per-turn caller having to handle a "we already accepted
107    ///   this work" error mid-task. The clamp is intentional accounting
108    ///   drift: if a tenant overshoots its cap during execution, the
109    ///   tracker reports `in_flight == cap` rather than the true usage.
110    ///   Subsequent `reserve()` calls will see the clamped value and
111    ///   gate accordingly.
112    /// - Negative deltas saturate at 0.
113    /// - No-op on poisoned lock (logged via `tracing::warn!`) or unknown tenant.
114    pub fn adjust(&self, scope: &TenantScope, delta: i64) {
115        let mut guard = self.states.write();
116        let Some(state) = guard.get_mut(&scope.tenant_id) else {
117            return;
118        };
119        if delta >= 0 {
120            state.in_flight = state
121                .in_flight
122                .saturating_add(delta as usize)
123                .min(self.per_tenant_cap);
124        } else {
125            // `i64::unsigned_abs()` returns `u64` and handles `i64::MIN`
126            // correctly; `-i64::MIN` would otherwise overflow.
127            state.in_flight = state
128                .in_flight
129                .saturating_sub(delta.unsigned_abs() as usize);
130        }
131        if state.in_flight > state.high_water {
132            state.high_water = state.in_flight;
133        }
134    }
135
136    fn release(&self, tenant_id: &str, tokens: usize) {
137        let mut guard = self.states.write();
138        if let Some(state) = guard.get_mut(tenant_id) {
139            state.in_flight = state.in_flight.saturating_sub(tokens);
140        }
141    }
142
143    pub fn snapshot(&self) -> Vec<(String, TenantTokenState)> {
144        self.states
145            .read()
146            .iter()
147            .map(|(k, v)| (k.clone(), v.clone()))
148            .collect()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn scope(t: &str) -> TenantScope {
157        TenantScope::new(t)
158    }
159
160    #[test]
161    fn reserve_within_cap_succeeds() {
162        let t = Arc::new(TenantTokenTracker::new(1000));
163        let r = t.reserve(&scope("a"), 500).unwrap();
164        let snap = t.snapshot();
165        assert_eq!(snap.len(), 1);
166        assert_eq!(snap[0].1.in_flight, 500);
167        drop(r);
168    }
169
170    #[test]
171    fn reserve_exceeding_cap_returns_tenant_overloaded() {
172        let t = Arc::new(TenantTokenTracker::new(100));
173        let _r = t.reserve(&scope("a"), 80).unwrap();
174        let err = t.reserve(&scope("a"), 50).unwrap_err();
175        match err {
176            Error::TenantOverloaded {
177                tenant_id,
178                in_flight,
179                cap,
180            } => {
181                assert_eq!(tenant_id, "a");
182                assert_eq!(in_flight, 80);
183                assert_eq!(cap, 100);
184            }
185            other => panic!("expected TenantOverloaded, got {other:?}"),
186        }
187    }
188
189    #[test]
190    fn drop_releases_reservation() {
191        let t = Arc::new(TenantTokenTracker::new(1000));
192        {
193            let _r = t.reserve(&scope("a"), 700).unwrap();
194            assert_eq!(t.snapshot()[0].1.in_flight, 700);
195        }
196        assert_eq!(t.snapshot()[0].1.in_flight, 0);
197    }
198
199    #[test]
200    fn tenants_are_isolated() {
201        let t = Arc::new(TenantTokenTracker::new(100));
202        let _ra = t.reserve(&scope("a"), 90).unwrap();
203        let _rb = t.reserve(&scope("b"), 90).unwrap();
204        let snap: HashMap<_, _> = t.snapshot().into_iter().collect();
205        assert_eq!(snap["a"].in_flight, 90);
206        assert_eq!(snap["b"].in_flight, 90);
207    }
208
209    #[test]
210    fn high_water_tracks_peak() {
211        let t = Arc::new(TenantTokenTracker::new(1000));
212        let r1 = t.reserve(&scope("a"), 400).unwrap();
213        let r2 = t.reserve(&scope("a"), 300).unwrap();
214        drop(r1);
215        let snap = t.snapshot();
216        assert_eq!(snap[0].1.in_flight, 300);
217        assert_eq!(snap[0].1.high_water, 700);
218        drop(r2);
219    }
220
221    #[test]
222    fn adjust_positive_delta_clamps_at_cap() {
223        let t = Arc::new(TenantTokenTracker::new(1000));
224        let _r = t.reserve(&scope("a"), 500).unwrap();
225        t.adjust(&scope("a"), 800);
226        let snap = t.snapshot();
227        assert_eq!(snap[0].1.in_flight, 1000); // clamped
228        assert_eq!(snap[0].1.high_water, 1000);
229    }
230
231    #[test]
232    fn adjust_negative_delta_decrements() {
233        let t = Arc::new(TenantTokenTracker::new(1000));
234        let _r = t.reserve(&scope("a"), 500).unwrap();
235        t.adjust(&scope("a"), -200);
236        assert_eq!(t.snapshot()[0].1.in_flight, 300);
237    }
238
239    #[test]
240    fn adjust_negative_i64_min_does_not_panic() {
241        let t = Arc::new(TenantTokenTracker::new(1000));
242        let _r = t.reserve(&scope("a"), 500).unwrap();
243        t.adjust(&scope("a"), i64::MIN); // must not panic in debug
244        assert_eq!(t.snapshot()[0].1.in_flight, 0);
245    }
246
247    #[tokio::test]
248    async fn reservation_owns_arc_and_outlives_borrow() {
249        // Compile-time check: TokenReservation can be moved into a future.
250        let t = Arc::new(TenantTokenTracker::new(1000));
251        let r = t.reserve(&scope("a"), 500).unwrap();
252        let handle: tokio::task::JoinHandle<()> = tokio::task::spawn_blocking(move || {
253            drop(r);
254        });
255        handle.await.unwrap();
256    }
257
258    #[test]
259    fn default_scope_uses_empty_string_bucket() {
260        let t = Arc::new(TenantTokenTracker::new(1000));
261        let _r = t.reserve(&TenantScope::default(), 500).unwrap();
262        let snap = t.snapshot();
263        assert_eq!(snap.len(), 1);
264        assert_eq!(snap[0].0, ""); // empty-string sentinel
265    }
266
267    #[test]
268    fn adjust_on_unknown_tenant_is_noop() {
269        let t = Arc::new(TenantTokenTracker::new(1000));
270        t.adjust(&scope("unknown"), -100);
271        assert!(t.snapshot().is_empty());
272    }
273}