hyperinfer_router/strategy/
weighted_shuffle.rs1use super::{DeploymentMetrics, RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use std::sync::Arc;
7
8#[derive(Debug, Clone)]
9pub struct WeightedShuffle;
10
11impl WeightedShuffle {
12 pub fn new() -> Self {
13 Self
14 }
15
16 fn effective_weight(deployment: &Deployment, metrics: &DeploymentMetrics) -> f64 {
17 let base_weight = deployment.weight as f64;
18
19 let rpm_ratio = match deployment.rpm_limit {
20 Some(limit) if limit > 0 => 1.0 - (metrics.rpm_used as f64 / limit as f64),
21 _ => 1.0,
22 };
23
24 let tpm_ratio = match deployment.tpm_limit {
25 Some(limit) if limit > 0 => 1.0 - (metrics.tpm_used as f64 / limit as f64),
26 _ => 1.0,
27 };
28
29 let capacity = rpm_ratio.min(tpm_ratio).max(0.0);
30 base_weight * capacity
31 }
32}
33
34impl Default for WeightedShuffle {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[async_trait]
41impl RoutingStrategy for WeightedShuffle {
42 fn name(&self) -> &str {
43 "weighted-shuffle"
44 }
45
46 async fn select<'a>(
47 &self,
48 _model: &str,
49 candidates: &'a [Arc<Deployment>],
50 state: &dyn RoutingState,
51 _request: &RoutingContext,
52 ) -> Result<&'a Arc<Deployment>, RoutingError> {
53 if candidates.is_empty() {
54 return Err(RoutingError::NoDeployments("empty candidates".into()));
55 }
56
57 let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
58 let all_metrics = state.get_all_metrics(&ids).await?;
59
60 let mut eligible: Vec<(usize, f64)> = Vec::new();
61
62 for (i, deployment) in candidates.iter().enumerate() {
63 if state.is_cooled_down(&deployment.id).await? {
64 continue;
65 }
66
67 let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
68
69 let ew = Self::effective_weight(deployment, &metrics);
70 if ew > 0.0 {
71 eligible.push((i, ew));
72 }
73 }
74
75 if eligible.is_empty() {
76 return Err(RoutingError::NoDeployments(
77 "no eligible deployments after filtering".into(),
78 ));
79 }
80
81 let filtered_candidates: Vec<&Arc<Deployment>> =
82 eligible.iter().map(|(i, _)| &candidates[*i]).collect();
83 let weights: Vec<f64> = eligible.iter().map(|(_, w)| *w).collect();
84
85 let selected = Self::weighted_select_owned(&filtered_candidates, &weights);
86 let selected_id = &selected.id;
87
88 candidates
89 .iter()
90 .find(|d| d.id == *selected_id)
91 .ok_or_else(|| RoutingError::NoDeployments("selected deployment not found".into()))
92 }
93}
94
95impl WeightedShuffle {
96 fn weighted_select_owned(candidates: &[&Arc<Deployment>], weights: &[f64]) -> Arc<Deployment> {
97 let total_weight: f64 = weights.iter().sum();
98 let mut rng = rand::thread_rng();
99 let mut threshold = rng.gen_range(0.0..total_weight);
100
101 for (i, weight) in weights.iter().enumerate() {
102 threshold -= weight;
103 if threshold <= 0.0 {
104 return Arc::clone(candidates[i]);
105 }
106 }
107
108 Arc::clone(candidates.last().unwrap())
109 }
110}
111
112#[cfg(test)]
113pub mod tests_helpers {
114 use super::super::{DeploymentMetrics, RecordFailureResult, RoutingError, RoutingState};
115 use async_trait::async_trait;
116 use std::collections::HashMap;
117
118 #[derive(Debug, Clone, Default)]
119 pub struct MockState {
120 pub metrics: HashMap<String, DeploymentMetrics>,
121 pub cooled_down: HashMap<String, bool>,
122 }
123
124 impl MockState {
125 pub fn new() -> Self {
126 Self::default()
127 }
128
129 pub fn with_metrics(mut self, id: &str, metrics: DeploymentMetrics) -> Self {
130 self.metrics.insert(id.to_string(), metrics);
131 self
132 }
133
134 pub fn with_cooldown(mut self, id: &str) -> Self {
135 self.cooled_down.insert(id.to_string(), true);
136 self
137 }
138 }
139
140 #[async_trait]
141 impl RoutingState for MockState {
142 async fn get_metrics(
143 &self,
144 deployment_id: &str,
145 ) -> Result<DeploymentMetrics, RoutingError> {
146 Ok(self.metrics.get(deployment_id).cloned().unwrap_or_default())
147 }
148
149 async fn get_all_metrics(
150 &self,
151 ids: &[&str],
152 ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError> {
153 let mut result = HashMap::new();
154 for id in ids {
155 if let Some(m) = self.metrics.get(*id) {
156 result.insert(id.to_string(), m.clone());
157 }
158 }
159 Ok(result)
160 }
161
162 async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError> {
163 Ok(self
164 .cooled_down
165 .get(deployment_id)
166 .copied()
167 .unwrap_or(false))
168 }
169
170 async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
171 Ok(())
172 }
173
174 async fn record_request_success(
175 &self,
176 _deployment_id: &str,
177 _latency_ms: f64,
178 _tokens: u64,
179 ) -> Result<(), RoutingError> {
180 Ok(())
181 }
182
183 async fn record_request_failure(
184 &self,
185 _deployment_id: &str,
186 ) -> Result<RecordFailureResult, RoutingError> {
187 Ok(RecordFailureResult {
188 failure_count: 0,
189 cooldown_triggered: false,
190 })
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::super::DeploymentMetrics;
198 use super::tests_helpers::MockState;
199 use super::*;
200 use crate::deployment::Deployment;
201 use hyperinfer_core::Provider;
202
203 fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
204 let mut d = Deployment::new(
205 "test-model".to_string(),
206 Provider::OpenAI,
207 "gpt-4".to_string(),
208 format!("key-{}", id),
209 );
210 d.weight = weight;
211 d.id = id.to_string();
212 Arc::new(d)
213 }
214
215 fn make_deployment_with_limits(
216 id: &str,
217 weight: u32,
218 rpm_limit: Option<u64>,
219 tpm_limit: Option<u64>,
220 ) -> Arc<Deployment> {
221 let mut d = Deployment::new(
222 "test-model".to_string(),
223 Provider::OpenAI,
224 "gpt-4".to_string(),
225 format!("key-{}", id),
226 );
227 d.weight = weight;
228 d.id = id.to_string();
229 d.rpm_limit = rpm_limit;
230 d.tpm_limit = tpm_limit;
231 Arc::new(d)
232 }
233
234 #[tokio::test]
235 async fn test_single_candidate() {
236 let d = make_deployment("d1", 1);
237 let candidates = vec![d.clone()];
238 let state = MockState::new();
239 let strategy = WeightedShuffle::new();
240 let ctx = RoutingContext::default();
241
242 let result = strategy
243 .select("test-model", &candidates, &state, &ctx)
244 .await
245 .unwrap();
246 assert_eq!(result.id, "d1");
247 }
248
249 #[tokio::test]
250 async fn test_empty_candidates() {
251 let candidates: Vec<Arc<Deployment>> = vec![];
252 let state = MockState::new();
253 let strategy = WeightedShuffle::new();
254 let ctx = RoutingContext::default();
255
256 let result = strategy
257 .select("test-model", &candidates, &state, &ctx)
258 .await;
259 assert!(result.is_err());
260 assert!(matches!(
261 result.unwrap_err(),
262 RoutingError::NoDeployments(_)
263 ));
264 }
265
266 #[tokio::test]
267 async fn test_cooled_down_excluded() {
268 let d1 = make_deployment("d1", 1);
269 let d2 = make_deployment("d2", 1);
270 let candidates = vec![d1, d2.clone()];
271 let state = MockState::new().with_cooldown("d1");
272 let strategy = WeightedShuffle::new();
273 let ctx = RoutingContext::default();
274
275 let result = strategy
276 .select("test-model", &candidates, &state, &ctx)
277 .await
278 .unwrap();
279 assert_eq!(result.id, "d2");
280 }
281
282 #[tokio::test]
283 async fn test_all_cooled_down_returns_error() {
284 let d1 = make_deployment("d1", 1);
285 let d2 = make_deployment("d2", 1);
286 let candidates = vec![d1, d2];
287 let state = MockState::new().with_cooldown("d1").with_cooldown("d2");
288 let strategy = WeightedShuffle::new();
289 let ctx = RoutingContext::default();
290
291 let result = strategy
292 .select("test-model", &candidates, &state, &ctx)
293 .await;
294 assert!(result.is_err());
295 assert!(matches!(
296 result.unwrap_err(),
297 RoutingError::NoDeployments(_)
298 ));
299 }
300
301 #[tokio::test]
302 async fn test_at_capacity_excluded() {
303 let d1 = make_deployment_with_limits("d1", 5, Some(100), None);
304 let d2 = make_deployment_with_limits("d2", 1, None, None);
305 let candidates = vec![d1, d2.clone()];
306
307 let metrics_d1 = DeploymentMetrics {
308 rpm_used: 100,
309 ..Default::default()
310 };
311
312 let state = MockState::new().with_metrics("d1", metrics_d1);
313 let strategy = WeightedShuffle::new();
314 let ctx = RoutingContext::default();
315
316 let result = strategy
317 .select("test-model", &candidates, &state, &ctx)
318 .await
319 .unwrap();
320 assert_eq!(result.id, "d2");
321 }
322
323 #[tokio::test]
324 async fn test_weight_distribution() {
325 let d1 = make_deployment("d1", 9);
326 let d2 = make_deployment("d2", 1);
327 let candidates = vec![d1, d2];
328 let state = MockState::new();
329 let strategy = WeightedShuffle::new();
330 let ctx = RoutingContext::default();
331
332 let mut d1_count = 0u32;
333 let iterations = 10000;
334
335 for _ in 0..iterations {
336 let result = strategy
337 .select("test-model", &candidates, &state, &ctx)
338 .await
339 .unwrap();
340 if result.id == "d1" {
341 d1_count += 1;
342 }
343 }
344
345 let ratio = d1_count as f64 / iterations as f64;
346 assert!(
347 ratio > 0.80 && ratio < 0.98,
348 "expected d1 ratio between 80-98%, got {:.2}%",
349 ratio * 100.0
350 );
351 }
352
353 #[test]
354 fn test_effective_weight_no_limits() {
355 let d = make_deployment("d1", 5);
356 let metrics = DeploymentMetrics::default();
357 let ew = WeightedShuffle::effective_weight(&d, &metrics);
358 assert!((ew - 5.0).abs() < f64::EPSILON);
359 }
360
361 #[test]
362 fn test_effective_weight_at_rpm_limit() {
363 let d = make_deployment_with_limits("d1", 5, Some(100), None);
364 let metrics = DeploymentMetrics {
365 rpm_used: 100,
366 ..Default::default()
367 };
368 let ew = WeightedShuffle::effective_weight(&d, &metrics);
369 assert!((ew - 0.0).abs() < f64::EPSILON);
370 }
371
372 #[test]
373 fn test_effective_weight_half_capacity() {
374 let d = make_deployment_with_limits("d1", 10, Some(100), None);
375 let metrics = DeploymentMetrics {
376 rpm_used: 50,
377 ..Default::default()
378 };
379 let ew = WeightedShuffle::effective_weight(&d, &metrics);
380 assert!((ew - 5.0).abs() < f64::EPSILON);
381 }
382}