use std::{io::Cursor, time::Duration};
use axum::{
body::Body,
extract::{Json, OriginalUri, Path, Query, State},
response::{IntoResponse, Response},
routing::{get, post},
Router,
};
use datafusion::{arrow::json::ArrayWriter, execution::SendableRecordBatchStream};
use datafusion_app::{observability::ObservabilityRequestDetails, ExecOptions, ExecResult};
use http::{HeaderValue, StatusCode};
use jiff::Timestamp;
use log::error;
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use tower_http::{timeout::TimeoutLayer, trace::TraceLayer};
use tracing::debug;
use crate::{config::HttpServerConfig, execution::AppExecution};
use super::tpch;
#[derive(Debug)]
struct ExecRequest {
path: String,
sql: String,
}
#[derive(Clone)]
struct ExecutionState {
execution: AppExecution,
config: HttpServerConfig,
}
impl ExecutionState {
pub fn new(execution: AppExecution, config: HttpServerConfig) -> Self {
Self { execution, config }
}
}
pub fn create_router(execution: AppExecution, config: HttpServerConfig) -> Router {
let state = ExecutionState::new(execution, config);
Router::new()
.route(
"/",
get(|State(_): State<ExecutionState>| async { "Hello, from DFT!" }),
)
.route(
"/health-check",
get(|State(_): State<ExecutionState>| async { "Healthy" }),
)
.route("/sql", post(post_sql_handler))
.route("/catalog", get(get_catalog_handler))
.route("/tpch/:number", get(get_tpch_query_handler))
.route("/table/:catalog/:schema/:table", get(get_table_handler))
.layer((
TraceLayer::new_for_http(),
TimeoutLayer::with_status_code(
http::StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(state.config.timeout_seconds),
),
))
.with_state(state)
}
#[derive(Deserialize)]
struct PostSqlBody {
sql: String,
#[serde(default)]
flightsql: bool,
}
async fn post_sql_handler(
state: State<ExecutionState>,
OriginalUri(uri): OriginalUri,
Json(body): Json<PostSqlBody>,
) -> Response {
if body.flightsql && !cfg!(feature = "flightsql") {
return (
StatusCode::BAD_REQUEST,
"FlightSQL is not enabled on this server",
)
.into_response();
}
let req = ExecRequest {
path: uri.path().to_string(),
sql: body.sql.to_string(),
};
let opts = ExecOptions::new(Some(state.config.result_limit), body.flightsql);
create_response(&state, req, opts).await
}
#[derive(Deserialize)]
struct GetCatalogQueryParams {
#[serde(default)]
flightsql: bool,
}
async fn get_catalog_handler(
state: State<ExecutionState>,
OriginalUri(uri): OriginalUri,
Query(query): Query<GetCatalogQueryParams>,
) -> Response {
let opts = ExecOptions::new(None, query.flightsql);
if opts.flightsql && !cfg!(feature = "flightsql") {
return (
StatusCode::BAD_REQUEST,
"FlightSQL is not enabled on this server",
)
.into_response();
}
let sql = "SHOW TABLES".to_string();
let req = ExecRequest {
path: uri.path().to_string(),
sql,
};
create_response(&state, req, opts).await
}
#[derive(Deserialize, Serialize)]
struct GetTablePathParams {
catalog: String,
schema: String,
table: String,
}
#[derive(Deserialize)]
struct GetTableQueryParams {
#[serde(default)]
flightsql: bool,
}
async fn get_table_handler(
state: State<ExecutionState>,
Path(path): Path<GetTablePathParams>,
Query(query): Query<GetTableQueryParams>,
OriginalUri(uri): OriginalUri,
) -> Response {
let GetTablePathParams {
catalog,
schema,
table,
} = path;
let sql = format!(
"SELECT * FROM \"{catalog}\".\"{schema}\".\"{table}\" LIMIT {}",
state.config.result_limit
);
let req = ExecRequest {
path: uri.path().to_string(),
sql,
};
let opts = ExecOptions::new(Some(state.config.result_limit), query.flightsql);
create_response(&state, req, opts).await
}
#[derive(Deserialize, Serialize)]
struct GetTpchPathParams {
number: usize,
}
async fn get_tpch_query_handler(
state: State<ExecutionState>,
Path(path): Path<GetTpchPathParams>,
OriginalUri(uri): OriginalUri,
) -> Response {
if let Some(sql) = tpch::sql_for_tpch_query(path.number) {
let req = ExecRequest {
path: uri.path().to_string(),
sql: sql.to_string(),
};
let opts = ExecOptions::new(None, false);
create_response(&state, req, opts).await
} else {
(StatusCode::BAD_REQUEST, "Unknown TPC-H query number").into_response()
}
}
async fn response_for_sql(
State(state): &State<ExecutionState>,
sql: String,
opts: ExecOptions,
) -> (Response, ResponseDetails) {
debug!("Executing sql: {sql}");
match state.execution.execute_sql_with_opts(&sql, opts).await {
Ok(ExecResult::RecordBatchStream(stream)) => batch_stream_to_response(stream).await,
Ok(_) => {
let res = (
StatusCode::BAD_REQUEST,
"Execution failed: unknown result type".to_string(),
)
.into_response();
(res, error_response_details())
}
Err(e) => {
let res = (StatusCode::BAD_REQUEST, format!("{}", e)).into_response();
(res, error_response_details())
}
}
}
struct ResponseDetails {
rows: u64,
}
fn error_response_details() -> ResponseDetails {
ResponseDetails { rows: 0 }
}
async fn batch_stream_to_response(
batch_stream: SendableRecordBatchStream,
) -> (Response, ResponseDetails) {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let mut writer = ArrayWriter::new(&mut buf);
let mut batch_stream = batch_stream;
let mut rows: usize = 0;
while let Some(maybe_batch) = batch_stream.next().await {
match maybe_batch {
Ok(batch) => {
if let Err(e) = writer.write(&batch) {
error!("Error serializing result batches: {}", e);
return (
(StatusCode::INTERNAL_SERVER_ERROR, "Serialization error").into_response(),
error_response_details(),
);
}
rows += batch.num_rows()
}
Err(e) => {
error!("Error executing query: {}", e);
return (
(StatusCode::INTERNAL_SERVER_ERROR, "Query execution error").into_response(),
error_response_details(),
);
}
}
}
if let Err(e) = writer.finish() {
error!("Error finalizing JSON writer: {}", e);
return (
(StatusCode::INTERNAL_SERVER_ERROR, "Finalization error").into_response(),
error_response_details(),
);
}
match String::from_utf8(buf.into_inner()) {
Ok(json) => {
let mut res = Response::new(Body::new(json));
res.headers_mut()
.insert("content-type", HeaderValue::from_static("application/json"));
let details = ResponseDetails { rows: rows as u64 };
(res, details)
}
Err(_) => (
(StatusCode::INTERNAL_SERVER_ERROR, "UTF-8 conversion error").into_response(),
error_response_details(),
),
}
}
async fn create_response(
state: &State<ExecutionState>,
req: ExecRequest,
opts: ExecOptions,
) -> Response {
let start = Timestamp::now();
let (res, details) = response_for_sql(state, req.sql.clone(), opts).await;
let elapsed = Timestamp::now() - start;
let req = ObservabilityRequestDetails {
request_id: None,
path: req.path,
sql: Some(req.sql),
start_ms: start.as_millisecond(),
duration_ms: elapsed.get_milliseconds(),
rows: Some(details.rows),
status: res.status().as_u16(),
};
let obs = state.execution.execution_ctx().observability();
if let Err(e) = obs
.try_record_request(state.execution.session_ctx(), req)
.await
{
error!("Error recording request: {}", e)
}
res
}
#[cfg(test)]
mod test {
use axum::body::Body;
use datafusion_app::{
config::ExecutionConfig, extensions::DftSessionStateBuilder, local::ExecutionContext,
};
use http::{Request, StatusCode};
use crate::{
config::HttpServerConfig, execution::AppExecution, server::http::router::create_router,
};
use tower::ServiceExt;
fn setup() -> (AppExecution, HttpServerConfig) {
let config = ExecutionConfig::default();
let state = DftSessionStateBuilder::try_new(None)
.unwrap()
.build()
.unwrap();
let local =
ExecutionContext::try_new(&config, state, crate::APP_NAME, env!("CARGO_PKG_VERSION"))
.unwrap();
let execution = AppExecution::new(local);
let http_config = HttpServerConfig::default();
(execution, http_config)
}
#[tokio::test]
async fn test_get_catalog() {
let (execution, http_config) = setup();
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/catalog")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_get_table() {
let (execution, http_config) = setup();
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/table/datafusion/information_schema/df_settings")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_get_nonexistent_table() {
let (execution, http_config) = setup();
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/table/datafusion/information_schema/df_setting")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_post_sql() {
let (execution, http_config) = setup();
let router = create_router(execution, http_config);
let req = Request::builder()
.method("POST")
.uri("/sql")
.header("Content-Type", "application/json")
.body(Body::from("{\"sql\": \"SELECT 1\"}"))
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_observability_request_logged() {
let (execution, http_config) = setup();
let router = create_router(execution.clone(), http_config);
let req = Request::builder()
.method("POST")
.uri("/sql")
.header("Content-Type", "application/json")
.body(Body::from("{\"sql\": \"SELECT 1\"}"))
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let batches = execution
.execution_ctx()
.session_ctx()
.sql("SELECT * FROM dft.observability.requests")
.await
.unwrap()
.collect()
.await
.unwrap();
let count = batches.iter().fold(0, |acc, b| acc + b.num_rows());
assert_eq!(count, 1);
}
}
#[cfg(all(test, feature = "flightsql"))]
mod flightsql_test {
use axum::body::Body;
use datafusion_app::{
config::{ExecutionConfig, FlightSQLConfig},
extensions::DftSessionStateBuilder,
flightsql::FlightSQLContext,
local::ExecutionContext,
};
use http::{Request, StatusCode};
use crate::{
config::HttpServerConfig, execution::AppExecution, server::http::router::create_router,
};
use tower::ServiceExt;
async fn setup() -> (AppExecution, HttpServerConfig) {
let config = ExecutionConfig::default();
let state = DftSessionStateBuilder::try_new(None)
.unwrap()
.build()
.unwrap();
let local =
ExecutionContext::try_new(&config, state, crate::APP_NAME, env!("CARGO_PKG_VERSION"))
.unwrap();
let mut execution = AppExecution::new(local);
let flightsql_cfg = FlightSQLConfig {
connection_url: "localhost:50051".to_string(),
..Default::default()
};
let flightsql_ctx = FlightSQLContext::new(flightsql_cfg);
flightsql_ctx
.create_client(Some("http://localhost:50051".to_string()), None)
.await
.unwrap();
execution.with_flightsql_ctx(flightsql_ctx);
let http_config = HttpServerConfig::default();
(execution, http_config)
}
#[tokio::test]
async fn test_get_catalog() {
let (execution, http_config) = setup().await;
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/catalog?flightsql=true")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_get_table() {
let (execution, http_config) = setup().await;
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/table/datafusion/information_schema/df_settings?flightsql=true")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_get_nonexistent_table() {
let (execution, http_config) = setup().await;
let router = create_router(execution, http_config);
let req = Request::builder()
.uri("/table/datafusion/information_schema/df_setting?flightsql=true")
.body(Body::empty())
.unwrap();
let res = router.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
}