Skip to main content

alopex_server/http/
sql.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Instant;
4
5use alopex_core::kv::async_adapter::AsyncKVTransactionAdapter;
6use alopex_core::kv::{KVStore, KVTransaction};
7use alopex_core::types::TxnMode;
8use alopex_sql::storage::async_storage::AsyncTxnBridge;
9use alopex_sql::storage::AsyncSqlTransaction;
10use alopex_sql::AlopexDialect;
11use axum::extract::Extension;
12use axum::response::{IntoResponse, Response};
13use axum::Json;
14use bincode::Options;
15use futures::StreamExt;
16use serde::{Deserialize, Serialize};
17use tokio::sync::mpsc;
18use tokio_stream::wrappers::ReceiverStream;
19
20use crate::error::{Result, ServerError};
21use crate::http::{error_response, json_response, RequestContext};
22use crate::ops::memory::MemoryControlPolicy;
23use crate::server::ServerState;
24use crate::session::{SessionId, TxnHandle};
25use alopex_core::storage::format::bincode_config;
26use alopex_sql::catalog::persistent::{PersistedTableMeta, TABLES_PREFIX};
27
28#[derive(Debug, Deserialize)]
29pub struct SqlRequest {
30    pub sql: String,
31    pub session_id: Option<String>,
32    #[serde(default)]
33    pub streaming: bool,
34}
35
36#[derive(Debug, Serialize)]
37pub struct ColumnInfoResponse {
38    pub name: String,
39    pub data_type: String,
40}
41
42#[derive(Debug, Serialize)]
43pub struct SqlResponse {
44    pub columns: Vec<ColumnInfoResponse>,
45    pub rows: Vec<Vec<alopex_sql::storage::SqlValue>>,
46    pub affected_rows: Option<u64>,
47}
48
49#[derive(Debug, Serialize)]
50struct StreamItem {
51    row: Option<Vec<alopex_sql::storage::SqlValue>>,
52    error: Option<StreamError>,
53    done: bool,
54}
55
56#[derive(Debug, Serialize)]
57struct StreamError {
58    code: String,
59    message: String,
60    correlation_id: String,
61}
62
63type AsyncTxn = AsyncTxnBridge<'static, AsyncKVTransactionAdapter>;
64
65enum StreamSource {
66    Txn(AsyncTxn),
67    Handle(TxnHandle),
68}
69
70pub async fn handle(
71    Extension(state): Extension<Arc<ServerState>>,
72    Extension(ctx): Extension<RequestContext>,
73    Json(request): Json<SqlRequest>,
74) -> Response {
75    if request.sql.trim().is_empty() {
76        return error_response(
77            ServerError::BadRequest("sql must not be empty".into()),
78            &ctx,
79        );
80    }
81
82    if request.streaming {
83        return stream_response(state, request, &ctx);
84    }
85
86    let result = execute_non_streaming(state.clone(), &request, &ctx).await;
87    match result {
88        Ok(response) => json_response(response, state.config.max_response_size, &ctx),
89        Err(err) => error_response(err, &ctx),
90    }
91}
92
93async fn execute_non_streaming(
94    state: Arc<ServerState>,
95    request: &SqlRequest,
96    ctx: &RequestContext,
97) -> Result<SqlResponse> {
98    let start = Instant::now();
99    let sql = request.sql.as_str();
100    let is_ddl = is_ddl(sql);
101    if is_write_sql(sql) {
102        state.lifecycle_state.check_write_allowed()?;
103    }
104
105    let exec_result: Result<alopex_sql::executor::ExecutionResult> = async {
106        if let Some(session_id) = &request.session_id {
107            let session_id = session_id
108                .parse::<SessionId>()
109                .map_err(|_| ServerError::BadRequest("invalid session_id".into()))?;
110            let fut = state.session_manager.execute_in_session(&session_id, sql);
111            let result = tokio::time::timeout(state.config.query_timeout, fut)
112                .await
113                .map_err(|_| ServerError::Timeout("query timeout".into()))??;
114            Ok(result)
115        } else {
116            let mut txn = state.begin_sql_txn().await?;
117            let fut = tokio::time::timeout(state.config.query_timeout, txn.async_execute(sql))
118                .await
119                .map_err(|_| ServerError::Timeout("query timeout".into()))?;
120            match fut {
121                Ok(result) => {
122                    txn.async_commit()
123                        .await
124                        .map_err(|err| ServerError::Sql(err.into()))?;
125                    Ok(result)
126                }
127                Err(err) => {
128                    let _ = txn.async_rollback().await;
129                    Err(ServerError::Sql(err.into()))
130                }
131            }
132        }
133    }
134    .await;
135    let exec_result = match exec_result {
136        Ok(result) => result,
137        Err(err) => {
138            state.metrics.record_query(start.elapsed(), false);
139            return Err(err);
140        }
141    };
142
143    if state.config.audit_log_enabled && is_ddl {
144        state
145            .audit
146            .log_ddl(sql, ctx.actor.as_deref(), &ctx.correlation_id);
147    }
148
149    if is_ddl {
150        sync_catalog_to_store(&state)?;
151    }
152
153    state.metrics.record_query(start.elapsed(), true);
154
155    Ok(map_execution_result(exec_result))
156}
157
158fn sync_catalog_to_store(state: &ServerState) -> Result<()> {
159    let guard = state
160        .catalog
161        .read()
162        .map_err(|_| ServerError::Internal("catalog lock poisoned".into()))?;
163    let tables = guard.list_tables();
164    let mut txn = state.store.begin(TxnMode::ReadWrite)?;
165    delete_prefix(&mut txn, TABLES_PREFIX)?;
166    for table in tables {
167        let persisted = PersistedTableMeta::from(&table);
168        let value = bincode_config()
169            .serialize(&persisted)
170            .map_err(|err| ServerError::Internal(err.to_string()))?;
171        txn.put(
172            table_key(&table.catalog_name, &table.namespace_name, &table.name),
173            value,
174        )?;
175    }
176    txn.commit_self()?;
177    Ok(())
178}
179
180fn delete_prefix<'a, T: KVTransaction<'a>>(txn: &mut T, prefix: &[u8]) -> Result<()> {
181    let mut keys = Vec::new();
182    for (key, _) in txn.scan_prefix(prefix)? {
183        keys.push(key);
184    }
185    for key in keys {
186        txn.delete(key)?;
187    }
188    Ok(())
189}
190
191fn table_key(catalog_name: &str, namespace_name: &str, table_name: &str) -> Vec<u8> {
192    let mut key = TABLES_PREFIX.to_vec();
193    key.extend_from_slice(catalog_name.as_bytes());
194    key.push(b'/');
195    key.extend_from_slice(namespace_name.as_bytes());
196    key.push(b'/');
197    key.extend_from_slice(table_name.as_bytes());
198    key
199}
200
201fn stream_response(state: Arc<ServerState>, request: SqlRequest, ctx: &RequestContext) -> Response {
202    if is_write_sql(&request.sql) {
203        if let Err(err) = state.lifecycle_state.check_write_allowed() {
204            return error_response(err, ctx);
205        }
206    }
207    let (sender, receiver) = mpsc::channel(32);
208    let sql = request.sql.clone();
209    let correlation_id = ctx.correlation_id.clone();
210    let max_response_size = state.config.max_response_size;
211    let timeout = state.config.query_timeout;
212    let memory_policy = MemoryControlPolicy::from_env();
213    let metrics = state.metrics.clone();
214    let mut audit = None;
215    if state.config.audit_log_enabled && is_ddl(&sql) {
216        audit = Some(state.audit.clone());
217    }
218
219    let session_id = request.session_id.clone();
220    let state_clone = state.clone();
221    let memory_policy = memory_policy.clone();
222    tokio::spawn(async move {
223        let start = Instant::now();
224        let mut bytes_sent = 0usize;
225        let mut success = true;
226        let mut source = match session_id {
227            Some(id) => {
228                let parsed = match id.parse::<SessionId>() {
229                    Ok(id) => id,
230                    Err(_) => {
231                        let _ = sender
232                            .send(stream_item_error(
233                                ServerError::BadRequest("invalid session_id".into()),
234                                &correlation_id,
235                            ))
236                            .await;
237                        return;
238                    }
239                };
240                match state_clone.session_manager.get_transaction(&parsed).await {
241                    Ok(handle) => StreamSource::Handle(handle),
242                    Err(err) => {
243                        let _ = sender.send(stream_item_error(err, &correlation_id)).await;
244                        return;
245                    }
246                }
247            }
248            None => match state_clone.begin_sql_txn().await {
249                Ok(txn) => StreamSource::Txn(txn),
250                Err(err) => {
251                    let _ = sender.send(stream_item_error(err, &correlation_id)).await;
252                    return;
253                }
254            },
255        };
256
257        let mut stream = match &mut source {
258            StreamSource::Handle(handle) => handle.query(&sql),
259            StreamSource::Txn(txn) => txn.async_query(&sql),
260        };
261        let deadline = start + timeout;
262        loop {
263            let remaining = deadline.saturating_duration_since(Instant::now());
264            if remaining.is_zero() {
265                let _ = sender
266                    .send(stream_item_error(
267                        ServerError::Timeout("query timeout".into()),
268                        &correlation_id,
269                    ))
270                    .await;
271                success = false;
272                break;
273            }
274
275            tokio::select! {
276                _ = sender.closed() => {
277                    success = false;
278                    break;
279                }
280                item = tokio::time::timeout(remaining, stream.next()) => {
281                    let next = match item {
282                        Ok(value) => value,
283                        Err(_) => {
284                            let _ = sender
285                                .send(stream_item_error(
286                                    ServerError::Timeout("query timeout".into()),
287                                    &correlation_id,
288                                ))
289                                .await;
290                            success = false;
291                            break;
292                        }
293                    };
294
295                    match next {
296                        Some(Ok(row)) => {
297                            let item = StreamItem {
298                                row: Some(row.values),
299                                error: None,
300                                done: false,
301                            };
302                            match serde_json::to_vec(&item) {
303                                Ok(bytes) => {
304                                    bytes_sent += bytes.len();
305                                    if let Err(err) =
306                                        memory_policy.enforce_output_bytes(bytes_sent as u64)
307                                    {
308                                        let _ = sender
309                                            .send(stream_item_error(err, &correlation_id))
310                                            .await;
311                                        success = false;
312                                        break;
313                                    }
314                                    if bytes_sent > max_response_size {
315                                        let _ = sender
316                                            .send(stream_item_error(
317                                                ServerError::PayloadTooLarge(
318                                                    "response size exceeds limit".into(),
319                                                ),
320                                                &correlation_id,
321                                            ))
322                                            .await;
323                                        success = false;
324                                        break;
325                                    }
326                                }
327                                Err(err) => {
328                                    let _ = sender
329                                        .send(stream_item_error(
330                                            ServerError::Internal(err.to_string()),
331                                            &correlation_id,
332                                        ))
333                                        .await;
334                                    success = false;
335                                    break;
336                                }
337                            }
338                            match sender.try_send(item) {
339                                Ok(()) => {}
340                                Err(mpsc::error::TrySendError::Full(item)) => {
341                                    metrics.record_backpressure();
342                                    if sender.send(item).await.is_err() {
343                                        success = false;
344                                        break;
345                                    }
346                                }
347                                Err(mpsc::error::TrySendError::Closed(_)) => {
348                                    success = false;
349                                    break;
350                                }
351                            }
352                        }
353                        Some(Err(err)) => {
354                            let _ = sender
355                                .send(stream_item_error(
356                                    ServerError::Sql(err.into()),
357                                    &correlation_id,
358                                ))
359                                .await;
360                            success = false;
361                            break;
362                        }
363                        None => break,
364                    }
365                }
366            }
367        }
368
369        drop(stream);
370        if let StreamSource::Txn(txn) = source {
371            let _ = txn.async_rollback().await;
372        }
373        if let Some(logger) = audit {
374            logger.log_ddl(&sql, None, &correlation_id);
375        }
376        metrics.record_query(start.elapsed(), success);
377        let _ = sender
378            .send(StreamItem {
379                row: None,
380                error: None,
381                done: true,
382            })
383            .await;
384    });
385
386    let stream = ReceiverStream::new(receiver).map(|item| {
387        let json = serde_json::to_string(&item).unwrap_or_else(|_| "{}".to_string());
388        Ok::<axum::body::Bytes, Infallible>(axum::body::Bytes::from(json + "\n"))
389    });
390
391    let body = axum::body::boxed(axum::body::Body::wrap_stream(stream));
392    axum::response::Response::builder()
393        .status(axum::http::StatusCode::OK)
394        .header(axum::http::header::CONTENT_TYPE, "application/jsonl")
395        .body(body)
396        .unwrap_or_else(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response())
397}
398
399fn stream_item_error(err: ServerError, correlation_id: &str) -> StreamItem {
400    StreamItem {
401        row: None,
402        error: Some(StreamError {
403            code: err.error_code(),
404            message: err.to_string(),
405            correlation_id: correlation_id.to_string(),
406        }),
407        done: false,
408    }
409}
410
411fn map_execution_result(exec_result: alopex_sql::executor::ExecutionResult) -> SqlResponse {
412    match exec_result {
413        alopex_sql::executor::ExecutionResult::Query(query) => SqlResponse {
414            columns: query
415                .columns
416                .into_iter()
417                .map(|col| ColumnInfoResponse {
418                    name: col.name,
419                    data_type: type_to_string(&col.data_type),
420                })
421                .collect(),
422            rows: query.rows,
423            affected_rows: None,
424        },
425        alopex_sql::executor::ExecutionResult::RowsAffected(rows) => SqlResponse {
426            columns: Vec::new(),
427            rows: Vec::new(),
428            affected_rows: Some(rows),
429        },
430        alopex_sql::executor::ExecutionResult::Success => SqlResponse {
431            columns: Vec::new(),
432            rows: Vec::new(),
433            affected_rows: None,
434        },
435    }
436}
437
438fn type_to_string(data_type: &alopex_sql::planner::ResolvedType) -> String {
439    match data_type {
440        alopex_sql::planner::ResolvedType::Integer => "INTEGER".to_string(),
441        alopex_sql::planner::ResolvedType::BigInt => "BIGINT".to_string(),
442        alopex_sql::planner::ResolvedType::Float => "FLOAT".to_string(),
443        alopex_sql::planner::ResolvedType::Double => "DOUBLE".to_string(),
444        alopex_sql::planner::ResolvedType::Text => "TEXT".to_string(),
445        alopex_sql::planner::ResolvedType::Blob => "BLOB".to_string(),
446        alopex_sql::planner::ResolvedType::Boolean => "BOOLEAN".to_string(),
447        alopex_sql::planner::ResolvedType::Timestamp => "TIMESTAMP".to_string(),
448        alopex_sql::planner::ResolvedType::Vector { dimension, metric } => {
449            format!("VECTOR({dimension}, {metric:?})")
450        }
451        alopex_sql::planner::ResolvedType::Null => "NULL".to_string(),
452    }
453}
454
455fn is_ddl(sql: &str) -> bool {
456    let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
457        return false;
458    };
459    statements.iter().any(|stmt| match &stmt.kind {
460        alopex_sql::ast::StatementKind::CreateTable(_)
461        | alopex_sql::ast::StatementKind::DropTable(_)
462        | alopex_sql::ast::StatementKind::CreateIndex(_)
463        | alopex_sql::ast::StatementKind::DropIndex(_) => true,
464        alopex_sql::ast::StatementKind::Select(_)
465        | alopex_sql::ast::StatementKind::Insert(_)
466        | alopex_sql::ast::StatementKind::Update(_)
467        | alopex_sql::ast::StatementKind::Delete(_) => false,
468    })
469}
470
471fn is_write_sql(sql: &str) -> bool {
472    let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
473        return false;
474    };
475    statements
476        .iter()
477        .any(|stmt| !matches!(stmt.kind, alopex_sql::ast::StatementKind::Select(_)))
478}