genome-sh 0.1.0

The jq of genomics. Fast, local, human-readable variant analysis.
use std::sync::Arc;

use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use axum::routing::get;
use axum::Router;
use serde::{Deserialize, Serialize};

use super::AppState;
use crate::variant::VariantQuery;

pub fn router() -> Router<Arc<AppState>> {
    Router::new()
        .route("/", get(index))
        .route("/v1/query/{query}", get(query_variant))
        .route("/v1/variant/{chrom}/{pos}/{ref}/{alt}", get(query_by_coordinates))
        .route("/v1/gene/{gene}", get(query_by_gene))
        .route("/v1/gnomad/{query}", get(query_gnomad))
        .route("/v1/sources", get(list_sources))
        .route("/v1/stats", get(stats))
        .route("/health", get(health))
}

#[derive(Serialize)]
struct IndexResponse {
    name: &'static str,
    version: &'static str,
    docs: &'static str,
    endpoints: Vec<&'static str>,
}

async fn index() -> Json<IndexResponse> {
    Json(IndexResponse {
        name: "genome-sh API",
        version: env!("CARGO_PKG_VERSION"),
        docs: "https://genome.sh/docs",
        endpoints: vec![
            "GET /v1/query/{rsid|gene|hgvs}",
            "GET /v1/variant/{chrom}/{pos}/{ref}/{alt}",
            "GET /v1/gene/{gene}",
            "GET /v1/gnomad/{rsid}",
            "GET /v1/sources",
            "GET /v1/stats",
            "GET /health",
        ],
    })
}

#[derive(Deserialize)]
struct QueryParams {
    /// Maximum number of results.
    #[serde(default = "default_limit")]
    limit: usize,
}

fn default_limit() -> usize {
    100
}

async fn query_variant(
    State(state): State<Arc<AppState>>,
    Path(query_str): Path<String>,
    Query(params): Query<QueryParams>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
    let query = VariantQuery::parse(&query_str).map_err(|e| {
        (
            StatusCode::BAD_REQUEST,
            Json(ErrorResponse {
                error: format!("Invalid query: {e}"),
            }),
        )
    })?;

    let results = state.db.query(&query).map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: format!("Database error: {e}"),
            }),
        )
    })?;

    let limited: Vec<_> = results.into_iter().take(params.limit).collect();
    let json = serde_json::to_value(&limited).unwrap_or_default();

    Ok(Json(json))
}

async fn query_by_coordinates(
    State(state): State<Arc<AppState>>,
    Path((chrom, pos, reference, alt)): Path<(String, u64, String, String)>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
    let query = VariantQuery::Coordinates {
        chrom: normalize_chrom(&chrom),
        pos,
        r#ref: reference,
        alt,
    };

    let results = state.db.query(&query).map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: format!("Database error: {e}"),
            }),
        )
    })?;

    let json = serde_json::to_value(&results).unwrap_or_default();
    Ok(Json(json))
}

async fn query_by_gene(
    State(state): State<Arc<AppState>>,
    Path(gene): Path<String>,
    Query(params): Query<QueryParams>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
    let query = VariantQuery::Gene(gene.to_uppercase());

    let results = state.db.query(&query).map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: format!("Database error: {e}"),
            }),
        )
    })?;

    let limited: Vec<_> = results.into_iter().take(params.limit).collect();
    let json = serde_json::to_value(&limited).unwrap_or_default();

    Ok(Json(json))
}

async fn query_gnomad(
    State(state): State<Arc<AppState>>,
    Path(query_str): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
    // Proxy to gnomAD GraphQL API and cache results.
    let result = state.gnomad_client.query(&query_str).await.map_err(|e| {
        (
            StatusCode::BAD_GATEWAY,
            Json(ErrorResponse {
                error: format!("gnomAD API error: {e}"),
            }),
        )
    })?;

    Ok(Json(result))
}

#[derive(Serialize)]
struct SourceInfo {
    source: String,
    version: String,
    variant_count: i64,
    updated_at: String,
}

async fn list_sources(
    State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<SourceInfo>>, (StatusCode, Json<ErrorResponse>)> {
    let conn = state.db.connection();
    let mut stmt = conn
        .prepare("SELECT source, version, variant_count, updated_at FROM db_meta ORDER BY source")
        .map_err(|e| {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Database error: {e}"),
                }),
            )
        })?;

    let sources = stmt
        .query_map([], |row| {
            Ok(SourceInfo {
                source: row.get(0)?,
                version: row.get::<_, Option<String>>(1)?.unwrap_or_default(),
                variant_count: row.get(2).unwrap_or(0),
                updated_at: row.get::<_, Option<String>>(3)?.unwrap_or_default(),
            })
        })
        .map_err(|e| {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Database error: {e}"),
                }),
            )
        })?
        .filter_map(|r| r.ok())
        .collect();

    Ok(Json(sources))
}

#[derive(Serialize)]
struct StatsResponse {
    total_variants: i64,
    sources: Vec<SourceInfo>,
}

async fn stats(
    State(state): State<Arc<AppState>>,
) -> Result<Json<StatsResponse>, (StatusCode, Json<ErrorResponse>)> {
    let total: i64 = {
        let conn = state.db.connection();
        conn.query_row("SELECT COUNT(*) FROM variants", [], |row| row.get(0))
            .unwrap_or(0)
    };

    let sources_result = list_sources(State(state)).await?;

    Ok(Json(StatsResponse {
        total_variants: total,
        sources: sources_result.0,
    }))
}

async fn health() -> Json<serde_json::Value> {
    Json(serde_json::json!({ "status": "ok" }))
}

#[derive(Serialize)]
struct ErrorResponse {
    error: String,
}

fn normalize_chrom(chrom: &str) -> String {
    if chrom.starts_with("chr") {
        chrom.to_string()
    } else {
        format!("chr{chrom}")
    }
}