Skip to main content

hyperinfer_router/strategy/
latency_based.rs

1use super::{DeploymentMetrics, RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LatencyBased {
12    pub ttl_secs: u64,
13    pub buffer: f64,
14    pub default_latency_ms: f64,
15}
16
17impl Default for LatencyBased {
18    fn default() -> Self {
19        Self {
20            ttl_secs: 3600,
21            buffer: 0.2,
22            default_latency_ms: 1000.0,
23        }
24    }
25}
26
27impl LatencyBased {
28    pub fn new() -> Self {
29        Self::default()
30    }
31}
32
33fn compute_global_median(metrics: &HashMap<String, DeploymentMetrics>) -> f64 {
34    let mut latencies: Vec<f64> = metrics
35        .values()
36        .map(|m| m.latency_ewma_ms)
37        .filter(|&l| l > 0.0)
38        .collect();
39
40    if latencies.is_empty() {
41        return 0.0;
42    }
43
44    latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
45    let len = latencies.len();
46    if len.is_multiple_of(2) {
47        (latencies[len / 2 - 1] + latencies[len / 2]) / 2.0
48    } else {
49        latencies[len / 2]
50    }
51}
52
53#[async_trait]
54impl RoutingStrategy for LatencyBased {
55    fn name(&self) -> &str {
56        "latency-based"
57    }
58
59    async fn select<'a>(
60        &self,
61        _model: &str,
62        candidates: &'a [Arc<Deployment>],
63        state: &dyn RoutingState,
64        _request: &RoutingContext,
65    ) -> Result<&'a Arc<Deployment>, RoutingError> {
66        if candidates.is_empty() {
67            return Err(RoutingError::NoDeployments("empty candidates".into()));
68        }
69
70        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
71        let all_metrics = state.get_all_metrics(&ids).await?;
72
73        let global_median = compute_global_median(&all_metrics);
74        let fallback_latency = if global_median > 0.0 {
75            global_median
76        } else {
77            self.default_latency_ms
78        };
79
80        let mut eligible: Vec<(usize, f64)> = Vec::new();
81
82        for (i, deployment) in candidates.iter().enumerate() {
83            if state.is_cooled_down(&deployment.id).await? {
84                continue;
85            }
86
87            let metrics = all_metrics.get(&deployment.id);
88            let latency = match metrics {
89                Some(m) if m.latency_ewma_ms > 0.0 => m.latency_ewma_ms,
90                _ => fallback_latency,
91            };
92
93            eligible.push((i, latency));
94        }
95
96        if eligible.is_empty() {
97            return Err(RoutingError::NoDeployments(
98                "no eligible deployments after filtering".into(),
99            ));
100        }
101
102        let best_latency = eligible
103            .iter()
104            .map(|(_, l)| *l)
105            .fold(f64::INFINITY, f64::min);
106
107        let threshold = best_latency * (1.0 + self.buffer);
108
109        let within_threshold: Vec<(usize, f64)> = eligible
110            .into_iter()
111            .filter(|(_, l)| *l <= threshold)
112            .collect();
113
114        let weights: Vec<f64> = within_threshold
115            .iter()
116            .map(|(i, _)| candidates[*i].weight as f64)
117            .collect();
118
119        let total_weight: f64 = weights.iter().sum();
120        if total_weight <= 0.0 {
121            let (i, _) = within_threshold[0];
122            return Ok(&candidates[i]);
123        }
124        let mut rng = rand::thread_rng();
125        let mut pick = rng.gen_range(0.0..total_weight);
126
127        for (idx, weight) in weights.iter().enumerate() {
128            pick -= weight;
129            if pick <= 0.0 {
130                let (i, _) = within_threshold[idx];
131                return Ok(&candidates[i]);
132            }
133        }
134
135        let (i, _) = *within_threshold.last().unwrap();
136        Ok(&candidates[i])
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::deployment::Deployment;
144    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
145    use hyperinfer_core::Provider;
146
147    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
148        let mut d = Deployment::new(
149            "test-model".to_string(),
150            Provider::OpenAI,
151            "gpt-4".to_string(),
152            format!("key-{}", id),
153        );
154        d.weight = weight;
155        d.id = id.to_string();
156        Arc::new(d)
157    }
158
159    #[tokio::test]
160    async fn test_selects_lowest_latency() {
161        let d1 = make_deployment("d1", 1);
162        let d2 = make_deployment("d2", 1);
163        let candidates = vec![d1, d2.clone()];
164
165        let state = MockState::new()
166            .with_metrics(
167                "d1",
168                DeploymentMetrics {
169                    latency_ewma_ms: 200.0,
170                    ..Default::default()
171                },
172            )
173            .with_metrics(
174                "d2",
175                DeploymentMetrics {
176                    latency_ewma_ms: 50.0,
177                    ..Default::default()
178                },
179            );
180
181        let strategy = LatencyBased::new();
182        let ctx = RoutingContext::default();
183
184        let result = strategy
185            .select("test-model", &candidates, &state, &ctx)
186            .await
187            .unwrap();
188        assert_eq!(result.id, "d2");
189    }
190
191    #[tokio::test]
192    async fn test_buffer_includes_near_candidates() {
193        let d1 = make_deployment("d1", 1);
194        let d2 = make_deployment("d2", 1);
195        let d3 = make_deployment("d3", 1);
196        let candidates = vec![d1, d2, d3];
197
198        let state = MockState::new()
199            .with_metrics(
200                "d1",
201                DeploymentMetrics {
202                    latency_ewma_ms: 100.0,
203                    ..Default::default()
204                },
205            )
206            .with_metrics(
207                "d2",
208                DeploymentMetrics {
209                    latency_ewma_ms: 115.0,
210                    ..Default::default()
211                },
212            )
213            .with_metrics(
214                "d3",
215                DeploymentMetrics {
216                    latency_ewma_ms: 500.0,
217                    ..Default::default()
218                },
219            );
220
221        let strategy = LatencyBased {
222            buffer: 0.2,
223            ..Default::default()
224        };
225        let ctx = RoutingContext::default();
226
227        let mut d3_count = 0u32;
228        for _ in 0..1000 {
229            let result = strategy
230                .select("test-model", &candidates, &state, &ctx)
231                .await
232                .unwrap();
233            if result.id == "d3" {
234                d3_count += 1;
235            }
236        }
237
238        assert_eq!(d3_count, 0, "d3 should never be selected with buffer=0.2");
239    }
240
241    #[tokio::test]
242    async fn test_cold_start_uses_global_median() {
243        let d1 = make_deployment("d1", 1);
244        let d2 = make_deployment("d2", 1);
245        let d_new = make_deployment("d_new", 1);
246        let candidates = vec![d1, d2, d_new];
247
248        let state = MockState::new()
249            .with_metrics(
250                "d1",
251                DeploymentMetrics {
252                    latency_ewma_ms: 100.0,
253                    ..Default::default()
254                },
255            )
256            .with_metrics(
257                "d2",
258                DeploymentMetrics {
259                    latency_ewma_ms: 110.0,
260                    ..Default::default()
261                },
262            );
263
264        let strategy = LatencyBased::new();
265        let ctx = RoutingContext::default();
266
267        let mut new_count = 0u32;
268        for _ in 0..1000 {
269            let result = strategy
270                .select("test-model", &candidates, &state, &ctx)
271                .await
272                .unwrap();
273            if result.id == "d_new" {
274                new_count += 1;
275            }
276        }
277
278        assert!(
279            (200..=800).contains(&new_count),
280            "new deployment should get significant traffic, got {}",
281            new_count
282        );
283    }
284
285    #[tokio::test]
286    async fn test_cooled_down_excluded() {
287        let d1 = make_deployment("d1", 1);
288        let d2 = make_deployment("d2", 1);
289        let candidates = vec![d1, d2.clone()];
290
291        let state = MockState::new()
292            .with_metrics(
293                "d1",
294                DeploymentMetrics {
295                    latency_ewma_ms: 50.0,
296                    ..Default::default()
297                },
298            )
299            .with_metrics(
300                "d2",
301                DeploymentMetrics {
302                    latency_ewma_ms: 200.0,
303                    ..Default::default()
304                },
305            )
306            .with_cooldown("d1");
307
308        let strategy = LatencyBased::new();
309        let ctx = RoutingContext::default();
310
311        let result = strategy
312            .select("test-model", &candidates, &state, &ctx)
313            .await
314            .unwrap();
315        assert_eq!(result.id, "d2");
316    }
317
318    #[test]
319    fn test_global_median_odd() {
320        let mut metrics = HashMap::new();
321        metrics.insert(
322            "a".to_string(),
323            DeploymentMetrics {
324                latency_ewma_ms: 100.0,
325                ..Default::default()
326            },
327        );
328        metrics.insert(
329            "b".to_string(),
330            DeploymentMetrics {
331                latency_ewma_ms: 200.0,
332                ..Default::default()
333            },
334        );
335        metrics.insert(
336            "c".to_string(),
337            DeploymentMetrics {
338                latency_ewma_ms: 300.0,
339                ..Default::default()
340            },
341        );
342
343        let median = compute_global_median(&metrics);
344        assert!((median - 200.0).abs() < f64::EPSILON);
345    }
346
347    #[test]
348    fn test_global_median_even() {
349        let mut metrics = HashMap::new();
350        metrics.insert(
351            "a".to_string(),
352            DeploymentMetrics {
353                latency_ewma_ms: 100.0,
354                ..Default::default()
355            },
356        );
357        metrics.insert(
358            "b".to_string(),
359            DeploymentMetrics {
360                latency_ewma_ms: 200.0,
361                ..Default::default()
362            },
363        );
364        metrics.insert(
365            "c".to_string(),
366            DeploymentMetrics {
367                latency_ewma_ms: 300.0,
368                ..Default::default()
369            },
370        );
371        metrics.insert(
372            "d".to_string(),
373            DeploymentMetrics {
374                latency_ewma_ms: 400.0,
375                ..Default::default()
376            },
377        );
378
379        let median = compute_global_median(&metrics);
380        assert!((median - 250.0).abs() < f64::EPSILON);
381    }
382
383    #[test]
384    fn test_global_median_empty() {
385        let metrics = HashMap::new();
386        let median = compute_global_median(&metrics);
387        assert!((median - 0.0).abs() < f64::EPSILON);
388    }
389}