unigateway 1.2.1

Lightweight, local-first LLM gateway for developers. A stable, single-binary unified entry point for all your AI tools and models.
use std::sync::Arc;

use axum::{
    extract::{Query, State},
    http::{HeaderMap, StatusCode},
    response::IntoResponse,
};
use serde_json::json;

use crate::authz::is_admin_authorized;
use crate::dto::{ApiResponse, CreateServiceReq, ModeSummaryOut, ServiceOut, SetDefaultModeReq};
use crate::types::AppState;

#[derive(serde::Deserialize)]
pub(crate) struct ListModesQuery {
    pub(crate) detailed: Option<bool>,
}

pub(crate) async fn api_list_services(
    State(state): State<Arc<AppState>>,
    headers: HeaderMap,
) -> impl IntoResponse {
    if !is_admin_authorized(&state, &headers).await {
        return StatusCode::UNAUTHORIZED.into_response();
    }

    let rows: Vec<ServiceOut> = state
        .gateway
        .list_services()
        .await
        .into_iter()
        .map(|(id, name)| ServiceOut { id, name })
        .collect();

    axum::Json(ApiResponse {
        success: true,
        data: rows,
    })
    .into_response()
}

pub(crate) async fn api_create_service(
    State(state): State<Arc<AppState>>,
    headers: HeaderMap,
    axum::Json(req): axum::Json<CreateServiceReq>,
) -> impl IntoResponse {
    if !is_admin_authorized(&state, &headers).await {
        return StatusCode::UNAUTHORIZED.into_response();
    }

    state.gateway.create_service(&req.id, &req.name).await;
    let _ = state.gateway.persist_if_dirty().await;

    axum::Json(ApiResponse {
        success: true,
        data: json!({"id": req.id, "name": req.name}),
    })
    .into_response()
}

pub(crate) async fn api_list_modes(
    State(state): State<Arc<AppState>>,
    headers: HeaderMap,
    Query(query): Query<ListModesQuery>,
) -> impl IntoResponse {
    if !is_admin_authorized(&state, &headers).await {
        return StatusCode::UNAUTHORIZED.into_response();
    }

    let modes = state.gateway.list_mode_views().await;
    let data = if query.detailed.unwrap_or(false) {
        json!(modes)
    } else {
        let rows: Vec<ModeSummaryOut> = modes
            .into_iter()
            .map(|mode| ModeSummaryOut {
                id: mode.id,
                name: mode.name,
                routing_strategy: mode.routing_strategy,
                is_default: mode.is_default,
                provider_count: mode.providers.len(),
                provider_names: mode
                    .providers
                    .into_iter()
                    .map(|provider| provider.name)
                    .collect(),
            })
            .collect();
        json!(rows)
    };

    axum::Json(ApiResponse {
        success: true,
        data,
    })
    .into_response()
}

pub(crate) async fn api_set_default_mode(
    State(state): State<Arc<AppState>>,
    headers: HeaderMap,
    axum::Json(req): axum::Json<SetDefaultModeReq>,
) -> impl IntoResponse {
    if !is_admin_authorized(&state, &headers).await {
        return StatusCode::UNAUTHORIZED.into_response();
    }

    match state.gateway.set_default_mode(&req.mode_id).await {
        Ok(_) => {
            let _ = state.gateway.persist_if_dirty().await;
            axum::Json(ApiResponse {
                success: true,
                data: json!({"mode_id": req.mode_id}),
            })
            .into_response()
        }
        Err(err) => (
            StatusCode::BAD_REQUEST,
            axum::Json(json!({"success": false, "error": err.to_string()})),
        )
            .into_response(),
    }
}

#[cfg(test)]
mod tests {
    use super::{ListModesQuery, api_list_modes, api_set_default_mode};
    use crate::config::GatewayState;
    use crate::dto::SetDefaultModeReq;
    use crate::types::{AppConfig, AppState};
    use axum::Json;
    use axum::body::to_bytes;
    use axum::extract::{Query, State};
    use axum::http::{HeaderMap, HeaderValue, StatusCode};
    use axum::response::IntoResponse;
    use serde_json::Value;
    use std::path::Path;
    use std::sync::Arc;
    use tempfile::tempdir;

    async fn test_state(admin_token: &str) -> Arc<AppState> {
        let dir = tempdir().expect("tempdir");
        let config_path = dir.path().join("config.toml");
        let gateway = GatewayState::load(Path::new(&config_path))
            .await
            .expect("load state");
        gateway.create_service("fast", "Fast").await;
        gateway.create_service("strong", "Strong").await;

        let config = AppConfig {
            bind: "127.0.0.1:3210".to_string(),
            config_path: config_path.to_string_lossy().to_string(),
            admin_token: admin_token.to_string(),
            openai_base_url: String::new(),
            openai_api_key: String::new(),
            openai_model: String::new(),
            anthropic_base_url: String::new(),
            anthropic_api_key: String::new(),
            anthropic_model: String::new(),
        };

        Arc::new(AppState::new(config, gateway))
    }

    #[tokio::test]
    async fn api_list_modes_unauthorized_without_token_header() {
        let state = test_state("secret").await;
        let response = api_list_modes(
            State(state),
            HeaderMap::new(),
            Query(ListModesQuery { detailed: None }),
        )
        .await
        .into_response();

        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    }

    #[tokio::test]
    async fn api_list_modes_returns_detailed_data_when_requested() {
        let state = test_state("secret").await;
        let mut headers = HeaderMap::new();
        headers.insert("x-admin-token", HeaderValue::from_static("secret"));

        let response = api_list_modes(
            State(state),
            headers,
            Query(ListModesQuery {
                detailed: Some(true),
            }),
        )
        .await
        .into_response();

        assert_eq!(response.status(), StatusCode::OK);
        let body = to_bytes(response.into_body(), usize::MAX)
            .await
            .expect("read body");
        let json: Value = serde_json::from_slice(&body).expect("json body");
        assert_eq!(json.get("success").and_then(Value::as_bool), Some(true));
        let data = json
            .get("data")
            .and_then(Value::as_array)
            .expect("data array");
        assert_eq!(data.len(), 2);
        assert!(data[0].get("providers").is_some());
    }

    #[tokio::test]
    async fn api_set_default_mode_returns_bad_request_for_unknown_mode() {
        let state = test_state("secret").await;
        let mut headers = HeaderMap::new();
        headers.insert("x-admin-token", HeaderValue::from_static("secret"));

        let response = api_set_default_mode(
            State(state),
            headers,
            Json(SetDefaultModeReq {
                mode_id: "missing".to_string(),
            }),
        )
        .await
        .into_response();

        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
    }
}