unigateway 1.5.0

Lightweight, local-first LLM gateway for developers. A stable, single-binary unified entry point for all your AI tools and models.
use std::sync::Arc;
use std::time::Instant;

use axum::{
    Json,
    body::Body,
    http::StatusCode,
    http::header,
    response::{IntoResponse, Response},
};
use serde_json::json;
use unigateway_host::{HostDispatchOutcome, HostError, HostResult, status::status_for_host_error};
use unigateway_protocol::{ProtocolHttpResponse, ProtocolResponseBody};

use crate::middleware::{GatewayAuth, record_stat};
use crate::types::GatewayRequestState;

use super::request_flow::PreparedGatewayRequest;

pub(super) type HostResponseResult =
    std::result::Result<ProtocolHttpResponse, ProtocolHttpResponse>;

pub(super) async fn resolve_core_only_host_flow<CoreFuture>(
    core_attempt: CoreFuture,
    unavailable_message: &str,
) -> HostResponseResult
where
    CoreFuture: std::future::Future<Output = HostResult<HostDispatchOutcome>>,
{
    match core_attempt.await {
        Ok(HostDispatchOutcome::Response(response)) => Ok(response),
        Ok(HostDispatchOutcome::PoolNotFound) => Err(error_json(
            StatusCode::SERVICE_UNAVAILABLE,
            unavailable_message,
        )),
        Ok(_) => Err(error_json(
            StatusCode::BAD_GATEWAY,
            "unsupported host dispatch outcome",
        )),
        Err(error) => Err(host_error_response(&error)),
    }
}

pub(super) fn missing_upstream_api_key_response() -> ProtocolHttpResponse {
    error_json(StatusCode::BAD_REQUEST, "missing upstream api key")
}

pub(super) async fn respond_prepared_host_result(
    state: &Arc<GatewayRequestState>,
    prepared: &PreparedGatewayRequest<'_>,
    endpoint: &str,
    result: HostResponseResult,
) -> Response {
    match prepared.auth.as_ref() {
        Some(auth) => {
            respond_authenticated_host_result(state, auth, endpoint, &prepared.start, result).await
        }
        None => respond_env_host_result(state, endpoint, &prepared.start, result).await,
    }
}

pub(super) async fn respond_authenticated_host_result(
    state: &Arc<GatewayRequestState>,
    auth: &GatewayAuth,
    endpoint: &str,
    start: &Instant,
    result: HostResponseResult,
) -> Response {
    match result {
        Ok(response) => gateway_success_response(state, auth, endpoint, start, response).await,
        Err(response) => gateway_error_response(state, auth, endpoint, start, response).await,
    }
}

pub(super) async fn respond_env_host_result(
    state: &Arc<GatewayRequestState>,
    endpoint: &str,
    start: &Instant,
    result: HostResponseResult,
) -> Response {
    match result {
        Ok(response) => success_response(state, endpoint, start, response).await,
        Err(response) => error_response(state, endpoint, start, response).await,
    }
}

async fn gateway_success_response(
    state: &Arc<GatewayRequestState>,
    auth: &GatewayAuth,
    endpoint: &str,
    start: &Instant,
    response: ProtocolHttpResponse,
) -> Response {
    auth.finalize(state).await;
    success_response(state, endpoint, start, response).await
}

async fn gateway_error_response(
    state: &Arc<GatewayRequestState>,
    auth: &GatewayAuth,
    endpoint: &str,
    start: &Instant,
    response: ProtocolHttpResponse,
) -> Response {
    auth.release(state).await;
    error_response(state, endpoint, start, response).await
}

async fn success_response(
    state: &Arc<GatewayRequestState>,
    endpoint: &str,
    start: &Instant,
    response: ProtocolHttpResponse,
) -> Response {
    recorded_response(state, endpoint, start, response).await
}

async fn error_response(
    state: &Arc<GatewayRequestState>,
    endpoint: &str,
    start: &Instant,
    response: ProtocolHttpResponse,
) -> Response {
    recorded_response(state, endpoint, start, response).await
}

async fn recorded_response(
    state: &Arc<GatewayRequestState>,
    endpoint: &str,
    start: &Instant,
    response: ProtocolHttpResponse,
) -> Response {
    let (status, body) = response.into_parts();
    record_stat(state, endpoint, status.as_u16(), start).await;
    into_axum_response(status, body)
}

fn into_axum_response(status: StatusCode, body: ProtocolResponseBody) -> Response {
    match body {
        ProtocolResponseBody::Json(value) => (status, Json(value)).into_response(),
        ProtocolResponseBody::ServerSentEvents(stream) => (
            status,
            [(header::CONTENT_TYPE, "text/event-stream")],
            Body::from_stream(stream),
        )
            .into_response(),
    }
}

fn host_error_response(error: &HostError) -> ProtocolHttpResponse {
    error_json(
        status_for_host_error(error),
        &format!("host execution error: {error}"),
    )
}

fn error_json(status: StatusCode, message: &str) -> ProtocolHttpResponse {
    ProtocolHttpResponse::json(status, json!({"error": {"message": message}}))
}