Skip to main content

hyperinfer_router/strategy/
cost_based.rs

1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CostBased;
10
11impl Default for CostBased {
12    fn default() -> Self {
13        Self
14    }
15}
16
17impl CostBased {
18    pub fn new() -> Self {
19        Self
20    }
21}
22
23fn estimate_cost(deployment: &Deployment, ctx: &RoutingContext) -> f64 {
24    let input_tokens = ctx.estimated_input_tokens.unwrap_or(0) as f64;
25    let output_tokens = ctx.estimated_output_tokens.unwrap_or(0) as f64;
26
27    let input_cost = deployment.input_cost_per_1k.unwrap_or(0.0);
28    let output_cost = deployment.output_cost_per_1k.unwrap_or(0.0);
29
30    (input_tokens / 1000.0) * input_cost + (output_tokens / 1000.0) * output_cost
31}
32
33#[async_trait]
34impl RoutingStrategy for CostBased {
35    fn name(&self) -> &str {
36        "cost-based"
37    }
38
39    async fn select<'a>(
40        &self,
41        _model: &str,
42        candidates: &'a [Arc<Deployment>],
43        state: &dyn RoutingState,
44        request: &RoutingContext,
45    ) -> Result<&'a Arc<Deployment>, RoutingError> {
46        if candidates.is_empty() {
47            return Err(RoutingError::NoDeployments("empty candidates".into()));
48        }
49
50        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
51        let all_metrics = state.get_all_metrics(&ids).await?;
52
53        let mut eligible: Vec<usize> = Vec::new();
54
55        for (i, deployment) in candidates.iter().enumerate() {
56            if state.is_cooled_down(&deployment.id).await? {
57                continue;
58            }
59
60            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
61
62            if let Some(rpm_limit) = deployment.rpm_limit {
63                if metrics.rpm_used >= rpm_limit {
64                    continue;
65                }
66            }
67
68            if let Some(tpm_limit) = deployment.tpm_limit {
69                if metrics.tpm_used >= tpm_limit {
70                    continue;
71                }
72            }
73
74            eligible.push(i);
75        }
76
77        if eligible.is_empty() {
78            return Err(RoutingError::NoDeployments(
79                "no eligible deployments after filtering".into(),
80            ));
81        }
82
83        let mut best_idx = eligible[0];
84        let mut best_cost = estimate_cost(&candidates[best_idx], request);
85        let mut best_weight = candidates[best_idx].weight;
86
87        for &i in eligible.iter().skip(1) {
88            let cost = estimate_cost(&candidates[i], request);
89            let weight = candidates[i].weight;
90
91            if cost < best_cost || ((cost - best_cost).abs() < f64::EPSILON && weight > best_weight)
92            {
93                best_idx = i;
94                best_cost = cost;
95                best_weight = weight;
96            }
97        }
98
99        Ok(&candidates[best_idx])
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::deployment::Deployment;
107    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
108    use hyperinfer_core::Provider;
109
110    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
111        let mut d = Deployment::new(
112            "test-model".to_string(),
113            Provider::OpenAI,
114            "gpt-4".to_string(),
115            format!("key-{}", id),
116        );
117        d.weight = weight;
118        d.id = id.to_string();
119        Arc::new(d)
120    }
121
122    fn make_deployment_with_costs(
123        id: &str,
124        weight: u32,
125        input_cost: f64,
126        output_cost: f64,
127    ) -> Arc<Deployment> {
128        let mut d = Deployment::new(
129            "test-model".to_string(),
130            Provider::OpenAI,
131            "gpt-4".to_string(),
132            format!("key-{}", id),
133        );
134        d.weight = weight;
135        d.id = id.to_string();
136        d.input_cost_per_1k = Some(input_cost);
137        d.output_cost_per_1k = Some(output_cost);
138        Arc::new(d)
139    }
140
141    #[tokio::test]
142    async fn test_selects_cheapest() {
143        let d1 = make_deployment_with_costs("d1", 1, 0.03, 0.06);
144        let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
145        let candidates = vec![d1, d2.clone()];
146
147        let state = MockState::new();
148        let strategy = CostBased::new();
149        let ctx = RoutingContext {
150            estimated_input_tokens: Some(1000),
151            estimated_output_tokens: Some(1000),
152            ..Default::default()
153        };
154
155        let result = strategy
156            .select("test-model", &candidates, &state, &ctx)
157            .await
158            .unwrap();
159        assert_eq!(result.id, "d2");
160    }
161
162    #[tokio::test]
163    async fn test_equal_cost_prefers_higher_weight() {
164        let d1 = make_deployment_with_costs("d1", 5, 0.01, 0.02);
165        let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
166        let candidates = vec![d1.clone(), d2];
167
168        let state = MockState::new();
169        let strategy = CostBased::new();
170        let ctx = RoutingContext {
171            estimated_input_tokens: Some(1000),
172            estimated_output_tokens: Some(1000),
173            ..Default::default()
174        };
175
176        let result = strategy
177            .select("test-model", &candidates, &state, &ctx)
178            .await
179            .unwrap();
180        assert_eq!(result.id, "d1");
181    }
182
183    #[tokio::test]
184    async fn test_no_cost_data_returns_first() {
185        let d1 = make_deployment("d1", 1);
186        let d2 = make_deployment("d2", 1);
187        let candidates = vec![d1.clone(), d2];
188
189        let state = MockState::new();
190        let strategy = CostBased::new();
191        let ctx = RoutingContext {
192            estimated_input_tokens: Some(1000),
193            estimated_output_tokens: Some(1000),
194            ..Default::default()
195        };
196
197        let result = strategy
198            .select("test-model", &candidates, &state, &ctx)
199            .await
200            .unwrap();
201        assert_eq!(result.id, "d1");
202    }
203
204    #[test]
205    fn test_estimate_cost_calculation() {
206        let d = make_deployment_with_costs("d1", 1, 0.01, 0.03);
207        let ctx = RoutingContext {
208            estimated_input_tokens: Some(2000),
209            estimated_output_tokens: Some(1000),
210            ..Default::default()
211        };
212
213        let cost = estimate_cost(&d, &ctx);
214        assert!((cost - 0.05).abs() < f64::EPSILON);
215    }
216
217    #[test]
218    fn test_estimate_cost_no_prices() {
219        let d = make_deployment("d1", 1);
220        let ctx = RoutingContext {
221            estimated_input_tokens: Some(2000),
222            estimated_output_tokens: Some(1000),
223            ..Default::default()
224        };
225
226        let cost = estimate_cost(&d, &ctx);
227        assert!((cost - 0.0).abs() < f64::EPSILON);
228    }
229}