hyperinfer-router 0.1.0

Intelligent request routing engine for HyperInfer
Documentation
use crate::error::RoutingError;
use crate::strategy::{DeploymentMetrics, RecordFailureResult, RoutingState};
use async_trait::async_trait;
use redis::aio::ConnectionManager;
use redis::Client;
use std::collections::HashMap;
use tracing::warn;

const METRICS_KEY_PREFIX: &str = "hyperinfer:routing:metrics:";

const METRIC_UPDATE_SCRIPT: &str = r#"
local prefix = KEYS[1]
local outcome = ARGV[1]
local latency_ms = tonumber(ARGV[2])
local token_count = tonumber(ARGV[3])
local alpha = tonumber(ARGV[4])
local allowed_fails = tonumber(ARGV[5])
local cooldown_secs = tonumber(ARGV[6])
local tpm_ttl = tonumber(ARGV[7])
local latency_ttl = tonumber(ARGV[8])
local failures_ttl = tonumber(ARGV[9])

local in_flight = tonumber(redis.call('GET', prefix .. ':in_flight') or '0')
if in_flight > 0 then
    redis.call('DECR', prefix .. ':in_flight')
end

if outcome == 'success' then
    local old_ewma = tonumber(redis.call('GET', prefix .. ':latency') or '0')
    local new_ewma
    if old_ewma == 0 or old_ewma == nil then
        new_ewma = latency_ms
    else
        new_ewma = alpha * latency_ms + (1 - alpha) * old_ewma
    end
    redis.call('SET', prefix .. ':latency', tostring(new_ewma), 'EX', latency_ttl)

    if token_count > 0 then
        local tpm_key = prefix .. ':tpm'
        redis.call('INCRBY', tpm_key, token_count)
        if redis.call('TTL', tpm_key) == -1 then
            redis.call('EXPIRE', tpm_key, tpm_ttl)
        end
    end

    redis.call('SET', prefix .. ':failures', '0', 'EX', failures_ttl)
    redis.call('INCR', prefix .. ':total_req')
    return {0, 0}
else
    local fails = redis.call('INCR', prefix .. ':failures')
    if redis.call('TTL', prefix .. ':failures') == -1 then
        redis.call('EXPIRE', prefix .. ':failures', failures_ttl)
    end
    local cooldown_triggered = 0
    if fails >= allowed_fails then
        redis.call('SET', prefix .. ':cooldown', '1', 'EX', cooldown_secs)
        cooldown_triggered = 1
    end
    redis.call('INCR', prefix .. ':total_fail')
    return {fails, cooldown_triggered}
end
"#;

fn parse_u64(v: &Option<String>) -> u64 {
    v.as_deref()
        .and_then(|s| s.parse::<u64>().ok())
        .unwrap_or(0)
}

fn parse_f64(v: &Option<String>) -> f64 {
    v.as_deref()
        .and_then(|s| s.parse::<f64>().ok())
        .unwrap_or(0.0)
}

fn metrics_from_values(values: &[Option<String>], base: usize) -> DeploymentMetrics {
    DeploymentMetrics {
        latency_ewma_ms: parse_f64(&values[base]),
        in_flight: parse_u64(&values[base + 1]),
        tpm_used: parse_u64(&values[base + 2]),
        rpm_used: parse_u64(&values[base + 3]),
        total_requests: parse_u64(&values[base + 4]),
        total_failures: parse_u64(&values[base + 5]),
        last_failure_ts: None,
    }
}

#[derive(Debug, Clone)]
pub struct RedisConfig {
    pub alpha: f64,
    pub allowed_fails: u64,
    pub cooldown_secs: u64,
    pub tpm_ttl_secs: u64,
    pub latency_ttl_secs: u64,
    pub failures_ttl_secs: u64,
}

impl Default for RedisConfig {
    fn default() -> Self {
        Self {
            alpha: 0.3,
            allowed_fails: 3,
            cooldown_secs: 30,
            tpm_ttl_secs: 60,
            latency_ttl_secs: 600,
            failures_ttl_secs: 300,
        }
    }
}

#[derive(Clone)]
pub struct RedisRoutingState {
    manager: ConnectionManager,
    config: RedisConfig,
}

impl RedisRoutingState {
    pub async fn new(redis_url: &str, config: RedisConfig) -> Result<Self, RoutingError> {
        let client = Client::open(redis_url)?;
        let manager = ConnectionManager::new(client).await?;
        Ok(Self { manager, config })
    }

    fn prefix(id: &str) -> String {
        format!("{}{}", METRICS_KEY_PREFIX, id)
    }
}

