1use std::{
4 collections::HashMap,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 time::{Duration as StdDuration, Instant},
10};
11
12use anyhow::{Context, Result, anyhow};
13use async_nats::jetstream::{
14 Context as JsContext,
15 context::KeyValueErrorKind,
16 kv::{self, CreateErrorKind, UpdateErrorKind},
17};
18use async_trait::async_trait;
19use gsm_telemetry::{TelemetryLabels, record_gauge};
20use serde::{Deserialize, Serialize};
21use time::{Duration, OffsetDateTime, serde::rfc3339};
22use tokio::sync::Mutex;
23use tracing::{Level, event, instrument, warn};
24
25const TOKEN: f64 = 1.0;
27const TICK_MS: i64 = 100;
28
29fn compute_wait_secs(limit: RateLimit, tokens: f64) -> f64 {
30 let missing = (TOKEN - tokens).max(0.0);
31 missing / limit.rps.max(0.1)
32}
33
34fn record_backpressure_tokens(tenant: &str, tokens: f64) {
35 let labels = TelemetryLabels {
36 tenant: tenant.to_string(),
37 platform: None,
38 chat_id: None,
39 msg_id: None,
40 extra: Vec::new(),
41 };
42 record_gauge("backpressure_tokens", tokens.round() as i64, &labels);
43}
44#[derive(Debug, Clone, Copy)]
45pub struct RateLimit {
46 pub rps: f64,
47 pub burst: f64,
48}
49
50impl Default for RateLimit {
51 fn default() -> Self {
52 Self {
53 rps: 5.0,
54 burst: 10.0,
55 }
56 }
57}
58
59#[derive(Clone, Default)]
60pub struct RateLimits {
61 default: RateLimit,
62 tenants: HashMap<String, RateLimit>,
63}
64
65impl RateLimits {
66 pub fn new(default: RateLimit, tenants: HashMap<String, RateLimit>) -> Self {
67 let normalize = |limit: RateLimit| RateLimit {
68 rps: limit.rps.max(0.1),
69 burst: limit.burst.max(1.0),
70 };
71 let default = normalize(default);
72 let tenants = tenants
73 .into_iter()
74 .map(|(tenant, limit)| (tenant, normalize(limit)))
75 .collect();
76 Self { default, tenants }
77 }
78
79 pub fn with_tenants(tenants: HashMap<String, RateLimit>) -> Self {
80 Self::new(RateLimit::default(), tenants)
81 }
82
83 pub fn get(&self, tenant: &str) -> RateLimit {
84 self.tenants.get(tenant).copied().unwrap_or(self.default)
85 }
86}
87
88#[async_trait]
89pub trait BackpressureLimiter: Send + Sync {
90 async fn acquire(&self, tenant: &str) -> Result<Permit>;
91}
92
93#[derive(Debug)]
94pub struct Permit;
95
96impl Permit {
97 fn new() -> Self {
98 Self
99 }
100}
101
102#[derive(Clone)]
103pub struct LocalBackpressureLimiter {
104 limits: Arc<RateLimits>,
105 buckets: Arc<Mutex<HashMap<String, LocalBucket>>>,
106}
107
108#[derive(Debug)]
109struct LocalBucket {
110 tokens: f64,
111 last_refill: Instant,
112}
113
114impl LocalBackpressureLimiter {
115 pub fn new(limits: Arc<RateLimits>) -> Self {
116 Self {
117 limits,
118 buckets: Arc::new(Mutex::new(HashMap::new())),
119 }
120 }
121
122 fn refill(tokens: f64, elapsed: StdDuration, limit: RateLimit) -> (f64, StdDuration) {
123 if elapsed.is_zero() {
124 return (tokens, StdDuration::from_millis(0));
125 }
126 let ticks = (elapsed.as_millis() as i64) / TICK_MS;
127 if ticks <= 0 {
128 return (tokens, StdDuration::from_millis(0));
129 }
130 let refill = (ticks as f64) * (limit.rps * (TICK_MS as f64 / 1000.0));
131 let tokens = (tokens + refill).min(limit.burst);
132 let consumed = StdDuration::from_millis((ticks * TICK_MS) as u64);
133 (tokens, consumed)
134 }
135}
136
137#[async_trait]
138impl BackpressureLimiter for LocalBackpressureLimiter {
139 async fn acquire(&self, tenant: &str) -> Result<Permit> {
140 let tenant_key = tenant.to_string();
141 loop {
142 let limit = self.limits.get(tenant);
143 let mut guard = self.buckets.lock().await;
144 let bucket = guard.entry(tenant_key.clone()).or_insert(LocalBucket {
145 tokens: limit.burst,
146 last_refill: Instant::now(),
147 });
148 let now = Instant::now();
149 let elapsed = now.saturating_duration_since(bucket.last_refill);
150 let (filled, consumed) = Self::refill(bucket.tokens, elapsed, limit);
151 if consumed > StdDuration::from_millis(0) {
152 bucket.last_refill += consumed;
153 bucket.tokens = filled;
154 }
155 if bucket.tokens >= TOKEN {
156 bucket.tokens -= TOKEN;
157 record_backpressure_tokens(&tenant_key, bucket.tokens);
158 drop(guard);
159 return Ok(Permit::new());
160 }
161 let wait_secs = compute_wait_secs(limit, bucket.tokens);
162 if wait_secs > 1.0 {
163 event!(
164 Level::INFO,
165 tenant = %tenant_key,
166 wait_secs,
167 "backpressure.waiting_for_tokens"
168 );
169 }
170 drop(guard);
171 tokio::time::sleep(StdDuration::from_secs_f64(wait_secs.max(0.1))).await;
172 }
173 }
174}
175
176struct RemoteBucketState {
177 tokens: f64,
178 last_refill: OffsetDateTime,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182struct RemoteBucketPersisted {
183 tokens: f64,
184 #[serde(with = "rfc3339")]
185 last_refill_ts: OffsetDateTime,
186}
187
188impl RemoteBucketPersisted {
189 fn new(tokens: f64, now: OffsetDateTime) -> Self {
190 Self {
191 tokens,
192 last_refill_ts: now,
193 }
194 }
195}
196
197pub struct JetStreamBackpressureLimiter {
198 limits: Arc<RateLimits>,
199 bucket: kv::Store,
200 namespace: String,
201}
202
203impl JetStreamBackpressureLimiter {
204 pub async fn new(js: &JsContext, namespace: &str, limits: Arc<RateLimits>) -> Result<Self> {
205 let bucket = match js.get_key_value(namespace).await {
206 Ok(store) => store,
207 Err(err) if err.kind() == KeyValueErrorKind::GetBucket => js
208 .create_key_value(kv::Config {
209 bucket: namespace.to_string(),
210 description: "backpressure rate limiter".into(),
211 history: 1,
212 max_age: StdDuration::from_secs(0),
213 ..Default::default()
214 })
215 .await
216 .with_context(|| format!("create JetStream KV bucket {namespace}"))?,
217 Err(err) => return Err(anyhow!(err).context("initializing backpressure bucket")),
218 };
219 Ok(Self {
220 limits,
221 bucket,
222 namespace: namespace.to_string(),
223 })
224 }
225
226 fn parse_state(&self, entry: Option<kv::Entry>, limit: RateLimit) -> RemoteBucketState {
227 let now = OffsetDateTime::now_utc();
228 match entry {
229 Some(e) => serde_json::from_slice::<RemoteBucketPersisted>(e.value.as_ref())
230 .map(|persisted| RemoteBucketState {
231 tokens: persisted.tokens.min(limit.burst),
232 last_refill: persisted.last_refill_ts,
233 })
234 .unwrap_or(RemoteBucketState {
235 tokens: limit.burst,
236 last_refill: now,
237 }),
238 None => RemoteBucketState {
239 tokens: limit.burst,
240 last_refill: now,
241 },
242 }
243 }
244
245 fn refill_tokens(
246 mut state: RemoteBucketState,
247 limit: RateLimit,
248 now: OffsetDateTime,
249 ) -> RemoteBucketState {
250 if now <= state.last_refill {
251 return state;
252 }
253 let elapsed_ms = (now - state.last_refill).whole_milliseconds();
254 let ticks = (elapsed_ms / i128::from(TICK_MS)) as i64;
255 if ticks <= 0 {
256 return state;
257 }
258 let refill = (ticks as f64) * (limit.rps * (TICK_MS as f64 / 1000.0));
259 state.tokens = (state.tokens + refill).min(limit.burst);
260 state.last_refill += Duration::milliseconds(ticks * TICK_MS);
261 state
262 }
263
264 async fn wait_for_tokens(wait_secs: f64) {
265 tokio::time::sleep(StdDuration::from_secs_f64(wait_secs.max(0.1))).await;
266 }
267}
268
269#[async_trait]
270impl BackpressureLimiter for JetStreamBackpressureLimiter {
271 #[instrument(name = "backpressure.remote.acquire", skip(self), fields(namespace = %self.namespace, tenant))]
272 async fn acquire(&self, tenant: &str) -> Result<Permit> {
273 let tenant_key = tenant.to_string();
274 let limit = self.limits.get(tenant);
275 let key = format!("rate/{tenant}");
276 let mut retries = 0usize;
277
278 loop {
279 let entry = self
280 .bucket
281 .entry(key.as_str())
282 .await
283 .with_context(|| format!("load rate state for {tenant}"))?;
284 let now = OffsetDateTime::now_utc();
285 let mut state = self.parse_state(entry.clone(), limit);
286 state = Self::refill_tokens(state, limit, now);
287 if state.tokens < TOKEN {
288 let wait_secs = compute_wait_secs(limit, state.tokens);
289 if wait_secs > 1.0 {
290 event!(
291 Level::INFO,
292 tenant = %tenant_key,
293 wait_secs,
294 namespace = %self.namespace,
295 "backpressure.waiting_for_tokens"
296 );
297 }
298 Self::wait_for_tokens(wait_secs).await;
299 continue;
300 }
301 state.tokens -= TOKEN;
302 record_backpressure_tokens(&tenant_key, state.tokens);
303 let persisted = RemoteBucketPersisted::new(state.tokens, state.last_refill);
304 let payload = serde_json::to_vec(&persisted)?;
305 match &entry {
306 Some(e) => match self
307 .bucket
308 .update(key.as_str(), payload.clone().into(), e.revision)
309 .await
310 {
311 Ok(_) => return Ok(Permit::new()),
312 Err(err) if err.kind() == UpdateErrorKind::WrongLastRevision => {
313 retries += 1;
314 if retries > 3 {
315 event!(
316 Level::WARN,
317 tenant = %tenant_key,
318 retries,
319 "egress.acquire_permit.cas_retry"
320 );
321 }
322 continue;
323 }
324 Err(err) => {
325 return Err(
326 anyhow!(err).context(format!("update remote rate state {tenant}"))
327 );
328 }
329 },
330 None => match self
331 .bucket
332 .create(key.as_str(), payload.clone().into())
333 .await
334 {
335 Ok(_) => return Ok(Permit::new()),
336 Err(err) if err.kind() == CreateErrorKind::AlreadyExists => {
337 retries += 1;
338 continue;
339 }
340 Err(err) => {
341 return Err(
342 anyhow!(err).context(format!("create remote rate state {tenant}"))
343 );
344 }
345 },
346 }
347 }
348 }
349}
350
351pub struct HybridLimiter {
352 remote: Option<JetStreamBackpressureLimiter>,
353 local: LocalBackpressureLimiter,
354 remote_failed: AtomicBool,
355}
356
357impl HybridLimiter {
358 pub async fn new(js: Option<&JsContext>) -> Result<Arc<Self>> {
359 Self::new_with_config(js, BackpressureConfig::default()).await
360 }
361
362 pub async fn new_with_config(
363 js: Option<&JsContext>,
364 config: BackpressureConfig,
365 ) -> Result<Arc<Self>> {
366 let limits = Arc::clone(&config.limits);
367 let namespace = config.namespace;
368 let remote = match js {
369 Some(ctx) => {
370 match JetStreamBackpressureLimiter::new(ctx, &namespace, limits.clone()).await {
371 Ok(limiter) => Some(limiter),
372 Err(err) => {
373 warn!(error = %err, "remote backpressure store unavailable, falling back to local limiter");
374 None
375 }
376 }
377 }
378 None => None,
379 };
380
381 let local = LocalBackpressureLimiter::new(limits);
382 Ok(Arc::new(Self {
383 remote,
384 local,
385 remote_failed: AtomicBool::new(false),
386 }))
387 }
388}
389
390#[derive(Clone)]
391pub struct BackpressureConfig {
392 pub namespace: String,
393 pub limits: Arc<RateLimits>,
394}
395
396impl BackpressureConfig {
397 pub fn new(namespace: impl Into<String>, limits: Arc<RateLimits>) -> Self {
398 Self {
399 namespace: namespace.into(),
400 limits,
401 }
402 }
403}
404
405impl Default for BackpressureConfig {
406 fn default() -> Self {
407 Self {
408 namespace: "rate-limits".to_string(),
409 limits: Arc::new(RateLimits::default()),
410 }
411 }
412}
413
414#[async_trait]
415impl BackpressureLimiter for HybridLimiter {
416 async fn acquire(&self, tenant: &str) -> Result<Permit> {
417 if let Some(remote) = &self.remote {
418 match remote.acquire(tenant).await {
419 Ok(permit) => {
420 self.remote_failed.store(false, Ordering::Release);
421 return Ok(permit);
422 }
423 Err(err) => {
424 if !self.remote_failed.swap(true, Ordering::AcqRel) {
425 warn!(error = %err, "remote limiter failed, switching to local fallback");
426 }
427 }
428 }
429 }
430 self.local.acquire(tenant).await
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[tokio::test]
439 async fn local_refills() {
440 let limits = Arc::new(RateLimits {
441 default: RateLimit {
442 rps: 10.0,
443 burst: 2.0,
444 },
445 tenants: HashMap::new(),
446 });
447 let limiter = LocalBackpressureLimiter::new(limits);
448 let _ = limiter.acquire("t").await.unwrap();
449 let _ = limiter.acquire("t").await.unwrap();
450 }
451
452 #[test]
453 fn compute_wait_secs_reflects_missing_tokens() {
454 let limit = RateLimit {
455 rps: 2.0,
456 burst: 1.0,
457 };
458 let wait = compute_wait_secs(limit, 0.0);
459 assert!((wait - 0.5).abs() < 1e-6);
460
461 let instant = compute_wait_secs(limit, 2.0);
462 assert_eq!(instant, 0.0);
463 }
464
465 #[test]
466 fn refill_respects_burst_and_elapsed_time() {
467 let limit = RateLimit {
468 rps: 1.0,
469 burst: 2.0,
470 };
471 let (tokens, consumed) =
473 LocalBackpressureLimiter::refill(0.5, StdDuration::from_millis(0), limit);
474 assert_eq!(tokens, 0.5);
475 assert_eq!(consumed, StdDuration::from_millis(0));
476
477 let (tokens, consumed) =
479 LocalBackpressureLimiter::refill(0.5, StdDuration::from_secs(2), limit);
480 assert_eq!(tokens, 2.0);
481 assert_eq!(consumed, StdDuration::from_secs(2));
482 }
483
484 #[test]
485 fn rate_limits_enforce_minimums() {
486 let mut tenants = HashMap::new();
487 tenants.insert(
488 "tenant".to_string(),
489 RateLimit {
490 rps: 0.0,
491 burst: 0.0,
492 },
493 );
494 let limits = RateLimits::with_tenants(tenants);
495 let cfg = limits.get("tenant");
496 assert_eq!(cfg.rps, 0.1);
497 assert_eq!(cfg.burst, 1.0);
498 }
499
500 #[test]
501 fn rate_limits_use_defaults_when_missing() {
502 let limits = RateLimits::default();
503 let cfg = limits.get("unknown");
504 assert_eq!(cfg.rps, 5.0);
505 assert_eq!(cfg.burst, 10.0);
506 }
507}