hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
use super::{RoutingContext, RoutingState, RoutingStrategy};
use crate::deployment::Deployment;
use crate::error::RoutingError;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostBased;

impl Default for CostBased {
    fn default() -> Self {
        Self
    }
}

impl CostBased {
    pub fn new() -> Self {
        Self
    }
}

fn estimate_cost(deployment: &Deployment, ctx: &RoutingContext) -> f64 {
    let input_tokens = ctx.estimated_input_tokens.unwrap_or(0) as f64;
    let output_tokens = ctx.estimated_output_tokens.unwrap_or(0) as f64;

    let input_cost = deployment.input_cost_per_1k.unwrap_or(0.0);
    let output_cost = deployment.output_cost_per_1k.unwrap_or(0.0);

    (input_tokens / 1000.0) * input_cost + (output_tokens / 1000.0) * output_cost
}

#[async_trait]
impl RoutingStrategy for CostBased {
    fn name(&self) -> &str {
        "cost-based"
    }

    async fn select<'a>(
        &self,
        _model: &str,
        candidates: &'a [Arc<Deployment>],
        state: &dyn RoutingState,
        request: &RoutingContext,
    ) -> Result<&'a Arc<Deployment>, RoutingError> {
        if candidates.is_empty() {
            return Err(RoutingError::NoDeployments("empty candidates".into()));
        }

        let ids: Vec<&str> = candidates.iter().map(|d| d.id.as_str()).collect();
        let all_metrics = state.get_all_metrics(&ids).await?;

        let mut eligible: Vec<usize> = Vec::new();

        for (i, deployment) in candidates.iter().enumerate() {
            if state.is_cooled_down(&deployment.id).await? {
                continue;
            }

            let metrics = all_metrics.get(&deployment.id).cloned().unwrap_or_default();

            if let Some(rpm_limit) = deployment.rpm_limit {
                if metrics.rpm_used >= rpm_limit {
                    continue;
                }
            }

            if let Some(tpm_limit) = deployment.tpm_limit {
                if metrics.tpm_used >= tpm_limit {
                    continue;
                }
            }

            eligible.push(i);
        }

        if eligible.is_empty() {
            return Err(RoutingError::NoDeployments(
                "no eligible deployments after filtering".into(),
            ));
        }

        let mut best_idx = eligible[0];
        let mut best_cost = estimate_cost(&candidates[best_idx], request);
        let mut best_weight = candidates[best_idx].weight;

        for &i in eligible.iter().skip(1) {
            let cost = estimate_cost(&candidates[i], request);
            let weight = candidates[i].weight;

            if cost < best_cost || ((cost - best_cost).abs() < f64::EPSILON && weight > best_weight)
            {
                best_idx = i;
                best_cost = cost;
                best_weight = weight;
            }
        }

        Ok(&candidates[best_idx])
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::deployment::Deployment;
    use crate::strategy::weighted_shuffle::tests_helpers::MockState;
    use hyperinfer_core::Provider;

    fn make_deployment(id: &str, weight: u32) -> Arc<Deployment> {
        let mut d = Deployment::new(
            "test-model".to_string(),
            Provider::OpenAI,
            "gpt-4".to_string(),
            format!("key-{}", id),
        );
        d.weight = weight;
        d.id = id.to_string();
        Arc::new(d)
    }

    fn make_deployment_with_costs(
        id: &str,
        weight: u32,
        input_cost: f64,
        output_cost: f64,
    ) -> Arc<Deployment> {
        let mut d = Deployment::new(
            "test-model".to_string(),
            Provider::OpenAI,
            "gpt-4".to_string(),
            format!("key-{}", id),
        );
        d.weight = weight;
        d.id = id.to_string();
        d.input_cost_per_1k = Some(input_cost);
        d.output_cost_per_1k = Some(output_cost);
        Arc::new(d)
    }

    #[tokio::test]
    async fn test_selects_cheapest() {
        let d1 = make_deployment_with_costs("d1", 1, 0.03, 0.06);
        let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
        let candidates = vec![d1, d2.clone()];

        let state = MockState::new();
        let strategy = CostBased::new();
        let ctx = RoutingContext {
            estimated_input_tokens: Some(1000),
            estimated_output_tokens: Some(1000),
            ..Default::default()
        };

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d2");
    }

    #[tokio::test]
    async fn test_equal_cost_prefers_higher_weight() {
        let d1 = make_deployment_with_costs("d1", 5, 0.01, 0.02);
        let d2 = make_deployment_with_costs("d2", 1, 0.01, 0.02);
        let candidates = vec![d1.clone(), d2];

        let state = MockState::new();
        let strategy = CostBased::new();
        let ctx = RoutingContext {
            estimated_input_tokens: Some(1000),
            estimated_output_tokens: Some(1000),
            ..Default::default()
        };

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d1");
    }

    #[tokio::test]
    async fn test_no_cost_data_returns_first() {
        let d1 = make_deployment("d1", 1);
        let d2 = make_deployment("d2", 1);
        let candidates = vec![d1.clone(), d2];

        let state = MockState::new();
        let strategy = CostBased::new();
        let ctx = RoutingContext {
            estimated_input_tokens: Some(1000),
            estimated_output_tokens: Some(1000),
            ..Default::default()
        };

        let result = strategy
            .select("test-model", &candidates, &state, &ctx)
            .await
            .unwrap();
        assert_eq!(result.id, "d1");
    }

    #[test]
    fn test_estimate_cost_calculation() {
        let d = make_deployment_with_costs("d1", 1, 0.01, 0.03);
        let ctx = RoutingContext {
            estimated_input_tokens: Some(2000),
            estimated_output_tokens: Some(1000),
            ..Default::default()
        };

        let cost = estimate_cost(&d, &ctx);
        assert!((cost - 0.05).abs() < f64::EPSILON);
    }

    #[test]
    fn test_estimate_cost_no_prices() {
        let d = make_deployment("d1", 1);
        let ctx = RoutingContext {
            estimated_input_tokens: Some(2000),
            estimated_output_tokens: Some(1000),
            ..Default::default()
        };

        let cost = estimate_cost(&d, &ctx);
        assert!((cost - 0.0).abs() < f64::EPSILON);
    }
}