use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet},
io::Cursor,
path::{Path, PathBuf},
sync::atomic::{AtomicU64, Ordering},
time::Instant,
};
use calamine::{open_workbook_auto, Data, Reader};
use csv::{ReaderBuilder, StringRecord, Writer};
use duckdb::{types::ValueRef, Connection};
use encoding_rs::Encoding;
use serde_json::{Map, Number, Value};
use sysinfo::{Pid, ProcessesToUpdate, System};
use tokio::task;
use crate::{
config::{
dataset_source_from_uri, BenchmarkArgs, BenchmarkMeasurement, BenchmarkRequest,
BenchmarkResponse, BenchmarkResult, ColumnSummary, ConvertArgs, ConvertRequest,
ConvertResponse, DatasetSource, QueryArgs, QueryMode, QueryRequest, QueryResponse,
SchemaArgs, SchemaRequest, SchemaResponse, StorageConfig, DEFAULT_SCHEMA_PREVIEW_ROWS,
DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
},
error::{AppError, AppResult},
};
#[derive(Debug, Clone, Default)]
pub struct Text2SqlEngine;
impl Text2SqlEngine {
pub fn new() -> Self {
Self
}
pub async fn execute_query(&self, request: QueryRequest) -> AppResult<QueryResponse> {
task::spawn_blocking(move || Self::execute_query_sync(request)).await?
}
pub async fn convert(&self, request: ConvertRequest) -> AppResult<ConvertResponse> {
task::spawn_blocking(move || Self::convert_sync(request)).await?
}
pub async fn inspect_schema(&self, request: SchemaRequest) -> AppResult<SchemaResponse> {
task::spawn_blocking(move || Self::inspect_schema_sync(request)).await?
}
pub async fn benchmark(&self, request: BenchmarkRequest) -> AppResult<BenchmarkResponse> {
task::spawn_blocking(move || Self::benchmark_sync(request)).await?
}
pub async fn convert_from_args(&self, args: ConvertArgs) -> AppResult<ConvertResponse> {
self.convert(ConvertRequest {
input_path: args.input.to_string_lossy().into_owned(),
output: dataset_source_from_uri(args.output, &args.s3),
normalize_columns: args.normalize_columns,
sheet_name: args.sheet_name,
csv_encoding: args.csv_encoding,
overwrite: args.overwrite,
add_filename_column: args.add_filename_column,
})
.await
}
pub async fn schema_from_args(&self, args: SchemaArgs) -> AppResult<SchemaResponse> {
self.inspect_schema(SchemaRequest {
dataset: dataset_source_from_uri(args.dataset, &args.s3),
table_name: args.table_name,
})
.await
}
pub async fn query_from_args(&self, args: QueryArgs) -> AppResult<QueryResponse> {
self.execute_query(QueryRequest {
sql: args.sql,
dataset: dataset_source_from_uri(args.dataset, &args.s3),
table_name: args.table_name,
mode: args.mode,
small_file_threshold_bytes: args.small_file_threshold_bytes,
limit: args.limit,
})
.await
}
pub async fn benchmark_from_args(&self, args: BenchmarkArgs) -> AppResult<BenchmarkResponse> {
self.benchmark(BenchmarkRequest {
output_dir: Some(args.output_dir.to_string_lossy().into_owned()),
row_counts: args.row_counts,
sql: args.sql,
limit: args.limit,
})
.await
}
fn execute_query_sync(request: QueryRequest) -> AppResult<QueryResponse> {
if request.sql.trim().is_empty() {
return Err(AppError::Validation("sql must not be empty".to_string()));
}
validate_query_sql(&request.sql)?;
let table_name = clean_identifier(&request.table_name);
if table_name.is_empty() {
return Err(AppError::Validation(
"table_name resolved to an empty identifier".to_string(),
));
}
let start = Instant::now();
let memory_before = current_memory_bytes();
let (resolved_mode, mut notes) = resolve_query_mode(
&request.dataset,
&request.mode,
request.small_file_threshold_bytes,
)?;
let (conn, mut setup_notes) =
open_registered_dataset(&request.dataset, &table_name, &resolved_mode)?;
notes.append(&mut setup_notes);
harden_query_connection(&conn, &request.dataset, &mut notes)?;
let sql = apply_limit(&request.sql, request.limit);
let (columns, rows) = collect_rows(&conn, &sql)?;
let elapsed_ms = start.elapsed().as_millis();
let memory_after = current_memory_bytes();
if matches!(resolved_mode, QueryMode::ParquetSelective) {
notes.push("Parquet-selective mode keeps the query on read_parquet(...), allowing DuckDB to use Parquet metadata, column pruning, and predicate pushdown when the SQL shape permits it.".to_string());
} else {
notes.push("Full-download mode materializes the entire dataset into a temporary DuckDB table before applying the SQL, which approximates downloading/loading the full file first.".to_string());
}
Ok(QueryResponse {
table_name,
mode: resolved_mode,
row_count: rows.len(),
columns,
rows,
elapsed_ms,
memory_delta_bytes: memory_after as i64 - memory_before as i64,
notes,
})
}
fn convert_sync(request: ConvertRequest) -> AppResult<ConvertResponse> {
let input_path = PathBuf::from(&request.input_path);
if !input_path.exists() {
return Err(AppError::Validation(format!(
"input path does not exist: {}",
input_path.display()
)));
}
ensure_output_allowed(&request.output, request.overwrite)?;
let start = Instant::now();
let temp_csv = unique_temp_csv_path();
let columns = write_normalized_csv(
&input_path,
&temp_csv,
request.normalize_columns,
request.sheet_name.as_deref(),
request.csv_encoding.as_deref(),
)?;
let row_count = count_csv_rows(&temp_csv)?;
let conn = Connection::open_in_memory()?;
configure_source(&conn, &request.output, &mut Vec::new())?;
let output_uri = source_to_uri(&request.output)?;
let filename_stem = input_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or_default()
.to_string();
let select_expr = if request.add_filename_column {
format!(
"SELECT *, {} AS _filename FROM read_csv_auto({}, HEADER=TRUE)",
sql_string_literal(&filename_stem),
sql_string_literal(&temp_csv.to_string_lossy()),
)
} else {
format!(
"SELECT * FROM read_csv_auto({}, HEADER=TRUE)",
sql_string_literal(&temp_csv.to_string_lossy()),
)
};
let copy_sql = format!(
"COPY ({select_expr}) TO {} (FORMAT parquet, COMPRESSION zstd)",
sql_string_literal(&output_uri),
);
conn.execute_batch(©_sql)?;
let _ = std::fs::remove_file(&temp_csv);
let mut columns = columns;
if request.add_filename_column {
columns.push("_filename".to_string());
}
let mut notes = vec!["CSV/XLS/XLSX/JSON input was normalized to CSV and then written to Parquet with DuckDB COPY.".to_string()];
if let Some(csv_encoding) = request.csv_encoding {
notes.push(format!(
"CSV input was decoded using the explicit encoding label `{csv_encoding}` before normalization."
));
}
Ok(ConvertResponse {
input_path: request.input_path,
output_uri,
row_count,
columns,
elapsed_ms: start.elapsed().as_millis(),
notes,
})
}
fn inspect_schema_sync(request: SchemaRequest) -> AppResult<SchemaResponse> {
let table_name = clean_identifier(&request.table_name);
if table_name.is_empty() {
return Err(AppError::Validation(
"table_name resolved to an empty identifier".to_string(),
));
}
let start = Instant::now();
let (conn, mut notes) =
open_registered_dataset(&request.dataset, &table_name, &QueryMode::ParquetSelective)?;
let columns = collect_schema(&conn, &table_name)?;
let preview_rows = collect_preview_rows(&conn, &table_name, DEFAULT_SCHEMA_PREVIEW_ROWS)?;
notes.push(
"Schema inspection registers the parquet dataset in parquet-selective mode and reads DuckDB statement metadata before generating SQL."
.to_string(),
);
Ok(SchemaResponse {
table_name,
columns,
preview_rows,
elapsed_ms: start.elapsed().as_millis(),
notes,
})
}
fn benchmark_sync(request: BenchmarkRequest) -> AppResult<BenchmarkResponse> {
if request.row_counts.is_empty() {
return Err(AppError::Validation(
"row_counts must not be empty".to_string(),
));
}
let output_dir = request
.output_dir
.unwrap_or_else(|| "./benchmark-data".to_string());
let output_path = PathBuf::from(&output_dir);
std::fs::create_dir_all(&output_path)?;
let mut results = Vec::new();
for row_count in request.row_counts {
let parquet_path = output_path.join(format!("benchmark-{}.parquet", row_count));
generate_benchmark_parquet(&parquet_path, row_count)?;
let dataset = DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
};
let selective = run_query_for_mode(
&request.sql,
&dataset,
QueryMode::ParquetSelective,
request.limit,
)?;
let full = run_query_for_mode(
&request.sql,
&dataset,
QueryMode::FullDownload,
request.limit,
)?;
results.push(BenchmarkResult {
row_count,
parquet_path: parquet_path.to_string_lossy().into_owned(),
parquet_selective: benchmark_measurement(&selective),
full_download: benchmark_measurement(&full),
});
}
let recommendation = recommend_strategy(&results);
Ok(BenchmarkResponse {
sql: request.sql,
output_dir,
results,
recommendation,
})
}
}
fn run_query_for_mode(
sql: &str,
dataset: &DatasetSource,
mode: QueryMode,
limit: Option<usize>,
) -> AppResult<QueryResponse> {
Text2SqlEngine::execute_query_sync(QueryRequest {
sql: sql.to_string(),
dataset: dataset.clone(),
table_name: "dataset".to_string(),
mode,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit,
})
}
fn validate_query_sql(sql: &str) -> AppResult<()> {
let policy_sql = policy_sql(sql);
let normalized = policy_sql.trim();
if normalized.is_empty() {
return Err(AppError::Validation("sql must not be empty".to_string()));
}
let normalized = normalized.trim_end_matches(';').trim();
if normalized.contains(';') {
return Err(AppError::Validation(
"only a single read-only SELECT/WITH statement is allowed".to_string(),
));
}
let mut tokens = normalized
.split(|c: char| !(c.is_ascii_alphanumeric() || c == '_'))
.filter(|token| !token.is_empty());
let first_token = tokens
.next()
.ok_or_else(|| AppError::Validation("sql must not be empty".to_string()))?;
if first_token != "select" && first_token != "with" {
return Err(AppError::Validation(
"only read-only SELECT/WITH queries against the registered dataset are allowed"
.to_string(),
));
}
for forbidden in [
"copy", "export", "attach", "detach", "install", "load", "create", "alter", "drop",
"insert", "update", "delete", "call", "pragma", "merge", "set", "vacuum",
] {
if tokens.clone().any(|token| token == forbidden) {
return Err(AppError::Validation(format!(
"query contains forbidden statement or keyword: {forbidden}"
)));
}
}
for forbidden_fn in [
"read_parquet(",
"read_csv(",
"read_csv_auto(",
"read_ndjson(",
"read_json(",
"read_json_auto(",
"read_json_objects(",
"parquet_scan(",
"csv_scan(",
"delta_scan(",
"iceberg_scan(",
"glob(",
"read_blob(",
"read_text(",
"parquet_metadata(",
"parquet_schema(",
"parquet_file_metadata(",
] {
if normalized.contains(forbidden_fn) {
return Err(AppError::Validation(format!(
"direct table-function access is not allowed in queries: {forbidden_fn}"
)));
}
}
Ok(())
}
fn policy_sql(sql: &str) -> String {
#[derive(Clone, Copy)]
enum State {
Normal,
SingleQuote,
DoubleQuote,
LineComment,
BlockComment,
}
let mut out = String::with_capacity(sql.len());
let chars: Vec<char> = sql.chars().collect();
let mut i = 0usize;
let mut state = State::Normal;
while i < chars.len() {
let c = chars[i];
let next = chars.get(i + 1).copied();
match state {
State::Normal => match (c, next) {
('\'', _) => {
state = State::SingleQuote;
out.push(' ');
}
('"', _) => {
state = State::DoubleQuote;
out.push(' ');
}
('-', Some('-')) => {
state = State::LineComment;
i += 1;
out.push(' ');
}
('/', Some('*')) => {
state = State::BlockComment;
i += 1;
out.push(' ');
}
_ => out.push(c.to_ascii_lowercase()),
},
State::SingleQuote => {
if c == '\'' {
if next == Some('\'') {
i += 1;
} else {
state = State::Normal;
}
}
out.push(' ');
}
State::DoubleQuote => {
if c == '"' {
state = State::Normal;
}
out.push(' ');
}
State::LineComment => {
if c == '\n' {
state = State::Normal;
out.push('\n');
} else {
out.push(' ');
}
}
State::BlockComment => {
if c == '*' && next == Some('/') {
state = State::Normal;
i += 1;
}
out.push(' ');
}
}
i += 1;
}
out
}
fn resolve_query_mode(
dataset: &DatasetSource,
requested_mode: &QueryMode,
small_file_threshold_bytes: u64,
) -> AppResult<(QueryMode, Vec<String>)> {
match requested_mode {
QueryMode::ParquetSelective => Ok((QueryMode::ParquetSelective, Vec::new())),
QueryMode::FullDownload => Ok((QueryMode::FullDownload, Vec::new())),
QueryMode::AutoHybrid => auto_hybrid_mode(dataset, small_file_threshold_bytes),
}
}
fn auto_hybrid_mode(
dataset: &DatasetSource,
small_file_threshold_bytes: u64,
) -> AppResult<(QueryMode, Vec<String>)> {
let source_uri = source_to_uri(dataset)?;
if matches!(&dataset.storage, StorageConfig::S3(_))
|| source_uri.starts_with("http://")
|| source_uri.starts_with("https://")
{
return Ok((
QueryMode::ParquetSelective,
vec![format!(
"hybrid chose parquet_selective because the dataset is remote; the {} byte threshold only applies to local files.",
small_file_threshold_bytes
)],
));
}
match std::fs::metadata(&source_uri) {
Ok(metadata) if metadata.len() <= small_file_threshold_bytes => Ok((
QueryMode::FullDownload,
vec![format!(
"hybrid chose full_download because the local parquet file is {} bytes, which is at or below the {} byte threshold.",
metadata.len(),
small_file_threshold_bytes
)],
)),
Ok(metadata) => Ok((
QueryMode::ParquetSelective,
vec![format!(
"hybrid chose parquet_selective because the local parquet file is {} bytes, which is above the {} byte threshold.",
metadata.len(),
small_file_threshold_bytes
)],
)),
Err(_) => Ok((
QueryMode::ParquetSelective,
vec![format!(
"hybrid could not read local file metadata, so it fell back to parquet_selective with the {} byte threshold.",
small_file_threshold_bytes
)],
)),
}
}
fn open_registered_dataset(
dataset: &DatasetSource,
table_name: &str,
mode: &QueryMode,
) -> AppResult<(Connection, Vec<String>)> {
let mut notes = Vec::new();
let conn = Connection::open_in_memory()?;
configure_source(&conn, dataset, &mut notes)?;
register_dataset(&conn, dataset, table_name, mode)?;
Ok((conn, notes))
}
fn harden_query_connection(
conn: &Connection,
dataset: &DatasetSource,
notes: &mut Vec<String>,
) -> AppResult<()> {
let (allowed_paths, allowed_directories) = allowed_query_locations(dataset)?;
let mut statements = vec![
"SET autoload_known_extensions = false".to_string(),
"SET autoinstall_known_extensions = false".to_string(),
"SET allow_community_extensions = false".to_string(),
];
if !allowed_directories.is_empty() {
statements.push(format!(
"SET allowed_directories = {}",
sql_string_list(&allowed_directories)
));
}
if !allowed_paths.is_empty() {
statements.push(format!(
"SET allowed_paths = {}",
sql_string_list(&allowed_paths)
));
}
statements.push("SET enable_external_access = false".to_string());
statements.push("SET lock_configuration = true".to_string());
conn.execute_batch(&statements.join(";"))?;
notes.push(
"Applied DuckDB defense-in-depth query hardening: disabled extension auto-load/install, restricted external access to the configured dataset location, and locked configuration changes before running user SQL.".to_string(),
);
Ok(())
}
fn allowed_query_locations(dataset: &DatasetSource) -> AppResult<(Vec<String>, Vec<String>)> {
let source_uri = source_to_uri(dataset)?;
let mut allowed_paths = vec![source_uri.clone()];
let mut allowed_directories = Vec::new();
let source_uses_glob = path_uses_glob(&source_uri);
match &dataset.storage {
StorageConfig::Local => {
let source_path = PathBuf::from(&source_uri);
if source_uses_glob {
if let Some(parent) = source_path.parent().filter(|p| !p.as_os_str().is_empty()) {
allowed_directories.push(parent.to_string_lossy().into_owned());
}
}
if let Ok(canonical) = std::fs::canonicalize(&source_path) {
allowed_paths.push(canonical.to_string_lossy().into_owned());
if source_uses_glob {
if let Some(parent) = canonical.parent() {
allowed_directories.push(parent.to_string_lossy().into_owned());
}
}
} else if source_uses_glob {
if let Some(parent) = source_path.parent().filter(|p| !p.as_os_str().is_empty()) {
allowed_directories.push(parent.to_string_lossy().into_owned());
}
}
}
StorageConfig::S3(_) => {
if source_uses_glob {
if let Some(prefix) = s3_prefix(&source_uri) {
allowed_directories.push(prefix);
}
}
}
}
dedupe_strings(&mut allowed_paths);
dedupe_strings(&mut allowed_directories);
Ok((allowed_paths, allowed_directories))
}
fn path_uses_glob(value: &str) -> bool {
value.contains('*') || value.contains('?') || value.contains('[')
}
fn dedupe_strings(values: &mut Vec<String>) {
values.sort();
values.dedup();
}
fn s3_prefix(uri: &str) -> Option<String> {
let (scheme, rest) = uri.split_once("://")?;
let slash_index = rest.rfind('/')?;
Some(format!("{scheme}://{}", &rest[..=slash_index]))
}
fn benchmark_measurement(response: &QueryResponse) -> BenchmarkMeasurement {
BenchmarkMeasurement {
elapsed_ms: response.elapsed_ms,
memory_delta_bytes: response.memory_delta_bytes,
result_rows: response.row_count,
}
}
fn configure_source(
conn: &Connection,
source: &DatasetSource,
notes: &mut Vec<String>,
) -> AppResult<()> {
if let StorageConfig::S3(options) = &source.storage {
conn.execute_batch("INSTALL httpfs; LOAD httpfs;")?;
let mut parts = vec!["TYPE s3".to_string()];
if let Some(region) = &options.region {
parts.push(format!("REGION {}", sql_string_literal(region)));
}
if let Some(endpoint) = &options.endpoint {
parts.push(format!(
"ENDPOINT {}",
sql_string_literal(&duckdb_endpoint(endpoint))
));
}
if let Some(access_key_id) = &options.access_key_id {
parts.push(format!("KEY_ID {}", sql_string_literal(access_key_id)));
}
if let Some(secret_access_key) = &options.secret_access_key {
parts.push(format!("SECRET {}", sql_string_literal(secret_access_key)));
}
if let Some(session_token) = &options.session_token {
parts.push(format!(
"SESSION_TOKEN {}",
sql_string_literal(session_token)
));
}
parts.push(format!(
"URL_STYLE {}",
sql_string_literal(if options.force_path_style {
"path"
} else {
"vhost"
})
));
parts.push(format!(
"USE_SSL {}",
if options.allow_http { "false" } else { "true" }
));
let sql = format!(
"CREATE OR REPLACE SECRET text2sql_s3 ({})",
parts.join(", ")
);
conn.execute_batch(&sql)?;
notes.push(
"Loaded DuckDB httpfs and configured an S3 secret for remote parquet access."
.to_string(),
);
}
Ok(())
}
fn duckdb_endpoint(endpoint: &str) -> String {
endpoint
.trim()
.trim_end_matches('/')
.strip_prefix("http://")
.or_else(|| {
endpoint
.trim()
.trim_end_matches('/')
.strip_prefix("https://")
})
.unwrap_or_else(|| endpoint.trim().trim_end_matches('/'))
.to_string()
}
fn register_dataset(
conn: &Connection,
dataset: &DatasetSource,
table_name: &str,
mode: &QueryMode,
) -> AppResult<()> {
let source_uri = source_to_uri(dataset)?;
let object_expr = format!("read_parquet({})", sql_string_literal(&source_uri));
let statement = match mode {
QueryMode::ParquetSelective => format!(
"CREATE OR REPLACE VIEW {} AS SELECT * FROM {}",
sql_identifier(table_name),
object_expr
),
QueryMode::FullDownload => format!(
"CREATE OR REPLACE TEMP TABLE {} AS SELECT * FROM {}",
sql_identifier(table_name),
object_expr
),
QueryMode::AutoHybrid => format!(
"CREATE OR REPLACE VIEW {} AS SELECT * FROM {}",
sql_identifier(table_name),
object_expr
),
};
conn.execute_batch(&statement)?;
Ok(())
}
fn collect_schema(conn: &Connection, table_name: &str) -> AppResult<Vec<ColumnSummary>> {
let sql = format!("SELECT * FROM {} LIMIT 0", sql_identifier(table_name));
collect_columns(conn, &sql)
}
fn collect_preview_rows(
conn: &Connection,
table_name: &str,
limit: usize,
) -> AppResult<Vec<Value>> {
let sql = format!(
"SELECT * FROM {} LIMIT {}",
sql_identifier(table_name),
limit
);
let (_, rows) = collect_rows(conn, &sql)?;
Ok(rows)
}
fn collect_columns(conn: &Connection, sql: &str) -> AppResult<Vec<ColumnSummary>> {
let mut stmt = conn.prepare(sql)?;
let rows = stmt.query([])?;
let statement = rows
.as_ref()
.ok_or_else(|| AppError::Validation("query returned no statement metadata".to_string()))?;
let column_count = statement.column_count();
let mut columns = Vec::with_capacity(column_count);
for index in 0..column_count {
let name = statement.column_name(index)?.to_string();
let duckdb_type = statement.column_type(index).to_string();
columns.push(ColumnSummary { name, duckdb_type });
}
Ok(columns)
}
fn collect_rows(conn: &Connection, sql: &str) -> AppResult<(Vec<ColumnSummary>, Vec<Value>)> {
let mut stmt = conn.prepare(sql)?;
let columns = collect_columns(conn, sql)?;
let mut rows = stmt.query([])?;
let mut out = Vec::new();
while let Some(row) = rows.next()? {
let mut object = Map::new();
for (index, column) in columns.iter().enumerate() {
let value = row.get_ref(index)?;
object.insert(column.name.clone(), value_ref_to_json(value)?);
}
out.push(Value::Object(object));
}
Ok((columns, out))
}
fn value_ref_to_json(value: ValueRef<'_>) -> AppResult<Value> {
Ok(match value {
ValueRef::Null => Value::Null,
ValueRef::Boolean(v) => Value::Bool(v),
ValueRef::TinyInt(v) => Value::Number(v.into()),
ValueRef::SmallInt(v) => Value::Number(v.into()),
ValueRef::Int(v) => Value::Number(v.into()),
ValueRef::BigInt(v) => Value::Number(v.into()),
ValueRef::HugeInt(v) => Value::String(v.to_string()),
ValueRef::UTinyInt(v) => Value::Number(v.into()),
ValueRef::USmallInt(v) => Value::Number(v.into()),
ValueRef::UInt(v) => Value::Number(v.into()),
ValueRef::UBigInt(v) => Value::Number(Number::from(v)),
ValueRef::Float(v) => Number::from_f64(v as f64)
.map(Value::Number)
.unwrap_or(Value::Null),
ValueRef::Double(v) => Number::from_f64(v)
.map(Value::Number)
.unwrap_or(Value::Null),
ValueRef::Decimal(v) => Value::String(v.to_string()),
ValueRef::Timestamp(_, v) => Value::String(v.to_string()),
ValueRef::Text(bytes) => Value::String(String::from_utf8(bytes.to_vec())?),
ValueRef::Blob(bytes) => Value::String(format!("0x{}", hex_string(bytes))),
ValueRef::Date32(v) => Value::String(v.to_string()),
ValueRef::Time64(_, v) => Value::String(v.to_string()),
ValueRef::Interval {
months,
days,
nanos,
} => Value::String(format!("{months} months, {days} days, {nanos} nanos")),
other => Value::String(format!("{other:?}")),
})
}
fn write_normalized_csv(
input_path: &Path,
output_path: &Path,
normalize_columns: bool,
sheet_name: Option<&str>,
csv_encoding: Option<&str>,
) -> AppResult<Vec<String>> {
let extension = input_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or_default()
.to_ascii_lowercase();
if csv_encoding
.map(str::trim)
.is_some_and(|encoding| !encoding.is_empty())
&& extension != "csv"
{
return Err(AppError::Validation(
"csv_encoding is only supported for CSV input".to_string(),
));
}
match extension.as_str() {
"csv" => convert_csv_to_csv(input_path, output_path, normalize_columns, csv_encoding),
"json" | "ndjson" | "jsonl" => {
convert_json_to_csv(input_path, output_path, normalize_columns)
}
"xls" | "xlsx" => {
convert_spreadsheet_to_csv(input_path, output_path, normalize_columns, sheet_name)
}
other => Err(AppError::Validation(format!(
"unsupported input extension: {other}"
))),
}
}
fn convert_csv_to_csv(
input_path: &Path,
output_path: &Path,
normalize_columns: bool,
csv_encoding: Option<&str>,
) -> AppResult<Vec<String>> {
if let Some(csv_encoding) = csv_encoding.map(str::trim).filter(|e| !e.is_empty()) {
let mut reader = csv_reader_from_encoded_path(input_path, csv_encoding)?;
let headers = reader.headers()?.clone();
let normalized = normalize_headers(&headers, normalize_columns);
let mut writer = Writer::from_path(output_path)?;
writer.write_record(&normalized)?;
for record in reader.records() {
writer.write_record(&record?)?;
}
writer.flush()?;
Ok(normalized)
} else {
let mut reader = ReaderBuilder::new().flexible(true).from_path(input_path)?;
let headers = reader.headers()?.clone();
let normalized = normalize_headers(&headers, normalize_columns);
let mut writer = Writer::from_path(output_path)?;
writer.write_record(&normalized)?;
for record in reader.records() {
writer.write_record(&record?)?;
}
writer.flush()?;
Ok(normalized)
}
}
fn csv_reader_from_encoded_path(
input_path: &Path,
csv_encoding: &str,
) -> AppResult<csv::Reader<Cursor<Vec<u8>>>> {
let raw_bytes = std::fs::read(input_path)?;
let decoded = decode_csv_bytes(&raw_bytes, csv_encoding)?;
let content = strip_utf8_bom(decoded.as_ref()).as_bytes().to_vec();
Ok(ReaderBuilder::new()
.flexible(true)
.from_reader(Cursor::new(content)))
}
fn decode_csv_bytes<'a>(bytes: &'a [u8], csv_encoding: &str) -> AppResult<Cow<'a, str>> {
let encoding = Encoding::for_label_no_replacement(csv_encoding.trim().as_bytes())
.ok_or_else(|| AppError::Validation(format!("unsupported csv_encoding: {csv_encoding}")))?;
encoding
.decode_without_bom_handling_and_without_replacement(bytes)
.ok_or_else(|| {
AppError::Validation(format!(
"input bytes could not be decoded as {}",
encoding.name()
))
})
}
fn convert_json_to_csv(
input_path: &Path,
output_path: &Path,
normalize_columns: bool,
) -> AppResult<Vec<String>> {
let values = read_json_rows(input_path)?;
let mut keys = BTreeSet::new();
for value in &values {
let object = value.as_object().ok_or_else(|| {
AppError::Validation(
"JSON input must be an array of objects, newline-delimited objects, or an object whose values are objects".to_string(),
)
})?;
keys.extend(object.keys().cloned());
}
let ordered: Vec<String> = keys.into_iter().collect();
let normalized = normalize_string_vec(&ordered, normalize_columns);
let mut writer = Writer::from_path(output_path)?;
writer.write_record(&normalized)?;
for value in &values {
let object = value.as_object().ok_or_else(|| {
AppError::Validation("JSON input must contain only objects".to_string())
})?;
let row = ordered
.iter()
.map(|key| object.get(key).map(json_value_to_cell).unwrap_or_default())
.collect::<Vec<_>>();
writer.write_record(&row)?;
}
writer.flush()?;
Ok(normalized)
}
fn read_json_rows(input_path: &Path) -> AppResult<Vec<Value>> {
let content = std::fs::read_to_string(input_path)?;
let content = strip_utf8_bom(&content);
match json_input_mode(input_path) {
JsonInputMode::LineDelimited => parse_line_delimited_json(content),
JsonInputMode::Auto => parse_json_document_or_lines(content),
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum JsonInputMode {
Auto,
LineDelimited,
}
fn json_input_mode(input_path: &Path) -> JsonInputMode {
match input_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or_default()
.to_ascii_lowercase()
.as_str()
{
"ndjson" | "jsonl" => JsonInputMode::LineDelimited,
_ => JsonInputMode::Auto,
}
}
fn parse_json_document_or_lines(content: &str) -> AppResult<Vec<Value>> {
let trimmed = content.trim_start();
if trimmed.starts_with('[') {
serde_json::from_str(trimmed).map_err(AppError::from)
} else if trimmed.starts_with('{') {
match serde_json::from_str(trimmed) {
Ok(root) => normalize_json_root(root),
Err(_) => parse_line_delimited_json(trimmed),
}
} else {
parse_line_delimited_json(trimmed)
}
}
fn parse_line_delimited_json(content: &str) -> AppResult<Vec<Value>> {
content
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.map(|line| serde_json::from_str::<Value>(line).map_err(AppError::from))
.collect()
}
fn strip_utf8_bom(content: &str) -> &str {
content.strip_prefix('\u{feff}').unwrap_or(content)
}
fn convert_spreadsheet_to_csv(
input_path: &Path,
output_path: &Path,
normalize_columns: bool,
sheet_name: Option<&str>,
) -> AppResult<Vec<String>> {
let mut workbook = open_workbook_auto(input_path)?;
let selected_sheet = if let Some(sheet_name) = sheet_name {
sheet_name.to_string()
} else {
workbook
.sheet_names()
.first()
.cloned()
.ok_or_else(|| AppError::Validation("spreadsheet has no sheets".to_string()))?
};
let range = workbook.worksheet_range(&selected_sheet)?;
let mut rows = range.rows();
let header_row = rows
.next()
.ok_or_else(|| AppError::Validation("spreadsheet is empty".to_string()))?;
let headers = header_row.iter().map(cell_to_string).collect::<Vec<_>>();
let normalized = normalize_string_vec(&headers, normalize_columns);
let mut writer = Writer::from_path(output_path)?;
writer.write_record(&normalized)?;
for row in rows {
let cells = row.iter().map(cell_to_string).collect::<Vec<_>>();
writer.write_record(&cells)?;
}
writer.flush()?;
Ok(normalized)
}
fn normalize_json_root(root: Value) -> AppResult<Vec<Value>> {
match root {
Value::Array(values) => Ok(values),
Value::Object(object_map) => {
if object_map.values().all(|value| value.is_object()) {
let mut rows = Vec::with_capacity(object_map.len());
for (key, value) in object_map {
let mut object = value.as_object().cloned().ok_or_else(|| {
AppError::Validation("expected nested JSON objects".to_string())
})?;
object.insert("entry_key".to_string(), Value::String(key));
rows.push(Value::Object(object));
}
Ok(rows)
} else {
Ok(vec![Value::Object(object_map)])
}
}
_ => Err(AppError::Validation(
"JSON input must be an array of objects, newline-delimited objects, or an object whose values are objects".to_string(),
)),
}
}
fn normalize_headers(headers: &StringRecord, normalize_columns: bool) -> Vec<String> {
normalize_string_vec(
&headers
.iter()
.map(|value| value.to_string())
.collect::<Vec<_>>(),
normalize_columns,
)
}
fn normalize_string_vec(headers: &[String], normalize_columns: bool) -> Vec<String> {
let base = headers
.iter()
.map(|header| {
if normalize_columns {
clean_identifier(header)
} else {
header.to_string()
}
})
.collect::<Vec<_>>();
dedupe_headers(base)
}
fn dedupe_headers(headers: Vec<String>) -> Vec<String> {
let mut seen = BTreeMap::<String, usize>::new();
headers
.into_iter()
.map(|header| {
let base = if header.trim().is_empty() {
"column".to_string()
} else {
header
};
let count = seen.entry(base.clone()).or_insert(0);
*count += 1;
if *count == 1 {
base
} else {
format!("{}_{}", base, count)
}
})
.collect()
}
fn clean_identifier(value: &str) -> String {
let reserved = [
"order", "group", "by", "where", "from", "select", "insert", "delete", "update", "join",
"table",
];
let mut cleaned = value
.trim()
.chars()
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect::<String>()
.trim_matches('_')
.to_string();
cleaned = cleaned.replace(['"', '='], "");
if cleaned.is_empty() {
cleaned = "column".to_string();
}
if cleaned.chars().next().is_some_and(|c| c.is_ascii_digit()) {
cleaned = format!("col_{cleaned}");
}
if reserved
.iter()
.any(|reserved| reserved.eq_ignore_ascii_case(&cleaned))
{
cleaned.push_str("_col");
}
cleaned
}
fn cell_to_string(cell: &Data) -> String {
match cell {
Data::Empty => String::new(),
Data::String(value) => value.to_string(),
Data::Float(value) => value.to_string(),
Data::Int(value) => value.to_string(),
Data::Bool(value) => value.to_string(),
Data::DateTime(value) => value.to_string(),
Data::DateTimeIso(value) => value.to_string(),
Data::DurationIso(value) => value.to_string(),
Data::Error(value) => format!("{value:?}"),
}
}
fn json_value_to_cell(value: &Value) -> String {
match value {
Value::Null => String::new(),
Value::Bool(v) => v.to_string(),
Value::Number(v) => v.to_string(),
Value::String(v) => v.clone(),
Value::Array(_) | Value::Object(_) => value.to_string(),
}
}
fn apply_limit(sql: &str, limit: Option<usize>) -> String {
match limit {
Some(limit) => format!("SELECT * FROM ({sql}) AS __text2sql_limit LIMIT {limit}"),
None => sql.to_string(),
}
}
fn current_memory_bytes() -> u64 {
let pid = Pid::from_u32(std::process::id());
let mut system = System::new();
system.refresh_processes(ProcessesToUpdate::Some(&[pid]), true);
system
.process(pid)
.map(|process| process.memory())
.unwrap_or(0)
}
static TEMP_FILE_COUNTER: AtomicU64 = AtomicU64::new(0);
fn unique_temp_csv_path() -> PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|duration| duration.as_nanos())
.unwrap_or_default();
let sequence = TEMP_FILE_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"text2sql-{}-{nanos}-{sequence}.csv",
std::process::id()
))
}
fn count_csv_rows(path: &Path) -> AppResult<usize> {
let mut reader = ReaderBuilder::new().flexible(true).from_path(path)?;
let mut count = 0usize;
for record in reader.records() {
record?;
count += 1;
}
Ok(count)
}
fn ensure_output_allowed(output: &DatasetSource, overwrite: bool) -> AppResult<()> {
if matches!(output.storage, StorageConfig::Local) {
let path = PathBuf::from(&output.uri);
if path.exists() && !overwrite {
return Err(AppError::Validation(format!(
"output already exists: {}",
path.display()
)));
}
}
Ok(())
}
fn source_to_uri(source: &DatasetSource) -> AppResult<String> {
match &source.storage {
StorageConfig::Local => Ok(source.uri.clone()),
StorageConfig::S3(_) => {
if !source.uri.starts_with("s3://") {
return Err(AppError::Validation(
"S3 storage requires an s3://bucket/key URI".to_string(),
));
}
Ok(source.uri.clone())
}
}
}
fn sql_string_list(values: &[String]) -> String {
format!(
"[{}]",
values
.iter()
.map(|value| sql_string_literal(value))
.collect::<Vec<_>>()
.join(", ")
)
}
fn sql_string_literal(value: &str) -> String {
format!("'{}'", value.replace('\\', "\\\\").replace('\'', "''"))
}
fn sql_identifier(value: &str) -> String {
format!("\"{}\"", value.replace('"', "\"\""))
}
fn hex_string(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for byte in bytes {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0f) as usize] as char);
}
out
}
fn generate_benchmark_parquet(path: &Path, row_count: usize) -> AppResult<()> {
let conn = Connection::open_in_memory()?;
let sql = format!(
"COPY (SELECT i AS id, CASE i % 4 WHEN 0 THEN '레이' WHEN 1 THEN '코스피' WHEN 2 THEN '코스닥' ELSE 'ETF' END AS symbol, CAST((i % 28) + 1 AS INTEGER) AS trade_day, CAST((i % 12) + 1 AS INTEGER) AS trade_month, ROUND(100 + (i % 100) * 1.17, 2) AS close_price, 1000 + (i % 5000) AS volume, ROUND((100 + (i % 100) * 1.17) * (1000 + (i % 5000)), 2) AS trade_value FROM range({row_count}) AS t(i)) TO {} (FORMAT parquet, COMPRESSION zstd, ROW_GROUP_SIZE 10000)",
sql_string_literal(&path.to_string_lossy())
);
conn.execute_batch(&sql)?;
Ok(())
}
fn recommend_strategy(results: &[BenchmarkResult]) -> String {
let better_or_equal = results
.iter()
.filter(|result| {
result.parquet_selective.elapsed_ms <= result.full_download.elapsed_ms
&& result.parquet_selective.memory_delta_bytes
<= result.full_download.memory_delta_bytes
})
.count();
if better_or_equal == results.len() {
"Parquet-selective execution won or tied on both latency and memory for every tested row count, so it is the recommended default. Use full-download mode only when you deliberately want complete materialization before SQL execution.".to_string()
} else {
"Mixed benchmark results: keep parquet-selective as the default for large/object-storage datasets, but retain full-download mode as a fallback for small files or edge-case debugging.".to_string()
}
}
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::tempdir;
use super::*;
fn fixture_path(name: &str) -> PathBuf {
let manifest_candidate = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("testdata")
.join(name);
if manifest_candidate.exists() {
return manifest_candidate;
}
let cwd = std::env::current_dir().expect("current dir");
for dir in cwd.ancestors() {
let candidate = dir.join("testdata").join(name);
if candidate.exists() {
return candidate;
}
}
manifest_candidate
}
#[tokio::test]
async fn query_engine_handles_local_parquet() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 128).unwrap();
let response = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT symbol FROM dataset WHERE trade_month = 9 ORDER BY symbol LIMIT 2"
.to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert!(!response.rows.is_empty());
assert_eq!(response.columns[0].name, "symbol");
}
#[tokio::test]
async fn auto_hybrid_chooses_full_download_for_small_local_file() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("small.parquet");
generate_benchmark_parquet(&parquet, 128).unwrap();
let response = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT COUNT(*) AS rows FROM dataset".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::AutoHybrid,
small_file_threshold_bytes: 10 * 1024 * 1024,
limit: None,
})
.await
.unwrap();
assert!(matches!(response.mode, QueryMode::FullDownload));
assert!(response
.notes
.iter()
.any(|note| note.contains("hybrid chose full_download")));
}
#[tokio::test]
async fn auto_hybrid_chooses_parquet_selective_above_threshold() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("small.parquet");
generate_benchmark_parquet(&parquet, 128).unwrap();
let response = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT COUNT(*) AS rows FROM dataset".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::AutoHybrid,
small_file_threshold_bytes: 1,
limit: None,
})
.await
.unwrap();
assert!(matches!(response.mode, QueryMode::ParquetSelective));
assert!(response
.notes
.iter()
.any(|note| note.contains("hybrid chose parquet_selective")));
}
#[test]
fn auto_hybrid_keeps_remote_datasets_selective() {
let (mode, notes) = auto_hybrid_mode(
&DatasetSource {
uri: "https://example.com/data.parquet".to_string(),
storage: StorageConfig::Local,
},
10 * 1024 * 1024,
)
.unwrap();
assert!(matches!(mode, QueryMode::ParquetSelective));
assert!(notes.iter().any(|note| note.contains("dataset is remote")));
}
#[tokio::test]
async fn schema_inspection_handles_local_parquet() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 128).unwrap();
let response = Text2SqlEngine::new()
.inspect_schema(SchemaRequest {
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
})
.await
.unwrap();
assert!(response
.columns
.iter()
.any(|column| column.name == "symbol"));
assert!(response
.columns
.iter()
.any(|column| column.name == "trade_value"));
assert!(response
.notes
.iter()
.any(|note| note.contains("before generating SQL")));
}
#[tokio::test]
async fn convert_csv_to_parquet_and_query_it() {
let dir = tempdir().unwrap();
let csv_path = dir.path().join("prices.csv");
fs::write(&csv_path, "종목명,거래대금\n레이,10\n코스피,20\n").unwrap();
let parquet_path = dir.path().join("prices.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: csv_path.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 2);
assert!(parquet_path.exists());
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT 종목명, 거래대금 FROM dataset ORDER BY 거래대금 DESC".to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.row_count, 2);
}
#[tokio::test]
async fn benchmark_generates_requested_row_counts() {
let dir = tempdir().unwrap();
let response = Text2SqlEngine::new()
.benchmark(BenchmarkRequest {
output_dir: Some(dir.path().to_string_lossy().into_owned()),
row_counts: vec![10_000, 50_000],
sql: "SELECT symbol, AVG(trade_value) AS avg_trade_value FROM dataset WHERE trade_month = 9 GROUP BY symbol ORDER BY avg_trade_value DESC".to_string(),
limit: None,
})
.await
.unwrap();
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].row_count, 10_000);
assert_eq!(response.results[1].row_count, 50_000);
}
#[test]
fn recommendation_keeps_parquet_selective_as_default_on_mixed_results() {
let results = vec![BenchmarkResult {
row_count: 10_000,
parquet_path: "./tmp/benchmark-10000.parquet".to_string(),
parquet_selective: BenchmarkMeasurement {
elapsed_ms: 12,
memory_delta_bytes: 2_048,
result_rows: 4,
},
full_download: BenchmarkMeasurement {
elapsed_ms: 10,
memory_delta_bytes: 8_192,
result_rows: 4,
},
}];
let recommendation = recommend_strategy(&results);
assert!(recommendation.contains("parquet-selective as the default"));
assert!(recommendation.contains("full-download mode as a fallback"));
}
#[tokio::test]
async fn real_world_owid_csv_converts_and_queries() {
let input = fixture_path("owid-covid-latest.csv");
assert!(
input.exists(),
"expected downloaded OWID fixture at {}",
input.display()
);
let dir = tempdir().unwrap();
let parquet_path = dir.path().join("owid-covid-latest.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert!(converted.row_count > 200);
assert!(converted.columns.iter().any(|column| column == "location"));
assert!(converted
.columns
.iter()
.any(|column| column == "population"));
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT location, population FROM dataset WHERE continent = 'Asia' AND population IS NOT NULL ORDER BY population DESC LIMIT 3".to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.row_count, 3);
assert_eq!(
queried.rows[0]["location"],
Value::String("China".to_string())
);
}
#[tokio::test]
async fn real_world_owid_json_object_map_converts_and_queries() {
let input = fixture_path("owid-covid-latest.json");
assert!(
input.exists(),
"expected downloaded OWID JSON fixture at {}",
input.display()
);
let dir = tempdir().unwrap();
let parquet_path = dir.path().join("owid-covid-latest-json.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert!(converted.row_count > 200);
assert!(converted.columns.iter().any(|column| column == "entry_key"));
assert!(converted.columns.iter().any(|column| column == "location"));
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql:
"SELECT entry_key, location FROM dataset WHERE location = 'South Korea' LIMIT 1"
.to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.row_count, 1);
assert_eq!(
queried.rows[0]["entry_key"],
Value::String("KOR".to_string())
);
}
#[tokio::test]
async fn canada_government_csv_converts_and_queries() {
let input = fixture_path("canada-wastewater-aggregate.csv");
assert!(
input.exists(),
"expected downloaded Canada fixture at {}",
input.display()
);
let dir = tempdir().unwrap();
let parquet_path = dir.path().join("canada-wastewater-aggregate.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert!(converted.row_count > 50_000);
assert!(converted.columns.iter().any(|column| column == "province"));
assert!(converted.columns.iter().any(|column| column == "measureid"));
assert!(converted.columns.iter().any(|column| column == "w_avg"));
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT province, measureid, COUNT(*) AS samples FROM dataset WHERE country = 'Canada' AND province = 'Ontario' AND measureid = 'covN2' GROUP BY province, measureid ORDER BY samples DESC LIMIT 1".to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.row_count, 1);
assert_eq!(
queried.rows[0]["province"],
Value::String("Ontario".to_string())
);
assert_eq!(
queried.rows[0]["measureid"],
Value::String("covN2".to_string())
);
}
#[tokio::test]
async fn unsupported_extension_returns_validation_error() {
let dir = tempdir().unwrap();
let input = dir.path().join("notes.txt");
fs::write(&input, "not a supported dataset").unwrap();
let output = dir.path().join("notes.parquet");
let error = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: output.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap_err();
assert!(error.to_string().contains("unsupported input extension"));
}
#[tokio::test]
async fn csv_convert_supports_explicit_non_utf8_encoding() {
let dir = tempdir().unwrap();
let input = dir.path().join("cp1252.csv");
let output = dir.path().join("cp1252.parquet");
std::fs::write(&input, b"name,city\nAndre,caf\xe9\n").unwrap();
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: output.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: Some("windows-1252".to_string()),
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 1);
assert!(converted
.notes
.iter()
.any(|note| note.contains("windows-1252")));
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT city FROM dataset LIMIT 1".to_string(),
dataset: DatasetSource {
uri: output.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.rows[0]["city"], Value::String("café".to_string()));
}
#[tokio::test]
async fn csv_convert_rejects_unknown_encoding_label() {
let dir = tempdir().unwrap();
let input = dir.path().join("sample.csv");
let output = dir.path().join("sample.parquet");
std::fs::write(&input, "name,city\nAndre,cafe\n").unwrap();
let error = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: output.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: Some("definitely-not-real".to_string()),
overwrite: true,
add_filename_column: false,
})
.await
.unwrap_err();
assert!(error.to_string().contains("unsupported csv_encoding"));
}
#[tokio::test]
async fn invalid_sql_returns_duckdb_error() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 128).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT definitely_missing FROM dataset".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("duckdb error"));
}
#[test]
fn policy_sql_ignores_strings_and_comments() {
let sql = "SELECT '-- not comment' AS note, \"copy\" FROM dataset /* read_parquet('/tmp/x') */ -- install spatial";
let normalized = policy_sql(sql);
assert!(normalized.starts_with("select"));
assert!(!normalized.contains("install spatial"));
assert!(!normalized.contains("read_parquet("));
}
#[tokio::test]
async fn query_rejects_copy_statement() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "COPY (SELECT * FROM dataset) TO '/tmp/exfil.csv'".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("only read-only SELECT/WITH"));
}
#[tokio::test]
async fn query_rejects_direct_read_parquet_call() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT * FROM read_parquet('/tmp/other.parquet') LIMIT 1".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("direct table-function access"));
}
#[tokio::test]
async fn query_rejects_direct_parquet_scan_call() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT * FROM parquet_scan('/tmp/other.parquet') LIMIT 1".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("direct table-function access"));
}
#[tokio::test]
async fn query_rejects_direct_read_blob_call() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT read_blob('/tmp/secret.txt')".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("direct table-function access"));
}
#[tokio::test]
async fn query_rejects_set_statement() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SET memory_limit = '8GB'".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("only read-only SELECT/WITH"));
}
#[tokio::test]
async fn query_rejects_multi_statement_sql() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
let error = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "CREATE TABLE hack AS SELECT 1; SELECT * FROM dataset LIMIT 1".to_string(),
dataset: DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap_err();
assert!(error.to_string().contains("single read-only SELECT/WITH"));
}
#[test]
fn hardened_query_connection_disables_copy_to_even_without_sql_validator() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
let output = dir.path().join("exfil.csv");
generate_benchmark_parquet(&parquet, 8).unwrap();
let dataset = DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
};
let (conn, mut notes) =
open_registered_dataset(&dataset, "dataset", &QueryMode::ParquetSelective).unwrap();
harden_query_connection(&conn, &dataset, &mut notes).unwrap();
let error = conn
.execute_batch(&format!(
"COPY (SELECT * FROM dataset LIMIT 1) TO {}",
sql_string_literal(&output.to_string_lossy())
))
.unwrap_err();
assert!(
error.to_string().contains("disabled by configuration")
|| error.to_string().contains("read_only"),
"unexpected error: {error}"
);
assert!(!output.exists(), "COPY should not create output files");
}
#[test]
fn hardened_query_connection_blocks_other_files_even_without_sql_validator() {
let dir = tempdir().unwrap();
let parquet = dir.path().join("sample.parquet");
let other = dir.path().join("other.parquet");
generate_benchmark_parquet(&parquet, 8).unwrap();
generate_benchmark_parquet(&other, 4).unwrap();
let dataset = DatasetSource {
uri: parquet.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
};
let (conn, mut notes) =
open_registered_dataset(&dataset, "dataset", &QueryMode::ParquetSelective).unwrap();
harden_query_connection(&conn, &dataset, &mut notes).unwrap();
let allowed_rows = collect_rows(&conn, "SELECT COUNT(*) AS rows FROM dataset").unwrap();
assert_eq!(allowed_rows.1[0]["rows"], Value::Number(Number::from(8u64)));
let error = match conn.prepare(&format!(
"SELECT * FROM read_parquet({}) LIMIT 1",
sql_string_literal(&other.to_string_lossy())
)) {
Ok(mut stmt) => match stmt.query([]) {
Ok(_) => panic!("direct access to a non-whitelisted parquet file should fail"),
Err(error) => error,
},
Err(error) => error,
};
assert!(
error.to_string().contains("disabled by configuration")
|| error.to_string().contains("Permission Error"),
"unexpected error: {error}"
);
}
#[tokio::test]
async fn real_world_usda_xlsx_converts() {
let input = fixture_path("keyfoods_0708.xlsx");
assert!(
input.exists(),
"expected USDA XLSX fixture at {}",
input.display()
);
let dir = tempdir().unwrap();
let parquet_path = dir.path().join("keyfoods_0708.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 538);
assert!(converted
.columns
.iter()
.any(|column| column == "Long_Description"));
}
#[tokio::test]
async fn json_array_of_objects_converts_and_queries() {
let dir = tempdir().unwrap();
let input = dir.path().join("sample.json");
fs::write(
&input,
r#"[{"city":"Seoul","value":10},{"city":"Busan","value":20}]"#,
)
.unwrap();
let parquet_path = dir.path().join("sample-json-array.parquet");
Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT city FROM dataset ORDER BY value DESC LIMIT 1".to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(queried.rows[0]["city"], Value::String("Busan".to_string()));
}
#[tokio::test]
async fn ndjson_converts_and_queries() {
let dir = tempdir().unwrap();
let input = dir.path().join("sample.ndjson");
fs::write(
&input,
r#"{"city":"Seoul","value":10}
{"city":"Busan","value":20}
"#,
)
.unwrap();
let parquet_path = dir.path().join("sample-ndjson.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 2);
}
#[tokio::test]
async fn bom_prefixed_ndjson_converts_and_queries() {
let dir = tempdir().unwrap();
let input = dir.path().join("sample-bom.ndjson");
fs::write(
&input,
"\u{feff}{\"city\":\"Seoul\",\"value\":10}\n{\"city\":\"Busan\",\"value\":20}\n",
)
.unwrap();
let parquet_path = dir.path().join("sample-bom-ndjson.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 2);
}
#[tokio::test]
async fn ndjson_extension_keeps_single_nested_object_as_one_row() {
let dir = tempdir().unwrap();
let input = dir.path().join("nested.ndjson");
fs::write(&input, r#"{"city":{"name":"Seoul"},"value":{"amount":10}}"#).unwrap();
let parquet_path = dir.path().join("nested-ndjson.parquet");
let converted = Text2SqlEngine::new()
.convert(ConvertRequest {
input_path: input.to_string_lossy().into_owned(),
output: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
normalize_columns: true,
sheet_name: None,
csv_encoding: None,
overwrite: true,
add_filename_column: false,
})
.await
.unwrap();
assert_eq!(converted.row_count, 1);
assert_eq!(
converted.columns,
vec!["city".to_string(), "value".to_string()]
);
assert!(!converted.columns.iter().any(|column| column == "entry_key"));
let queried = Text2SqlEngine::new()
.execute_query(QueryRequest {
sql: "SELECT city FROM dataset".to_string(),
dataset: DatasetSource {
uri: parquet_path.to_string_lossy().into_owned(),
storage: StorageConfig::Local,
},
table_name: "dataset".to_string(),
mode: QueryMode::ParquetSelective,
small_file_threshold_bytes: DEFAULT_SMALL_FILE_THRESHOLD_BYTES,
limit: None,
})
.await
.unwrap();
assert_eq!(
queried.rows[0]["city"],
Value::String(r#"{"name":"Seoul"}"#.to_string())
);
}
#[test]
fn duckdb_endpoint_strips_scheme_and_trailing_slash() {
assert_eq!(
duckdb_endpoint("http://127.0.0.1:19000/"),
"127.0.0.1:19000"
);
assert_eq!(
duckdb_endpoint("https://minio.example.com"),
"minio.example.com"
);
assert_eq!(
duckdb_endpoint("minio.example.com:9000"),
"minio.example.com:9000"
);
}
#[test]
fn clean_identifier_keeps_korean_text() {
assert_eq!(clean_identifier("거래 대금"), "거래_대금");
assert_eq!(clean_identifier("9월"), "col_9월");
}
}