hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
use super::{RoutingContext, RoutingState, RoutingStrategy};
use crate::deployment::Deployment;
use crate::error::RoutingError;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageBased;

impl Default for UsageBased {
    fn default() -> Self {
        Self
    }
}

impl UsageBased {
    pub fn new() -> Self {
        Self
    }
}

#[async_trait]
impl RoutingStrategy for UsageBased {
    fn name(&self) -> &str {
        "usage-based"
    }

    async fn select<'a>(
        &self,
        _model: &str,
        candidates: &'a [Arc<Deployment>],
        state: &dyn RoutingState,
        _request: &RoutingContext,
    ) -> Result<&'a Arc<Deployment>, RoutingError> {
        if candidates.is_empty() {
            return Err(RoutingError::NoDeployments("empty candidates".into()));
        }

        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
        let all_metrics = state.get_all_metrics(&ids).await?;

        let mut eligible: Vec<(usize, f64)> = Vec::new();

        for (i, deployment) in candidates.iter().enumerate() {
            if state.is_cooled_down(&deployment.id).await? {
                continue;
            }

            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();

            let utilization = match deployment.tpm_limit {
                Some(limit) if limit > 0 => metrics.tpm_used as f64 / limit as f64,
                _ => 0.0,
            };

            eligible.push((i, utilization));
        }

        if eligible.is_empty() {
            return Err(RoutingError::NoDeployments(
                "no eligible deployments after filtering".into(),
            ));
        }

        let best_idx = eligible
            .iter()
            .enumerate()
            .min_by(|(_, (_, a)), (_, (_, b))| {
                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
            })
            .map(|(idx, _)| idx)
            .unwrap();

        let (i, _) = eligible[best_idx];
        Ok(&candidates[i])
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::deployment::Deployment;
    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
    use crate::strategy::DeploymentMetrics;
    use hyperinfer_core::Provider;

    fn make_deployment(id: &str, weight: u32, tpm_limit: Option<u64>) -> Arc<Deployment> {
        let mut d = Deployment::new(
            "test-model".to_string(),
            Provider::OpenAI,
            "gpt-4".to_string(),
            format!("key-{}", id),
        );
        d.weight = weight;
        d.id = id.to_string();
        d.tpm_limit = tpm_limit;
        Arc::new(d)
    }

    #[tokio::test]
    async fn test_selects_lowest_utilization() {
        let d1 = make_deployment("d1", 1, Some(1000));
        let d2 = make_deployment("d2", 1, Some(1000));
        let candidates = vec![d1, d2.clone()];

        let state = MockState::new()
            .with_metrics(
                "d1",
                DeploymentMetrics {
                    tpm_used: 800,
                    ..Default::default()
                },
            )
            .with_metrics(
                "d2",
                DeploymentMetrics {
                    tpm_used: 200,
                    ..Default::default()
                },
            );

        let strategy = UsageBased::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d2");
    }

    #[tokio::test]
    async fn test_ratio_not_absolute() {
        let d_big = make_deployment("d_big", 1, Some(100000));
        let d_small = make_deployment("d_small", 1, Some(1000));
        let candidates = vec![d_big.clone(), d_small];

        let state = MockState::new()
            .with_metrics(
                "d_big",
                DeploymentMetrics {
                    tpm_used: 50000,
                    ..Default::default()
                },
            )
            .with_metrics(
                "d_small",
                DeploymentMetrics {
                    tpm_used: 900,
                    ..Default::default()
                },
            );

        let strategy = UsageBased::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d_big");
    }

    #[tokio::test]
    async fn test_cooled_down_excluded() {
        let d1 = make_deployment("d1", 1, Some(1000));
        let d2 = make_deployment("d2", 1, Some(1000));
        let candidates = vec![d1, d2.clone()];

        let state = MockState::new()
            .with_metrics(
                "d1",
                DeploymentMetrics {
                    tpm_used: 100,
                    ..Default::default()
                },
            )
            .with_metrics(
                "d2",
                DeploymentMetrics {
                    tpm_used: 800,
                    ..Default::default()
                },
            )
            .with_cooldown("d1");

        let strategy = UsageBased::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d2");
    }
}