Skip to main content

hyperinfer_router/strategy/
usage_based.rs

1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct UsageBased;
10
11impl Default for UsageBased {
12    fn default() -> Self {
13        Self
14    }
15}
16
17impl UsageBased {
18    pub fn new() -> Self {
19        Self
20    }
21}
22
23#[async_trait]
24impl RoutingStrategy for UsageBased {
25    fn name(&self) -> &str {
26        "usage-based"
27    }
28
29    async fn select<'a>(
30        &self,
31        _model: &str,
32        candidates: &'a [Arc<Deployment>],
33        state: &dyn RoutingState,
34        _request: &RoutingContext,
35    ) -> Result<&'a Arc<Deployment>, RoutingError> {
36        if candidates.is_empty() {
37            return Err(RoutingError::NoDeployments("empty candidates".into()));
38        }
39
40        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
41        let all_metrics = state.get_all_metrics(&ids).await?;
42
43        let mut eligible: Vec<(usize, f64)> = Vec::new();
44
45        for (i, deployment) in candidates.iter().enumerate() {
46            if state.is_cooled_down(&deployment.id).await? {
47                continue;
48            }
49
50            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
51
52            let utilization = match deployment.tpm_limit {
53                Some(limit) if limit > 0 => metrics.tpm_used as f64 / limit as f64,
54                _ => 0.0,
55            };
56
57            eligible.push((i, utilization));
58        }
59
60        if eligible.is_empty() {
61            return Err(RoutingError::NoDeployments(
62                "no eligible deployments after filtering".into(),
63            ));
64        }
65
66        let best_idx = eligible
67            .iter()
68            .enumerate()
69            .min_by(|(_, (_, a)), (_, (_, b))| {
70                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
71            })
72            .map(|(idx, _)| idx)
73            .unwrap();
74
75        let (i, _) = eligible[best_idx];
76        Ok(&candidates[i])
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::deployment::Deployment;
84    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
85    use crate::strategy::DeploymentMetrics;
86    use hyperinfer_core::Provider;
87
88    fn make_deployment(id: &str, weight: u32, tpm_limit: Option<u64>) -> Arc<Deployment> {
89        let mut d = Deployment::new(
90            "test-model".to_string(),
91            Provider::OpenAI,
92            "gpt-4".to_string(),
93            format!("key-{}", id),
94        );
95        d.weight = weight;
96        d.id = id.to_string();
97        d.tpm_limit = tpm_limit;
98        Arc::new(d)
99    }
100
101    #[tokio::test]
102    async fn test_selects_lowest_utilization() {
103        let d1 = make_deployment("d1", 1, Some(1000));
104        let d2 = make_deployment("d2", 1, Some(1000));
105        let candidates = vec![d1, d2.clone()];
106
107        let state = MockState::new()
108            .with_metrics(
109                "d1",
110                DeploymentMetrics {
111                    tpm_used: 800,
112                    ..Default::default()
113                },
114            )
115            .with_metrics(
116                "d2",
117                DeploymentMetrics {
118                    tpm_used: 200,
119                    ..Default::default()
120                },
121            );
122
123        let strategy = UsageBased::new();
124        let ctx = RoutingContext::default();
125
126        let result = strategy
127            .select("test-model", &candidates, &state, &ctx)
128            .await
129            .unwrap();
130        assert_eq!(result.id, "d2");
131    }
132
133    #[tokio::test]
134    async fn test_ratio_not_absolute() {
135        let d_big = make_deployment("d_big", 1, Some(100000));
136        let d_small = make_deployment("d_small", 1, Some(1000));
137        let candidates = vec![d_big.clone(), d_small];
138
139        let state = MockState::new()
140            .with_metrics(
141                "d_big",
142                DeploymentMetrics {
143                    tpm_used: 50000,
144                    ..Default::default()
145                },
146            )
147            .with_metrics(
148                "d_small",
149                DeploymentMetrics {
150                    tpm_used: 900,
151                    ..Default::default()
152                },
153            );
154
155        let strategy = UsageBased::new();
156        let ctx = RoutingContext::default();
157
158        let result = strategy
159            .select("test-model", &candidates, &state, &ctx)
160            .await
161            .unwrap();
162        assert_eq!(result.id, "d_big");
163    }
164
165    #[tokio::test]
166    async fn test_cooled_down_excluded() {
167        let d1 = make_deployment("d1", 1, Some(1000));
168        let d2 = make_deployment("d2", 1, Some(1000));
169        let candidates = vec![d1, d2.clone()];
170
171        let state = MockState::new()
172            .with_metrics(
173                "d1",
174                DeploymentMetrics {
175                    tpm_used: 100,
176                    ..Default::default()
177                },
178            )
179            .with_metrics(
180                "d2",
181                DeploymentMetrics {
182                    tpm_used: 800,
183                    ..Default::default()
184                },
185            )
186            .with_cooldown("d1");
187
188        let strategy = UsageBased::new();
189        let ctx = RoutingContext::default();
190
191        let result = strategy
192            .select("test-model", &candidates, &state, &ctx)
193            .await
194            .unwrap();
195        assert_eq!(result.id, "d2");
196    }
197}