1use crate::deployment::{Deployment, DeploymentPool};
2use crate::error::RoutingError;
3use crate::fallback::{ErrorKind, FallbackConfig};
4use crate::strategy::{RoutingContext, RoutingState, RoutingStrategy};
5use std::collections::{HashMap, HashSet};
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
13pub struct GlobalLimits {
14 pub max_total_attempts: u32,
15 pub global_timeout_ms: u64,
16}
17
18impl Default for GlobalLimits {
19 fn default() -> Self {
20 Self {
21 max_total_attempts: 10,
22 global_timeout_ms: 60_000,
23 }
24 }
25}
26
27#[derive(Debug, Clone)]
28pub struct RoutingResult {
29 pub deployment: Arc<Deployment>,
30 pub attempt: u32,
31 pub fallback_chain: Vec<String>,
32}
33
34pub struct RouterEngine {
35 pool: Arc<RwLock<DeploymentPool>>,
36 strategies: Arc<RwLock<HashMap<String, Box<dyn RoutingStrategy>>>>,
37 default_strategy: Arc<RwLock<String>>,
38 fallback_config: Arc<RwLock<FallbackConfig>>,
39 global_limits: GlobalLimits,
40 model_aliases: Arc<RwLock<HashMap<String, String>>>,
41 routing_groups: Arc<RwLock<HashMap<String, String>>>,
42}
43
44impl RouterEngine {
45 pub fn new(global_limits: GlobalLimits) -> Self {
46 Self {
47 pool: Arc::new(RwLock::new(DeploymentPool::new())),
48 strategies: Arc::new(RwLock::new(HashMap::new())),
49 default_strategy: Arc::new(RwLock::new("weighted-shuffle".to_string())),
50 fallback_config: Arc::new(RwLock::new(FallbackConfig::new())),
51 global_limits,
52 model_aliases: Arc::new(RwLock::new(HashMap::new())),
53 routing_groups: Arc::new(RwLock::new(HashMap::new())),
54 }
55 }
56
57 pub async fn add_deployment(&self, deployment: Deployment) {
58 let mut pool = self.pool.write().await;
59 pool.add(deployment);
60 }
61
62 pub async fn remove_deployment(&self, deployment_id: &str) {
63 let mut pool = self.pool.write().await;
64 pool.remove(deployment_id);
65 }
66
67 pub async fn rebuild_pool(&self, deployments: Vec<Deployment>) {
68 let mut pool = self.pool.write().await;
69 *pool = DeploymentPool::new();
70 for d in deployments {
71 pool.add(d);
72 }
73 }
74
75 pub async fn register_strategy(&self, strategy: Box<dyn RoutingStrategy>) {
76 let name = strategy.name().to_string();
77 let mut strategies = self.strategies.write().await;
78 strategies.insert(name, strategy);
79 }
80
81 pub async fn set_default_strategy(&self, name: &str) {
82 let mut ds = self.default_strategy.write().await;
83 *ds = name.to_string();
84 }
85
86 pub async fn set_fallback_config(&self, config: FallbackConfig) {
87 let mut fc = self.fallback_config.write().await;
88 *fc = config;
89 }
90
91 pub async fn set_alias(&self, alias: &str, target: &str) {
92 let mut aliases = self.model_aliases.write().await;
93 aliases.insert(alias.to_string(), target.to_string());
94 }
95
96 pub async fn set_routing_group(&self, model: &str, strategy_name: &str) {
97 let mut groups = self.routing_groups.write().await;
98 groups.insert(model.to_string(), strategy_name.to_string());
99 }
100
101 pub async fn record_success(
102 &self,
103 deployment_id: &str,
104 latency_ms: f64,
105 tokens: u64,
106 state: &dyn RoutingState,
107 ) {
108 let _ = state
109 .record_request_success(deployment_id, latency_ms, tokens)
110 .await;
111 }
112
113 pub async fn record_failure(&self, deployment_id: &str, state: &dyn RoutingState) {
114 let _ = state.record_request_failure(deployment_id).await;
115 }
116
117 pub fn resolve_alias(model: &str, aliases: &HashMap<String, String>) -> String {
118 aliases
119 .get(model)
120 .cloned()
121 .unwrap_or_else(|| model.to_string())
122 }
123
124 pub async fn select_deployment(
125 &self,
126 model: &str,
127 state: &dyn RoutingState,
128 ctx: &RoutingContext,
129 ) -> Result<RoutingResult, RoutingError> {
130 let aliases = self.model_aliases.read().await;
131 let resolved_model = Self::resolve_alias(model, &aliases);
132 drop(aliases);
133
134 let pool = self.pool.read().await;
135 let candidates = pool
136 .get(&resolved_model)
137 .ok_or_else(|| RoutingError::NoDeployments(resolved_model.clone()))?;
138
139 let candidates_vec: Vec<Arc<Deployment>> = candidates.to_vec();
140
141 if candidates_vec.is_empty() {
142 return Err(RoutingError::NoDeployments(resolved_model.clone()));
143 }
144
145 let strategy_name = {
146 let groups = self.routing_groups.read().await;
147 if let Some(name) = groups.get(&resolved_model) {
148 name.clone()
149 } else {
150 let ds = self.default_strategy.read().await;
151 ds.clone()
152 }
153 };
154
155 let strategies = self.strategies.read().await;
156 let strategy = strategies.get(&strategy_name).ok_or_else(|| {
157 RoutingError::StrategyError(format!("strategy '{}' not found", strategy_name))
158 })?;
159
160 let selected = strategy
161 .select(&resolved_model, &candidates_vec, state, ctx)
162 .await?;
163
164 Ok(RoutingResult {
165 deployment: Arc::clone(selected),
166 attempt: 1,
167 fallback_chain: vec![resolved_model],
168 })
169 }
170
171 pub async fn route_with_fallback<T: Send + 'static>(
172 &self,
173 model: &str,
174 state: &dyn RoutingState,
175 ctx: &RoutingContext,
176 executor: impl Fn(
177 Arc<Deployment>,
178 )
179 -> Pin<Box<dyn Future<Output = Result<T, hyperinfer_core::HyperInferError>> + Send>>
180 + Send,
181 ) -> Result<(RoutingResult, T), RoutingError> {
182 let start_time = Instant::now();
183 let mut total_attempts: u32 = 0;
184
185 let aliases = self.model_aliases.read().await;
186 let initial_model = Self::resolve_alias(model, &aliases);
187 drop(aliases);
188
189 let mut current_model = initial_model.clone();
190 let mut fallback_chain = vec![current_model.clone()];
191 let mut excluded_ids: HashSet<String> = HashSet::new();
192 let mut fallback_models_tried: HashSet<String> = HashSet::new();
193
194 let fallback_config = self.fallback_config.read().await.clone();
195
196 loop {
197 if total_attempts >= self.global_limits.max_total_attempts {
198 return Err(RoutingError::MaxAttemptsExceeded {
199 attempts: total_attempts,
200 });
201 }
202
203 if start_time.elapsed().as_millis() as u64 >= self.global_limits.global_timeout_ms {
204 return Err(RoutingError::GlobalTimeout {
205 timeout_ms: self.global_limits.global_timeout_ms,
206 });
207 }
208
209 let candidates = {
210 let pool = self.pool.read().await;
211 pool.get(¤t_model).map(|c| c.to_vec())
212 };
213
214 let candidates = match candidates {
215 Some(c) if !c.is_empty() => c,
216 _ => {
217 let fallbacks =
218 fallback_config.get_fallbacks(¤t_model, &ErrorKind::Other);
219 let next = fallbacks
220 .into_iter()
221 .find(|m| !fallback_models_tried.contains(m));
222 match next {
223 Some(m) => {
224 fallback_models_tried.insert(m.clone());
225 current_model = m.clone();
226 fallback_chain.push(m);
227 continue;
228 }
229 None => {
230 return Err(RoutingError::AllDeploymentsFailed(initial_model));
231 }
232 }
233 }
234 };
235
236 let eligible: Vec<Arc<Deployment>> = candidates
237 .into_iter()
238 .filter(|d| !excluded_ids.contains(&d.id))
239 .collect();
240
241 if eligible.is_empty() {
242 let fallbacks = fallback_config.get_fallbacks(¤t_model, &ErrorKind::Other);
243 let next = fallbacks
244 .into_iter()
245 .find(|m| !fallback_models_tried.contains(m));
246 match next {
247 Some(m) => {
248 fallback_models_tried.insert(m.clone());
249 current_model = m.clone();
250 fallback_chain.push(m);
251 continue;
252 }
253 None => {
254 return Err(RoutingError::AllDeploymentsFailed(initial_model));
255 }
256 }
257 }
258
259 let strategy_name = {
260 let groups = self.routing_groups.read().await;
261 if let Some(name) = groups.get(¤t_model) {
262 name.clone()
263 } else {
264 let ds = self.default_strategy.read().await;
265 ds.clone()
266 }
267 };
268
269 let selected = {
270 let strategies = self.strategies.read().await;
271 let strategy = strategies.get(&strategy_name).ok_or_else(|| {
272 RoutingError::StrategyError(format!("strategy '{}' not found", strategy_name))
273 })?;
274 strategy
275 .select(¤t_model, &eligible, state, ctx)
276 .await?
277 .clone()
278 };
279
280 total_attempts += 1;
281 let _ = state.record_request_start(&selected.id).await;
282
283 let deployment_id = selected.id.clone();
284 let executor_future = executor(Arc::clone(&selected));
285 let join_result = tokio::spawn(executor_future).await;
286
287 match join_result {
288 Ok(Ok(value)) => {
289 return Ok((
290 RoutingResult {
291 deployment: selected,
292 attempt: total_attempts,
293 fallback_chain,
294 },
295 value,
296 ));
297 }
298 Ok(Err(err)) => {
299 let error_kind = ErrorKind::classify(&err);
300 let _ = state.record_request_failure(&deployment_id).await;
301 excluded_ids.insert(deployment_id);
302
303 let same_model_candidates = {
304 let pool = self.pool.read().await;
305 pool.get(¤t_model).map(|c| c.to_vec())
306 };
307
308 let same_model_remaining = same_model_candidates
309 .map(|c| {
310 c.into_iter()
311 .filter(|d| !excluded_ids.contains(&d.id))
312 .count()
313 })
314 .unwrap_or(0);
315
316 if same_model_remaining > 0
317 && total_attempts < self.global_limits.max_total_attempts
318 {
319 continue;
320 }
321
322 let fallbacks = fallback_config.get_fallbacks(¤t_model, &error_kind);
323 let next = fallbacks
324 .into_iter()
325 .find(|m| !fallback_models_tried.contains(m));
326 match next {
327 Some(m) => {
328 fallback_models_tried.insert(m.clone());
329 current_model = m.clone();
330 fallback_chain.push(m);
331 excluded_ids.clear();
332 }
333 None => {
334 return Err(RoutingError::AllDeploymentsFailed(initial_model));
335 }
336 }
337 }
338 Err(_panic) => {
339 let _ = state.record_request_failure(&deployment_id).await;
340 excluded_ids.insert(deployment_id);
341 return Err(RoutingError::ExecutorPanic);
342 }
343 }
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::deployment::Deployment;
352 use crate::strategy::weighted_shuffle::tests_helpers::MockState;
353 use crate::strategy::weighted_shuffle::WeightedShuffle;
354 use crate::strategy::RoutingStrategyExt;
355 use hyperinfer_core::Provider;
356 use std::sync::Arc;
357
358 fn make_deployment(model_name: &str, id: &str) -> Deployment {
359 let mut d = Deployment::new(
360 model_name.to_string(),
361 Provider::OpenAI,
362 model_name.to_string(),
363 format!("key-{}", id),
364 );
365 d.id = id.to_string();
366 d
367 }
368
369 #[tokio::test]
370 async fn test_select_deployment_basic() {
371 let engine = RouterEngine::new(GlobalLimits::default());
372 engine
373 .register_strategy(WeightedShuffle::new().boxed())
374 .await;
375 engine.add_deployment(make_deployment("gpt-4", "d1")).await;
376
377 let state = MockState::new();
378 let ctx = RoutingContext::default();
379 let result = engine
380 .select_deployment("gpt-4", &state, &ctx)
381 .await
382 .unwrap();
383
384 assert_eq!(result.deployment.id, "d1");
385 assert_eq!(result.attempt, 1);
386 assert_eq!(result.fallback_chain, vec!["gpt-4"]);
387 }
388
389 #[tokio::test]
390 async fn test_alias_resolution() {
391 let engine = RouterEngine::new(GlobalLimits::default());
392 engine
393 .register_strategy(WeightedShuffle::new().boxed())
394 .await;
395 engine
396 .add_deployment(make_deployment("gpt-4-turbo", "d1"))
397 .await;
398 engine.set_alias("smart", "gpt-4-turbo").await;
399
400 let state = MockState::new();
401 let ctx = RoutingContext::default();
402 let result = engine
403 .select_deployment("smart", &state, &ctx)
404 .await
405 .unwrap();
406
407 assert_eq!(result.deployment.id, "d1");
408 assert_eq!(result.deployment.model_name, "gpt-4-turbo");
409 }
410
411 #[tokio::test]
412 async fn test_no_deployments_error() {
413 let engine = RouterEngine::new(GlobalLimits::default());
414 engine
415 .register_strategy(WeightedShuffle::new().boxed())
416 .await;
417
418 let state = MockState::new();
419 let ctx = RoutingContext::default();
420 let result = engine.select_deployment("nonexistent", &state, &ctx).await;
421
422 assert!(result.is_err());
423 assert!(matches!(
424 result.unwrap_err(),
425 RoutingError::NoDeployments(_)
426 ));
427 }
428
429 #[tokio::test]
430 async fn test_routing_group_strategy_selection() {
431 let engine = RouterEngine::new(GlobalLimits::default());
432 engine
433 .register_strategy(WeightedShuffle::new().boxed())
434 .await;
435 engine
436 .register_strategy(Box::new(WeightedShuffle::new()))
437 .await;
438 engine.add_deployment(make_deployment("gpt-4", "d1")).await;
439
440 engine.set_routing_group("gpt-4", "weighted-shuffle").await;
441
442 let state = MockState::new();
443 let ctx = RoutingContext::default();
444 let result = engine
445 .select_deployment("gpt-4", &state, &ctx)
446 .await
447 .unwrap();
448 assert_eq!(result.deployment.id, "d1");
449 }
450
451 #[tokio::test]
452 async fn test_global_limits_max_attempts() {
453 let limits = GlobalLimits {
454 max_total_attempts: 2,
455 global_timeout_ms: 60_000,
456 };
457 let engine = RouterEngine::new(limits);
458 engine
459 .register_strategy(WeightedShuffle::new().boxed())
460 .await;
461 engine.add_deployment(make_deployment("gpt-4", "d1")).await;
462 engine.add_deployment(make_deployment("gpt-4", "d2")).await;
463
464 let state = MockState::new();
465 let ctx = RoutingContext::default();
466
467 let executor = |_d: Arc<Deployment>| -> Pin<
468 Box<dyn Future<Output = Result<(), hyperinfer_core::HyperInferError>> + Send>,
469 > {
470 Box::pin(async {
471 Err::<(), _>(hyperinfer_core::HyperInferError::ApiError {
472 status: 500,
473 message: "fail".into(),
474 })
475 })
476 };
477
478 let result = engine
479 .route_with_fallback::<()>("gpt-4", &state, &ctx, executor)
480 .await;
481
482 assert!(result.is_err());
483 let err = result.unwrap_err();
484 assert!(
485 matches!(
486 err,
487 RoutingError::AllDeploymentsFailed(_) | RoutingError::MaxAttemptsExceeded { .. }
488 ),
489 "expected AllDeploymentsFailed or MaxAttemptsExceeded, got: {:?}",
490 err
491 );
492 }
493
494 #[tokio::test]
495 async fn test_record_success_passthrough() {
496 let engine = RouterEngine::new(GlobalLimits::default());
497 let state = MockState::new();
498 engine.record_success("d1", 100.0, 500, &state).await;
499 }
500
501 #[tokio::test]
502 async fn test_record_failure_passthrough() {
503 let engine = RouterEngine::new(GlobalLimits::default());
504 let state = MockState::new();
505 engine.record_failure("d1", &state).await;
506 }
507}