Skip to main content

hyperinfer_client/
router_engine.rs

1use hyperinfer_core::{ChatRequest, Deployment, Provider};
2use hyperinfer_router::{
3    deployment::Deployment as RouterDeployment,
4    engine::{GlobalLimits, RouterEngine as HyperInferRouterEngine, RoutingResult},
5    error::RoutingError,
6    strategy::{
7        cost_based::CostBased, latency_based::LatencyBased, least_busy::LeastBusy,
8        usage_based::UsageBased, weighted_shuffle::WeightedShuffle, RoutingContext, RoutingState,
9    },
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Client-side router engine that wraps the core routing engine
15pub struct RouterEngine {
16    inner: Arc<HyperInferRouterEngine>,
17    state: Arc<dyn RoutingState>,
18}
19
20impl RouterEngine {
21    /// Create a new router engine with default configuration and all built-in strategies
22    pub async fn new() -> Self {
23        let engine = HyperInferRouterEngine::new(GlobalLimits::default());
24        Self::register_all_strategies(&engine).await;
25        let state = Arc::new(NoopState);
26        Self {
27            inner: Arc::new(engine),
28            state,
29        }
30    }
31
32    /// Create a new router engine with custom global limits
33    pub async fn with_limits(limits: GlobalLimits) -> Self {
34        let engine = HyperInferRouterEngine::new(limits);
35        Self::register_all_strategies(&engine).await;
36        let state = Arc::new(NoopState);
37        Self {
38            inner: Arc::new(engine),
39            state,
40        }
41    }
42
43    async fn register_all_strategies(engine: &HyperInferRouterEngine) {
44        engine
45            .register_strategy(Box::new(WeightedShuffle::new()))
46            .await;
47        engine
48            .register_strategy(Box::new(LatencyBased::new()))
49            .await;
50        engine.register_strategy(Box::new(LeastBusy::new())).await;
51        engine.register_strategy(Box::new(UsageBased::new())).await;
52        engine.register_strategy(Box::new(CostBased::new())).await;
53    }
54
55    /// Load deployments into the routing pool
56    pub async fn load_deployments(&self, deployments: Vec<Deployment>) {
57        for d in deployments {
58            let router_deployment = self.core_to_router_deployment(d);
59            self.inner.add_deployment(router_deployment).await;
60        }
61    }
62
63    /// Rebuild the deployment pool with a fresh set
64    pub async fn rebuild_pool(&self, deployments: Vec<Deployment>) {
65        let router_deployments: Vec<RouterDeployment> = deployments
66            .into_iter()
67            .map(|d| self.core_to_router_deployment(d))
68            .collect();
69        self.inner.rebuild_pool(router_deployments).await;
70    }
71
72    /// Select a deployment for a request using routing strategies
73    pub async fn select_deployment(
74        &self,
75        request: &ChatRequest,
76    ) -> Result<RoutingResult, RoutingError> {
77        let ctx = RoutingContext::default();
78        self.inner
79            .select_deployment(&request.model, self.state.as_ref(), &ctx)
80            .await
81    }
82
83    /// Record a successful request for metrics
84    pub async fn record_success(&self, deployment_id: &str, latency_ms: f64, tokens: u64) {
85        self.inner
86            .record_success(deployment_id, latency_ms, tokens, self.state.as_ref())
87            .await;
88    }
89
90    /// Record a failed request for metrics
91    pub async fn record_failure(&self, deployment_id: &str) {
92        self.inner
93            .record_failure(deployment_id, self.state.as_ref())
94            .await;
95    }
96
97    /// Get the inner router engine (for advanced use)
98    pub fn inner(&self) -> &Arc<HyperInferRouterEngine> {
99        &self.inner
100    }
101
102    fn core_to_router_deployment(&self, d: Deployment) -> RouterDeployment {
103        let provider = match d.provider.as_str() {
104            "openai" => Provider::OpenAI,
105            "anthropic" => Provider::Anthropic,
106            _ => Provider::Other,
107        };
108        let mut router_deployment = RouterDeployment::new(
109            d.name.clone(),
110            provider,
111            d.model.clone(),
112            d.api_key_ref.clone(),
113        );
114        router_deployment.id = d.id.clone();
115        if !d.base_url.is_empty() {
116            router_deployment = router_deployment.with_base_url(d.base_url.clone());
117        }
118        router_deployment = router_deployment.with_weight(d.weight);
119        if let Some(max_tpm) = d.max_tpm {
120            router_deployment = router_deployment.with_tpm_limit(max_tpm as u64);
121        }
122        if let Some(max_rpm) = d.max_rpm {
123            router_deployment = router_deployment.with_rpm_limit(max_rpm as u64);
124        }
125        if let Some(cost) = d.cost_per_1k_input_tokens {
126            router_deployment = router_deployment.with_input_cost(cost);
127        }
128        if let Some(cost) = d.cost_per_1k_output_tokens {
129            router_deployment = router_deployment.with_output_cost(cost);
130        }
131        router_deployment
132    }
133}
134
135/// Noop routing state for client-side routing (metrics tracked via RedisRoutingState)
136struct NoopState;
137
138#[async_trait::async_trait]
139impl RoutingState for NoopState {
140    async fn get_metrics(
141        &self,
142        _deployment_id: &str,
143    ) -> Result<hyperinfer_router::strategy::DeploymentMetrics, RoutingError> {
144        Ok(hyperinfer_router::strategy::DeploymentMetrics::default())
145    }
146
147    async fn get_all_metrics(
148        &self,
149        _ids: &[&str],
150    ) -> Result<HashMap<String, hyperinfer_router::strategy::DeploymentMetrics>, RoutingError> {
151        Ok(HashMap::new())
152    }
153
154    async fn is_cooled_down(&self, _deployment_id: &str) -> Result<bool, RoutingError> {
155        Ok(false)
156    }
157
158    async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
159        Ok(())
160    }
161
162    async fn record_request_success(
163        &self,
164        _deployment_id: &str,
165        _latency_ms: f64,
166        _tokens: u64,
167    ) -> Result<(), RoutingError> {
168        Ok(())
169    }
170
171    async fn record_request_failure(
172        &self,
173        _deployment_id: &str,
174    ) -> Result<hyperinfer_router::strategy::RecordFailureResult, RoutingError> {
175        Ok(hyperinfer_router::strategy::RecordFailureResult {
176            failure_count: 0,
177            cooldown_triggered: false,
178        })
179    }
180}