Skip to main content

hyperinfer_router/strategy/
least_busy.rs

1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct LeastBusy;
11
12impl Default for LeastBusy {
13    fn default() -> Self {
14        Self
15    }
16}
17
18impl LeastBusy {
19    pub fn new() -> Self {
20        Self
21    }
22}
23
24#[async_trait]
25impl RoutingStrategy for LeastBusy {
26    fn name(&self) -> &str {
27        "least-busy"
28    }
29
30    async fn select<'a>(
31        &self,
32        _model: &str,
33        candidates: &'a [Arc<Deployment>],
34        state: &dyn RoutingState,
35        _request: &RoutingContext,
36    ) -> Result<&'a Arc<Deployment>, RoutingError> {
37        if candidates.is_empty() {
38            return Err(RoutingError::NoDeployments("empty candidates".into()));
39        }
40
41        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
42        let all_metrics = state.get_all_metrics(&ids).await?;
43
44        let mut eligible: Vec<(usize, u64)> = Vec::new();
45
46        for (i, deployment) in candidates.iter().enumerate() {
47            if state.is_cooled_down(&deployment.id).await? {
48                continue;
49            }
50
51            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
52            eligible.push((i, metrics.in_flight));
53        }
54
55        if eligible.is_empty() {
56            return Err(RoutingError::NoDeployments(
57                "no eligible deployments after filtering".into(),
58            ));
59        }
60
61        let min_in_flight = eligible.iter().map(|(_, f)| *f).min().unwrap();
62
63        let tied: Vec<(usize, u64)> = eligible
64            .into_iter()
65            .filter(|(_, f)| *f == min_in_flight)
66            .collect();
67
68        let weights: Vec<f64> = tied
69            .iter()
70            .map(|(i, _)| candidates[*i].weight as f64)
71            .collect();
72
73        let total_weight: f64 = weights.iter().sum();
74        let mut rng = rand::thread_rng();
75        let mut pick = rng.gen_range(0.0..total_weight);
76
77        for (idx, weight) in weights.iter().enumerate() {
78            pick -= weight;
79            if pick <= 0.0 {
80                let (i, _) = tied[idx];
81                return Ok(&candidates[i]);
82            }
83        }
84
85        let (i, _) = *tied.last().unwrap();
86        Ok(&candidates[i])
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::deployment::Deployment;
94    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
95    use crate::strategy::DeploymentMetrics;
96    use hyperinfer_core::Provider;
97
98    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
99        let mut d = Deployment::new(
100            "test-model".to_string(),
101            Provider::OpenAI,
102            "gpt-4".to_string(),
103            format!("key-{}", id),
104        );
105        d.weight = weight;
106        d.id = id.to_string();
107        Arc::new(d)
108    }
109
110    #[tokio::test]
111    async fn test_selects_least_busy() {
112        let d1 = make_deployment("d1", 1);
113        let d2 = make_deployment("d2", 1);
114        let candidates = vec![d1, d2.clone()];
115
116        let state = MockState::new()
117            .with_metrics(
118                "d1",
119                DeploymentMetrics {
120                    in_flight: 10,
121                    ..Default::default()
122                },
123            )
124            .with_metrics(
125                "d2",
126                DeploymentMetrics {
127                    in_flight: 2,
128                    ..Default::default()
129                },
130            );
131
132        let strategy = LeastBusy::new();
133        let ctx = RoutingContext::default();
134
135        let result = strategy
136            .select("test-model", &candidates, &state, &ctx)
137            .await
138            .unwrap();
139        assert_eq!(result.id, "d2");
140    }
141
142    #[tokio::test]
143    async fn test_tie_broken_by_weight() {
144        let d1 = make_deployment("d1", 9);
145        let d2 = make_deployment("d2", 1);
146        let candidates = vec![d1, d2];
147
148        let state = MockState::new()
149            .with_metrics(
150                "d1",
151                DeploymentMetrics {
152                    in_flight: 5,
153                    ..Default::default()
154                },
155            )
156            .with_metrics(
157                "d2",
158                DeploymentMetrics {
159                    in_flight: 5,
160                    ..Default::default()
161                },
162            );
163
164        let strategy = LeastBusy::new();
165        let ctx = RoutingContext::default();
166
167        let mut d1_count = 0u32;
168        for _ in 0..5000 {
169            let result = strategy
170                .select("test-model", &candidates, &state, &ctx)
171                .await
172                .unwrap();
173            if result.id == "d1" {
174                d1_count += 1;
175            }
176        }
177
178        let ratio = d1_count as f64 / 5000.0;
179        assert!(
180            ratio > 0.80,
181            "d1 should win >80% with 9:1 weight, got {:.2}%",
182            ratio * 100.0
183        );
184    }
185
186    #[tokio::test]
187    async fn test_cooled_down_excluded() {
188        let d1 = make_deployment("d1", 1);
189        let d2 = make_deployment("d2", 1);
190        let candidates = vec![d1, d2.clone()];
191
192        let state = MockState::new()
193            .with_metrics(
194                "d1",
195                DeploymentMetrics {
196                    in_flight: 1,
197                    ..Default::default()
198                },
199            )
200            .with_metrics(
201                "d2",
202                DeploymentMetrics {
203                    in_flight: 10,
204                    ..Default::default()
205                },
206            )
207            .with_cooldown("d1");
208
209        let strategy = LeastBusy::new();
210        let ctx = RoutingContext::default();
211
212        let result = strategy
213            .select("test-model", &candidates, &state, &ctx)
214            .await
215            .unwrap();
216        assert_eq!(result.id, "d2");
217    }
218}