Skip to main content

python_instruction_dsl_proc/
lib.rs

1extern crate proc_macro;
2use heck::ToUpperCamelCase;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    Expr, Ident, LitInt, Token, bracketed, parenthesized, parse::Parse, parse_macro_input,
7    spanned::Spanned,
8};
9
10#[derive(Clone)]
11enum StackItem {
12    Name(Ident),
13    NameCounted(Ident, Expr),
14    /// Amount of unused (and unnamed) stack items
15    Unused(Expr),
16}
17
18#[derive(Clone)]
19struct StackEffect {
20    pops: Vec<StackItem>,
21    pushes: Vec<StackItem>,
22}
23
24#[derive(Clone)]
25struct Opcode {
26    name: Ident,
27    number: LitInt,
28    stack_effect: Option<StackEffect>,
29}
30
31impl Parse for Opcode {
32    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
33        // Example: LOAD_CONST = 100 ( -- constant)
34        // You may also use ( / ) to indicate this opcode has no stack description.
35        let name: Ident = input.parse()?;
36        input.parse::<Token![=]>()?;
37        let number: LitInt = input.parse()?;
38
39        let inner_stack_effect;
40
41        parenthesized!(inner_stack_effect in input);
42
43        let mut stack_effect = StackEffect {
44            pops: vec![],
45            pushes: vec![],
46        };
47
48        if inner_stack_effect.parse::<Token![/]>().is_ok() {
49            // This opcode does not have a stack description
50            return Ok(Opcode {
51                name,
52                number,
53                stack_effect: None,
54            });
55        }
56
57        // Pops
58        while inner_stack_effect.peek(Ident) {
59            let name: Ident = inner_stack_effect.parse()?;
60
61            stack_effect.pops.push(
62                // This name is special, see the reference document on GitHub.
63                if name == "unused" {
64                    if inner_stack_effect.peek(syn::token::Bracket) {
65                        let inner_bracket;
66                        bracketed!(inner_bracket in inner_stack_effect);
67                        let size: Expr = inner_bracket.parse()?;
68                        StackItem::Unused(size)
69                    } else {
70                        StackItem::Unused(Expr::Lit(syn::ExprLit {
71                            attrs: vec![],
72                            lit: syn::Lit::Int(LitInt::new(
73                                "1",
74                                proc_macro::Span::call_site().into(),
75                            )),
76                        }))
77                    }
78                } else {
79                    if inner_stack_effect.peek(syn::token::Bracket) {
80                        let inner_bracket;
81                        bracketed!(inner_bracket in inner_stack_effect);
82                        let size: Expr = inner_bracket.parse()?;
83                        StackItem::NameCounted(name, size)
84                    } else {
85                        StackItem::Name(name)
86                    }
87                },
88            );
89
90            if inner_stack_effect.parse::<Token![,]>().is_err() {
91                break;
92            }
93        }
94
95        inner_stack_effect.parse::<Token![-]>()?;
96        inner_stack_effect.parse::<Token![-]>()?;
97
98        while inner_stack_effect.peek(Ident) {
99            let name: Ident = inner_stack_effect.parse()?;
100
101            stack_effect.pushes.push(
102                // This name is special, see the reference document on GitHub.
103                if name == "unused" {
104                    if inner_stack_effect.peek(syn::token::Bracket) {
105                        let inner_bracket;
106                        bracketed!(inner_bracket in inner_stack_effect);
107                        let size: Expr = inner_bracket.parse()?;
108                        StackItem::Unused(size)
109                    } else {
110                        StackItem::Unused(Expr::Lit(syn::ExprLit {
111                            attrs: vec![],
112                            lit: syn::Lit::Int(LitInt::new(
113                                "1",
114                                proc_macro::Span::call_site().into(),
115                            )),
116                        }))
117                    }
118                } else {
119                    if inner_stack_effect.peek(syn::token::Bracket) {
120                        let inner_bracket;
121                        bracketed!(inner_bracket in inner_stack_effect);
122                        let size: Expr = inner_bracket.parse()?;
123                        StackItem::NameCounted(name, size)
124                    } else {
125                        StackItem::Name(name)
126                    }
127                },
128            );
129
130            if inner_stack_effect.parse::<Token![,]>().is_err() {
131                break;
132            }
133        }
134
135        Ok(Opcode {
136            name,
137            number,
138            stack_effect: Some(stack_effect),
139        })
140    }
141}
142
143struct Opcodes {
144    opcodes: Vec<Opcode>,
145}
146
147impl Parse for Opcodes {
148    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
149        let mut opcodes = vec![];
150
151        loop {
152            opcodes.push(Opcode::parse(input)?);
153
154            if input.parse::<Token![,]>().is_err() || input.is_empty() {
155                break;
156            }
157        }
158
159        Ok(Opcodes { opcodes })
160    }
161}
162
163fn sum_items(items: &[StackItem]) -> Expr {
164    if items.is_empty() {
165        // 0 if empty
166        Expr::Lit(syn::ExprLit {
167            attrs: vec![],
168            lit: syn::Lit::Int(LitInt::new("0", proc_macro::Span::call_site().into())),
169        })
170    } else {
171        items
172            .iter()
173            .map(|p| match p {
174                StackItem::Name(_) => Expr::Lit(syn::ExprLit {
175                    attrs: vec![],
176                    lit: syn::Lit::Int(LitInt::new("1", proc_macro::Span::call_site().into())),
177                }),
178                StackItem::NameCounted(_, size) => size.clone(),
179                StackItem::Unused(size) => size.clone(),
180            })
181            .reduce(|left, right| {
182                syn::Expr::Binary(syn::ExprBinary {
183                    attrs: vec![],
184                    left: Box::new(left),
185                    op: syn::BinOp::Add(syn::token::Plus {
186                        spans: [proc_macro::Span::call_site().into()],
187                    }),
188                    right: Box::new(right),
189                })
190            })
191            .expect("Something is wrong with the format")
192    }
193}
194
195#[proc_macro]
196pub fn define_opcodes(input: TokenStream) -> TokenStream {
197    let Opcodes { opcodes } = parse_macro_input!(input as Opcodes);
198
199    let opcodes_with_stack: Vec<_> = opcodes
200        .iter()
201        .filter(|o| o.stack_effect.is_some())
202        .collect();
203
204    let names: Vec<_> = opcodes.iter().map(|o| &o.name).collect();
205    let camel_names: Vec<Ident> = names
206        .iter()
207        .map(|ident| {
208            let camel = ident.to_string().to_upper_camel_case();
209            Ident::new(&camel, ident.span())
210        })
211        .collect();
212
213    let names_with_stack: Vec<_> = opcodes_with_stack.iter().map(|o| &o.name).collect();
214
215    let numbers: Vec<_> = opcodes.iter().map(|o| &o.number).collect();
216
217    let pops: Vec<_> = opcodes_with_stack
218        .iter()
219        .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pops))
220        .collect();
221
222    let pushes: Vec<_> = opcodes_with_stack
223        .iter()
224        .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pushes))
225        .collect();
226
227    let mut expanded = quote! {
228        #[allow(non_camel_case_types)]
229        #[allow(clippy::upper_case_acronyms)]
230        #[derive(Debug, Clone, PartialEq, Eq)]
231        pub enum Opcode {
232            #( #names ),*,
233            INVALID_OPCODE(u8),
234        }
235
236        impl From<u8> for Opcode {
237            fn from(value: u8) -> Self {
238                match value {
239                    #( #numbers => Opcode::#names, )*
240                    _ => Opcode::INVALID_OPCODE(value),
241                }
242            }
243        }
244
245        impl From<Opcode> for u8 {
246            fn from(value: Opcode) -> Self {
247                match value {
248                    #( Opcode::#names => #numbers , )*
249                    Opcode::INVALID_OPCODE(value) => value,
250                }
251            }
252        }
253
254        impl From<(Opcode, u8)> for Instruction {
255            fn from(value: (Opcode, u8)) -> Self {
256                match value.0 {
257                    #(
258                        Opcode::#names => Instruction::#camel_names(value.1),
259                    )*
260                    Opcode::INVALID_OPCODE(opcode) => {
261                        if !cfg!(test) {
262                            Instruction::InvalidOpcode((opcode, value.1))
263                        } else {
264                            panic!("Testing environment should not come across invalid opcodes")
265                        }
266                    },
267                }
268            }
269        }
270
271        impl Opcode {
272            pub fn from_instruction(instruction: &Instruction) -> Self {
273                match instruction {
274                    #(
275                        Instruction::#camel_names(_) => Opcode::#names ,
276                    )*
277                    Instruction::InvalidOpcode((opcode, _)) => Opcode::INVALID_OPCODE(*opcode),
278                }
279            }
280        }
281
282        impl StackEffectTrait for Opcode {
283            fn stack_effect(&self, oparg: u32, jump: bool, calculate_max: bool) -> StackEffect {
284                match &self {
285                    #(
286                        Opcode::#names_with_stack => StackEffect { pops: #pops, pushes: #pushes },
287                    )*
288                    Opcode::INVALID_OPCODE(_) => StackEffect { pops: 0, pushes: 0 },
289
290                    _ => unimplemented!("stack_effect not implemented for {:?}", self),
291                }
292            }
293        }
294    };
295
296    let mut input_sirs = vec![];
297    let mut output_sirs = vec![];
298
299    for (opcode, name) in opcodes.iter().zip(names) {
300        let mut input_constructor_fields = vec![];
301        let mut output_constructor_fields = vec![];
302
303        if let Some(stack_effect) = &opcode.stack_effect {
304            let mut index = quote! { 0 };
305            for pop in stack_effect.pops.iter().rev() {
306                match pop {
307                    StackItem::Name(name) => {
308                        let name = name.to_string();
309                        input_constructor_fields
310                            .push(quote! { StackItem { name: #name, count: 1, index: #index } });
311                        index = quote! { (#index) + 1 };
312                    }
313                    StackItem::NameCounted(name, count) => {
314                        let name = name.to_string();
315                        input_constructor_fields.push(
316                            quote! { StackItem { name: #name, count: #count, index: #index } },
317                        );
318                        index = quote! { (#index) + #count };
319                    }
320                    StackItem::Unused(count) => {
321                        index = quote! { (#index) + #count };
322                    }
323                }
324            }
325
326            input_constructor_fields.reverse();
327
328            let mut index = quote! { 0 };
329            for push in stack_effect.pushes.iter().rev() {
330                match push {
331                    StackItem::Name(name) => {
332                        let name = name.to_string();
333                        output_constructor_fields
334                            .push(quote! { StackItem { name: #name, count: 1, index: #index } });
335                        index = quote! { (#index) + 1 };
336                    }
337                    StackItem::NameCounted(name, count) => {
338                        let name = name.to_string();
339                        output_constructor_fields.push(
340                            quote! { StackItem { name: #name, count: #count, index: #index } },
341                        );
342                        index = quote! { (#index) + #count };
343                    }
344                    StackItem::Unused(count) => {
345                        index = quote! { (#index) + #count };
346                    }
347                }
348            }
349        }
350
351        input_sirs.push(quote! { Opcode::#name => vec![
352            #(
353                #input_constructor_fields
354            ),*
355        ] });
356
357        output_sirs.push(quote! { Opcode::#name => vec![
358            #(
359                #output_constructor_fields
360            ),*
361        ] });
362    }
363
364    let sir = quote! {
365        pub mod sir {
366            use super::{Opcode};
367            use crate::sir::{SIR, StackItem, SIRStatement, Call, SIRExpression, AuxVar};
368            use crate::traits::{GenericSIRNode, SIROwned};
369
370
371            #[derive(PartialEq, Debug, Clone)]
372            pub struct SIRNode {
373                pub opcode: Opcode,
374                pub oparg: u32,
375                pub input: Vec<StackItem>,
376                pub output: Vec<StackItem>,
377            }
378
379            impl SIRNode {
380                pub fn new(opcode: Opcode, oparg: u32, jump: bool) -> Self {
381                    // This comes from the Python DSL where it is used to calculate the max stack usage possible. We intentionally disable it here.
382                    let calculate_max = false;
383
384                    let input = match opcode {
385                        #(
386                            #input_sirs
387                        ),*,
388                        Opcode::INVALID_OPCODE(_) => vec![],
389                    };
390
391                    let output = match opcode {
392                        #(
393                            #output_sirs
394                        ),*,
395                        Opcode::INVALID_OPCODE(_) => vec![],
396                    };
397
398                    Self {
399                        opcode,
400                        oparg,
401                        input,
402                        output,
403                    }
404                }
405            }
406
407            impl GenericSIRNode for SIRNode {
408                type Opcode = Opcode;
409
410                fn new(opcode: Self::Opcode, oparg: u32, jump: bool) -> Self {
411                    SIRNode::new(opcode, oparg, jump)
412                }
413
414                fn get_outputs(&self) -> &[StackItem] {
415                    &self.output
416                }
417
418                fn get_inputs(&self) -> &[StackItem] {
419                    &self.input
420                }
421            }
422
423            impl SIROwned<SIRNode> for SIR<SIRNode> {
424                fn new(statements: Vec<SIRStatement<SIRNode>>) -> Self {
425                    SIR(statements)
426                }
427            }
428
429            impl std::fmt::Display for SIR<SIRNode> {
430                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431                    for statement in &self.0 {
432                        match statement {
433                            SIRStatement::Assignment(aux_var, call) => {
434                                writeln!(f, "{} = {}", aux_var.name, call)?;
435                            }
436                            SIRStatement::TupleAssignment(aux_vars, call) => {
437                                let vars = aux_vars.iter().map(|v| v.name.clone()).collect::<Vec<_>>().join(", ");
438                                writeln!(f, "({}) = {}", vars, call)?;
439                            }
440                            SIRStatement::DisregardCall(call) => {
441                                writeln!(f, "{}", call)?;
442                            }
443                        }
444                    }
445                    Ok(())
446                }
447            }
448
449            impl std::fmt::Display for Call<SIRNode> {
450                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451                    let mut inputs = self
452                        .stack_inputs
453                        .iter()
454                        .map(|input| format!("{}", input))
455                        .collect::<Vec<_>>();
456
457                    inputs.push(format!("{}", self.node.oparg));
458
459                    write!(f, "{:#?}({})", self.node.opcode, inputs.join(", "))
460                }
461            }
462
463            impl std::fmt::Display for SIRExpression<SIRNode> {
464                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465                    match self {
466                        SIRExpression::Call(call) => write!(f, "{}", call),
467                        SIRExpression::AuxVar(aux_var) => write!(f, "{}", aux_var.name.clone()),
468                        SIRExpression::PhiNode(phi) => write!(f, "phi({})", phi.iter().map(|v| &v.name).cloned().collect::<Vec<_>>().join(", ")),
469                    }
470                }
471            }
472        }
473    };
474
475    expanded.extend(sir);
476
477    expanded.into()
478}