heartbit_core/agent/
tenant_tracker.rs1#![allow(missing_docs)]
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12use crate::auth::TenantScope;
13use crate::error::Error;
14
15#[derive(Debug, Default, Clone)]
22pub struct TenantTokenState {
23 pub in_flight: usize,
24 pub high_water: usize,
25}
26
27pub struct TenantTokenTracker {
36 states: RwLock<HashMap<String, TenantTokenState>>,
37 per_tenant_cap: usize,
38}
39
40pub struct TokenReservation {
47 tracker: Arc<TenantTokenTracker>,
48 tenant_id: String,
49 tokens: usize,
50}
51
52impl std::fmt::Debug for TokenReservation {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("TokenReservation")
55 .field("tenant_id", &self.tenant_id)
56 .field("tokens", &self.tokens)
57 .finish()
58 }
59}
60
61impl Drop for TokenReservation {
62 fn drop(&mut self) {
63 self.tracker.release(&self.tenant_id, self.tokens);
64 }
65}
66
67impl TenantTokenTracker {
68 pub fn new(per_tenant_cap: usize) -> Self {
69 Self {
70 states: RwLock::new(HashMap::new()),
71 per_tenant_cap,
72 }
73 }
74
75 pub fn reserve(
76 self: &Arc<Self>,
77 scope: &TenantScope,
78 tokens: usize,
79 ) -> Result<TokenReservation, Error> {
80 let tenant = scope.tenant_id.clone();
81 let mut guard = self.states.write();
82 let state = guard.entry(tenant.clone()).or_default();
83 if state.in_flight.saturating_add(tokens) > self.per_tenant_cap {
84 return Err(Error::TenantOverloaded {
85 tenant_id: tenant,
86 in_flight: state.in_flight,
87 cap: self.per_tenant_cap,
88 });
89 }
90 state.in_flight += tokens;
91 if state.in_flight > state.high_water {
92 state.high_water = state.in_flight;
93 }
94 Ok(TokenReservation {
95 tracker: Arc::clone(self),
96 tenant_id: tenant,
97 tokens,
98 })
99 }
100
101 pub fn adjust(&self, scope: &TenantScope, delta: i64) {
115 let mut guard = self.states.write();
116 let Some(state) = guard.get_mut(&scope.tenant_id) else {
117 return;
118 };
119 if delta >= 0 {
120 state.in_flight = state
121 .in_flight
122 .saturating_add(delta as usize)
123 .min(self.per_tenant_cap);
124 } else {
125 state.in_flight = state
128 .in_flight
129 .saturating_sub(delta.unsigned_abs() as usize);
130 }
131 if state.in_flight > state.high_water {
132 state.high_water = state.in_flight;
133 }
134 }
135
136 fn release(&self, tenant_id: &str, tokens: usize) {
137 let mut guard = self.states.write();
138 if let Some(state) = guard.get_mut(tenant_id) {
139 state.in_flight = state.in_flight.saturating_sub(tokens);
140 }
141 }
142
143 pub fn snapshot(&self) -> Vec<(String, TenantTokenState)> {
144 self.states
145 .read()
146 .iter()
147 .map(|(k, v)| (k.clone(), v.clone()))
148 .collect()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 fn scope(t: &str) -> TenantScope {
157 TenantScope::new(t)
158 }
159
160 #[test]
161 fn reserve_within_cap_succeeds() {
162 let t = Arc::new(TenantTokenTracker::new(1000));
163 let r = t.reserve(&scope("a"), 500).unwrap();
164 let snap = t.snapshot();
165 assert_eq!(snap.len(), 1);
166 assert_eq!(snap[0].1.in_flight, 500);
167 drop(r);
168 }
169
170 #[test]
171 fn reserve_exceeding_cap_returns_tenant_overloaded() {
172 let t = Arc::new(TenantTokenTracker::new(100));
173 let _r = t.reserve(&scope("a"), 80).unwrap();
174 let err = t.reserve(&scope("a"), 50).unwrap_err();
175 match err {
176 Error::TenantOverloaded {
177 tenant_id,
178 in_flight,
179 cap,
180 } => {
181 assert_eq!(tenant_id, "a");
182 assert_eq!(in_flight, 80);
183 assert_eq!(cap, 100);
184 }
185 other => panic!("expected TenantOverloaded, got {other:?}"),
186 }
187 }
188
189 #[test]
190 fn drop_releases_reservation() {
191 let t = Arc::new(TenantTokenTracker::new(1000));
192 {
193 let _r = t.reserve(&scope("a"), 700).unwrap();
194 assert_eq!(t.snapshot()[0].1.in_flight, 700);
195 }
196 assert_eq!(t.snapshot()[0].1.in_flight, 0);
197 }
198
199 #[test]
200 fn tenants_are_isolated() {
201 let t = Arc::new(TenantTokenTracker::new(100));
202 let _ra = t.reserve(&scope("a"), 90).unwrap();
203 let _rb = t.reserve(&scope("b"), 90).unwrap();
204 let snap: HashMap<_, _> = t.snapshot().into_iter().collect();
205 assert_eq!(snap["a"].in_flight, 90);
206 assert_eq!(snap["b"].in_flight, 90);
207 }
208
209 #[test]
210 fn high_water_tracks_peak() {
211 let t = Arc::new(TenantTokenTracker::new(1000));
212 let r1 = t.reserve(&scope("a"), 400).unwrap();
213 let r2 = t.reserve(&scope("a"), 300).unwrap();
214 drop(r1);
215 let snap = t.snapshot();
216 assert_eq!(snap[0].1.in_flight, 300);
217 assert_eq!(snap[0].1.high_water, 700);
218 drop(r2);
219 }
220
221 #[test]
222 fn adjust_positive_delta_clamps_at_cap() {
223 let t = Arc::new(TenantTokenTracker::new(1000));
224 let _r = t.reserve(&scope("a"), 500).unwrap();
225 t.adjust(&scope("a"), 800);
226 let snap = t.snapshot();
227 assert_eq!(snap[0].1.in_flight, 1000); assert_eq!(snap[0].1.high_water, 1000);
229 }
230
231 #[test]
232 fn adjust_negative_delta_decrements() {
233 let t = Arc::new(TenantTokenTracker::new(1000));
234 let _r = t.reserve(&scope("a"), 500).unwrap();
235 t.adjust(&scope("a"), -200);
236 assert_eq!(t.snapshot()[0].1.in_flight, 300);
237 }
238
239 #[test]
240 fn adjust_negative_i64_min_does_not_panic() {
241 let t = Arc::new(TenantTokenTracker::new(1000));
242 let _r = t.reserve(&scope("a"), 500).unwrap();
243 t.adjust(&scope("a"), i64::MIN); assert_eq!(t.snapshot()[0].1.in_flight, 0);
245 }
246
247 #[tokio::test]
248 async fn reservation_owns_arc_and_outlives_borrow() {
249 let t = Arc::new(TenantTokenTracker::new(1000));
251 let r = t.reserve(&scope("a"), 500).unwrap();
252 let handle: tokio::task::JoinHandle<()> = tokio::task::spawn_blocking(move || {
253 drop(r);
254 });
255 handle.await.unwrap();
256 }
257
258 #[test]
259 fn default_scope_uses_empty_string_bucket() {
260 let t = Arc::new(TenantTokenTracker::new(1000));
261 let _r = t.reserve(&TenantScope::default(), 500).unwrap();
262 let snap = t.snapshot();
263 assert_eq!(snap.len(), 1);
264 assert_eq!(snap[0].0, ""); }
266
267 #[test]
268 fn adjust_on_unknown_tenant_is_noop() {
269 let t = Arc::new(TenantTokenTracker::new(1000));
270 t.adjust(&scope("unknown"), -100);
271 assert!(t.snapshot().is_empty());
272 }
273}