1use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use anyhow::anyhow;
15use arrow_json::LineDelimitedWriter;
16use lance::Dataset;
17use lance::datafusion::LanceTableProvider;
18use lance::deps::arrow_array::builder::{
19 BooleanBuilder, Float64Builder, Int64Builder, StringBuilder,
20};
21use lance::deps::arrow_array::{Array, LargeBinaryArray, RecordBatch, StringArray};
22use lance::deps::arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
23use lance::deps::datafusion::arrow::util::pretty::pretty_format_batches;
24use lance::deps::datafusion::catalog::{Session, TableFunctionImpl, TableProvider};
25use lance::deps::datafusion::common::ScalarValue;
26use lance::deps::datafusion::datasource::{ViewTable, provider_as_source};
27use lance::deps::datafusion::error::DataFusionError;
28use lance::deps::datafusion::execution::SessionStateBuilder;
29use lance::deps::datafusion::execution::runtime_env::RuntimeEnvBuilder;
30use lance::deps::datafusion::logical_expr::{
31 ColumnarValue, LogicalPlanBuilder, ScalarUDF, Volatility, create_udf,
32};
33use lance::deps::datafusion::logical_expr::{Expr, TableType};
34use lance::deps::datafusion::physical_plan::ExecutionPlan;
35use lance::deps::datafusion::prelude::{SQLOptions, SessionConfig, SessionContext, col};
36use lance::deps::datafusion::sql::parser::{DFParser, Statement as DfStatement};
37use lance::deps::datafusion::sql::sqlparser::ast::{SetExpr, Statement as SqlStatement};
38use lance_arrow::SchemaExt;
39use lance_datafusion::udf::register_functions;
40use lance_index::scalar::FullTextSearchQuery;
41use lance_index::scalar::inverted::parser::from_json;
42use parquet::arrow::ArrowWriter;
43use serde_json::{Map as JsonMap, Value as JsonValue, json};
44
45const MEM_LIMIT_BYTES: usize = 512 * 1024 * 1024;
48const QUERY_TIMEOUT: Duration = Duration::from_secs(30);
51const INLINE_BUDGET_BYTES: usize = 80_000;
53const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
56pub const DEFAULT_INLINE_ROWS: usize = 100;
58pub const MAX_INLINE_ROWS: usize = 1_000;
60
61#[derive(Debug, Clone, Copy)]
64pub enum Format {
65 Parquet,
66 Ndjson,
67}
68
69impl Format {
70 pub fn ext(self) -> &'static str {
71 match self {
72 Self::Parquet => "parquet",
73 Self::Ndjson => "ndjson",
74 }
75 }
76
77 pub fn mime(self) -> &'static str {
78 match self {
79 Self::Parquet => "application/vnd.apache.parquet",
80 Self::Ndjson => "application/x-ndjson",
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy)]
87pub enum Mode {
88 Inline,
90 InlineJson,
97 Export(Format),
99}
100
101pub struct Tables {
104 pub sessions: Arc<Dataset>,
105 pub messages: Arc<Dataset>,
106 pub parts: Arc<Dataset>,
107}
108
109pub enum Outcome {
111 Inline(String),
113 InlineJson(JsonValue),
116 Export {
118 bytes: Vec<u8>,
119 format: Format,
120 rows: usize,
121 columns: Vec<String>,
122 },
123}
124
125#[derive(Debug)]
129pub enum SqlError {
130 Query(String),
131 Infra(anyhow::Error),
132}
133
134fn infra(error: ArrowError) -> SqlError {
135 SqlError::Infra(anyhow::Error::new(error))
136}
137
138pub async fn run(
141 tables: &Tables,
142 sql: &str,
143 mode: Mode,
144 inline_rows: usize,
145) -> Result<Outcome, SqlError> {
146 let parsed = parse_and_gate(sql)?;
147 if matches!(parsed.kind, StatementKind::Explain) && matches!(mode, Mode::Export(_)) {
148 return Err(SqlError::Query(
149 "EXPLAIN returns a plan, not a result set; use output=table (or json) to read it"
150 .to_owned(),
151 ));
152 }
153 if projection_mentions_vector(parsed.projection_query()) {
154 return Err(SqlError::Query(
155 "the `vector` column is not selectable from pond_sql_query (it is a \
156 FixedSizeList<f32> embedding, ~600 bytes per row and not useful in a result). \
157 For semantic search use pond_search. Filtering on it is allowed in WHERE \
158 (e.g. `vector IS NOT NULL`)."
159 .to_owned(),
160 ));
161 }
162 let ctx = build_context()?;
163 register(&ctx, tables)?;
164
165 let options = SQLOptions::new()
171 .with_allow_ddl(false)
172 .with_allow_dml(false)
173 .with_allow_statements(matches!(parsed.kind, StatementKind::Explain));
174 let df = ctx
175 .sql_with_options(sql, options)
176 .await
177 .map_err(|error| SqlError::Query(enrich(&format!("SQL error: {error}"))))?;
178
179 let result_schema = Arc::new(df.schema().as_arrow().clone());
182 let started = Instant::now();
183 let collected = tokio::time::timeout(QUERY_TIMEOUT, df.collect())
184 .await
185 .map_err(|_| {
186 SqlError::Query(format!(
187 "query exceeded the {}s limit; add a narrower WHERE or a LIMIT",
188 QUERY_TIMEOUT.as_secs()
189 ))
190 })?
191 .map_err(|error| SqlError::Query(enrich(&format!("SQL error: {error}"))))?;
192 let elapsed = started.elapsed();
193
194 let display: Vec<RecordBatch> = if collected.is_empty() {
195 vec![displayable(&RecordBatch::new_empty(result_schema)).map_err(infra)?]
196 } else {
197 collected
198 .iter()
199 .map(displayable)
200 .collect::<Result<_, _>>()
201 .map_err(infra)?
202 };
203
204 match mode {
205 Mode::Inline => Ok(Outcome::Inline(
206 render_inline(&display, inline_rows, elapsed).map_err(infra)?,
207 )),
208 Mode::InlineJson => Ok(Outcome::InlineJson(render_inline_json(
209 &display,
210 inline_rows,
211 elapsed,
212 )?)),
213 Mode::Export(format) => {
214 let rows = display.iter().map(RecordBatch::num_rows).sum();
215 let columns = display
216 .first()
217 .map(|batch| {
218 batch
219 .schema()
220 .fields()
221 .iter()
222 .map(|field| field.name().clone())
223 .collect::<Vec<_>>()
224 })
225 .unwrap_or_default();
226 let bytes = match format {
227 Format::Parquet => encode_parquet(&display)?,
228 Format::Ndjson => encode_ndjson(&display)?,
229 };
230 if bytes.len() > MAX_EXPORT_BYTES {
231 return Err(SqlError::Query(format!(
232 "export is {} bytes, over the {MAX_EXPORT_BYTES} byte limit; \
233 narrow the query or aggregate",
234 bytes.len()
235 )));
236 }
237 Ok(Outcome::Export {
238 bytes,
239 format,
240 rows,
241 columns,
242 })
243 }
244 }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249enum StatementKind {
250 Query,
252 Explain,
254}
255
256struct ParsedStatement {
262 kind: StatementKind,
263 query: lance::deps::datafusion::sql::sqlparser::ast::Query,
264}
265
266impl ParsedStatement {
267 fn projection_query(&self) -> &lance::deps::datafusion::sql::sqlparser::ast::Query {
268 &self.query
269 }
270}
271
272fn parse_and_gate(sql: &str) -> Result<ParsedStatement, SqlError> {
279 let statements = DFParser::parse_sql(sql)
280 .map_err(|error| SqlError::Query(format!("SQL parse error: {error}")))?;
281 if statements.len() != 1 {
282 return Err(SqlError::Query(
283 "pond_sql_query runs exactly one statement; submit a single SELECT".to_owned(),
284 ));
285 }
286 let Some(front) = statements.front() else {
287 return Err(read_only_rejection());
288 };
289 match front {
290 DfStatement::Statement(boxed) => match boxed.as_ref() {
291 SqlStatement::Query(query) => Ok(ParsedStatement {
292 kind: StatementKind::Query,
293 query: query.as_ref().clone(),
294 }),
295 _ => Err(read_only_rejection()),
296 },
297 DfStatement::Explain(explain) => match explain.statement.as_ref() {
298 DfStatement::Statement(inner) => match inner.as_ref() {
299 SqlStatement::Query(query) => Ok(ParsedStatement {
300 kind: StatementKind::Explain,
301 query: query.as_ref().clone(),
302 }),
303 _ => Err(read_only_rejection()),
304 },
305 _ => Err(read_only_rejection()),
306 },
307 _ => Err(read_only_rejection()),
308 }
309}
310
311fn read_only_rejection() -> SqlError {
312 SqlError::Query(
313 "pond_sql_query is read-only: only a single SELECT/WITH (or EXPLAIN of one) is \
314 allowed (no INSERT/UPDATE/DELETE/CREATE/DROP/COPY/SET)"
315 .to_owned(),
316 )
317}
318
319fn projection_mentions_vector(query: &lance::deps::datafusion::sql::sqlparser::ast::Query) -> bool {
330 walk_set_expr_for_vector(query.body.as_ref())
331}
332
333fn walk_set_expr_for_vector(expr: &SetExpr) -> bool {
334 match expr {
335 SetExpr::Select(select) => select
336 .projection
337 .iter()
338 .any(|item| mentions_vector_token(&item.to_string())),
339 SetExpr::Query(inner) => walk_set_expr_for_vector(inner.body.as_ref()),
340 SetExpr::SetOperation { left, right, .. } => {
341 walk_set_expr_for_vector(left) || walk_set_expr_for_vector(right)
342 }
343 _ => false,
344 }
345}
346
347fn mentions_vector_token(text: &str) -> bool {
348 text.split(|c: char| !c.is_alphanumeric() && c != '_')
349 .any(|token| token == "vector")
350}
351
352fn build_context() -> Result<SessionContext, SqlError> {
353 let runtime = RuntimeEnvBuilder::new()
354 .with_memory_limit(MEM_LIMIT_BYTES, 1.0)
355 .build_arc()
356 .map_err(|error| SqlError::Infra(anyhow!("datafusion runtime init failed: {error}")))?;
357 let state = SessionStateBuilder::new()
360 .with_config(SessionConfig::new().with_information_schema(true))
361 .with_runtime_env(runtime)
362 .with_default_features()
363 .build();
364 Ok(SessionContext::new_with_state(state))
365}
366
367fn register(ctx: &SessionContext, tables: &Tables) -> Result<(), SqlError> {
368 for (name, dataset) in [
369 ("sessions", &tables.sessions),
370 ("messages", &tables.messages),
371 ] {
372 let provider = LanceTableProvider::new(dataset.clone(), false, false);
376 ctx.register_table(name, Arc::new(provider))
377 .map_err(|error| SqlError::Infra(anyhow!("register table {name}: {error}")))?;
378 }
379 let provider = LanceTableProvider::new(tables.parts.clone(), false, false);
384 let keep: Vec<_> = tables
385 .parts
386 .schema()
387 .fields
388 .iter()
389 .filter(|field| field.name != "data")
390 .map(|field| col(field.name.as_str()))
391 .collect();
392 let plan = LogicalPlanBuilder::scan("parts", provider_as_source(Arc::new(provider)), None)
393 .and_then(|builder| builder.project(keep))
394 .and_then(LogicalPlanBuilder::build)
395 .map_err(|error| SqlError::Infra(anyhow!("build parts view: {error}")))?;
396 ctx.register_table("parts", Arc::new(ViewTable::new(plan, None)))
397 .map_err(|error| SqlError::Infra(anyhow!("register table parts: {error}")))?;
398 let fts = ScoredFtsUdtf {
402 datasets: HashMap::from([
403 ("sessions".to_owned(), tables.sessions.clone()),
404 ("messages".to_owned(), tables.messages.clone()),
405 ("parts".to_owned(), tables.parts.clone()),
406 ]),
407 };
408 ctx.register_udtf("fts", Arc::new(fts));
409 register_functions(ctx);
410 for udf in lenient_json_udfs() {
414 ctx.register_udf(udf);
415 }
416 Ok(())
417}
418
419#[derive(Debug)]
431struct ScoredFtsUdtf {
432 datasets: HashMap<String, Arc<Dataset>>,
433}
434
435impl TableFunctionImpl for ScoredFtsUdtf {
436 fn call(
437 &self,
438 expr: &[Expr],
439 ) -> Result<Arc<dyn TableProvider>, lance::deps::datafusion::error::DataFusionError> {
440 let [table_expr, query_expr] = expr else {
441 return Err(DataFusionError::Execution(
442 "fts() takes (table_name, fts_query_json)".to_owned(),
443 ));
444 };
445 let Expr::Literal(ScalarValue::Utf8(Some(table_name)), _) = table_expr else {
446 return Err(DataFusionError::Execution(
447 "fts() first argument must be a table name string".to_owned(),
448 ));
449 };
450 let Expr::Literal(ScalarValue::Utf8(Some(fts_query)), _) = query_expr else {
451 return Err(DataFusionError::Execution(
452 "fts() second argument must be the fts query as a JSON string".to_owned(),
453 ));
454 };
455 let dataset = self.datasets.get(table_name).ok_or_else(|| {
456 DataFusionError::Execution(format!("fts(): table {table_name} not found"))
457 })?;
458 let mut full_schema = Schema::from(dataset.schema());
459 full_schema = full_schema
460 .try_with_column(Field::new(SCORE_COLUMN, DataType::Float32, true))
461 .map_err(|error| DataFusionError::ArrowError(Box::new(error), None))?;
462 Ok(Arc::new(ScoredFtsProvider {
463 dataset: dataset.clone(),
464 fts_query: FullTextSearchQuery::new_query(from_json(fts_query)?),
465 full_schema: Arc::new(full_schema),
466 }))
467 }
468}
469
470const SCORE_COLUMN: &str = "_score";
471
472#[derive(Debug)]
473struct ScoredFtsProvider {
474 dataset: Arc<Dataset>,
475 fts_query: FullTextSearchQuery,
476 full_schema: SchemaRef,
477}
478
479#[async_trait::async_trait]
480impl TableProvider for ScoredFtsProvider {
481 fn as_any(&self) -> &dyn std::any::Any {
482 self
483 }
484
485 fn schema(&self) -> SchemaRef {
486 self.full_schema.clone()
487 }
488
489 fn table_type(&self) -> TableType {
490 TableType::Temporary
491 }
492
493 async fn scan(
494 &self,
495 _state: &dyn Session,
496 projection: Option<&Vec<usize>>,
497 filters: &[Expr],
498 limit: Option<usize>,
499 ) -> Result<Arc<dyn ExecutionPlan>, lance::deps::datafusion::error::DataFusionError> {
500 let mut scan = self.dataset.scan();
501 scan.full_text_search(self.fts_query.clone())?;
502 scan.disable_scoring_autoprojection();
506 match projection {
507 Some(projection) if projection.is_empty() => {
508 scan.empty_project()?;
509 }
510 Some(projection) => {
511 let columns: Vec<&str> = projection
512 .iter()
513 .map(|idx| self.full_schema.field(*idx).name().as_str())
514 .collect();
515 scan.project(&columns)?;
516 }
517 None => {
518 let columns: Vec<&str> = self
519 .full_schema
520 .fields()
521 .iter()
522 .map(|field| field.name().as_str())
523 .collect();
524 scan.project(&columns)?;
525 }
526 }
527 if let Some(combined) = filters
528 .iter()
529 .cloned()
530 .reduce(|left, right| left.and(right))
531 {
532 scan.filter_expr(combined);
533 }
534 scan.limit(limit.map(|l| l as i64), None)?;
535 scan.create_plan().await.map_err(DataFusionError::from)
536 }
537}
538
539enum JsonGet {
541 Text,
542 Int,
543 Float,
544 Bool,
545}
546
547fn lenient_json_udfs() -> [ScalarUDF; 4] {
555 let make = |name: &str, kind: JsonGet, return_type: DataType| {
556 create_udf(
557 name,
558 vec![DataType::LargeBinary, DataType::Utf8],
559 return_type,
560 Volatility::Immutable,
561 Arc::new(move |args: &[ColumnarValue]| json_get_lenient(args, &kind)),
562 )
563 };
564 [
565 make("json_get_string", JsonGet::Text, DataType::Utf8),
566 make("json_get_int", JsonGet::Int, DataType::Int64),
567 make("json_get_float", JsonGet::Float, DataType::Float64),
568 make("json_get_bool", JsonGet::Bool, DataType::Boolean),
569 ]
570}
571
572fn json_get_lenient(
573 args: &[ColumnarValue],
574 kind: &JsonGet,
575) -> Result<ColumnarValue, DataFusionError> {
576 let arrays = ColumnarValue::values_to_arrays(args)?;
577 let [jsonb_arg, key_arg] = arrays.as_slice() else {
578 return Err(DataFusionError::Execution(
579 "json_get_* takes exactly (json_column, 'key')".to_owned(),
580 ));
581 };
582 let jsonb_array = jsonb_arg
583 .as_any()
584 .downcast_ref::<LargeBinaryArray>()
585 .ok_or_else(|| {
586 DataFusionError::Execution(
587 "json_get_* argument 1 must be a JSON column (variant_data, options)".to_owned(),
588 )
589 })?;
590 let key_array = key_arg
591 .as_any()
592 .downcast_ref::<StringArray>()
593 .ok_or_else(|| {
594 DataFusionError::Execution("json_get_* argument 2 must be a string key".to_owned())
595 })?;
596
597 let field = |row: usize| -> Option<jsonb::OwnedJsonb> {
598 if jsonb_array.is_null(row) || key_array.is_null(row) {
599 return None;
600 }
601 let raw = jsonb::RawJsonb::new(jsonb_array.value(row));
602 let key = key_array.value(row);
603 let value = if raw.is_object().unwrap_or(false) {
604 raw.get_by_name(key, false).ok().flatten()
605 } else if raw.is_array().unwrap_or(false) {
606 key.parse::<usize>()
607 .ok()
608 .and_then(|index| raw.get_by_index(index).ok().flatten())
609 } else {
610 None
611 };
612 value.filter(|value| !value.as_raw().is_null().unwrap_or(false))
613 };
614
615 let rows = jsonb_array.len();
616 let array: Arc<dyn Array> = match kind {
617 JsonGet::Text => {
618 let mut builder = StringBuilder::with_capacity(rows, 1024);
619 for row in 0..rows {
620 match field(row) {
621 Some(value) => match value.as_raw().to_str() {
624 Ok(text) => builder.append_value(text),
625 Err(_) => builder.append_value(value.to_string()),
626 },
627 None => builder.append_null(),
628 }
629 }
630 Arc::new(builder.finish())
631 }
632 JsonGet::Int => {
633 let mut builder = Int64Builder::with_capacity(rows);
634 for row in 0..rows {
635 builder.append_option(field(row).and_then(|value| value.as_raw().to_i64().ok()));
636 }
637 Arc::new(builder.finish())
638 }
639 JsonGet::Float => {
640 let mut builder = Float64Builder::with_capacity(rows);
641 for row in 0..rows {
642 builder.append_option(field(row).and_then(|value| value.as_raw().to_f64().ok()));
643 }
644 Arc::new(builder.finish())
645 }
646 JsonGet::Bool => {
647 let mut builder = BooleanBuilder::with_capacity(rows);
648 for row in 0..rows {
649 builder.append_option(field(row).and_then(|value| value.as_raw().to_bool().ok()));
650 }
651 Arc::new(builder.finish())
652 }
653 };
654 Ok(ColumnarValue::Array(array))
655}
656
657fn enrich(message: &str) -> String {
661 const HINTS: &[(&str, &str)] = &[
662 (
663 "No field named",
664 "columns are messages(session_id, id, timestamp, role, source_agent, project, \
665 content [system-role only], search_text [the conversational text], \
666 embedding_model, options) | sessions(id, parent_session_id, parent_message_id, \
667 source_agent, created_at, project, options) | parts(session_id, message_id, id, \
668 ordinal, type, provenance, variant_data, options). Part bodies (tool params/\
669 results, text) live in parts.variant_data - read them with \
670 json_extract(variant_data, '$.field'). For text search use fts('messages', \
671 ...); to read a transcript use pond_get. Full doc: resource schema://pond-sql.",
672 ),
673 (
674 "Encountered non UTF-8 data",
675 "JSON columns (variant_data, options) are binary JSONB - CAST / ::text does not \
676 work on them. Stringify the whole value with json_extract(col, '$'), or fetch \
677 one field with json_extract(col, '$.field').",
678 ),
679 (
680 "LIKE prefix queries are not supported for bitmap indexes",
681 "prefix LIKE ('x%') and starts_with() fail on bitmap-indexed columns \
682 (messages.source_agent, messages.role). Use equality, \
683 split_part(source_agent, '/', 1) = '...', or an infix pattern (LIKE '%x%').",
684 ),
685 (
686 "call to 'json_",
687 "JSON function signatures: json_get_string|json_get_int|json_get_float|\
688 json_get_bool(col, 'key') - one key, not a path; json_get(col, 'key') returns \
689 JSONB for chaining; json_extract(col, '$.a.b') takes a JSONPath and returns \
690 JSON text of any value (the right tool for nested or mixed-type fields).",
691 ),
692 (
693 "Invalid function 'json",
694 "available JSON functions: json_get_string, json_get_int, json_get_float, \
695 json_get_bool (col, 'key'); json_get(col, 'key') -> JSONB for chaining; \
696 json_extract(col, '$.a.b') -> JSON text; json_array_contains; \
697 json_array_length. See resource schema://pond-sql.",
698 ),
699 (
700 "does not satisfy distribution requirements",
705 "this fts query shape planned an unexecutable join. For AND semantics use a \
706 single match query with operator And: fts('messages', \
707 '{\"match\":{\"column\":\"search_text\",\"terms\":\"a b\",\"operator\":\"And\"}}'), \
708 optionally with LIKE post-filters in WHERE.",
709 ),
710 (
711 "position is not found but required for phrase queries",
712 "the full-text index is built without positions, so \"phrase\" queries are \
713 unavailable. Use a match query with operator And plus LIKE post-filters for \
714 exact-substring matching.",
715 ),
716 ];
717 for (pattern, hint) in HINTS {
718 if message.contains(pattern) {
719 return format!("{message}\nhint: {hint}");
720 }
721 }
722 message.to_owned()
723}
724
725fn displayable(batch: &RecordBatch) -> Result<RecordBatch, ArrowError> {
728 let decoded = lance_arrow::json::convert_lance_json_to_arrow(batch)?;
729 let keep: Vec<usize> = decoded
730 .schema()
731 .fields()
732 .iter()
733 .enumerate()
734 .filter(|(_, field)| is_displayable(field.data_type()))
735 .map(|(index, _)| index)
736 .collect();
737 decoded.project(&keep)
738}
739
740fn is_displayable(data_type: &DataType) -> bool {
741 !matches!(
742 data_type,
743 DataType::FixedSizeList(_, _)
744 | DataType::Binary
745 | DataType::LargeBinary
746 | DataType::BinaryView
747 | DataType::FixedSizeBinary(_)
748 )
749}
750
751fn render_inline(
752 display: &[RecordBatch],
753 max_rows: usize,
754 elapsed: Duration,
755) -> Result<String, ArrowError> {
756 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
757 let elapsed_ms = elapsed.as_millis();
758 if total == 0 {
759 return Ok(format!(
761 "0 rows ({elapsed_ms} ms).\n{}",
762 pretty_format_batches(display)?
763 ));
764 }
765 let mut shown = total.min(max_rows);
766 let mut table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
767 while table.len() > INLINE_BUDGET_BYTES && shown > 1 {
768 shown = (shown / 2).max(1);
769 table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
770 }
771 let mut out = format!("{total} row(s) in {elapsed_ms} ms; showing {shown}.\n{table}");
772 if shown < total {
773 out.push_str(&format!(
774 "\n... {} row(s) omitted. To page: ORDER BY <indexed col> (e.g. timestamp, \
775 id), then in the next call add `WHERE (col, id) < (<last_col>, <last_id>)` - \
776 keyset pagination, see schema://pond-sql. For the full set: output=parquet \
777 or output=ndjson.",
778 total - shown
779 ));
780 }
781 Ok(out)
782}
783
784fn render_inline_json(
789 display: &[RecordBatch],
790 max_rows: usize,
791 elapsed: Duration,
792) -> Result<JsonValue, SqlError> {
793 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
794 let columns: Vec<String> = display
795 .first()
796 .map(|batch| {
797 batch
798 .schema()
799 .fields()
800 .iter()
801 .map(|field| field.name().clone())
802 .collect()
803 })
804 .unwrap_or_default();
805 let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
806
807 if total == 0 {
808 return Ok(json!({
809 "total_rows": 0,
810 "shown_rows": 0,
811 "truncated": false,
812 "elapsed_ms": elapsed_ms,
813 "columns": columns,
814 "rows": [],
815 }));
816 }
817
818 let mut shown = total.min(max_rows);
819 let mut rows = batches_to_json_rows(&limit_batches(display, shown))?;
820 let mut serialized = serde_json::to_string(&rows)
821 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
822 while serialized.len() > INLINE_BUDGET_BYTES && shown > 1 {
823 shown = (shown / 2).max(1);
824 rows = batches_to_json_rows(&limit_batches(display, shown))?;
825 serialized = serde_json::to_string(&rows)
826 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
827 }
828
829 let mut payload = JsonMap::new();
830 payload.insert("total_rows".to_owned(), json!(total));
831 payload.insert("shown_rows".to_owned(), json!(shown));
832 payload.insert("truncated".to_owned(), json!(shown < total));
833 payload.insert("elapsed_ms".to_owned(), json!(elapsed_ms));
834 payload.insert("columns".to_owned(), json!(columns));
835 payload.insert("rows".to_owned(), JsonValue::Array(rows));
836 if shown < total {
837 payload.insert(
838 "next_steps".to_owned(),
839 json!(format!(
840 "{} row(s) omitted; ORDER BY + keyset (`WHERE (col, id) < \
841 (<last_col>, <last_id>)`) to page, or output=parquet|ndjson for the \
842 full set. See schema://pond-sql.",
843 total - shown
844 )),
845 );
846 }
847 Ok(JsonValue::Object(payload))
848}
849
850fn batches_to_json_rows(batches: &[RecordBatch]) -> Result<Vec<JsonValue>, SqlError> {
854 if batches.iter().all(|batch| batch.num_rows() == 0) {
855 return Ok(Vec::new());
856 }
857 let mut buffer = Vec::new();
858 {
859 let mut writer = LineDelimitedWriter::new(&mut buffer);
860 let refs: Vec<&RecordBatch> = batches.iter().collect();
861 writer
862 .write_batches(&refs)
863 .map_err(|error| SqlError::Infra(anyhow!("ndjson encode: {error}")))?;
864 writer
865 .finish()
866 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish: {error}")))?;
867 }
868 let text = String::from_utf8(buffer)
869 .map_err(|error| SqlError::Infra(anyhow!("ndjson not utf-8: {error}")))?;
870 text.lines()
871 .filter(|line| !line.is_empty())
872 .map(|line| {
873 serde_json::from_str::<JsonValue>(line)
874 .map_err(|error| SqlError::Infra(anyhow!("ndjson parse: {error}")))
875 })
876 .collect()
877}
878
879fn limit_batches(batches: &[RecordBatch], max_rows: usize) -> Vec<RecordBatch> {
880 let mut out = Vec::new();
881 let mut remaining = max_rows;
882 for batch in batches {
883 if remaining == 0 {
884 break;
885 }
886 if batch.num_rows() <= remaining {
887 remaining -= batch.num_rows();
888 out.push(batch.clone());
889 } else {
890 out.push(batch.slice(0, remaining));
891 remaining = 0;
892 }
893 }
894 out
895}
896
897fn encode_parquet(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
898 let schema = batches
899 .first()
900 .map(RecordBatch::schema)
901 .ok_or_else(|| SqlError::Query("query returned no columns to export".to_owned()))?;
902 let mut buffer = Vec::new();
903 let mut writer = ArrowWriter::try_new(&mut buffer, schema, None)
904 .map_err(|error| SqlError::Infra(anyhow!("parquet init failed: {error}")))?;
905 for batch in batches {
906 writer
907 .write(batch)
908 .map_err(|error| SqlError::Infra(anyhow!("parquet write failed: {error}")))?;
909 }
910 writer
911 .close()
912 .map_err(|error| SqlError::Infra(anyhow!("parquet close failed: {error}")))?;
913 Ok(buffer)
914}
915
916fn encode_ndjson(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
917 let mut buffer = Vec::new();
918 {
919 let mut writer = LineDelimitedWriter::new(&mut buffer);
920 let refs: Vec<&RecordBatch> = batches.iter().collect();
921 writer
922 .write_batches(&refs)
923 .map_err(|error| SqlError::Infra(anyhow!("ndjson write failed: {error}")))?;
924 writer
925 .finish()
926 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish failed: {error}")))?;
927 }
928 Ok(buffer)
929}
930
931#[cfg(test)]
932mod tests {
933 use super::*;
934
935 fn rejected(sql: &str) -> bool {
936 matches!(parse_and_gate(sql), Err(SqlError::Query(_)))
937 }
938
939 fn parses_as(sql: &str, expected: StatementKind) -> bool {
940 match parse_and_gate(sql) {
941 Ok(parsed) => matches!(
942 (&parsed.kind, &expected),
943 (StatementKind::Query, StatementKind::Query)
944 | (StatementKind::Explain, StatementKind::Explain)
945 ),
946 Err(_) => false,
947 }
948 }
949
950 #[test]
951 fn allows_single_select_and_cte() {
952 assert!(parses_as("SELECT 1", StatementKind::Query));
953 assert!(parses_as(
954 "SELECT role, count(*) FROM messages GROUP BY role",
955 StatementKind::Query
956 ));
957 assert!(parses_as(
958 "WITH t AS (SELECT 1 AS a) SELECT a FROM t",
959 StatementKind::Query
960 ));
961 }
962
963 #[test]
964 fn allows_explain_of_select() {
965 assert!(parses_as("EXPLAIN SELECT 1", StatementKind::Explain));
966 assert!(parses_as(
967 "EXPLAIN ANALYZE SELECT role FROM messages",
968 StatementKind::Explain
969 ));
970 }
971
972 #[test]
973 fn rejects_explain_of_non_query() {
974 assert!(rejected("EXPLAIN INSERT INTO messages VALUES ('x')"));
977 }
978
979 #[test]
980 fn rejects_writes_and_side_effects() {
981 assert!(rejected("INSERT INTO messages VALUES ('x')"));
982 assert!(rejected("UPDATE messages SET role = 'x'"));
983 assert!(rejected("DELETE FROM messages"));
984 assert!(rejected("CREATE TABLE t (x INT)"));
985 assert!(rejected("CREATE VIEW v AS SELECT 1"));
986 assert!(rejected("DROP TABLE messages"));
987 assert!(rejected(
988 "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/etc'"
989 ));
990 assert!(rejected("COPY (SELECT 1) TO '/tmp/x.parquet'"));
991 assert!(rejected("SET a = 1"));
992 }
993
994 #[test]
995 fn rejects_multiple_statements() {
996 assert!(rejected("SELECT 1; SELECT 2"));
997 assert!(rejected("SELECT 1; DROP TABLE messages"));
998 }
999
1000 #[test]
1001 fn rejects_unparseable() {
1002 assert!(rejected("NOT SQL AT ALL ;;"));
1003 }
1004
1005 fn mentions_vector(sql: &str) -> bool {
1006 match parse_and_gate(sql) {
1007 Ok(parsed) => projection_mentions_vector(parsed.projection_query()),
1008 Err(_) => false,
1009 }
1010 }
1011
1012 #[test]
1013 fn explicit_vector_projection_is_rejected() {
1014 assert!(mentions_vector("SELECT vector FROM messages"));
1015 assert!(mentions_vector("SELECT id, vector FROM messages"));
1016 assert!(mentions_vector("SELECT m.vector FROM messages m"));
1017 assert!(mentions_vector("SELECT array_length(vector) FROM messages"));
1018 assert!(mentions_vector("EXPLAIN SELECT vector FROM messages"));
1019 }
1020
1021 #[test]
1022 fn enrich_appends_recovery_hints() {
1023 let cases = [
1025 (
1026 "SQL error: Schema error: No field named created_at.",
1027 "schema://pond-sql",
1028 ),
1029 (
1030 "SQL error: External error: Arrow error: Invalid argument error: \
1031 Encountered non UTF-8 data",
1032 "json_extract",
1033 ),
1034 (
1035 "SQL error: External error: Not supported: LIKE prefix queries are not \
1036 supported for bitmap indexes",
1037 "split_part",
1038 ),
1039 (
1040 "SQL error: Error during planning: Failed to coerce arguments to satisfy \
1041 a call to 'json_get_string' function",
1042 "JSONPath",
1043 ),
1044 (
1045 "SQL error: Error during planning: Invalid function 'json_get_json'.",
1046 "json_extract",
1047 ),
1048 ];
1049 for (raw, marker) in cases {
1050 let enriched = enrich(raw);
1051 assert!(enriched.starts_with(raw), "original kept: {enriched}");
1052 assert!(enriched.contains("hint:"), "hint appended: {enriched}");
1053 assert!(enriched.contains(marker), "hint names the fix: {enriched}");
1054 }
1055 assert_eq!(
1057 enrich("SQL error: division by zero"),
1058 "SQL error: division by zero"
1059 );
1060 }
1061
1062 #[test]
1063 fn select_star_and_where_vector_are_allowed() {
1064 assert!(!mentions_vector("SELECT * FROM messages"));
1066 assert!(!mentions_vector(
1068 "SELECT id FROM messages WHERE vector IS NOT NULL"
1069 ));
1070 }
1071}