hyperinfer-server 0.1.1

High-performance LLM Gateway server built with Axum
use hyperinfer_core::{ChatRequest, Database, Provider};
use hyperinfer_router::{
    deployment::Deployment as RouterDeployment,
    engine::{GlobalLimits, RouterEngine},
    error::RoutingError,
    strategy::{
        cost_based::CostBased, latency_based::LatencyBased, least_busy::LeastBusy,
        usage_based::UsageBased, weighted_shuffle::WeightedShuffle, DeploymentMetrics,
        RecordFailureResult, RoutingContext, RoutingState,
    },
};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::LazyLock;

static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);

const BLOCKED_IP_PREFIXES: &[&str] = &[
    "169.254.", "10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.20.", "172.21.", "172.22.",
    "172.23.", "172.24.", "172.25.", "172.26.", "172.27.", "172.28.", "172.29.", "172.30.",
    "172.31.", "192.168.", "127.", "0.",
];

pub struct ProxyAuth {
    pub team_id: String,
    pub api_key_id: String,
}

pub struct NoopState;

#[async_trait::async_trait]
impl RoutingState for NoopState {
    async fn get_metrics(&self, _deployment_id: &str) -> Result<DeploymentMetrics, RoutingError> {
        Ok(DeploymentMetrics::default())
    }

    async fn get_all_metrics(
        &self,
        _ids: &[&str],
    ) -> Result<HashMap<String, DeploymentMetrics>, RoutingError> {
        Ok(HashMap::new())
    }

    async fn is_cooled_down(&self, _deployment_id: &str) -> Result<bool, RoutingError> {
        Ok(false)
    }

    async fn record_request_start(&self, _deployment_id: &str) -> Result<(), RoutingError> {
        Ok(())
    }

    async fn record_request_success(
        &self,
        _deployment_id: &str,
        _latency_ms: f64,
        _tokens: u64,
    ) -> Result<(), RoutingError> {
        Ok(())
    }

    async fn record_request_failure(
        &self,
        _deployment_id: &str,
    ) -> Result<RecordFailureResult, RoutingError> {
        Ok(RecordFailureResult {
            failure_count: 0,
            cooldown_triggered: false,
        })
    }
}

pub async fn validate_api_key<D: Database>(db: &D, api_key: &str) -> Result<ProxyAuth, u16> {
    if api_key.is_empty() {
        return Err(401);
    }

    let key_hash = {
        let mut hasher = Sha256::new();
        hasher.update(api_key.as_bytes());
        hex::encode(hasher.finalize())
    };

    match db.get_api_key_by_hash(&key_hash).await {
        Ok(Some(key)) => {
            if !key.is_active {
                return Err(401);
            }
            Ok(ProxyAuth {
                team_id: key.team_id,
                api_key_id: key.id,
            })
        }
        Ok(None) => Err(401),
        Err(_) => Err(500),
    }
}

fn validate_base_url(url: &str) -> Result<(), u16> {
    let parsed = url::Url::parse(url).map_err(|_| 400u16)?;
    let host = parsed.host_str().ok_or(400u16)?;

    if let Ok(ip) = host.parse::<IpAddr>() {
        let ip_str = ip.to_string();
        for prefix in BLOCKED_IP_PREFIXES {
            if ip_str.starts_with(prefix) {
                return Err(400);
            }
        }
    }

    if parsed.scheme() != "https" && parsed.scheme() != "http" {
        return Err(400);
    }

    Ok(())
}

fn default_base_url_for_provider(provider: &Provider) -> String {
    match provider {
        Provider::Anthropic => "https://api.anthropic.com/v1".to_string(),
        Provider::OpenAI => "https://api.openai.com/v1".to_string(),
        _ => "https://api.openai.com/v1".to_string(),
    }
}

fn set_auth_header(
    headers: &mut reqwest::header::HeaderMap,
    provider: &Provider,
    api_key: &str,
) -> Result<(), u16> {
    if api_key.is_empty() {
        return Ok(());
    }
    match provider {
        Provider::Anthropic => {
            headers.insert("x-api-key", api_key.parse().map_err(|_| 500u16)?);
            headers.insert(
                "anthropic-version",
                "2023-06-01".parse().map_err(|_| 500u16)?,
            );
        }
        _ => {
            headers.insert(
                "authorization",
                format!("Bearer {}", api_key).parse().map_err(|_| 500u16)?,
            );
        }
    }
    Ok(())
}

