quill-sql 0.3.1

An educational Rust relational database (RDBMS) inspired by CMU 15445
Documentation
use axum::{
    extract::State,
    http::StatusCode,
    routing::{get, post},
    Json, Router,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;

use quill_sql::database::{Database, DatabaseOptions};
use quill_sql::transaction::IsolationLevel;
use std::str::FromStr;

/// Shared app state holding a Database protected by a mutex.
#[derive(Clone)]
struct AppState {
    db: Arc<std::sync::Mutex<Database>>,
    options: DatabaseOptions,
}

fn rebuild_db(opts: &DatabaseOptions) -> Database {
    if let Ok(path) = std::env::var("QUILL_DB_FILE") {
        Database::new_on_disk_with_options(&path, opts.clone()).expect("open db file")
    } else {
        Database::new_temp_with_options(opts.clone()).expect("open temp db")
    }
}

fn lock_or_rebuild_db<'a>(
    state: &'a AppState,
) -> Result<std::sync::MutexGuard<'a, Database>, (StatusCode, String)> {
    match state.db.lock() {
        Ok(guard) => Ok(guard),
        Err(poisoned) => {
            let mut guard = poisoned.into_inner();
            *guard = rebuild_db(&state.options);
            Ok(guard)
        }
    }
}

async fn rebuild(
    State(state): State<AppState>,
) -> Result<Json<&'static str>, (StatusCode, String)> {
    let new_db = rebuild_db(&state.options);
    let mut db_guard = state
        .db
        .lock()
        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
    *db_guard = new_db;
    Ok(Json("rebuilt"))
}

/// Request payload for /api/sql
#[derive(Deserialize)]
struct SqlRequest {
    sql: String,
}

/// Response payload for /api/sql
#[derive(Serialize)]
struct SqlResponse {
    rows: Vec<Vec<String>>, // simple strings for frontend consumption
}

/// Response payload for /api/sql_batch
#[derive(Serialize)]
struct SqlBatchResponse {
    results: Vec<Vec<Vec<String>>>,
}

/// Remove single-line SQL comments beginning with `--`.
/// This intentionally does NOT strip inline comments inside string literals.
fn strip_sql_comments(input: &str) -> String {
    let mut out = String::with_capacity(input.len());
    for line in input.lines() {
        let trimmed = line.trim_start();
        if trimmed.starts_with("--") {
            continue;
        }
        out.push_str(line);
        out.push('\n');
    }
    out
}

async fn debug_locks_snapshot(
    State(state): State<AppState>,
) -> Result<Json<quill_sql::transaction::LockDebugSnapshot>, (StatusCode, String)> {
    let db_guard = state
        .db
        .lock()
        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
    Ok(Json(db_guard.debug_lock_snapshot()))
}

async fn debug_trace_last(
    State(state): State<AppState>,
) -> Result<Json<quill_sql::database::DebugTrace>, (StatusCode, String)> {
    let db_guard = lock_or_rebuild_db(&state)?;
    match db_guard.debug_last_trace() {
        Some(trace) => Ok(Json(trace)),
        None => Err((StatusCode::NOT_FOUND, "no query executed yet".to_string())),
    }
}

async fn debug_plan_last(
    State(state): State<AppState>,
) -> Result<Json<quill_sql::database::DebugPlanSnapshot>, (StatusCode, String)> {
    let db_guard = lock_or_rebuild_db(&state)?;
    match db_guard.debug_last_plan() {
        Some(plan) => Ok(Json(plan)),
        None => Err((StatusCode::NOT_FOUND, "no query executed yet".to_string())),
    }
}

async fn debug_mvcc_versions(
    State(state): State<AppState>,
) -> Result<Json<quill_sql::database::MvccVersionsDebug>, (StatusCode, String)> {
    let db = lock_or_rebuild_db(&state)?;
    db.debug_mvcc_versions()
        .map(Json)
        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", e)))
}

