faster_pest_derive/
lib.rs

1use std::collections::HashMap;
2
3use pest_meta::optimizer::OptimizedRule;
4pub(crate) use pest_meta::{optimizer::OptimizedExpr, ast::RuleType};
5extern crate proc_macro;
6use proc_macro::TokenStream;
7
8mod ids;
9pub(crate) use ids::*;
10mod tree_inspection;
11pub(crate) use tree_inspection::*;
12mod expr_codegen;
13pub(crate) use expr_codegen::*;
14mod optimizer;
15pub(crate) use optimizer::*;
16
17use syn::*;
18use proc_macro2::TokenTree;
19
20fn list_grammar_files(attrs: &[Attribute]) -> Vec<String> {
21    attrs.iter().filter(|attr| attr.path.is_ident("grammar")).map(|a| {
22        let mut tokens = a.tokens.clone().into_iter();
23        match tokens.next() {
24            Some(TokenTree::Punct(punct)) if punct.as_char() == '=' => (),
25            _ => panic!("Expected leading '=' in grammar attribute"),
26        }
27        let path = match tokens.next() {
28            Some(TokenTree::Literal(value)) => value.to_string(),
29            _ => panic!("Expected literal in grammar attribute")
30        };
31        path.trim_matches('"').to_string()
32    }).collect()
33}
34
35fn get_all_rules(grammar_files: &[String]) -> Vec<OptimizedRule> {
36    let mut rules = HashMap::new();
37
38    for path in grammar_files {
39        let Ok(grammar) = std::fs::read_to_string(path) else {
40            panic!("Could not read grammar file at {path:?}");
41        };
42        let (_, new_rules) = match pest_meta::parse_and_optimize(&grammar) {
43            Ok(new_rules) => new_rules,
44            Err(e) => panic!("{}", e[0])
45        };
46        for new_rule in new_rules {
47            rules.insert(new_rule.name.clone(), new_rule);
48        }
49    }
50
51    rules.into_values().collect()
52}
53
54#[proc_macro_derive(Parser, attributes(grammar))]
55pub fn derive_parser(input: TokenStream) -> TokenStream {
56    let ast = parse_macro_input!(input as DeriveInput);
57    let struct_ident = ast.ident;
58
59    let grammar_files = list_grammar_files(&ast.attrs);
60    let rules = get_all_rules(&grammar_files);
61    let mut full_code = String::new();
62
63    // Find silent rules
64    let silent_rules = rules.iter().filter(|rule| matches!(rule.ty, RuleType::Silent)).map(|rule| rule.name.as_str()).collect::<Vec<_>>();
65
66    // Find if there is a rule named WHITESPACE
67    let has_whitespace = rules.iter().any(|rule| rule.name.as_str() == "WHITESPACE");
68
69    // Create Ident enum
70    full_code.push_str("#[derive(Debug, Copy, Clone)]\n");
71    full_code.push_str("pub enum Ident<'i> {\n");
72    for rule in &rules {
73        let name = rule.name.as_str();
74        if !silent_rules.contains(&name) {
75            let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
76            full_code.push_str(&format!("    {name_pascal_case}(&'i str),\n"));
77        }
78    }
79    full_code.push_str("}\n");
80    full_code.push_str("impl<'i> IdentTrait for Ident<'i> {\n");
81    full_code.push_str("    type Rule = Rule;\n");
82    full_code.push_str("    \n");
83    full_code.push_str("    fn as_rule(&self) -> Rule {\n");
84    full_code.push_str("        match self {\n");
85    for rule in &rules {
86        let name = rule.name.as_str();
87        if !silent_rules.contains(&name) {
88            let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
89            full_code.push_str(&format!("            Ident::{name_pascal_case}(_) => Rule::{name},\n"));
90        }
91    }
92    full_code.push_str("        }\n");
93    full_code.push_str("    }\n");
94    full_code.push_str("    \n");
95    full_code.push_str("    fn as_str(&self) -> &str {\n");
96    full_code.push_str("        match self {\n");
97    for rule in &rules {
98        let name = rule.name.as_str();
99        if !silent_rules.contains(&name) {
100            let name_pascal_case = name.chars().next().unwrap().to_uppercase().collect::<String>() + &name[1..];
101            full_code.push_str(&format!("            Ident::{name_pascal_case}(s) => s,\n"));
102        }
103    }
104    full_code.push_str("        }\n");
105    full_code.push_str("    }\n");
106    full_code.push_str("}\n\n");
107
108    // Create Rule enum
109    full_code.push_str("#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]\n");
110    full_code.push_str("pub enum Rule {\n");
111    for rule in &rules {
112        let name = rule.name.as_str();
113        if !silent_rules.contains(&name) {
114            full_code.push_str(&format!("    {name},\n"));
115        }
116    }
117    full_code.push_str("}\n\n");
118
119    // Create parse method TODO name
120    full_code.push_str("#[automatically_derived]\n");
121    full_code.push_str(&format!("impl {struct_ident} {{\n"));
122    full_code.push_str("    pub fn parse(rule: Rule, input: &str) -> Result<Pairs2<Ident>, Error> {\n");
123    full_code.push_str("        let mut idents = Vec::with_capacity(500);\n"); // TODO: refine 500
124    full_code.push_str("        match rule {\n");
125    for rule in &rules {
126        let name = rule.name.as_str();
127        if !silent_rules.contains(&name) {
128            full_code.push_str(&format!("            Rule::{name} => {struct_ident}_faster_pest::parse_{name}(input.as_bytes(), &mut idents)?,\n"));
129        }
130    }
131    full_code.push_str("        };\n");
132    full_code.push_str("        Ok(unsafe {{ Pairs2::from_idents(idents, input) }})\n");
133    full_code.push_str("    }\n");
134    full_code.push_str("}\n\n");
135
136    full_code.push_str("\n\n#[automatically_derived]\n");
137    full_code.push_str("#[allow(clippy::all)]\n");
138    full_code.push_str(&format!("pub mod {struct_ident}_faster_pest {{\n"));
139    full_code.push_str("    use super::*;\n");
140
141    let mut ids = IdRegistry::new();
142    let mut optimized_exprs = Vec::new();
143    let mut exprs = Vec::new();
144    let mut character_set_rules = HashMap::new();
145    for rule in &rules {
146        let expr = optimize(&rule.expr);
147        if matches!(rule.ty, RuleType::Silent) {
148            if let FPestExpr::CharacterCondition(c) = &expr {
149                character_set_rules.insert(rule.name.as_str(), c.to_owned());
150            }
151        }
152        optimized_exprs.push(expr);
153    }
154    for expr in &mut optimized_exprs {
155        optimize_second_stage(expr, &character_set_rules);
156    }
157    println!("{:#?}", optimized_exprs);
158    for (i, rule) in rules.iter().enumerate() {
159        let expr = optimized_exprs.get(i).unwrap();
160        exprs.extend(list_exprs(expr));
161        let rule_name = rule.name.as_str();
162        let rule_name_pascal_case = rule_name.chars().next().unwrap().to_uppercase().collect::<String>() + &rule_name[1..];
163        let top_expr_id = ids.id(expr);
164        let formatted_idents = match contains_idents(expr, has_whitespace) {
165            true => "idents",
166            false => "",
167        };
168        match silent_rules.contains(&rule_name) {
169            false => full_code.push_str(&format!(r#"
170                #[automatically_derived]
171                #[allow(clippy::all)]
172                pub fn parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Result<&'i [u8], Error> {{
173                    let idents_len = idents.len();
174                    if idents_len == idents.capacity() {{
175                        idents.reserve(500);
176                    }}
177                    unsafe {{ idents.set_len(idents_len + 1); }}
178                    let new_input = match parse_{top_expr_id}(input, {formatted_idents}) {{
179                        Ok(input) => input,
180                        Err(e) => {{
181                            unsafe {{ idents.set_len(idents_len); }}
182                            return Err(e);
183                        }}
184                    }};
185                    let content = unsafe {{ std::str::from_utf8_unchecked(input.get_unchecked(..input.len() - new_input.len())) }};
186                    unsafe {{ *idents.get_unchecked_mut(idents_len) = (Ident::{rule_name_pascal_case}(content), idents.len()); }}
187                    Ok(new_input)
188                }}
189
190                #[automatically_derived]
191                #[allow(clippy::all)]
192                pub fn quick_parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Option<&'i [u8]> {{
193                    let idents_len = idents.len();
194                    if idents_len == idents.capacity() {{
195                        idents.reserve(500);
196                    }}
197                    unsafe {{ idents.set_len(idents_len + 1); }}
198                    let new_input = match quick_parse_{top_expr_id}(input, {formatted_idents}) {{
199                        Some(input) => input,
200                        None => {{
201                            unsafe {{ idents.set_len(idents_len); }}
202                            return None;
203                        }}
204                    }};
205                    let content = unsafe {{ std::str::from_utf8_unchecked(input.get_unchecked(..input.len() - new_input.len())) }};
206                    unsafe {{ *idents.get_unchecked_mut(idents_len) = (Ident::{rule_name_pascal_case}(content), idents.len()); }}
207                    Some(new_input)
208                }}
209                "#)
210            ),
211            true => full_code.push_str(&format!(r#"
212                #[automatically_derived]
213                #[allow(clippy::all)]
214                pub fn parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Result<&'i [u8], Error> {{
215                    parse_{top_expr_id}(input, {formatted_idents})
216                }}
217
218                #[automatically_derived]
219                #[allow(clippy::all)]
220                pub fn quick_parse_{rule_name}<'i, 'b>(input: &'i [u8], idents: &'b mut Vec<(Ident<'i>, usize)>) -> Option<&'i [u8]> {{
221                    quick_parse_{top_expr_id}(input, {formatted_idents})
222                }}
223                "#)
224            ),
225        }
226
227        full_code.push_str(&format!(r#"
228            #[automatically_derived]
229            #[allow(clippy::all)]
230            impl {struct_ident} {{
231                pub fn parse_{rule_name}(input: &str) -> Result<IdentList<Ident>, Error> {{
232                    let mut idents = Vec::with_capacity(500);
233                    if quick_parse_{rule_name}(input.as_bytes(), &mut idents).is_some() {{
234                        return Ok(unsafe {{ IdentList::from_idents(idents) }});
235                    }}
236                    idents.clear();
237                    parse_{rule_name}(input.as_bytes(), &mut idents)?;
238                    Ok(unsafe {{ IdentList::from_idents(idents) }})
239                }}
240            }}"#
241        ));
242    }
243    exprs.sort_by_key(|expr| ids.id(expr));
244    exprs.dedup();
245    for expr in exprs {
246        let mut new_code = code(expr, &mut ids, has_whitespace);
247        let mut new_code2 = new_code.trim_start_matches('\n');
248        let new_code2_len = new_code2.len();
249        new_code2 = new_code2.trim_start_matches(' ');
250        let len_diff = new_code2_len - new_code2.len();
251        let pattern = "\n".to_string() + &" ".repeat(len_diff);
252        new_code = new_code.replace(&pattern, "\n");
253        full_code.push_str(new_code.as_str());
254    }
255    full_code.push_str("}\n");
256    std::fs::write("target/fp_code.rs", &full_code).unwrap();
257
258    full_code.parse().unwrap()
259}