hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
use super::{DeploymentMetrics, RoutingContext, RoutingState, RoutingStrategy};
use crate::deployment::Deployment;
use crate::error::RoutingError;
use async_trait::async_trait;
use rand::Rng;
use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct WeightedShuffle;

impl WeightedShuffle {
    pub fn new() -> Self {
        Self
    }

    fn effective_weight(deployment: &Deployment, metrics: &DeploymentMetrics) -> f64 {
        let base_weight = deployment.weight as f64;

        let rpm_ratio = match deployment.rpm_limit {
            Some(limit) if limit > 0 => 1.0 - (metrics.rpm_used as f64 / limit as f64),
            _ => 1.0,
        };

        let tpm_ratio = match deployment.tpm_limit {
            Some(limit) if limit > 0 => 1.0 - (metrics.tpm_used as f64 / limit as f64),
            _ => 1.0,
        };

        let capacity = rpm_ratio.min(tpm_ratio).max(0.0);
        base_weight * capacity
    }
}

impl Default for WeightedShuffle {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl RoutingStrategy for WeightedShuffle {
    fn name(&self) -> &str {
        "weighted-shuffle"
    }

    async fn select<'a>(
        &self,
        _model: &str,
        candidates: &'a [Arc<Deployment>],
        state: &dyn RoutingState,
        _request: &RoutingContext,
    ) -> Result<&'a Arc<Deployment>, RoutingError> {
        if candidates.is_empty() {
            return Err(RoutingError::NoDeployments("empty candidates".into()));
        }

        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
        let all_metrics = state.get_all_metrics(&ids).await?;

        let mut eligible: Vec<(usize, f64)> = Vec::new();

        for (i, deployment) in candidates.iter().enumerate() {
            if state.is_cooled_down(&deployment.id).await? {
                continue;
            }

            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();

            let ew = Self::effective_weight(deployment, &metrics);
            if ew > 0.0 {
                eligible.push((i, ew));
            }
        }

        if eligible.is_empty() {
            return Err(RoutingError::NoDeployments(
                "no eligible deployments after filtering".into(),
            ));
        }

        let filtered_candidates: Vec<&Arc<Deployment>> =
            eligible.iter().map(|(i, _)| &candidates[*i]).collect();
        let weights: Vec<f64> = eligible.iter().map(|(_, w)| *w).collect();

        let selected = Self::weighted_select_owned(&filtered_candidates, &weights);
        let selected_id = &selected.id;

        candidates
            .iter()
            .find(|d| d.id == *selected_id)
            .ok_or_else(|| RoutingError::NoDeployments("selected deployment not found".into()))
    }
}

impl WeightedShuffle {
    fn weighted_select_owned(candidates: &[&Arc<Deployment>], weights: &[f64]) -> Arc<Deployment> {
        let total_weight: f64 = weights.iter().sum();
        let mut rng = rand::thread_rng();
        let mut threshold = rng.gen_range(0.0..total_weight);

        for (i, weight) in weights.iter().enumerate() {
            threshold -= weight;
            if threshold <= 0.0 {
                return Arc::clone(candidates[i]);
            }
        }

        Arc::clone(candidates.last().unwrap())
    }
}

#[cfg(test)]
pub mod tests_helpers {
    use super::super::{DeploymentMetrics, RecordFailureResult, RoutingError, RoutingState};
    use async_trait::async_trait;
    use std::collections::HashMap;

    #[derive(Debug, Clone, Default)]
    pub struct MockState {
        pub metrics: HashMap<String, DeploymentMetrics>,
        pub cooled_down: HashMap<String, bool>,
    }

    impl MockState {
        pub fn new() -> Self {
            Self::default()
        }

        pub fn with_metrics(mut self, id: &str, metrics: DeploymentMetrics) -> Self {
            self.metrics.insert(id.to_string(), metrics);
            self
        }

        pub fn with_cooldown(mut self, id: &str) -> Self {
            self.cooled_down.insert(id.to_string(), true);
            self
        }
    }

    #[async_trait]
    impl RoutingState for MockState {
        async fn get_metrics(
            &self,
            deployment_id: &str,
        ) -> Result<DeploymentMetrics, RoutingError> {
            Ok(self.metrics.get(deployment_id).cloned().unwrap_or_default())
        }

        async fn get_all_metrics(
            &self,
            ids: &[&str],
        ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError> {
            let mut result = HashMap::new();
            for id in ids {
                if let Some(m) = self.metrics.get(*id) {
                    result.insert(id.to_string(), m.clone());
                }
            }
            Ok(result)
        }

        async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError> {
            Ok(self
                .cooled_down
                .get(deployment_id)
                .copied()
                .unwrap_or(false))
        }

        async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
            Ok(())
        }

        async fn record_request_success(
            &self,
            _deployment_id: &str,
            _latency_ms: f64,
            _tokens: u64,
        ) -> Result<(), RoutingError> {
            Ok(())
        }

