1use std::sync::Arc;
2
3use dashmap::DashMap;
4use parking_lot::RwLock;
5use smallvec::SmallVec;
6
7use crate::capability::{CapabilityConfig, CapabilityProvider};
8use crate::clock::{Clock, Timestamp};
9use crate::config::EngineConfig;
10use crate::decision::{Alternative, Decision, ScoreBreakdown};
11use crate::error::{Error, Result};
12use crate::health::HealthState;
13use crate::metrics::Metrics;
14use crate::observation::{Observation, StatusClass};
15use crate::provider::ProviderConfig;
16use crate::quota::{self, Dimension, QuotaConfig};
17use crate::scoring::{ProviderScoreContext, ScoringStrategy, WeightedScorer};
18
19#[derive(Clone)]
32pub struct GrateLimiter {
33 inner: Arc<Inner>,
34}
35
36struct Inner {
37 providers: DashMap<String, ProviderRuntime>,
39 capabilities: RwLock<DashMap<String, CapabilityDef>>,
41 scorer: Box<dyn ScoringStrategy>,
43 config: EngineConfig,
45 clock: Arc<dyn Clock>,
47 metrics: Metrics,
49}
50
51struct ProviderRuntime {
53 config: ProviderConfig,
54 health: RwLock<HealthState>,
55 quota_trackers: Vec<(QuotaConfig, Box<dyn crate::quota::QuotaTracker>)>,
56}
57
58struct CapabilityDef {
60 providers: SmallVec<[CapabilityProvider; 4]>,
61}
62
63impl GrateLimiter {
64 pub fn new(config: EngineConfig) -> Self {
66 let clock = config.clock();
67 let scorer = Box::new(WeightedScorer::new(config.scoring.clone()));
68
69 Self {
70 inner: Arc::new(Inner {
71 providers: DashMap::new(),
72 capabilities: RwLock::new(DashMap::new()),
73 scorer,
74 config,
75 clock,
76 metrics: Metrics::new(),
77 }),
78 }
79 }
80
81 pub fn upsert_provider(&self, config: ProviderConfig) {
86 let now = self.inner.clock.now();
87 let trackers: Vec<_> = config
88 .quotas
89 .iter()
90 .map(|qc| (qc.clone(), quota::create_tracker(qc, now)))
91 .collect();
92
93 if let Some(mut existing) = self.inner.providers.get_mut(&config.name) {
94 existing.config = config;
96 existing.quota_trackers = trackers;
97 } else {
98 self.inner.providers.insert(
99 config.name.clone(),
100 ProviderRuntime {
101 config,
102 health: RwLock::new(HealthState::new(now)),
103 quota_trackers: trackers,
104 },
105 );
106 }
107 }
108
109 pub fn upsert_capability(&self, config: CapabilityConfig) {
111 let caps = self.inner.capabilities.read();
112 caps.insert(
113 config.name.clone(),
114 CapabilityDef {
115 providers: SmallVec::from_vec(config.providers),
116 },
117 );
118 }
119
120 pub fn select(&self, capability: &str) -> Result<Decision> {
130 self.inner
131 .metrics
132 .selects
133 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
134
135 let now = self.inner.clock.now();
136
137 let caps = self.inner.capabilities.read();
139 let cap_def = caps
140 .get(capability)
141 .ok_or_else(|| Error::UnknownCapability(capability.to_string()))?;
142
143 let cap_providers = &cap_def.providers;
144 if cap_providers.is_empty() {
145 return Err(Error::NoAvailableProviders(capability.to_string()));
146 }
147
148 let max_priority = cap_providers.iter().map(|p| p.priority).max().unwrap_or(1);
150
151 let mut max_latency_ms: f64 = 0.0;
152 for cp in cap_providers.iter() {
153 if let Some(pr) = self.inner.providers.get(&cp.provider) {
154 let health = pr.health.read();
155 if health.latency_ms() > max_latency_ms {
156 max_latency_ms = health.latency_ms();
157 }
158 }
159 }
160 if max_latency_ms <= 0.0 {
161 max_latency_ms = 1.0;
162 }
163
164 let mut candidates: SmallVec<[(String, f32, ScoreBreakdown); 4]> = SmallVec::new();
166
167 for cp in cap_providers.iter() {
168 let Some(pr) = self.inner.providers.get(&cp.provider) else {
169 continue;
170 };
171
172 let health = pr.health.read();
173
174 if health.is_in_cooldown(now) {
176 continue;
177 }
178
179 if health.score() < self.inner.config.minimum_health_score {
181 continue;
182 }
183
184 let (quota_remaining_ratio, predicted_exhaustion, burn_rate) =
186 self.worst_quota_state(&pr.quota_trackers, now);
187
188 let ctx = ProviderScoreContext {
189 quota_remaining_ratio,
190 predicted_exhaustion_secs: predicted_exhaustion,
191 burn_rate,
192 health_score: health.score(),
193 priority: cp.priority,
194 max_priority,
195 latency_ms: health.latency_ms(),
196 max_latency_ms,
197 };
198
199 let score = self.inner.scorer.score(&ctx);
200 let breakdown = ScoreBreakdown {
201 quota_score: ctx.quota_remaining_ratio as f32,
202 health_score: ctx.health_score,
203 priority_score: cp.priority as f32 / max_priority as f32,
204 latency_score: if max_latency_ms > 0.0 {
205 (1.0 - (ctx.latency_ms / max_latency_ms) as f32).max(0.0)
206 } else {
207 1.0
208 },
209 };
210
211 candidates.push((cp.provider.clone(), score, breakdown));
212 }
213
214 drop(cap_def);
215 drop(caps);
216
217 if candidates.is_empty() {
218 self.inner
219 .metrics
220 .no_provider_available
221 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
222 return Err(Error::NoAvailableProviders(capability.to_string()));
223 }
224
225 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227
228 let (provider, score, reasoning) = candidates.remove(0);
229 let alternatives = candidates
230 .into_iter()
231 .map(|(p, s, _)| Alternative {
232 provider: p,
233 score: s,
234 })
235 .collect();
236
237 Ok(Decision {
238 provider,
239 score,
240 reasoning,
241 alternatives,
242 })
243 }
244
245 pub fn observe(&self, obs: Observation) -> Result<()> {
253 self.inner
254 .metrics
255 .observations
256 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257
258 let now = self.inner.clock.now();
259
260 let pr = self
261 .inner
262 .providers
263 .get(&obs.provider)
264 .ok_or_else(|| Error::UnknownProvider(obs.provider.clone()))?;
265
266 for (qc, tracker) in &pr.quota_trackers {
268 let amount = match qc.dimension {
269 Dimension::Requests => obs.usage.requests,
270 Dimension::Tokens => obs.usage.tokens.unwrap_or(0),
271 Dimension::Bytes => obs.usage.bytes.unwrap_or(0),
272 Dimension::CostUsd => obs.usage.cost_micro_usd.unwrap_or(0),
273 Dimension::Concurrency => obs.usage.requests, };
275 if amount > 0 {
276 tracker.record(amount, now);
277 }
278 }
279
280 let cooldown_secs = pr.config.cooldown_seconds;
282 let health_config = &self.inner.config.health;
283 let mut health = pr.health.write();
284 let was_in_cooldown = health.is_in_cooldown(now);
285
286 match obs.outcome.status {
287 StatusClass::Success | StatusClass::ClientError => {
288 health.record_success(obs.outcome.latency_ms, now, health_config);
289 }
290 StatusClass::RateLimited => {
291 health.record_rate_limited(now, health_config, cooldown_secs);
292 }
293 StatusClass::Forbidden => {
294 health.record_forbidden(now, health_config, cooldown_secs);
295 }
296 StatusClass::ServerError => {
297 health.record_server_error(now, health_config, cooldown_secs);
298 }
299 StatusClass::Timeout => {
300 health.record_timeout(now, health_config, cooldown_secs);
301 }
302 }
303
304 if !was_in_cooldown && health.is_in_cooldown(now) {
306 self.inner
307 .metrics
308 .cooldowns_triggered
309 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
310 }
311
312 Ok(())
313 }
314
315 pub fn metrics(&self) -> &Metrics {
317 &self.inner.metrics
318 }
319
320 pub fn provider_health(&self, provider: &str) -> Option<f32> {
322 self.inner
323 .providers
324 .get(provider)
325 .map(|pr| pr.health.read().score())
326 }
327
328 pub fn provider_in_cooldown(&self, provider: &str) -> Option<bool> {
330 let now = self.inner.clock.now();
331 self.inner
332 .providers
333 .get(provider)
334 .map(|pr| pr.health.read().is_in_cooldown(now))
335 }
336
337 pub fn provider_quota_remaining(&self, provider: &str, dimension: Dimension) -> Option<u64> {
339 let now = self.inner.clock.now();
340 self.inner.providers.get(provider).and_then(|pr| {
341 pr.quota_trackers
342 .iter()
343 .find(|(qc, _)| qc.dimension == dimension)
344 .map(|(_, tracker)| tracker.remaining(now))
345 })
346 }
347
348 fn worst_quota_state(
350 &self,
351 trackers: &[(QuotaConfig, Box<dyn crate::quota::QuotaTracker>)],
352 now: Timestamp,
353 ) -> (f64, f64, f64) {
354 if trackers.is_empty() {
355 return (1.0, f64::INFINITY, 0.0);
356 }
357
358 let mut worst_remaining = 1.0_f64;
359 let mut worst_exhaustion = f64::INFINITY;
360 let mut max_burn_rate = 0.0_f64;
361
362 for (_, tracker) in trackers {
363 let remaining = 1.0 - tracker.usage_ratio(now);
364 let exhaustion = tracker.predicted_exhaustion_secs(now);
365 let burn = tracker.burn_rate(now);
366
367 if remaining < worst_remaining {
368 worst_remaining = remaining;
369 }
370 if exhaustion < worst_exhaustion {
371 worst_exhaustion = exhaustion;
372 }
373 if burn > max_burn_rate {
374 max_burn_rate = burn;
375 }
376 }
377
378 (worst_remaining, worst_exhaustion, max_burn_rate)
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::clock::MockClock;
386 use crate::observation::{Outcome, Usage};
387 use crate::quota::Window;
388
389 fn setup_engine() -> (GrateLimiter, Arc<MockClock>) {
390 let clock = Arc::new(MockClock::new());
391 let config = EngineConfig::default().with_clock(clock.clone());
392 let engine = GrateLimiter::new(config);
393
394 engine.upsert_provider(ProviderConfig {
395 name: "openai".into(),
396 quotas: vec![QuotaConfig {
397 dimension: Dimension::Requests,
398 limit: 100,
399 window: Some(Window::Minute),
400 }],
401 priority: 10,
402 weight: 1.0,
403 cooldown_seconds: 30,
404 });
405
406 engine.upsert_provider(ProviderConfig {
407 name: "anthropic".into(),
408 quotas: vec![QuotaConfig {
409 dimension: Dimension::Requests,
410 limit: 80,
411 window: Some(Window::Minute),
412 }],
413 priority: 8,
414 weight: 1.0,
415 cooldown_seconds: 30,
416 });
417
418 engine.upsert_capability(CapabilityConfig {
419 name: "chat".into(),
420 providers: vec![
421 CapabilityProvider {
422 provider: "openai".into(),
423 priority: 10,
424 },
425 CapabilityProvider {
426 provider: "anthropic".into(),
427 priority: 8,
428 },
429 ],
430 });
431
432 (engine, clock)
433 }
434
435 #[test]
436 fn select_returns_best_provider() {
437 let (engine, _clock) = setup_engine();
438 let decision = engine.select("chat").unwrap();
439 assert_eq!(decision.provider, "openai");
441 assert!(decision.score > 0.0);
442 assert_eq!(decision.alternatives.len(), 1);
443 }
444
445 #[test]
446 fn select_unknown_capability_errors() {
447 let (engine, _clock) = setup_engine();
448 let result = engine.select("nonexistent");
449 assert!(matches!(result, Err(Error::UnknownCapability(_))));
450 }
451
452 #[test]
453 fn observe_updates_health() {
454 let (engine, _clock) = setup_engine();
455
456 engine
458 .observe(Observation {
459 provider: "openai".into(),
460 capability: Some("chat".into()),
461 usage: Usage {
462 requests: 1,
463 ..Default::default()
464 },
465 outcome: Outcome {
466 status: StatusClass::RateLimited,
467 latency_ms: 100,
468 },
469 })
470 .unwrap();
471
472 let health = engine.provider_health("openai").unwrap();
473 assert!(health < 1.0);
474 }
475
476 #[test]
477 fn observe_unknown_provider_errors() {
478 let (engine, _clock) = setup_engine();
479 let result = engine.observe(Observation {
480 provider: "nonexistent".into(),
481 capability: None,
482 usage: Usage::default(),
483 outcome: Outcome {
484 status: StatusClass::Success,
485 latency_ms: 100,
486 },
487 });
488 assert!(matches!(result, Err(Error::UnknownProvider(_))));
489 }
490
491 #[test]
492 fn degraded_provider_loses_to_healthy() {
493 let (engine, clock) = setup_engine();
494
495 for _i in 0..3 {
497 clock.advance_ms(1000);
498 engine
499 .observe(Observation {
500 provider: "openai".into(),
501 capability: Some("chat".into()),
502 usage: Usage {
503 requests: 1,
504 ..Default::default()
505 },
506 outcome: Outcome {
507 status: StatusClass::RateLimited,
508 latency_ms: 100,
509 },
510 })
511 .unwrap();
512 }
513
514 let decision = engine.select("chat").unwrap();
516 assert_eq!(decision.provider, "anthropic");
517 }
518
519 #[test]
520 fn metrics_increment() {
521 let (engine, _clock) = setup_engine();
522
523 engine.select("chat").unwrap();
524 engine.select("chat").unwrap();
525 assert_eq!(engine.metrics().selects(), 2);
526
527 engine
528 .observe(Observation {
529 provider: "openai".into(),
530 capability: None,
531 usage: Usage {
532 requests: 1,
533 ..Default::default()
534 },
535 outcome: Outcome {
536 status: StatusClass::Success,
537 latency_ms: 50,
538 },
539 })
540 .unwrap();
541 assert_eq!(engine.metrics().observations(), 1);
542 }
543
544 #[test]
545 fn provider_quota_tracking() {
546 let (engine, _clock) = setup_engine();
547
548 assert_eq!(
549 engine.provider_quota_remaining("openai", Dimension::Requests),
550 Some(100)
551 );
552
553 engine
554 .observe(Observation {
555 provider: "openai".into(),
556 capability: None,
557 usage: Usage {
558 requests: 30,
559 ..Default::default()
560 },
561 outcome: Outcome {
562 status: StatusClass::Success,
563 latency_ms: 100,
564 },
565 })
566 .unwrap();
567
568 let remaining = engine
569 .provider_quota_remaining("openai", Dimension::Requests)
570 .unwrap();
571 assert_eq!(remaining, 70);
572 }
573
574 #[test]
575 fn upsert_provider_preserves_health() {
576 let (engine, _clock) = setup_engine();
577
578 engine
580 .observe(Observation {
581 provider: "openai".into(),
582 capability: None,
583 usage: Usage {
584 requests: 1,
585 ..Default::default()
586 },
587 outcome: Outcome {
588 status: StatusClass::ServerError,
589 latency_ms: 100,
590 },
591 })
592 .unwrap();
593
594 let health_before = engine.provider_health("openai").unwrap();
595
596 engine.upsert_provider(ProviderConfig {
598 name: "openai".into(),
599 quotas: vec![QuotaConfig {
600 dimension: Dimension::Requests,
601 limit: 200, window: Some(Window::Minute),
603 }],
604 priority: 10,
605 weight: 1.0,
606 cooldown_seconds: 30,
607 });
608
609 let health_after = engine.provider_health("openai").unwrap();
611 assert!((health_before - health_after).abs() < 0.01);
612
613 assert_eq!(
615 engine.provider_quota_remaining("openai", Dimension::Requests),
616 Some(200)
617 );
618 }
619
620 #[test]
621 fn engine_is_clone_and_send() {
622 let (engine, _) = setup_engine();
623 let engine2 = engine.clone();
624
625 let handle = std::thread::spawn(move || engine2.select("chat").unwrap());
627 let decision = handle.join().unwrap();
628 assert!(!decision.provider.is_empty());
629 }
630
631 #[test]
632 fn anticipatory_routing_under_pressure() {
633 let (engine, clock) = setup_engine();
634
635 for _ in 0..90 {
637 engine
638 .observe(Observation {
639 provider: "openai".into(),
640 capability: Some("chat".into()),
641 usage: Usage {
642 requests: 1,
643 ..Default::default()
644 },
645 outcome: Outcome {
646 status: StatusClass::Success,
647 latency_ms: 50,
648 },
649 })
650 .unwrap();
651 }
652 clock.advance_ms(5000); let decision = engine.select("chat").unwrap();
657 assert_eq!(
660 decision.provider, "anthropic",
661 "Anticipatory routing should prefer anthropic when openai is nearly exhausted"
662 );
663 }
664
665 #[test]
666 fn cooldown_expires_and_provider_recovers() {
667 let (engine, clock) = setup_engine();
668
669 for _ in 0..3 {
671 clock.advance_ms(100);
672 engine
673 .observe(Observation {
674 provider: "openai".into(),
675 capability: None,
676 usage: Usage {
677 requests: 1,
678 ..Default::default()
679 },
680 outcome: Outcome {
681 status: StatusClass::RateLimited,
682 latency_ms: 100,
683 },
684 })
685 .unwrap();
686 }
687
688 assert_eq!(engine.provider_in_cooldown("openai"), Some(true));
689
690 clock.advance_secs(31);
692
693 assert_eq!(engine.provider_in_cooldown("openai"), Some(false));
694
695 let decision = engine.select("chat").unwrap();
697 assert!(!decision.provider.is_empty());
699 }
700}