Skip to main content

pond/
sql.rs

1//! `pond_sql_query`: read-only DataFusion SQL over the three Lance tables
2//! (`sessions` / `messages` / `parts`), registered as `LanceTableProvider`s on
3//! a fresh per-call `SessionContext`. Read-only is enforced in two layers - a
4//! single-`SELECT` pre-parse and `sql_with_options` with DDL/DML/statements all
5//! disabled - so no statement that mutates the corpus or touches the filesystem
6//! (INSERT/UPDATE/DELETE/CREATE/DROP/COPY/CREATE EXTERNAL TABLE/SET) can run.
7//! Results render inline (row-capped) or export to a parquet/ndjson file the
8//! caller fetches via the `pond-sql-export://` resource (`src/transport.rs`).
9
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use anyhow::anyhow;
14use arrow_json::LineDelimitedWriter;
15use lance::Dataset;
16use lance::datafusion::LanceTableProvider;
17use lance::dataset::udtf::FtsQueryUDTFBuilder;
18use lance::deps::arrow_array::RecordBatch;
19use lance::deps::arrow_schema::{ArrowError, DataType};
20use lance::deps::datafusion::arrow::util::pretty::pretty_format_batches;
21use lance::deps::datafusion::execution::SessionStateBuilder;
22use lance::deps::datafusion::execution::runtime_env::RuntimeEnvBuilder;
23use lance::deps::datafusion::prelude::{SQLOptions, SessionConfig, SessionContext};
24use lance::deps::datafusion::sql::parser::{DFParser, Statement as DfStatement};
25use lance::deps::datafusion::sql::sqlparser::ast::{SetExpr, Statement as SqlStatement};
26use lance_datafusion::udf::register_functions;
27use parquet::arrow::ArrowWriter;
28use serde_json::{Map as JsonMap, Value as JsonValue, json};
29
30/// Per-query memory ceiling for the DataFusion runtime. Not enforced on every
31/// operator (datafusion caveat), so the timeout below is the hard backstop.
32const MEM_LIMIT_BYTES: usize = 512 * 1024 * 1024;
33/// Wall-clock cap on `collect()`. DataFusion 53 has no built-in query timeout,
34/// so this `tokio::time::timeout` is the only guard against a runaway plan.
35const QUERY_TIMEOUT: Duration = Duration::from_secs(30);
36/// Byte budget for the inline (rendered table) result; rows are dropped to fit.
37const INLINE_BUDGET_BYTES: usize = 80_000;
38/// Hard ceiling on an export artifact: base64'd over `resources/read` it costs
39/// ~1.33x this in the response, so keep it well under any process envelope.
40const MAX_EXPORT_BYTES: usize = 100 * 1024 * 1024;
41/// Default inline row cap when the caller passes no `limit`.
42pub const DEFAULT_INLINE_ROWS: usize = 100;
43/// Upper bound on the caller-supplied inline `limit`.
44pub const MAX_INLINE_ROWS: usize = 1_000;
45
46/// Export serialization format. Vector columns are excluded and JSON columns
47/// are decoded to text before encoding (see [`displayable`]).
48#[derive(Debug, Clone, Copy)]
49pub enum Format {
50    Parquet,
51    Ndjson,
52}
53
54impl Format {
55    pub fn ext(self) -> &'static str {
56        match self {
57            Self::Parquet => "parquet",
58            Self::Ndjson => "ndjson",
59        }
60    }
61
62    pub fn mime(self) -> &'static str {
63        match self {
64            Self::Parquet => "application/vnd.apache.parquet",
65            Self::Ndjson => "application/x-ndjson",
66        }
67    }
68}
69
70/// How `pond_sql_query` returns results.
71#[derive(Debug, Clone, Copy)]
72pub enum Mode {
73    /// Render a row-capped table into the tool result.
74    Inline,
75    /// Return a row-capped JSON payload; the MCP layer surfaces it through
76    /// `structuredContent` (with a stringified text fallback for clients that
77    /// do not surface the structured channel). Empirically validated on Claude
78    /// Code 2.1.165: when both channels carry the same payload, the agent reads
79    /// the structured one and the text block is a soft-landing for other
80    /// clients (spec 2025-11-25 server SHOULD).
81    InlineJson,
82    /// Write the full result to a file and return a `pond-sql-export://` link.
83    Export(Format),
84}
85
86/// The three Lance datasets, fetched fresh per call so each query sees a
87/// current snapshot (the handle freshness gate runs on each `Store::dataset`).
88pub struct Tables {
89    pub sessions: Arc<Dataset>,
90    pub messages: Arc<Dataset>,
91    pub parts: Arc<Dataset>,
92}
93
94/// Result of a successful `run`.
95pub enum Outcome {
96    /// A rendered, row-capped table (already includes the metrics footer).
97    Inline(String),
98    /// A row-capped JSON payload with metadata fields (`total_rows`,
99    /// `shown_rows`, `truncated`, `elapsed_ms`, `columns`, `rows`).
100    InlineJson(JsonValue),
101    /// Encoded export bytes plus metadata for the caller's summary/resource.
102    Export {
103        bytes: Vec<u8>,
104        format: Format,
105        rows: usize,
106        columns: Vec<String>,
107    },
108}
109
110/// Two error channels: `Query` is caller-fixable (parse/plan/exec/limits) and
111/// the tool surfaces it as an `isError` result so the model self-corrects;
112/// `Infra` is an internal failure surfaced as a protocol error.
113#[derive(Debug)]
114pub enum SqlError {
115    Query(String),
116    Infra(anyhow::Error),
117}
118
119fn infra(error: ArrowError) -> SqlError {
120    SqlError::Infra(anyhow::Error::new(error))
121}
122
123/// Execute one read-only SQL query and return either a rendered table, a JSON
124/// payload, or encoded export bytes.
125pub async fn run(
126    tables: &Tables,
127    sql: &str,
128    mode: Mode,
129    inline_rows: usize,
130) -> Result<Outcome, SqlError> {
131    let parsed = parse_and_gate(sql)?;
132    if matches!(parsed.kind, StatementKind::Explain) && matches!(mode, Mode::Export(_)) {
133        return Err(SqlError::Query(
134            "EXPLAIN returns a plan, not a result set; use output=table (or json) to read it"
135                .to_owned(),
136        ));
137    }
138    if projection_mentions_vector(parsed.projection_query()) {
139        return Err(SqlError::Query(
140            "the `vector` column is not selectable from pond_sql_query (it is a \
141             FixedSizeList<f32> embedding, ~600 bytes per row and not useful in a result). \
142             For semantic search use pond_search. Filtering on it is allowed in WHERE \
143             (e.g. `vector IS NOT NULL`)."
144                .to_owned(),
145        ));
146    }
147    let ctx = build_context()?;
148    register(&ctx, tables)?;
149
150    // Defense in depth on top of the pre-parse gate: SQLOptions blocks DDL/DML
151    // at planning time. `allow_statements` stays false for a plain SELECT (the
152    // parse-time gate already rejects SET/SHOW etc.) but must be true for
153    // EXPLAIN, which DataFusion classifies as a Statement node. The inner
154    // query of an EXPLAIN was vetted by the gate above.
155    let options = SQLOptions::new()
156        .with_allow_ddl(false)
157        .with_allow_dml(false)
158        .with_allow_statements(matches!(parsed.kind, StatementKind::Explain));
159    let df = ctx
160        .sql_with_options(sql, options)
161        .await
162        .map_err(|error| SqlError::Query(format!("SQL error: {error}")))?;
163
164    // Captured before `collect()` consumes `df`, so an empty result still
165    // renders its column headers.
166    let result_schema = Arc::new(df.schema().as_arrow().clone());
167    let started = Instant::now();
168    let collected = tokio::time::timeout(QUERY_TIMEOUT, df.collect())
169        .await
170        .map_err(|_| {
171            SqlError::Query(format!(
172                "query exceeded the {}s limit; add a narrower WHERE or a LIMIT",
173                QUERY_TIMEOUT.as_secs()
174            ))
175        })?
176        .map_err(|error| SqlError::Query(format!("SQL error: {error}")))?;
177    let elapsed = started.elapsed();
178
179    let display: Vec<RecordBatch> = if collected.is_empty() {
180        vec![displayable(&RecordBatch::new_empty(result_schema)).map_err(infra)?]
181    } else {
182        collected
183            .iter()
184            .map(displayable)
185            .collect::<Result<_, _>>()
186            .map_err(infra)?
187    };
188
189    match mode {
190        Mode::Inline => Ok(Outcome::Inline(
191            render_inline(&display, inline_rows, elapsed).map_err(infra)?,
192        )),
193        Mode::InlineJson => Ok(Outcome::InlineJson(render_inline_json(
194            &display,
195            inline_rows,
196            elapsed,
197        )?)),
198        Mode::Export(format) => {
199            let rows = display.iter().map(RecordBatch::num_rows).sum();
200            let columns = display
201                .first()
202                .map(|batch| {
203                    batch
204                        .schema()
205                        .fields()
206                        .iter()
207                        .map(|field| field.name().clone())
208                        .collect::<Vec<_>>()
209                })
210                .unwrap_or_default();
211            let bytes = match format {
212                Format::Parquet => encode_parquet(&display)?,
213                Format::Ndjson => encode_ndjson(&display)?,
214            };
215            if bytes.len() > MAX_EXPORT_BYTES {
216                return Err(SqlError::Query(format!(
217                    "export is {} bytes, over the {MAX_EXPORT_BYTES} byte limit; \
218                     narrow the query or aggregate",
219                    bytes.len()
220                )));
221            }
222            Ok(Outcome::Export {
223                bytes,
224                format,
225                rows,
226                columns,
227            })
228        }
229    }
230}
231
232/// Top-level statement shape allowed past the read-only gate.
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234enum StatementKind {
235    /// A plain `Query` (SELECT/WITH/VALUES/UNION).
236    Query,
237    /// `EXPLAIN [ANALYZE] <query>` - planning info only, no mutation.
238    Explain,
239}
240
241/// Parsed top-level statement, normalized so downstream checks always see a
242/// projection-bearing `Query` regardless of whether the user wrote `SELECT`
243/// or `EXPLAIN SELECT`. DataFusion's parser wraps EXPLAIN in its own
244/// `DfStatement::Explain` variant (separate from sqlparser's
245/// `SqlStatement::Explain`), so the gate has to peel both layers.
246struct ParsedStatement {
247    kind: StatementKind,
248    query: lance::deps::datafusion::sql::sqlparser::ast::Query,
249}
250
251impl ParsedStatement {
252    fn projection_query(&self) -> &lance::deps::datafusion::sql::sqlparser::ast::Query {
253        &self.query
254    }
255}
256
257/// Read-only gate: parse the SQL and require exactly one top-level `Query` or
258/// `EXPLAIN <Query>`. Rejects DDL/DML/COPY/SET/SHOW and multi-statement input,
259/// which `SQLOptions` alone does not catch at planning time. EXPLAIN of a
260/// non-Query (e.g. `EXPLAIN INSERT ...`) is also rejected: EXPLAIN itself is
261/// read-only, but letting the inner shape be DDL/DML widens the surface area
262/// the gate has to reason about for no real agent gain.
263fn parse_and_gate(sql: &str) -> Result<ParsedStatement, SqlError> {
264    let statements = DFParser::parse_sql(sql)
265        .map_err(|error| SqlError::Query(format!("SQL parse error: {error}")))?;
266    if statements.len() != 1 {
267        return Err(SqlError::Query(
268            "pond_sql_query runs exactly one statement; submit a single SELECT".to_owned(),
269        ));
270    }
271    let Some(front) = statements.front() else {
272        return Err(read_only_rejection());
273    };
274    match front {
275        DfStatement::Statement(boxed) => match boxed.as_ref() {
276            SqlStatement::Query(query) => Ok(ParsedStatement {
277                kind: StatementKind::Query,
278                query: query.as_ref().clone(),
279            }),
280            _ => Err(read_only_rejection()),
281        },
282        DfStatement::Explain(explain) => match explain.statement.as_ref() {
283            DfStatement::Statement(inner) => match inner.as_ref() {
284                SqlStatement::Query(query) => Ok(ParsedStatement {
285                    kind: StatementKind::Explain,
286                    query: query.as_ref().clone(),
287                }),
288                _ => Err(read_only_rejection()),
289            },
290            _ => Err(read_only_rejection()),
291        },
292        _ => Err(read_only_rejection()),
293    }
294}
295
296fn read_only_rejection() -> SqlError {
297    SqlError::Query(
298        "pond_sql_query is read-only: only a single SELECT/WITH (or EXPLAIN of one) is \
299         allowed (no INSERT/UPDATE/DELETE/CREATE/DROP/COPY/SET)"
300            .to_owned(),
301    )
302}
303
304/// Reject any top-level projection that explicitly references the embedding
305/// `vector` column. Today such queries silently return an empty column (the
306/// FixedSizeList<f32> is stripped by `displayable`), which wastes agent tokens
307/// diagnosing. WHERE/HAVING references stay legal - the doc lets agents filter
308/// on it (e.g. `WHERE vector IS NOT NULL`); only projecting the column out is
309/// blocked. Heuristic: tokenize each top-level SELECT item and look for a bare
310/// `vector` identifier. Covers `SELECT vector`, `SELECT id, vector`,
311/// `SELECT m.vector`, and `SELECT array_length(vector)`. Wildcards (`*` /
312/// `messages.*`) keep the existing silent-strip behavior since they don't name
313/// the column explicitly.
314fn projection_mentions_vector(query: &lance::deps::datafusion::sql::sqlparser::ast::Query) -> bool {
315    walk_set_expr_for_vector(query.body.as_ref())
316}
317
318fn walk_set_expr_for_vector(expr: &SetExpr) -> bool {
319    match expr {
320        SetExpr::Select(select) => select
321            .projection
322            .iter()
323            .any(|item| mentions_vector_token(&item.to_string())),
324        SetExpr::Query(inner) => walk_set_expr_for_vector(inner.body.as_ref()),
325        SetExpr::SetOperation { left, right, .. } => {
326            walk_set_expr_for_vector(left) || walk_set_expr_for_vector(right)
327        }
328        _ => false,
329    }
330}
331
332fn mentions_vector_token(text: &str) -> bool {
333    text.split(|c: char| !c.is_alphanumeric() && c != '_')
334        .any(|token| token == "vector")
335}
336
337fn build_context() -> Result<SessionContext, SqlError> {
338    let runtime = RuntimeEnvBuilder::new()
339        .with_memory_limit(MEM_LIMIT_BYTES, 1.0)
340        .build_arc()
341        .map_err(|error| SqlError::Infra(anyhow!("datafusion runtime init failed: {error}")))?;
342    let state = SessionStateBuilder::new()
343        .with_config(SessionConfig::new())
344        .with_runtime_env(runtime)
345        .with_default_features()
346        .build();
347    Ok(SessionContext::new_with_state(state))
348}
349
350fn register(ctx: &SessionContext, tables: &Tables) -> Result<(), SqlError> {
351    for (name, dataset) in [
352        ("sessions", &tables.sessions),
353        ("messages", &tables.messages),
354        ("parts", &tables.parts),
355    ] {
356        // LanceTableProvider (not the bare Dataset impl) so WHERE/projection/
357        // limit push into Lance's indexed scan; (false, false) hides _rowid /
358        // _rowaddr from the SQL schema.
359        let provider = LanceTableProvider::new(dataset.clone(), false, false);
360        ctx.register_table(name, Arc::new(provider))
361            .map_err(|error| SqlError::Infra(anyhow!("register table {name}: {error}")))?;
362    }
363    // `fts('messages', '{...}')` BM25 search-in-SQL, and lance's JSON /
364    // contains_tokens UDFs for filtering inside the JSON columns.
365    let fts = FtsQueryUDTFBuilder::builder()
366        .register_table("sessions", tables.sessions.clone())
367        .register_table("messages", tables.messages.clone())
368        .register_table("parts", tables.parts.clone())
369        .build();
370    ctx.register_udtf("fts", Arc::new(fts));
371    register_functions(ctx);
372    Ok(())
373}
374
375/// Decode lance JSONB columns to JSON text, then drop columns that don't render
376/// readably (the embedding `vector` FixedSizeList and any leftover binary).
377fn displayable(batch: &RecordBatch) -> Result<RecordBatch, ArrowError> {
378    let decoded = lance_arrow::json::convert_lance_json_to_arrow(batch)?;
379    let keep: Vec<usize> = decoded
380        .schema()
381        .fields()
382        .iter()
383        .enumerate()
384        .filter(|(_, field)| is_displayable(field.data_type()))
385        .map(|(index, _)| index)
386        .collect();
387    decoded.project(&keep)
388}
389
390fn is_displayable(data_type: &DataType) -> bool {
391    !matches!(
392        data_type,
393        DataType::FixedSizeList(_, _)
394            | DataType::Binary
395            | DataType::LargeBinary
396            | DataType::BinaryView
397            | DataType::FixedSizeBinary(_)
398    )
399}
400
401fn render_inline(
402    display: &[RecordBatch],
403    max_rows: usize,
404    elapsed: Duration,
405) -> Result<String, ArrowError> {
406    let total: usize = display.iter().map(RecordBatch::num_rows).sum();
407    let elapsed_ms = elapsed.as_millis();
408    if total == 0 {
409        // Still render the header so the caller sees the result columns.
410        return Ok(format!(
411            "0 rows ({elapsed_ms} ms).\n{}",
412            pretty_format_batches(display)?
413        ));
414    }
415    let mut shown = total.min(max_rows);
416    let mut table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
417    while table.len() > INLINE_BUDGET_BYTES && shown > 1 {
418        shown = (shown / 2).max(1);
419        table = pretty_format_batches(&limit_batches(display, shown))?.to_string();
420    }
421    let mut out = format!("{total} row(s) in {elapsed_ms} ms; showing {shown}.\n{table}");
422    if shown < total {
423        out.push_str(&format!(
424            "\n... {} row(s) omitted. To page: ORDER BY <indexed col> (e.g. timestamp, \
425             id), then in the next call add `WHERE (col, id) < (<last_col>, <last_id>)` - \
426             keyset pagination, see schema://pond-sql. For the full set: output=parquet \
427             or output=ndjson.",
428            total - shown
429        ));
430    }
431    Ok(out)
432}
433
434/// JSON sibling of `render_inline`: same row cap and byte-budget shrinking,
435/// returned as a `JsonValue` so the MCP layer can hand it to
436/// `CallToolResult::structured` (text fallback + structured channel in one
437/// call, see [`Mode::InlineJson`]).
438fn render_inline_json(
439    display: &[RecordBatch],
440    max_rows: usize,
441    elapsed: Duration,
442) -> Result<JsonValue, SqlError> {
443    let total: usize = display.iter().map(RecordBatch::num_rows).sum();
444    let columns: Vec<String> = display
445        .first()
446        .map(|batch| {
447            batch
448                .schema()
449                .fields()
450                .iter()
451                .map(|field| field.name().clone())
452                .collect()
453        })
454        .unwrap_or_default();
455    let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
456
457    if total == 0 {
458        return Ok(json!({
459            "total_rows": 0,
460            "shown_rows": 0,
461            "truncated": false,
462            "elapsed_ms": elapsed_ms,
463            "columns": columns,
464            "rows": [],
465        }));
466    }
467
468    let mut shown = total.min(max_rows);
469    let mut rows = batches_to_json_rows(&limit_batches(display, shown))?;
470    let mut serialized = serde_json::to_string(&rows)
471        .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
472    while serialized.len() > INLINE_BUDGET_BYTES && shown > 1 {
473        shown = (shown / 2).max(1);
474        rows = batches_to_json_rows(&limit_batches(display, shown))?;
475        serialized = serde_json::to_string(&rows)
476            .map_err(|error| SqlError::Infra(anyhow!("json serialize: {error}")))?;
477    }
478
479    let mut payload = JsonMap::new();
480    payload.insert("total_rows".to_owned(), json!(total));
481    payload.insert("shown_rows".to_owned(), json!(shown));
482    payload.insert("truncated".to_owned(), json!(shown < total));
483    payload.insert("elapsed_ms".to_owned(), json!(elapsed_ms));
484    payload.insert("columns".to_owned(), json!(columns));
485    payload.insert("rows".to_owned(), JsonValue::Array(rows));
486    if shown < total {
487        payload.insert(
488            "next_steps".to_owned(),
489            json!(format!(
490                "{} row(s) omitted; ORDER BY + keyset (`WHERE (col, id) < \
491                 (<last_col>, <last_id>)`) to page, or output=parquet|ndjson for the \
492                 full set. See schema://pond-sql.",
493                total - shown
494            )),
495        );
496    }
497    Ok(JsonValue::Object(payload))
498}
499
500/// Convert RecordBatches to a JSON array of row objects via the existing
501/// NDJSON writer (handles all Arrow types, including the decoded JSON columns
502/// that come out of `displayable`).
503fn batches_to_json_rows(batches: &[RecordBatch]) -> Result<Vec<JsonValue>, SqlError> {
504    if batches.iter().all(|batch| batch.num_rows() == 0) {
505        return Ok(Vec::new());
506    }
507    let mut buffer = Vec::new();
508    {
509        let mut writer = LineDelimitedWriter::new(&mut buffer);
510        let refs: Vec<&RecordBatch> = batches.iter().collect();
511        writer
512            .write_batches(&refs)
513            .map_err(|error| SqlError::Infra(anyhow!("ndjson encode: {error}")))?;
514        writer
515            .finish()
516            .map_err(|error| SqlError::Infra(anyhow!("ndjson finish: {error}")))?;
517    }
518    let text = String::from_utf8(buffer)
519        .map_err(|error| SqlError::Infra(anyhow!("ndjson not utf-8: {error}")))?;
520    text.lines()
521        .filter(|line| !line.is_empty())
522        .map(|line| {
523            serde_json::from_str::<JsonValue>(line)
524                .map_err(|error| SqlError::Infra(anyhow!("ndjson parse: {error}")))
525        })
526        .collect()
527}
528
529fn limit_batches(batches: &[RecordBatch], max_rows: usize) -> Vec<RecordBatch> {
530    let mut out = Vec::new();
531    let mut remaining = max_rows;
532    for batch in batches {
533        if remaining == 0 {
534            break;
535        }
536        if batch.num_rows() <= remaining {
537            remaining -= batch.num_rows();
538            out.push(batch.clone());
539        } else {
540            out.push(batch.slice(0, remaining));
541            remaining = 0;
542        }
543    }
544    out
545}
546
547fn encode_parquet(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
548    let schema = batches
549        .first()
550        .map(RecordBatch::schema)
551        .ok_or_else(|| SqlError::Query("query returned no columns to export".to_owned()))?;
552    let mut buffer = Vec::new();
553    let mut writer = ArrowWriter::try_new(&mut buffer, schema, None)
554        .map_err(|error| SqlError::Infra(anyhow!("parquet init failed: {error}")))?;
555    for batch in batches {
556        writer
557            .write(batch)
558            .map_err(|error| SqlError::Infra(anyhow!("parquet write failed: {error}")))?;
559    }
560    writer
561        .close()
562        .map_err(|error| SqlError::Infra(anyhow!("parquet close failed: {error}")))?;
563    Ok(buffer)
564}
565
566fn encode_ndjson(batches: &[RecordBatch]) -> Result<Vec<u8>, SqlError> {
567    let mut buffer = Vec::new();
568    {
569        let mut writer = LineDelimitedWriter::new(&mut buffer);
570        let refs: Vec<&RecordBatch> = batches.iter().collect();
571        writer
572            .write_batches(&refs)
573            .map_err(|error| SqlError::Infra(anyhow!("ndjson write failed: {error}")))?;
574        writer
575            .finish()
576            .map_err(|error| SqlError::Infra(anyhow!("ndjson finish failed: {error}")))?;
577    }
578    Ok(buffer)
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    fn rejected(sql: &str) -> bool {
586        matches!(parse_and_gate(sql), Err(SqlError::Query(_)))
587    }
588
589    fn parses_as(sql: &str, expected: StatementKind) -> bool {
590        match parse_and_gate(sql) {
591            Ok(parsed) => matches!(
592                (&parsed.kind, &expected),
593                (StatementKind::Query, StatementKind::Query)
594                    | (StatementKind::Explain, StatementKind::Explain)
595            ),
596            Err(_) => false,
597        }
598    }
599
600    #[test]
601    fn allows_single_select_and_cte() {
602        assert!(parses_as("SELECT 1", StatementKind::Query));
603        assert!(parses_as(
604            "SELECT role, count(*) FROM messages GROUP BY role",
605            StatementKind::Query
606        ));
607        assert!(parses_as(
608            "WITH t AS (SELECT 1 AS a) SELECT a FROM t",
609            StatementKind::Query
610        ));
611    }
612
613    #[test]
614    fn allows_explain_of_select() {
615        assert!(parses_as("EXPLAIN SELECT 1", StatementKind::Explain));
616        assert!(parses_as(
617            "EXPLAIN ANALYZE SELECT role FROM messages",
618            StatementKind::Explain
619        ));
620    }
621
622    #[test]
623    fn rejects_explain_of_non_query() {
624        // EXPLAIN of a side-effecting statement: the inner statement is what
625        // would matter; reject to keep the surface tight.
626        assert!(rejected("EXPLAIN INSERT INTO messages VALUES ('x')"));
627    }
628
629    #[test]
630    fn rejects_writes_and_side_effects() {
631        assert!(rejected("INSERT INTO messages VALUES ('x')"));
632        assert!(rejected("UPDATE messages SET role = 'x'"));
633        assert!(rejected("DELETE FROM messages"));
634        assert!(rejected("CREATE TABLE t (x INT)"));
635        assert!(rejected("CREATE VIEW v AS SELECT 1"));
636        assert!(rejected("DROP TABLE messages"));
637        assert!(rejected(
638            "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/etc'"
639        ));
640        assert!(rejected("COPY (SELECT 1) TO '/tmp/x.parquet'"));
641        assert!(rejected("SET a = 1"));
642    }
643
644    #[test]
645    fn rejects_multiple_statements() {
646        assert!(rejected("SELECT 1; SELECT 2"));
647        assert!(rejected("SELECT 1; DROP TABLE messages"));
648    }
649
650    #[test]
651    fn rejects_unparseable() {
652        assert!(rejected("NOT SQL AT ALL ;;"));
653    }
654
655    fn mentions_vector(sql: &str) -> bool {
656        match parse_and_gate(sql) {
657            Ok(parsed) => projection_mentions_vector(parsed.projection_query()),
658            Err(_) => false,
659        }
660    }
661
662    #[test]
663    fn explicit_vector_projection_is_rejected() {
664        assert!(mentions_vector("SELECT vector FROM messages"));
665        assert!(mentions_vector("SELECT id, vector FROM messages"));
666        assert!(mentions_vector("SELECT m.vector FROM messages m"));
667        assert!(mentions_vector("SELECT array_length(vector) FROM messages"));
668        assert!(mentions_vector("EXPLAIN SELECT vector FROM messages"));
669    }
670
671    #[test]
672    fn select_star_and_where_vector_are_allowed() {
673        // `SELECT *` falls through to the existing silent-strip in displayable.
674        assert!(!mentions_vector("SELECT * FROM messages"));
675        // Filtering on `vector` is documented as legal (`vector IS NOT NULL`).
676        assert!(!mentions_vector(
677            "SELECT id FROM messages WHERE vector IS NOT NULL"
678        ));
679    }
680}