1use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use anyhow::anyhow;
14use arrow_json::LineDelimitedWriter;
15use lance::Dataset;
16use lance::datafusion::LanceTableProvider;
17use lance::dataset::udtf::FtsQueryUDTFBuilder;
18use lance::deps::arrow_array::builder::{
19 BooleanBuilder, Float64Builder, Int64Builder, StringBuilder,
20};
21use lance::deps::arrow_array::{Array, LargeBinaryArray, RecordBatch, StringArray};
22use lance::deps::arrow_schema::{ArrowError, DataType};
23use lance::deps::datafusion::arrow::util::pretty::pretty_format_batches;
24use lance::deps::datafusion::datasource::{ViewTable, provider_as_source};
25use lance::deps::datafusion::error::DataFusionError;
26use lance::deps::datafusion::execution::SessionStateBuilder;
27use lance::deps::datafusion::execution::runtime_env::RuntimeEnvBuilder;
28use lance::deps::datafusion::logical_expr::{
29 ColumnarValue, LogicalPlanBuilder, ScalarUDF, Volatility, create_udf,
30};
31use lance::deps::datafusion::prelude::{SQLOptions, SessionConfig, SessionContext, col};
32use lance::deps::datafusion::sql::parser::{DFParser, Statement as DfStatement};
33use lance::deps::datafusion::sql::sqlparser::ast::{SetExpr, Statement as SqlStatement};
34use lance_datafusion::udf::register_functions;
35use parquet::arrow::ArrowWriter;
36use serde_json::{Map as JsonMap, Value as JsonValue, json};
37
38const MEM_LIMIT_BYTES: usize = 512 * 1024 * 1024;
41const QUERY_TIMEOUT: Duration = Duration::from_secs(30);
44const INLINE_BUDGET_BYTES: usize = 80_000;
46const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
49pub const DEFAULT_INLINE_ROWS: usize = 100;
51pub const MAX_INLINE_ROWS: usize = 1_000;
53
54#[derive(Debug, Clone, Copy)]
57pub enum Format {
58 Parquet,
59 Ndjson,
60}
61
62impl Format {
63 pub fn ext(self) -> &'static str {
64 match self {
65 Self::Parquet => "parquet",
66 Self::Ndjson => "ndjson",
67 }
68 }
69
70 pub fn mime(self) -> &'static str {
71 match self {
72 Self::Parquet => "application/vnd.apache.parquet",
73 Self::Ndjson => "application/x-ndjson",
74 }
75 }
76}
77
78#[derive(Debug, Clone, Copy)]
80pub enum Mode {
81 Inline,
83 InlineJson,
90 Export(Format),
92}
93
94pub struct Tables {
97 pub sessions: Arc<Dataset>,
98 pub messages: Arc<Dataset>,
99 pub parts: Arc<Dataset>,
100}
101
102pub enum Outcome {
104 Inline(String),
106 InlineJson(JsonValue),
109 Export {
111 bytes: Vec<u8>,
112 format: Format,
113 rows: usize,
114 columns: Vec<String>,
115 },
116}
117
118#[derive(Debug)]
122pub enum SqlError {
123 Query(String),
124 Infra(anyhow::Error),
125}
126
127fn infra(error: ArrowError) -> SqlError {
128 SqlError::Infra(anyhow::Error::new(error))
129}
130
131pub async fn run(
134 tables: &Tables,
135 sql: &str,
136 mode: Mode,
137 inline_rows: usize,
138) -> Result<Outcome, SqlError> {
139 let parsed = parse_and_gate(sql)?;
140 if matches!(parsed.kind, StatementKind::Explain) && matches!(mode, Mode::Export(_)) {
141 return Err(SqlError::Query(
142 "EXPLAIN returns a plan, not a result set; use output=table (or json) to read it"
143 .to_owned(),
144 ));
145 }
146 if projection_mentions_vector(parsed.projection_query()) {
147 return Err(SqlError::Query(
148 "the `vector` column is not selectable from pond_sql_query (it is a \
149 FixedSizeList<f32> embedding, ~600 bytes per row and not useful in a result). \
150 For semantic search use pond_search. Filtering on it is allowed in WHERE \
151 (e.g. `vector IS NOT NULL`)."
152 .to_owned(),
153 ));
154 }
155 let ctx = build_context()?;
156 register(&ctx, tables)?;
157
158 let options = SQLOptions::new()
164 .with_allow_ddl(false)
165 .with_allow_dml(false)
166 .with_allow_statements(matches!(parsed.kind, StatementKind::Explain));
167 let df = ctx
168 .sql_with_options(sql, options)
169 .await
170 .map_err(|error| SqlError::Query(enrich(&format!("SQL error: {error}"))))?;
171
172 let result_schema = Arc::new(df.schema().as_arrow().clone());
175 let started = Instant::now();
176 let collected = tokio::time::timeout(QUERY_TIMEOUT, df.collect())
177 .await
178 .map_err(|_| {
179 SqlError::Query(format!(
180 "query exceeded the {}s limit; add a narrower WHERE or a LIMIT",
181 QUERY_TIMEOUT.as_secs()
182 ))
183 })?
184 .map_err(|error| SqlError::Query(enrich(&format!("SQL error: {error}"))))?;
185 let elapsed = started.elapsed();
186
187 let display: Vec<RecordBatch> = if collected.is_empty() {
188 vec![displayable(&RecordBatch::new_empty(result_schema)).map_err(infra)?]
189 } else {
190 collected
191 .iter()
192 .map(displayable)
193 .collect::<Result<_, _>>()
194 .map_err(infra)?
195 };
196
197 match mode {
198 Mode::Inline => Ok(Outcome::Inline(
199 render_inline(&display, inline_rows, elapsed).map_err(infra)?,
200 )),
201 Mode::InlineJson => Ok(Outcome::InlineJson(render_inline_json(
202 &display,
203 inline_rows,
204 elapsed,
205 )?)),
206 Mode::Export(format) => {
207 let rows = display.iter().map(RecordBatch::num_rows).sum();
208 let columns = display
209 .first()
210 .map(|batch| {
211 batch
212 .schema()
213 .fields()
214 .iter()
215 .map(|field| field.name().clone())
216 .collect::<Vec<_>>()
217 })
218 .unwrap_or_default();
219 let bytes = match format {
220 Format::Parquet => encode_parquet(&display)?,
221 Format::Ndjson => encode_ndjson(&display)?,
222 };
223 if bytes.len() > MAX_EXPORT_BYTES {
224 return Err(SqlError::Query(format!(
225 "export is {} bytes, over the {MAX_EXPORT_BYTES} byte limit; \
226 narrow the query or aggregate",
227 bytes.len()
228 )));
229 }
230 Ok(Outcome::Export {
231 bytes,
232 format,
233 rows,
234 columns,
235 })
236 }
237 }
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242enum StatementKind {
243 Query,
245 Explain,
247}
248
249struct ParsedStatement {
255 kind: StatementKind,
256 query: lance::deps::datafusion::sql::sqlparser::ast::Query,
257}
258
259impl ParsedStatement {
260 fn projection_query(&self) -> &lance::deps::datafusion::sql::sqlparser::ast::Query {
261 &self.query
262 }
263}
264
265fn parse_and_gate(sql: &str) -> Result<ParsedStatement, SqlError> {
272 let statements = DFParser::parse_sql(sql)
273 .map_err(|error| SqlError::Query(format!("SQL parse error: {error}")))?;
274 if statements.len() != 1 {
275 return Err(SqlError::Query(
276 "pond_sql_query runs exactly one statement; submit a single SELECT".to_owned(),
277 ));
278 }
279 let Some(front) = statements.front() else {
280 return Err(read_only_rejection());
281 };
282 match front {
283 DfStatement::Statement(boxed) => match boxed.as_ref() {
284 SqlStatement::Query(query) => Ok(ParsedStatement {
285 kind: StatementKind::Query,
286 query: query.as_ref().clone(),
287 }),
288 _ => Err(read_only_rejection()),
289 },
290 DfStatement::Explain(explain) => match explain.statement.as_ref() {
291 DfStatement::Statement(inner) => match inner.as_ref() {
292 SqlStatement::Query(query) => Ok(ParsedStatement {
293 kind: StatementKind::Explain,
294 query: query.as_ref().clone(),
295 }),
296 _ => Err(read_only_rejection()),
297 },
298 _ => Err(read_only_rejection()),
299 },
300 _ => Err(read_only_rejection()),
301 }
302}
303
304fn read_only_rejection() -> SqlError {
305 SqlError::Query(
306 "pond_sql_query is read-only: only a single SELECT/WITH (or EXPLAIN of one) is \
307 allowed (no INSERT/UPDATE/DELETE/CREATE/DROP/COPY/SET)"
308 .to_owned(),
309 )
310}
311
312fn projection_mentions_vector(query: &lance::deps::datafusion::sql::sqlparser::ast::Query) -> bool {
323 walk_set_expr_for_vector(query.body.as_ref())
324}
325
326fn walk_set_expr_for_vector(expr: &SetExpr) -> bool {
327 match expr {
328 SetExpr::Select(select) => select
329 .projection
330 .iter()
331 .any(|item| mentions_vector_token(&item.to_string())),
332 SetExpr::Query(inner) => walk_set_expr_for_vector(inner.body.as_ref()),
333 SetExpr::SetOperation { left, right, .. } => {
334 walk_set_expr_for_vector(left) || walk_set_expr_for_vector(right)
335 }
336 _ => false,
337 }
338}
339
340fn mentions_vector_token(text: &str) -> bool {
341 text.split(|c: char| !c.is_alphanumeric() && c != '_')
342 .any(|token| token == "vector")
343}
344
345fn build_context() -> Result<SessionContext, SqlError> {
346 let runtime = RuntimeEnvBuilder::new()
347 .with_memory_limit(MEM_LIMIT_BYTES, 1.0)
348 .build_arc()
349 .map_err(|error| SqlError::Infra(anyhow!("datafusion runtime init failed: {error}")))?;
350 let state = SessionStateBuilder::new()
353 .with_config(SessionConfig::new().with_information_schema(true))
354 .with_runtime_env(runtime)
355 .with_default_features()
356 .build();
357 Ok(SessionContext::new_with_state(state))
358}
359
360fn register(ctx: &SessionContext, tables: &Tables) -> Result<(), SqlError> {
361 for (name, dataset) in [
362 ("sessions", &tables.sessions),
363 ("messages", &tables.messages),
364 ] {
365 let provider = LanceTableProvider::new(dataset.clone(), false, false);
369 ctx.register_table(name, Arc::new(provider))
370 .map_err(|error| SqlError::Infra(anyhow!("register table {name}: {error}")))?;
371 }
372 let provider = LanceTableProvider::new(tables.parts.clone(), false, false);
377 let keep: Vec<_> = tables
378 .parts
379 .schema()
380 .fields
381 .iter()
382 .filter(|field| field.name != "data")
383 .map(|field| col(field.name.as_str()))
384 .collect();
385 let plan = LogicalPlanBuilder::scan("parts", provider_as_source(Arc::new(provider)), None)
386 .and_then(|builder| builder.project(keep))
387 .and_then(LogicalPlanBuilder::build)
388 .map_err(|error| SqlError::Infra(anyhow!("build parts view: {error}")))?;
389 ctx.register_table("parts", Arc::new(ViewTable::new(plan, None)))
390 .map_err(|error| SqlError::Infra(anyhow!("register table parts: {error}")))?;
391 let fts = FtsQueryUDTFBuilder::builder()
394 .register_table("sessions", tables.sessions.clone())
395 .register_table("messages", tables.messages.clone())
396 .register_table("parts", tables.parts.clone())
397 .build();
398 ctx.register_udtf("fts", Arc::new(fts));
399 register_functions(ctx);
400 for udf in lenient_json_udfs() {
404 ctx.register_udf(udf);
405 }
406 Ok(())
407}
408
409enum JsonGet {
411 Text,
412 Int,
413 Float,
414 Bool,
415}
416
417fn lenient_json_udfs() -> [ScalarUDF; 4] {
425 let make = |name: &str, kind: JsonGet, return_type: DataType| {
426 create_udf(
427 name,
428 vec![DataType::LargeBinary, DataType::Utf8],
429 return_type,
430 Volatility::Immutable,
431 Arc::new(move |args: &[ColumnarValue]| json_get_lenient(args, &kind)),
432 )
433 };
434 [
435 make("json_get_string", JsonGet::Text, DataType::Utf8),
436 make("json_get_int", JsonGet::Int, DataType::Int64),
437 make("json_get_float", JsonGet::Float, DataType::Float64),
438 make("json_get_bool", JsonGet::Bool, DataType::Boolean),
439 ]
440}
441
442fn json_get_lenient(
443 args: &[ColumnarValue],
444 kind: &JsonGet,
445) -> Result<ColumnarValue, DataFusionError> {
446 let arrays = ColumnarValue::values_to_arrays(args)?;
447 let [jsonb_arg, key_arg] = arrays.as_slice() else {
448 return Err(DataFusionError::Execution(
449 "json_get_* takes exactly (json_column, 'key')".to_owned(),
450 ));
451 };
452 let jsonb_array = jsonb_arg
453 .as_any()
454 .downcast_ref::<LargeBinaryArray>()
455 .ok_or_else(|| {
456 DataFusionError::Execution(
457 "json_get_* argument 1 must be a JSON column (variant_data, options)".to_owned(),
458 )
459 })?;
460 let key_array = key_arg
461 .as_any()
462 .downcast_ref::<StringArray>()
463 .ok_or_else(|| {
464 DataFusionError::Execution("json_get_* argument 2 must be a string key".to_owned())
465 })?;
466
467 let field = |row: usize| -> Option<jsonb::OwnedJsonb> {
468 if jsonb_array.is_null(row) || key_array.is_null(row) {
469 return None;
470 }
471 let raw = jsonb::RawJsonb::new(jsonb_array.value(row));
472 let key = key_array.value(row);
473 let value = if raw.is_object().unwrap_or(false) {
474 raw.get_by_name(key, false).ok().flatten()
475 } else if raw.is_array().unwrap_or(false) {
476 key.parse::<usize>()
477 .ok()
478 .and_then(|index| raw.get_by_index(index).ok().flatten())
479 } else {
480 None
481 };
482 value.filter(|value| !value.as_raw().is_null().unwrap_or(false))
483 };
484
485 let rows = jsonb_array.len();
486 let array: Arc<dyn Array> = match kind {
487 JsonGet::Text => {
488 let mut builder = StringBuilder::with_capacity(rows, 1024);
489 for row in 0..rows {
490 match field(row) {
491 Some(value) => match value.as_raw().to_str() {
494 Ok(text) => builder.append_value(text),
495 Err(_) => builder.append_value(value.to_string()),
496 },
497 None => builder.append_null(),
498 }
499 }
500 Arc::new(builder.finish())
501 }
502 JsonGet::Int => {
503 let mut builder = Int64Builder::with_capacity(rows);
504 for row in 0..rows {
505 builder.append_option(field(row).and_then(|value| value.as_raw().to_i64().ok()));
506 }
507 Arc::new(builder.finish())
508 }
509 JsonGet::Float => {
510 let mut builder = Float64Builder::with_capacity(rows);
511 for row in 0..rows {
512 builder.append_option(field(row).and_then(|value| value.as_raw().to_f64().ok()));
513 }
514 Arc::new(builder.finish())
515 }
516 JsonGet::Bool => {
517 let mut builder = BooleanBuilder::with_capacity(rows);
518 for row in 0..rows {
519 builder.append_option(field(row).and_then(|value| value.as_raw().to_bool().ok()));
520 }
521 Arc::new(builder.finish())
522 }
523 };
524 Ok(ColumnarValue::Array(array))
525}
526
527fn enrich(message: &str) -> String {
531 const HINTS: &[(&str, &str)] = &[
532 (
533 "No field named",
534 "columns are messages(session_id, id, timestamp, role, source_agent, project, \
535 content [system-role only], search_text [the conversational text], \
536 embedding_model, options) | sessions(id, parent_session_id, parent_message_id, \
537 source_agent, created_at, project, options) | parts(session_id, message_id, id, \
538 ordinal, type, provenance, variant_data, options). Part bodies (tool params/\
539 results, text) live in parts.variant_data - read them with \
540 json_extract(variant_data, '$.field'). For text search use fts('messages', \
541 ...); to read a transcript use pond_get. Full doc: resource schema://pond-sql.",
542 ),
543 (
544 "Encountered non UTF-8 data",
545 "JSON columns (variant_data, options) are binary JSONB - CAST / ::text does not \
546 work on them. Stringify the whole value with json_extract(col, '$'), or fetch \
547 one field with json_extract(col, '$.field').",
548 ),
549 (
550 "LIKE prefix queries are not supported for bitmap indexes",
551 "prefix LIKE ('x%') and starts_with() fail on bitmap-indexed columns \
552 (messages.source_agent, messages.role). Use equality, \
553 split_part(source_agent, '/', 1) = '...', or an infix pattern (LIKE '%x%').",
554 ),
555 (
556 "call to 'json_",
557 "JSON function signatures: json_get_string|json_get_int|json_get_float|\
558 json_get_bool(col, 'key') - one key, not a path; json_get(col, 'key') returns \
559 JSONB for chaining; json_extract(col, '$.a.b') takes a JSONPath and returns \
560 JSON text of any value (the right tool for nested or mixed-type fields).",
561 ),
562 (
563 "Invalid function 'json",
564 "available JSON functions: json_get_string, json_get_int, json_get_float, \
565 json_get_bool (col, 'key'); json_get(col, 'key') -> JSONB for chaining; \
566 json_extract(col, '$.a.b') -> JSON text; json_array_contains; \
567 json_array_length. See resource schema://pond-sql.",
568 ),
569 ];
570 for (pattern, hint) in HINTS {
571 if message.contains(pattern) {
572 return format!("{message}\nhint: {hint}");
573 }
574 }
575 message.to_owned()
576}
577
578fn displayable(batch: &RecordBatch) -> Result<RecordBatch, ArrowError> {
581 let decoded = lance_arrow::json::convert_lance_json_to_arrow(batch)?;
582 let keep: Vec<usize> = decoded
583 .schema()
584 .fields()
585 .iter()
586 .enumerate()
587 .filter(|(_, field)| is_displayable(field.data_type()))
588 .map(|(index, _)| index)
589 .collect();
590 decoded.project(&keep)
591}
592
593fn is_displayable(data_type: &DataType) -> bool {
594 !matches!(
595 data_type,
596 DataType::FixedSizeList(_, _)
597 | DataType::Binary
598 | DataType::LargeBinary
599 | DataType::BinaryView
600 | DataType::FixedSizeBinary(_)
601 )
602}
603
604fn render_inline(
605 display: &[RecordBatch],
606 max_rows: usize,
607 elapsed: Duration,
608) -> Result<String, ArrowError> {
609 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
610 let elapsed_ms = elapsed.as_millis();
611 if total == 0 {
612 return Ok(format!(
614 "0 rows ({elapsed_ms} ms).\n{}",
615 pretty_format_batches(display)?
616 ));
617 }
618 let mut shown = total.min(max_rows);
619 let mut table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
620 while table.len() > INLINE_BUDGET_BYTES && shown > 1 {
621 shown = (shown / 2).max(1);
622 table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
623 }
624 let mut out = format!("{total} row(s) in {elapsed_ms} ms; showing {shown}.\n{table}");
625 if shown < total {
626 out.push_str(&format!(
627 "\n... {} row(s) omitted. To page: ORDER BY <indexed col> (e.g. timestamp, \
628 id), then in the next call add `WHERE (col, id) < (<last_col>, <last_id>)` - \
629 keyset pagination, see schema://pond-sql. For the full set: output=parquet \
630 or output=ndjson.",
631 total - shown
632 ));
633 }
634 Ok(out)
635}
636
637fn render_inline_json(
642 display: &[RecordBatch],
643 max_rows: usize,
644 elapsed: Duration,
645) -> Result<JsonValue, SqlError> {
646 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
647 let columns: Vec<String> = display
648 .first()
649 .map(|batch| {
650 batch
651 .schema()
652 .fields()
653 .iter()
654 .map(|field| field.name().clone())
655 .collect()
656 })
657 .unwrap_or_default();
658 let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
659
660 if total == 0 {
661 return Ok(json!({
662 "total_rows": 0,
663 "shown_rows": 0,
664 "truncated": false,
665 "elapsed_ms": elapsed_ms,
666 "columns": columns,
667 "rows": [],
668 }));
669 }
670
671 let mut shown = total.min(max_rows);
672 let mut rows = batches_to_json_rows(&limit_batches(display, shown))?;
673 let mut serialized = serde_json::to_string(&rows)
674 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
675 while serialized.len() > INLINE_BUDGET_BYTES && shown > 1 {
676 shown = (shown / 2).max(1);
677 rows = batches_to_json_rows(&limit_batches(display, shown))?;
678 serialized = serde_json::to_string(&rows)
679 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
680 }
681
682 let mut payload = JsonMap::new();
683 payload.insert("total_rows".to_owned(), json!(total));
684 payload.insert("shown_rows".to_owned(), json!(shown));
685 payload.insert("truncated".to_owned(), json!(shown < total));
686 payload.insert("elapsed_ms".to_owned(), json!(elapsed_ms));
687 payload.insert("columns".to_owned(), json!(columns));
688 payload.insert("rows".to_owned(), JsonValue::Array(rows));
689 if shown < total {
690 payload.insert(
691 "next_steps".to_owned(),
692 json!(format!(
693 "{} row(s) omitted; ORDER BY + keyset (`WHERE (col, id) < \
694 (<last_col>, <last_id>)`) to page, or output=parquet|ndjson for the \
695 full set. See schema://pond-sql.",
696 total - shown
697 )),
698 );
699 }
700 Ok(JsonValue::Object(payload))
701}
702
703fn batches_to_json_rows(batches: &[RecordBatch]) -> Result<Vec<JsonValue>, SqlError> {
707 if batches.iter().all(|batch| batch.num_rows() == 0) {
708 return Ok(Vec::new());
709 }
710 let mut buffer = Vec::new();
711 {
712 let mut writer = LineDelimitedWriter::new(&mut buffer);
713 let refs: Vec<&RecordBatch> = batches.iter().collect();
714 writer
715 .write_batches(&refs)
716 .map_err(|error| SqlError::Infra(anyhow!("ndjson encode: {error}")))?;
717 writer
718 .finish()
719 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish: {error}")))?;
720 }
721 let text = String::from_utf8(buffer)
722 .map_err(|error| SqlError::Infra(anyhow!("ndjson not utf-8: {error}")))?;
723 text.lines()
724 .filter(|line| !line.is_empty())
725 .map(|line| {
726 serde_json::from_str::<JsonValue>(line)
727 .map_err(|error| SqlError::Infra(anyhow!("ndjson parse: {error}")))
728 })
729 .collect()
730}
731
732fn limit_batches(batches: &[RecordBatch], max_rows: usize) -> Vec<RecordBatch> {
733 let mut out = Vec::new();
734 let mut remaining = max_rows;
735 for batch in batches {
736 if remaining == 0 {
737 break;
738 }
739 if batch.num_rows() <= remaining {
740 remaining -= batch.num_rows();
741 out.push(batch.clone());
742 } else {
743 out.push(batch.slice(0, remaining));
744 remaining = 0;
745 }
746 }
747 out
748}
749
750fn encode_parquet(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
751 let schema = batches
752 .first()
753 .map(RecordBatch::schema)
754 .ok_or_else(|| SqlError::Query("query returned no columns to export".to_owned()))?;
755 let mut buffer = Vec::new();
756 let mut writer = ArrowWriter::try_new(&mut buffer, schema, None)
757 .map_err(|error| SqlError::Infra(anyhow!("parquet init failed: {error}")))?;
758 for batch in batches {
759 writer
760 .write(batch)
761 .map_err(|error| SqlError::Infra(anyhow!("parquet write failed: {error}")))?;
762 }
763 writer
764 .close()
765 .map_err(|error| SqlError::Infra(anyhow!("parquet close failed: {error}")))?;
766 Ok(buffer)
767}
768
769fn encode_ndjson(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
770 let mut buffer = Vec::new();
771 {
772 let mut writer = LineDelimitedWriter::new(&mut buffer);
773 let refs: Vec<&RecordBatch> = batches.iter().collect();
774 writer
775 .write_batches(&refs)
776 .map_err(|error| SqlError::Infra(anyhow!("ndjson write failed: {error}")))?;
777 writer
778 .finish()
779 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish failed: {error}")))?;
780 }
781 Ok(buffer)
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787
788 fn rejected(sql: &str) -> bool {
789 matches!(parse_and_gate(sql), Err(SqlError::Query(_)))
790 }
791
792 fn parses_as(sql: &str, expected: StatementKind) -> bool {
793 match parse_and_gate(sql) {
794 Ok(parsed) => matches!(
795 (&parsed.kind, &expected),
796 (StatementKind::Query, StatementKind::Query)
797 | (StatementKind::Explain, StatementKind::Explain)
798 ),
799 Err(_) => false,
800 }
801 }
802
803 #[test]
804 fn allows_single_select_and_cte() {
805 assert!(parses_as("SELECT 1", StatementKind::Query));
806 assert!(parses_as(
807 "SELECT role, count(*) FROM messages GROUP BY role",
808 StatementKind::Query
809 ));
810 assert!(parses_as(
811 "WITH t AS (SELECT 1 AS a) SELECT a FROM t",
812 StatementKind::Query
813 ));
814 }
815
816 #[test]
817 fn allows_explain_of_select() {
818 assert!(parses_as("EXPLAIN SELECT 1", StatementKind::Explain));
819 assert!(parses_as(
820 "EXPLAIN ANALYZE SELECT role FROM messages",
821 StatementKind::Explain
822 ));
823 }
824
825 #[test]
826 fn rejects_explain_of_non_query() {
827 assert!(rejected("EXPLAIN INSERT INTO messages VALUES ('x')"));
830 }
831
832 #[test]
833 fn rejects_writes_and_side_effects() {
834 assert!(rejected("INSERT INTO messages VALUES ('x')"));
835 assert!(rejected("UPDATE messages SET role = 'x'"));
836 assert!(rejected("DELETE FROM messages"));
837 assert!(rejected("CREATE TABLE t (x INT)"));
838 assert!(rejected("CREATE VIEW v AS SELECT 1"));
839 assert!(rejected("DROP TABLE messages"));
840 assert!(rejected(
841 "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/etc'"
842 ));
843 assert!(rejected("COPY (SELECT 1) TO '/tmp/x.parquet'"));
844 assert!(rejected("SET a = 1"));
845 }
846
847 #[test]
848 fn rejects_multiple_statements() {
849 assert!(rejected("SELECT 1; SELECT 2"));
850 assert!(rejected("SELECT 1; DROP TABLE messages"));
851 }
852
853 #[test]
854 fn rejects_unparseable() {
855 assert!(rejected("NOT SQL AT ALL ;;"));
856 }
857
858 fn mentions_vector(sql: &str) -> bool {
859 match parse_and_gate(sql) {
860 Ok(parsed) => projection_mentions_vector(parsed.projection_query()),
861 Err(_) => false,
862 }
863 }
864
865 #[test]
866 fn explicit_vector_projection_is_rejected() {
867 assert!(mentions_vector("SELECT vector FROM messages"));
868 assert!(mentions_vector("SELECT id, vector FROM messages"));
869 assert!(mentions_vector("SELECT m.vector FROM messages m"));
870 assert!(mentions_vector("SELECT array_length(vector) FROM messages"));
871 assert!(mentions_vector("EXPLAIN SELECT vector FROM messages"));
872 }
873
874 #[test]
875 fn enrich_appends_recovery_hints() {
876 let cases = [
878 (
879 "SQL error: Schema error: No field named created_at.",
880 "schema://pond-sql",
881 ),
882 (
883 "SQL error: External error: Arrow error: Invalid argument error: \
884 Encountered non UTF-8 data",
885 "json_extract",
886 ),
887 (
888 "SQL error: External error: Not supported: LIKE prefix queries are not \
889 supported for bitmap indexes",
890 "split_part",
891 ),
892 (
893 "SQL error: Error during planning: Failed to coerce arguments to satisfy \
894 a call to 'json_get_string' function",
895 "JSONPath",
896 ),
897 (
898 "SQL error: Error during planning: Invalid function 'json_get_json'.",
899 "json_extract",
900 ),
901 ];
902 for (raw, marker) in cases {
903 let enriched = enrich(raw);
904 assert!(enriched.starts_with(raw), "original kept: {enriched}");
905 assert!(enriched.contains("hint:"), "hint appended: {enriched}");
906 assert!(enriched.contains(marker), "hint names the fix: {enriched}");
907 }
908 assert_eq!(
910 enrich("SQL error: division by zero"),
911 "SQL error: division by zero"
912 );
913 }
914
915 #[test]
916 fn select_star_and_where_vector_are_allowed() {
917 assert!(!mentions_vector("SELECT * FROM messages"));
919 assert!(!mentions_vector(
921 "SELECT id FROM messages WHERE vector IS NOT NULL"
922 ));
923 }
924}