Skip to main content

grate_limiter/
engine.rs

1use std::sync::Arc;
2
3use dashmap::DashMap;
4use parking_lot::RwLock;
5use smallvec::SmallVec;
6
7use crate::capability::{CapabilityConfig, CapabilityProvider};
8use crate::clock::{Clock, Timestamp};
9use crate::config::EngineConfig;
10use crate::decision::{Alternative, Decision, ScoreBreakdown};
11use crate::error::{Error, Result};
12use crate::health::HealthState;
13use crate::metrics::Metrics;
14use crate::observation::{Observation, StatusClass};
15use crate::provider::ProviderConfig;
16use crate::quota::{self, Dimension, QuotaConfig};
17use crate::scoring::{ProviderScoreContext, ScoringStrategy, WeightedScorer};
18
19/// The main grate-limiter engine.
20///
21/// Thread-safe and cheaply cloneable. All instances share the same internal state.
22///
23/// # Example
24///
25/// ```rust
26/// use grate_limiter::{GrateLimiter, EngineConfig};
27///
28/// let engine = GrateLimiter::new(EngineConfig::default());
29/// // Register providers and capabilities, then use engine.select() and engine.observe()
30/// ```
31#[derive(Clone)]
32pub struct GrateLimiter {
33    inner: Arc<Inner>,
34}
35
36struct Inner {
37    /// Per-provider runtime state. Key = provider name.
38    providers: DashMap<String, ProviderRuntime>,
39    /// Capability definitions. Key = capability name.
40    capabilities: RwLock<DashMap<String, CapabilityDef>>,
41    /// Scoring strategy.
42    scorer: Box<dyn ScoringStrategy>,
43    /// Engine configuration.
44    config: EngineConfig,
45    /// Monotonic clock.
46    clock: Arc<dyn Clock>,
47    /// Observable metrics.
48    metrics: Metrics,
49}
50
51/// Internal runtime state for a provider.
52struct ProviderRuntime {
53    config: ProviderConfig,
54    health: RwLock<HealthState>,
55    quota_trackers: Vec<(QuotaConfig, Box<dyn crate::quota::QuotaTracker>)>,
56}
57
58/// Internal capability definition.
59struct CapabilityDef {
60    providers: SmallVec<[CapabilityProvider; 4]>,
61}
62
63impl GrateLimiter {
64    /// Create a new engine with the given configuration.
65    pub fn new(config: EngineConfig) -> Self {
66        let clock = config.clock();
67        let scorer = Box::new(WeightedScorer::new(config.scoring.clone()));
68
69        Self {
70            inner: Arc::new(Inner {
71                providers: DashMap::new(),
72                capabilities: RwLock::new(DashMap::new()),
73                scorer,
74                config,
75                clock,
76                metrics: Metrics::new(),
77            }),
78        }
79    }
80
81    /// Register or update a provider and its quotas.
82    ///
83    /// If the provider already exists, its configuration and quota trackers are replaced.
84    /// Health state is preserved across upserts.
85    pub fn upsert_provider(&self, config: ProviderConfig) {
86        let now = self.inner.clock.now();
87        let trackers: Vec<_> = config
88            .quotas
89            .iter()
90            .map(|qc| (qc.clone(), quota::create_tracker(qc, now)))
91            .collect();
92
93        if let Some(mut existing) = self.inner.providers.get_mut(&config.name) {
94            // Preserve health, update config and trackers
95            existing.config = config;
96            existing.quota_trackers = trackers;
97        } else {
98            self.inner.providers.insert(
99                config.name.clone(),
100                ProviderRuntime {
101                    config,
102                    health: RwLock::new(HealthState::new(now)),
103                    quota_trackers: trackers,
104                },
105            );
106        }
107    }
108
109    /// Register or update a capability and its provider mappings.
110    pub fn upsert_capability(&self, config: CapabilityConfig) {
111        let caps = self.inner.capabilities.read();
112        caps.insert(
113            config.name.clone(),
114            CapabilityDef {
115                providers: SmallVec::from_vec(config.providers),
116            },
117        );
118    }
119
120    /// Select the best provider for a capability.
121    ///
122    /// Returns a [`Decision`] with the selected provider, its score, reasoning breakdown,
123    /// and ranked alternatives.
124    ///
125    /// # Errors
126    ///
127    /// - [`Error::UnknownCapability`] if the capability is not registered.
128    /// - [`Error::NoAvailableProviders`] if all providers are in cooldown or below minimum health.
129    pub fn select(&self, capability: &str) -> Result<Decision> {
130        self.inner
131            .metrics
132            .selects
133            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
134
135        let now = self.inner.clock.now();
136
137        // Look up capability
138        let caps = self.inner.capabilities.read();
139        let cap_def = caps
140            .get(capability)
141            .ok_or_else(|| Error::UnknownCapability(capability.to_string()))?;
142
143        let cap_providers = &cap_def.providers;
144        if cap_providers.is_empty() {
145            return Err(Error::NoAvailableProviders(capability.to_string()));
146        }
147
148        // Find max priority and max latency for normalization
149        let max_priority = cap_providers.iter().map(|p| p.priority).max().unwrap_or(1);
150
151        let mut max_latency_ms: f64 = 0.0;
152        for cp in cap_providers.iter() {
153            if let Some(pr) = self.inner.providers.get(&cp.provider) {
154                let health = pr.health.read();
155                if health.latency_ms() > max_latency_ms {
156                    max_latency_ms = health.latency_ms();
157                }
158            }
159        }
160        if max_latency_ms <= 0.0 {
161            max_latency_ms = 1.0;
162        }
163
164        // Score all providers
165        let mut candidates: SmallVec<[(String, f32, ScoreBreakdown); 4]> = SmallVec::new();
166
167        for cp in cap_providers.iter() {
168            let Some(pr) = self.inner.providers.get(&cp.provider) else {
169                continue;
170            };
171
172            let health = pr.health.read();
173
174            // Skip providers in cooldown
175            if health.is_in_cooldown(now) {
176                continue;
177            }
178
179            // Skip providers below minimum health
180            if health.score() < self.inner.config.minimum_health_score {
181                continue;
182            }
183
184            // Calculate worst quota state across all dimensions
185            let (quota_remaining_ratio, predicted_exhaustion, burn_rate) =
186                self.worst_quota_state(&pr.quota_trackers, now);
187
188            let ctx = ProviderScoreContext {
189                quota_remaining_ratio,
190                predicted_exhaustion_secs: predicted_exhaustion,
191                burn_rate,
192                health_score: health.score(),
193                priority: cp.priority,
194                max_priority,
195                latency_ms: health.latency_ms(),
196                max_latency_ms,
197            };
198
199            let score = self.inner.scorer.score(&ctx);
200            let breakdown = ScoreBreakdown {
201                quota_score: ctx.quota_remaining_ratio as f32,
202                health_score: ctx.health_score,
203                priority_score: cp.priority as f32 / max_priority as f32,
204                latency_score: if max_latency_ms > 0.0 {
205                    (1.0 - (ctx.latency_ms / max_latency_ms) as f32).max(0.0)
206                } else {
207                    1.0
208                },
209            };
210
211            candidates.push((cp.provider.clone(), score, breakdown));
212        }
213
214        drop(cap_def);
215        drop(caps);
216
217        if candidates.is_empty() {
218            self.inner
219                .metrics
220                .no_provider_available
221                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
222            return Err(Error::NoAvailableProviders(capability.to_string()));
223        }
224
225        // Sort by score descending
226        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227
228        let (provider, score, reasoning) = candidates.remove(0);
229        let alternatives = candidates
230            .into_iter()
231            .map(|(p, s, _)| Alternative {
232                provider: p,
233                score: s,
234            })
235            .collect();
236
237        Ok(Decision {
238            provider,
239            score,
240            reasoning,
241            alternatives,
242        })
243    }
244
245    /// Report an observation after a provider interaction.
246    ///
247    /// Updates quota counters and health state for the provider.
248    ///
249    /// # Errors
250    ///
251    /// - [`Error::UnknownProvider`] if the provider is not registered.
252    pub fn observe(&self, obs: Observation) -> Result<()> {
253        self.inner
254            .metrics
255            .observations
256            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257
258        let now = self.inner.clock.now();
259
260        let pr = self
261            .inner
262            .providers
263            .get(&obs.provider)
264            .ok_or_else(|| Error::UnknownProvider(obs.provider.clone()))?;
265
266        // Update quota trackers
267        for (qc, tracker) in &pr.quota_trackers {
268            let amount = match qc.dimension {
269                Dimension::Requests => obs.usage.requests,
270                Dimension::Tokens => obs.usage.tokens.unwrap_or(0),
271                Dimension::Bytes => obs.usage.bytes.unwrap_or(0),
272                Dimension::CostUsd => obs.usage.cost_micro_usd.unwrap_or(0),
273                Dimension::Concurrency => obs.usage.requests, // track as concurrency slot
274            };
275            if amount > 0 {
276                tracker.record(amount, now);
277            }
278        }
279
280        // Update health
281        let cooldown_secs = pr.config.cooldown_seconds;
282        let health_config = &self.inner.config.health;
283        let mut health = pr.health.write();
284        let was_in_cooldown = health.is_in_cooldown(now);
285
286        match obs.outcome.status {
287            StatusClass::Success | StatusClass::ClientError => {
288                health.record_success(obs.outcome.latency_ms, now, health_config);
289            }
290            StatusClass::RateLimited => {
291                health.record_rate_limited(now, health_config, cooldown_secs);
292            }
293            StatusClass::Forbidden => {
294                health.record_forbidden(now, health_config, cooldown_secs);
295            }
296            StatusClass::ServerError => {
297                health.record_server_error(now, health_config, cooldown_secs);
298            }
299            StatusClass::Timeout => {
300                health.record_timeout(now, health_config, cooldown_secs);
301            }
302        }
303
304        // Track cooldown triggers
305        if !was_in_cooldown && health.is_in_cooldown(now) {
306            self.inner
307                .metrics
308                .cooldowns_triggered
309                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
310        }
311
312        Ok(())
313    }
314
315    /// Access engine metrics.
316    pub fn metrics(&self) -> &Metrics {
317        &self.inner.metrics
318    }
319
320    /// Get the current health score for a provider.
321    pub fn provider_health(&self, provider: &str) -> Option<f32> {
322        self.inner
323            .providers
324            .get(provider)
325            .map(|pr| pr.health.read().score())
326    }
327
328    /// Check if a provider is currently in cooldown.
329    pub fn provider_in_cooldown(&self, provider: &str) -> Option<bool> {
330        let now = self.inner.clock.now();
331        self.inner
332            .providers
333            .get(provider)
334            .map(|pr| pr.health.read().is_in_cooldown(now))
335    }
336
337    /// Get the remaining quota for a specific dimension on a provider.
338    pub fn provider_quota_remaining(&self, provider: &str, dimension: Dimension) -> Option<u64> {
339        let now = self.inner.clock.now();
340        self.inner.providers.get(provider).and_then(|pr| {
341            pr.quota_trackers
342                .iter()
343                .find(|(qc, _)| qc.dimension == dimension)
344                .map(|(_, tracker)| tracker.remaining(now))
345        })
346    }
347
348    /// Calculate the worst (most constrained) quota state across all dimensions.
349    fn worst_quota_state(
350        &self,
351        trackers: &[(QuotaConfig, Box<dyn crate::quota::QuotaTracker>)],
352        now: Timestamp,
353    ) -> (f64, f64, f64) {
354        if trackers.is_empty() {
355            return (1.0, f64::INFINITY, 0.0);
356        }
357
358        let mut worst_remaining = 1.0_f64;
359        let mut worst_exhaustion = f64::INFINITY;
360        let mut max_burn_rate = 0.0_f64;
361
362        for (_, tracker) in trackers {
363            let remaining = 1.0 - tracker.usage_ratio(now);
364            let exhaustion = tracker.predicted_exhaustion_secs(now);
365            let burn = tracker.burn_rate(now);
366
367            if remaining < worst_remaining {
368                worst_remaining = remaining;
369            }
370            if exhaustion < worst_exhaustion {
371                worst_exhaustion = exhaustion;
372            }
373            if burn > max_burn_rate {
374                max_burn_rate = burn;
375            }
376        }
377
378        (worst_remaining, worst_exhaustion, max_burn_rate)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::clock::MockClock;
386    use crate::observation::{Outcome, Usage};
387    use crate::quota::Window;
388
389    fn setup_engine() -> (GrateLimiter, Arc<MockClock>) {
390        let clock = Arc::new(MockClock::new());
391        let config = EngineConfig::default().with_clock(clock.clone());
392        let engine = GrateLimiter::new(config);
393
394        engine.upsert_provider(ProviderConfig {
395            name: "openai".into(),
396            quotas: vec![QuotaConfig {
397                dimension: Dimension::Requests,
398                limit: 100,
399                window: Some(Window::Minute),
400            }],
401            priority: 10,
402            weight: 1.0,
403            cooldown_seconds: 30,
404        });
405
406        engine.upsert_provider(ProviderConfig {
407            name: "anthropic".into(),
408            quotas: vec![QuotaConfig {
409                dimension: Dimension::Requests,
410                limit: 80,
411                window: Some(Window::Minute),
412            }],
413            priority: 8,
414            weight: 1.0,
415            cooldown_seconds: 30,
416        });
417
418        engine.upsert_capability(CapabilityConfig {
419            name: "chat".into(),
420            providers: vec![
421                CapabilityProvider {
422                    provider: "openai".into(),
423                    priority: 10,
424                },
425                CapabilityProvider {
426                    provider: "anthropic".into(),
427                    priority: 8,
428                },
429            ],
430        });
431
432        (engine, clock)
433    }
434
435    #[test]
436    fn select_returns_best_provider() {
437        let (engine, _clock) = setup_engine();
438        let decision = engine.select("chat").unwrap();
439        // openai has higher priority and both are fully healthy
440        assert_eq!(decision.provider, "openai");
441        assert!(decision.score > 0.0);
442        assert_eq!(decision.alternatives.len(), 1);
443    }
444
445    #[test]
446    fn select_unknown_capability_errors() {
447        let (engine, _clock) = setup_engine();
448        let result = engine.select("nonexistent");
449        assert!(matches!(result, Err(Error::UnknownCapability(_))));
450    }
451
452    #[test]
453    fn observe_updates_health() {
454        let (engine, _clock) = setup_engine();
455
456        // Report a 429 for openai
457        engine
458            .observe(Observation {
459                provider: "openai".into(),
460                capability: Some("chat".into()),
461                usage: Usage {
462                    requests: 1,
463                    ..Default::default()
464                },
465                outcome: Outcome {
466                    status: StatusClass::RateLimited,
467                    latency_ms: 100,
468                },
469            })
470            .unwrap();
471
472        let health = engine.provider_health("openai").unwrap();
473        assert!(health < 1.0);
474    }
475
476    #[test]
477    fn observe_unknown_provider_errors() {
478        let (engine, _clock) = setup_engine();
479        let result = engine.observe(Observation {
480            provider: "nonexistent".into(),
481            capability: None,
482            usage: Usage::default(),
483            outcome: Outcome {
484                status: StatusClass::Success,
485                latency_ms: 100,
486            },
487        });
488        assert!(matches!(result, Err(Error::UnknownProvider(_))));
489    }
490
491    #[test]
492    fn degraded_provider_loses_to_healthy() {
493        let (engine, clock) = setup_engine();
494
495        // Degrade openai with repeated 429s
496        for _i in 0..3 {
497            clock.advance_ms(1000);
498            engine
499                .observe(Observation {
500                    provider: "openai".into(),
501                    capability: Some("chat".into()),
502                    usage: Usage {
503                        requests: 1,
504                        ..Default::default()
505                    },
506                    outcome: Outcome {
507                        status: StatusClass::RateLimited,
508                        latency_ms: 100,
509                    },
510                })
511                .unwrap();
512        }
513
514        // openai should now be in cooldown, so anthropic wins
515        let decision = engine.select("chat").unwrap();
516        assert_eq!(decision.provider, "anthropic");
517    }
518
519    #[test]
520    fn metrics_increment() {
521        let (engine, _clock) = setup_engine();
522
523        engine.select("chat").unwrap();
524        engine.select("chat").unwrap();
525        assert_eq!(engine.metrics().selects(), 2);
526
527        engine
528            .observe(Observation {
529                provider: "openai".into(),
530                capability: None,
531                usage: Usage {
532                    requests: 1,
533                    ..Default::default()
534                },
535                outcome: Outcome {
536                    status: StatusClass::Success,
537                    latency_ms: 50,
538                },
539            })
540            .unwrap();
541        assert_eq!(engine.metrics().observations(), 1);
542    }
543
544    #[test]
545    fn provider_quota_tracking() {
546        let (engine, _clock) = setup_engine();
547
548        assert_eq!(
549            engine.provider_quota_remaining("openai", Dimension::Requests),
550            Some(100)
551        );
552
553        engine
554            .observe(Observation {
555                provider: "openai".into(),
556                capability: None,
557                usage: Usage {
558                    requests: 30,
559                    ..Default::default()
560                },
561                outcome: Outcome {
562                    status: StatusClass::Success,
563                    latency_ms: 100,
564                },
565            })
566            .unwrap();
567
568        let remaining = engine
569            .provider_quota_remaining("openai", Dimension::Requests)
570            .unwrap();
571        assert_eq!(remaining, 70);
572    }
573
574    #[test]
575    fn upsert_provider_preserves_health() {
576        let (engine, _clock) = setup_engine();
577
578        // Damage openai health
579        engine
580            .observe(Observation {
581                provider: "openai".into(),
582                capability: None,
583                usage: Usage {
584                    requests: 1,
585                    ..Default::default()
586                },
587                outcome: Outcome {
588                    status: StatusClass::ServerError,
589                    latency_ms: 100,
590                },
591            })
592            .unwrap();
593
594        let health_before = engine.provider_health("openai").unwrap();
595
596        // Re-upsert provider
597        engine.upsert_provider(ProviderConfig {
598            name: "openai".into(),
599            quotas: vec![QuotaConfig {
600                dimension: Dimension::Requests,
601                limit: 200, // new limit
602                window: Some(Window::Minute),
603            }],
604            priority: 10,
605            weight: 1.0,
606            cooldown_seconds: 30,
607        });
608
609        // Health should be preserved
610        let health_after = engine.provider_health("openai").unwrap();
611        assert!((health_before - health_after).abs() < 0.01);
612
613        // But quota should be reset to new limit
614        assert_eq!(
615            engine.provider_quota_remaining("openai", Dimension::Requests),
616            Some(200)
617        );
618    }
619
620    #[test]
621    fn engine_is_clone_and_send() {
622        let (engine, _) = setup_engine();
623        let engine2 = engine.clone();
624
625        // Spawn a thread to prove Send + Sync
626        let handle = std::thread::spawn(move || engine2.select("chat").unwrap());
627        let decision = handle.join().unwrap();
628        assert!(!decision.provider.is_empty());
629    }
630
631    #[test]
632    fn anticipatory_routing_under_pressure() {
633        let (engine, clock) = setup_engine();
634
635        // Consume 90% of openai's quota rapidly
636        for _ in 0..90 {
637            engine
638                .observe(Observation {
639                    provider: "openai".into(),
640                    capability: Some("chat".into()),
641                    usage: Usage {
642                        requests: 1,
643                        ..Default::default()
644                    },
645                    outcome: Outcome {
646                        status: StatusClass::Success,
647                        latency_ms: 50,
648                    },
649                })
650                .unwrap();
651        }
652        clock.advance_ms(5000); // 5 seconds elapsed — very fast burn rate
653
654        // With 10% remaining and high burn rate, anthropic should win
655        // despite lower priority
656        let decision = engine.select("chat").unwrap();
657        // At this point openai has 10% remaining with rapid burn,
658        // anthropic has 100% remaining. Anthropic should score higher.
659        assert_eq!(
660            decision.provider, "anthropic",
661            "Anticipatory routing should prefer anthropic when openai is nearly exhausted"
662        );
663    }
664
665    #[test]
666    fn cooldown_expires_and_provider_recovers() {
667        let (engine, clock) = setup_engine();
668
669        // Trigger cooldown on openai
670        for _ in 0..3 {
671            clock.advance_ms(100);
672            engine
673                .observe(Observation {
674                    provider: "openai".into(),
675                    capability: None,
676                    usage: Usage {
677                        requests: 1,
678                        ..Default::default()
679                    },
680                    outcome: Outcome {
681                        status: StatusClass::RateLimited,
682                        latency_ms: 100,
683                    },
684                })
685                .unwrap();
686        }
687
688        assert_eq!(engine.provider_in_cooldown("openai"), Some(true));
689
690        // Advance past cooldown (30s)
691        clock.advance_secs(31);
692
693        assert_eq!(engine.provider_in_cooldown("openai"), Some(false));
694
695        // openai should be selectable again (though health is degraded)
696        let decision = engine.select("chat").unwrap();
697        // Either provider could win depending on health recovery
698        assert!(!decision.provider.is_empty());
699    }
700}