        async fn record_request_failure(
            &self,
            _deployment_id: &str,
        ) -> Result<RecordFailureResult, RoutingError> {
            Ok(RecordFailureResult {
                failure_count: 0,
                cooldown_triggered: false,
            })
        }
    }
}

#[cfg(test)]
mod tests {
    use super::super::DeploymentMetrics;
    use super::tests_helpers::MockState;
    use super::*;
    use crate::deployment::Deployment;
    use hyperinfer_core::Provider;

    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
        let mut d = Deployment::new(
            "test-model".to_string(),
            Provider::OpenAI,
            "gpt-4".to_string(),
            format!("key-{}", id),
        );
        d.weight = weight;
        d.id = id.to_string();
        Arc::new(d)
    }

    fn make_deployment_with_limits(
        id: &str,
        weight: u32,
        rpm_limit: Option<u64>,
        tpm_limit: Option<u64>,
    ) -> Arc<Deployment> {
        let mut d = Deployment::new(
            "test-model".to_string(),
            Provider::OpenAI,
            "gpt-4".to_string(),
            format!("key-{}", id),
        );
        d.weight = weight;
        d.id = id.to_string();
        d.rpm_limit = rpm_limit;
        d.tpm_limit = tpm_limit;
        Arc::new(d)
    }

    #[tokio::test]
    async fn test_single_candidate() {
        let d = make_deployment("d1", 1);
        let candidates = vec![d.clone()];
        let state = MockState::new();
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d1");
    }

    #[tokio::test]
    async fn test_empty_candidates() {
        let candidates: Vec<Arc<Deployment>> = vec![];
        let state = MockState::new();
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await;
        assert!(result.is_err());
        assert!(matches!(
            result.unwrap_err(),
            RoutingError::NoDeployments(_)
        ));
    }

    #[tokio::test]
    async fn test_cooled_down_excluded() {
        let d1 = make_deployment("d1", 1);
        let d2 = make_deployment("d2", 1);
        let candidates = vec![d1, d2.clone()];
        let state = MockState::new().with_cooldown("d1");
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d2");
    }

    #[tokio::test]
    async fn test_all_cooled_down_returns_error() {
        let d1 = make_deployment("d1", 1);
        let d2 = make_deployment("d2", 1);
        let candidates = vec![d1, d2];
        let state = MockState::new().with_cooldown("d1").with_cooldown("d2");
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await;
        assert!(result.is_err());
        assert!(matches!(
            result.unwrap_err(),
            RoutingError::NoDeployments(_)
        ));
    }

    #[tokio::test]
    async fn test_at_capacity_excluded() {
        let d1 = make_deployment_with_limits("d1", 5, Some(100), None);
        let d2 = make_deployment_with_limits("d2", 1, None, None);
        let candidates = vec![d1, d2.clone()];

        let metrics_d1 = DeploymentMetrics {
            rpm_used: 100,
            ..Default::default()
        };

        let state = MockState::new().with_metrics("d1", metrics_d1);
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d2");
    }

    #[tokio::test]
    async fn test_weight_distribution() {
        let d1 = make_deployment("d1", 9);
        let d2 = make_deployment("d2", 1);
        let candidates = vec![d1, d2];
        let state = MockState::new();
        let strategy = WeightedShuffle::new();
        let ctx = RoutingContext::default();

        let mut d1_count = 0u32;
        let iterations = 10000;

        for _ in 0..iterations {
            let result = strategy
                .select("test-model", &candidates, &state, &ctx)
                .await
                .unwrap();
            if result.id == "d1" {
                d1_count += 1;
            }
        }

        let ratio = d1_count as f64 / iterations as f64;
        assert!(
            ratio > 0.80 && ratio < 0.98,
            "expected d1 ratio between 80-98%, got {:.2}%",
            ratio * 100.0
        );
    }

    #[test]
    fn test_effective_weight_no_limits() {
        let d = make_deployment("d1", 5);
        let metrics = DeploymentMetrics::default();
        let ew = WeightedShuffle::effective_weight(&d, &metrics);
        assert!((ew - 5.0).abs() < f64::EPSILON);
    }

    #[test]
    fn test_effective_weight_at_rpm_limit() {
        let d = make_deployment_with_limits("d1", 5, Some(100), None);
        let metrics = DeploymentMetrics {
            rpm_used: 100,
            ..Default::default()
        };
        let ew = WeightedShuffle::effective_weight(&d, &metrics);
        assert!((ew - 0.0).abs() < f64::EPSILON);
    }

    #[test]
    fn test_effective_weight_half_capacity() {
        let d = make_deployment_with_limits("d1", 10, Some(100), None);
        let metrics = DeploymentMetrics {
            rpm_used: 50,
            ..Default::default()
        };
        let ew = WeightedShuffle::effective_weight(&d, &metrics);
        assert!((ew - 5.0).abs() < f64::EPSILON);
    }
}