async fn debug_txns(
    State(state): State<AppState>,
) -> Result<Json<quill_sql::transaction::TxnDebugSnapshot>, (StatusCode, String)> {
    let db = lock_or_rebuild_db(&state)?;
    Ok(Json(db.debug_txn_snapshot()))
}

#[tokio::main]
async fn main() {
    env_logger::init();

    // Build database (in-memory Holt temp by default).
    let default_isolation_level = std::env::var("QUILL_DEFAULT_ISOLATION")
        .ok()
        .as_deref()
        .map(IsolationLevel::from_str)
        .transpose()
        .unwrap_or_else(|e| panic!("invalid QUILL_DEFAULT_ISOLATION: {}", e));
    let db_options = DatabaseOptions {
        default_isolation_level,
        ..DatabaseOptions::default()
    };

    let db = if let Ok(path) = std::env::var("QUILL_DB_FILE") {
        Database::new_on_disk_with_options(&path, db_options.clone()).expect("open db file")
    } else {
        Database::new_temp_with_options(db_options.clone()).expect("open temp db")
    };

    let state = AppState {
        db: Arc::new(std::sync::Mutex::new(db)),
        options: db_options,
    };

    // Static services
    let static_service =
        tower_http::services::ServeDir::new("public").append_index_html_on_directories(true);
    let docs_service = tower_http::services::ServeDir::new("docs");

    let app = Router::new()
        .route("/api/sql", post(api_sql))
        .route("/api/sql_batch", post(api_sql_batch))
        .route("/admin/rebuild", post(rebuild))
        .route("/debug/locks/snapshot", get(debug_locks_snapshot))
        .route("/debug/txns", get(debug_txns))
        .route("/debug/trace/last", get(debug_trace_last))
        .route("/debug/mvcc/versions", get(debug_mvcc_versions))
        .route("/debug/plan/last", get(debug_plan_last))
        .nest_service("/docs", docs_service)
        .fallback_service(static_service)
        .with_state(state);

    // CORS for simple local testing
    let app = app.layer(tower_http::cors::CorsLayer::very_permissive());

    // Bind address: prefer PORT for platforms like Vercel/Heroku
    let bind_addr = if let Ok(port) = std::env::var("PORT") {
        format!("0.0.0.0:{}", port)
    } else {
        std::env::var("QUILL_HTTP_ADDR").unwrap_or_else(|_| "0.0.0.0:8080".to_string())
    };
    let addr: SocketAddr = bind_addr.parse().expect("invalid bind addr");
    println!("Serving on http://{}", addr);
    axum::serve(
        tokio::net::TcpListener::bind(addr)
            .await
            .expect("bind http"),
        app,
    )
    .await
    .expect("server error");
}

/// Execute SQL and return rows of strings
async fn api_sql(
    State(state): State<AppState>,
    Json(req): Json<SqlRequest>,
) -> Result<Json<SqlResponse>, (StatusCode, String)> {
    let mut db = lock_or_rebuild_db(&state)?;
    let cleaned = strip_sql_comments(&req.sql);
    let tuples = db
        .run(&cleaned)
        .map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
    let rows = tuples
        .into_iter()
        .map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
        .collect();
    Ok(Json(SqlResponse { rows }))
}

/// Execute multiple SQL statements separated by ';' and return all result sets
async fn api_sql_batch(
    State(state): State<AppState>,
    Json(req): Json<SqlRequest>,
) -> Result<Json<SqlBatchResponse>, (StatusCode, String)> {
    let mut db = lock_or_rebuild_db(&state)?;
    let cleaned = strip_sql_comments(&req.sql);
    let statements = cleaned
        .split(';')
        .map(|s| s.trim())
        .filter(|s| !s.is_empty())
        .take(100);
    let mut results: Vec<Vec<Vec<String>>> = Vec::new();
    for stmt in statements {
        let tuples = db
            .run(stmt)
            .map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
        let rows: Vec<Vec<String>> = tuples
            .into_iter()
            .map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
            .collect();
        results.push(rows);
    }
    Ok(Json(SqlBatchResponse { results }))
}