1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Instant;
4
5use alopex_core::kv::async_adapter::AsyncKVTransactionAdapter;
6use alopex_core::kv::{KVStore, KVTransaction};
7use alopex_core::types::TxnMode;
8use alopex_sql::storage::async_storage::AsyncTxnBridge;
9use alopex_sql::storage::AsyncSqlTransaction;
10use alopex_sql::AlopexDialect;
11use axum::extract::Extension;
12use axum::response::{IntoResponse, Response};
13use axum::Json;
14use bincode::Options;
15use futures::StreamExt;
16use serde::{Deserialize, Serialize};
17use tokio::sync::mpsc;
18use tokio_stream::wrappers::ReceiverStream;
19
20use crate::error::{Result, ServerError};
21use crate::http::{error_response, json_response, RequestContext};
22use crate::ops::memory::MemoryControlPolicy;
23use crate::server::ServerState;
24use crate::session::{SessionId, TxnHandle};
25use alopex_core::storage::format::bincode_config;
26use alopex_sql::catalog::persistent::{PersistedTableMeta, TABLES_PREFIX};
27
28#[derive(Debug, Deserialize)]
29pub struct SqlRequest {
30 pub sql: String,
31 pub session_id: Option<String>,
32 #[serde(default)]
33 pub streaming: bool,
34}
35
36#[derive(Debug, Serialize)]
37pub struct ColumnInfoResponse {
38 pub name: String,
39 pub data_type: String,
40}
41
42#[derive(Debug, Serialize)]
43pub struct SqlResponse {
44 pub columns: Vec<ColumnInfoResponse>,
45 pub rows: Vec<Vec<alopex_sql::storage::SqlValue>>,
46 pub affected_rows: Option<u64>,
47}
48
49#[derive(Debug, Serialize)]
50struct StreamItem {
51 row: Option<Vec<alopex_sql::storage::SqlValue>>,
52 error: Option<StreamError>,
53 done: bool,
54}
55
56#[derive(Debug, Serialize)]
57struct StreamError {
58 code: String,
59 message: String,
60 correlation_id: String,
61}
62
63type AsyncTxn = AsyncTxnBridge<'static, AsyncKVTransactionAdapter>;
64
65enum StreamSource {
66 Txn(AsyncTxn),
67 Handle(TxnHandle),
68}
69
70pub async fn handle(
71 Extension(state): Extension<Arc<ServerState>>,
72 Extension(ctx): Extension<RequestContext>,
73 Json(request): Json<SqlRequest>,
74) -> Response {
75 if request.sql.trim().is_empty() {
76 return error_response(
77 ServerError::BadRequest("sql must not be empty".into()),
78 &ctx,
79 );
80 }
81
82 if request.streaming {
83 return stream_response(state, request, &ctx);
84 }
85
86 let result = execute_non_streaming(state.clone(), &request, &ctx).await;
87 match result {
88 Ok(response) => json_response(response, state.config.max_response_size, &ctx),
89 Err(err) => error_response(err, &ctx),
90 }
91}
92
93async fn execute_non_streaming(
94 state: Arc<ServerState>,
95 request: &SqlRequest,
96 ctx: &RequestContext,
97) -> Result<SqlResponse> {
98 let start = Instant::now();
99 let sql = request.sql.as_str();
100 let is_ddl = is_ddl(sql);
101 if is_write_sql(sql) {
102 state.lifecycle_state.check_write_allowed()?;
103 }
104
105 let exec_result: Result<alopex_sql::executor::ExecutionResult> = async {
106 if let Some(session_id) = &request.session_id {
107 let session_id = session_id
108 .parse::<SessionId>()
109 .map_err(|_| ServerError::BadRequest("invalid session_id".into()))?;
110 let fut = state.session_manager.execute_in_session(&session_id, sql);
111 let result = tokio::time::timeout(state.config.query_timeout, fut)
112 .await
113 .map_err(|_| ServerError::Timeout("query timeout".into()))??;
114 Ok(result)
115 } else {
116 let mut txn = state.begin_sql_txn().await?;
117 let fut = tokio::time::timeout(state.config.query_timeout, txn.async_execute(sql))
118 .await
119 .map_err(|_| ServerError::Timeout("query timeout".into()))?;
120 match fut {
121 Ok(result) => {
122 txn.async_commit()
123 .await
124 .map_err(|err| ServerError::Sql(err.into()))?;
125 Ok(result)
126 }
127 Err(err) => {
128 let _ = txn.async_rollback().await;
129 Err(ServerError::Sql(err.into()))
130 }
131 }
132 }
133 }
134 .await;
135 let exec_result = match exec_result {
136 Ok(result) => result,
137 Err(err) => {
138 state.metrics.record_query(start.elapsed(), false);
139 return Err(err);
140 }
141 };
142
143 if state.config.audit_log_enabled && is_ddl {
144 state
145 .audit
146 .log_ddl(sql, ctx.actor.as_deref(), &ctx.correlation_id);
147 }
148
149 if is_ddl {
150 sync_catalog_to_store(&state)?;
151 }
152
153 state.metrics.record_query(start.elapsed(), true);
154
155 Ok(map_execution_result(exec_result))
156}
157
158fn sync_catalog_to_store(state: &ServerState) -> Result<()> {
159 let guard = state
160 .catalog
161 .read()
162 .map_err(|_| ServerError::Internal("catalog lock poisoned".into()))?;
163 let tables = guard.list_tables();
164 let mut txn = state.store.begin(TxnMode::ReadWrite)?;
165 delete_prefix(&mut txn, TABLES_PREFIX)?;
166 for table in tables {
167 let persisted = PersistedTableMeta::from(&table);
168 let value = bincode_config()
169 .serialize(&persisted)
170 .map_err(|err| ServerError::Internal(err.to_string()))?;
171 txn.put(
172 table_key(&table.catalog_name, &table.namespace_name, &table.name),
173 value,
174 )?;
175 }
176 txn.commit_self()?;
177 Ok(())
178}
179
180fn delete_prefix<'a, T: KVTransaction<'a>>(txn: &mut T, prefix: &[u8]) -> Result<()> {
181 let mut keys = Vec::new();
182 for (key, _) in txn.scan_prefix(prefix)? {
183 keys.push(key);
184 }
185 for key in keys {
186 txn.delete(key)?;
187 }
188 Ok(())
189}
190
191fn table_key(catalog_name: &str, namespace_name: &str, table_name: &str) -> Vec<u8> {
192 let mut key = TABLES_PREFIX.to_vec();
193 key.extend_from_slice(catalog_name.as_bytes());
194 key.push(b'/');
195 key.extend_from_slice(namespace_name.as_bytes());
196 key.push(b'/');
197 key.extend_from_slice(table_name.as_bytes());
198 key
199}
200
201fn stream_response(state: Arc<ServerState>, request: SqlRequest, ctx: &RequestContext) -> Response {
202 if is_write_sql(&request.sql) {
203 if let Err(err) = state.lifecycle_state.check_write_allowed() {
204 return error_response(err, ctx);
205 }
206 }
207 let (sender, receiver) = mpsc::channel(32);
208 let sql = request.sql.clone();
209 let correlation_id = ctx.correlation_id.clone();
210 let max_response_size = state.config.max_response_size;
211 let timeout = state.config.query_timeout;
212 let memory_policy = MemoryControlPolicy::from_env();
213 let metrics = state.metrics.clone();
214 let mut audit = None;
215 if state.config.audit_log_enabled && is_ddl(&sql) {
216 audit = Some(state.audit.clone());
217 }
218
219 let session_id = request.session_id.clone();
220 let state_clone = state.clone();
221 let memory_policy = memory_policy.clone();
222 tokio::spawn(async move {
223 let start = Instant::now();
224 let mut bytes_sent = 0usize;
225 let mut success = true;
226 let mut source = match session_id {
227 Some(id) => {
228 let parsed = match id.parse::<SessionId>() {
229 Ok(id) => id,
230 Err(_) => {
231 let _ = sender
232 .send(stream_item_error(
233 ServerError::BadRequest("invalid session_id".into()),
234 &correlation_id,
235 ))
236 .await;
237 return;
238 }
239 };
240 match state_clone.session_manager.get_transaction(&parsed).await {
241 Ok(handle) => StreamSource::Handle(handle),
242 Err(err) => {
243 let _ = sender.send(stream_item_error(err, &correlation_id)).await;
244 return;
245 }
246 }
247 }
248 None => match state_clone.begin_sql_txn().await {
249 Ok(txn) => StreamSource::Txn(txn),
250 Err(err) => {
251 let _ = sender.send(stream_item_error(err, &correlation_id)).await;
252 return;
253 }
254 },
255 };
256
257 let mut stream = match &mut source {
258 StreamSource::Handle(handle) => handle.query(&sql),
259 StreamSource::Txn(txn) => txn.async_query(&sql),
260 };
261 let deadline = start + timeout;
262 loop {
263 let remaining = deadline.saturating_duration_since(Instant::now());
264 if remaining.is_zero() {
265 let _ = sender
266 .send(stream_item_error(
267 ServerError::Timeout("query timeout".into()),
268 &correlation_id,
269 ))
270 .await;
271 success = false;
272 break;
273 }
274
275 tokio::select! {
276 _ = sender.closed() => {
277 success = false;
278 break;
279 }
280 item = tokio::time::timeout(remaining, stream.next()) => {
281 let next = match item {
282 Ok(value) => value,
283 Err(_) => {
284 let _ = sender
285 .send(stream_item_error(
286 ServerError::Timeout("query timeout".into()),
287 &correlation_id,
288 ))
289 .await;
290 success = false;
291 break;
292 }
293 };
294
295 match next {
296 Some(Ok(row)) => {
297 let item = StreamItem {
298 row: Some(row.values),
299 error: None,
300 done: false,
301 };
302 match serde_json::to_vec(&item) {
303 Ok(bytes) => {
304 bytes_sent += bytes.len();
305 if let Err(err) =
306 memory_policy.enforce_output_bytes(bytes_sent as u64)
307 {
308 let _ = sender
309 .send(stream_item_error(err, &correlation_id))
310 .await;
311 success = false;
312 break;
313 }
314 if bytes_sent > max_response_size {
315 let _ = sender
316 .send(stream_item_error(
317 ServerError::PayloadTooLarge(
318 "response size exceeds limit".into(),
319 ),
320 &correlation_id,
321 ))
322 .await;
323 success = false;
324 break;
325 }
326 }
327 Err(err) => {
328 let _ = sender
329 .send(stream_item_error(
330 ServerError::Internal(err.to_string()),
331 &correlation_id,
332 ))
333 .await;
334 success = false;
335 break;
336 }
337 }
338 match sender.try_send(item) {
339 Ok(()) => {}
340 Err(mpsc::error::TrySendError::Full(item)) => {
341 metrics.record_backpressure();
342 if sender.send(item).await.is_err() {
343 success = false;
344 break;
345 }
346 }
347 Err(mpsc::error::TrySendError::Closed(_)) => {
348 success = false;
349 break;
350 }
351 }
352 }
353 Some(Err(err)) => {
354 let _ = sender
355 .send(stream_item_error(
356 ServerError::Sql(err.into()),
357 &correlation_id,
358 ))
359 .await;
360 success = false;
361 break;
362 }
363 None => break,
364 }
365 }
366 }
367 }
368
369 drop(stream);
370 if let StreamSource::Txn(txn) = source {
371 let _ = txn.async_rollback().await;
372 }
373 if let Some(logger) = audit {
374 logger.log_ddl(&sql, None, &correlation_id);
375 }
376 metrics.record_query(start.elapsed(), success);
377 let _ = sender
378 .send(StreamItem {
379 row: None,
380 error: None,
381 done: true,
382 })
383 .await;
384 });
385
386 let stream = ReceiverStream::new(receiver).map(|item| {
387 let json = serde_json::to_string(&item).unwrap_or_else(|_| "{}".to_string());
388 Ok::<axum::body::Bytes, Infallible>(axum::body::Bytes::from(json + "\n"))
389 });
390
391 let body = axum::body::boxed(axum::body::Body::wrap_stream(stream));
392 axum::response::Response::builder()
393 .status(axum::http::StatusCode::OK)
394 .header(axum::http::header::CONTENT_TYPE, "application/jsonl")
395 .body(body)
396 .unwrap_or_else(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response())
397}
398
399fn stream_item_error(err: ServerError, correlation_id: &str) -> StreamItem {
400 StreamItem {
401 row: None,
402 error: Some(StreamError {
403 code: err.error_code(),
404 message: err.to_string(),
405 correlation_id: correlation_id.to_string(),
406 }),
407 done: false,
408 }
409}
410
411fn map_execution_result(exec_result: alopex_sql::executor::ExecutionResult) -> SqlResponse {
412 match exec_result {
413 alopex_sql::executor::ExecutionResult::Query(query) => SqlResponse {
414 columns: query
415 .columns
416 .into_iter()
417 .map(|col| ColumnInfoResponse {
418 name: col.name,
419 data_type: type_to_string(&col.data_type),
420 })
421 .collect(),
422 rows: query.rows,
423 affected_rows: None,
424 },
425 alopex_sql::executor::ExecutionResult::RowsAffected(rows) => SqlResponse {
426 columns: Vec::new(),
427 rows: Vec::new(),
428 affected_rows: Some(rows),
429 },
430 alopex_sql::executor::ExecutionResult::Success => SqlResponse {
431 columns: Vec::new(),
432 rows: Vec::new(),
433 affected_rows: None,
434 },
435 }
436}
437
438fn type_to_string(data_type: &alopex_sql::planner::ResolvedType) -> String {
439 match data_type {
440 alopex_sql::planner::ResolvedType::Integer => "INTEGER".to_string(),
441 alopex_sql::planner::ResolvedType::BigInt => "BIGINT".to_string(),
442 alopex_sql::planner::ResolvedType::Float => "FLOAT".to_string(),
443 alopex_sql::planner::ResolvedType::Double => "DOUBLE".to_string(),
444 alopex_sql::planner::ResolvedType::Text => "TEXT".to_string(),
445 alopex_sql::planner::ResolvedType::Blob => "BLOB".to_string(),
446 alopex_sql::planner::ResolvedType::Boolean => "BOOLEAN".to_string(),
447 alopex_sql::planner::ResolvedType::Timestamp => "TIMESTAMP".to_string(),
448 alopex_sql::planner::ResolvedType::Vector { dimension, metric } => {
449 format!("VECTOR({dimension}, {metric:?})")
450 }
451 alopex_sql::planner::ResolvedType::Null => "NULL".to_string(),
452 }
453}
454
455fn is_ddl(sql: &str) -> bool {
456 let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
457 return false;
458 };
459 statements.iter().any(|stmt| match &stmt.kind {
460 alopex_sql::ast::StatementKind::CreateTable(_)
461 | alopex_sql::ast::StatementKind::DropTable(_)
462 | alopex_sql::ast::StatementKind::CreateIndex(_)
463 | alopex_sql::ast::StatementKind::DropIndex(_) => true,
464 alopex_sql::ast::StatementKind::Select(_)
465 | alopex_sql::ast::StatementKind::Insert(_)
466 | alopex_sql::ast::StatementKind::Update(_)
467 | alopex_sql::ast::StatementKind::Delete(_) => false,
468 })
469}
470
471fn is_write_sql(sql: &str) -> bool {
472 let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
473 return false;
474 };
475 statements
476 .iter()
477 .any(|stmt| !matches!(stmt.kind, alopex_sql::ast::StatementKind::Select(_)))
478}