1use cstree::{
2 build::GreenNodeBuilder, green::GreenNode, interning::Resolver, RawSyntaxKind, Syntax,
3};
4use miniz_oxide::inflate::decompress_to_vec;
5
6use crate::{
7 lexer::{lex, lexer_ported::init_tokens, parser_error::ParserError, TokenKind},
8 parser::{
9 end_rule_id, end_rule_kind, num_non_terminal_symbol, num_terminal_symbol,
10 rule_name_to_component_id, token_kind_to_component_id, Action, ACTION_TABLE, GOTO_TABLE,
11 RULES,
12 },
13};
14
15use super::{lexer::Token, syntax_kind::SyntaxKind};
16
17struct Node {
18 token: Option<Token>,
19 component_id: u32,
20 children: Vec<Node>,
21 start_byte_pos: usize,
22 end_byte_pos: usize,
23}
24
25pub type PostgreSQLSyntax = SyntaxKind;
26
27impl From<SyntaxKind> for cstree::RawSyntaxKind {
28 fn from(kind: SyntaxKind) -> Self {
29 Self(kind as u32)
30 }
31}
32
33pub type SyntaxNode = cstree::syntax::SyntaxNode<PostgreSQLSyntax>;
34pub type ResolvedNode = cstree::syntax::ResolvedNode<PostgreSQLSyntax>;
35pub type ResolvedToken = cstree::syntax::ResolvedToken<PostgreSQLSyntax>;
36#[allow(unused)]
37pub type SyntaxToken = cstree::syntax::SyntaxToken<PostgreSQLSyntax>;
38#[allow(unused)]
39pub type SyntaxElement = cstree::util::NodeOrToken<SyntaxNode, SyntaxToken>;
40pub type SyntaxElementRef<'a> = cstree::util::NodeOrToken<&'a SyntaxNode, &'a SyntaxToken>;
41pub type NodeOrToken<'a> = cstree::util::NodeOrToken<&'a ResolvedNode, &'a ResolvedToken>;
42
43struct Parser {
44 builder: GreenNodeBuilder<'static, 'static, PostgreSQLSyntax>,
45}
46
47impl Parser {
48 fn parse_rec(
49 &mut self,
50 node: &Node,
51 peekable: &mut std::iter::Peekable<std::vec::IntoIter<(SyntaxKind, usize, usize, &str)>>,
52 ) {
53 if cfg!(feature = "remove-empty-node") && node.start_byte_pos == node.end_byte_pos {
54 return;
55 }
56
57 while let Some((kind, start, _, text)) = peekable.peek() {
58 if *start >= node.start_byte_pos {
60 break;
62 }
63 self.builder.token(*kind, text);
64 peekable.next();
65 }
66
67 let kind: SyntaxKind = SyntaxKind::from_raw(RawSyntaxKind(node.component_id));
68 if let Some(token) = &node.token {
69 self.builder.token(kind, &token.value);
70 } else {
71 self.builder.start_node(kind);
72 node.children
73 .iter()
74 .for_each(|c| self.parse_rec(c, peekable));
75 self.builder.finish_node();
76 }
77 }
78
79 fn parse(
80 mut self,
81 nodes: &Vec<&Node>,
82 extras: Vec<(SyntaxKind, usize, usize, &str)>,
83 ) -> (GreenNode, impl Resolver) {
84 let mut peekable = extras.into_iter().peekable();
85
86 self.builder.start_node(SyntaxKind::Root);
87
88 for node in nodes {
89 self.parse_rec(node, &mut peekable);
90 }
91
92 while let Some((kind, _, _, text)) = peekable.peek() {
93 self.builder.token(*kind, text);
94 peekable.next();
95 }
96
97 self.builder.finish_node();
98
99 let (tree, cache) = self.builder.finish();
100 (tree, cache.unwrap().into_interner().unwrap())
101 }
102}
103
104pub fn parse(input: &str) -> Result<ResolvedNode, ParserError> {
106 let mut tokens = lex(input)?;
107
108 if !tokens.is_empty() {
109 init_tokens(&mut tokens);
110 }
111
112 tokens.push(Token {
113 kind: end_rule_kind(),
114 value: "".to_string(),
115 start_byte_pos: input.len(),
116 end_byte_pos: input.len(),
117 });
118
119 let action_table_u8 = decompress_to_vec(ACTION_TABLE.as_ref()).unwrap();
120 let goto_table_u8 = decompress_to_vec(GOTO_TABLE.as_ref()).unwrap();
121 let action_table = unsafe { action_table_u8.align_to::<i16>().1 };
122 let goto_table = unsafe { goto_table_u8.align_to::<i16>().1 };
123
124 let mut stack: Vec<(u32, Node)> = Vec::new();
125 let mut tokens: std::iter::Peekable<std::vec::IntoIter<Token>> = tokens.into_iter().peekable();
126
127 stack.push((
128 0,
129 Node {
130 token: None,
131 component_id: end_rule_id(),
132 children: Vec::new(),
133 start_byte_pos: 0,
134 end_byte_pos: 0,
135 },
136 ));
137
138 let mut last_pos = 0;
139 let mut extras: Vec<(SyntaxKind, usize, usize, &str)> = Vec::new();
140
141 loop {
142 let state = stack.last().unwrap().0;
143 let token = match tokens.peek() {
144 Some(token) => token,
145 None => {
146 return Err(ParserError::ParseError {
147 message: "unexpected end of input".to_string(),
148 start_byte_pos: input.len(),
149 end_byte_pos: input.len(),
150 });
151 }
152 };
153
154 let cid = token_kind_to_component_id(&token.kind);
155
156 if matches!(token.kind, TokenKind::C_COMMENT | TokenKind::SQL_COMMENT) {
157 if last_pos < token.start_byte_pos {
158 extras.push((
159 SyntaxKind::Whitespace,
160 last_pos,
161 token.start_byte_pos,
162 &input[last_pos..token.start_byte_pos],
163 ));
164 }
165
166 last_pos = token.end_byte_pos;
167
168 let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
169 extras.push((
170 kind,
171 token.start_byte_pos,
172 token.end_byte_pos,
173 &input[token.start_byte_pos..token.end_byte_pos],
174 ));
175 tokens.next();
176
177 continue;
178 }
179
180 let action = match action_table[(state * num_terminal_symbol() + cid) as usize] {
181 0x7FFF => Action::Error,
182 v if v > 0 => Action::Shift((v - 1) as usize),
183 v if v < 0 => Action::Reduce((-v - 1) as usize),
184 _ => Action::Accept,
185 };
186
187 match action {
188 Action::Shift(next_state) => {
189 let node = Node {
190 token: Some(token.clone()),
191 component_id: cid,
192 children: Vec::new(),
193 start_byte_pos: token.start_byte_pos,
194 end_byte_pos: token.end_byte_pos,
195 };
196
197 if last_pos < token.start_byte_pos {
198 extras.push((
199 SyntaxKind::Whitespace,
200 last_pos,
201 token.start_byte_pos,
202 &input[last_pos..token.start_byte_pos],
203 ));
204 }
205
206 last_pos = token.end_byte_pos;
207
208 stack.push((next_state as u32, node));
209 tokens.next();
210 }
211 Action::Reduce(rule_index) => {
212 let rule = &RULES[rule_index];
213
214 let mut children = Vec::new();
215 for _ in 0..rule.len {
216 children.push(stack.pop().unwrap().1);
217 }
218 children.reverse();
219
220 let reduced_component_id = rule_name_to_component_id(rule.name);
221
222 let start_byte_pos =
223 children
224 .first()
225 .map(|t| t.start_byte_pos)
226 .unwrap_or_else(|| {
227 extras
229 .last()
230 .map(|e| e.2)
231 .unwrap_or_default()
232 .max(stack.last().unwrap().1.end_byte_pos)
233 });
234
235 let end_byte_pos = children
236 .last()
237 .map(|t| t.end_byte_pos)
238 .unwrap_or(start_byte_pos);
239
240 let node = Node {
241 token: None,
242 component_id: reduced_component_id + num_terminal_symbol(),
243 children,
244 start_byte_pos,
245 end_byte_pos,
246 };
247
248 let next_state = stack.last().unwrap().0;
249 let goto = goto_table
250 [(next_state * num_non_terminal_symbol() + reduced_component_id) as usize];
251
252 match goto {
253 next_state if next_state >= 0 => {
254 stack.push((next_state as u32, node));
255 }
256 _ => {
257 return Err(ParserError::ParseError {
258 message: format!(
259 "syntax error at byte position {}",
260 token.start_byte_pos
261 ),
262 start_byte_pos: token.start_byte_pos,
263 end_byte_pos: token.end_byte_pos,
264 });
265 }
266 }
267 }
268 Action::Accept => {
269 break;
270 }
271 Action::Error => {
272 return Err(ParserError::ParseError {
273 message: format!(
274 "Action::Error: syntax error at byte position {}",
275 token.start_byte_pos
276 ),
277 start_byte_pos: token.start_byte_pos,
278 end_byte_pos: token.end_byte_pos,
279 });
280 }
281 }
282 }
283
284 while let Some(token) = tokens.next() {
285 if last_pos < token.start_byte_pos {
286 extras.push((
287 SyntaxKind::Whitespace,
288 last_pos,
289 token.start_byte_pos,
290 &input[last_pos..token.start_byte_pos],
291 ));
292 }
293
294 last_pos = token.end_byte_pos;
295
296 if tokens.peek().is_none() {
298 break;
299 }
300
301 let cid = token_kind_to_component_id(&token.kind);
302 let kind = SyntaxKind::from_raw(RawSyntaxKind(cid));
303 extras.push((
304 kind,
305 token.start_byte_pos,
306 token.end_byte_pos,
307 &input[token.start_byte_pos..token.end_byte_pos],
308 ));
309 }
310
311 let parser = Parser {
312 builder: GreenNodeBuilder::new(),
313 };
314 let root: Vec<&Node> = stack[1..].iter().map(|s| &s.1).collect();
315 let (ast, resolver) = parser.parse(&root, extras);
316
317 Ok(SyntaxNode::new_root_with_resolver(ast, resolver))
318}