Skip to main content

hyperinfer_router/strategy/
weighted_shuffle.rs

1use super::{DeploymentMetrics, RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use std::sync::Arc;
7
8#[derive(Debug, Clone)]
9pub struct WeightedShuffle;
10
11impl WeightedShuffle {
12    pub fn new() -> Self {
13        Self
14    }
15
16    fn effective_weight(deployment: &Deployment, metrics: &DeploymentMetrics) -> f64 {
17        let base_weight = deployment.weight as f64;
18
19        let rpm_ratio = match deployment.rpm_limit {
20            Some(limit) if limit > 0 => 1.0 - (metrics.rpm_used as f64 / limit as f64),
21            _ => 1.0,
22        };
23
24        let tpm_ratio = match deployment.tpm_limit {
25            Some(limit) if limit > 0 => 1.0 - (metrics.tpm_used as f64 / limit as f64),
26            _ => 1.0,
27        };
28
29        let capacity = rpm_ratio.min(tpm_ratio).max(0.0);
30        base_weight * capacity
31    }
32}
33
34impl Default for WeightedShuffle {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[async_trait]
41impl RoutingStrategy for WeightedShuffle {
42    fn name(&self) -> &str {
43        "weighted-shuffle"
44    }
45
46    async fn select<'a>(
47        &self,
48        _model: &str,
49        candidates: &'a [Arc<Deployment>],
50        state: &dyn RoutingState,
51        _request: &RoutingContext,
52    ) -> Result<&'a Arc<Deployment>, RoutingError> {
53        if candidates.is_empty() {
54            return Err(RoutingError::NoDeployments("empty candidates".into()));
55        }
56
57        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
58        let all_metrics = state.get_all_metrics(&ids).await?;
59
60        let mut eligible: Vec<(usize, f64)> = Vec::new();
61
62        for (i, deployment) in candidates.iter().enumerate() {
63            if state.is_cooled_down(&deployment.id).await? {
64                continue;
65            }
66
67            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
68
69            let ew = Self::effective_weight(deployment, &metrics);
70            if ew > 0.0 {
71                eligible.push((i, ew));
72            }
73        }
74
75        if eligible.is_empty() {
76            return Err(RoutingError::NoDeployments(
77                "no eligible deployments after filtering".into(),
78            ));
79        }
80
81        let filtered_candidates: Vec<&Arc<Deployment>> =
82            eligible.iter().map(|(i, _)| &candidates[*i]).collect();
83        let weights: Vec<f64> = eligible.iter().map(|(_, w)| *w).collect();
84
85        let selected = Self::weighted_select_owned(&filtered_candidates, &weights);
86        let selected_id = &selected.id;
87
88        candidates
89            .iter()
90            .find(|d| d.id == *selected_id)
91            .ok_or_else(|| RoutingError::NoDeployments("selected deployment not found".into()))
92    }
93}
94
95impl WeightedShuffle {
96    fn weighted_select_owned(candidates: &[&Arc<Deployment>], weights: &[f64]) -> Arc<Deployment> {
97        let total_weight: f64 = weights.iter().sum();
98        let mut rng = rand::thread_rng();
99        let mut threshold = rng.gen_range(0.0..total_weight);
100
101        for (i, weight) in weights.iter().enumerate() {
102            threshold -= weight;
103            if threshold <= 0.0 {
104                return Arc::clone(candidates[i]);
105            }
106        }
107
108        Arc::clone(candidates.last().unwrap())
109    }
110}
111
112#[cfg(test)]
113pub mod tests_helpers {
114    use super::super::{DeploymentMetrics, RecordFailureResult, RoutingError, RoutingState};
115    use async_trait::async_trait;
116    use std::collections::HashMap;
117
118    #[derive(Debug, Clone, Default)]
119    pub struct MockState {
120        pub metrics: HashMap<String, DeploymentMetrics>,
121        pub cooled_down: HashMap<String, bool>,
122    }
123
124    impl MockState {
125        pub fn new() -> Self {
126            Self::default()
127        }
128
129        pub fn with_metrics(mut self, id: &str, metrics: DeploymentMetrics) -> Self {
130            self.metrics.insert(id.to_string(), metrics);
131            self
132        }
133
134        pub fn with_cooldown(mut self, id: &str) -> Self {
135            self.cooled_down.insert(id.to_string(), true);
136            self
137        }
138    }
139
140    #[async_trait]
141    impl RoutingState for MockState {
142        async fn get_metrics(
143            &self,
144            deployment_id: &str,
145        ) -> Result<DeploymentMetrics, RoutingError> {
146            Ok(self.metrics.get(deployment_id).cloned().unwrap_or_default())
147        }
148
149        async fn get_all_metrics(
150            &self,
151            ids: &[&str],
152        ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError> {
153            let mut result = HashMap::new();
154            for id in ids {
155                if let Some(m) = self.metrics.get(*id) {
156                    result.insert(id.to_string(), m.clone());
157                }
158            }
159            Ok(result)
160        }
161
162        async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError> {
163            Ok(self
164                .cooled_down
165                .get(deployment_id)
166                .copied()
167                .unwrap_or(false))
168        }
169
170        async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
171            Ok(())
172        }
173
174        async fn record_request_success(
175            &self,
176            _deployment_id: &str,
177            _latency_ms: f64,
178            _tokens: u64,
179        ) -> Result<(), RoutingError> {
180            Ok(())
181        }
182
183        async fn record_request_failure(
184            &self,
185            _deployment_id: &str,
186        ) -> Result<RecordFailureResult, RoutingError> {
187            Ok(RecordFailureResult {
188                failure_count: 0,
189                cooldown_triggered: false,
190            })
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::super::DeploymentMetrics;
198    use super::tests_helpers::MockState;
199    use super::*;
200    use crate::deployment::Deployment;
201    use hyperinfer_core::Provider;
202
203    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
204        let mut d = Deployment::new(
205            "test-model".to_string(),
206            Provider::OpenAI,
207            "gpt-4".to_string(),
208            format!("key-{}", id),
209        );
210        d.weight = weight;
211        d.id = id.to_string();
212        Arc::new(d)
213    }
214
215    fn make_deployment_with_limits(
216        id: &str,
217        weight: u32,
218        rpm_limit: Option<u64>,
219        tpm_limit: Option<u64>,
220    ) -> Arc<Deployment> {
221        let mut d = Deployment::new(
222            "test-model".to_string(),
223            Provider::OpenAI,
224            "gpt-4".to_string(),
225            format!("key-{}", id),
226        );
227        d.weight = weight;
228        d.id = id.to_string();
229        d.rpm_limit = rpm_limit;
230        d.tpm_limit = tpm_limit;
231        Arc::new(d)
232    }
233
234    #[tokio::test]
235    async fn test_single_candidate() {
236        let d = make_deployment("d1", 1);
237        let candidates = vec![d.clone()];
238        let state = MockState::new();
239        let strategy = WeightedShuffle::new();
240        let ctx = RoutingContext::default();
241
242        let result = strategy
243            .select("test-model", &candidates, &state, &ctx)
244            .await
245            .unwrap();
246        assert_eq!(result.id, "d1");
247    }
248
249    #[tokio::test]
250    async fn test_empty_candidates() {
251        let candidates: Vec<Arc<Deployment>> = vec![];
252        let state = MockState::new();
253        let strategy = WeightedShuffle::new();
254        let ctx = RoutingContext::default();
255
256        let result = strategy
257            .select("test-model", &candidates, &state, &ctx)
258            .await;
259        assert!(result.is_err());
260        assert!(matches!(
261            result.unwrap_err(),
262            RoutingError::NoDeployments(_)
263        ));
264    }
265
266    #[tokio::test]
267    async fn test_cooled_down_excluded() {
268        let d1 = make_deployment("d1", 1);
269        let d2 = make_deployment("d2", 1);
270        let candidates = vec![d1, d2.clone()];
271        let state = MockState::new().with_cooldown("d1");
272        let strategy = WeightedShuffle::new();
273        let ctx = RoutingContext::default();
274
275        let result = strategy
276            .select("test-model", &candidates, &state, &ctx)
277            .await
278            .unwrap();
279        assert_eq!(result.id, "d2");
280    }
281
282    #[tokio::test]
283    async fn test_all_cooled_down_returns_error() {
284        let d1 = make_deployment("d1", 1);
285        let d2 = make_deployment("d2", 1);
286        let candidates = vec![d1, d2];
287        let state = MockState::new().with_cooldown("d1").with_cooldown("d2");
288        let strategy = WeightedShuffle::new();
289        let ctx = RoutingContext::default();
290
291        let result = strategy
292            .select("test-model", &candidates, &state, &ctx)
293            .await;
294        assert!(result.is_err());
295        assert!(matches!(
296            result.unwrap_err(),
297            RoutingError::NoDeployments(_)
298        ));
299    }
300
301    #[tokio::test]
302    async fn test_at_capacity_excluded() {
303        let d1 = make_deployment_with_limits("d1", 5, Some(100), None);
304        let d2 = make_deployment_with_limits("d2", 1, None, None);
305        let candidates = vec![d1, d2.clone()];
306
307        let metrics_d1 = DeploymentMetrics {
308            rpm_used: 100,
309            ..Default::default()
310        };
311
312        let state = MockState::new().with_metrics("d1", metrics_d1);
313        let strategy = WeightedShuffle::new();
314        let ctx = RoutingContext::default();
315
316        let result = strategy
317            .select("test-model", &candidates, &state, &ctx)
318            .await
319            .unwrap();
320        assert_eq!(result.id, "d2");
321    }
322
323    #[tokio::test]
324    async fn test_weight_distribution() {
325        let d1 = make_deployment("d1", 9);
326        let d2 = make_deployment("d2", 1);
327        let candidates = vec![d1, d2];
328        let state = MockState::new();
329        let strategy = WeightedShuffle::new();
330        let ctx = RoutingContext::default();
331
332        let mut d1_count = 0u32;
333        let iterations = 10000;
334
335        for _ in 0..iterations {
336            let result = strategy
337                .select("test-model", &candidates, &state, &ctx)
338                .await
339                .unwrap();
340            if result.id == "d1" {
341                d1_count += 1;
342            }
343        }
344
345        let ratio = d1_count as f64 / iterations as f64;
346        assert!(
347            ratio > 0.80 && ratio < 0.98,
348            "expected d1 ratio between 80-98%, got {:.2}%",
349            ratio * 100.0
350        );
351    }
352
353    #[test]
354    fn test_effective_weight_no_limits() {
355        let d = make_deployment("d1", 5);
356        let metrics = DeploymentMetrics::default();
357        let ew = WeightedShuffle::effective_weight(&d, &metrics);
358        assert!((ew - 5.0).abs() < f64::EPSILON);
359    }
360
361    #[test]
362    fn test_effective_weight_at_rpm_limit() {
363        let d = make_deployment_with_limits("d1", 5, Some(100), None);
364        let metrics = DeploymentMetrics {
365            rpm_used: 100,
366            ..Default::default()
367        };
368        let ew = WeightedShuffle::effective_weight(&d, &metrics);
369        assert!((ew - 0.0).abs() < f64::EPSILON);
370    }
371
372    #[test]
373    fn test_effective_weight_half_capacity() {
374        let d = make_deployment_with_limits("d1", 10, Some(100), None);
375        let metrics = DeploymentMetrics {
376            rpm_used: 50,
377            ..Default::default()
378        };
379        let ew = WeightedShuffle::effective_weight(&d, &metrics);
380        assert!((ew - 5.0).abs() < f64::EPSILON);
381    }
382}