hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
use hyperinfer_core::Provider;
use hyperinfer_router::{
    Deployment, GlobalLimits, RedisConfig, RedisRoutingState, RouterEngine, RoutingContext,
    RoutingState, UsageBased, WeightedShuffle,
};
use std::time::Duration;
use testcontainers::{core::IntoContainerPort, runners::AsyncRunner, GenericImage};
use testcontainers_modules::redis::REDIS_PORT;

async fn setup_redis() -> (String, testcontainers::ContainerAsync<GenericImage>) {
    let redis = GenericImage::new("redis", "7.2")
        .with_exposed_port(REDIS_PORT.tcp())
        .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
            "Ready to accept connections",
        ))
        .start()
        .await
        .expect("Failed to start Redis container");
    let port = redis.get_host_port_ipv4(REDIS_PORT).await.unwrap();
    (format!("redis://127.0.0.1:{}", port), redis)
}

fn make_deployment(model: &str, id: &str, weight: u32) -> Deployment {
    let mut d = Deployment::new(
        model.to_string(),
        Provider::OpenAI,
        model.to_string(),
        format!("key-{}", id),
    );
    d.id = id.to_string();
    d.weight = weight;
    d
}

#[tokio::test]
async fn test_full_routing_flow_with_redis() {
    let (redis_url, _container) = setup_redis().await;
    let state = RedisRoutingState::new(&redis_url, RedisConfig::default())
        .await
        .unwrap();

    let engine = RouterEngine::new(GlobalLimits::default());
    engine
        .register_strategy(Box::new(WeightedShuffle::new()))
        .await;

    let dep_a = make_deployment("gpt-4", "dep-a", 3);
    let dep_b = make_deployment("gpt-4", "dep-b", 1);
    engine.add_deployment(dep_a).await;
    engine.add_deployment(dep_b).await;

    let ctx = RoutingContext::default();
    let result = engine
        .select_deployment("gpt-4", &state, &ctx)
        .await
        .unwrap();

    assert!(
        result.deployment.id == "dep-a" || result.deployment.id == "dep-b",
        "Expected one of dep-a or dep-b, got {}",
        result.deployment.id
    );
    assert_eq!(result.attempt, 1);
}

#[tokio::test]
async fn test_metrics_affect_routing() {
    let (redis_url, _container) = setup_redis().await;
    let state = RedisRoutingState::new(&redis_url, RedisConfig::default())
        .await
        .unwrap();

    let engine = RouterEngine::new(GlobalLimits::default());
    engine.register_strategy(Box::new(UsageBased::new())).await;
    engine.set_default_strategy("usage-based").await;

    let mut dep_a = make_deployment("gpt-4", "dep-a", 1);
    dep_a.tpm_limit = Some(10000);
    let mut dep_b = make_deployment("gpt-4", "dep-b", 1);
    dep_b.tpm_limit = Some(10000);
    engine.add_deployment(dep_a).await;
    engine.add_deployment(dep_b).await;

    state.record_request_start("dep-a").await.unwrap();
    tokio::time::sleep(Duration::from_millis(50)).await;
    state
        .record_request_success("dep-a", 100.0, 9000)
        .await
        .unwrap();
    tokio::time::sleep(Duration::from_millis(100)).await;

    let ctx = RoutingContext::default();
    let result = engine
        .select_deployment("gpt-4", &state, &ctx)
        .await
        .unwrap();

    assert_eq!(
        result.deployment.id, "dep-b",
        "dep_b should be selected since dep_a has high utilization"
    );
}

#[tokio::test]
async fn test_cooldown_removes_from_pool() {
    let (redis_url, _container) = setup_redis().await;
    let config = RedisConfig {
        allowed_fails: 2,
        cooldown_secs: 10,
        ..RedisConfig::default()
    };
    let state = RedisRoutingState::new(&redis_url, config).await.unwrap();

    let engine = RouterEngine::new(GlobalLimits::default());
    engine
        .register_strategy(Box::new(WeightedShuffle::new()))
        .await;

    let dep_a = make_deployment("gpt-4", "dep-a", 1);
    let dep_b = make_deployment("gpt-4", "dep-b", 1);
    engine.add_deployment(dep_a).await;
    engine.add_deployment(dep_b).await;

    for _ in 0..2 {
        state.record_request_start("dep-a").await.unwrap();
        tokio::time::sleep(Duration::from_millis(50)).await;
        state.record_request_failure("dep-a").await.unwrap();
    }

    let is_cooled = state.is_cooled_down("dep-a").await.unwrap();
    assert!(is_cooled, "dep-a should be in cooldown");

    let ctx = RoutingContext::default();
    let result = engine
        .select_deployment("gpt-4", &state, &ctx)
        .await
        .unwrap();

    assert_eq!(
        result.deployment.id, "dep-b",
        "dep_b should be selected since dep_a is in cooldown"
    );
}

#[tokio::test]
async fn test_concurrent_routing_decisions() {
    let (redis_url, _container) = setup_redis().await;
    let state = RedisRoutingState::new(&redis_url, RedisConfig::default())
        .await
        .unwrap();

    let engine = std::sync::Arc::new(RouterEngine::new(GlobalLimits::default()));
    engine
        .register_strategy(Box::new(WeightedShuffle::new()))
        .await;

    for i in 0..5 {
        let dep = make_deployment("gpt-4", &format!("dep-{}", i), 1);
        engine.add_deployment(dep).await;
    }

    let mut handles = Vec::new();
    for _ in 0..100 {
        let engine = engine.clone();
        let state = state.clone();
        handles.push(tokio::spawn(async move {
            let ctx = RoutingContext::default();
            engine
                .select_deployment("gpt-4", &state, &ctx)
                .await
                .unwrap()
        }));
    }

    let mut results = Vec::new();
    for h in handles {
        results.push(h.await.unwrap());
    }

    assert_eq!(
        results.len(),
        100,
        "All 100 routing decisions should succeed"
    );
    for r in &results {
        assert!(
            r.deployment.id.starts_with("dep-"),
            "Unexpected deployment id: {}",
            r.deployment.id
        );
    }
}

#[tokio::test]
async fn test_engine_passthrough_with_redis() {
    let (redis_url, _container) = setup_redis().await;
    let state = RedisRoutingState::new(&redis_url, RedisConfig::default())
        .await
        .unwrap();

    let engine = RouterEngine::new(GlobalLimits::default());
    engine
        .register_strategy(Box::new(WeightedShuffle::new()))
        .await;

    let dep = make_deployment("gpt-4", "dep-passthrough", 1);
    engine.add_deployment(dep).await;

    let ctx = RoutingContext::default();
    let result = engine
        .select_deployment("gpt-4", &state, &ctx)
        .await
        .unwrap();
    assert_eq!(result.deployment.id, "dep-passthrough");

    state.record_request_start("dep-passthrough").await.unwrap();
    tokio::time::sleep(Duration::from_millis(50)).await;

    engine
        .record_success("dep-passthrough", 200.0, 300, &state)
        .await;
    tokio::time::sleep(Duration::from_millis(100)).await;

    let metrics = state.get_metrics("dep-passthrough").await.unwrap();
    assert_eq!(metrics.total_requests, 1);
    assert!((metrics.latency_ewma_ms - 200.0).abs() < 1.0);
    assert_eq!(metrics.tpm_used, 300);
}