#[async_trait]
impl RoutingState for RedisRoutingState {
    async fn get_metrics(&self, deployment_id: &str) -> Result<DeploymentMetrics, RoutingError> {
        let prefix = Self::prefix(deployment_id);
        let mut conn = self.manager.clone();

        let keys = [
            format!("{}:latency", prefix),
            format!("{}:in_flight", prefix),
            format!("{}:tpm", prefix),
            format!("{}:rpm", prefix),
            format!("{}:total_req", prefix),
            format!("{}:total_fail", prefix),
        ];

        let values: Vec<Option<String>> = redis::cmd("MGET")
            .arg(&keys[0])
            .arg(&keys[1])
            .arg(&keys[2])
            .arg(&keys[3])
            .arg(&keys[4])
            .arg(&keys[5])
            .query_async(&mut conn)
            .await?;

        Ok(metrics_from_values(&values, 0))
    }

    async fn get_all_metrics(
        &self,
        ids: &[&str],
    ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError> {
        if ids.is_empty() {
            return Ok(HashMap::new());
        }

        let mut conn = self.manager.clone();
        let keys_per_deployment = 6;
        let mut all_keys: Vec<String> = Vec::with_capacity(ids.len() * keys_per_deployment);

        for id in ids {
            let prefix = Self::prefix(id);
            all_keys.push(format!("{}:latency", prefix));
            all_keys.push(format!("{}:in_flight", prefix));
            all_keys.push(format!("{}:tpm", prefix));
            all_keys.push(format!("{}:rpm", prefix));
            all_keys.push(format!("{}:total_req", prefix));
            all_keys.push(format!("{}:total_fail", prefix));
        }

        let mut cmd = redis::cmd("MGET");
        for key in &all_keys {
            cmd.arg(key);
        }

        let values: Vec<Option<String>> = cmd.query_async(&mut conn).await?;

        let mut result = HashMap::with_capacity(ids.len());
        for (i, id) in ids.iter().enumerate() {
            let base = i * keys_per_deployment;
            if base + 5 < values.len() {
                result.insert(id.to_string(), metrics_from_values(&values, base));
            }
        }

        Ok(result)
    }

    async fn is_cooled_down(&self, deployment_id: &str) -> Result<bool, RoutingError> {
        let prefix = Self::prefix(deployment_id);
        let mut conn = self.manager.clone();

        let val: Option<String> = redis::cmd("GET")
            .arg(format!("{}:cooldown", prefix))
            .query_async(&mut conn)
            .await?;

        Ok(val.is_some())
    }

    async fn record_request_start(&self, deployment_id: &str) -> Result<(), RoutingError> {
        let prefix = Self::prefix(deployment_id);
        let mut conn = self.manager.clone();
        let rpm_key = format!("{}:rpm", prefix);
        let pipe = redis::pipe()
            .atomic()
            .incr(format!("{}:in_flight", prefix), 1)
            .incr(&rpm_key, 1)
            .cmd("EXPIRE")
            .arg(&rpm_key)
            .arg(60)
            .clone();

        if let Err(e) = pipe.query_async::<()>(&mut conn).await {
            warn!(deployment_id = %deployment_id, error = %e, "failed to record request start");
        }

        Ok(())
    }

    async fn record_request_success(
        &self,
        deployment_id: &str,
        latency_ms: f64,
        tokens: u64,
    ) -> Result<(), RoutingError> {
        let prefix = Self::prefix(deployment_id);
        let manager = self.manager.clone();
        let config = self.config.clone();
        let deployment_id = deployment_id.to_string();

        tokio::spawn(async move {
            let mut conn = manager.clone();
            let result: Result<Vec<u64>, redis::RedisError> = redis::cmd("EVAL")
                .arg(METRIC_UPDATE_SCRIPT)
                .arg(1)
                .arg(&prefix)
                .arg("success")
                .arg(latency_ms)
                .arg(tokens)
                .arg(config.alpha)
                .arg(config.allowed_fails)
                .arg(config.cooldown_secs)
                .arg(config.tpm_ttl_secs)
                .arg(config.latency_ttl_secs)
                .arg(config.failures_ttl_secs)
                .query_async(&mut conn)
                .await;

            if let Err(e) = result {
                warn!(deployment_id = %deployment_id, error = %e, "failed to record request success");
            }
        });

        Ok(())
    }

    async fn record_request_failure(
        &self,
        deployment_id: &str,
    ) -> Result<RecordFailureResult, RoutingError> {
        let prefix = Self::prefix(deployment_id);
        let mut conn = self.manager.clone();

        let result: Vec<u64> = redis::cmd("EVAL")
            .arg(METRIC_UPDATE_SCRIPT)
            .arg(1)
            .arg(&prefix)
            .arg("failure")
            .arg(0.0_f64)
            .arg(0_u64)
            .arg(self.config.alpha)
            .arg(self.config.allowed_fails)
            .arg(self.config.cooldown_secs)
            .arg(self.config.tpm_ttl_secs)
            .arg(self.config.latency_ttl_secs)
            .arg(self.config.failures_ttl_secs)
            .query_async(&mut conn)
            .await?;

        let failure_count = result.first().copied().unwrap_or(0);
        let cooldown_triggered = result.get(1).copied().unwrap_or(0);

        Ok(RecordFailureResult {
            failure_count,
            cooldown_triggered: cooldown_triggered == 1,
        })
    }
}