Skip to main content

rigsql_parser/grammar/
postgres.rs

1use rigsql_core::{NodeSegment, Segment, SegmentType, TokenKind};
2
3use crate::context::ParseContext;
4
5use super::ansi::ANSI_STATEMENT_KEYWORDS;
6use super::{eat_trivia_segments, parse_comma_separated, token_segment, Grammar};
7
8/// PostgreSQL grammar — extends ANSI with PostgreSQL-specific syntax.
9pub struct PostgresGrammar;
10
11impl Grammar for PostgresGrammar {
12    fn statement_keywords(&self) -> &[&str] {
13        ANSI_STATEMENT_KEYWORDS
14    }
15
16    fn dispatch_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
17        self.dispatch_ansi_statement(ctx)
18    }
19
20    // ── Override: SELECT clause to support DISTINCT ON ────────────
21
22    fn parse_select_clause(&self, ctx: &mut ParseContext) -> Option<Segment> {
23        let mut children = Vec::new();
24
25        let kw = ctx.eat_keyword("SELECT")?;
26        children.push(token_segment(kw, SegmentType::Keyword));
27        children.extend(eat_trivia_segments(ctx));
28
29        // DISTINCT ON (expr, ...) or DISTINCT or ALL
30        if ctx.peek_keyword("DISTINCT") {
31            let distinct_kw = ctx.advance().unwrap();
32            children.push(token_segment(distinct_kw, SegmentType::Keyword));
33            children.extend(eat_trivia_segments(ctx));
34
35            // PostgreSQL: DISTINCT ON (col1, col2, ...)
36            if ctx.peek_keyword("ON") {
37                let on_kw = ctx.advance().unwrap();
38                children.push(token_segment(on_kw, SegmentType::Keyword));
39                children.extend(eat_trivia_segments(ctx));
40
41                if ctx.peek_kind() == Some(TokenKind::LParen) {
42                    if let Some(cols) = self.parse_paren_block(ctx) {
43                        children.push(cols);
44                    }
45                }
46                children.extend(eat_trivia_segments(ctx));
47            }
48        } else if ctx.peek_keyword("ALL") {
49            let all_kw = ctx.advance().unwrap();
50            children.push(token_segment(all_kw, SegmentType::Keyword));
51            children.extend(eat_trivia_segments(ctx));
52        }
53
54        // Select targets (comma-separated expressions)
55        parse_comma_separated(ctx, &mut children, |c| self.parse_select_target(c));
56
57        Some(Segment::Node(NodeSegment::new(
58            SegmentType::SelectClause,
59            children,
60        )))
61    }
62
63    // ── Override: INSERT to support ON CONFLICT and RETURNING ──────
64
65    fn parse_insert_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
66        let mut children = Vec::new();
67        let kw = ctx.eat_keyword("INSERT")?;
68        children.push(token_segment(kw, SegmentType::Keyword));
69        children.extend(eat_trivia_segments(ctx));
70
71        let into_kw = ctx.eat_keyword("INTO")?;
72        children.push(token_segment(into_kw, SegmentType::Keyword));
73        children.extend(eat_trivia_segments(ctx));
74
75        // Table name
76        if let Some(name) = self.parse_qualified_name(ctx) {
77            children.push(name);
78        }
79        children.extend(eat_trivia_segments(ctx));
80
81        // Optional column list
82        if ctx.peek_kind() == Some(TokenKind::LParen) {
83            if let Some(cols) = self.parse_paren_block(ctx) {
84                children.push(cols);
85                children.extend(eat_trivia_segments(ctx));
86            }
87        }
88
89        // VALUES or SELECT
90        if ctx.peek_keyword("VALUES") {
91            if let Some(vals) = self.parse_values_clause(ctx) {
92                children.push(vals);
93            }
94        } else if ctx.peek_keyword("SELECT") || ctx.peek_keyword("WITH") {
95            if let Some(sel) = self.parse_select_statement(ctx) {
96                children.push(sel);
97            }
98        }
99
100        // ON CONFLICT clause (PostgreSQL upsert)
101        children.extend(eat_trivia_segments(ctx));
102        if ctx.peek_keyword("ON") {
103            if let Some(oc) = self.parse_on_conflict_clause(ctx) {
104                children.push(oc);
105            }
106        }
107
108        // RETURNING clause
109        children.extend(eat_trivia_segments(ctx));
110        if ctx.peek_keyword("RETURNING") {
111            if let Some(ret) = self.parse_returning_clause(ctx) {
112                children.push(ret);
113            }
114        }
115
116        Some(Segment::Node(NodeSegment::new(
117            SegmentType::InsertStatement,
118            children,
119        )))
120    }
121
122    // ── Override: UPDATE to support RETURNING ──────────────────────
123
124    fn parse_update_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
125        let mut children = Vec::new();
126        let kw = ctx.eat_keyword("UPDATE")?;
127        children.push(token_segment(kw, SegmentType::Keyword));
128        children.extend(eat_trivia_segments(ctx));
129
130        // Table name
131        if let Some(name) = self.parse_table_reference(ctx) {
132            children.push(name);
133        }
134        children.extend(eat_trivia_segments(ctx));
135
136        // SET clause
137        if ctx.peek_keyword("SET") {
138            if let Some(set) = self.parse_set_clause(ctx) {
139                children.push(set);
140            }
141        }
142
143        // WHERE clause
144        children.extend(eat_trivia_segments(ctx));
145        if ctx.peek_keyword("WHERE") {
146            if let Some(wh) = self.parse_where_clause(ctx) {
147                children.push(wh);
148            }
149        }
150
151        // RETURNING clause
152        children.extend(eat_trivia_segments(ctx));
153        if ctx.peek_keyword("RETURNING") {
154            if let Some(ret) = self.parse_returning_clause(ctx) {
155                children.push(ret);
156            }
157        }
158
159        Some(Segment::Node(NodeSegment::new(
160            SegmentType::UpdateStatement,
161            children,
162        )))
163    }
164
165    // ── Override: DELETE to support RETURNING ──────────────────────
166
167    fn parse_delete_statement(&self, ctx: &mut ParseContext) -> Option<Segment> {
168        let mut children = Vec::new();
169        let kw = ctx.eat_keyword("DELETE")?;
170        children.push(token_segment(kw, SegmentType::Keyword));
171        children.extend(eat_trivia_segments(ctx));
172
173        // FROM
174        if ctx.peek_keyword("FROM") {
175            let from_kw = ctx.advance().unwrap();
176            children.push(token_segment(from_kw, SegmentType::Keyword));
177            children.extend(eat_trivia_segments(ctx));
178        }
179
180        // Table name
181        if let Some(name) = self.parse_qualified_name(ctx) {
182            children.push(name);
183        }
184
185        // WHERE clause
186        children.extend(eat_trivia_segments(ctx));
187        if ctx.peek_keyword("WHERE") {
188            if let Some(wh) = self.parse_where_clause(ctx) {
189                children.push(wh);
190            }
191        }
192
193        // RETURNING clause
194        children.extend(eat_trivia_segments(ctx));
195        if ctx.peek_keyword("RETURNING") {
196            if let Some(ret) = self.parse_returning_clause(ctx) {
197                children.push(ret);
198            }
199        }
200
201        Some(Segment::Node(NodeSegment::new(
202            SegmentType::DeleteStatement,
203            children,
204        )))
205    }
206
207    // ── Override: unary expression to add :: postfix cast ──────────
208
209    fn parse_unary_expression(&self, ctx: &mut ParseContext) -> Option<Segment> {
210        // Handle unary +/- prefix
211        if let Some(kind) = ctx.peek_kind() {
212            if matches!(kind, TokenKind::Plus | TokenKind::Minus) {
213                let op = ctx.advance().unwrap();
214                let mut children = vec![token_segment(op, SegmentType::ArithmeticOperator)];
215                children.extend(eat_trivia_segments(ctx));
216                if let Some(expr) = self.parse_primary_expression(ctx) {
217                    children.push(expr);
218                }
219                let base = Segment::Node(NodeSegment::new(SegmentType::UnaryExpression, children));
220                return Some(self.parse_postfix(ctx, base));
221            }
222        }
223        let base = self.parse_primary_expression(ctx)?;
224        Some(self.parse_postfix(ctx, base))
225    }
226}
227
228// ── PostgreSQL-specific parsing methods ────────────────────────────
229
230impl PostgresGrammar {
231    /// Parse postfix operators: `::type` cast and `[idx]` array subscript.
232    /// Loops to handle chaining: `arr[1]::text`, `col::int[]`.
233    fn parse_postfix(&self, ctx: &mut ParseContext, mut expr: Segment) -> Segment {
234        loop {
235            // Peek ahead past trivia without consuming, to avoid Vec allocation
236            // on the common path where no postfix operator follows.
237            let save = ctx.save();
238            eat_trivia_segments(ctx);
239            let next = ctx.peek_kind();
240            ctx.restore(save);
241
242            if next != Some(TokenKind::ColonColon) && next != Some(TokenKind::LBracket) {
243                break;
244            }
245
246            let trivia = eat_trivia_segments(ctx);
247
248            // :: type cast
249            if ctx.peek_kind() == Some(TokenKind::ColonColon) {
250                let cc = ctx.advance().unwrap();
251                let mut children = vec![expr];
252                children.extend(trivia);
253                children.push(token_segment(cc, SegmentType::Operator));
254                children.extend(eat_trivia_segments(ctx));
255                if let Some(dt) = self.parse_data_type(ctx) {
256                    children.push(dt);
257                }
258                // Handle array type suffix: ::int[]
259                let save2 = ctx.save();
260                if ctx.peek_kind() == Some(TokenKind::LBracket) {
261                    let lb = ctx.advance().unwrap();
262                    if ctx.peek_kind() == Some(TokenKind::RBracket) {
263                        let rb = ctx.advance().unwrap();
264                        children.push(token_segment(lb, SegmentType::Operator));
265                        children.push(token_segment(rb, SegmentType::Operator));
266                    } else {
267                        ctx.restore(save2);
268                    }
269                }
270                expr = Segment::Node(NodeSegment::new(SegmentType::TypeCastExpression, children));
271                continue;
272            }
273
274            // [idx] array subscript
275            if ctx.peek_kind() == Some(TokenKind::LBracket) {
276                let lb = ctx.advance().unwrap();
277                let mut children = vec![expr];
278                children.extend(trivia);
279                children.push(token_segment(lb, SegmentType::Operator));
280                children.extend(eat_trivia_segments(ctx));
281                if let Some(idx) = self.parse_expression(ctx) {
282                    children.push(idx);
283                }
284                children.extend(eat_trivia_segments(ctx));
285                if let Some(rb) = ctx.eat_kind(TokenKind::RBracket) {
286                    children.push(token_segment(rb, SegmentType::Operator));
287                }
288                expr = Segment::Node(NodeSegment::new(
289                    SegmentType::ArrayAccessExpression,
290                    children,
291                ));
292                continue;
293            }
294
295            // Unreachable: we only enter the loop body when next is :: or [
296            unreachable!();
297        }
298        expr
299    }
300
301    /// Parse RETURNING clause: `RETURNING expr, expr, ...`
302    fn parse_returning_clause(&self, ctx: &mut ParseContext) -> Option<Segment> {
303        let mut children = Vec::new();
304        let kw = ctx.eat_keyword("RETURNING")?;
305        children.push(token_segment(kw, SegmentType::Keyword));
306        children.extend(eat_trivia_segments(ctx));
307
308        parse_comma_separated(ctx, &mut children, |c| self.parse_select_target(c));
309
310        Some(Segment::Node(NodeSegment::new(
311            SegmentType::ReturningClause,
312            children,
313        )))
314    }
315
316    /// Parse ON CONFLICT clause:
317    /// `ON CONFLICT (col, ...) DO NOTHING`
318    /// `ON CONFLICT (col, ...) DO UPDATE SET col = expr, ...`
319    fn parse_on_conflict_clause(&self, ctx: &mut ParseContext) -> Option<Segment> {
320        let save = ctx.save();
321        let mut children = Vec::new();
322
323        let on_kw = ctx.eat_keyword("ON")?;
324        let trivia = eat_trivia_segments(ctx);
325        if !ctx.peek_keyword("CONFLICT") {
326            ctx.restore(save);
327            return None;
328        }
329        children.push(token_segment(on_kw, SegmentType::Keyword));
330        children.extend(trivia);
331
332        let conflict_kw = ctx.advance().unwrap();
333        children.push(token_segment(conflict_kw, SegmentType::Keyword));
334        children.extend(eat_trivia_segments(ctx));
335
336        // Optional conflict target: (column, ...) or ON CONSTRAINT name
337        if ctx.peek_kind() == Some(TokenKind::LParen) {
338            if let Some(cols) = self.parse_paren_block(ctx) {
339                children.push(cols);
340            }
341            children.extend(eat_trivia_segments(ctx));
342        } else if ctx.peek_keyword("ON") {
343            // ON CONSTRAINT constraint_name
344            let on2 = ctx.advance().unwrap();
345            children.push(token_segment(on2, SegmentType::Keyword));
346            children.extend(eat_trivia_segments(ctx));
347            if ctx.peek_keyword("CONSTRAINT") {
348                let cons_kw = ctx.advance().unwrap();
349                children.push(token_segment(cons_kw, SegmentType::Keyword));
350                children.extend(eat_trivia_segments(ctx));
351                if let Some(name) = self.parse_identifier(ctx) {
352                    children.push(name);
353                }
354                children.extend(eat_trivia_segments(ctx));
355            }
356        }
357
358        // WHERE clause for conflict target (partial index)
359        if ctx.peek_keyword("WHERE") {
360            if let Some(wh) = self.parse_where_clause(ctx) {
361                children.push(wh);
362                children.extend(eat_trivia_segments(ctx));
363            }
364        }
365
366        // DO NOTHING or DO UPDATE SET ...
367        if ctx.peek_keyword("DO") {
368            let do_kw = ctx.advance().unwrap();
369            children.push(token_segment(do_kw, SegmentType::Keyword));
370            children.extend(eat_trivia_segments(ctx));
371
372            if ctx.peek_keyword("NOTHING") {
373                let nothing_kw = ctx.advance().unwrap();
374                children.push(token_segment(nothing_kw, SegmentType::Keyword));
375            } else if ctx.peek_keyword("UPDATE") {
376                let update_kw = ctx.advance().unwrap();
377                children.push(token_segment(update_kw, SegmentType::Keyword));
378                children.extend(eat_trivia_segments(ctx));
379
380                // SET clause
381                if ctx.peek_keyword("SET") {
382                    if let Some(set) = self.parse_set_clause(ctx) {
383                        children.push(set);
384                    }
385                }
386
387                // WHERE clause (for DO UPDATE)
388                children.extend(eat_trivia_segments(ctx));
389                if ctx.peek_keyword("WHERE") {
390                    if let Some(wh) = self.parse_where_clause(ctx) {
391                        children.push(wh);
392                    }
393                }
394            }
395        }
396
397        Some(Segment::Node(NodeSegment::new(
398            SegmentType::OnConflictClause,
399            children,
400        )))
401    }
402}