hyperinfer_router/strategy/
usage_based.rs1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct UsageBased;
10
11impl Default for UsageBased {
12 fn default() -> Self {
13 Self
14 }
15}
16
17impl UsageBased {
18 pub fn new() -> Self {
19 Self
20 }
21}
22
23#[async_trait]
24impl RoutingStrategy for UsageBased {
25 fn name(&self) -> &str {
26 "usage-based"
27 }
28
29 async fn select<'a>(
30 &self,
31 _model: &str,
32 candidates: &'a [Arc<Deployment>],
33 state: &dyn RoutingState,
34 _request: &RoutingContext,
35 ) -> Result<&'a Arc<Deployment>, RoutingError> {
36 if candidates.is_empty() {
37 return Err(RoutingError::NoDeployments("empty candidates".into()));
38 }
39
40 let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
41 let all_metrics = state.get_all_metrics(&ids).await?;
42
43 let mut eligible: Vec<(usize, f64)> = Vec::new();
44
45 for (i, deployment) in candidates.iter().enumerate() {
46 if state.is_cooled_down(&deployment.id).await? {
47 continue;
48 }
49
50 let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
51
52 let utilization = match deployment.tpm_limit {
53 Some(limit) if limit > 0 => metrics.tpm_used as f64 / limit as f64,
54 _ => 0.0,
55 };
56
57 eligible.push((i, utilization));
58 }
59
60 if eligible.is_empty() {
61 return Err(RoutingError::NoDeployments(
62 "no eligible deployments after filtering".into(),
63 ));
64 }
65
66 let best_idx = eligible
67 .iter()
68 .enumerate()
69 .min_by(|(_, (_, a)), (_, (_, b))| {
70 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
71 })
72 .map(|(idx, _)| idx)
73 .unwrap();
74
75 let (i, _) = eligible[best_idx];
76 Ok(&candidates[i])
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::deployment::Deployment;
84 use crate::strategy::weighted_shuffle::tests_helpers::MockState;
85 use crate::strategy::DeploymentMetrics;
86 use hyperinfer_core::Provider;
87
88 fn make_deployment(id: &str, weight: u32, tpm_limit: Option<u64>) -> Arc<Deployment> {
89 let mut d = Deployment::new(
90 "test-model".to_string(),
91 Provider::OpenAI,
92 "gpt-4".to_string(),
93 format!("key-{}", id),
94 );
95 d.weight = weight;
96 d.id = id.to_string();
97 d.tpm_limit = tpm_limit;
98 Arc::new(d)
99 }
100
101 #[tokio::test]
102 async fn test_selects_lowest_utilization() {
103 let d1 = make_deployment("d1", 1, Some(1000));
104 let d2 = make_deployment("d2", 1, Some(1000));
105 let candidates = vec![d1, d2.clone()];
106
107 let state = MockState::new()
108 .with_metrics(
109 "d1",
110 DeploymentMetrics {
111 tpm_used: 800,
112 ..Default::default()
113 },
114 )
115 .with_metrics(
116 "d2",
117 DeploymentMetrics {
118 tpm_used: 200,
119 ..Default::default()
120 },
121 );
122
123 let strategy = UsageBased::new();
124 let ctx = RoutingContext::default();
125
126 let result = strategy
127 .select("test-model", &candidates, &state, &ctx)
128 .await
129 .unwrap();
130 assert_eq!(result.id, "d2");
131 }
132
133 #[tokio::test]
134 async fn test_ratio_not_absolute() {
135 let d_big = make_deployment("d_big", 1, Some(100000));
136 let d_small = make_deployment("d_small", 1, Some(1000));
137 let candidates = vec![d_big.clone(), d_small];
138
139 let state = MockState::new()
140 .with_metrics(
141 "d_big",
142 DeploymentMetrics {
143 tpm_used: 50000,
144 ..Default::default()
145 },
146 )
147 .with_metrics(
148 "d_small",
149 DeploymentMetrics {
150 tpm_used: 900,
151 ..Default::default()
152 },
153 );
154
155 let strategy = UsageBased::new();
156 let ctx = RoutingContext::default();
157
158 let result = strategy
159 .select("test-model", &candidates, &state, &ctx)
160 .await
161 .unwrap();
162 assert_eq!(result.id, "d_big");
163 }
164
165 #[tokio::test]
166 async fn test_cooled_down_excluded() {
167 let d1 = make_deployment("d1", 1, Some(1000));
168 let d2 = make_deployment("d2", 1, Some(1000));
169 let candidates = vec![d1, d2.clone()];
170
171 let state = MockState::new()
172 .with_metrics(
173 "d1",
174 DeploymentMetrics {
175 tpm_used: 100,
176 ..Default::default()
177 },
178 )
179 .with_metrics(
180 "d2",
181 DeploymentMetrics {
182 tpm_used: 800,
183 ..Default::default()
184 },
185 )
186 .with_cooldown("d1");
187
188 let strategy = UsageBased::new();
189 let ctx = RoutingContext::default();
190
191 let result = strategy
192 .select("test-model", &candidates, &state, &ctx)
193 .await
194 .unwrap();
195 assert_eq!(result.id, "d2");
196 }
197}