hyperinfer_client/
router_engine.rs1use hyperinfer_core::{ChatRequest, Deployment, Provider};
2use hyperinfer_router::{
3 deployment::Deployment as RouterDeployment,
4 engine::{GlobalLimits, RouterEngine as HyperInferRouterEngine, RoutingResult},
5 error::RoutingError,
6 strategy::{
7 cost_based::CostBased, latency_based::LatencyBased, least_busy::LeastBusy,
8 usage_based::UsageBased, weighted_shuffle::WeightedShuffle, RoutingContext, RoutingState,
9 },
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14pub struct RouterEngine {
16 inner: Arc<HyperInferRouterEngine>,
17 state: Arc<dyn RoutingState>,
18}
19
20impl RouterEngine {
21 pub async fn new() -> Self {
23 let engine = HyperInferRouterEngine::new(GlobalLimits::default());
24 Self::register_all_strategies(&engine).await;
25 let state = Arc::new(NoopState);
26 Self {
27 inner: Arc::new(engine),
28 state,
29 }
30 }
31
32 pub async fn with_limits(limits: GlobalLimits) -> Self {
34 let engine = HyperInferRouterEngine::new(limits);
35 Self::register_all_strategies(&engine).await;
36 let state = Arc::new(NoopState);
37 Self {
38 inner: Arc::new(engine),
39 state,
40 }
41 }
42
43 async fn register_all_strategies(engine: &HyperInferRouterEngine) {
44 engine
45 .register_strategy(Box::new(WeightedShuffle::new()))
46 .await;
47 engine
48 .register_strategy(Box::new(LatencyBased::new()))
49 .await;
50 engine.register_strategy(Box::new(LeastBusy::new())).await;
51 engine.register_strategy(Box::new(UsageBased::new())).await;
52 engine.register_strategy(Box::new(CostBased::new())).await;
53 }
54
55 pub async fn load_deployments(&self, deployments: Vec<Deployment>) {
57 for d in deployments {
58 let router_deployment = self.core_to_router_deployment(d);
59 self.inner.add_deployment(router_deployment).await;
60 }
61 }
62
63 pub async fn rebuild_pool(&self, deployments: Vec<Deployment>) {
65 let router_deployments: Vec<RouterDeployment> = deployments
66 .into_iter()
67 .map(|d| self.core_to_router_deployment(d))
68 .collect();
69 self.inner.rebuild_pool(router_deployments).await;
70 }
71
72 pub async fn select_deployment(
74 &self,
75 request: &ChatRequest,
76 ) -> Result<RoutingResult, RoutingError> {
77 let ctx = RoutingContext::default();
78 self.inner
79 .select_deployment(&request.model, self.state.as_ref(), &ctx)
80 .await
81 }
82
83 pub async fn record_success(&self, deployment_id: &str, latency_ms: f64, tokens: u64) {
85 self.inner
86 .record_success(deployment_id, latency_ms, tokens, self.state.as_ref())
87 .await;
88 }
89
90 pub async fn record_failure(&self, deployment_id: &str) {
92 self.inner
93 .record_failure(deployment_id, self.state.as_ref())
94 .await;
95 }
96
97 pub fn inner(&self) -> &Arc<HyperInferRouterEngine> {
99 &self.inner
100 }
101
102 fn core_to_router_deployment(&self, d: Deployment) -> RouterDeployment {
103 let provider = match d.provider.as_str() {
104 "openai" => Provider::OpenAI,
105 "anthropic" => Provider::Anthropic,
106 _ => Provider::Other,
107 };
108 let mut router_deployment = RouterDeployment::new(
109 d.name.clone(),
110 provider,
111 d.model.clone(),
112 d.api_key_ref.clone(),
113 );
114 router_deployment.id = d.id.clone();
115 if !d.base_url.is_empty() {
116 router_deployment = router_deployment.with_base_url(d.base_url.clone());
117 }
118 router_deployment = router_deployment.with_weight(d.weight);
119 if let Some(max_tpm) = d.max_tpm {
120 router_deployment = router_deployment.with_tpm_limit(max_tpm as u64);
121 }
122 if let Some(max_rpm) = d.max_rpm {
123 router_deployment = router_deployment.with_rpm_limit(max_rpm as u64);
124 }
125 if let Some(cost) = d.cost_per_1k_input_tokens {
126 router_deployment = router_deployment.with_input_cost(cost);
127 }
128 if let Some(cost) = d.cost_per_1k_output_tokens {
129 router_deployment = router_deployment.with_output_cost(cost);
130 }
131 router_deployment
132 }
133}
134
135struct NoopState;
137
138#[async_trait::async_trait]
139impl RoutingState for NoopState {
140 async fn get_metrics(
141 &self,
142 _deployment_id: &str,
143 ) -> Result<hyperinfer_router::strategy::DeploymentMetrics, RoutingError> {
144 Ok(hyperinfer_router::strategy::DeploymentMetrics::default())
145 }
146
147 async fn get_all_metrics(
148 &self,
149 _ids: &[&str],
150 ) -> Result<HashMap<String, hyperinfer_router::strategy::DeploymentMetrics>, RoutingError> {
151 Ok(HashMap::new())
152 }
153
154 async fn is_cooled_down(&self, _deployment_id: &str) -> Result<bool, RoutingError> {
155 Ok(false)
156 }
157
158 async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
159 Ok(())
160 }
161
162 async fn record_request_success(
163 &self,
164 _deployment_id: &str,
165 _latency_ms: f64,
166 _tokens: u64,
167 ) -> Result<(), RoutingError> {
168 Ok(())
169 }
170
171 async fn record_request_failure(
172 &self,
173 _deployment_id: &str,
174 ) -> Result<hyperinfer_router::strategy::RecordFailureResult, RoutingError> {
175 Ok(hyperinfer_router::strategy::RecordFailureResult {
176 failure_count: 0,
177 cooldown_triggered: false,
178 })
179 }
180}