use std::sync::Arc;
use axum::{
extract::State,
routing::{get, post},
Json, Router,
};
use serde_json::json;
use tokio::net::TcpListener;
use tracing::info;
use crate::{
config::{BenchmarkRequest, ConvertRequest, QueryRequest, RuntimeConfig, SchemaRequest},
engine::Text2SqlEngine,
error::{AppError, AppResult},
};
#[derive(Clone)]
pub struct AppState {
engine: Arc<Text2SqlEngine>,
service_name: String,
version: &'static str,
}
impl AppState {
pub fn new(service_name: impl Into<String>) -> Self {
Self {
engine: Arc::new(Text2SqlEngine::new()),
service_name: service_name.into(),
version: env!("CARGO_PKG_VERSION"),
}
}
}
pub fn build_router(state: AppState) -> Router {
Router::new()
.route("/", get(root))
.route("/health", get(health))
.route("/schema", post(schema))
.route("/query", post(query))
.route("/convert", post(convert))
.route("/benchmark", post(benchmark))
.with_state(state)
}
pub async fn run(config: RuntimeConfig) -> AppResult<()> {
let listener = TcpListener::bind(config.bind_addr)
.await
.map_err(AppError::Bind)?;
let local_addr = listener.local_addr().map_err(AppError::Bind)?;
let app = build_router(AppState::new(config.service_name.clone()));
info!(%local_addr, service = %config.service_name, "text2sql listening");
axum::serve(listener, app).await.map_err(AppError::Serve)
}
async fn root(State(state): State<AppState>) -> Json<serde_json::Value> {
Json(json!({
"service": state.service_name,
"version": state.version,
"capabilities": [
"health-check",
"schema-inspection",
"parquet-selective-query",
"csv-xlsx-json-to-parquet-conversion",
"benchmark-10k-50k-100k-500k"
]
}))
}
async fn health(State(state): State<AppState>) -> Json<serde_json::Value> {
Json(json!({"status": "ok", "service": state.service_name, "version": state.version}))
}
async fn query(
State(state): State<AppState>,
Json(request): Json<QueryRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let response = state.engine.execute_query(request).await?;
Ok(Json(serde_json::to_value(response)?))
}
async fn schema(
State(state): State<AppState>,
Json(request): Json<SchemaRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let response = state.engine.inspect_schema(request).await?;
Ok(Json(serde_json::to_value(response)?))
}
async fn convert(
State(state): State<AppState>,
Json(request): Json<ConvertRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let response = state.engine.convert(request).await?;
Ok(Json(serde_json::to_value(response)?))
}
async fn benchmark(
State(state): State<AppState>,
Json(request): Json<BenchmarkRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let response = state.engine.benchmark(request).await?;
Ok(Json(serde_json::to_value(response)?))
}
#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tempfile::tempdir;
use tower::util::ServiceExt;
use super::*;
#[tokio::test]
async fn health_route_reports_ok() {
let response = build_router(AppState::new("test-service"))
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn query_route_returns_json_error_for_invalid_sql() {
let response = build_router(AppState::new("test-service"))
.oneshot(
Request::builder()
.method("POST")
.uri("/query")
.header("content-type", "application/json")
.body(Body::from(r#"{"sql":"SELECT * FROM missing_table","dataset":{"uri":"/tmp/does-not-exist.parquet","storage":{"type":"local"}},"table_name":"dataset","mode":"parquet_selective"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let payload: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(payload["error"].as_str().unwrap().contains("duckdb error"));
}
#[tokio::test]
async fn convert_route_converts_csv_end_to_end() {
let dir = tempdir().unwrap();
let csv_path = dir.path().join("mini.csv");
let parquet_path = dir.path().join("mini.parquet");
std::fs::write(
&csv_path,
"city,value
Seoul,10
Busan,20
",
)
.unwrap();
let payload = serde_json::json!({
"input_path": csv_path.to_string_lossy(),
"output": {
"uri": parquet_path.to_string_lossy(),
"storage": {"type": "local"}
},
"normalize_columns": true,
"overwrite": true
});
let response = build_router(AppState::new("test-service"))
.oneshot(
Request::builder()
.method("POST")
.uri("/convert")
.header("content-type", "application/json")
.body(Body::from(payload.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(parquet_path.exists());
}
#[tokio::test]
async fn schema_route_returns_columns_without_sql() {
let dir = tempdir().unwrap();
let csv_path = dir.path().join("mini.csv");
let parquet_path = dir.path().join("mini.parquet");
std::fs::write(&csv_path, "city,value\nSeoul,10\nBusan,20\n").unwrap();
let convert_payload = serde_json::json!({
"input_path": csv_path.to_string_lossy(),
"output": {
"uri": parquet_path.to_string_lossy(),
"storage": {"type": "local"}
},
"normalize_columns": true,
"overwrite": true
});
let app = build_router(AppState::new("test-service"));
let convert_response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/convert")
.header("content-type", "application/json")
.body(Body::from(convert_payload.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(convert_response.status(), StatusCode::OK);
let schema_payload = serde_json::json!({
"dataset": {
"uri": parquet_path.to_string_lossy(),
"storage": {"type": "local"}
}
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/schema")
.header("content-type", "application/json")
.body(Body::from(schema_payload.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let payload: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(payload["table_name"], "dataset");
assert_eq!(payload["columns"][0]["name"], "city");
assert!(payload["notes"]
.as_array()
.unwrap()
.iter()
.any(|value| value.as_str().unwrap().contains("before generating SQL")));
}
#[tokio::test]
async fn benchmark_route_returns_results() {
let dir = tempdir().unwrap();
let payload = serde_json::json!({
"output_dir": dir.path().to_string_lossy(),
"row_counts": [1000, 2000],
"sql": "SELECT symbol, AVG(trade_value) AS avg_trade_value FROM dataset WHERE trade_month = 9 GROUP BY symbol ORDER BY avg_trade_value DESC"
});
let response = build_router(AppState::new("test-service"))
.oneshot(
Request::builder()
.method("POST")
.uri("/benchmark")
.header("content-type", "application/json")
.body(Body::from(payload.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let payload: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(payload["results"].as_array().unwrap().len(), 2);
}
#[tokio::test]
async fn root_route_lists_new_capabilities() {
let response = build_router(AppState::new("test-service"))
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let payload: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(payload["capabilities"]
.as_array()
.unwrap()
.iter()
.any(|value| value == "parquet-selective-query"));
assert!(payload["capabilities"]
.as_array()
.unwrap()
.iter()
.any(|value| value == "schema-inspection"));
}
}