hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
pub mod cost_based;
pub mod latency_based;
pub mod least_busy;
pub mod usage_based;
pub mod weighted_shuffle;

use crate::deployment::Deployment;
use crate::error::RoutingError;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DeploymentMetrics {
    pub latency_ewma_ms: f64,
    pub in_flight: u64,
    pub tpm_used: u64,
    pub rpm_used: u64,
    pub total_requests: u64,
    pub total_failures: u64,
    pub last_failure_ts: Option<u64>,
}

#[derive(Debug, Clone, Default)]
pub struct RoutingContext {
    pub estimated_input_tokens: Option<u64>,
    pub estimated_output_tokens: Option<u64>,
    pub team_id: Option<String>,
}

#[async_trait]
pub trait RoutingStrategy: Send + Sync + dyn_clone::DynClone {
    fn name(&self) -> &str;

    async fn select<'a>(
        &self,
        model: &str,
        candidates: &'a [Arc<Deployment>],
        state: &dyn RoutingState,
        request: &RoutingContext,
    ) -> Result<&'a Arc<Deployment>, RoutingError>;
}

dyn_clone::clone_trait_object!(RoutingStrategy);

#[async_trait]
pub trait RoutingState: Send + Sync {
    async fn get_metrics(&self, deployment_id: &str) -> Result<DeploymentMetrics, RoutingError>;

    async fn get_all_metrics(
        &self,
        ids: &[&str],
    ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError>;

    async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError>;

    async fn record_request_start(&self, deployment_id: &str) -> Result<(), RoutingError>;

    async fn record_request_success(
        &self,
        deployment_id: &str,
        latency_ms: f64,
        tokens: u64,
    ) -> Result<(), RoutingError>;

    async fn record_request_failure(
        &self,
        deployment_id: &str,
    ) -> Result<RecordFailureResult, RoutingError>;
}

#[derive(Debug, Clone)]
pub struct RecordFailureResult {
    pub failure_count: u64,
    pub cooldown_triggered: bool,
}

#[async_trait]
pub trait RoutingStrategyExt: RoutingStrategy {
    fn boxed(self) -> Box<dyn RoutingStrategy>
    where
        Self: Sized + 'static,
    {
        Box::new(self)
    }
}

impl<T: RoutingStrategy + 'static> RoutingStrategyExt for T {}