faster_pest_derive/
lib.rs1use std::collections::HashMap;
2
3use pest_meta::optimizer::OptimizedRule;
4pub(crate) use pest_meta::{optimizer::OptimizedExpr, ast::RuleType};
5extern crate proc_macro;
6use proc_macro::TokenStream;
7
8mod ids;
9pub(crate) use ids::*;
10mod tree_inspection;
11pub(crate) use tree_inspection::*;
12mod expr_codegen;
13pub(crate) use expr_codegen::*;
14mod optimizer;
15pub(crate) use optimizer::*;
16
17use syn::*;
18use proc_macro2::TokenTree;
19
20fn list_grammar_files(attrs: &[Attribute]) -> Vec<String> {
21 attrs.iter().filter(|attr| attr.path.is_ident("grammar")).map(|a| {
22 let mut tokens = a.tokens.clone().into_iter();
23 match tokens.next() {
24 Some(TokenTree::Punct(punct)) if punct.as_char() == '=' => (),
25 _ => panic!("Expected leading '=' in grammar attribute"),
26 }
27 let path = match tokens.next() {
28 Some(TokenTree::Literal(value)) => value.to_string(),
29 _ => panic!("Expected literal in grammar attribute")
30 };
31 path.trim_matches('"').to_string()
32 }).collect()
33}
34
35fn get_all_rules(grammar_files: &[String]) -> Vec<OptimizedRule> {
36 let mut rules = HashMap::new();
37
38 for path in grammar_files {
39 let Ok(grammar) = std::fs::read_to_string(path) else {
40 panic!("Could not read grammar file at {path:?}");
41 };
42 let (_, new_rules) = match pest_meta::parse_and_optimize(&grammar) {
43 Ok(new_rules) => new_rules,
44 Err(e) => panic!("{}", e[0])
45 };
46 for new_rule in new_rules {
47 rules.insert(new_rule.name.clone(), new_rule);
48 }
49 }
50
51 rules.into_values().collect()
52}
53
54#[proc_macro_derive(Parser, attributes(grammar))]
55pub fn derive_parser(input: TokenStream) -> TokenStream {
56 let ast = parse_macro_input!(input as DeriveInput);
57 let struct_ident = ast.ident;
58
59 let grammar_files = list_grammar_files(&ast.attrs);
60 let rules = get_all_rules(&grammar_files);
61 let mut full_code = String::new();
62
63 let silent_rules = rules.iter().filter(|rule| matches!(rule.ty, RuleType::Silent)).map(|rule| rule.name.as_str()).collect::<Vec<_>>();
65
66 let has_whitespace = rules.iter().any(|rule| rule.name.as_str() == "WHITESPACE");
68
69 full_code.push_str("#[derive(Debug, Copy, Clone)]\n");
71 full_code.push_str("pub enum Ident<'i> {\n");
72 for rule in &rules {
73 let name = rule.name.as_str();
74 if !silent_rules.contains(&name) {
75 let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
76 full_code.push_str(&format!(" {name_pascal_case}(&'i str),\n"));
77 }
78 }
79 full_code.push_str("}\n");
80 full_code.push_str("impl<'i> IdentTrait for Ident<'i> {\n");
81 full_code.push_str(" type Rule = Rule;\n");
82 full_code.push_str(" \n");
83 full_code.push_str(" fn as_rule(&self) -> Rule {\n");
84 full_code.push_str(" match self {\n");
85 for rule in &rules {
86 let name = rule.name.as_str();
87 if !silent_rules.contains(&name) {
88 let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
89 full_code.push_str(&format!(" Ident::{name_pascal_case}(_) => Rule::{name},\n"));
90 }
91 }
92 full_code.push_str(" }\n");
93 full_code.push_str(" }\n");
94 full_code.push_str(" \n");
95 full_code.push_str(" fn as_str(&self) -> &str {\n");
96 full_code.push_str(" match self {\n");
97 for rule in &rules {
98 let name = rule.name.as_str();
99 if !silent_rules.contains(&name) {
100 let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
101 full_code.push_str(&format!(" Ident::{name_pascal_case}(s) => s,\n"));
102 }
103 }
104 full_code.push_str(" }\n");
105 full_code.push_str(" }\n");
106 full_code.push_str("}\n\n");
107
108 full_code.push_str("#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]\n");
110 full_code.push_str("pub enum Rule {\n");
111 for rule in &rules {
112 let name = rule.name.as_str();
113 if !silent_rules.contains(&name) {
114 full_code.push_str(&format!(" {name},\n"));
115 }
116 }
117 full_code.push_str("}\n\n");
118
119 full_code.push_str("#[automatically_derived]\n");
121 full_code.push_str(&format!("impl {struct_ident} {{\n"));
122 full_code.push_str(" pub fn parse(rule: Rule, input: &str) -> Result<Pairs2<Ident>, Error> {\n");
123 full_code.push_str(" let mut idents = Vec::with_capacity(500);\n"); full_code.push_str(" match rule {\n");
125 for rule in &rules {
126 let name = rule.name.as_str();
127 if !silent_rules.contains(&name) {
128 full_code.push_str(&format!(" Rule::{name} => {struct_ident}_faster_pest::parse_{name}(input.as_bytes(), &mut idents)?,\n"));
129 }
130 }
131 full_code.push_str(" };\n");
132 full_code.push_str(" Ok(unsafe {{ Pairs2::from_idents(idents, input) }})\n");
133 full_code.push_str(" }\n");
134 full_code.push_str("}\n\n");
135
136 full_code.push_str("\n\n#[automatically_derived]\n");
137 full_code.push_str("#[allow(clippy::all)]\n");
138 full_code.push_str(&format!("pub mod {struct_ident}_faster_pest {{\n"));
139 full_code.push_str(" use super::*;\n");
140
141 let mut ids = IdRegistry::new();
142 let mut optimized_exprs = Vec::new();
143 let mut exprs = Vec::new();
144 let mut character_set_rules = HashMap::new();
145 for rule in &rules {
146 let expr = optimize(&rule.expr);
147 if matches!(rule.ty, RuleType::Silent) {
148 if let FPestExpr::CharacterCondition(c) = &expr {
149 character_set_rules.insert(rule.name.as_str(), c.to_owned());
150 }
151 }
152 optimized_exprs.push(expr);
153 }
154 for expr in &mut optimized_exprs {
155 optimize_second_stage(expr, &character_set_rules);
156 }
157 println!("{:#?}", optimized_exprs);
158 for (i, rule) in rules.iter().enumerate() {
159 let expr = optimized_exprs.get(i).unwrap();
160 exprs.extend(list_exprs(expr));
161 let rule_name = rule.name.as_str();
162 let rule_name_pascal_case = rule_name.chars().next().unwrap().to_uppercase().collect::<String>() + &rule_name[1..];
163 let top_expr_id = ids.id(expr);
164 let formatted_idents = match contains_idents(expr, has_whitespace) {
165 true => "idents",
166 false => "",
167 };
168 match silent_rules.contains(&rule_name) {
169 false => full_code.push_str(&format!(r#"
170 #[automatically_derived]
171 #[allow(clippy::all)]
172 pub fn parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Result<&'i [u8], Error> {{
173 let idents_len = idents.len();
174 if idents_len == idents.capacity() {{
175 idents.reserve(500);
176 }}
177 unsafe {{ idents.set_len(idents_len + 1); }}
178 let new_input = match parse_{top_expr_id}(input, {formatted_idents}) {{
179 Ok(input) => input,
180 Err(e) => {{
181 unsafe {{ idents.set_len(idents_len); }}
182 return Err(e);
183 }}
184 }};
185 let content = unsafe {{ std::str::from_utf8_unchecked(input.get_unchecked(..input.len() - new_input.len())) }};
186 unsafe {{ *idents.get_unchecked_mut(idents_len) = (Ident::{rule_name_pascal_case}(content), idents.len()); }}
187 Ok(new_input)
188 }}
189
190 #[automatically_derived]
191 #[allow(clippy::all)]
192 pub fn quick_parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Option<&'i [u8]> {{
193 let idents_len = idents.len();
194 if idents_len == idents.capacity() {{
195 idents.reserve(500);
196 }}
197 unsafe {{ idents.set_len(idents_len + 1); }}
198 let new_input = match quick_parse_{top_expr_id}(input, {formatted_idents}) {{
199 Some(input) => input,
200 None => {{
201 unsafe {{ idents.set_len(idents_len); }}
202 return None;
203 }}
204 }};
205 let content = unsafe {{ std::str::from_utf8_unchecked(input.get_unchecked(..input.len() - new_input.len())) }};
206 unsafe {{ *idents.get_unchecked_mut(idents_len) = (Ident::{rule_name_pascal_case}(content), idents.len()); }}
207 Some(new_input)
208 }}
209 "#)
210 ),
211 true => full_code.push_str(&format!(r#"
212 #[automatically_derived]
213 #[allow(clippy::all)]
214 pub fn parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Result<&'i [u8], Error> {{
215 parse_{top_expr_id}(input, {formatted_idents})
216 }}
217
218 #[automatically_derived]
219 #[allow(clippy::all)]
220 pub fn quick_parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Option<&'i [u8]> {{
221 quick_parse_{top_expr_id}(input, {formatted_idents})
222 }}
223 "#)
224 ),
225 }
226
227 full_code.push_str(&format!(r#"
228 #[automatically_derived]
229 #[allow(clippy::all)]
230 impl {struct_ident} {{
231 pub fn parse_{rule_name}(input: &str) -> Result<IdentList<Ident>, Error> {{
232 let mut idents = Vec::with_capacity(500);
233 if quick_parse_{rule_name}(input.as_bytes(), &mut idents).is_some() {{
234 return Ok(unsafe {{ IdentList::from_idents(idents) }});
235 }}
236 idents.clear();
237 parse_{rule_name}(input.as_bytes(), &mut idents)?;
238 Ok(unsafe {{ IdentList::from_idents(idents) }})
239 }}
240 }}"#
241 ));
242 }
243 exprs.sort_by_key(|expr| ids.id(expr));
244 exprs.dedup();
245 for expr in exprs {
246 let mut new_code = code(expr, &mut ids, has_whitespace);
247 let mut new_code2 = new_code.trim_start_matches('\n');
248 let new_code2_len = new_code2.len();
249 new_code2 = new_code2.trim_start_matches(' ');
250 let len_diff = new_code2_len - new_code2.len();
251 let pattern = "\n".to_string() + &" ".repeat(len_diff);
252 new_code = new_code.replace(&pattern, "\n");
253 full_code.push_str(new_code.as_str());
254 }
255 full_code.push_str("}\n");
256 std::fs::write("target/fp_code.rs", &full_code).unwrap();
257
258 full_code.parse().unwrap()
259}