Skip to main content

hyperinfer_router/
engine.rs

1use crate::deployment::{Deployment, DeploymentPool};
2use crate::error::RoutingError;
3use crate::fallback::{ErrorKind, FallbackConfig};
4use crate::strategy::{RoutingContext, RoutingState, RoutingStrategy};
5use std::collections::{HashMap, HashSet};
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone)]
13pub struct GlobalLimits {
14    pub max_total_attempts: u32,
15    pub global_timeout_ms: u64,
16}
17
18impl Default for GlobalLimits {
19    fn default() -> Self {
20        Self {
21            max_total_attempts: 10,
22            global_timeout_ms: 60_000,
23        }
24    }
25}
26
27#[derive(Debug, Clone)]
28pub struct RoutingResult {
29    pub deployment: Arc<Deployment>,
30    pub attempt: u32,
31    pub fallback_chain: Vec<String>,
32}
33
34pub struct RouterEngine {
35    pool: Arc<RwLock<DeploymentPool>>,
36    strategies: Arc<RwLock<HashMap<String, Box<dyn RoutingStrategy>>>>,
37    default_strategy: Arc<RwLock<String>>,
38    fallback_config: Arc<RwLock<FallbackConfig>>,
39    global_limits: GlobalLimits,
40    model_aliases: Arc<RwLock<HashMap<String, String>>>,
41    routing_groups: Arc<RwLock<HashMap<String, String>>>,
42}
43
44impl RouterEngine {
45    pub fn new(global_limits: GlobalLimits) -> Self {
46        Self {
47            pool: Arc::new(RwLock::new(DeploymentPool::new())),
48            strategies: Arc::new(RwLock::new(HashMap::new())),
49            default_strategy: Arc::new(RwLock::new("weighted-shuffle".to_string())),
50            fallback_config: Arc::new(RwLock::new(FallbackConfig::new())),
51            global_limits,
52            model_aliases: Arc::new(RwLock::new(HashMap::new())),
53            routing_groups: Arc::new(RwLock::new(HashMap::new())),
54        }
55    }
56
57    pub async fn add_deployment(&self, deployment: Deployment) {
58        let mut pool = self.pool.write().await;
59        pool.add(deployment);
60    }
61
62    pub async fn remove_deployment(&self, deployment_id: &str) {
63        let mut pool = self.pool.write().await;
64        pool.remove(deployment_id);
65    }
66
67    pub async fn rebuild_pool(&self, deployments: Vec<Deployment>) {
68        let mut pool = self.pool.write().await;
69        *pool = DeploymentPool::new();
70        for d in deployments {
71            pool.add(d);
72        }
73    }
74
75    pub async fn register_strategy(&self, strategy: Box<dyn RoutingStrategy>) {
76        let name = strategy.name().to_string();
77        let mut strategies = self.strategies.write().await;
78        strategies.insert(name, strategy);
79    }
80
81    pub async fn set_default_strategy(&self, name: &str) {
82        let mut ds = self.default_strategy.write().await;
83        *ds = name.to_string();
84    }
85
86    pub async fn set_fallback_config(&self, config: FallbackConfig) {
87        let mut fc = self.fallback_config.write().await;
88        *fc = config;
89    }
90
91    pub async fn set_alias(&self, alias: &str, target: &str) {
92        let mut aliases = self.model_aliases.write().await;
93        aliases.insert(alias.to_string(), target.to_string());
94    }
95
96    pub async fn set_routing_group(&self, model: &str, strategy_name: &str) {
97        let mut groups = self.routing_groups.write().await;
98        groups.insert(model.to_string(), strategy_name.to_string());
99    }
100
101    pub async fn record_success(
102        &self,
103        deployment_id: &str,
104        latency_ms: f64,
105        tokens: u64,
106        state: &dyn RoutingState,
107    ) {
108        let _ = state
109            .record_request_success(deployment_id, latency_ms, tokens)
110            .await;
111    }
112
113    pub async fn record_failure(&self, deployment_id: &str, state: &dyn RoutingState) {
114        let _ = state.record_request_failure(deployment_id).await;
115    }
116
117    pub fn resolve_alias(model: &str, aliases: &HashMap<String, String>) -> String {
118        aliases
119            .get(model)
120            .cloned()
121            .unwrap_or_else(|| model.to_string())
122    }
123
124    pub async fn select_deployment(
125        &self,
126        model: &str,
127        state: &dyn RoutingState,
128        ctx: &RoutingContext,
129    ) -> Result<RoutingResult, RoutingError> {
130        let aliases = self.model_aliases.read().await;
131        let resolved_model = Self::resolve_alias(model, &aliases);
132        drop(aliases);
133
134        let pool = self.pool.read().await;
135        let candidates = pool
136            .get(&resolved_model)
137            .ok_or_else(|| RoutingError::NoDeployments(resolved_model.clone()))?;
138
139        let candidates_vec: Vec<Arc<Deployment>> = candidates.to_vec();
140
141        if candidates_vec.is_empty() {
142            return Err(RoutingError::NoDeployments(resolved_model.clone()));
143        }
144
145        let strategy_name = {
146            let groups = self.routing_groups.read().await;
147            if let Some(name) = groups.get(&resolved_model) {
148                name.clone()
149            } else {
150                let ds = self.default_strategy.read().await;
151                ds.clone()
152            }
153        };
154
155        let strategies = self.strategies.read().await;
156        let strategy = strategies.get(&strategy_name).ok_or_else(|| {
157            RoutingError::StrategyError(format!("strategy '{}' not found", strategy_name))
158        })?;
159
160        let selected = strategy
161            .select(&resolved_model, &candidates_vec, state, ctx)
162            .await?;
163
164        Ok(RoutingResult {
165            deployment: Arc::clone(selected),
166            attempt: 1,
167            fallback_chain: vec![resolved_model],
168        })
169    }
170
171    pub async fn route_with_fallback<T: Send + 'static>(
172        &self,
173        model: &str,
174        state: &dyn RoutingState,
175        ctx: &RoutingContext,
176        executor: impl Fn(
177                Arc<Deployment>,
178            )
179                -> Pin<Box<dyn Future<Output = Result<T, hyperinfer_core::HyperInferError>> + Send>>
180            + Send,
181    ) -> Result<(RoutingResult, T), RoutingError> {
182        let start_time = Instant::now();
183        let mut total_attempts: u32 = 0;
184
185        let aliases = self.model_aliases.read().await;
186        let initial_model = Self::resolve_alias(model, &aliases);
187        drop(aliases);
188
189        let mut current_model = initial_model.clone();
190        let mut fallback_chain = vec![current_model.clone()];
191        let mut excluded_ids: HashSet<String> = HashSet::new();
192        let mut fallback_models_tried: HashSet<String> = HashSet::new();
193
194        let fallback_config = self.fallback_config.read().await.clone();
195
196        loop {
197            if total_attempts >= self.global_limits.max_total_attempts {
198                return Err(RoutingError::MaxAttemptsExceeded {
199                    attempts: total_attempts,
200                });
201            }
202
203            if start_time.elapsed().as_millis() as u64 >= self.global_limits.global_timeout_ms {
204                return Err(RoutingError::GlobalTimeout {
205                    timeout_ms: self.global_limits.global_timeout_ms,
206                });
207            }
208
209            let candidates = {
210                let pool = self.pool.read().await;
211                pool.get(&current_model).map(|c| c.to_vec())
212            };
213
214            let candidates = match candidates {
215                Some(c) if !c.is_empty() => c,
216                _ => {
217                    let fallbacks =
218                        fallback_config.get_fallbacks(&current_model, &ErrorKind::Other);
219                    let next = fallbacks
220                        .into_iter()
221                        .find(|m| !fallback_models_tried.contains(m));
222                    match next {
223                        Some(m) => {
224                            fallback_models_tried.insert(m.clone());
225                            current_model = m.clone();
226                            fallback_chain.push(m);
227                            continue;
228                        }
229                        None => {
230                            return Err(RoutingError::AllDeploymentsFailed(initial_model));
231                        }
232                    }
233                }
234            };
235
236            let eligible: Vec<Arc<Deployment>> = candidates
237                .into_iter()
238                .filter(|d| !excluded_ids.contains(&d.id))
239                .collect();
240
241            if eligible.is_empty() {
242                let fallbacks = fallback_config.get_fallbacks(&current_model, &ErrorKind::Other);
243                let next = fallbacks
244                    .into_iter()
245                    .find(|m| !fallback_models_tried.contains(m));
246                match next {
247                    Some(m) => {
248                        fallback_models_tried.insert(m.clone());
249                        current_model = m.clone();
250                        fallback_chain.push(m);
251                        continue;
252                    }
253                    None => {
254                        return Err(RoutingError::AllDeploymentsFailed(initial_model));
255                    }
256                }
257            }
258
259            let strategy_name = {
260                let groups = self.routing_groups.read().await;
261                if let Some(name) = groups.get(&current_model) {
262                    name.clone()
263                } else {
264                    let ds = self.default_strategy.read().await;
265                    ds.clone()
266                }
267            };
268
269            let selected = {
270                let strategies = self.strategies.read().await;
271                let strategy = strategies.get(&strategy_name).ok_or_else(|| {
272                    RoutingError::StrategyError(format!("strategy '{}' not found", strategy_name))
273                })?;
274                strategy
275                    .select(&current_model, &eligible, state, ctx)
276                    .await?
277                    .clone()
278            };
279
280            total_attempts += 1;
281            let _ = state.record_request_start(&selected.id).await;
282
283            let deployment_id = selected.id.clone();
284            let executor_future = executor(Arc::clone(&selected));
285            let join_result = tokio::spawn(executor_future).await;
286
287            match join_result {
288                Ok(Ok(value)) => {
289                    return Ok((
290                        RoutingResult {
291                            deployment: selected,
292                            attempt: total_attempts,
293                            fallback_chain,
294                        },
295                        value,
296                    ));
297                }
298                Ok(Err(err)) => {
299                    let error_kind = ErrorKind::classify(&err);
300                    let _ = state.record_request_failure(&deployment_id).await;
301                    excluded_ids.insert(deployment_id);
302
303                    let same_model_candidates = {
304                        let pool = self.pool.read().await;
305                        pool.get(&current_model).map(|c| c.to_vec())
306                    };
307
308                    let same_model_remaining = same_model_candidates
309                        .map(|c| {
310                            c.into_iter()
311                                .filter(|d| !excluded_ids.contains(&d.id))
312                                .count()
313                        })
314                        .unwrap_or(0);
315
316                    if same_model_remaining > 0
317                        && total_attempts < self.global_limits.max_total_attempts
318                    {
319                        continue;
320                    }
321
322                    let fallbacks = fallback_config.get_fallbacks(&current_model, &error_kind);
323                    let next = fallbacks
324                        .into_iter()
325                        .find(|m| !fallback_models_tried.contains(m));
326                    match next {
327                        Some(m) => {
328                            fallback_models_tried.insert(m.clone());
329                            current_model = m.clone();
330                            fallback_chain.push(m);
331                            excluded_ids.clear();
332                        }
333                        None => {
334                            return Err(RoutingError::AllDeploymentsFailed(initial_model));
335                        }
336                    }
337                }
338                Err(_panic) => {
339                    let _ = state.record_request_failure(&deployment_id).await;
340                    excluded_ids.insert(deployment_id);
341                    return Err(RoutingError::ExecutorPanic);
342                }
343            }
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::deployment::Deployment;
352    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
353    use crate::strategy::weighted_shuffle::WeightedShuffle;
354    use crate::strategy::RoutingStrategyExt;
355    use hyperinfer_core::Provider;
356    use std::sync::Arc;
357
358    fn make_deployment(model_name: &str, id: &str) -> Deployment {
359        let mut d = Deployment::new(
360            model_name.to_string(),
361            Provider::OpenAI,
362            model_name.to_string(),
363            format!("key-{}", id),
364        );
365        d.id = id.to_string();
366        d
367    }
368
369    #[tokio::test]
370    async fn test_select_deployment_basic() {
371        let engine = RouterEngine::new(GlobalLimits::default());
372        engine
373            .register_strategy(WeightedShuffle::new().boxed())
374            .await;
375        engine.add_deployment(make_deployment("gpt-4", "d1")).await;
376
377        let state = MockState::new();
378        let ctx = RoutingContext::default();
379        let result = engine
380            .select_deployment("gpt-4", &state, &ctx)
381            .await
382            .unwrap();
383
384        assert_eq!(result.deployment.id, "d1");
385        assert_eq!(result.attempt, 1);
386        assert_eq!(result.fallback_chain, vec!["gpt-4"]);
387    }
388
389    #[tokio::test]
390    async fn test_alias_resolution() {
391        let engine = RouterEngine::new(GlobalLimits::default());
392        engine
393            .register_strategy(WeightedShuffle::new().boxed())
394            .await;
395        engine
396            .add_deployment(make_deployment("gpt-4-turbo", "d1"))
397            .await;
398        engine.set_alias("smart", "gpt-4-turbo").await;
399
400        let state = MockState::new();
401        let ctx = RoutingContext::default();
402        let result = engine
403            .select_deployment("smart", &state, &ctx)
404            .await
405            .unwrap();
406
407        assert_eq!(result.deployment.id, "d1");
408        assert_eq!(result.deployment.model_name, "gpt-4-turbo");
409    }
410
411    #[tokio::test]
412    async fn test_no_deployments_error() {
413        let engine = RouterEngine::new(GlobalLimits::default());
414        engine
415            .register_strategy(WeightedShuffle::new().boxed())
416            .await;
417
418        let state = MockState::new();
419        let ctx = RoutingContext::default();
420        let result = engine.select_deployment("nonexistent", &state, &ctx).await;
421
422        assert!(result.is_err());
423        assert!(matches!(
424            result.unwrap_err(),
425            RoutingError::NoDeployments(_)
426        ));
427    }
428
429    #[tokio::test]
430    async fn test_routing_group_strategy_selection() {
431        let engine = RouterEngine::new(GlobalLimits::default());
432        engine
433            .register_strategy(WeightedShuffle::new().boxed())
434            .await;
435        engine
436            .register_strategy(Box::new(WeightedShuffle::new()))
437            .await;
438        engine.add_deployment(make_deployment("gpt-4", "d1")).await;
439
440        engine.set_routing_group("gpt-4", "weighted-shuffle").await;
441
442        let state = MockState::new();
443        let ctx = RoutingContext::default();
444        let result = engine
445            .select_deployment("gpt-4", &state, &ctx)
446            .await
447            .unwrap();
448        assert_eq!(result.deployment.id, "d1");
449    }
450
451    #[tokio::test]
452    async fn test_global_limits_max_attempts() {
453        let limits = GlobalLimits {
454            max_total_attempts: 2,
455            global_timeout_ms: 60_000,
456        };
457        let engine = RouterEngine::new(limits);
458        engine
459            .register_strategy(WeightedShuffle::new().boxed())
460            .await;
461        engine.add_deployment(make_deployment("gpt-4", "d1")).await;
462        engine.add_deployment(make_deployment("gpt-4", "d2")).await;
463
464        let state = MockState::new();
465        let ctx = RoutingContext::default();
466
467        let executor = |_d: Arc<Deployment>| -> Pin<
468            Box<dyn Future<Output = Result<(), hyperinfer_core::HyperInferError>> + Send>,
469        > {
470            Box::pin(async {
471                Err::<(), _>(hyperinfer_core::HyperInferError::ApiError {
472                    status: 500,
473                    message: "fail".into(),
474                })
475            })
476        };
477
478        let result = engine
479            .route_with_fallback::<()>("gpt-4", &state, &ctx, executor)
480            .await;
481
482        assert!(result.is_err());
483        let err = result.unwrap_err();
484        assert!(
485            matches!(
486                err,
487                RoutingError::AllDeploymentsFailed(_) | RoutingError::MaxAttemptsExceeded { .. }
488            ),
489            "expected AllDeploymentsFailed or MaxAttemptsExceeded, got: {:?}",
490            err
491        );
492    }
493
494    #[tokio::test]
495    async fn test_record_success_passthrough() {
496        let engine = RouterEngine::new(GlobalLimits::default());
497        let state = MockState::new();
498        engine.record_success("d1", 100.0, 500, &state).await;
499    }
500
501    #[tokio::test]
502    async fn test_record_failure_passthrough() {
503        let engine = RouterEngine::new(GlobalLimits::default());
504        let state = MockState::new();
505        engine.record_failure("d1", &state).await;
506    }
507}