postgresql_cst_parser/
cst.rs

1use cstree::{
2    build::GreenNodeBuilder, green::GreenNode, interning::Resolver, RawSyntaxKind, Syntax,
3};
4use miniz_oxide::inflate::decompress_to_vec;
5
6use crate::{
7    lexer::{lex, lexer_ported::init_tokens, parser_error::ParserError, TokenKind},
8    parser::{
9        end_rule_id, end_rule_kind, num_non_terminal_symbol, num_terminal_symbol,
10        rule_name_to_component_id, token_kind_to_component_id, Action, ACTION_TABLE, GOTO_TABLE,
11        RULES,
12    },
13};
14
15use super::{lexer::Token, syntax_kind::SyntaxKind};
16
17struct Node {
18    token: Option<Token>,
19    component_id: u32,
20    children: Vec<Node>,
21    start_byte_pos: usize,
22    end_byte_pos: usize,
23}
24
25pub type PostgreSQLSyntax = SyntaxKind;
26
27impl From<SyntaxKind> for cstree::RawSyntaxKind {
28    fn from(kind: SyntaxKind) -> Self {
29        Self(kind as u32)
30    }
31}
32
33pub type SyntaxNode = cstree::syntax::SyntaxNode<PostgreSQLSyntax>;
34pub type ResolvedNode = cstree::syntax::ResolvedNode<PostgreSQLSyntax>;
35pub type ResolvedToken = cstree::syntax::ResolvedToken<PostgreSQLSyntax>;
36#[allow(unused)]
37pub type SyntaxToken = cstree::syntax::SyntaxToken<PostgreSQLSyntax>;
38#[allow(unused)]
39pub type SyntaxElement = cstree::util::NodeOrToken<SyntaxNode, SyntaxToken>;
40pub type SyntaxElementRef<'a> = cstree::util::NodeOrToken<&'a SyntaxNode, &'a SyntaxToken>;
41pub type NodeOrToken<'a> = cstree::util::NodeOrToken<&'a ResolvedNode, &'a ResolvedToken>;
42
43struct Parser {
44    builder: GreenNodeBuilder<'static, 'static, PostgreSQLSyntax>,
45}
46
47impl Parser {
48    fn parse_rec(
49        &mut self,
50        node: &Node,
51        peekable: &mut std::iter::Peekable<std::vec::IntoIter<(SyntaxKind, usize, usize, &str)>>,
52    ) {
53        if cfg!(feature = "remove-empty-node") && node.start_byte_pos == node.end_byte_pos {
54            return;
55        }
56
57        while let Some((kind, start, _, text)) = peekable.peek() {
58            // TODO: Consider whether the presence or absence of an equals sign changes the position of comments. Determine which option is preferable
59            if *start >= node.start_byte_pos {
60                // if *start > node.start_byte_pos {
61                break;
62            }
63            self.builder.token(*kind, text);
64            peekable.next();
65        }
66
67        let kind: SyntaxKind = SyntaxKind::from_raw(RawSyntaxKind(node.component_id));
68        if let Some(token) = &node.token {
69            self.builder.token(kind, &token.value);
70        } else {
71            self.builder.start_node(kind);
72            node.children
73                .iter()
74                .for_each(|c| self.parse_rec(c, peekable));
75            self.builder.finish_node();
76        }
77    }
78
79    fn parse(
80        mut self,
81        nodes: &Vec<&Node>,
82        extras: Vec<(SyntaxKind, usize, usize, &str)>,
83    ) -> (GreenNode, impl Resolver) {
84        let mut peekable = extras.into_iter().peekable();
85
86        self.builder.start_node(SyntaxKind::Root);
87
88        for node in nodes {
89            self.parse_rec(node, &mut peekable);
90        }
91
92        while let Some((kind, _, _, text)) = peekable.peek() {
93            self.builder.token(*kind, text);
94            peekable.next();
95        }
96
97        self.builder.finish_node();
98
99        let (tree, cache) = self.builder.finish();
100        (tree, cache.unwrap().into_interner().unwrap())
101    }
102}
103
104/// Parsing a string as PostgreSQL syntax and converting it into a ResolvedNode
105pub fn parse(input: &str) -> Result<ResolvedNode, ParserError> {
106    let mut tokens = lex(input)?;
107
108    if !tokens.is_empty() {
109        init_tokens(&mut tokens);
110    }
111
112    tokens.push(Token {
113        kind: end_rule_kind(),
114        value: "".to_string(),
115        start_byte_pos: input.len(),
116        end_byte_pos: input.len(),
117    });
118
119    let action_table_u8 = decompress_to_vec(ACTION_TABLE.as_ref()).unwrap();
120    let goto_table_u8 = decompress_to_vec(GOTO_TABLE.as_ref()).unwrap();
121    let action_table = unsafe { action_table_u8.align_to::<i16>().1 };
122    let goto_table = unsafe { goto_table_u8.align_to::<i16>().1 };
123
124    let mut stack: Vec<(u32, Node)> = Vec::new();
125    let mut tokens: std::iter::Peekable<std::vec::IntoIter<Token>> = tokens.into_iter().peekable();
126
127    stack.push((
128        0,
129        Node {
130            token: None,
131            component_id: end_rule_id(),
132            children: Vec::new(),
133            start_byte_pos: 0,
134            end_byte_pos: 0,
135        },
136    ));
137
138    let mut last_pos = 0;
139    let mut extras: Vec<(SyntaxKind, usize, usize, &str)> = Vec::new();
140
141    loop {
142        let state = stack.last().unwrap().0;
143        let token = match tokens.peek() {
144            Some(token) => token,
145            None => {
146                return Err(ParserError::ParseError {
147                    message: "unexpected end of input".to_string(),
148                    start_byte_pos: input.len(),
149                    end_byte_pos: input.len(),
150                });
151            }
152        };
153
154        let cid = token_kind_to_component_id(&token.kind);
155
156        if matches!(token.kind, TokenKind::C_COMMENT | TokenKind::SQL_COMMENT) {
157            if last_pos < token.start_byte_pos {
158                extras.push((
159                    SyntaxKind::Whitespace,
160                    last_pos,
161                    token.start_byte_pos,
162                    &input[last_pos..token.start_byte_pos],
163                ));
164            }
165
166            last_pos = token.end_byte_pos;
167
168            let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
169            extras.push((
170                kind,
171                token.start_byte_pos,
172                token.end_byte_pos,
173                &input[token.start_byte_pos..token.end_byte_pos],
174            ));
175            tokens.next();
176
177            continue;
178        }
179
180        let action = match action_table[(state * num_terminal_symbol() + cid) as usize] {
181            0x7FFF => Action::Error,
182            v if v > 0 => Action::Shift((v - 1) as usize),
183            v if v < 0 => Action::Reduce((-v - 1) as usize),
184            _ => Action::Accept,
185        };
186
187        match action {
188            Action::Shift(next_state) => {
189                let node = Node {
190                    token: Some(token.clone()),
191                    component_id: cid,
192                    children: Vec::new(),
193                    start_byte_pos: token.start_byte_pos,
194                    end_byte_pos: token.end_byte_pos,
195                };
196
197                if last_pos < token.start_byte_pos {
198                    extras.push((
199                        SyntaxKind::Whitespace,
200                        last_pos,
201                        token.start_byte_pos,
202                        &input[last_pos..token.start_byte_pos],
203                    ));
204                }
205
206                last_pos = token.end_byte_pos;
207
208                stack.push((next_state as u32, node));
209                tokens.next();
210            }
211            Action::Reduce(rule_index) => {
212                let rule = &RULES[rule_index];
213
214                let mut children = Vec::new();
215                for _ in 0..rule.len {
216                    children.push(stack.pop().unwrap().1);
217                }
218                children.reverse();
219
220                let reduced_component_id = rule_name_to_component_id(rule.name);
221
222                let start_byte_pos =
223                    children
224                        .first()
225                        .map(|t| t.start_byte_pos)
226                        .unwrap_or_else(|| {
227                            // Adopt the larger of the end position of the previous token or the end of the space.
228                            extras
229                                .last()
230                                .map(|e| e.2)
231                                .unwrap_or_default()
232                                .max(stack.last().unwrap().1.end_byte_pos)
233                        });
234
235                let end_byte_pos = children
236                    .last()
237                    .map(|t| t.end_byte_pos)
238                    .unwrap_or(start_byte_pos);
239
240                let node = Node {
241                    token: None,
242                    component_id: reduced_component_id + num_terminal_symbol(),
243                    children,
244                    start_byte_pos,
245                    end_byte_pos,
246                };
247
248                let next_state = stack.last().unwrap().0;
249                let goto = goto_table
250                    [(next_state * num_non_terminal_symbol() + reduced_component_id) as usize];
251
252                match goto {
253                    next_state if next_state >= 0 => {
254                        stack.push((next_state as u32, node));
255                    }
256                    _ => {
257                        return Err(ParserError::ParseError {
258                            message: format!(
259                                "syntax error at byte position {}",
260                                token.start_byte_pos
261                            ),
262                            start_byte_pos: token.start_byte_pos,
263                            end_byte_pos: token.end_byte_pos,
264                        });
265                    }
266                }
267            }
268            Action::Accept => {
269                break;
270            }
271            Action::Error => {
272                return Err(ParserError::ParseError {
273                    message: format!(
274                        "Action::Error: syntax error at byte position {}",
275                        token.start_byte_pos
276                    ),
277                    start_byte_pos: token.start_byte_pos,
278                    end_byte_pos: token.end_byte_pos,
279                });
280            }
281        }
282    }
283
284    while let Some(token) = tokens.next() {
285        if last_pos < token.start_byte_pos {
286            extras.push((
287                SyntaxKind::Whitespace,
288                last_pos,
289                token.start_byte_pos,
290                &input[last_pos..token.start_byte_pos],
291            ));
292        }
293
294        last_pos = token.end_byte_pos;
295
296        // 最後のトークンは$endなので、ここで終了
297        if tokens.peek().is_none() {
298            break;
299        }
300
301        let cid = token_kind_to_component_id(&token.kind);
302        let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
303        extras.push((
304            kind,
305            token.start_byte_pos,
306            token.end_byte_pos,
307            &input[token.start_byte_pos..token.end_byte_pos],
308        ));
309    }
310
311    let parser = Parser {
312        builder: GreenNodeBuilder::new(),
313    };
314    let root: Vec<&Node> = stack[1..].iter().map(|s| &s.1).collect();
315    let (ast, resolver) = parser.parse(&root, extras);
316
317    Ok(SyntaxNode::new_root_with_resolver(ast, resolver))
318}