Skip to main content

hyperinfer_router/strategy/
mod.rs

1pub mod cost_based;
2pub mod latency_based;
3pub mod least_busy;
4pub mod usage_based;
5pub mod weighted_shuffle;
6
7use crate::deployment::Deployment;
8use crate::error::RoutingError;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct DeploymentMetrics {
16    pub latency_ewma_ms: f64,
17    pub in_flight: u64,
18    pub tpm_used: u64,
19    pub rpm_used: u64,
20    pub total_requests: u64,
21    pub total_failures: u64,
22    pub last_failure_ts: Option<u64>,
23}
24
25#[derive(Debug, Clone, Default)]
26pub struct RoutingContext {
27    pub estimated_input_tokens: Option<u64>,
28    pub estimated_output_tokens: Option<u64>,
29    pub team_id: Option<String>,
30}
31
32#[async_trait]
33pub trait RoutingStrategy: Send + Sync + dyn_clone::DynClone {
34    fn name(&self) -> &str;
35
36    async fn select<'a>(
37        &self,
38        model: &str,
39        candidates: &'a [Arc<Deployment>],
40        state: &dyn RoutingState,
41        request: &RoutingContext,
42    ) -> Result<&'a Arc<Deployment>, RoutingError>;
43}
44
45dyn_clone::clone_trait_object!(RoutingStrategy);
46
47#[async_trait]
48pub trait RoutingState: Send + Sync {
49    async fn get_metrics(&self, deployment_id: &str) -> Result<DeploymentMetrics, RoutingError>;
50
51    async fn get_all_metrics(
52        &self,
53        ids: &[&str],
54    ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError>;
55
56    async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError>;
57
58    async fn record_request_start(&self, deployment_id: &str) -> Result<(), RoutingError>;
59
60    async fn record_request_success(
61        &self,
62        deployment_id: &str,
63        latency_ms: f64,
64        tokens: u64,
65    ) -> Result<(), RoutingError>;
66
67    async fn record_request_failure(
68        &self,
69        deployment_id: &str,
70    ) -> Result<RecordFailureResult, RoutingError>;
71}
72
73#[derive(Debug, Clone)]
74pub struct RecordFailureResult {
75    pub failure_count: u64,
76    pub cooldown_triggered: bool,
77}
78
79#[async_trait]
80pub trait RoutingStrategyExt: RoutingStrategy {
81    fn boxed(self) -> Box<dyn RoutingStrategy>
82    where
83        Self: Sized + 'static,
84    {
85        Box::new(self)
86    }
87}
88
89impl<T: RoutingStrategy + 'static> RoutingStrategyExt for T {}