postgresql_cst_parser/
cst.rs

1use cstree::{
2    build::GreenNodeBuilder, green::GreenNode, interning::Resolver, RawSyntaxKind, Syntax,
3};
4
5use crate::{
6    lexer::{lex, lexer_ported::init_tokens, parser_error::ParserError, TokenKind},
7    parser::{
8        end_rule_id, end_rule_kind, num_terminal_symbol, rule_name_to_component_id,
9        token_kind_to_component_id, Action, ACTION_CHECK_TABLE, ACTION_DEF_RULE_TABLE,
10        ACTION_TABLE, ACTION_TABLE_INDEX, GOTO_CHECK_TABLE, GOTO_TABLE, GOTO_TABLE_INDEX, RULES,
11    },
12};
13
14use super::{lexer::Token, syntax_kind::SyntaxKind};
15
16const ERROR_ACTION_CODE: i16 = 0x7FFF;
17const DEFAULT_ACTION_CODE: i16 = 0x7FFE;
18const INVALID_GOTO_CODE: i16 = -1;
19
20struct Node {
21    token: Option<Token>,
22    component_id: u32,
23    children: Vec<Node>,
24    start_byte_pos: usize,
25    end_byte_pos: usize,
26}
27
28pub type PostgreSQLSyntax = SyntaxKind;
29
30impl From<SyntaxKind> for cstree::RawSyntaxKind {
31    fn from(kind: SyntaxKind) -> Self {
32        Self(kind as u32)
33    }
34}
35
36pub type SyntaxNode = cstree::syntax::SyntaxNode<PostgreSQLSyntax>;
37pub type ResolvedNode = cstree::syntax::ResolvedNode<PostgreSQLSyntax>;
38pub type ResolvedToken = cstree::syntax::ResolvedToken<PostgreSQLSyntax>;
39#[allow(unused)]
40pub type SyntaxToken = cstree::syntax::SyntaxToken<PostgreSQLSyntax>;
41#[allow(unused)]
42pub type SyntaxElement = cstree::util::NodeOrToken<SyntaxNode, SyntaxToken>;
43pub type SyntaxElementRef<'a> = cstree::util::NodeOrToken<&'a SyntaxNode, &'a SyntaxToken>;
44pub type NodeOrToken<'a> = cstree::util::NodeOrToken<&'a ResolvedNode, &'a ResolvedToken>;
45
46struct Parser {
47    builder: GreenNodeBuilder<'static, 'static, PostgreSQLSyntax>,
48}
49
50impl Parser {
51    fn parse_rec(
52        &mut self,
53        node: &Node,
54        peekable: &mut std::iter::Peekable<std::vec::IntoIter<(SyntaxKind, usize, usize, &str)>>,
55    ) {
56        if cfg!(feature = "remove-empty-node") && node.start_byte_pos == node.end_byte_pos {
57            return;
58        }
59
60        while let Some((kind, start, _, text)) = peekable.peek() {
61            // TODO: Consider whether the presence or absence of an equals sign changes the position of comments. Determine which option is preferable
62            if *start >= node.start_byte_pos {
63                // if *start > node.start_byte_pos {
64                break;
65            }
66            self.builder.token(*kind, text);
67            peekable.next();
68        }
69
70        let kind: SyntaxKind = SyntaxKind::from_raw(RawSyntaxKind(node.component_id));
71        if let Some(token) = &node.token {
72            self.builder.token(kind, &token.value);
73        } else {
74            self.builder.start_node(kind);
75            node.children
76                .iter()
77                .for_each(|c| self.parse_rec(c, peekable));
78            self.builder.finish_node();
79        }
80    }
81
82    fn parse(
83        mut self,
84        nodes: &Vec<&Node>,
85        extras: Vec<(SyntaxKind, usize, usize, &str)>,
86    ) -> (GreenNode, impl Resolver) {
87        let mut peekable = extras.into_iter().peekable();
88
89        self.builder.start_node(SyntaxKind::Root);
90
91        for node in nodes {
92            self.parse_rec(node, &mut peekable);
93        }
94
95        while let Some((kind, _, _, text)) = peekable.peek() {
96            self.builder.token(*kind, text);
97            peekable.next();
98        }
99
100        self.builder.finish_node();
101
102        let (tree, cache) = self.builder.finish();
103        (tree, cache.unwrap().into_interner().unwrap())
104    }
105}
106
107fn lookup_parser_action(state: u32, cid: u32) -> i16 {
108    let state = state as usize;
109    let cid = cid as usize;
110
111    let i = ACTION_TABLE_INDEX[state] as usize;
112    if ACTION_CHECK_TABLE[i + cid] == cid as i16 {
113        if ACTION_TABLE[i + cid] == DEFAULT_ACTION_CODE {
114            ACTION_DEF_RULE_TABLE[state]
115        } else {
116            ACTION_TABLE[i + cid]
117        }
118    } else {
119        ERROR_ACTION_CODE
120    }
121}
122
123fn lookup_goto_state(state: u32, cid: u32) -> i16 {
124    let state = state as usize;
125    let cid = cid as usize;
126
127    let i = GOTO_TABLE_INDEX[state] as usize;
128    if GOTO_CHECK_TABLE[i + cid] == cid as i16 {
129        GOTO_TABLE[i + cid]
130    } else {
131        INVALID_GOTO_CODE
132    }
133}
134
135/// Parsing a string as PostgreSQL syntax and converting it into a ResolvedNode
136pub fn parse(input: &str) -> Result<ResolvedNode, ParserError> {
137    let mut tokens = lex(input)?;
138
139    if !tokens.is_empty() {
140        init_tokens(&mut tokens);
141    }
142
143    tokens.push(Token {
144        kind: end_rule_kind(),
145        value: "".to_string(),
146        start_byte_pos: input.len(),
147        end_byte_pos: input.len(),
148    });
149
150    let mut stack: Vec<(u32, Node)> = Vec::new();
151    let mut tokens: std::iter::Peekable<std::vec::IntoIter<Token>> = tokens.into_iter().peekable();
152
153    stack.push((
154        0,
155        Node {
156            token: None,
157            component_id: end_rule_id(),
158            children: Vec::new(),
159            start_byte_pos: 0,
160            end_byte_pos: 0,
161        },
162    ));
163
164    let mut last_pos = 0;
165    let mut extras: Vec<(SyntaxKind, usize, usize, &str)> = Vec::new();
166
167    loop {
168        let state = stack.last().unwrap().0;
169        let token = match tokens.peek() {
170            Some(token) => token,
171            None => {
172                return Err(ParserError::ParseError {
173                    message: "unexpected end of input".to_string(),
174                    start_byte_pos: input.len(),
175                    end_byte_pos: input.len(),
176                });
177            }
178        };
179
180        let cid = token_kind_to_component_id(&token.kind);
181
182        if matches!(token.kind, TokenKind::C_COMMENT | TokenKind::SQL_COMMENT) {
183            if last_pos < token.start_byte_pos {
184                extras.push((
185                    SyntaxKind::Whitespace,
186                    last_pos,
187                    token.start_byte_pos,
188                    &input[last_pos..token.start_byte_pos],
189                ));
190            }
191
192            last_pos = token.end_byte_pos;
193
194            let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
195            extras.push((
196                kind,
197                token.start_byte_pos,
198                token.end_byte_pos,
199                &input[token.start_byte_pos..token.end_byte_pos],
200            ));
201            tokens.next();
202
203            continue;
204        }
205
206        let action = match lookup_parser_action(state, cid) {
207            ERROR_ACTION_CODE => Action::Error,
208            v if v > 0 => Action::Shift((v - 1) as usize),
209            v if v < 0 => Action::Reduce((-v - 1) as usize),
210            _ => Action::Accept,
211        };
212
213        match action {
214            Action::Shift(next_state) => {
215                let node = Node {
216                    token: Some(token.clone()),
217                    component_id: cid,
218                    children: Vec::new(),
219                    start_byte_pos: token.start_byte_pos,
220                    end_byte_pos: token.end_byte_pos,
221                };
222
223                if last_pos < token.start_byte_pos {
224                    extras.push((
225                        SyntaxKind::Whitespace,
226                        last_pos,
227                        token.start_byte_pos,
228                        &input[last_pos..token.start_byte_pos],
229                    ));
230                }
231
232                last_pos = token.end_byte_pos;
233
234                stack.push((next_state as u32, node));
235                tokens.next();
236            }
237            Action::Reduce(rule_index) => {
238                let rule = &RULES[rule_index];
239
240                let mut children = Vec::new();
241                for _ in 0..rule.len {
242                    children.push(stack.pop().unwrap().1);
243                }
244                children.reverse();
245
246                let reduced_component_id = rule_name_to_component_id(rule.name);
247
248                let start_byte_pos =
249                    children
250                        .first()
251                        .map(|t| t.start_byte_pos)
252                        .unwrap_or_else(|| {
253                            // Adopt the larger of the end position of the previous token or the end of the space.
254                            extras
255                                .last()
256                                .map(|e| e.2)
257                                .unwrap_or_default()
258                                .max(stack.last().unwrap().1.end_byte_pos)
259                        });
260
261                let end_byte_pos = children
262                    .last()
263                    .map(|t| t.end_byte_pos)
264                    .unwrap_or(start_byte_pos);
265
266                let node = Node {
267                    token: None,
268                    component_id: reduced_component_id + num_terminal_symbol(),
269                    children,
270                    start_byte_pos,
271                    end_byte_pos,
272                };
273
274                let next_state = stack.last().unwrap().0;
275                let goto = lookup_goto_state(next_state, reduced_component_id);
276
277                match goto {
278                    next_state if next_state >= 0 => {
279                        stack.push((next_state as u32, node));
280                    }
281                    _ => {
282                        return Err(ParserError::ParseError {
283                            message: format!(
284                                "syntax error at byte position {}",
285                                token.start_byte_pos
286                            ),
287                            start_byte_pos: token.start_byte_pos,
288                            end_byte_pos: token.end_byte_pos,
289                        });
290                    }
291                }
292            }
293            Action::Accept => {
294                break;
295            }
296            Action::Error => {
297                return Err(ParserError::ParseError {
298                    message: format!(
299                        "Action::Error: syntax error at byte position {}",
300                        token.start_byte_pos
301                    ),
302                    start_byte_pos: token.start_byte_pos,
303                    end_byte_pos: token.end_byte_pos,
304                });
305            }
306        }
307    }
308
309    while let Some(token) = tokens.next() {
310        if last_pos < token.start_byte_pos {
311            extras.push((
312                SyntaxKind::Whitespace,
313                last_pos,
314                token.start_byte_pos,
315                &input[last_pos..token.start_byte_pos],
316            ));
317        }
318
319        last_pos = token.end_byte_pos;
320
321        // The last token is $end, so exit the loop here
322        if tokens.peek().is_none() {
323            break;
324        }
325
326        let cid = token_kind_to_component_id(&token.kind);
327        let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
328        extras.push((
329            kind,
330            token.start_byte_pos,
331            token.end_byte_pos,
332            &input[token.start_byte_pos..token.end_byte_pos],
333        ));
334    }
335
336    let parser = Parser {
337        builder: GreenNodeBuilder::new(),
338    };
339    let root: Vec<&Node> = stack[1..].iter().map(|s| &s.1).collect();
340    let (ast, resolver) = parser.parse(&root, extras);
341
342    Ok(SyntaxNode::new_root_with_resolver(ast, resolver))
343}