use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use quill_sql::database::{Database, DatabaseOptions};
use quill_sql::transaction::IsolationLevel;
use std::str::FromStr;
#[derive(Clone)]
struct AppState {
db: Arc<std::sync::Mutex<Database>>,
options: DatabaseOptions,
}
fn rebuild_db(opts: &DatabaseOptions) -> Database {
if let Ok(path) = std::env::var("QUILL_DB_FILE") {
Database::new_on_disk_with_options(&path, opts.clone()).expect("open db file")
} else {
Database::new_temp_with_options(opts.clone()).expect("open temp db")
}
}
fn lock_or_rebuild_db<'a>(
state: &'a AppState,
) -> Result<std::sync::MutexGuard<'a, Database>, (StatusCode, String)> {
match state.db.lock() {
Ok(guard) => Ok(guard),
Err(poisoned) => {
let mut guard = poisoned.into_inner();
*guard = rebuild_db(&state.options);
Ok(guard)
}
}
}
async fn rebuild(
State(state): State<AppState>,
) -> Result<Json<&'static str>, (StatusCode, String)> {
let new_db = rebuild_db(&state.options);
let mut db_guard = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
*db_guard = new_db;
Ok(Json("rebuilt"))
}
#[derive(Deserialize)]
struct SqlRequest {
sql: String,
}
#[derive(Serialize)]
struct SqlResponse {
rows: Vec<Vec<String>>, }
#[derive(Serialize)]
struct SqlBatchResponse {
results: Vec<Vec<Vec<String>>>,
}
fn strip_sql_comments(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for line in input.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with("--") {
continue;
}
out.push_str(line);
out.push('\n');
}
out
}
async fn debug_locks_snapshot(
State(state): State<AppState>,
) -> Result<Json<quill_sql::transaction::LockDebugSnapshot>, (StatusCode, String)> {
let db_guard = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
Ok(Json(db_guard.debug_lock_snapshot()))
}
async fn debug_trace_last(
State(state): State<AppState>,
) -> Result<Json<quill_sql::database::DebugTrace>, (StatusCode, String)> {
let db_guard = lock_or_rebuild_db(&state)?;
match db_guard.debug_last_trace() {
Some(trace) => Ok(Json(trace)),
None => Err((StatusCode::NOT_FOUND, "no query executed yet".to_string())),
}
}
async fn debug_plan_last(
State(state): State<AppState>,
) -> Result<Json<quill_sql::database::DebugPlanSnapshot>, (StatusCode, String)> {
let db_guard = lock_or_rebuild_db(&state)?;
match db_guard.debug_last_plan() {
Some(plan) => Ok(Json(plan)),
None => Err((StatusCode::NOT_FOUND, "no query executed yet".to_string())),
}
}
async fn debug_mvcc_versions(
State(state): State<AppState>,
) -> Result<Json<quill_sql::database::MvccVersionsDebug>, (StatusCode, String)> {
let db = lock_or_rebuild_db(&state)?;
db.debug_mvcc_versions()
.map(Json)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", e)))
}
async fn debug_txns(
State(state): State<AppState>,
) -> Result<Json<quill_sql::transaction::TxnDebugSnapshot>, (StatusCode, String)> {
let db = lock_or_rebuild_db(&state)?;
Ok(Json(db.debug_txn_snapshot()))
}
#[tokio::main]
async fn main() {
env_logger::init();
let default_isolation_level = std::env::var("QUILL_DEFAULT_ISOLATION")
.ok()
.as_deref()
.map(IsolationLevel::from_str)
.transpose()
.unwrap_or_else(|e| panic!("invalid QUILL_DEFAULT_ISOLATION: {}", e));
let db_options = DatabaseOptions {
default_isolation_level,
..DatabaseOptions::default()
};
let db = if let Ok(path) = std::env::var("QUILL_DB_FILE") {
Database::new_on_disk_with_options(&path, db_options.clone()).expect("open db file")
} else {
Database::new_temp_with_options(db_options.clone()).expect("open temp db")
};
let state = AppState {
db: Arc::new(std::sync::Mutex::new(db)),
options: db_options,
};
let static_service =
tower_http::services::ServeDir::new("public").append_index_html_on_directories(true);
let docs_service = tower_http::services::ServeDir::new("docs");
let app = Router::new()
.route("/api/sql", post(api_sql))
.route("/api/sql_batch", post(api_sql_batch))
.route("/admin/rebuild", post(rebuild))
.route("/debug/locks/snapshot", get(debug_locks_snapshot))
.route("/debug/txns", get(debug_txns))
.route("/debug/trace/last", get(debug_trace_last))
.route("/debug/mvcc/versions", get(debug_mvcc_versions))
.route("/debug/plan/last", get(debug_plan_last))
.nest_service("/docs", docs_service)
.fallback_service(static_service)
.with_state(state);
let app = app.layer(tower_http::cors::CorsLayer::very_permissive());
let bind_addr = if let Ok(port) = std::env::var("PORT") {
format!("0.0.0.0:{}", port)
} else {
std::env::var("QUILL_HTTP_ADDR").unwrap_or_else(|_| "0.0.0.0:8080".to_string())
};
let addr: SocketAddr = bind_addr.parse().expect("invalid bind addr");
println!("Serving on http://{}", addr);
axum::serve(
tokio::net::TcpListener::bind(addr)
.await
.expect("bind http"),
app,
)
.await
.expect("server error");
}
async fn api_sql(
State(state): State<AppState>,
Json(req): Json<SqlRequest>,
) -> Result<Json<SqlResponse>, (StatusCode, String)> {
let mut db = lock_or_rebuild_db(&state)?;
let cleaned = strip_sql_comments(&req.sql);
let tuples = db
.run(&cleaned)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
let rows = tuples
.into_iter()
.map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
.collect();
Ok(Json(SqlResponse { rows }))
}
async fn api_sql_batch(
State(state): State<AppState>,
Json(req): Json<SqlRequest>,
) -> Result<Json<SqlBatchResponse>, (StatusCode, String)> {
let mut db = lock_or_rebuild_db(&state)?;
let cleaned = strip_sql_comments(&req.sql);
let statements = cleaned
.split(';')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.take(100);
let mut results: Vec<Vec<Vec<String>>> = Vec::new();
for stmt in statements {
let tuples = db
.run(stmt)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
let rows: Vec<Vec<String>> = tuples
.into_iter()
.map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
.collect();
results.push(rows);
}
Ok(Json(SqlBatchResponse { results }))
}