hyperinfer_router/strategy/
least_busy.rs1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct LeastBusy;
11
12impl Default for LeastBusy {
13 fn default() -> Self {
14 Self
15 }
16}
17
18impl LeastBusy {
19 pub fn new() -> Self {
20 Self
21 }
22}
23
24#[async_trait]
25impl RoutingStrategy for LeastBusy {
26 fn name(&self) -> &str {
27 "least-busy"
28 }
29
30 async fn select<'a>(
31 &self,
32 _model: &str,
33 candidates: &'a [Arc<Deployment>],
34 state: &dyn RoutingState,
35 _request: &RoutingContext,
36 ) -> Result<&'a Arc<Deployment>, RoutingError> {
37 if candidates.is_empty() {
38 return Err(RoutingError::NoDeployments("empty candidates".into()));
39 }
40
41 let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
42 let all_metrics = state.get_all_metrics(&ids).await?;
43
44 let mut eligible: Vec<(usize, u64)> = Vec::new();
45
46 for (i, deployment) in candidates.iter().enumerate() {
47 if state.is_cooled_down(&deployment.id).await? {
48 continue;
49 }
50
51 let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
52 eligible.push((i, metrics.in_flight));
53 }
54
55 if eligible.is_empty() {
56 return Err(RoutingError::NoDeployments(
57 "no eligible deployments after filtering".into(),
58 ));
59 }
60
61 let min_in_flight = eligible.iter().map(|(_, f)| *f).min().unwrap();
62
63 let tied: Vec<(usize, u64)> = eligible
64 .into_iter()
65 .filter(|(_, f)| *f == min_in_flight)
66 .collect();
67
68 let weights: Vec<f64> = tied
69 .iter()
70 .map(|(i, _)| candidates[*i].weight as f64)
71 .collect();
72
73 let total_weight: f64 = weights.iter().sum();
74 let mut rng = rand::thread_rng();
75 let mut pick = rng.gen_range(0.0..total_weight);
76
77 for (idx, weight) in weights.iter().enumerate() {
78 pick -= weight;
79 if pick <= 0.0 {
80 let (i, _) = tied[idx];
81 return Ok(&candidates[i]);
82 }
83 }
84
85 let (i, _) = *tied.last().unwrap();
86 Ok(&candidates[i])
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::deployment::Deployment;
94 use crate::strategy::weighted_shuffle::tests_helpers::MockState;
95 use crate::strategy::DeploymentMetrics;
96 use hyperinfer_core::Provider;
97
98 fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
99 let mut d = Deployment::new(
100 "test-model".to_string(),
101 Provider::OpenAI,
102 "gpt-4".to_string(),
103 format!("key-{}", id),
104 );
105 d.weight = weight;
106 d.id = id.to_string();
107 Arc::new(d)
108 }
109
110 #[tokio::test]
111 async fn test_selects_least_busy() {
112 let d1 = make_deployment("d1", 1);
113 let d2 = make_deployment("d2", 1);
114 let candidates = vec![d1, d2.clone()];
115
116 let state = MockState::new()
117 .with_metrics(
118 "d1",
119 DeploymentMetrics {
120 in_flight: 10,
121 ..Default::default()
122 },
123 )
124 .with_metrics(
125 "d2",
126 DeploymentMetrics {
127 in_flight: 2,
128 ..Default::default()
129 },
130 );
131
132 let strategy = LeastBusy::new();
133 let ctx = RoutingContext::default();
134
135 let result = strategy
136 .select("test-model", &candidates, &state, &ctx)
137 .await
138 .unwrap();
139 assert_eq!(result.id, "d2");
140 }
141
142 #[tokio::test]
143 async fn test_tie_broken_by_weight() {
144 let d1 = make_deployment("d1", 9);
145 let d2 = make_deployment("d2", 1);
146 let candidates = vec![d1, d2];
147
148 let state = MockState::new()
149 .with_metrics(
150 "d1",
151 DeploymentMetrics {
152 in_flight: 5,
153 ..Default::default()
154 },
155 )
156 .with_metrics(
157 "d2",
158 DeploymentMetrics {
159 in_flight: 5,
160 ..Default::default()
161 },
162 );
163
164 let strategy = LeastBusy::new();
165 let ctx = RoutingContext::default();
166
167 let mut d1_count = 0u32;
168 for _ in 0..5000 {
169 let result = strategy
170 .select("test-model", &candidates, &state, &ctx)
171 .await
172 .unwrap();
173 if result.id == "d1" {
174 d1_count += 1;
175 }
176 }
177
178 let ratio = d1_count as f64 / 5000.0;
179 assert!(
180 ratio > 0.80,
181 "d1 should win >80% with 9:1 weight, got {:.2}%",
182 ratio * 100.0
183 );
184 }
185
186 #[tokio::test]
187 async fn test_cooled_down_excluded() {
188 let d1 = make_deployment("d1", 1);
189 let d2 = make_deployment("d2", 1);
190 let candidates = vec![d1, d2.clone()];
191
192 let state = MockState::new()
193 .with_metrics(
194 "d1",
195 DeploymentMetrics {
196 in_flight: 1,
197 ..Default::default()
198 },
199 )
200 .with_metrics(
201 "d2",
202 DeploymentMetrics {
203 in_flight: 10,
204 ..Default::default()
205 },
206 )
207 .with_cooldown("d1");
208
209 let strategy = LeastBusy::new();
210 let ctx = RoutingContext::default();
211
212 let result = strategy
213 .select("test-model", &candidates, &state, &ctx)
214 .await
215 .unwrap();
216 assert_eq!(result.id, "d2");
217 }
218}