hyperinfer_router/strategy/
latency_based.rs1use super::{DeploymentMetrics, RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LatencyBased {
12 pub ttl_secs: u64,
13 pub buffer: f64,
14 pub default_latency_ms: f64,
15}
16
17impl Default for LatencyBased {
18 fn default() -> Self {
19 Self {
20 ttl_secs: 3600,
21 buffer: 0.2,
22 default_latency_ms: 1000.0,
23 }
24 }
25}
26
27impl LatencyBased {
28 pub fn new() -> Self {
29 Self::default()
30 }
31}
32
33fn compute_global_median(metrics: &HashMap<String, DeploymentMetrics>) -> f64 {
34 let mut latencies: Vec<f64> = metrics
35 .values()
36 .map(|m| m.latency_ewma_ms)
37 .filter(|&l| l > 0.0)
38 .collect();
39
40 if latencies.is_empty() {
41 return 0.0;
42 }
43
44 latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
45 let len = latencies.len();
46 if len.is_multiple_of(2) {
47 (latencies[len / 2 - 1] + latencies[len / 2]) / 2.0
48 } else {
49 latencies[len / 2]
50 }
51}
52
53#[async_trait]
54impl RoutingStrategy for LatencyBased {
55 fn name(&self) -> &str {
56 "latency-based"
57 }
58
59 async fn select<'a>(
60 &self,
61 _model: &str,
62 candidates: &'a [Arc<Deployment>],
63 state: &dyn RoutingState,
64 _request: &RoutingContext,
65 ) -> Result<&'a Arc<Deployment>, RoutingError> {
66 if candidates.is_empty() {
67 return Err(RoutingError::NoDeployments("empty candidates".into()));
68 }
69
70 let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
71 let all_metrics = state.get_all_metrics(&ids).await?;
72
73 let global_median = compute_global_median(&all_metrics);
74 let fallback_latency = if global_median > 0.0 {
75 global_median
76 } else {
77 self.default_latency_ms
78 };
79
80 let mut eligible: Vec<(usize, f64)> = Vec::new();
81
82 for (i, deployment) in candidates.iter().enumerate() {
83 if state.is_cooled_down(&deployment.id).await? {
84 continue;
85 }
86
87 let metrics = all_metrics.get(&deployment.id);
88 let latency = match metrics {
89 Some(m) if m.latency_ewma_ms > 0.0 => m.latency_ewma_ms,
90 _ => fallback_latency,
91 };
92
93 eligible.push((i, latency));
94 }
95
96 if eligible.is_empty() {
97 return Err(RoutingError::NoDeployments(
98 "no eligible deployments after filtering".into(),
99 ));
100 }
101
102 let best_latency = eligible
103 .iter()
104 .map(|(_, l)| *l)
105 .fold(f64::INFINITY, f64::min);
106
107 let threshold = best_latency * (1.0 + self.buffer);
108
109 let within_threshold: Vec<(usize, f64)> = eligible
110 .into_iter()
111 .filter(|(_, l)| *l <= threshold)
112 .collect();
113
114 let weights: Vec<f64> = within_threshold
115 .iter()
116 .map(|(i, _)| candidates[*i].weight as f64)
117 .collect();
118
119 let total_weight: f64 = weights.iter().sum();
120 if total_weight <= 0.0 {
121 let (i, _) = within_threshold[0];
122 return Ok(&candidates[i]);
123 }
124 let mut rng = rand::thread_rng();
125 let mut pick = rng.gen_range(0.0..total_weight);
126
127 for (idx, weight) in weights.iter().enumerate() {
128 pick -= weight;
129 if pick <= 0.0 {
130 let (i, _) = within_threshold[idx];
131 return Ok(&candidates[i]);
132 }
133 }
134
135 let (i, _) = *within_threshold.last().unwrap();
136 Ok(&candidates[i])
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::deployment::Deployment;
144 use crate::strategy::weighted_shuffle::tests_helpers::MockState;
145 use hyperinfer_core::Provider;
146
147 fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
148 let mut d = Deployment::new(
149 "test-model".to_string(),
150 Provider::OpenAI,
151 "gpt-4".to_string(),
152 format!("key-{}", id),
153 );
154 d.weight = weight;
155 d.id = id.to_string();
156 Arc::new(d)
157 }
158
159 #[tokio::test]
160 async fn test_selects_lowest_latency() {
161 let d1 = make_deployment("d1", 1);
162 let d2 = make_deployment("d2", 1);
163 let candidates = vec![d1, d2.clone()];
164
165 let state = MockState::new()
166 .with_metrics(
167 "d1",
168 DeploymentMetrics {
169 latency_ewma_ms: 200.0,
170 ..Default::default()
171 },
172 )
173 .with_metrics(
174 "d2",
175 DeploymentMetrics {
176 latency_ewma_ms: 50.0,
177 ..Default::default()
178 },
179 );
180
181 let strategy = LatencyBased::new();
182 let ctx = RoutingContext::default();
183
184 let result = strategy
185 .select("test-model", &candidates, &state, &ctx)
186 .await
187 .unwrap();
188 assert_eq!(result.id, "d2");
189 }
190
191 #[tokio::test]
192 async fn test_buffer_includes_near_candidates() {
193 let d1 = make_deployment("d1", 1);
194 let d2 = make_deployment("d2", 1);
195 let d3 = make_deployment("d3", 1);
196 let candidates = vec![d1, d2, d3];
197
198 let state = MockState::new()
199 .with_metrics(
200 "d1",
201 DeploymentMetrics {
202 latency_ewma_ms: 100.0,
203 ..Default::default()
204 },
205 )
206 .with_metrics(
207 "d2",
208 DeploymentMetrics {
209 latency_ewma_ms: 115.0,
210 ..Default::default()
211 },
212 )
213 .with_metrics(
214 "d3",
215 DeploymentMetrics {
216 latency_ewma_ms: 500.0,
217 ..Default::default()
218 },
219 );
220
221 let strategy = LatencyBased {
222 buffer: 0.2,
223 ..Default::default()
224 };
225 let ctx = RoutingContext::default();
226
227 let mut d3_count = 0u32;
228 for _ in 0..1000 {
229 let result = strategy
230 .select("test-model", &candidates, &state, &ctx)
231 .await
232 .unwrap();
233 if result.id == "d3" {
234 d3_count += 1;
235 }
236 }
237
238 assert_eq!(d3_count, 0, "d3 should never be selected with buffer=0.2");
239 }
240
241 #[tokio::test]
242 async fn test_cold_start_uses_global_median() {
243 let d1 = make_deployment("d1", 1);
244 let d2 = make_deployment("d2", 1);
245 let d_new = make_deployment("d_new", 1);
246 let candidates = vec![d1, d2, d_new];
247
248 let state = MockState::new()
249 .with_metrics(
250 "d1",
251 DeploymentMetrics {
252 latency_ewma_ms: 100.0,
253 ..Default::default()
254 },
255 )
256 .with_metrics(
257 "d2",
258 DeploymentMetrics {
259 latency_ewma_ms: 110.0,
260 ..Default::default()
261 },
262 );
263
264 let strategy = LatencyBased::new();
265 let ctx = RoutingContext::default();
266
267 let mut new_count = 0u32;
268 for _ in 0..1000 {
269 let result = strategy
270 .select("test-model", &candidates, &state, &ctx)
271 .await
272 .unwrap();
273 if result.id == "d_new" {
274 new_count += 1;
275 }
276 }
277
278 assert!(
279 (200..=800).contains(&new_count),
280 "new deployment should get significant traffic, got {}",
281 new_count
282 );
283 }
284
285 #[tokio::test]
286 async fn test_cooled_down_excluded() {
287 let d1 = make_deployment("d1", 1);
288 let d2 = make_deployment("d2", 1);
289 let candidates = vec![d1, d2.clone()];
290
291 let state = MockState::new()
292 .with_metrics(
293 "d1",
294 DeploymentMetrics {
295 latency_ewma_ms: 50.0,
296 ..Default::default()
297 },
298 )
299 .with_metrics(
300 "d2",
301 DeploymentMetrics {
302 latency_ewma_ms: 200.0,
303 ..Default::default()
304 },
305 )
306 .with_cooldown("d1");
307
308 let strategy = LatencyBased::new();
309 let ctx = RoutingContext::default();
310
311 let result = strategy
312 .select("test-model", &candidates, &state, &ctx)
313 .await
314 .unwrap();
315 assert_eq!(result.id, "d2");
316 }
317
318 #[test]
319 fn test_global_median_odd() {
320 let mut metrics = HashMap::new();
321 metrics.insert(
322 "a".to_string(),
323 DeploymentMetrics {
324 latency_ewma_ms: 100.0,
325 ..Default::default()
326 },
327 );
328 metrics.insert(
329 "b".to_string(),
330 DeploymentMetrics {
331 latency_ewma_ms: 200.0,
332 ..Default::default()
333 },
334 );
335 metrics.insert(
336 "c".to_string(),
337 DeploymentMetrics {
338 latency_ewma_ms: 300.0,
339 ..Default::default()
340 },
341 );
342
343 let median = compute_global_median(&metrics);
344 assert!((median - 200.0).abs() < f64::EPSILON);
345 }
346
347 #[test]
348 fn test_global_median_even() {
349 let mut metrics = HashMap::new();
350 metrics.insert(
351 "a".to_string(),
352 DeploymentMetrics {
353 latency_ewma_ms: 100.0,
354 ..Default::default()
355 },
356 );
357 metrics.insert(
358 "b".to_string(),
359 DeploymentMetrics {
360 latency_ewma_ms: 200.0,
361 ..Default::default()
362 },
363 );
364 metrics.insert(
365 "c".to_string(),
366 DeploymentMetrics {
367 latency_ewma_ms: 300.0,
368 ..Default::default()
369 },
370 );
371 metrics.insert(
372 "d".to_string(),
373 DeploymentMetrics {
374 latency_ewma_ms: 400.0,
375 ..Default::default()
376 },
377 );
378
379 let median = compute_global_median(&metrics);
380 assert!((median - 250.0).abs() < f64::EPSILON);
381 }
382
383 #[test]
384 fn test_global_median_empty() {
385 let metrics = HashMap::new();
386 let median = compute_global_median(&metrics);
387 assert!((median - 0.0).abs() < f64::EPSILON);
388 }
389}