pub struct SelectedDeployment {
    pub deployment: RouterDeployment,
    pub base_url: String,
    pub api_key: String,
    pub provider: Provider,
}

pub async fn select_deployment<D: Database>(
    db: &D,
    request: &ChatRequest,
    deployments: &[hyperinfer_core::Deployment],
    model_aliases: &std::collections::HashMap<String, String>,
    _auth: Option<&ProxyAuth>,
) -> Result<SelectedDeployment, RoutingError> {
    if deployments.is_empty() {
        return Err(RoutingError::NoDeployments(request.model.clone()));
    }

    let engine = RouterEngine::new(GlobalLimits::default());
    engine
        .register_strategy(Box::new(WeightedShuffle::new()))
        .await;
    engine
        .register_strategy(Box::new(LatencyBased::new()))
        .await;
    engine.register_strategy(Box::new(LeastBusy::new())).await;
    engine.register_strategy(Box::new(UsageBased::new())).await;
    engine.register_strategy(Box::new(CostBased::new())).await;

    for (alias, target) in model_aliases {
        engine.set_alias(alias, target).await;
    }

    if let Ok(Some(config)) = db.get_routing_config().await {
        engine.set_default_strategy(&config.strategy).await;

        if let Ok(groups) = serde_json::from_value::<std::collections::HashMap<String, String>>(
            config.routing_groups,
        ) {
            for (model, strategy_name) in groups {
                engine.set_routing_group(&model, &strategy_name).await;
            }
        }
    }

    for d in deployments {
        let provider = match d.provider.as_str() {
            "openai" => Provider::OpenAI,
            "anthropic" => Provider::Anthropic,
            _ => Provider::Other,
        };
        let mut router_deployment = RouterDeployment::new(
            d.name.clone(),
            provider,
            d.model.clone(),
            d.api_key_ref.clone(),
        );
        router_deployment.id = d.id.clone();
        if !d.base_url.is_empty() {
            router_deployment = router_deployment.with_base_url(d.base_url.clone());
        }
        router_deployment = router_deployment.with_weight(d.weight);
        if let Some(max_tpm) = d.max_tpm {
            router_deployment = router_deployment.with_tpm_limit(max_tpm as u64);
        }
        if let Some(max_rpm) = d.max_rpm {
            router_deployment = router_deployment.with_rpm_limit(max_rpm as u64);
        }
        if let Some(cost) = d.cost_per_1k_input_tokens {
            router_deployment = router_deployment.with_input_cost(cost);
        }
        if let Some(cost) = d.cost_per_1k_output_tokens {
            router_deployment = router_deployment.with_output_cost(cost);
        }
        engine.add_deployment(router_deployment).await;
    }

    let state_ref = &NoopState;
    let ctx = RoutingContext::default();
    let result = engine
        .select_deployment(&request.model, state_ref, &ctx)
        .await?;

    let selected = &result.deployment;
    let base_url = selected
        .base_url
        .clone()
        .unwrap_or_else(|| default_base_url_for_provider(&selected.provider));

    Ok(SelectedDeployment {
        deployment: RouterDeployment::new(
            selected.model_name.clone(),
            selected.provider.clone(),
            selected.model.clone(),
            selected.api_key_ref.clone(),
        ),
        base_url,
        api_key: selected.api_key_ref.clone(),
        provider: selected.provider.clone(),
    })
}

pub async fn forward_request(
    request: &ChatRequest,
    base_url: &str,
    api_key: &str,
    provider: &Provider,
) -> Result<serde_json::Value, u16> {
    validate_base_url(base_url)?;

    let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));

    let mut headers = reqwest::header::HeaderMap::new();
    headers.insert(
        "content-type",
        "application/json".parse().map_err(|_| 500u16)?,
    );
    set_auth_header(&mut headers, provider, api_key)?;

    let response = match HTTP_CLIENT
        .post(&url)
        .headers(headers)
        .json(request)
        .send()
        .await
    {
        Ok(r) => r,
        Err(_) => return Err(502),
    };

    let status = response.status();
    let body: serde_json::Value = match response.json().await {
        Ok(b) => b,
        Err(_) => return Err(502),
    };

    if status.is_success() {
        Ok(body)
    } else {
        Err(status.as_u16())
    }
}