Skip to main content

palimpsest_sql/
parser.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SQL parsing entry points (Postgres dialect).
5
6use core::ops::ControlFlow;
7
8use sqlparser::{
9    ast::{
10        BinaryOperator, Expr, JoinConstraint, JoinOperator, Query, Select, SetExpr, Statement,
11        TableFactor, TableWithJoins, Visit, Visitor,
12    },
13    dialect::PostgreSqlDialect,
14    parser::Parser,
15};
16
17use crate::{limits::enforce_input_size, QueryLimits, SqlError};
18
19/// Parses a single `SELECT` statement under the default
20/// [`QueryLimits`].
21///
22/// # Errors
23/// Returns [`SqlError`] if the input fails parsing, validation, or
24/// size limits.
25pub fn parse_select(sql: &str) -> Result<Statement, SqlError> {
26    parse_select_with_limits(sql, QueryLimits::DEFAULT)
27}
28
29/// Like [`parse_select`] but with a caller-supplied [`QueryLimits`].
30///
31/// # Errors
32/// Returns [`SqlError::QueryTooLarge`] before invoking the parser if
33/// the input exceeds `limits.max_input_bytes`; otherwise propagates
34/// any parse / validation error.
35pub fn parse_select_with_limits(sql: &str, limits: QueryLimits) -> Result<Statement, SqlError> {
36    enforce_input_size(sql, limits)?;
37    let dialect = PostgreSqlDialect {};
38    let mut statements = Parser::parse_sql(&dialect, sql)?;
39
40    if statements.len() != 1 {
41        return Err(SqlError::StatementCount(statements.len()));
42    }
43
44    let statement = statements.remove(0);
45    let Statement::Query(query) = &statement else {
46        return Err(SqlError::UnsupportedStatement);
47    };
48
49    validate_query(query)?;
50    Ok(statement)
51}
52
53/// Walks a parsed query tree and rejects features outside the v1
54/// supported surface (recursive CTEs, ORDER BY without LIMIT, etc).
55///
56/// # Errors
57/// Returns [`SqlError::UnsupportedFeature`] (or related variants) on
58/// the first construct that lies outside the supported surface.
59pub fn validate_query(query: &Query) -> Result<(), SqlError> {
60    if let Some(with) = &query.with {
61        if with.recursive {
62            return Err(SqlError::UnsupportedFeature("recursive CTEs"));
63        }
64
65        for cte in &with.cte_tables {
66            validate_query(&cte.query)?;
67        }
68    }
69
70    // ORDER BY without LIMIT is allowed: the lowerer plans it as a
71    // `TopK` with `limit = usize::MAX`, i.e. "sort the whole result
72    // set." Cheap for the small result-sets typical of live
73    // subscriptions; the QueryLimits node-count budget is the real
74    // ceiling on how big a sort the server will accept.
75
76    validate_expression_surface(query)?;
77    validate_set_expr(&query.body)
78}
79
80fn validate_set_expr(expr: &SetExpr) -> Result<(), SqlError> {
81    match expr {
82        SetExpr::Select(select) => validate_select(select),
83        SetExpr::Query(query) => validate_query(query),
84        SetExpr::SetOperation { left, right, .. } => {
85            validate_set_expr(left)?;
86            validate_set_expr(right)
87        }
88        SetExpr::Values(_) => Err(SqlError::UnsupportedFeature("VALUES queries")),
89        SetExpr::Insert(_) => Err(SqlError::UnsupportedFeature("INSERT in query body")),
90        SetExpr::Update(_) => Err(SqlError::UnsupportedFeature("UPDATE in query body")),
91        SetExpr::Table(_) => Err(SqlError::UnsupportedFeature("TABLE queries")),
92    }
93}
94
95fn validate_select(select: &Select) -> Result<(), SqlError> {
96    for table in &select.from {
97        validate_table_with_joins(table)?;
98    }
99
100    Ok(())
101}
102
103fn validate_table_with_joins(table: &TableWithJoins) -> Result<(), SqlError> {
104    validate_table_factor(&table.relation)?;
105
106    for join in &table.joins {
107        match &join.join_operator {
108            JoinOperator::RightOuter(_) => {
109                return Err(SqlError::UnsupportedFeature("RIGHT JOIN"));
110            }
111            JoinOperator::FullOuter(_) => {
112                return Err(SqlError::UnsupportedFeature("FULL JOIN"));
113            }
114            JoinOperator::Inner(constraint) | JoinOperator::LeftOuter(constraint) => {
115                validate_join_constraint(constraint)?;
116            }
117            JoinOperator::CrossJoin => {}
118            _ => return Err(SqlError::UnsupportedFeature("non-standard joins")),
119        }
120
121        validate_table_factor(&join.relation)?;
122    }
123
124    Ok(())
125}
126
127fn validate_table_factor(table: &TableFactor) -> Result<(), SqlError> {
128    match table {
129        TableFactor::Table { .. } => Ok(()),
130        TableFactor::Derived { subquery, .. } => validate_query(subquery),
131        _ => Err(SqlError::UnsupportedFeature(
132            "table functions or special table factors",
133        )),
134    }
135}
136
137fn validate_join_constraint(constraint: &JoinConstraint) -> Result<(), SqlError> {
138    match constraint {
139        JoinConstraint::On(expr) if is_equi_join_predicate(expr) => Ok(()),
140        JoinConstraint::On(_) => Err(SqlError::UnsupportedFeature("theta joins")),
141        JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None => Ok(()),
142    }
143}
144
145fn is_equi_join_predicate(expr: &Expr) -> bool {
146    match expr {
147        Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
148            matches!(
149                left.as_ref(),
150                Expr::Identifier(_) | Expr::CompoundIdentifier(_)
151            ) && matches!(
152                right.as_ref(),
153                Expr::Identifier(_) | Expr::CompoundIdentifier(_)
154            )
155        }
156        Expr::BinaryOp {
157            left,
158            op: BinaryOperator::And,
159            right,
160        } => is_equi_join_predicate(left) && is_equi_join_predicate(right),
161        _ => false,
162    }
163}
164
165fn validate_expression_surface(query: &Query) -> Result<(), SqlError> {
166    let mut visitor = UnsupportedExprVisitor;
167    match query.visit(&mut visitor) {
168        ControlFlow::Continue(()) => Ok(()),
169        ControlFlow::Break(feature) => Err(SqlError::UnsupportedFeature(feature)),
170    }
171}
172
173struct UnsupportedExprVisitor;
174
175impl Visitor for UnsupportedExprVisitor {
176    type Break = &'static str;
177
178    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
179        match expr {
180            Expr::Function(function) if function.over.is_some() => {
181                ControlFlow::Break("window functions")
182            }
183            Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::Subquery(_) => {
184                ControlFlow::Break("scalar subqueries with unbounded result")
185            }
186            _ => ControlFlow::Continue(()),
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::parse_select;
194
195    #[test]
196    fn parses_postgres_cte() {
197        parse_select(
198            "WITH recent_posts AS (
199                SELECT id, author_id FROM posts WHERE created_at > now() - interval '1 day'
200             )
201             SELECT id FROM recent_posts ORDER BY id LIMIT 10",
202        )
203        .expect("CTE query should parse");
204    }
205
206    #[test]
207    fn rejects_recursive_cte() {
208        let err = parse_select(
209            "WITH RECURSIVE nums(n) AS (
210                SELECT 1 UNION ALL SELECT n + 1 FROM nums WHERE n < 10
211             )
212             SELECT n FROM nums",
213        )
214        .expect_err("recursive CTEs are out of scope for v1");
215
216        assert!(err.to_string().contains("recursive CTEs"));
217    }
218
219    #[test]
220    fn accepts_order_by_without_limit() {
221        // Lowered as TopK { limit: usize::MAX } — "sort the whole
222        // result set." See `lower.rs::lower_query_with_context`.
223        parse_select("SELECT id FROM posts ORDER BY created_at")
224            .expect("ORDER BY without LIMIT is supported");
225    }
226
227    #[test]
228    fn rejects_right_join() {
229        let err = parse_select(
230            "SELECT posts.id
231             FROM posts RIGHT JOIN authors ON posts.author_id = authors.id",
232        )
233        .expect_err("RIGHT JOIN is out of scope for v1");
234
235        assert!(err.to_string().contains("RIGHT JOIN"));
236    }
237
238    #[test]
239    fn rejects_theta_join() {
240        let err = parse_select(
241            "SELECT posts.id
242             FROM posts JOIN authors ON posts.author_id > authors.id",
243        )
244        .expect_err("theta joins are out of scope for v1");
245
246        assert!(err.to_string().contains("theta joins"));
247    }
248
249    #[test]
250    fn rejects_window_functions() {
251        let err = parse_select(
252            "SELECT row_number() OVER (PARTITION BY author_id ORDER BY created_at)
253             FROM posts",
254        )
255        .expect_err("window functions are out of scope for v1");
256
257        assert!(err.to_string().contains("window functions"));
258    }
259
260    #[test]
261    fn rejects_scalar_subqueries() {
262        let err = parse_select("SELECT (SELECT max(id) FROM posts) FROM authors")
263            .expect_err("scalar subqueries are out of scope for v1");
264
265        assert!(err.to_string().contains("scalar subqueries"));
266    }
267}