Skip to main content

grate_limiter/
scoring.rs

1use serde::{Deserialize, Serialize};
2
3/// Weights for the composite scoring algorithm.
4///
5/// All weights should sum to 1.0 for normalized scoring.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ScoringWeights {
8    /// Weight for remaining quota percentage (anticipatory).
9    pub quota: f32,
10    /// Weight for provider health score.
11    pub health: f32,
12    /// Weight for capability-level priority.
13    pub priority: f32,
14    /// Weight for latency score.
15    pub latency: f32,
16}
17
18impl Default for ScoringWeights {
19    fn default() -> Self {
20        Self {
21            quota: 0.40,
22            health: 0.35,
23            priority: 0.20,
24            latency: 0.05,
25        }
26    }
27}
28
29/// Trait for pluggable scoring strategies.
30///
31/// Implement this to customize how providers are ranked.
32pub trait ScoringStrategy: Send + Sync {
33    /// Score a provider given its context. Returns a value in [0.0, 1.0].
34    fn score(&self, ctx: &ProviderScoreContext) -> f32;
35}
36
37/// Context provided to the scoring strategy for a single provider.
38#[derive(Debug, Clone)]
39pub struct ProviderScoreContext {
40    /// Remaining quota as a ratio [0.0, 1.0]. 1.0 = fully available.
41    pub quota_remaining_ratio: f64,
42    /// Predicted seconds until quota exhaustion.
43    pub predicted_exhaustion_secs: f64,
44    /// Burn rate in units per second.
45    pub burn_rate: f64,
46    /// Health score [0.0, 1.0].
47    pub health_score: f32,
48    /// Capability-level priority for this provider (higher = preferred).
49    pub priority: u16,
50    /// Maximum priority across all providers for this capability.
51    pub max_priority: u16,
52    /// EWMA latency in milliseconds.
53    pub latency_ms: f64,
54    /// Maximum observed latency across candidates (for normalization).
55    pub max_latency_ms: f64,
56}
57
58/// Default weighted composite scorer.
59pub(crate) struct WeightedScorer {
60    pub(crate) weights: ScoringWeights,
61}
62
63impl WeightedScorer {
64    pub(crate) fn new(weights: ScoringWeights) -> Self {
65        Self { weights }
66    }
67
68    /// Compute the quota sub-score with anticipatory exhaustion prediction.
69    fn quota_score(ctx: &ProviderScoreContext) -> f32 {
70        let base = ctx.quota_remaining_ratio as f32;
71
72        // Anticipatory penalty: if exhaustion is predicted soon, reduce score aggressively
73        let exhaustion_penalty = if ctx.predicted_exhaustion_secs < 10.0 {
74            0.8 // Severe penalty — exhaustion imminent
75        } else if ctx.predicted_exhaustion_secs < 30.0 {
76            0.5
77        } else if ctx.predicted_exhaustion_secs < 60.0 {
78            0.3
79        } else if ctx.predicted_exhaustion_secs < 120.0 {
80            0.1
81        } else {
82            0.0
83        };
84
85        // Burn rate penalty: fast consumption rate reduces confidence
86        let burn_penalty = if ctx.burn_rate > 0.0 && ctx.quota_remaining_ratio < 0.5 {
87            0.1
88        } else {
89            0.0
90        };
91
92        (base - exhaustion_penalty - burn_penalty).max(0.0)
93    }
94
95    /// Compute the priority sub-score normalized to [0.0, 1.0].
96    fn priority_score(ctx: &ProviderScoreContext) -> f32 {
97        if ctx.max_priority == 0 {
98            return 0.5;
99        }
100        ctx.priority as f32 / ctx.max_priority as f32
101    }
102
103    /// Compute the latency sub-score (lower latency = higher score).
104    fn latency_score(ctx: &ProviderScoreContext) -> f32 {
105        if ctx.max_latency_ms <= 0.0 || ctx.latency_ms <= 0.0 {
106            return 1.0; // No latency data — assume fine
107        }
108        (1.0 - (ctx.latency_ms / ctx.max_latency_ms) as f32).max(0.0)
109    }
110}
111
112impl ScoringStrategy for WeightedScorer {
113    fn score(&self, ctx: &ProviderScoreContext) -> f32 {
114        let qs = Self::quota_score(ctx);
115        let hs = ctx.health_score;
116        let ps = Self::priority_score(ctx);
117        let ls = Self::latency_score(ctx);
118
119        let final_score = qs * self.weights.quota
120            + hs * self.weights.health
121            + ps * self.weights.priority
122            + ls * self.weights.latency;
123
124        final_score.clamp(0.0, 1.0)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    fn default_ctx() -> ProviderScoreContext {
133        ProviderScoreContext {
134            quota_remaining_ratio: 1.0,
135            predicted_exhaustion_secs: f64::INFINITY,
136            burn_rate: 0.0,
137            health_score: 1.0,
138            priority: 10,
139            max_priority: 10,
140            latency_ms: 100.0,
141            max_latency_ms: 200.0,
142        }
143    }
144
145    #[test]
146    fn perfect_provider_scores_high() {
147        let scorer = WeightedScorer::new(ScoringWeights::default());
148        let ctx = default_ctx();
149        let score = scorer.score(&ctx);
150        assert!(score > 0.9, "score={score}");
151    }
152
153    #[test]
154    fn exhausted_provider_scores_low() {
155        let scorer = WeightedScorer::new(ScoringWeights::default());
156        let ctx = ProviderScoreContext {
157            quota_remaining_ratio: 0.05,
158            predicted_exhaustion_secs: 5.0,
159            health_score: 0.5,
160            ..default_ctx()
161        };
162        let score = scorer.score(&ctx);
163        assert!(score < 0.5, "score={score}");
164    }
165
166    #[test]
167    fn unhealthy_provider_scores_low() {
168        let scorer = WeightedScorer::new(ScoringWeights::default());
169        let ctx = ProviderScoreContext {
170            health_score: 0.2,
171            ..default_ctx()
172        };
173        let score = scorer.score(&ctx);
174        assert!(score < 0.8, "score={score}");
175    }
176
177    #[test]
178    fn low_priority_scores_lower() {
179        let scorer = WeightedScorer::new(ScoringWeights::default());
180        let high = scorer.score(&default_ctx());
181        let low = scorer.score(&ProviderScoreContext {
182            priority: 2,
183            ..default_ctx()
184        });
185        assert!(high > low);
186    }
187
188    #[test]
189    fn anticipatory_penalty_kicks_in() {
190        let scorer = WeightedScorer::new(ScoringWeights::default());
191
192        // Provider with lots of remaining quota but fast burn rate
193        let fast_burn = ProviderScoreContext {
194            quota_remaining_ratio: 0.3,
195            predicted_exhaustion_secs: 20.0, // will exhaust in 20s
196            burn_rate: 50.0,
197            ..default_ctx()
198        };
199        let slow_burn = ProviderScoreContext {
200            quota_remaining_ratio: 0.3,
201            predicted_exhaustion_secs: 300.0,
202            burn_rate: 1.0,
203            ..default_ctx()
204        };
205
206        let fast_score = scorer.score(&fast_burn);
207        let slow_score = scorer.score(&slow_burn);
208        assert!(
209            slow_score > fast_score,
210            "slow={slow_score} fast={fast_score}"
211        );
212    }
213
214    #[test]
215    fn score_always_bounded() {
216        let scorer = WeightedScorer::new(ScoringWeights::default());
217
218        // Worst case
219        let ctx = ProviderScoreContext {
220            quota_remaining_ratio: 0.0,
221            predicted_exhaustion_secs: 0.0,
222            burn_rate: 1000.0,
223            health_score: 0.0,
224            priority: 0,
225            max_priority: 10,
226            latency_ms: 5000.0,
227            max_latency_ms: 5000.0,
228        };
229        let score = scorer.score(&ctx);
230        assert!((0.0..=1.0).contains(&score), "score={score}");
231
232        // Best case
233        let score = scorer.score(&default_ctx());
234        assert!((0.0..=1.0).contains(&score), "score={score}");
235    }
236}