mathml_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::{parse_macro_input, Ident, Result, Token, Type};
6
7mod kw {
8    syn::custom_keyword!(to);
9    syn::custom_keyword!(with);
10}
11
12#[proc_macro]
13pub fn attach(input: TokenStream) -> TokenStream {
14    // parse input
15    let input = parse_macro_input!(input as OpenInput);
16
17    // create new node
18    let instantiation_expr;
19    let pass_object_expr;
20    let tag_str;
21    let tag_type;
22    match input.tag {
23        Tag::Ident(tag) => {
24            instantiation_expr = quote! {
25                let mut new_node = #tag::default();
26            };
27            pass_object_expr = quote! {
28                new_tag = Some(MathNode::#tag(new_node));
29            };
30
31            tag_str = tag.to_string();
32            tag_type = tag;
33        }
34        Tag::Enum(enum_name, enum_type) => {
35            let enum_str = enum_name.to_string().to_lowercase();
36            let fn_str = format!("new_{}", enum_str);
37            let fn_ident = Ident::new(&fn_str, Span::call_site());
38
39            instantiation_expr = quote! {};
40
41            pass_object_expr = quote! {
42                new_tag = Some(MathNode::#fn_ident(#enum_name::#enum_type));
43            };
44
45            tag_str = format!("{}::{}", enum_name.to_string(), enum_type.to_string());
46            tag_type = enum_name;
47        }
48    }
49
50    // attributes field names and types
51    let attr_idents = input.attr_idents;
52    let attr_types = input.attr_types;
53    // also need strings for matching tokens
54    let mut attr_str = Vec::new();
55    for ident in &attr_idents {
56        let mut ident_str = ident.to_string();
57        if ident_str.starts_with("r#") {
58            ident_str = ident_str.trim_start_matches("r#").to_string()
59        }
60        attr_str.push(ident_str);
61    }
62
63    // create code to parse attributes
64    let store_attr = quote! {
65        // parse any attributes, keeping their types in mind
66        let attributes = e.attributes().map(|a| a.unwrap()).collect::<Vec<_>>();
67        //println!("{:?}", attributes);
68        for attribute in attributes {
69            let key = std::str::from_utf8(attribute.key).unwrap();
70            let value = attribute.unescape_and_decode_value(&reader).unwrap();
71            match key {
72                #(#attr_str => {
73                    new_node.#attr_idents =
74                        Some(value.parse::<#attr_types>().expect("Incorrect type"));
75                })*
76                _ => {
77                    //println!("{:?}", #attr_str);
78                    panic!("Attribute {} not parsed for {}", key, #tag_str);
79                }
80            }
81        }
82    };
83
84    let index_expr = quote! {
85        parent.index(MathNodeType::#tag_type, current.clone());
86    };
87
88    let parents = &input.parents;
89    // create strings for debugging
90    let mut parent_strs: Vec<String> = Vec::new();
91    let mut index_exprs: Vec<proc_macro2::TokenStream> = Vec::new();
92    // TODO: Convert this to trait
93    let parents_to_index: Vec<String> = vec!["Apply", "Lambda", "Piecewise", "Piece", "Otherwise"]
94        .iter()
95        .map(|&a| a.into())
96        .collect();
97    for parent in parents {
98        let parent_str = parent.to_string();
99        if parents_to_index.contains(&parent_str) {
100            index_exprs.push(index_expr.clone());
101        } else {
102            index_exprs.push(quote! {})
103        }
104        parent_strs.push(parent_str);
105    }
106
107    let tokens = quote! {
108        {
109            // create new object
110            #instantiation_expr
111            #store_attr
112            // match the current tag
113            match container[current] {
114                // with the parent
115                #(MathNode::#parents (ref mut parent) => {
116                    #pass_object_expr
117                    // update current pointer (which is really an int)
118                    current = container_len;
119                    // update parent pointer of new tag
120                    parent.children.push(current.clone());
121                    #index_exprs
122                    // push current pointer to stack
123                    stack.push(current.clone());
124                    //println!("Opened {}", #tag_str);
125                })*
126                _ => {
127                    panic!("Tag {:?} not parsed under parent {:?}", #tag_str, container[current]);
128                }
129            }
130        }
131    };
132    tokens.into()
133}
134
135#[derive(Debug)]
136struct OpenInput {
137    tag: Tag,
138    parents: Vec<Ident>,
139    attr_idents: Vec<Ident>,
140    attr_types: Vec<Type>,
141}
142
143#[derive(Debug)]
144enum Tag {
145    Ident(Ident),
146    Enum(Ident, Ident),
147}
148
149impl Parse for OpenInput {
150    fn parse(input: ParseStream) -> Result<Self> {
151        // parse tag
152        let tag_ident = syn::Ident::parse(input)?;
153        let tag: Tag;
154        // see if this is an enum
155        let mut lookahead = input.lookahead1();
156        if lookahead.peek(Token![:]) {
157            input.parse::<Token![:]>()?;
158            input.parse::<Token![:]>()?;
159
160            let enum_type = syn::Ident::parse(input)?;
161            tag = Tag::Enum(tag_ident, enum_type);
162        } else {
163            tag = Tag::Ident(tag_ident);
164        }
165
166        // define fields used later
167        let mut attr_idents = Vec::new();
168        let mut attr_types = Vec::new();
169        // define lookahead function
170        lookahead = input.lookahead1();
171
172        // if attributes are specified
173        if lookahead.peek(kw::with) {
174            let _with = input.parse::<kw::with>()?;
175
176            // loop over attributes and types
177            loop {
178                // parse attribute field name as ident
179                let ident = syn::Ident::parse(input)?;
180                attr_idents.push(ident);
181                let _as = input.parse::<Token![as]>();
182                // parse attribute type
183                let ty = syn::Type::parse(input)?;
184                attr_types.push(ty);
185
186                // consume comma if it exists
187                if input.peek(Token![,]) {
188                    input.parse::<Token![,]>()?;
189                }
190
191                // break if found into
192                // lookahead works only once
193                lookahead = input.lookahead1();
194                if lookahead.peek(kw::to) {
195                    break;
196                }
197            }
198        }
199        let _to = input.parse::<kw::to>()?;
200
201        // parse parent
202        let mut parents = vec![syn::Ident::parse(input)?];
203
204        // see if there are multiple parents
205        loop {
206            lookahead = input.lookahead1();
207            if lookahead.peek(Token![|]) {
208                input.parse::<Token![|]>()?;
209            } else {
210                break;
211            }
212
213            lookahead = input.lookahead1();
214            if lookahead.peek(Ident) {
215                parents.push(syn::Ident::parse(input)?);
216            }
217        }
218
219        //println!("Parents: {:?}", parents);
220
221        Ok(OpenInput {
222            tag,
223            parents,
224            attr_idents,
225            attr_types,
226        })
227    }
228}
229
230#[proc_macro]
231pub fn close(input: TokenStream) -> TokenStream {
232    let input = parse_macro_input!(input as CloseInput);
233    //println!("{:?}", input);
234
235    let tag = &input.tag;
236    let tag_str = input.tag.to_string();
237
238    let tokens = quote! {
239        match container[current] {
240            MathNode::#tag (ref mut tag_field) => {
241                stack.pop();
242                current = stack.last().unwrap().to_owned();
243                tag_field.parent = Some(current.clone());
244                //println!("Closing {}", #tag_str);
245            }
246            _ => {
247                //println!("{:#?}", container);
248                panic!("Trying to close {} but currently in {:?}", #tag_str, container[current]);
249            }
250        }
251    };
252    tokens.into()
253}
254
255#[derive(Debug)]
256struct CloseInput {
257    tag: Ident,
258}
259
260impl Parse for CloseInput {
261    fn parse(input: ParseStream) -> Result<Self> {
262        let tag = syn::Ident::parse(input)?;
263        Ok(CloseInput { tag })
264    }
265}
266#[cfg(test)]
267mod tests {
268    #[test]
269    fn it_works() {
270        assert_eq!(2 + 2, 4);
271    }
272}