Skip to main content

vldb_sqlite/
sql_exec.rs

1use crate::pb::sqlite_value::Kind as ProtoSqliteValueKind;
2use crate::pb::{ExecuteBatchItem, SqliteValue as ProtoSqliteValue};
3use arrow::array::{ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, StringBuilder};
4use arrow::datatypes::{DataType, Field, Schema};
5use arrow::ipc::writer::StreamWriter;
6use arrow::record_batch::RecordBatch;
7use rusqlite::types::{ToSql, Value as SqliteValue, ValueRef as SqliteValueRef};
8use rusqlite::{Connection, Error as RusqliteError};
9use serde::Serialize;
10use serde_json::{Map as JsonMap, Number as JsonNumber, Value as JsonValue};
11use std::fmt;
12use std::fs::File;
13use std::io::{self, Write};
14use std::io::{Read, Seek, SeekFrom};
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18
19/// 默认 Arrow IPC chunk 大小,供流式查询在内存与下游消费之间折中。
20/// Default Arrow IPC chunk size used to balance memory and downstream consumption.
21pub const DEFAULT_IPC_CHUNK_BYTES: usize = 1024 * 1024;
22
23/// 单批次物化的最大行数,避免单个 Arrow record batch 过大。
24/// Maximum number of rows materialized per batch to avoid oversized Arrow record batches.
25pub const STREAMING_BATCH_ROWS: usize = 1000;
26
27/// 通用 SQL 执行结果。
28/// Shared SQL execution result.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ExecuteScriptResult {
31    /// 是否执行成功。
32    /// Whether the execution succeeded.
33    pub success: bool,
34    /// 结果消息。
35    /// Result message.
36    pub message: String,
37    /// 受影响行数。
38    /// Number of affected rows.
39    pub rows_changed: i64,
40    /// 最近一次插入行 ID。
41    /// Last inserted row id.
42    pub last_insert_rowid: i64,
43}
44
45/// 通用批量执行结果。
46/// Shared batch-execution result.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct ExecuteBatchResult {
49    /// 是否执行成功。
50    /// Whether the batch execution succeeded.
51    pub success: bool,
52    /// 结果消息。
53    /// Result message.
54    pub message: String,
55    /// 受影响行数。
56    /// Number of affected rows.
57    pub rows_changed: i64,
58    /// 最近一次插入行 ID。
59    /// Last inserted row id.
60    pub last_insert_rowid: i64,
61    /// 实际执行的语句次数。
62    /// Number of statements executed.
63    pub statements_executed: i64,
64}
65
66/// JSON 查询结果。
67/// JSON query result.
68#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
69pub struct QueryJsonResult {
70    /// JSON 行集字符串。
71    /// JSON row-set string.
72    pub json_data: String,
73    /// 返回行数。
74    /// Number of rows returned.
75    pub row_count: u64,
76}
77
78/// Arrow IPC chunk 查询结果。
79/// Arrow IPC chunk query result.
80#[derive(Debug)]
81#[allow(dead_code)]
82pub struct QueryStreamResult {
83    /// 临时文件后端与 chunk 索引。
84    /// Temporary-file backend and chunk index metadata.
85    storage: QueryStreamStorage,
86    /// 返回行数。
87    /// Number of rows returned.
88    pub row_count: u64,
89    /// chunk 数量。
90    /// Number of emitted chunks.
91    pub chunk_count: u64,
92    /// 总字节数。
93    /// Total byte size of all chunks.
94    pub total_bytes: u64,
95}
96
97/// QueryStream 单个 chunk 的文件偏移与长度信息。
98/// File offset and length metadata for a single QueryStream chunk.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100struct QueryStreamChunkDescriptor {
101    /// chunk 在暂存文件中的起始偏移。
102    /// Starting offset of the chunk in the spool file.
103    offset: u64,
104    /// chunk 的字节长度。
105    /// Byte length of the chunk.
106    len: u64,
107}
108
109/// QueryStream 暂存后端。
110/// QueryStream spool backend.
111#[derive(Debug)]
112#[allow(dead_code)]
113struct QueryStreamStorage {
114    /// 暂存文件路径。
115    /// Spool file path.
116    file_path: PathBuf,
117    /// chunk 偏移索引。
118    /// Chunk offset index.
119    chunks: Vec<QueryStreamChunkDescriptor>,
120}
121
122#[allow(dead_code)]
123impl QueryStreamResult {
124    /// 读取指定下标的 chunk 内容。
125    /// Read the chunk content at the specified index.
126    pub fn read_chunk(&self, index: usize) -> Result<Vec<u8>, SqlExecCoreError> {
127        let descriptor = self.chunks_descriptor(index)?;
128        let mut file = File::open(&self.storage.file_path).map_err(|error| {
129            SqlExecCoreError::Internal(format!(
130                "open query stream spool file failed: {error}"
131            ))
132        })?;
133        file.seek(SeekFrom::Start(descriptor.offset)).map_err(|error| {
134            SqlExecCoreError::Internal(format!(
135                "seek query stream spool file failed: {error}"
136            ))
137        })?;
138        let chunk_len = usize::try_from(descriptor.len).map_err(|_| {
139            SqlExecCoreError::Internal(
140                "query stream chunk length exceeds usize / QueryStream chunk 长度超过 usize"
141                    .to_string(),
142            )
143        })?;
144        let mut chunk = vec![0_u8; chunk_len];
145        file.read_exact(&mut chunk).map_err(|error| {
146            SqlExecCoreError::Internal(format!(
147                "read query stream spool chunk failed: {error}"
148            ))
149        })?;
150        Ok(chunk)
151    }
152
153    /// 返回指定下标的 chunk 描述信息。
154    /// Return the chunk descriptor at the specified index.
155    fn chunks_descriptor(&self, index: usize) -> Result<QueryStreamChunkDescriptor, SqlExecCoreError> {
156        self.storage.chunks.get(index).copied().ok_or_else(|| {
157            SqlExecCoreError::InvalidArgument(
158                "chunk index out of bounds / chunk 下标越界".to_string(),
159            )
160        })
161    }
162}
163
164impl Drop for QueryStreamResult {
165    fn drop(&mut self) {
166        let _ = std::fs::remove_file(&self.storage.file_path);
167    }
168}
169
170/// QueryStream 执行过程中的共享统计信息。
171/// Shared metrics produced during QueryStream execution.
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct QueryStreamMetrics {
174    /// 返回行数。
175    /// Number of rows returned.
176    pub row_count: u64,
177    /// chunk 数量。
178    /// Number of emitted chunks.
179    pub chunk_count: u64,
180    /// 总字节数。
181    /// Total emitted byte size.
182    pub total_bytes: u64,
183}
184
185/// 通用 SQL 核心错误。
186/// Shared SQL core error.
187#[derive(Debug)]
188pub enum SqlExecCoreError {
189    /// 调用参数无效。
190    /// Invalid caller input.
191    InvalidArgument(String),
192    /// SQLite 执行错误。
193    /// SQLite execution error.
194    Sqlite {
195        /// 错误前缀。
196        /// Error prefix.
197        prefix: &'static str,
198        /// 底层 SQLite 错误。
199        /// Underlying SQLite error.
200        error: RusqliteError,
201    },
202    /// 内部序列化或 Arrow 处理错误。
203    /// Internal serialization or Arrow processing error.
204    Internal(String),
205}
206
207impl fmt::Display for SqlExecCoreError {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        match self {
210            Self::InvalidArgument(message) => write!(f, "{message}"),
211            Self::Sqlite { prefix, error } => write!(f, "{prefix}: {error}"),
212            Self::Internal(message) => write!(f, "{message}"),
213        }
214    }
215}
216
217impl std::error::Error for SqlExecCoreError {}
218
219/// 从 gRPC typed params 与 `params_json` 中解析最终的 SQLite 参数列表。
220/// Parse the final SQLite parameter list from gRPC typed params and `params_json`.
221pub fn parse_request_params(
222    params: &[ProtoSqliteValue],
223    params_json: &str,
224) -> Result<Vec<SqliteValue>, SqlExecCoreError> {
225    if !params.is_empty() {
226        if !params_json.trim().is_empty() {
227            return Err(SqlExecCoreError::InvalidArgument(
228                "provide either flat params or params_json, but not both".to_string(),
229            ));
230        }
231
232        return params
233            .iter()
234            .map(proto_param_to_sqlite_value)
235            .collect::<Result<Vec<_>, _>>();
236    }
237
238    parse_legacy_params_json(params_json)
239}
240
241/// 从 gRPC 批量参数中解析最终批量 SQLite 参数列表。
242/// Parse the final batch SQLite parameter list from gRPC batch params.
243pub fn parse_batch_params(
244    items: &[ExecuteBatchItem],
245) -> Result<Vec<Vec<SqliteValue>>, SqlExecCoreError> {
246    items
247        .iter()
248        .map(|item| parse_request_params(&item.params, ""))
249        .collect()
250}
251
252/// 从 legacy `params_json` 字符串解析参数列表。
253/// Parse parameter list from the legacy `params_json` string.
254pub fn parse_legacy_params_json(params_json: &str) -> Result<Vec<SqliteValue>, SqlExecCoreError> {
255    if params_json.trim().is_empty() {
256        return Ok(Vec::new());
257    }
258
259    let params = serde_json::from_str::<JsonValue>(params_json).map_err(|err| {
260        SqlExecCoreError::InvalidArgument(format!(
261            "params_json must be a JSON array of scalar values: {err}"
262        ))
263    })?;
264
265    let items = params.as_array().ok_or_else(|| {
266        SqlExecCoreError::InvalidArgument(
267            "params_json must be a JSON array of scalar values".to_string(),
268        )
269    })?;
270
271    items
272        .iter()
273        .cloned()
274        .map(json_param_to_sqlite_value)
275        .collect()
276}
277
278/// 解析 typed 单值到 SQLite 值。
279/// Convert a typed protobuf value to a SQLite value.
280pub fn proto_param_to_sqlite_value(
281    value: &ProtoSqliteValue,
282) -> Result<SqliteValue, SqlExecCoreError> {
283    match value.kind.as_ref() {
284        Some(ProtoSqliteValueKind::Int64Value(value)) => Ok(SqliteValue::Integer(*value)),
285        Some(ProtoSqliteValueKind::Float64Value(value)) => Ok(SqliteValue::Real(*value)),
286        Some(ProtoSqliteValueKind::StringValue(value)) => Ok(SqliteValue::Text(value.clone())),
287        Some(ProtoSqliteValueKind::BytesValue(value)) => Ok(SqliteValue::Blob(value.clone())),
288        Some(ProtoSqliteValueKind::BoolValue(value)) => Ok(SqliteValue::Integer(i64::from(*value))),
289        Some(ProtoSqliteValueKind::NullValue(_)) => Ok(SqliteValue::Null),
290        None => Err(SqlExecCoreError::InvalidArgument(
291            "SqliteValue.kind must be set for every bound parameter".to_string(),
292        )),
293    }
294}
295
296/// 解析 JSON 标量到 SQLite 值。
297/// Convert a JSON scalar into a SQLite value.
298pub fn json_param_to_sqlite_value(value: JsonValue) -> Result<SqliteValue, SqlExecCoreError> {
299    match value {
300        JsonValue::Null => Ok(SqliteValue::Null),
301        JsonValue::Bool(value) => Ok(SqliteValue::Integer(i64::from(value))),
302        JsonValue::Number(value) => {
303            if let Some(value) = value.as_i64() {
304                Ok(SqliteValue::Integer(value))
305            } else if let Some(value) = value.as_u64() {
306                Ok(SqliteValue::Integer(i64::try_from(value).map_err(|_| {
307                    SqlExecCoreError::InvalidArgument(
308                        "params_json contains an unsigned integer larger than SQLite signed 64-bit range"
309                            .to_string(),
310                    )
311                })?))
312            } else if let Some(value) = value.as_f64() {
313                Ok(SqliteValue::Real(value))
314            } else {
315                Err(SqlExecCoreError::InvalidArgument(
316                    "params_json contains an unsupported numeric value".to_string(),
317                ))
318            }
319        }
320        JsonValue::String(value) => Ok(SqliteValue::Text(value)),
321        JsonValue::Array(_) | JsonValue::Object(_) => Err(SqlExecCoreError::InvalidArgument(
322            "params_json only supports scalar JSON values (null, bool, number, string)".to_string(),
323        )),
324    }
325}
326
327/// 执行脚本或单条语句。
328/// Execute a script or a single statement.
329pub fn execute_script(
330    conn: &mut Connection,
331    sql: &str,
332    bound_values: &[SqliteValue],
333) -> Result<ExecuteScriptResult, SqlExecCoreError> {
334    if sql.trim().is_empty() {
335        return Err(SqlExecCoreError::InvalidArgument(
336            "sql must not be empty".to_string(),
337        ));
338    }
339
340    if bound_values.is_empty() {
341        conn.execute_batch(sql).map_err(|error| SqlExecCoreError::Sqlite {
342            prefix: "sqlite execute_batch failed",
343            error,
344        })?;
345
346        return Ok(ExecuteScriptResult {
347            success: true,
348            message: "script executed successfully".to_string(),
349            rows_changed: i64::try_from(conn.changes()).unwrap_or(i64::MAX),
350            last_insert_rowid: conn.last_insert_rowid(),
351        });
352    }
353
354    if has_multiple_sql_statements(sql) {
355        return Err(SqlExecCoreError::InvalidArgument(
356            "flat params or params_json are only supported for a single SQL statement".to_string(),
357        ));
358    }
359
360    let mut stmt = conn.prepare(sql).map_err(|error| SqlExecCoreError::Sqlite {
361        prefix: "sqlite prepare failed",
362        error,
363    })?;
364    let params = bind_values_as_params(bound_values);
365    let rows_changed = stmt
366        .execute(params.as_slice())
367        .map_err(|error| SqlExecCoreError::Sqlite {
368            prefix: "sqlite execute failed",
369            error,
370        })?;
371
372    Ok(ExecuteScriptResult {
373        success: true,
374        message: format!("statement executed successfully (rows_changed={rows_changed})"),
375        rows_changed: i64::try_from(rows_changed).unwrap_or(i64::MAX),
376        last_insert_rowid: conn.last_insert_rowid(),
377    })
378}
379
380/// 执行同 SQL 多组参数的批量执行。
381/// Execute a single SQL statement repeatedly with multiple parameter groups.
382pub fn execute_batch(
383    conn: &mut Connection,
384    sql: &str,
385    batch_params: &[Vec<SqliteValue>],
386) -> Result<ExecuteBatchResult, SqlExecCoreError> {
387    if sql.trim().is_empty() {
388        return Err(SqlExecCoreError::InvalidArgument(
389            "sql must not be empty".to_string(),
390        ));
391    }
392    if batch_params.is_empty() {
393        return Err(SqlExecCoreError::InvalidArgument(
394            "items must not be empty".to_string(),
395        ));
396    }
397    if has_multiple_sql_statements(sql) {
398        return Err(SqlExecCoreError::InvalidArgument(
399            "execute_batch only supports a single SQL statement".to_string(),
400        ));
401    }
402
403    conn.execute_batch("BEGIN TRANSACTION")
404        .map_err(|error| SqlExecCoreError::Sqlite {
405            prefix: "sqlite BEGIN TRANSACTION failed",
406            error,
407        })?;
408
409    let batch_result = (|| -> Result<ExecuteBatchResult, SqlExecCoreError> {
410        let mut stmt = conn.prepare(sql).map_err(|error| SqlExecCoreError::Sqlite {
411            prefix: "sqlite prepare failed",
412            error,
413        })?;
414
415        let mut rows_changed = 0_i64;
416        for params in batch_params {
417            let params = bind_values_as_params(params);
418            let changed = stmt
419                .execute(params.as_slice())
420                .map_err(|error| SqlExecCoreError::Sqlite {
421                    prefix: "sqlite execute failed",
422                    error,
423                })?;
424            rows_changed = rows_changed.saturating_add(i64::try_from(changed).unwrap_or(i64::MAX));
425        }
426
427        drop(stmt);
428        conn.execute_batch("COMMIT")
429            .map_err(|error| SqlExecCoreError::Sqlite {
430                prefix: "sqlite COMMIT failed",
431                error,
432            })?;
433
434        Ok(ExecuteBatchResult {
435            success: true,
436            message: format!(
437                "batch executed successfully (statements_executed={} rows_changed={rows_changed})",
438                batch_params.len()
439            ),
440            rows_changed,
441            last_insert_rowid: conn.last_insert_rowid(),
442            statements_executed: i64::try_from(batch_params.len()).unwrap_or(i64::MAX),
443        })
444    })();
445
446    if batch_result.is_err() {
447        let _ = conn.execute_batch("ROLLBACK");
448    }
449
450    batch_result
451}
452
453/// 执行查询并返回 JSON 行集。
454/// Execute a query and return a JSON row set.
455pub fn query_json(
456    conn: &mut Connection,
457    sql: &str,
458    bound_values: &[SqliteValue],
459) -> Result<QueryJsonResult, SqlExecCoreError> {
460    if sql.trim().is_empty() {
461        return Err(SqlExecCoreError::InvalidArgument(
462            "sql must not be empty".to_string(),
463        ));
464    }
465    if has_multiple_sql_statements(sql) {
466        return Err(SqlExecCoreError::InvalidArgument(
467            "query_json only supports a single SQL statement".to_string(),
468        ));
469    }
470
471    let mut stmt = conn.prepare(sql).map_err(|error| SqlExecCoreError::Sqlite {
472        prefix: "sqlite prepare failed",
473        error,
474    })?;
475    let column_names = stmt
476        .column_names()
477        .into_iter()
478        .map(|name| name.to_string())
479        .collect::<Vec<_>>();
480    let params = bind_values_as_params(bound_values);
481    let mut rows = stmt.query(params.as_slice()).map_err(|error| SqlExecCoreError::Sqlite {
482        prefix: "sqlite query failed",
483        error,
484    })?;
485
486    let mut json_rows = Vec::<JsonValue>::new();
487    while let Some(row) = rows.next().map_err(|error| SqlExecCoreError::Sqlite {
488        prefix: "sqlite row fetch failed",
489        error,
490    })? {
491        let mut object = JsonMap::new();
492        for (index, column_name) in column_names.iter().enumerate() {
493            let value = row
494                .get_ref(index)
495                .map_err(|error| SqlExecCoreError::Sqlite {
496                    prefix: "sqlite value access failed",
497                    error,
498                })?;
499            object.insert(column_name.clone(), sqlite_value_ref_to_json(value));
500        }
501        json_rows.push(JsonValue::Object(object));
502    }
503
504    let row_count = u64::try_from(json_rows.len()).unwrap_or(u64::MAX);
505    let json_data = serde_json::to_string(&json_rows).map_err(|error| {
506        SqlExecCoreError::Internal(format!("serialize JSON result failed: {error}"))
507    })?;
508
509    Ok(QueryJsonResult {
510        json_data,
511        row_count,
512    })
513}
514
515/// 执行查询并返回 Arrow IPC chunk 列表。
516/// Execute a query and return Arrow IPC chunks.
517#[allow(dead_code)]
518pub fn query_stream(
519    conn: &mut Connection,
520    sql: &str,
521    bound_values: &[SqliteValue],
522    target_chunk_size: usize,
523) -> Result<QueryStreamResult, SqlExecCoreError> {
524    let writer = TempFileChunkWriter::new(target_chunk_size)?;
525    let (writer, metrics) = query_stream_with_writer(conn, sql, bound_values, writer)?;
526
527    Ok(QueryStreamResult {
528        storage: writer.into_storage(),
529        row_count: metrics.row_count,
530        chunk_count: metrics.chunk_count,
531        total_bytes: metrics.total_bytes,
532    })
533}
534
535/// 执行查询并把 Arrow IPC chunk 写入任意 `Write` 接口。
536/// Execute a query and write Arrow IPC chunks into any `Write` sink.
537pub fn query_stream_with_writer<W: QueryStreamChunkWriter>(
538    conn: &mut Connection,
539    sql: &str,
540    bound_values: &[SqliteValue],
541    writer: W,
542) -> Result<(W, QueryStreamMetrics), SqlExecCoreError> {
543    if sql.trim().is_empty() {
544        return Err(SqlExecCoreError::InvalidArgument(
545            "sql must not be empty".to_string(),
546        ));
547    }
548    if has_multiple_sql_statements(sql) {
549        return Err(SqlExecCoreError::InvalidArgument(
550            "query_stream only supports a single SQL statement".to_string(),
551        ));
552    }
553
554    let mut stmt = conn.prepare(sql).map_err(|error| SqlExecCoreError::Sqlite {
555        prefix: "sqlite prepare failed",
556        error,
557    })?;
558    let columns = stmt.columns();
559    let column_names = columns
560        .iter()
561        .map(|column| column.name().to_string())
562        .collect::<Vec<_>>();
563    let declared_types = columns
564        .iter()
565        .map(|column| column.decl_type().map(|value| value.to_string()))
566        .collect::<Vec<_>>();
567    let params = bind_values_as_params(bound_values);
568    let mut rows = stmt.query(params.as_slice()).map_err(|error| SqlExecCoreError::Sqlite {
569        prefix: "sqlite query failed",
570        error,
571    })?;
572
573    let mut chunk_writer = Some(writer);
574    let mut ipc_writer: Option<StreamWriter<W>> = None;
575    let mut schema: Option<Arc<Schema>> = None;
576    let mut column_kinds: Option<Vec<ArrowColumnKind>> = None;
577    let mut total_rows: usize = 0;
578
579    loop {
580        let mut batch_rows = Vec::<Vec<SqliteValue>>::new();
581        while batch_rows.len() < STREAMING_BATCH_ROWS {
582            match rows.next().map_err(|error| SqlExecCoreError::Sqlite {
583                prefix: "sqlite row fetch failed",
584                error,
585            })? {
586                Some(row) => {
587                    let mut values = Vec::with_capacity(column_names.len());
588                    for index in 0..column_names.len() {
589                        let value = row
590                            .get_ref(index)
591                            .map_err(|error| SqlExecCoreError::Sqlite {
592                                prefix: "sqlite value access failed",
593                                error,
594                            })?;
595                        values.push(SqliteValue::try_from(value).map_err(|error| {
596                            SqlExecCoreError::Sqlite {
597                                prefix: "sqlite value conversion failed while materializing rows",
598                                error: RusqliteError::FromSqlConversionFailure(
599                                    index,
600                                    value.data_type(),
601                                    Box::new(error),
602                                ),
603                            }
604                        })?);
605                    }
606                    batch_rows.push(values);
607                }
608                None => break,
609            }
610        }
611
612        if batch_rows.is_empty() {
613            break;
614        }
615
616        total_rows += batch_rows.len();
617
618        if ipc_writer.is_none() {
619            column_kinds = Some(infer_column_kinds(&declared_types, &batch_rows));
620            schema = Some(Arc::new(Schema::new(
621                column_names
622                    .iter()
623                    .zip(column_kinds.as_ref().unwrap().iter())
624                    .map(|(name, kind)| {
625                        Field::new(
626                            name,
627                            match kind {
628                                ArrowColumnKind::Int64 => DataType::Int64,
629                                ArrowColumnKind::Float64 => DataType::Float64,
630                                ArrowColumnKind::Utf8 => DataType::Utf8,
631                                ArrowColumnKind::Binary => DataType::Binary,
632                            },
633                            true,
634                        )
635                    })
636                    .collect::<Vec<_>>(),
637            )));
638
639            let writer = StreamWriter::try_new(chunk_writer.take().unwrap(), schema.as_ref().unwrap())
640                .map_err(|error| {
641                SqlExecCoreError::Internal(format!(
642                    "arrow stream header write failed: {error}"
643                ))
644            })?;
645            ipc_writer = Some(writer);
646        }
647
648        let batch = RecordBatch::try_new(
649            Arc::clone(schema.as_ref().unwrap()),
650            build_arrow_arrays(column_kinds.as_ref().unwrap(), &batch_rows),
651        )
652        .map_err(|error| {
653            SqlExecCoreError::Internal(format!("arrow record batch build failed: {error}"))
654        })?;
655
656        let writer = ipc_writer.as_mut().unwrap();
657        writer.write(&batch).map_err(|error| {
658            SqlExecCoreError::Internal(format!("arrow batch write failed: {error}"))
659        })?;
660        writer.flush().map_err(|error| {
661            SqlExecCoreError::Internal(format!("arrow batch flush failed: {error}"))
662        })?;
663    }
664
665    let (writer, chunk_count, total_bytes) = if let Some(mut writer) = ipc_writer {
666        writer.finish().map_err(|error| {
667            SqlExecCoreError::Internal(format!("arrow stream finish failed: {error}"))
668        })?;
669        writer.flush().map_err(|error| {
670            SqlExecCoreError::Internal(format!("arrow final flush failed: {error}"))
671        })?;
672        let writer = writer.into_inner().map_err(|error| {
673            SqlExecCoreError::Internal(format!("arrow stream finalize failed: {error}"))
674        })?;
675        let chunk_count = writer.emitted_chunk_count();
676        let total_bytes = writer.emitted_total_bytes();
677        (writer, chunk_count, total_bytes)
678    } else {
679        (chunk_writer.take().expect("writer should remain available"), 0, 0)
680    };
681
682    Ok((writer, QueryStreamMetrics {
683        row_count: u64::try_from(total_rows).unwrap_or(u64::MAX),
684        chunk_count,
685        total_bytes,
686    }))
687}
688
689/// 检测 SQL 是否包含多条语句。
690/// Detect whether the SQL string contains multiple statements.
691pub fn has_multiple_sql_statements(sql: &str) -> bool {
692    count_sql_statements(sql) > 1
693}
694
695/// 统计 SQL 中实际包含的有效语句数量。
696/// Count the number of effective SQL statements contained in the input string.
697pub fn count_sql_statements(sql: &str) -> usize {
698    let chars: Vec<char> = sql.chars().collect();
699    let len = chars.len();
700    let mut i = 0;
701    let mut statement_count = 0;
702    let mut has_content = false;
703
704    while i < len {
705        match chars[i] {
706            '\'' => {
707                has_content = true;
708                i += 1;
709                while i < len {
710                    if chars[i] == '\'' {
711                        if i + 1 < len && chars[i + 1] == '\'' {
712                            i += 2;
713                        } else {
714                            i += 1;
715                            break;
716                        }
717                    } else {
718                        i += 1;
719                    }
720                }
721            }
722            '"' => {
723                has_content = true;
724                i += 1;
725                while i < len {
726                    if chars[i] == '"' {
727                        if i + 1 < len && chars[i + 1] == '"' {
728                            i += 2;
729                        } else {
730                            i += 1;
731                            break;
732                        }
733                    } else {
734                        i += 1;
735                    }
736                }
737            }
738            '-' if i + 1 < len && chars[i + 1] == '-' => {
739                i += 2;
740                while i < len && chars[i] != '\n' {
741                    i += 1;
742                }
743            }
744            '/' if i + 1 < len && chars[i + 1] == '*' => {
745                i += 2;
746                while i + 1 < len {
747                    if chars[i] == '*' && chars[i + 1] == '/' {
748                        i += 2;
749                        break;
750                    }
751                    i += 1;
752                }
753            }
754            ';' => {
755                if has_content {
756                    statement_count += 1;
757                }
758                has_content = false;
759                i += 1;
760            }
761            c if !c.is_whitespace() => {
762                has_content = true;
763                i += 1;
764            }
765            _ => i += 1,
766        }
767    }
768
769    if has_content {
770        statement_count += 1;
771    }
772
773    statement_count
774}
775
776/// 把 SQLite 值转换为 JSON 值。
777/// Convert a SQLite value to a JSON value.
778pub fn sqlite_value_to_json(value: &SqliteValue) -> JsonValue {
779    match value {
780        SqliteValue::Null => JsonValue::Null,
781        SqliteValue::Integer(value) => JsonValue::from(*value),
782        SqliteValue::Real(value) => json_float(*value),
783        SqliteValue::Text(value) => JsonValue::String(value.clone()),
784        SqliteValue::Blob(value) => JsonValue::Array(
785            value
786                .iter()
787                .map(|byte| JsonValue::from(u64::from(*byte)))
788                .collect(),
789        ),
790    }
791}
792
793fn bind_values_as_params(values: &[SqliteValue]) -> Vec<&dyn ToSql> {
794    values.iter().map(|value| value as &dyn ToSql).collect()
795}
796
797fn sqlite_value_ref_to_json(value: SqliteValueRef<'_>) -> JsonValue {
798    match SqliteValue::try_from(value) {
799        Ok(value) => sqlite_value_to_json(&value),
800        Err(_) => JsonValue::Null,
801    }
802}
803
804#[derive(Copy, Clone, Debug)]
805enum ArrowColumnKind {
806    Int64,
807    Float64,
808    Utf8,
809    Binary,
810}
811
812fn infer_column_kinds(
813    declared_types: &[Option<String>],
814    rows: &[Vec<SqliteValue>],
815) -> Vec<ArrowColumnKind> {
816    let column_count = declared_types.len();
817    let mut kinds = Vec::with_capacity(column_count);
818
819    for index in 0..column_count {
820        let mut current = declared_type_hint(declared_types[index].as_deref());
821        for row in rows {
822            current = merge_column_kind(current, &row[index]);
823        }
824        kinds.push(current.unwrap_or(ArrowColumnKind::Utf8));
825    }
826
827    kinds
828}
829
830fn declared_type_hint(value: Option<&str>) -> Option<ArrowColumnKind> {
831    let normalized = value?.trim().to_ascii_uppercase();
832
833    if normalized.contains("INT") || normalized.contains("BOOL") {
834        Some(ArrowColumnKind::Int64)
835    } else if normalized.contains("REAL")
836        || normalized.contains("FLOA")
837        || normalized.contains("DOUB")
838        || normalized.contains("NUMERIC")
839        || normalized.contains("DEC")
840    {
841        Some(ArrowColumnKind::Float64)
842    } else if normalized.contains("BLOB") {
843        Some(ArrowColumnKind::Binary)
844    } else if normalized.contains("CHAR")
845        || normalized.contains("CLOB")
846        || normalized.contains("TEXT")
847        || normalized.contains("JSON")
848        || normalized.contains("DATE")
849        || normalized.contains("TIME")
850    {
851        Some(ArrowColumnKind::Utf8)
852    } else {
853        None
854    }
855}
856
857fn merge_column_kind(
858    current: Option<ArrowColumnKind>,
859    value: &SqliteValue,
860) -> Option<ArrowColumnKind> {
861    match value {
862        SqliteValue::Null => current,
863        SqliteValue::Integer(_) => Some(match current {
864            None => ArrowColumnKind::Int64,
865            Some(ArrowColumnKind::Int64) => ArrowColumnKind::Int64,
866            Some(ArrowColumnKind::Float64) => ArrowColumnKind::Float64,
867            Some(ArrowColumnKind::Utf8) => ArrowColumnKind::Utf8,
868            Some(ArrowColumnKind::Binary) => ArrowColumnKind::Utf8,
869        }),
870        SqliteValue::Real(_) => Some(match current {
871            None => ArrowColumnKind::Float64,
872            Some(ArrowColumnKind::Int64) => ArrowColumnKind::Float64,
873            Some(ArrowColumnKind::Float64) => ArrowColumnKind::Float64,
874            Some(ArrowColumnKind::Utf8) => ArrowColumnKind::Utf8,
875            Some(ArrowColumnKind::Binary) => ArrowColumnKind::Utf8,
876        }),
877        SqliteValue::Text(_) => Some(ArrowColumnKind::Utf8),
878        SqliteValue::Blob(_) => Some(match current {
879            None => ArrowColumnKind::Binary,
880            Some(ArrowColumnKind::Binary) => ArrowColumnKind::Binary,
881            _ => ArrowColumnKind::Utf8,
882        }),
883    }
884}
885
886fn build_arrow_arrays(kinds: &[ArrowColumnKind], rows: &[Vec<SqliteValue>]) -> Vec<ArrayRef> {
887    let mut arrays = Vec::<ArrayRef>::with_capacity(kinds.len());
888
889    for (index, kind) in kinds.iter().enumerate() {
890        match kind {
891            ArrowColumnKind::Int64 => {
892                let mut builder = Int64Builder::with_capacity(rows.len());
893                for row in rows {
894                    match &row[index] {
895                        SqliteValue::Null => builder.append_null(),
896                        SqliteValue::Integer(value) => builder.append_value(*value),
897                        SqliteValue::Real(value) => builder.append_value(*value as i64),
898                        SqliteValue::Text(_) | SqliteValue::Blob(_) => builder.append_null(),
899                    }
900                }
901                arrays.push(Arc::new(builder.finish()));
902            }
903            ArrowColumnKind::Float64 => {
904                let mut builder = Float64Builder::with_capacity(rows.len());
905                for row in rows {
906                    match &row[index] {
907                        SqliteValue::Null => builder.append_null(),
908                        SqliteValue::Integer(value) => builder.append_value(*value as f64),
909                        SqliteValue::Real(value) => builder.append_value(*value),
910                        SqliteValue::Text(_) | SqliteValue::Blob(_) => builder.append_null(),
911                    }
912                }
913                arrays.push(Arc::new(builder.finish()));
914            }
915            ArrowColumnKind::Utf8 => {
916                let mut builder = StringBuilder::new();
917                for row in rows {
918                    match sqlite_value_to_text(&row[index]) {
919                        Some(value) => builder.append_value(value),
920                        None => builder.append_null(),
921                    }
922                }
923                arrays.push(Arc::new(builder.finish()));
924            }
925            ArrowColumnKind::Binary => {
926                let mut builder = BinaryBuilder::new();
927                for row in rows {
928                    match &row[index] {
929                        SqliteValue::Null => builder.append_null(),
930                        SqliteValue::Blob(value) => builder.append_value(value),
931                        other => builder.append_value(
932                            sqlite_value_to_text(other).unwrap_or_default().as_bytes(),
933                        ),
934                    }
935                }
936                arrays.push(Arc::new(builder.finish()));
937            }
938        }
939    }
940
941    arrays
942}
943
944fn sqlite_value_to_text(value: &SqliteValue) -> Option<String> {
945    match value {
946        SqliteValue::Null => None,
947        SqliteValue::Integer(value) => Some(value.to_string()),
948        SqliteValue::Real(value) => Some(format_float(*value)),
949        SqliteValue::Text(value) => Some(value.clone()),
950        SqliteValue::Blob(value) => Some(format!("x'{}'", blob_to_hex(value))),
951    }
952}
953
954fn blob_to_hex(bytes: &[u8]) -> String {
955    const HEX: &[u8; 16] = b"0123456789ABCDEF";
956    let mut output = String::with_capacity(bytes.len() * 2);
957    for byte in bytes {
958        output.push(char::from(HEX[(byte >> 4) as usize]));
959        output.push(char::from(HEX[(byte & 0x0f) as usize]));
960    }
961    output
962}
963
964fn json_float(value: f64) -> JsonValue {
965    if value.is_nan() || value.is_infinite() {
966        return JsonValue::String(format_float(value));
967    }
968    JsonNumber::from_f64(value)
969        .map(JsonValue::Number)
970        .unwrap_or_else(|| JsonValue::String(format_float(value)))
971}
972
973fn format_float(value: f64) -> String {
974    if value.fract() == 0.0 {
975        format!("{value:.1}")
976    } else {
977        value.to_string()
978    }
979}
980
981/// Arrow IPC chunk 收集器。
982/// Arrow IPC chunk collector.
983pub trait QueryStreamChunkWriter: Write {
984    /// 返回当前已经写出的 chunk 数量。
985    /// Return the number of chunks emitted so far.
986    fn emitted_chunk_count(&self) -> u64;
987
988    /// 返回当前已经写出的总字节数。
989    /// Return the total emitted byte size so far.
990    fn emitted_total_bytes(&self) -> u64;
991}
992
993/// Arrow IPC chunk 收集器。
994/// Arrow IPC chunk collector.
995#[allow(dead_code)]
996pub struct ChunkCollector {
997    chunks: Vec<Vec<u8>>,
998    pending: Vec<u8>,
999    target_chunk_size: usize,
1000    emitted_chunks: usize,
1001    emitted_bytes: usize,
1002}
1003
1004#[allow(dead_code)]
1005impl ChunkCollector {
1006    fn new(target_chunk_size: usize) -> Self {
1007        let chunk_size = target_chunk_size.max(64 * 1024);
1008        Self {
1009            chunks: Vec::new(),
1010            pending: Vec::with_capacity(chunk_size),
1011            target_chunk_size: chunk_size,
1012            emitted_chunks: 0,
1013            emitted_bytes: 0,
1014        }
1015    }
1016
1017    fn emit_full_chunks(&mut self) {
1018        while self.pending.len() >= self.target_chunk_size {
1019            let remainder = self.pending.split_off(self.target_chunk_size);
1020            let chunk = std::mem::replace(&mut self.pending, remainder);
1021            self.send_chunk(chunk);
1022        }
1023    }
1024
1025    fn emit_remaining(&mut self) {
1026        if self.pending.is_empty() {
1027            return;
1028        }
1029
1030        let chunk = std::mem::take(&mut self.pending);
1031        self.send_chunk(chunk);
1032    }
1033
1034    fn send_chunk(&mut self, chunk: Vec<u8>) {
1035        self.emitted_chunks += 1;
1036        self.emitted_bytes += chunk.len();
1037        self.chunks.push(chunk);
1038    }
1039}
1040
1041/// 基于临时文件的 QueryStream chunk 写入器。
1042/// Temporary-file-backed QueryStream chunk writer.
1043pub struct TempFileChunkWriter {
1044    file: File,
1045    file_path: PathBuf,
1046    pending: Vec<u8>,
1047    target_chunk_size: usize,
1048    emitted_chunks: usize,
1049    emitted_bytes: usize,
1050    current_offset: u64,
1051    chunk_descriptors: Vec<QueryStreamChunkDescriptor>,
1052}
1053
1054static NEXT_QUERY_STREAM_SPOOL_ID: AtomicU64 = AtomicU64::new(1);
1055
1056impl TempFileChunkWriter {
1057    fn new(target_chunk_size: usize) -> Result<Self, SqlExecCoreError> {
1058        let chunk_size = target_chunk_size.max(64 * 1024);
1059        let file_path = make_query_stream_spool_path();
1060        let file = File::create(&file_path).map_err(|error| {
1061            SqlExecCoreError::Internal(format!(
1062                "create query stream spool file failed: {error}"
1063            ))
1064        })?;
1065        Ok(Self {
1066            file,
1067            file_path,
1068            pending: Vec::with_capacity(chunk_size),
1069            target_chunk_size: chunk_size,
1070            emitted_chunks: 0,
1071            emitted_bytes: 0,
1072            current_offset: 0,
1073            chunk_descriptors: Vec::new(),
1074        })
1075    }
1076
1077    fn into_storage(self) -> QueryStreamStorage {
1078        QueryStreamStorage {
1079            file_path: self.file_path,
1080            chunks: self.chunk_descriptors,
1081        }
1082    }
1083
1084    fn emit_full_chunks(&mut self) -> io::Result<()> {
1085        while self.pending.len() >= self.target_chunk_size {
1086            let remainder = self.pending.split_off(self.target_chunk_size);
1087            let chunk = std::mem::replace(&mut self.pending, remainder);
1088            self.write_chunk(chunk)?;
1089        }
1090        Ok(())
1091    }
1092
1093    fn emit_remaining(&mut self) -> io::Result<()> {
1094        if self.pending.is_empty() {
1095            return Ok(());
1096        }
1097
1098        let chunk = std::mem::take(&mut self.pending);
1099        self.write_chunk(chunk)
1100    }
1101
1102    fn write_chunk(&mut self, chunk: Vec<u8>) -> io::Result<()> {
1103        self.file.write_all(&chunk)?;
1104        let chunk_len_u64 = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1105        self.chunk_descriptors.push(QueryStreamChunkDescriptor {
1106            offset: self.current_offset,
1107            len: chunk_len_u64,
1108        });
1109        self.current_offset = self.current_offset.saturating_add(chunk_len_u64);
1110        self.emitted_chunks += 1;
1111        self.emitted_bytes += chunk.len();
1112        Ok(())
1113    }
1114}
1115
1116impl QueryStreamChunkWriter for TempFileChunkWriter {
1117    fn emitted_chunk_count(&self) -> u64 {
1118        u64::try_from(self.emitted_chunks).unwrap_or(u64::MAX)
1119    }
1120
1121    fn emitted_total_bytes(&self) -> u64 {
1122        u64::try_from(self.emitted_bytes).unwrap_or(u64::MAX)
1123    }
1124}
1125
1126impl Write for TempFileChunkWriter {
1127    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1128        if buf.is_empty() {
1129            return Ok(0);
1130        }
1131
1132        self.pending.extend_from_slice(buf);
1133        self.emit_full_chunks()?;
1134        Ok(buf.len())
1135    }
1136
1137    fn flush(&mut self) -> io::Result<()> {
1138        self.emit_remaining()?;
1139        self.file.flush()
1140    }
1141}
1142
1143fn make_query_stream_spool_path() -> PathBuf {
1144    let unique = NEXT_QUERY_STREAM_SPOOL_ID.fetch_add(1, Ordering::Relaxed);
1145    let file_name = format!(
1146        "vldb-sqlite-query-stream-{}-{}-{}.bin",
1147        std::process::id(),
1148        unique,
1149        chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
1150    );
1151    std::env::temp_dir().join(file_name)
1152}
1153
1154impl QueryStreamChunkWriter for ChunkCollector {
1155    fn emitted_chunk_count(&self) -> u64 {
1156        u64::try_from(self.emitted_chunks).unwrap_or(u64::MAX)
1157    }
1158
1159    fn emitted_total_bytes(&self) -> u64 {
1160        u64::try_from(self.emitted_bytes).unwrap_or(u64::MAX)
1161    }
1162}
1163
1164impl Write for ChunkCollector {
1165    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1166        if buf.is_empty() {
1167            return Ok(0);
1168        }
1169
1170        self.pending.extend_from_slice(buf);
1171        self.emit_full_chunks();
1172        Ok(buf.len())
1173    }
1174
1175    fn flush(&mut self) -> io::Result<()> {
1176        self.emit_remaining();
1177        Ok(())
1178    }
1179}
1180
1181#[cfg(test)]
1182mod tests {
1183    use super::{
1184        DEFAULT_IPC_CHUNK_BYTES, ExecuteBatchResult, ExecuteScriptResult, count_sql_statements,
1185        has_multiple_sql_statements, json_param_to_sqlite_value, parse_legacy_params_json, query_json,
1186        query_stream,
1187    };
1188    use rusqlite::Connection;
1189    use rusqlite::types::Value as SqliteValue;
1190    use serde_json::json;
1191
1192    /// 创建临时内存连接并初始化一张测试表。
1193    /// Create a temporary in-memory connection and initialize a test table.
1194    fn open_test_connection() -> Connection {
1195        let conn = Connection::open_in_memory().expect("in-memory sqlite should open");
1196        conn.execute_batch(
1197            "CREATE TABLE demo(id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, score REAL, ok INTEGER);",
1198        )
1199        .expect("demo schema should initialize");
1200        conn
1201    }
1202
1203    #[test]
1204    fn parse_legacy_params_json_supports_scalar_values() {
1205        let parsed = parse_legacy_params_json("[1,2.5,true,\"hello\",null]")
1206            .expect("params_json should parse");
1207        assert_eq!(
1208            parsed,
1209            vec![
1210                SqliteValue::Integer(1),
1211                SqliteValue::Real(2.5),
1212                SqliteValue::Integer(1),
1213                SqliteValue::Text("hello".to_string()),
1214                SqliteValue::Null,
1215            ]
1216        );
1217    }
1218
1219    #[test]
1220    fn json_param_to_sqlite_value_rejects_nested_values() {
1221        let err = json_param_to_sqlite_value(json!({"nested":true})).expect_err("nested JSON should fail");
1222        assert!(err.to_string().contains("scalar JSON values"));
1223    }
1224
1225    #[test]
1226    fn execute_and_query_round_trip() {
1227        let mut conn = open_test_connection();
1228        let execute = super::execute_script(
1229            &mut conn,
1230            "INSERT INTO demo(name, score, ok) VALUES (?1, ?2, ?3)",
1231            &[
1232                SqliteValue::Text("alpha".to_string()),
1233                SqliteValue::Real(7.5),
1234                SqliteValue::Integer(1),
1235            ],
1236        )
1237        .expect("insert should succeed");
1238        assert_eq!(
1239            execute,
1240            ExecuteScriptResult {
1241                success: true,
1242                message: "statement executed successfully (rows_changed=1)".to_string(),
1243                rows_changed: 1,
1244                last_insert_rowid: 1,
1245            }
1246        );
1247
1248        let queried = query_json(
1249            &mut conn,
1250            "SELECT id, name, score, ok FROM demo ORDER BY id",
1251            &[],
1252        )
1253        .expect("query_json should succeed");
1254        assert_eq!(queried.row_count, 1);
1255        assert!(queried.json_data.contains("\"alpha\""));
1256    }
1257
1258    #[test]
1259    fn execute_batch_runs_multiple_parameter_sets() {
1260        let mut conn = open_test_connection();
1261        let batch = super::execute_batch(
1262            &mut conn,
1263            "INSERT INTO demo(name, score, ok) VALUES (?1, ?2, ?3)",
1264            &[
1265                vec![
1266                    SqliteValue::Text("alpha".to_string()),
1267                    SqliteValue::Real(1.5),
1268                    SqliteValue::Integer(1),
1269                ],
1270                vec![
1271                    SqliteValue::Text("beta".to_string()),
1272                    SqliteValue::Real(2.5),
1273                    SqliteValue::Integer(0),
1274                ],
1275            ],
1276        )
1277        .expect("batch should succeed");
1278        assert_eq!(
1279            batch,
1280            ExecuteBatchResult {
1281                success: true,
1282                message:
1283                    "batch executed successfully (statements_executed=2 rows_changed=2)".to_string(),
1284                rows_changed: 2,
1285                last_insert_rowid: 2,
1286                statements_executed: 2,
1287            }
1288        );
1289    }
1290
1291    #[test]
1292    fn query_stream_returns_ipc_chunks() {
1293        let mut conn = open_test_connection();
1294        conn.execute(
1295            "INSERT INTO demo(name, score, ok) VALUES (?1, ?2, ?3)",
1296            ("alpha", 7.5_f64, 1_i64),
1297        )
1298        .expect("insert should succeed");
1299        let result = query_stream(
1300            &mut conn,
1301            "SELECT id, name, score, ok FROM demo ORDER BY id",
1302            &[],
1303            DEFAULT_IPC_CHUNK_BYTES,
1304        )
1305        .expect("query_stream should succeed");
1306        assert_eq!(result.row_count, 1);
1307        assert!(result.chunk_count >= 1);
1308        let first_chunk = result.read_chunk(0).expect("first chunk should be readable");
1309        assert!(!first_chunk.is_empty());
1310    }
1311
1312    #[test]
1313    fn has_multiple_sql_statements_detects_multiple_statements() {
1314        assert!(has_multiple_sql_statements("SELECT 1; SELECT 2;"));
1315        assert!(!has_multiple_sql_statements("SELECT ';'"));
1316    }
1317
1318    #[test]
1319    fn count_sql_statements_ignores_empty_segments_and_comments() {
1320        assert_eq!(count_sql_statements(""), 0);
1321        assert_eq!(count_sql_statements(" ; \n "), 0);
1322        assert_eq!(count_sql_statements("SELECT 1"), 1);
1323        assert_eq!(count_sql_statements("SELECT 1; SELECT 2"), 2);
1324        assert_eq!(
1325            count_sql_statements("SELECT 1; -- ignored ;\n/* hidden ; */ SELECT 2;"),
1326            2
1327        );
1328    }
1329}