use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::net::SocketAddr;
use std::sync::Arc;
use quill_sql::database::{Database, DatabaseOptions};
use quill_sql::transaction::IsolationLevel;
use std::str::FromStr;
#[derive(Clone)]
struct AppState {
db: Arc<std::sync::Mutex<Database>>,
}
#[derive(Deserialize)]
struct SqlRequest {
sql: String,
}
#[derive(Serialize)]
struct SqlResponse {
rows: Vec<Vec<String>>, }
#[derive(Serialize)]
struct SqlBatchResponse {
results: Vec<Vec<Vec<String>>>,
}
fn strip_sql_comments(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for line in input.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with("--") {
continue;
}
out.push_str(line);
out.push('\n');
}
out
}
async fn api_examples() -> Result<Json<HashMap<String, String>>, (StatusCode, String)> {
const EXAMPLE_DIR: &str = "src/tests/sql_example/";
let mut examples = HashMap::new();
match fs::read_dir(EXAMPLE_DIR) {
Ok(entries) => {
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("slt") {
if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
match fs::read_to_string(&path) {
Ok(content) => {
examples.insert(name.to_string(), content);
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"Failed to read example file {}: {}",
path.display(),
e
),
));
}
}
}
}
}
}
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read examples directory {}: {}", EXAMPLE_DIR, e),
));
}
}
Ok(Json(examples))
}
#[tokio::main]
async fn main() {
env_logger::init();
let default_isolation_level = std::env::var("QUILL_DEFAULT_ISOLATION")
.ok()
.as_deref()
.map(IsolationLevel::from_str)
.transpose()
.unwrap_or_else(|e| panic!("invalid QUILL_DEFAULT_ISOLATION: {}", e));
let db_options = DatabaseOptions {
default_isolation_level,
..DatabaseOptions::default()
};
let db = if let Ok(path) = std::env::var("QUILL_DB_FILE") {
Database::new_on_disk_with_options(&path, db_options.clone()).expect("open db file")
} else {
Database::new_temp_with_options(db_options).expect("open temp db")
};
let state = AppState {
db: Arc::new(std::sync::Mutex::new(db)),
};
let static_service =
tower_http::services::ServeDir::new("public").append_index_html_on_directories(true);
let docs_service = tower_http::services::ServeDir::new("docs");
let app = Router::new()
.route("/api/sql", post(api_sql))
.route("/api/sql_batch", post(api_sql_batch))
.route("/api/examples", get(api_examples))
.nest_service("/docs", docs_service)
.fallback_service(static_service)
.with_state(state);
let app = app.layer(tower_http::cors::CorsLayer::very_permissive());
let bind_addr = if let Ok(port) = std::env::var("PORT") {
format!("0.0.0.0:{}", port)
} else {
std::env::var("QUILL_HTTP_ADDR").unwrap_or_else(|_| "0.0.0.0:8080".to_string())
};
let addr: SocketAddr = bind_addr.parse().expect("invalid bind addr");
println!("Serving on http://{}", addr);
axum::serve(
tokio::net::TcpListener::bind(addr)
.await
.expect("bind http"),
app,
)
.await
.expect("server error");
}
async fn api_sql(
State(state): State<AppState>,
Json(req): Json<SqlRequest>,
) -> Result<Json<SqlResponse>, (StatusCode, String)> {
let mut db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
let cleaned = strip_sql_comments(&req.sql);
let tuples = db
.run(&cleaned)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
let rows = tuples
.into_iter()
.map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
.collect();
Ok(Json(SqlResponse { rows }))
}
async fn api_sql_batch(
State(state): State<AppState>,
Json(req): Json<SqlRequest>,
) -> Result<Json<SqlBatchResponse>, (StatusCode, String)> {
let mut db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB poisoned".to_string()))?;
let cleaned = strip_sql_comments(&req.sql);
let statements = cleaned
.split(';')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.take(100);
let mut results: Vec<Vec<Vec<String>>> = Vec::new();
for stmt in statements {
let tuples = db
.run(stmt)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("{}", e)))?;
let rows: Vec<Vec<String>> = tuples
.into_iter()
.map(|t| t.data.into_iter().map(|v| format!("{}", v)).collect())
.collect();
results.push(rows);
}
Ok(Json(SqlBatchResponse { results }))
}