Skip to main content

python_instruction_dsl_proc/
lib.rs

1extern crate proc_macro;
2use heck::ToUpperCamelCase;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    Expr, Ident, LitInt, Token, bracketed, parenthesized, parse::Parse, parse_macro_input,
7    token::Paren,
8};
9
10#[derive(Clone)]
11enum StackItem {
12    Name(Ident),
13    NameCounted(Ident, Expr),
14    /// Amount of unused (and unnamed) stack items
15    Unused(Expr),
16}
17
18#[derive(Clone)]
19struct StackEffect {
20    pops: Vec<StackItem>,
21    pushes: Vec<StackItem>,
22}
23
24#[derive(Clone)]
25struct Opcode {
26    name: Ident,
27    number: LitInt,
28    stack_effect: Option<StackEffect>,
29}
30
31impl Parse for Opcode {
32    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
33        // Example: LOAD_CONST = 100 ( -- constant)
34        // You may also use ( / ) to indicate this opcode has no stack description.
35        let name: Ident = input.parse()?;
36        input.parse::<Token![=]>()?;
37        let number: LitInt = input.parse()?;
38
39        let inner_stack_effect;
40
41        parenthesized!(inner_stack_effect in input);
42
43        let mut stack_effect = StackEffect {
44            pops: vec![],
45            pushes: vec![],
46        };
47
48        if inner_stack_effect.parse::<Token![/]>().is_ok() {
49            // This opcode does not have a stack description
50            return Ok(Opcode {
51                name,
52                number,
53                stack_effect: None,
54            });
55        }
56
57        // Pops
58        while inner_stack_effect.peek(Ident) {
59            let name: Ident = inner_stack_effect.parse()?;
60
61            stack_effect.pops.push(
62                // This name is special, see the reference document on GitHub.
63                if name == "unused" {
64                    if inner_stack_effect.peek(syn::token::Bracket) {
65                        let inner_bracket;
66                        bracketed!(inner_bracket in inner_stack_effect);
67                        let size: Expr = inner_bracket.parse()?;
68                        StackItem::Unused(size)
69                    } else {
70                        StackItem::Unused(Expr::Lit(syn::ExprLit {
71                            attrs: vec![],
72                            lit: syn::Lit::Int(LitInt::new(
73                                "1",
74                                proc_macro::Span::call_site().into(),
75                            )),
76                        }))
77                    }
78                } else {
79                    if inner_stack_effect.peek(syn::token::Bracket) {
80                        let inner_bracket;
81                        bracketed!(inner_bracket in inner_stack_effect);
82                        let size: Expr = inner_bracket.parse()?;
83                        StackItem::NameCounted(name, size)
84                    } else {
85                        StackItem::Name(name)
86                    }
87                },
88            );
89
90            if inner_stack_effect.parse::<Token![,]>().is_err() {
91                break;
92            }
93        }
94
95        inner_stack_effect.parse::<Token![-]>()?;
96        inner_stack_effect.parse::<Token![-]>()?;
97
98        while inner_stack_effect.peek(Ident) {
99            let name: Ident = inner_stack_effect.parse()?;
100
101            stack_effect.pushes.push(
102                // This name is special, see the reference document on GitHub.
103                if name == "unused" {
104                    if inner_stack_effect.peek(syn::token::Bracket) {
105                        let inner_bracket;
106                        bracketed!(inner_bracket in inner_stack_effect);
107                        let size: Expr = inner_bracket.parse()?;
108                        StackItem::Unused(size)
109                    } else {
110                        StackItem::Unused(Expr::Lit(syn::ExprLit {
111                            attrs: vec![],
112                            lit: syn::Lit::Int(LitInt::new(
113                                "1",
114                                proc_macro::Span::call_site().into(),
115                            )),
116                        }))
117                    }
118                } else {
119                    if inner_stack_effect.peek(syn::token::Bracket) {
120                        let inner_bracket;
121                        bracketed!(inner_bracket in inner_stack_effect);
122                        let size: Expr = inner_bracket.parse()?;
123                        StackItem::NameCounted(name, size)
124                    } else {
125                        StackItem::Name(name)
126                    }
127                },
128            );
129
130            if inner_stack_effect.parse::<Token![,]>().is_err() {
131                break;
132            }
133        }
134
135        Ok(Opcode {
136            name,
137            number,
138            stack_effect: Some(stack_effect),
139        })
140    }
141}
142
143struct Opcodes {
144    opcodes: Vec<Opcode>,
145}
146
147impl Parse for Opcodes {
148    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
149        let mut opcodes = vec![];
150
151        loop {
152            opcodes.push(Opcode::parse(input)?);
153
154            if input.parse::<Token![,]>().is_err() || input.is_empty() {
155                break;
156            }
157        }
158
159        Ok(Opcodes { opcodes })
160    }
161}
162
163fn sum_items(items: &[StackItem]) -> Expr {
164    if items.is_empty() {
165        // 0 if empty
166        Expr::Lit(syn::ExprLit {
167            attrs: vec![],
168            lit: syn::Lit::Int(LitInt::new("0", proc_macro::Span::call_site().into())),
169        })
170    } else {
171        items
172            .iter()
173            .map(|p| match p {
174                StackItem::Name(_) => Expr::Lit(syn::ExprLit {
175                    attrs: vec![],
176                    lit: syn::Lit::Int(LitInt::new("1", proc_macro::Span::call_site().into())),
177                }),
178                StackItem::NameCounted(_, size) => size.clone(),
179                StackItem::Unused(size) => size.clone(),
180            })
181            .reduce(|left, right| {
182                syn::Expr::Binary(syn::ExprBinary {
183                    attrs: vec![],
184                    left: Box::new(left),
185                    op: syn::BinOp::Add(syn::token::Plus {
186                        spans: [proc_macro::Span::call_site().into()],
187                    }),
188                    right: Box::new(right),
189                })
190            })
191            .expect("Something is wrong with the format")
192    }
193}
194
195#[proc_macro]
196pub fn define_opcodes(input: TokenStream) -> TokenStream {
197    let Opcodes { opcodes } = parse_macro_input!(input as Opcodes);
198
199    let opcodes_with_stack: Vec<_> = opcodes
200        .iter()
201        .filter(|o| o.stack_effect.is_some())
202        .collect();
203
204    let names: Vec<_> = opcodes.iter().map(|o| &o.name).collect();
205    let camel_names: Vec<Ident> = names
206        .iter()
207        .map(|ident| {
208            let camel = ident.to_string().to_upper_camel_case();
209            Ident::new(&camel, ident.span())
210        })
211        .collect();
212
213    let names_with_stack: Vec<_> = opcodes_with_stack.iter().map(|o| &o.name).collect();
214
215    let numbers: Vec<_> = opcodes.iter().map(|o| &o.number).collect();
216
217    let pops: Vec<_> = opcodes_with_stack
218        .iter()
219        .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pops))
220        .collect();
221
222    let pushes: Vec<_> = opcodes_with_stack
223        .iter()
224        .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pushes))
225        .collect();
226
227    let mut expanded = quote! {
228        #[allow(non_camel_case_types)]
229        #[allow(clippy::upper_case_acronyms)]
230        #[derive(Debug, Clone, PartialEq, Eq)]
231        pub enum Opcode {
232            #( #names ),*,
233            INVALID_OPCODE(u8),
234        }
235
236        impl From<u8> for Opcode {
237            fn from(value: u8) -> Self {
238                match value {
239                    #( #numbers => Opcode::#names, )*
240                    _ => Opcode::INVALID_OPCODE(value),
241                }
242            }
243        }
244
245        impl From<Opcode> for u8 {
246            fn from(value: Opcode) -> Self {
247                match value {
248                    #( Opcode::#names => #numbers , )*
249                    Opcode::INVALID_OPCODE(value) => value,
250                }
251            }
252        }
253
254        impl From<(Opcode, u8)> for Instruction {
255            fn from(value: (Opcode, u8)) -> Self {
256                match value.0 {
257                    #(
258                        Opcode::#names => Instruction::#camel_names(value.1),
259                    )*
260                    Opcode::INVALID_OPCODE(opcode) => {
261                        if !cfg!(test) {
262                            Instruction::InvalidOpcode((opcode, value.1))
263                        } else {
264                            panic!("Testing environment should not come across invalid opcodes")
265                        }
266                    },
267                }
268            }
269        }
270
271        impl Opcode {
272            pub fn from_instruction(instruction: &Instruction) -> Self {
273                match instruction {
274                    #(
275                        Instruction::#camel_names(_) => Opcode::#names ,
276                    )*
277                    Instruction::InvalidOpcode((opcode, _)) => Opcode::INVALID_OPCODE(*opcode),
278                }
279            }
280        }
281
282        impl StackEffectTrait for Opcode {
283            fn stack_effect(&self, oparg: u32, jump: bool, calculate_max: bool) -> StackEffect {
284                match &self {
285                    #(
286                        Opcode::#names_with_stack => StackEffect { pops: #pops, pushes: #pushes },
287                    )*
288                    Opcode::INVALID_OPCODE(_) => StackEffect { pops: 0, pushes: 0 },
289
290                    _ => unimplemented!("stack_effect not implemented for {:?}", self),
291                }
292            }
293        }
294    };
295
296    expanded.into()
297}