hyperinfer_router/strategy/
cost_based.rs1use super::{RoutingContext, RoutingState, RoutingStrategy};
2use crate::deployment::Deployment;
3use crate::error::RoutingError;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CostBased;
10
11impl Default for CostBased {
12 fn default() -> Self {
13 Self
14 }
15}
16
17impl CostBased {
18 pub fn new() -> Self {
19 Self
20 }
21}
22
23fn estimate_cost(deployment: &Deployment, ctx: &RoutingContext) -> f64 {
24 let input_tokens = ctx.estimated_input_tokens.unwrap_or(0) as f64;
25 let output_tokens = ctx.estimated_output_tokens.unwrap_or(0) as f64;
26
27 let input_cost = deployment.input_cost_per_1k.unwrap_or(0.0);
28 let output_cost = deployment.output_cost_per_1k.unwrap_or(0.0);
29
30 (input_tokens / 1000.0) * input_cost + (output_tokens / 1000.0) * output_cost
31}
32
33#[async_trait]
34impl RoutingStrategy for CostBased {
35 fn name(&self) -> &str {
36 "cost-based"
37 }
38
39 async fn select<'a>(
40 &self,
41 _model: &str,
42 candidates: &'a [Arc<Deployment>],
43 state: &dyn RoutingState,
44 request: &RoutingContext,
45 ) -> Result<&'a Arc<Deployment>, RoutingError> {
46 if candidates.is_empty() {
47 return Err(RoutingError::NoDeployments("empty candidates".into()));
48 }
49
50 let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
51 let all_metrics = state.get_all_metrics(&ids).await?;
52
53 let mut eligible: Vec<usize> = Vec::new();
54
55 for (i, deployment) in candidates.iter().enumerate() {
56 if state.is_cooled_down(&deployment.id).await? {
57 continue;
58 }
59
60 let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();
61
62 if let Some(rpm_limit) = deployment.rpm_limit {
63 if metrics.rpm_used >= rpm_limit {
64 continue;
65 }
66 }
67
68 if let Some(tpm_limit) = deployment.tpm_limit {
69 if metrics.tpm_used >= tpm_limit {
70 continue;
71 }
72 }
73
74 eligible.push(i);
75 }
76
77 if eligible.is_empty() {
78 return Err(RoutingError::NoDeployments(
79 "no eligible deployments after filtering".into(),
80 ));
81 }
82
83 let mut best_idx = eligible[0];
84 let mut best_cost = estimate_cost(&candidates[best_idx], request);
85 let mut best_weight = candidates[best_idx].weight;
86
87 for &i in eligible.iter().skip(1) {
88 let cost = estimate_cost(&candidates[i], request);
89 let weight = candidates[i].weight;
90
91 if cost < best_cost || ((cost - best_cost).abs() < f64::EPSILON && weight > best_weight)
92 {
93 best_idx = i;
94 best_cost = cost;
95 best_weight = weight;
96 }
97 }
98
99 Ok(&candidates[best_idx])
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::deployment::Deployment;
107 use crate::strategy::weighted_shuffle::tests_helpers::MockState;
108 use hyperinfer_core::Provider;
109
110 fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
111 let mut d = Deployment::new(
112 "test-model".to_string(),
113 Provider::OpenAI,
114 "gpt-4".to_string(),
115 format!("key-{}", id),
116 );
117 d.weight = weight;
118 d.id = id.to_string();
119 Arc::new(d)
120 }
121
122 fn make_deployment_with_costs(
123 id: &str,
124 weight: u32,
125 input_cost: f64,
126 output_cost: f64,
127 ) -> Arc<Deployment> {
128 let mut d = Deployment::new(
129 "test-model".to_string(),
130 Provider::OpenAI,
131 "gpt-4".to_string(),
132 format!("key-{}", id),
133 );
134 d.weight = weight;
135 d.id = id.to_string();
136 d.input_cost_per_1k = Some(input_cost);
137 d.output_cost_per_1k = Some(output_cost);
138 Arc::new(d)
139 }
140
141 #[tokio::test]
142 async fn test_selects_cheapest() {
143 let d1 = make_deployment_with_costs("d1", 1, 0.03, 0.06);
144 let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
145 let candidates = vec![d1, d2.clone()];
146
147 let state = MockState::new();
148 let strategy = CostBased::new();
149 let ctx = RoutingContext {
150 estimated_input_tokens: Some(1000),
151 estimated_output_tokens: Some(1000),
152 ..Default::default()
153 };
154
155 let result = strategy
156 .select("test-model", &candidates, &state, &ctx)
157 .await
158 .unwrap();
159 assert_eq!(result.id, "d2");
160 }
161
162 #[tokio::test]
163 async fn test_equal_cost_prefers_higher_weight() {
164 let d1 = make_deployment_with_costs("d1", 5, 0.01, 0.02);
165 let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
166 let candidates = vec![d1.clone(), d2];
167
168 let state = MockState::new();
169 let strategy = CostBased::new();
170 let ctx = RoutingContext {
171 estimated_input_tokens: Some(1000),
172 estimated_output_tokens: Some(1000),
173 ..Default::default()
174 };
175
176 let result = strategy
177 .select("test-model", &candidates, &state, &ctx)
178 .await
179 .unwrap();
180 assert_eq!(result.id, "d1");
181 }
182
183 #[tokio::test]
184 async fn test_no_cost_data_returns_first() {
185 let d1 = make_deployment("d1", 1);
186 let d2 = make_deployment("d2", 1);
187 let candidates = vec![d1.clone(), d2];
188
189 let state = MockState::new();
190 let strategy = CostBased::new();
191 let ctx = RoutingContext {
192 estimated_input_tokens: Some(1000),
193 estimated_output_tokens: Some(1000),
194 ..Default::default()
195 };
196
197 let result = strategy
198 .select("test-model", &candidates, &state, &ctx)
199 .await
200 .unwrap();
201 assert_eq!(result.id, "d1");
202 }
203
204 #[test]
205 fn test_estimate_cost_calculation() {
206 let d = make_deployment_with_costs("d1", 1, 0.01, 0.03);
207 let ctx = RoutingContext {
208 estimated_input_tokens: Some(2000),
209 estimated_output_tokens: Some(1000),
210 ..Default::default()
211 };
212
213 let cost = estimate_cost(&d, &ctx);
214 assert!((cost - 0.05).abs() < f64::EPSILON);
215 }
216
217 #[test]
218 fn test_estimate_cost_no_prices() {
219 let d = make_deployment("d1", 1);
220 let ctx = RoutingContext {
221 estimated_input_tokens: Some(2000),
222 estimated_output_tokens: Some(1000),
223 ..Default::default()
224 };
225
226 let cost = estimate_cost(&d, &ctx);
227 assert!((cost - 0.0).abs() < f64::EPSILON);
228 }
229}