use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::response::Json;
use axum::routing::get;
use axum::Router;
use serde::{Deserialize, Serialize};
use super::AppState;
use crate::variant::VariantQuery;
pub fn router() -> Router<Arc<AppState>> {
Router::new()
.route("/", get(index))
.route("/v1/query/{query}", get(query_variant))
.route("/v1/variant/{chrom}/{pos}/{ref}/{alt}", get(query_by_coordinates))
.route("/v1/gene/{gene}", get(query_by_gene))
.route("/v1/gnomad/{query}", get(query_gnomad))
.route("/v1/sources", get(list_sources))
.route("/v1/stats", get(stats))
.route("/health", get(health))
}
#[derive(Serialize)]
struct IndexResponse {
name: &'static str,
version: &'static str,
docs: &'static str,
endpoints: Vec<&'static str>,
}
async fn index() -> Json<IndexResponse> {
Json(IndexResponse {
name: "genome-sh API",
version: env!("CARGO_PKG_VERSION"),
docs: "https://genome.sh/docs",
endpoints: vec![
"GET /v1/query/{rsid|gene|hgvs}",
"GET /v1/variant/{chrom}/{pos}/{ref}/{alt}",
"GET /v1/gene/{gene}",
"GET /v1/gnomad/{rsid}",
"GET /v1/sources",
"GET /v1/stats",
"GET /health",
],
})
}
#[derive(Deserialize)]
struct QueryParams {
#[serde(default = "default_limit")]
limit: usize,
}
fn default_limit() -> usize {
100
}
async fn query_variant(
State(state): State<Arc<AppState>>,
Path(query_str): Path<String>,
Query(params): Query<QueryParams>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let query = VariantQuery::parse(&query_str).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid query: {e}"),
}),
)
})?;
let results = state.db.query(&query).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Database error: {e}"),
}),
)
})?;
let limited: Vec<_> = results.into_iter().take(params.limit).collect();
let json = serde_json::to_value(&limited).unwrap_or_default();
Ok(Json(json))
}
async fn query_by_coordinates(
State(state): State<Arc<AppState>>,
Path((chrom, pos, reference, alt)): Path<(String, u64, String, String)>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let query = VariantQuery::Coordinates {
chrom: normalize_chrom(&chrom),
pos,
r#ref: reference,
alt,
};
let results = state.db.query(&query).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Database error: {e}"),
}),
)
})?;
let json = serde_json::to_value(&results).unwrap_or_default();
Ok(Json(json))
}
async fn query_by_gene(
State(state): State<Arc<AppState>>,
Path(gene): Path<String>,
Query(params): Query<QueryParams>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let query = VariantQuery::Gene(gene.to_uppercase());
let results = state.db.query(&query).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Database error: {e}"),
}),
)
})?;
let limited: Vec<_> = results.into_iter().take(params.limit).collect();
let json = serde_json::to_value(&limited).unwrap_or_default();
Ok(Json(json))
}
async fn query_gnomad(
State(state): State<Arc<AppState>>,
Path(query_str): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResponse>)> {
let result = state.gnomad_client.query(&query_str).await.map_err(|e| {
(
StatusCode::BAD_GATEWAY,
Json(ErrorResponse {
error: format!("gnomAD API error: {e}"),
}),
)
})?;
Ok(Json(result))
}
#[derive(Serialize)]
struct SourceInfo {
source: String,
version: String,
variant_count: i64,
updated_at: String,
}
async fn list_sources(
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<SourceInfo>>, (StatusCode, Json<ErrorResponse>)> {
let conn = state.db.connection();
let mut stmt = conn
.prepare("SELECT source, version, variant_count, updated_at FROM db_meta ORDER BY source")
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Database error: {e}"),
}),
)
})?;
let sources = stmt
.query_map([], |row| {
Ok(SourceInfo {
source: row.get(0)?,
version: row.get::<_, Option<String>>(1)?.unwrap_or_default(),
variant_count: row.get(2).unwrap_or(0),
updated_at: row.get::<_, Option<String>>(3)?.unwrap_or_default(),
})
})
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Database error: {e}"),
}),
)
})?
.filter_map(|r| r.ok())
.collect();
Ok(Json(sources))
}
#[derive(Serialize)]
struct StatsResponse {
total_variants: i64,
sources: Vec<SourceInfo>,
}
async fn stats(
State(state): State<Arc<AppState>>,
) -> Result<Json<StatsResponse>, (StatusCode, Json<ErrorResponse>)> {
let total: i64 = {
let conn = state.db.connection();
conn.query_row("SELECT COUNT(*) FROM variants", [], |row| row.get(0))
.unwrap_or(0)
};
let sources_result = list_sources(State(state)).await?;
Ok(Json(StatsResponse {
total_variants: total,
sources: sources_result.0,
}))
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
fn normalize_chrom(chrom: &str) -> String {
if chrom.starts_with("chr") {
chrom.to_string()
} else {
format!("chr{chrom}")
}
}