1use core::ops::ControlFlow;
7
8use sqlparser::{
9 ast::{
10 BinaryOperator, Expr, JoinConstraint, JoinOperator, Query, Select, SetExpr, Statement,
11 TableFactor, TableWithJoins, Visit, Visitor,
12 },
13 dialect::PostgreSqlDialect,
14 parser::Parser,
15};
16
17use crate::{limits::enforce_input_size, QueryLimits, SqlError};
18
19pub fn parse_select(sql: &str) -> Result<Statement, SqlError> {
26 parse_select_with_limits(sql, QueryLimits::DEFAULT)
27}
28
29pub fn parse_select_with_limits(sql: &str, limits: QueryLimits) -> Result<Statement, SqlError> {
36 enforce_input_size(sql, limits)?;
37 let dialect = PostgreSqlDialect {};
38 let mut statements = Parser::parse_sql(&dialect, sql)?;
39
40 if statements.len() != 1 {
41 return Err(SqlError::StatementCount(statements.len()));
42 }
43
44 let statement = statements.remove(0);
45 let Statement::Query(query) = &statement else {
46 return Err(SqlError::UnsupportedStatement);
47 };
48
49 validate_query(query)?;
50 Ok(statement)
51}
52
53pub fn validate_query(query: &Query) -> Result<(), SqlError> {
60 if let Some(with) = &query.with {
61 if with.recursive {
62 return Err(SqlError::UnsupportedFeature("recursive CTEs"));
63 }
64
65 for cte in &with.cte_tables {
66 validate_query(&cte.query)?;
67 }
68 }
69
70 validate_expression_surface(query)?;
77 validate_set_expr(&query.body)
78}
79
80fn validate_set_expr(expr: &SetExpr) -> Result<(), SqlError> {
81 match expr {
82 SetExpr::Select(select) => validate_select(select),
83 SetExpr::Query(query) => validate_query(query),
84 SetExpr::SetOperation { left, right, .. } => {
85 validate_set_expr(left)?;
86 validate_set_expr(right)
87 }
88 SetExpr::Values(_) => Err(SqlError::UnsupportedFeature("VALUES queries")),
89 SetExpr::Insert(_) => Err(SqlError::UnsupportedFeature("INSERT in query body")),
90 SetExpr::Update(_) => Err(SqlError::UnsupportedFeature("UPDATE in query body")),
91 SetExpr::Table(_) => Err(SqlError::UnsupportedFeature("TABLE queries")),
92 }
93}
94
95fn validate_select(select: &Select) -> Result<(), SqlError> {
96 for table in &select.from {
97 validate_table_with_joins(table)?;
98 }
99
100 Ok(())
101}
102
103fn validate_table_with_joins(table: &TableWithJoins) -> Result<(), SqlError> {
104 validate_table_factor(&table.relation)?;
105
106 for join in &table.joins {
107 match &join.join_operator {
108 JoinOperator::RightOuter(_) => {
109 return Err(SqlError::UnsupportedFeature("RIGHT JOIN"));
110 }
111 JoinOperator::FullOuter(_) => {
112 return Err(SqlError::UnsupportedFeature("FULL JOIN"));
113 }
114 JoinOperator::Inner(constraint) | JoinOperator::LeftOuter(constraint) => {
115 validate_join_constraint(constraint)?;
116 }
117 JoinOperator::CrossJoin => {}
118 _ => return Err(SqlError::UnsupportedFeature("non-standard joins")),
119 }
120
121 validate_table_factor(&join.relation)?;
122 }
123
124 Ok(())
125}
126
127fn validate_table_factor(table: &TableFactor) -> Result<(), SqlError> {
128 match table {
129 TableFactor::Table { .. } => Ok(()),
130 TableFactor::Derived { subquery, .. } => validate_query(subquery),
131 _ => Err(SqlError::UnsupportedFeature(
132 "table functions or special table factors",
133 )),
134 }
135}
136
137fn validate_join_constraint(constraint: &JoinConstraint) -> Result<(), SqlError> {
138 match constraint {
139 JoinConstraint::On(expr) if is_equi_join_predicate(expr) => Ok(()),
140 JoinConstraint::On(_) => Err(SqlError::UnsupportedFeature("theta joins")),
141 JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None => Ok(()),
142 }
143}
144
145fn is_equi_join_predicate(expr: &Expr) -> bool {
146 match expr {
147 Expr::BinaryOp { left, op, right } if *op == BinaryOperator::Eq => {
148 matches!(
149 left.as_ref(),
150 Expr::Identifier(_) | Expr::CompoundIdentifier(_)
151 ) && matches!(
152 right.as_ref(),
153 Expr::Identifier(_) | Expr::CompoundIdentifier(_)
154 )
155 }
156 Expr::BinaryOp {
157 left,
158 op: BinaryOperator::And,
159 right,
160 } => is_equi_join_predicate(left) && is_equi_join_predicate(right),
161 _ => false,
162 }
163}
164
165fn validate_expression_surface(query: &Query) -> Result<(), SqlError> {
166 let mut visitor = UnsupportedExprVisitor;
167 match query.visit(&mut visitor) {
168 ControlFlow::Continue(()) => Ok(()),
169 ControlFlow::Break(feature) => Err(SqlError::UnsupportedFeature(feature)),
170 }
171}
172
173struct UnsupportedExprVisitor;
174
175impl Visitor for UnsupportedExprVisitor {
176 type Break = &'static str;
177
178 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
179 match expr {
180 Expr::Function(function) if function.over.is_some() => {
181 ControlFlow::Break("window functions")
182 }
183 Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::Subquery(_) => {
184 ControlFlow::Break("scalar subqueries with unbounded result")
185 }
186 _ => ControlFlow::Continue(()),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::parse_select;
194
195 #[test]
196 fn parses_postgres_cte() {
197 parse_select(
198 "WITH recent_posts AS (
199 SELECT id, author_id FROM posts WHERE created_at > now() - interval '1 day'
200 )
201 SELECT id FROM recent_posts ORDER BY id LIMIT 10",
202 )
203 .expect("CTE query should parse");
204 }
205
206 #[test]
207 fn rejects_recursive_cte() {
208 let err = parse_select(
209 "WITH RECURSIVE nums(n) AS (
210 SELECT 1 UNION ALL SELECT n + 1 FROM nums WHERE n < 10
211 )
212 SELECT n FROM nums",
213 )
214 .expect_err("recursive CTEs are out of scope for v1");
215
216 assert!(err.to_string().contains("recursive CTEs"));
217 }
218
219 #[test]
220 fn accepts_order_by_without_limit() {
221 parse_select("SELECT id FROM posts ORDER BY created_at")
224 .expect("ORDER BY without LIMIT is supported");
225 }
226
227 #[test]
228 fn rejects_right_join() {
229 let err = parse_select(
230 "SELECT posts.id
231 FROM posts RIGHT JOIN authors ON posts.author_id = authors.id",
232 )
233 .expect_err("RIGHT JOIN is out of scope for v1");
234
235 assert!(err.to_string().contains("RIGHT JOIN"));
236 }
237
238 #[test]
239 fn rejects_theta_join() {
240 let err = parse_select(
241 "SELECT posts.id
242 FROM posts JOIN authors ON posts.author_id > authors.id",
243 )
244 .expect_err("theta joins are out of scope for v1");
245
246 assert!(err.to_string().contains("theta joins"));
247 }
248
249 #[test]
250 fn rejects_window_functions() {
251 let err = parse_select(
252 "SELECT row_number() OVER (PARTITION BY author_id ORDER BY created_at)
253 FROM posts",
254 )
255 .expect_err("window functions are out of scope for v1");
256
257 assert!(err.to_string().contains("window functions"));
258 }
259
260 #[test]
261 fn rejects_scalar_subqueries() {
262 let err = parse_select("SELECT (SELECT max(id) FROM posts) FROM authors")
263 .expect_err("scalar subqueries are out of scope for v1");
264
265 assert!(err.to_string().contains("scalar subqueries"));
266 }
267}