use std::sync::Arc;
use axum::{Json, Router, extract::State, routing::{get, post}};
use serde_json::json;
use crate::{config::QueryRequest, error::AppError, query::QueryEngine};
#[derive(Clone)]
pub struct AppState {
engine: Arc<QueryEngine>,
}
impl AppState {
pub fn new(engine: QueryEngine) -> Self {
Self {
engine: Arc::new(engine),
}
}
}
pub fn build_router(state: AppState) -> Router {
Router::new()
.route("/health", get(health))
.route("/query", post(query))
.with_state(state)
}
async fn health() -> Json<serde_json::Value> {
Json(json!({"status": "ok"}))
}
async fn query(
State(state): State<AppState>,
Json(request): Json<QueryRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let response = state.engine.execute(request).await?;
Ok(Json(serde_json::to_value(response)?))
}
#[cfg(test)]
mod tests {
use std::{fs::File, sync::Arc};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use datafusion::arrow::{
array::{Int64Array, StringArray},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use parquet::arrow::ArrowWriter;
use serde_json::json;
use tempfile::tempdir;
use tower::ServiceExt;
use super::*;
use crate::{config::{DatasetConfig, StorageConfig}, query::QueryEngine};
#[tokio::test]
async fn health_endpoint_reports_ok() {
let app = build_router(AppState::new(QueryEngine::new()));
let response = app
.oneshot(Request::builder().uri("/health").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn query_endpoint_executes_sql_against_local_parquet() {
let dir = tempdir().unwrap();
let path = dir.path().join("sample.parquet");
write_parquet(&path).unwrap();
let app = build_router(AppState::new(QueryEngine::new()));
let payload = json!({
"sql": "SELECT city FROM trips WHERE trips >= 20 ORDER BY city",
"datasets": [
{
"table_name": "trips",
"uri": path.to_string_lossy(),
"storage": { "type": "local" }
}
]
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/query")
.header("content-type", "application/json")
.body(Body::from(payload.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["row_count"], serde_json::Value::Number(1.into()));
assert_eq!(json["rows"][0]["city"], serde_json::Value::String("seattle".to_string()));
}
fn write_parquet(path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
let schema = Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("trips", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["vancouver", "seattle"])) as Arc<dyn datafusion::arrow::array::Array>,
Arc::new(Int64Array::from(vec![12, 22])) as Arc<dyn datafusion::arrow::array::Array>,
],
)?;
let file = File::create(path)?;
let mut writer = ArrowWriter::try_new(file, schema, None)?;
writer.write(&batch)?;
writer.close()?;
Ok(())
}
}