alopex_server/http/
sql.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::time::Instant;
4
5use alopex_core::kv::async_adapter::AsyncKVTransactionAdapter;
6use alopex_sql::storage::async_storage::AsyncTxnBridge;
7use alopex_sql::storage::AsyncSqlTransaction;
8use alopex_sql::AlopexDialect;
9use axum::extract::Extension;
10use axum::response::{IntoResponse, Response};
11use axum::Json;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tokio::sync::mpsc;
15use tokio_stream::wrappers::ReceiverStream;
16
17use crate::error::{Result, ServerError};
18use crate::http::{error_response, json_response, RequestContext};
19use crate::server::ServerState;
20use crate::session::{SessionId, TxnHandle};
21
22#[derive(Debug, Deserialize)]
23pub struct SqlRequest {
24    pub sql: String,
25    pub session_id: Option<String>,
26    #[serde(default)]
27    pub streaming: bool,
28}
29
30#[derive(Debug, Serialize)]
31pub struct ColumnInfoResponse {
32    pub name: String,
33    pub data_type: String,
34}
35
36#[derive(Debug, Serialize)]
37pub struct SqlResponse {
38    pub columns: Vec<ColumnInfoResponse>,
39    pub rows: Vec<Vec<alopex_sql::storage::SqlValue>>,
40    pub affected_rows: Option<u64>,
41}
42
43#[derive(Debug, Serialize)]
44struct StreamItem {
45    row: Option<Vec<alopex_sql::storage::SqlValue>>,
46    error: Option<StreamError>,
47    done: bool,
48}
49
50#[derive(Debug, Serialize)]
51struct StreamError {
52    code: String,
53    message: String,
54    correlation_id: String,
55}
56
57type AsyncTxn = AsyncTxnBridge<'static, AsyncKVTransactionAdapter>;
58
59enum StreamSource {
60    Txn(AsyncTxn),
61    Handle(TxnHandle),
62}
63
64pub async fn handle(
65    Extension(state): Extension<Arc<ServerState>>,
66    Extension(ctx): Extension<RequestContext>,
67    Json(request): Json<SqlRequest>,
68) -> Response {
69    if request.sql.trim().is_empty() {
70        return error_response(
71            ServerError::BadRequest("sql must not be empty".into()),
72            &ctx,
73        );
74    }
75
76    if request.streaming {
77        return stream_response(state, request, &ctx);
78    }
79
80    let result = execute_non_streaming(state.clone(), &request, &ctx).await;
81    match result {
82        Ok(response) => json_response(response, state.config.max_response_size, &ctx),
83        Err(err) => error_response(err, &ctx),
84    }
85}
86
87async fn execute_non_streaming(
88    state: Arc<ServerState>,
89    request: &SqlRequest,
90    ctx: &RequestContext,
91) -> Result<SqlResponse> {
92    let start = Instant::now();
93    let sql = request.sql.as_str();
94    let is_ddl = is_ddl(sql);
95
96    let exec_result: Result<alopex_sql::executor::ExecutionResult> = async {
97        if let Some(session_id) = &request.session_id {
98            let session_id = session_id
99                .parse::<SessionId>()
100                .map_err(|_| ServerError::BadRequest("invalid session_id".into()))?;
101            let fut = state.session_manager.execute_in_session(&session_id, sql);
102            let result = tokio::time::timeout(state.config.query_timeout, fut)
103                .await
104                .map_err(|_| ServerError::Timeout("query timeout".into()))??;
105            Ok(result)
106        } else {
107            let mut txn = state.begin_sql_txn().await?;
108            let fut = tokio::time::timeout(state.config.query_timeout, txn.async_execute(sql))
109                .await
110                .map_err(|_| ServerError::Timeout("query timeout".into()))?;
111            match fut {
112                Ok(result) => {
113                    txn.async_commit()
114                        .await
115                        .map_err(|err| ServerError::Sql(err.into()))?;
116                    Ok(result)
117                }
118                Err(err) => {
119                    let _ = txn.async_rollback().await;
120                    Err(ServerError::Sql(err.into()))
121                }
122            }
123        }
124    }
125    .await;
126    let exec_result = match exec_result {
127        Ok(result) => result,
128        Err(err) => {
129            state.metrics.record_query(start.elapsed(), false);
130            return Err(err);
131        }
132    };
133
134    if state.config.audit_log_enabled && is_ddl {
135        state
136            .audit
137            .log_ddl(sql, ctx.actor.as_deref(), &ctx.correlation_id);
138    }
139
140    state.metrics.record_query(start.elapsed(), true);
141
142    Ok(map_execution_result(exec_result))
143}
144
145fn stream_response(state: Arc<ServerState>, request: SqlRequest, ctx: &RequestContext) -> Response {
146    let (sender, receiver) = mpsc::channel(32);
147    let sql = request.sql.clone();
148    let correlation_id = ctx.correlation_id.clone();
149    let max_response_size = state.config.max_response_size;
150    let timeout = state.config.query_timeout;
151    let metrics = state.metrics.clone();
152    let mut audit = None;
153    if state.config.audit_log_enabled && is_ddl(&sql) {
154        audit = Some(state.audit.clone());
155    }
156
157    let session_id = request.session_id.clone();
158    let state_clone = state.clone();
159    tokio::spawn(async move {
160        let start = Instant::now();
161        let mut bytes_sent = 0usize;
162        let mut success = true;
163        let mut source = match session_id {
164            Some(id) => {
165                let parsed = match id.parse::<SessionId>() {
166                    Ok(id) => id,
167                    Err(_) => {
168                        let _ = sender
169                            .send(stream_item_error(
170                                ServerError::BadRequest("invalid session_id".into()),
171                                &correlation_id,
172                            ))
173                            .await;
174                        return;
175                    }
176                };
177                match state_clone.session_manager.get_transaction(&parsed).await {
178                    Ok(handle) => StreamSource::Handle(handle),
179                    Err(err) => {
180                        let _ = sender.send(stream_item_error(err, &correlation_id)).await;
181                        return;
182                    }
183                }
184            }
185            None => match state_clone.begin_sql_txn().await {
186                Ok(txn) => StreamSource::Txn(txn),
187                Err(err) => {
188                    let _ = sender.send(stream_item_error(err, &correlation_id)).await;
189                    return;
190                }
191            },
192        };
193
194        let mut stream = match &mut source {
195            StreamSource::Handle(handle) => handle.query(&sql),
196            StreamSource::Txn(txn) => txn.async_query(&sql),
197        };
198        let deadline = start + timeout;
199        loop {
200            let remaining = deadline.saturating_duration_since(Instant::now());
201            if remaining.is_zero() {
202                let _ = sender
203                    .send(stream_item_error(
204                        ServerError::Timeout("query timeout".into()),
205                        &correlation_id,
206                    ))
207                    .await;
208                success = false;
209                break;
210            }
211
212            tokio::select! {
213                _ = sender.closed() => {
214                    success = false;
215                    break;
216                }
217                item = tokio::time::timeout(remaining, stream.next()) => {
218                    let next = match item {
219                        Ok(value) => value,
220                        Err(_) => {
221                            let _ = sender
222                                .send(stream_item_error(
223                                    ServerError::Timeout("query timeout".into()),
224                                    &correlation_id,
225                                ))
226                                .await;
227                            success = false;
228                            break;
229                        }
230                    };
231
232                    match next {
233                        Some(Ok(row)) => {
234                            let item = StreamItem {
235                                row: Some(row.values),
236                                error: None,
237                                done: false,
238                            };
239                            match serde_json::to_vec(&item) {
240                                Ok(bytes) => {
241                                    bytes_sent += bytes.len();
242                                    if bytes_sent > max_response_size {
243                                        let _ = sender
244                                            .send(stream_item_error(
245                                                ServerError::PayloadTooLarge(
246                                                    "response size exceeds limit".into(),
247                                                ),
248                                                &correlation_id,
249                                            ))
250                                            .await;
251                                        success = false;
252                                        break;
253                                    }
254                                }
255                                Err(err) => {
256                                    let _ = sender
257                                        .send(stream_item_error(
258                                            ServerError::Internal(err.to_string()),
259                                            &correlation_id,
260                                        ))
261                                        .await;
262                                    success = false;
263                                    break;
264                                }
265                            }
266                            match sender.try_send(item) {
267                                Ok(()) => {}
268                                Err(mpsc::error::TrySendError::Full(item)) => {
269                                    metrics.record_backpressure();
270                                    if sender.send(item).await.is_err() {
271                                        success = false;
272                                        break;
273                                    }
274                                }
275                                Err(mpsc::error::TrySendError::Closed(_)) => {
276                                    success = false;
277                                    break;
278                                }
279                            }
280                        }
281                        Some(Err(err)) => {
282                            let _ = sender
283                                .send(stream_item_error(
284                                    ServerError::Sql(err.into()),
285                                    &correlation_id,
286                                ))
287                                .await;
288                            success = false;
289                            break;
290                        }
291                        None => break,
292                    }
293                }
294            }
295        }
296
297        drop(stream);
298        if let StreamSource::Txn(txn) = source {
299            let _ = txn.async_rollback().await;
300        }
301        if let Some(logger) = audit {
302            logger.log_ddl(&sql, None, &correlation_id);
303        }
304        metrics.record_query(start.elapsed(), success);
305        let _ = sender
306            .send(StreamItem {
307                row: None,
308                error: None,
309                done: true,
310            })
311            .await;
312    });
313
314    let stream = ReceiverStream::new(receiver).map(|item| {
315        let json = serde_json::to_string(&item).unwrap_or_else(|_| "{}".to_string());
316        Ok::<axum::body::Bytes, Infallible>(axum::body::Bytes::from(json + "\n"))
317    });
318
319    let body = axum::body::boxed(axum::body::Body::wrap_stream(stream));
320    axum::response::Response::builder()
321        .status(axum::http::StatusCode::OK)
322        .header(axum::http::header::CONTENT_TYPE, "application/jsonl")
323        .body(body)
324        .unwrap_or_else(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response())
325}
326
327fn stream_item_error(err: ServerError, correlation_id: &str) -> StreamItem {
328    StreamItem {
329        row: None,
330        error: Some(StreamError {
331            code: err.error_code(),
332            message: err.to_string(),
333            correlation_id: correlation_id.to_string(),
334        }),
335        done: false,
336    }
337}
338
339fn map_execution_result(exec_result: alopex_sql::executor::ExecutionResult) -> SqlResponse {
340    match exec_result {
341        alopex_sql::executor::ExecutionResult::Query(query) => SqlResponse {
342            columns: query
343                .columns
344                .into_iter()
345                .map(|col| ColumnInfoResponse {
346                    name: col.name,
347                    data_type: type_to_string(&col.data_type),
348                })
349                .collect(),
350            rows: query.rows,
351            affected_rows: None,
352        },
353        alopex_sql::executor::ExecutionResult::RowsAffected(rows) => SqlResponse {
354            columns: Vec::new(),
355            rows: Vec::new(),
356            affected_rows: Some(rows),
357        },
358        alopex_sql::executor::ExecutionResult::Success => SqlResponse {
359            columns: Vec::new(),
360            rows: Vec::new(),
361            affected_rows: None,
362        },
363    }
364}
365
366fn type_to_string(data_type: &alopex_sql::planner::ResolvedType) -> String {
367    match data_type {
368        alopex_sql::planner::ResolvedType::Integer => "INTEGER".to_string(),
369        alopex_sql::planner::ResolvedType::BigInt => "BIGINT".to_string(),
370        alopex_sql::planner::ResolvedType::Float => "FLOAT".to_string(),
371        alopex_sql::planner::ResolvedType::Double => "DOUBLE".to_string(),
372        alopex_sql::planner::ResolvedType::Text => "TEXT".to_string(),
373        alopex_sql::planner::ResolvedType::Blob => "BLOB".to_string(),
374        alopex_sql::planner::ResolvedType::Boolean => "BOOLEAN".to_string(),
375        alopex_sql::planner::ResolvedType::Timestamp => "TIMESTAMP".to_string(),
376        alopex_sql::planner::ResolvedType::Vector { dimension, metric } => {
377            format!("VECTOR({dimension}, {metric:?})")
378        }
379        alopex_sql::planner::ResolvedType::Null => "NULL".to_string(),
380    }
381}
382
383fn is_ddl(sql: &str) -> bool {
384    let Ok(statements) = alopex_sql::parser::Parser::parse_sql(&AlopexDialect, sql) else {
385        return false;
386    };
387    statements.iter().any(|stmt| match &stmt.kind {
388        alopex_sql::ast::StatementKind::CreateTable(_)
389        | alopex_sql::ast::StatementKind::DropTable(_)
390        | alopex_sql::ast::StatementKind::CreateIndex(_)
391        | alopex_sql::ast::StatementKind::DropIndex(_) => true,
392        alopex_sql::ast::StatementKind::Select(_)
393        | alopex_sql::ast::StatementKind::Insert(_)
394        | alopex_sql::ast::StatementKind::Update(_)
395        | alopex_sql::ast::StatementKind::Delete(_) => false,
396    })
397}