gsm_backpressure/
lib.rs

1//! Distributed backpressure primitives backed by JetStream key-value state.
2
3use std::{
4    collections::HashMap,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    time::{Duration as StdDuration, Instant},
10};
11
12use anyhow::{Context, Result, anyhow};
13use async_nats::jetstream::{
14    Context as JsContext,
15    context::KeyValueErrorKind,
16    kv::{self, CreateErrorKind, UpdateErrorKind},
17};
18use async_trait::async_trait;
19use gsm_telemetry::{TelemetryLabels, record_gauge};
20use serde::{Deserialize, Serialize};
21use time::{Duration, OffsetDateTime, serde::rfc3339};
22use tokio::sync::Mutex;
23use tracing::{Level, event, instrument, warn};
24
25/// How many seconds one token represents.
26const TOKEN: f64 = 1.0;
27const TICK_MS: i64 = 100;
28
29fn compute_wait_secs(limit: RateLimit, tokens: f64) -> f64 {
30    let missing = (TOKEN - tokens).max(0.0);
31    missing / limit.rps.max(0.1)
32}
33
34fn record_backpressure_tokens(tenant: &str, tokens: f64) {
35    let labels = TelemetryLabels {
36        tenant: tenant.to_string(),
37        platform: None,
38        chat_id: None,
39        msg_id: None,
40        extra: Vec::new(),
41    };
42    record_gauge("backpressure_tokens", tokens.round() as i64, &labels);
43}
44#[derive(Debug, Clone, Copy)]
45pub struct RateLimit {
46    pub rps: f64,
47    pub burst: f64,
48}
49
50impl Default for RateLimit {
51    fn default() -> Self {
52        Self {
53            rps: 5.0,
54            burst: 10.0,
55        }
56    }
57}
58
59#[derive(Clone, Default)]
60pub struct RateLimits {
61    default: RateLimit,
62    tenants: HashMap<String, RateLimit>,
63}
64
65impl RateLimits {
66    pub fn new(default: RateLimit, tenants: HashMap<String, RateLimit>) -> Self {
67        let normalize = |limit: RateLimit| RateLimit {
68            rps: limit.rps.max(0.1),
69            burst: limit.burst.max(1.0),
70        };
71        let default = normalize(default);
72        let tenants = tenants
73            .into_iter()
74            .map(|(tenant, limit)| (tenant, normalize(limit)))
75            .collect();
76        Self { default, tenants }
77    }
78
79    pub fn with_tenants(tenants: HashMap<String, RateLimit>) -> Self {
80        Self::new(RateLimit::default(), tenants)
81    }
82
83    pub fn get(&self, tenant: &str) -> RateLimit {
84        self.tenants.get(tenant).copied().unwrap_or(self.default)
85    }
86}
87
88#[async_trait]
89pub trait BackpressureLimiter: Send + Sync {
90    async fn acquire(&self, tenant: &str) -> Result<Permit>;
91}
92
93#[derive(Debug)]
94pub struct Permit;
95
96impl Permit {
97    fn new() -> Self {
98        Self
99    }
100}
101
102#[derive(Clone)]
103pub struct LocalBackpressureLimiter {
104    limits: Arc<RateLimits>,
105    buckets: Arc<Mutex<HashMap<String, LocalBucket>>>,
106}
107
108#[derive(Debug)]
109struct LocalBucket {
110    tokens: f64,
111    last_refill: Instant,
112}
113
114impl LocalBackpressureLimiter {
115    pub fn new(limits: Arc<RateLimits>) -> Self {
116        Self {
117            limits,
118            buckets: Arc::new(Mutex::new(HashMap::new())),
119        }
120    }
121
122    fn refill(tokens: f64, elapsed: StdDuration, limit: RateLimit) -> (f64, StdDuration) {
123        if elapsed.is_zero() {
124            return (tokens, StdDuration::from_millis(0));
125        }
126        let ticks = (elapsed.as_millis() as i64) / TICK_MS;
127        if ticks <= 0 {
128            return (tokens, StdDuration::from_millis(0));
129        }
130        let refill = (ticks as f64) * (limit.rps * (TICK_MS as f64 / 1000.0));
131        let tokens = (tokens + refill).min(limit.burst);
132        let consumed = StdDuration::from_millis((ticks * TICK_MS) as u64);
133        (tokens, consumed)
134    }
135}
136
137#[async_trait]
138impl BackpressureLimiter for LocalBackpressureLimiter {
139    async fn acquire(&self, tenant: &str) -> Result<Permit> {
140        let tenant_key = tenant.to_string();
141        loop {
142            let limit = self.limits.get(tenant);
143            let mut guard = self.buckets.lock().await;
144            let bucket = guard.entry(tenant_key.clone()).or_insert(LocalBucket {
145                tokens: limit.burst,
146                last_refill: Instant::now(),
147            });
148            let now = Instant::now();
149            let elapsed = now.saturating_duration_since(bucket.last_refill);
150            let (filled, consumed) = Self::refill(bucket.tokens, elapsed, limit);
151            if consumed > StdDuration::from_millis(0) {
152                bucket.last_refill += consumed;
153                bucket.tokens = filled;
154            }
155            if bucket.tokens >= TOKEN {
156                bucket.tokens -= TOKEN;
157                record_backpressure_tokens(&tenant_key, bucket.tokens);
158                drop(guard);
159                return Ok(Permit::new());
160            }
161            let wait_secs = compute_wait_secs(limit, bucket.tokens);
162            if wait_secs > 1.0 {
163                event!(
164                    Level::INFO,
165                    tenant = %tenant_key,
166                    wait_secs,
167                    "backpressure.waiting_for_tokens"
168                );
169            }
170            drop(guard);
171            tokio::time::sleep(StdDuration::from_secs_f64(wait_secs.max(0.1))).await;
172        }
173    }
174}
175
176struct RemoteBucketState {
177    tokens: f64,
178    last_refill: OffsetDateTime,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182struct RemoteBucketPersisted {
183    tokens: f64,
184    #[serde(with = "rfc3339")]
185    last_refill_ts: OffsetDateTime,
186}
187
188impl RemoteBucketPersisted {
189    fn new(tokens: f64, now: OffsetDateTime) -> Self {
190        Self {
191            tokens,
192            last_refill_ts: now,
193        }
194    }
195}
196
197pub struct JetStreamBackpressureLimiter {
198    limits: Arc<RateLimits>,
199    bucket: kv::Store,
200    namespace: String,
201}
202
203impl JetStreamBackpressureLimiter {
204    pub async fn new(js: &JsContext, namespace: &str, limits: Arc<RateLimits>) -> Result<Self> {
205        let bucket = match js.get_key_value(namespace).await {
206            Ok(store) => store,
207            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => js
208                .create_key_value(kv::Config {
209                    bucket: namespace.to_string(),
210                    description: "backpressure rate limiter".into(),
211                    history: 1,
212                    max_age: StdDuration::from_secs(0),
213                    ..Default::default()
214                })
215                .await
216                .with_context(|| format!("create JetStream KV bucket {namespace}"))?,
217            Err(err) => return Err(anyhow!(err).context("initializing backpressure bucket")),
218        };
219        Ok(Self {
220            limits,
221            bucket,
222            namespace: namespace.to_string(),
223        })
224    }
225
226    fn parse_state(&self, entry: Option<kv::Entry>, limit: RateLimit) -> RemoteBucketState {
227        let now = OffsetDateTime::now_utc();
228        match entry {
229            Some(e) => serde_json::from_slice::<RemoteBucketPersisted>(e.value.as_ref())
230                .map(|persisted| RemoteBucketState {
231                    tokens: persisted.tokens.min(limit.burst),
232                    last_refill: persisted.last_refill_ts,
233                })
234                .unwrap_or(RemoteBucketState {
235                    tokens: limit.burst,
236                    last_refill: now,
237                }),
238            None => RemoteBucketState {
239                tokens: limit.burst,
240                last_refill: now,
241            },
242        }
243    }
244
245    fn refill_tokens(
246        mut state: RemoteBucketState,
247        limit: RateLimit,
248        now: OffsetDateTime,
249    ) -> RemoteBucketState {
250        if now <= state.last_refill {
251            return state;
252        }
253        let elapsed_ms = (now - state.last_refill).whole_milliseconds();
254        let ticks = (elapsed_ms / i128::from(TICK_MS)) as i64;
255        if ticks <= 0 {
256            return state;
257        }
258        let refill = (ticks as f64) * (limit.rps * (TICK_MS as f64 / 1000.0));
259        state.tokens = (state.tokens + refill).min(limit.burst);
260        state.last_refill += Duration::milliseconds(ticks * TICK_MS);
261        state
262    }
263
264    async fn wait_for_tokens(wait_secs: f64) {
265        tokio::time::sleep(StdDuration::from_secs_f64(wait_secs.max(0.1))).await;
266    }
267}
268
269#[async_trait]
270impl BackpressureLimiter for JetStreamBackpressureLimiter {
271    #[instrument(name = "backpressure.remote.acquire", skip(self), fields(namespace = %self.namespace, tenant))]
272    async fn acquire(&self, tenant: &str) -> Result<Permit> {
273        let tenant_key = tenant.to_string();
274        let limit = self.limits.get(tenant);
275        let key = format!("rate/{tenant}");
276        let mut retries = 0usize;
277
278        loop {
279            let entry = self
280                .bucket
281                .entry(key.as_str())
282                .await
283                .with_context(|| format!("load rate state for {tenant}"))?;
284            let now = OffsetDateTime::now_utc();
285            let mut state = self.parse_state(entry.clone(), limit);
286            state = Self::refill_tokens(state, limit, now);
287            if state.tokens < TOKEN {
288                let wait_secs = compute_wait_secs(limit, state.tokens);
289                if wait_secs > 1.0 {
290                    event!(
291                        Level::INFO,
292                        tenant = %tenant_key,
293                        wait_secs,
294                        namespace = %self.namespace,
295                        "backpressure.waiting_for_tokens"
296                    );
297                }
298                Self::wait_for_tokens(wait_secs).await;
299                continue;
300            }
301            state.tokens -= TOKEN;
302            record_backpressure_tokens(&tenant_key, state.tokens);
303            let persisted = RemoteBucketPersisted::new(state.tokens, state.last_refill);
304            let payload = serde_json::to_vec(&persisted)?;
305            match &entry {
306                Some(e) => match self
307                    .bucket
308                    .update(key.as_str(), payload.clone().into(), e.revision)
309                    .await
310                {
311                    Ok(_) => return Ok(Permit::new()),
312                    Err(err) if err.kind() == UpdateErrorKind::WrongLastRevision => {
313                        retries += 1;
314                        if retries > 3 {
315                            event!(
316                                Level::WARN,
317                                tenant = %tenant_key,
318                                retries,
319                                "egress.acquire_permit.cas_retry"
320                            );
321                        }
322                        continue;
323                    }
324                    Err(err) => {
325                        return Err(
326                            anyhow!(err).context(format!("update remote rate state {tenant}"))
327                        );
328                    }
329                },
330                None => match self
331                    .bucket
332                    .create(key.as_str(), payload.clone().into())
333                    .await
334                {
335                    Ok(_) => return Ok(Permit::new()),
336                    Err(err) if err.kind() == CreateErrorKind::AlreadyExists => {
337                        retries += 1;
338                        continue;
339                    }
340                    Err(err) => {
341                        return Err(
342                            anyhow!(err).context(format!("create remote rate state {tenant}"))
343                        );
344                    }
345                },
346            }
347        }
348    }
349}
350
351pub struct HybridLimiter {
352    remote: Option<JetStreamBackpressureLimiter>,
353    local: LocalBackpressureLimiter,
354    remote_failed: AtomicBool,
355}
356
357impl HybridLimiter {
358    pub async fn new(js: Option<&JsContext>) -> Result<Arc<Self>> {
359        Self::new_with_config(js, BackpressureConfig::default()).await
360    }
361
362    pub async fn new_with_config(
363        js: Option<&JsContext>,
364        config: BackpressureConfig,
365    ) -> Result<Arc<Self>> {
366        let limits = Arc::clone(&config.limits);
367        let namespace = config.namespace;
368        let remote = match js {
369            Some(ctx) => {
370                match JetStreamBackpressureLimiter::new(ctx, &namespace, limits.clone()).await {
371                    Ok(limiter) => Some(limiter),
372                    Err(err) => {
373                        warn!(error = %err, "remote backpressure store unavailable, falling back to local limiter");
374                        None
375                    }
376                }
377            }
378            None => None,
379        };
380
381        let local = LocalBackpressureLimiter::new(limits);
382        Ok(Arc::new(Self {
383            remote,
384            local,
385            remote_failed: AtomicBool::new(false),
386        }))
387    }
388}
389
390#[derive(Clone)]
391pub struct BackpressureConfig {
392    pub namespace: String,
393    pub limits: Arc<RateLimits>,
394}
395
396impl BackpressureConfig {
397    pub fn new(namespace: impl Into<String>, limits: Arc<RateLimits>) -> Self {
398        Self {
399            namespace: namespace.into(),
400            limits,
401        }
402    }
403}
404
405impl Default for BackpressureConfig {
406    fn default() -> Self {
407        Self {
408            namespace: "rate-limits".to_string(),
409            limits: Arc::new(RateLimits::default()),
410        }
411    }
412}
413
414#[async_trait]
415impl BackpressureLimiter for HybridLimiter {
416    async fn acquire(&self, tenant: &str) -> Result<Permit> {
417        if let Some(remote) = &self.remote {
418            match remote.acquire(tenant).await {
419                Ok(permit) => {
420                    self.remote_failed.store(false, Ordering::Release);
421                    return Ok(permit);
422                }
423                Err(err) => {
424                    if !self.remote_failed.swap(true, Ordering::AcqRel) {
425                        warn!(error = %err, "remote limiter failed, switching to local fallback");
426                    }
427                }
428            }
429        }
430        self.local.acquire(tenant).await
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[tokio::test]
439    async fn local_refills() {
440        let limits = Arc::new(RateLimits {
441            default: RateLimit {
442                rps: 10.0,
443                burst: 2.0,
444            },
445            tenants: HashMap::new(),
446        });
447        let limiter = LocalBackpressureLimiter::new(limits);
448        let _ = limiter.acquire("t").await.unwrap();
449        let _ = limiter.acquire("t").await.unwrap();
450    }
451
452    #[test]
453    fn compute_wait_secs_reflects_missing_tokens() {
454        let limit = RateLimit {
455            rps: 2.0,
456            burst: 1.0,
457        };
458        let wait = compute_wait_secs(limit, 0.0);
459        assert!((wait - 0.5).abs() < 1e-6);
460
461        let instant = compute_wait_secs(limit, 2.0);
462        assert_eq!(instant, 0.0);
463    }
464
465    #[test]
466    fn refill_respects_burst_and_elapsed_time() {
467        let limit = RateLimit {
468            rps: 1.0,
469            burst: 2.0,
470        };
471        // No time elapsed, nothing changes.
472        let (tokens, consumed) =
473            LocalBackpressureLimiter::refill(0.5, StdDuration::from_millis(0), limit);
474        assert_eq!(tokens, 0.5);
475        assert_eq!(consumed, StdDuration::from_millis(0));
476
477        // Two seconds elapsed should top up but not exceed burst.
478        let (tokens, consumed) =
479            LocalBackpressureLimiter::refill(0.5, StdDuration::from_secs(2), limit);
480        assert_eq!(tokens, 2.0);
481        assert_eq!(consumed, StdDuration::from_secs(2));
482    }
483
484    #[test]
485    fn rate_limits_enforce_minimums() {
486        let mut tenants = HashMap::new();
487        tenants.insert(
488            "tenant".to_string(),
489            RateLimit {
490                rps: 0.0,
491                burst: 0.0,
492            },
493        );
494        let limits = RateLimits::with_tenants(tenants);
495        let cfg = limits.get("tenant");
496        assert_eq!(cfg.rps, 0.1);
497        assert_eq!(cfg.burst, 1.0);
498    }
499
500    #[test]
501    fn rate_limits_use_defaults_when_missing() {
502        let limits = RateLimits::default();
503        let cfg = limits.get("unknown");
504        assert_eq!(cfg.rps, 5.0);
505        assert_eq!(cfg.burst, 10.0);
506    }
507}