1use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use anyhow::anyhow;
14use arrow_json::LineDelimitedWriter;
15use lance::Dataset;
16use lance::datafusion::LanceTableProvider;
17use lance::dataset::udtf::FtsQueryUDTFBuilder;
18use lance::deps::arrow_array::RecordBatch;
19use lance::deps::arrow_schema::{ArrowError, DataType};
20use lance::deps::datafusion::arrow::util::pretty::pretty_format_batches;
21use lance::deps::datafusion::execution::SessionStateBuilder;
22use lance::deps::datafusion::execution::runtime_env::RuntimeEnvBuilder;
23use lance::deps::datafusion::prelude::{SQLOptions, SessionConfig, SessionContext};
24use lance::deps::datafusion::sql::parser::{DFParser, Statement as DfStatement};
25use lance::deps::datafusion::sql::sqlparser::ast::{SetExpr, Statement as SqlStatement};
26use lance_datafusion::udf::register_functions;
27use parquet::arrow::ArrowWriter;
28use serde_json::{Map as JsonMap, Value as JsonValue, json};
29
30const MEM_LIMIT_BYTES: usize = 512 * 1024 * 1024;
33const QUERY_TIMEOUT: Duration = Duration::from_secs(30);
36const INLINE_BUDGET_BYTES: usize = 80_000;
38const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
41pub const DEFAULT_INLINE_ROWS: usize = 100;
43pub const MAX_INLINE_ROWS: usize = 1_000;
45
46#[derive(Debug, Clone, Copy)]
49pub enum Format {
50 Parquet,
51 Ndjson,
52}
53
54impl Format {
55 pub fn ext(self) -> &'static str {
56 match self {
57 Self::Parquet => "parquet",
58 Self::Ndjson => "ndjson",
59 }
60 }
61
62 pub fn mime(self) -> &'static str {
63 match self {
64 Self::Parquet => "application/vnd.apache.parquet",
65 Self::Ndjson => "application/x-ndjson",
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy)]
72pub enum Mode {
73 Inline,
75 InlineJson,
82 Export(Format),
84}
85
86pub struct Tables {
89 pub sessions: Arc<Dataset>,
90 pub messages: Arc<Dataset>,
91 pub parts: Arc<Dataset>,
92}
93
94pub enum Outcome {
96 Inline(String),
98 InlineJson(JsonValue),
101 Export {
103 bytes: Vec<u8>,
104 format: Format,
105 rows: usize,
106 columns: Vec<String>,
107 },
108}
109
110#[derive(Debug)]
114pub enum SqlError {
115 Query(String),
116 Infra(anyhow::Error),
117}
118
119fn infra(error: ArrowError) -> SqlError {
120 SqlError::Infra(anyhow::Error::new(error))
121}
122
123pub async fn run(
126 tables: &Tables,
127 sql: &str,
128 mode: Mode,
129 inline_rows: usize,
130) -> Result<Outcome, SqlError> {
131 let parsed = parse_and_gate(sql)?;
132 if matches!(parsed.kind, StatementKind::Explain) && matches!(mode, Mode::Export(_)) {
133 return Err(SqlError::Query(
134 "EXPLAIN returns a plan, not a result set; use output=table (or json) to read it"
135 .to_owned(),
136 ));
137 }
138 if projection_mentions_vector(parsed.projection_query()) {
139 return Err(SqlError::Query(
140 "the `vector` column is not selectable from pond_sql_query (it is a \
141 FixedSizeList<f32> embedding, ~600 bytes per row and not useful in a result). \
142 For semantic search use pond_search. Filtering on it is allowed in WHERE \
143 (e.g. `vector IS NOT NULL`)."
144 .to_owned(),
145 ));
146 }
147 let ctx = build_context()?;
148 register(&ctx, tables)?;
149
150 let options = SQLOptions::new()
156 .with_allow_ddl(false)
157 .with_allow_dml(false)
158 .with_allow_statements(matches!(parsed.kind, StatementKind::Explain));
159 let df = ctx
160 .sql_with_options(sql, options)
161 .await
162 .map_err(|error| SqlError::Query(format!("SQL error: {error}")))?;
163
164 let result_schema = Arc::new(df.schema().as_arrow().clone());
167 let started = Instant::now();
168 let collected = tokio::time::timeout(QUERY_TIMEOUT, df.collect())
169 .await
170 .map_err(|_| {
171 SqlError::Query(format!(
172 "query exceeded the {}s limit; add a narrower WHERE or a LIMIT",
173 QUERY_TIMEOUT.as_secs()
174 ))
175 })?
176 .map_err(|error| SqlError::Query(format!("SQL error: {error}")))?;
177 let elapsed = started.elapsed();
178
179 let display: Vec<RecordBatch> = if collected.is_empty() {
180 vec![displayable(&RecordBatch::new_empty(result_schema)).map_err(infra)?]
181 } else {
182 collected
183 .iter()
184 .map(displayable)
185 .collect::<Result<_, _>>()
186 .map_err(infra)?
187 };
188
189 match mode {
190 Mode::Inline => Ok(Outcome::Inline(
191 render_inline(&display, inline_rows, elapsed).map_err(infra)?,
192 )),
193 Mode::InlineJson => Ok(Outcome::InlineJson(render_inline_json(
194 &display,
195 inline_rows,
196 elapsed,
197 )?)),
198 Mode::Export(format) => {
199 let rows = display.iter().map(RecordBatch::num_rows).sum();
200 let columns = display
201 .first()
202 .map(|batch| {
203 batch
204 .schema()
205 .fields()
206 .iter()
207 .map(|field| field.name().clone())
208 .collect::<Vec<_>>()
209 })
210 .unwrap_or_default();
211 let bytes = match format {
212 Format::Parquet => encode_parquet(&display)?,
213 Format::Ndjson => encode_ndjson(&display)?,
214 };
215 if bytes.len() > MAX_EXPORT_BYTES {
216 return Err(SqlError::Query(format!(
217 "export is {} bytes, over the {MAX_EXPORT_BYTES} byte limit; \
218 narrow the query or aggregate",
219 bytes.len()
220 )));
221 }
222 Ok(Outcome::Export {
223 bytes,
224 format,
225 rows,
226 columns,
227 })
228 }
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234enum StatementKind {
235 Query,
237 Explain,
239}
240
241struct ParsedStatement {
247 kind: StatementKind,
248 query: lance::deps::datafusion::sql::sqlparser::ast::Query,
249}
250
251impl ParsedStatement {
252 fn projection_query(&self) -> &lance::deps::datafusion::sql::sqlparser::ast::Query {
253 &self.query
254 }
255}
256
257fn parse_and_gate(sql: &str) -> Result<ParsedStatement, SqlError> {
264 let statements = DFParser::parse_sql(sql)
265 .map_err(|error| SqlError::Query(format!("SQL parse error: {error}")))?;
266 if statements.len() != 1 {
267 return Err(SqlError::Query(
268 "pond_sql_query runs exactly one statement; submit a single SELECT".to_owned(),
269 ));
270 }
271 let Some(front) = statements.front() else {
272 return Err(read_only_rejection());
273 };
274 match front {
275 DfStatement::Statement(boxed) => match boxed.as_ref() {
276 SqlStatement::Query(query) => Ok(ParsedStatement {
277 kind: StatementKind::Query,
278 query: query.as_ref().clone(),
279 }),
280 _ => Err(read_only_rejection()),
281 },
282 DfStatement::Explain(explain) => match explain.statement.as_ref() {
283 DfStatement::Statement(inner) => match inner.as_ref() {
284 SqlStatement::Query(query) => Ok(ParsedStatement {
285 kind: StatementKind::Explain,
286 query: query.as_ref().clone(),
287 }),
288 _ => Err(read_only_rejection()),
289 },
290 _ => Err(read_only_rejection()),
291 },
292 _ => Err(read_only_rejection()),
293 }
294}
295
296fn read_only_rejection() -> SqlError {
297 SqlError::Query(
298 "pond_sql_query is read-only: only a single SELECT/WITH (or EXPLAIN of one) is \
299 allowed (no INSERT/UPDATE/DELETE/CREATE/DROP/COPY/SET)"
300 .to_owned(),
301 )
302}
303
304fn projection_mentions_vector(query: &lance::deps::datafusion::sql::sqlparser::ast::Query) -> bool {
315 walk_set_expr_for_vector(query.body.as_ref())
316}
317
318fn walk_set_expr_for_vector(expr: &SetExpr) -> bool {
319 match expr {
320 SetExpr::Select(select) => select
321 .projection
322 .iter()
323 .any(|item| mentions_vector_token(&item.to_string())),
324 SetExpr::Query(inner) => walk_set_expr_for_vector(inner.body.as_ref()),
325 SetExpr::SetOperation { left, right, .. } => {
326 walk_set_expr_for_vector(left) || walk_set_expr_for_vector(right)
327 }
328 _ => false,
329 }
330}
331
332fn mentions_vector_token(text: &str) -> bool {
333 text.split(|c: char| !c.is_alphanumeric() && c != '_')
334 .any(|token| token == "vector")
335}
336
337fn build_context() -> Result<SessionContext, SqlError> {
338 let runtime = RuntimeEnvBuilder::new()
339 .with_memory_limit(MEM_LIMIT_BYTES, 1.0)
340 .build_arc()
341 .map_err(|error| SqlError::Infra(anyhow!("datafusion runtime init failed: {error}")))?;
342 let state = SessionStateBuilder::new()
343 .with_config(SessionConfig::new())
344 .with_runtime_env(runtime)
345 .with_default_features()
346 .build();
347 Ok(SessionContext::new_with_state(state))
348}
349
350fn register(ctx: &SessionContext, tables: &Tables) -> Result<(), SqlError> {
351 for (name, dataset) in [
352 ("sessions", &tables.sessions),
353 ("messages", &tables.messages),
354 ("parts", &tables.parts),
355 ] {
356 let provider = LanceTableProvider::new(dataset.clone(), false, false);
360 ctx.register_table(name, Arc::new(provider))
361 .map_err(|error| SqlError::Infra(anyhow!("register table {name}: {error}")))?;
362 }
363 let fts = FtsQueryUDTFBuilder::builder()
366 .register_table("sessions", tables.sessions.clone())
367 .register_table("messages", tables.messages.clone())
368 .register_table("parts", tables.parts.clone())
369 .build();
370 ctx.register_udtf("fts", Arc::new(fts));
371 register_functions(ctx);
372 Ok(())
373}
374
375fn displayable(batch: &RecordBatch) -> Result<RecordBatch, ArrowError> {
378 let decoded = lance_arrow::json::convert_lance_json_to_arrow(batch)?;
379 let keep: Vec<usize> = decoded
380 .schema()
381 .fields()
382 .iter()
383 .enumerate()
384 .filter(|(_, field)| is_displayable(field.data_type()))
385 .map(|(index, _)| index)
386 .collect();
387 decoded.project(&keep)
388}
389
390fn is_displayable(data_type: &DataType) -> bool {
391 !matches!(
392 data_type,
393 DataType::FixedSizeList(_, _)
394 | DataType::Binary
395 | DataType::LargeBinary
396 | DataType::BinaryView
397 | DataType::FixedSizeBinary(_)
398 )
399}
400
401fn render_inline(
402 display: &[RecordBatch],
403 max_rows: usize,
404 elapsed: Duration,
405) -> Result<String, ArrowError> {
406 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
407 let elapsed_ms = elapsed.as_millis();
408 if total == 0 {
409 return Ok(format!(
411 "0 rows ({elapsed_ms} ms).\n{}",
412 pretty_format_batches(display)?
413 ));
414 }
415 let mut shown = total.min(max_rows);
416 let mut table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
417 while table.len() > INLINE_BUDGET_BYTES && shown > 1 {
418 shown = (shown / 2).max(1);
419 table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
420 }
421 let mut out = format!("{total} row(s) in {elapsed_ms} ms; showing {shown}.\n{table}");
422 if shown < total {
423 out.push_str(&format!(
424 "\n... {} row(s) omitted. To page: ORDER BY <indexed col> (e.g. timestamp, \
425 id), then in the next call add `WHERE (col, id) < (<last_col>, <last_id>)` - \
426 keyset pagination, see schema://pond-sql. For the full set: output=parquet \
427 or output=ndjson.",
428 total - shown
429 ));
430 }
431 Ok(out)
432}
433
434fn render_inline_json(
439 display: &[RecordBatch],
440 max_rows: usize,
441 elapsed: Duration,
442) -> Result<JsonValue, SqlError> {
443 let total: usize = display.iter().map(RecordBatch::num_rows).sum();
444 let columns: Vec<String> = display
445 .first()
446 .map(|batch| {
447 batch
448 .schema()
449 .fields()
450 .iter()
451 .map(|field| field.name().clone())
452 .collect()
453 })
454 .unwrap_or_default();
455 let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
456
457 if total == 0 {
458 return Ok(json!({
459 "total_rows": 0,
460 "shown_rows": 0,
461 "truncated": false,
462 "elapsed_ms": elapsed_ms,
463 "columns": columns,
464 "rows": [],
465 }));
466 }
467
468 let mut shown = total.min(max_rows);
469 let mut rows = batches_to_json_rows(&limit_batches(display, shown))?;
470 let mut serialized = serde_json::to_string(&rows)
471 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
472 while serialized.len() > INLINE_BUDGET_BYTES && shown > 1 {
473 shown = (shown / 2).max(1);
474 rows = batches_to_json_rows(&limit_batches(display, shown))?;
475 serialized = serde_json::to_string(&rows)
476 .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
477 }
478
479 let mut payload = JsonMap::new();
480 payload.insert("total_rows".to_owned(), json!(total));
481 payload.insert("shown_rows".to_owned(), json!(shown));
482 payload.insert("truncated".to_owned(), json!(shown < total));
483 payload.insert("elapsed_ms".to_owned(), json!(elapsed_ms));
484 payload.insert("columns".to_owned(), json!(columns));
485 payload.insert("rows".to_owned(), JsonValue::Array(rows));
486 if shown < total {
487 payload.insert(
488 "next_steps".to_owned(),
489 json!(format!(
490 "{} row(s) omitted; ORDER BY + keyset (`WHERE (col, id) < \
491 (<last_col>, <last_id>)`) to page, or output=parquet|ndjson for the \
492 full set. See schema://pond-sql.",
493 total - shown
494 )),
495 );
496 }
497 Ok(JsonValue::Object(payload))
498}
499
500fn batches_to_json_rows(batches: &[RecordBatch]) -> Result<Vec<JsonValue>, SqlError> {
504 if batches.iter().all(|batch| batch.num_rows() == 0) {
505 return Ok(Vec::new());
506 }
507 let mut buffer = Vec::new();
508 {
509 let mut writer = LineDelimitedWriter::new(&mut buffer);
510 let refs: Vec<&RecordBatch> = batches.iter().collect();
511 writer
512 .write_batches(&refs)
513 .map_err(|error| SqlError::Infra(anyhow!("ndjson encode: {error}")))?;
514 writer
515 .finish()
516 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish: {error}")))?;
517 }
518 let text = String::from_utf8(buffer)
519 .map_err(|error| SqlError::Infra(anyhow!("ndjson not utf-8: {error}")))?;
520 text.lines()
521 .filter(|line| !line.is_empty())
522 .map(|line| {
523 serde_json::from_str::<JsonValue>(line)
524 .map_err(|error| SqlError::Infra(anyhow!("ndjson parse: {error}")))
525 })
526 .collect()
527}
528
529fn limit_batches(batches: &[RecordBatch], max_rows: usize) -> Vec<RecordBatch> {
530 let mut out = Vec::new();
531 let mut remaining = max_rows;
532 for batch in batches {
533 if remaining == 0 {
534 break;
535 }
536 if batch.num_rows() <= remaining {
537 remaining -= batch.num_rows();
538 out.push(batch.clone());
539 } else {
540 out.push(batch.slice(0, remaining));
541 remaining = 0;
542 }
543 }
544 out
545}
546
547fn encode_parquet(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
548 let schema = batches
549 .first()
550 .map(RecordBatch::schema)
551 .ok_or_else(|| SqlError::Query("query returned no columns to export".to_owned()))?;
552 let mut buffer = Vec::new();
553 let mut writer = ArrowWriter::try_new(&mut buffer, schema, None)
554 .map_err(|error| SqlError::Infra(anyhow!("parquet init failed: {error}")))?;
555 for batch in batches {
556 writer
557 .write(batch)
558 .map_err(|error| SqlError::Infra(anyhow!("parquet write failed: {error}")))?;
559 }
560 writer
561 .close()
562 .map_err(|error| SqlError::Infra(anyhow!("parquet close failed: {error}")))?;
563 Ok(buffer)
564}
565
566fn encode_ndjson(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
567 let mut buffer = Vec::new();
568 {
569 let mut writer = LineDelimitedWriter::new(&mut buffer);
570 let refs: Vec<&RecordBatch> = batches.iter().collect();
571 writer
572 .write_batches(&refs)
573 .map_err(|error| SqlError::Infra(anyhow!("ndjson write failed: {error}")))?;
574 writer
575 .finish()
576 .map_err(|error| SqlError::Infra(anyhow!("ndjson finish failed: {error}")))?;
577 }
578 Ok(buffer)
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 fn rejected(sql: &str) -> bool {
586 matches!(parse_and_gate(sql), Err(SqlError::Query(_)))
587 }
588
589 fn parses_as(sql: &str, expected: StatementKind) -> bool {
590 match parse_and_gate(sql) {
591 Ok(parsed) => matches!(
592 (&parsed.kind, &expected),
593 (StatementKind::Query, StatementKind::Query)
594 | (StatementKind::Explain, StatementKind::Explain)
595 ),
596 Err(_) => false,
597 }
598 }
599
600 #[test]
601 fn allows_single_select_and_cte() {
602 assert!(parses_as("SELECT 1", StatementKind::Query));
603 assert!(parses_as(
604 "SELECT role, count(*) FROM messages GROUP BY role",
605 StatementKind::Query
606 ));
607 assert!(parses_as(
608 "WITH t AS (SELECT 1 AS a) SELECT a FROM t",
609 StatementKind::Query
610 ));
611 }
612
613 #[test]
614 fn allows_explain_of_select() {
615 assert!(parses_as("EXPLAIN SELECT 1", StatementKind::Explain));
616 assert!(parses_as(
617 "EXPLAIN ANALYZE SELECT role FROM messages",
618 StatementKind::Explain
619 ));
620 }
621
622 #[test]
623 fn rejects_explain_of_non_query() {
624 assert!(rejected("EXPLAIN INSERT INTO messages VALUES ('x')"));
627 }
628
629 #[test]
630 fn rejects_writes_and_side_effects() {
631 assert!(rejected("INSERT INTO messages VALUES ('x')"));
632 assert!(rejected("UPDATE messages SET role = 'x'"));
633 assert!(rejected("DELETE FROM messages"));
634 assert!(rejected("CREATE TABLE t (x INT)"));
635 assert!(rejected("CREATE VIEW v AS SELECT 1"));
636 assert!(rejected("DROP TABLE messages"));
637 assert!(rejected(
638 "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/etc'"
639 ));
640 assert!(rejected("COPY (SELECT 1) TO '/tmp/x.parquet'"));
641 assert!(rejected("SET a = 1"));
642 }
643
644 #[test]
645 fn rejects_multiple_statements() {
646 assert!(rejected("SELECT 1; SELECT 2"));
647 assert!(rejected("SELECT 1; DROP TABLE messages"));
648 }
649
650 #[test]
651 fn rejects_unparseable() {
652 assert!(rejected("NOT SQL AT ALL ;;"));
653 }
654
655 fn mentions_vector(sql: &str) -> bool {
656 match parse_and_gate(sql) {
657 Ok(parsed) => projection_mentions_vector(parsed.projection_query()),
658 Err(_) => false,
659 }
660 }
661
662 #[test]
663 fn explicit_vector_projection_is_rejected() {
664 assert!(mentions_vector("SELECT vector FROM messages"));
665 assert!(mentions_vector("SELECT id, vector FROM messages"));
666 assert!(mentions_vector("SELECT m.vector FROM messages m"));
667 assert!(mentions_vector("SELECT array_length(vector) FROM messages"));
668 assert!(mentions_vector("EXPLAIN SELECT vector FROM messages"));
669 }
670
671 #[test]
672 fn select_star_and_where_vector_are_allowed() {
673 assert!(!mentions_vector("SELECT * FROM messages"));
675 assert!(!mentions_vector(
677 "SELECT id FROM messages WHERE vector IS NOT NULL"
678 ));
679 }
680}