1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Instant;
4
5use alopex_core::kv::async_adapter::AsyncKVTransactionAdapter;
6use alopex_sql::storage::async_storage::AsyncTxnBridge;
7use alopex_sql::storage::AsyncSqlTransaction;
8use alopex_sql::AlopexDialect;
9use axum::extract::Extension;
10use axum::response::{IntoResponse, Response};
11use axum::Json;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::ReceiverStream;
16
17use crate::error::{Result, ServerError};
18use crate::http::{error_response, json_response, RequestContext};
19use crate::server::ServerState;
20use crate::session::{SessionId, TxnHandle};
21
22#[derive(Debug, Deserialize)]
23pub struct SqlRequest {
24 pub sql: String,
25 pub session_id: Option<String>,
26 #[serde(default)]
27 pub streaming: bool,
28}
29
30#[derive(Debug, Serialize)]
31pub struct ColumnInfoResponse {
32 pub name: String,
33 pub data_type: String,
34}
35
36#[derive(Debug, Serialize)]
37pub struct SqlResponse {
38 pub columns: Vec<ColumnInfoResponse>,
39 pub rows: Vec<Vec<alopex_sql::storage::SqlValue>>,
40 pub affected_rows: Option<u64>,
41}
42
43#[derive(Debug, Serialize)]
44struct StreamItem {
45 row: Option<Vec<alopex_sql::storage::SqlValue>>,
46 error: Option<StreamError>,
47 done: bool,
48}
49
50#[derive(Debug, Serialize)]
51struct StreamError {
52 code: String,
53 message: String,
54 correlation_id: String,
55}
56
57type AsyncTxn = AsyncTxnBridge<'static, AsyncKVTransactionAdapter>;
58
59enum StreamSource {
60 Txn(AsyncTxn),
61 Handle(TxnHandle),
62}
63
64pub async fn handle(
65 Extension(state): Extension<Arc<ServerState>>,
66 Extension(ctx): Extension<RequestContext>,
67 Json(request): Json<SqlRequest>,
68) -> Response {
69 if request.sql.trim().is_empty() {
70 return error_response(
71 ServerError::BadRequest("sql must not be empty".into()),
72 &ctx,
73 );
74 }
75
76 if request.streaming {
77 return stream_response(state, request, &ctx);
78 }
79
80 let result = execute_non_streaming(state.clone(), &request, &ctx).await;
81 match result {
82 Ok(response) => json_response(response, state.config.max_response_size, &ctx),
83 Err(err) => error_response(err, &ctx),
84 }
85}
86
87async fn execute_non_streaming(
88 state: Arc<ServerState>,
89 request: &SqlRequest,
90 ctx: &RequestContext,
91) -> Result<SqlResponse> {
92 let start = Instant::now();
93 let sql = request.sql.as_str();
94 let is_ddl = is_ddl(sql);
95
96 let exec_result: Result<alopex_sql::executor::ExecutionResult> = async {
97 if let Some(session_id) = &request.session_id {
98 let session_id = session_id
99 .parse::<SessionId>()
100 .map_err(|_| ServerError::BadRequest("invalid session_id".into()))?;
101 let fut = state.session_manager.execute_in_session(&session_id, sql);
102 let result = tokio::time::timeout(state.config.query_timeout, fut)
103 .await
104 .map_err(|_| ServerError::Timeout("query timeout".into()))??;
105 Ok(result)
106 } else {
107 let mut txn = state.begin_sql_txn().await?;
108 let fut = tokio::time::timeout(state.config.query_timeout, txn.async_execute(sql))
109 .await
110 .map_err(|_| ServerError::Timeout("query timeout".into()))?;
111 match fut {
112 Ok(result) => {
113 txn.async_commit()
114 .await
115 .map_err(|err| ServerError::Sql(err.into()))?;
116 Ok(result)
117 }
118 Err(err) => {
119 let _ = txn.async_rollback().await;
120 Err(ServerError::Sql(err.into()))
121 }
122 }
123 }
124 }
125 .await;
126 let exec_result = match exec_result {
127 Ok(result) => result,
128 Err(err) => {
129 state.metrics.record_query(start.elapsed(), false);
130 return Err(err);
131 }
132 };
133
134 if state.config.audit_log_enabled && is_ddl {
135 state
136 .audit
137 .log_ddl(sql, ctx.actor.as_deref(), &ctx.correlation_id);
138 }
139
140 state.metrics.record_query(start.elapsed(), true);
141
142 Ok(map_execution_result(exec_result))
143}
144
145fn stream_response(state: Arc<ServerState>, request: SqlRequest, ctx: &RequestContext) -> Response {
146 let (sender, receiver) = mpsc::channel(32);
147 let sql = request.sql.clone();
148 let correlation_id = ctx.correlation_id.clone();
149 let max_response_size = state.config.max_response_size;
150 let timeout = state.config.query_timeout;
151 let metrics = state.metrics.clone();
152 let mut audit = None;
153 if state.config.audit_log_enabled && is_ddl(&sql) {
154 audit = Some(state.audit.clone());
155 }
156
157 let session_id = request.session_id.clone();
158 let state_clone = state.clone();
159 tokio::spawn(async move {
160 let start = Instant::now();
161 let mut bytes_sent = 0usize;
162 let mut success = true;
163 let mut source = match session_id {
164 Some(id) => {
165 let parsed = match id.parse::<SessionId>() {
166 Ok(id) => id,
167 Err(_) => {
168 let _ = sender
169 .send(stream_item_error(
170 ServerError::BadRequest("invalid session_id".into()),
171 &correlation_id,
172 ))
173 .await;
174 return;
175 }
176 };
177 match state_clone.session_manager.get_transaction(&parsed).await {
178 Ok(handle) => StreamSource::Handle(handle),
179 Err(err) => {
180 let _ = sender.send(stream_item_error(err, &correlation_id)).await;
181 return;
182 }
183 }
184 }
185 None => match state_clone.begin_sql_txn().await {
186 Ok(txn) => StreamSource::Txn(txn),
187 Err(err) => {
188 let _ = sender.send(stream_item_error(err, &correlation_id)).await;
189 return;
190 }
191 },
192 };
193
194 let mut stream = match &mut source {
195 StreamSource::Handle(handle) => handle.query(&sql),
196 StreamSource::Txn(txn) => txn.async_query(&sql),
197 };
198 let deadline = start + timeout;
199 loop {
200 let remaining = deadline.saturating_duration_since(Instant::now());
201 if remaining.is_zero() {
202 let _ = sender
203 .send(stream_item_error(
204 ServerError::Timeout("query timeout".into()),
205 &correlation_id,
206 ))
207 .await;
208 success = false;
209 break;
210 }
211
212 tokio::select! {
213 _ = sender.closed() => {
214 success = false;
215 break;
216 }
217 item = tokio::time::timeout(remaining, stream.next()) => {
218 let next = match item {
219 Ok(value) => value,
220 Err(_) => {
221 let _ = sender
222 .send(stream_item_error(
223 ServerError::Timeout("query timeout".into()),
224 &correlation_id,
225 ))
226 .await;
227 success = false;
228 break;
229 }
230 };
231
232 match next {
233 Some(Ok(row)) => {
234 let item = StreamItem {
235 row: Some(row.values),
236 error: None,
237 done: false,
238 };
239 match serde_json::to_vec(&item) {
240 Ok(bytes) => {
241 bytes_sent += bytes.len();
242 if bytes_sent > max_response_size {
243 let _ = sender
244 .send(stream_item_error(
245 ServerError::PayloadTooLarge(
246 "response size exceeds limit".into(),
247 ),
248 &correlation_id,
249 ))
250 .await;
251 success = false;
252 break;
253 }
254 }
255 Err(err) => {
256 let _ = sender
257 .send(stream_item_error(
258 ServerError::Internal(err.to_string()),
259 &correlation_id,
260 ))
261 .await;
262 success = false;
263 break;
264 }
265 }
266 match sender.try_send(item) {
267 Ok(()) => {}
268 Err(mpsc::error::TrySendError::Full(item)) => {
269 metrics.record_backpressure();
270 if sender.send(item).await.is_err() {
271 success = false;
272 break;
273 }
274 }
275 Err(mpsc::error::TrySendError::Closed(_)) => {
276 success = false;
277 break;
278 }
279 }
280 }
281 Some(Err(err)) => {
282 let _ = sender
283 .send(stream_item_error(
284 ServerError::Sql(err.into()),
285 &correlation_id,
286 ))
287 .await;
288 success = false;
289 break;
290 }
291 None => break,
292 }
293 }
294 }
295 }
296
297 drop(stream);
298 if let StreamSource::Txn(txn) = source {
299 let _ = txn.async_rollback().await;
300 }
301 if let Some(logger) = audit {
302 logger.log_ddl(&sql, None, &correlation_id);
303 }
304 metrics.record_query(start.elapsed(), success);
305 let _ = sender
306 .send(StreamItem {
307 row: None,
308 error: None,
309 done: true,
310 })
311 .await;
312 });
313
314 let stream = ReceiverStream::new(receiver).map(|item| {
315 let json = serde_json::to_string(&item).unwrap_or_else(|_| "{}".to_string());
316 Ok::<axum::body::Bytes, Infallible>(axum::body::Bytes::from(json + "\n"))
317 });
318
319 let body = axum::body::boxed(axum::body::Body::wrap_stream(stream));
320 axum::response::Response::builder()
321 .status(axum::http::StatusCode::OK)
322 .header(axum::http::header::CONTENT_TYPE, "application/jsonl")
323 .body(body)
324 .unwrap_or_else(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response())
325}
326
327fn stream_item_error(err: ServerError, correlation_id: &str) -> StreamItem {
328 StreamItem {
329 row: None,
330 error: Some(StreamError {
331 code: err.error_code(),
332 message: err.to_string(),
333 correlation_id: correlation_id.to_string(),
334 }),
335 done: false,
336 }
337}
338
339fn map_execution_result(exec_result: alopex_sql::executor::ExecutionResult) -> SqlResponse {
340 match exec_result {
341 alopex_sql::executor::ExecutionResult::Query(query) => SqlResponse {
342 columns: query
343 .columns
344 .into_iter()
345 .map(|col| ColumnInfoResponse {
346 name: col.name,
347 data_type: type_to_string(&col.data_type),
348 })
349 .collect(),
350 rows: query.rows,
351 affected_rows: None,
352 },
353 alopex_sql::executor::ExecutionResult::RowsAffected(rows) => SqlResponse {
354 columns: Vec::new(),
355 rows: Vec::new(),
356 affected_rows: Some(rows),
357 },
358 alopex_sql::executor::ExecutionResult::Success => SqlResponse {
359 columns: Vec::new(),
360 rows: Vec::new(),
361 affected_rows: None,
362 },
363 }
364}
365
366fn type_to_string(data_type: &alopex_sql::planner::ResolvedType) -> String {
367 match data_type {
368 alopex_sql::planner::ResolvedType::Integer => "INTEGER".to_string(),
369 alopex_sql::planner::ResolvedType::BigInt => "BIGINT".to_string(),
370 alopex_sql::planner::ResolvedType::Float => "FLOAT".to_string(),
371 alopex_sql::planner::ResolvedType::Double => "DOUBLE".to_string(),
372 alopex_sql::planner::ResolvedType::Text => "TEXT".to_string(),
373 alopex_sql::planner::ResolvedType::Blob => "BLOB".to_string(),
374 alopex_sql::planner::ResolvedType::Boolean => "BOOLEAN".to_string(),
375 alopex_sql::planner::ResolvedType::Timestamp => "TIMESTAMP".to_string(),
376 alopex_sql::planner::ResolvedType::Vector { dimension, metric } => {
377 format!("VECTOR({dimension}, {metric:?})")
378 }
379 alopex_sql::planner::ResolvedType::Null => "NULL".to_string(),
380 }
381}
382
383fn is_ddl(sql: &str) -> bool {
384 let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
385 return false;
386 };
387 statements.iter().any(|stmt| match &stmt.kind {
388 alopex_sql::ast::StatementKind::CreateTable(_)
389 | alopex_sql::ast::StatementKind::DropTable(_)
390 | alopex_sql::ast::StatementKind::CreateIndex(_)
391 | alopex_sql::ast::StatementKind::DropIndex(_) => true,
392 alopex_sql::ast::StatementKind::Select(_)
393 | alopex_sql::ast::StatementKind::Insert(_)
394 | alopex_sql::ast::StatementKind::Update(_)
395 | alopex_sql::ast::StatementKind::Delete(_) => false,
396 })
397}