python_instruction_dsl_proc/
lib.rs1extern crate proc_macro;
2use heck::ToUpperCamelCase;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6 Expr, Ident, LitInt, Token, bracketed, parenthesized, parse::Parse, parse_macro_input,
7 spanned::Spanned,
8};
9
10#[derive(Clone)]
11enum StackItem {
12 Name(Ident),
13 NameCounted(Ident, Expr),
14 Unused(Expr),
16}
17
18#[derive(Clone)]
19struct StackEffect {
20 pops: Vec<StackItem>,
21 pushes: Vec<StackItem>,
22}
23
24#[derive(Clone)]
25struct Opcode {
26 name: Ident,
27 number: LitInt,
28 stack_effect: Option<StackEffect>,
29}
30
31impl Parse for Opcode {
32 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
33 let name: Ident = input.parse()?;
36 input.parse::<Token![=]>()?;
37 let number: LitInt = input.parse()?;
38
39 let inner_stack_effect;
40
41 parenthesized!(inner_stack_effect in input);
42
43 let mut stack_effect = StackEffect {
44 pops: vec![],
45 pushes: vec![],
46 };
47
48 if inner_stack_effect.parse::<Token![/]>().is_ok() {
49 return Ok(Opcode {
51 name,
52 number,
53 stack_effect: None,
54 });
55 }
56
57 while inner_stack_effect.peek(Ident) {
59 let name: Ident = inner_stack_effect.parse()?;
60
61 stack_effect.pops.push(
62 if name == "unused" {
64 if inner_stack_effect.peek(syn::token::Bracket) {
65 let inner_bracket;
66 bracketed!(inner_bracket in inner_stack_effect);
67 let size: Expr = inner_bracket.parse()?;
68 StackItem::Unused(size)
69 } else {
70 StackItem::Unused(Expr::Lit(syn::ExprLit {
71 attrs: vec![],
72 lit: syn::Lit::Int(LitInt::new(
73 "1",
74 proc_macro::Span::call_site().into(),
75 )),
76 }))
77 }
78 } else {
79 if inner_stack_effect.peek(syn::token::Bracket) {
80 let inner_bracket;
81 bracketed!(inner_bracket in inner_stack_effect);
82 let size: Expr = inner_bracket.parse()?;
83 StackItem::NameCounted(name, size)
84 } else {
85 StackItem::Name(name)
86 }
87 },
88 );
89
90 if inner_stack_effect.parse::<Token![,]>().is_err() {
91 break;
92 }
93 }
94
95 inner_stack_effect.parse::<Token![-]>()?;
96 inner_stack_effect.parse::<Token![-]>()?;
97
98 while inner_stack_effect.peek(Ident) {
99 let name: Ident = inner_stack_effect.parse()?;
100
101 stack_effect.pushes.push(
102 if name == "unused" {
104 if inner_stack_effect.peek(syn::token::Bracket) {
105 let inner_bracket;
106 bracketed!(inner_bracket in inner_stack_effect);
107 let size: Expr = inner_bracket.parse()?;
108 StackItem::Unused(size)
109 } else {
110 StackItem::Unused(Expr::Lit(syn::ExprLit {
111 attrs: vec![],
112 lit: syn::Lit::Int(LitInt::new(
113 "1",
114 proc_macro::Span::call_site().into(),
115 )),
116 }))
117 }
118 } else {
119 if inner_stack_effect.peek(syn::token::Bracket) {
120 let inner_bracket;
121 bracketed!(inner_bracket in inner_stack_effect);
122 let size: Expr = inner_bracket.parse()?;
123 StackItem::NameCounted(name, size)
124 } else {
125 StackItem::Name(name)
126 }
127 },
128 );
129
130 if inner_stack_effect.parse::<Token![,]>().is_err() {
131 break;
132 }
133 }
134
135 Ok(Opcode {
136 name,
137 number,
138 stack_effect: Some(stack_effect),
139 })
140 }
141}
142
143struct Opcodes {
144 opcodes: Vec<Opcode>,
145}
146
147impl Parse for Opcodes {
148 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
149 let mut opcodes = vec![];
150
151 loop {
152 opcodes.push(Opcode::parse(input)?);
153
154 if input.parse::<Token![,]>().is_err() || input.is_empty() {
155 break;
156 }
157 }
158
159 Ok(Opcodes { opcodes })
160 }
161}
162
163fn sum_items(items: &[StackItem]) -> Expr {
164 if items.is_empty() {
165 Expr::Lit(syn::ExprLit {
167 attrs: vec![],
168 lit: syn::Lit::Int(LitInt::new("0", proc_macro::Span::call_site().into())),
169 })
170 } else {
171 items
172 .iter()
173 .map(|p| match p {
174 StackItem::Name(_) => Expr::Lit(syn::ExprLit {
175 attrs: vec![],
176 lit: syn::Lit::Int(LitInt::new("1", proc_macro::Span::call_site().into())),
177 }),
178 StackItem::NameCounted(_, size) => size.clone(),
179 StackItem::Unused(size) => size.clone(),
180 })
181 .reduce(|left, right| {
182 syn::Expr::Binary(syn::ExprBinary {
183 attrs: vec![],
184 left: Box::new(left),
185 op: syn::BinOp::Add(syn::token::Plus {
186 spans: [proc_macro::Span::call_site().into()],
187 }),
188 right: Box::new(right),
189 })
190 })
191 .expect("Something is wrong with the format")
192 }
193}
194
195#[proc_macro]
196pub fn define_opcodes(input: TokenStream) -> TokenStream {
197 let Opcodes { opcodes } = parse_macro_input!(input as Opcodes);
198
199 let opcodes_with_stack: Vec<_> = opcodes
200 .iter()
201 .filter(|o| o.stack_effect.is_some())
202 .collect();
203
204 let names: Vec<_> = opcodes.iter().map(|o| &o.name).collect();
205 let camel_names: Vec<Ident> = names
206 .iter()
207 .map(|ident| {
208 let camel = ident.to_string().to_upper_camel_case();
209 Ident::new(&camel, ident.span())
210 })
211 .collect();
212
213 let names_with_stack: Vec<_> = opcodes_with_stack.iter().map(|o| &o.name).collect();
214
215 let numbers: Vec<_> = opcodes.iter().map(|o| &o.number).collect();
216
217 let pops: Vec<_> = opcodes_with_stack
218 .iter()
219 .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pops))
220 .collect();
221
222 let pushes: Vec<_> = opcodes_with_stack
223 .iter()
224 .map(|o| sum_items(&o.stack_effect.as_ref().unwrap().pushes))
225 .collect();
226
227 let mut expanded = quote! {
228 #[allow(non_camel_case_types)]
229 #[allow(clippy::upper_case_acronyms)]
230 #[derive(Debug, Clone, PartialEq, Eq)]
231 pub enum Opcode {
232 #( #names ),*,
233 INVALID_OPCODE(u8),
234 }
235
236 impl From<u8> for Opcode {
237 fn from(value: u8) -> Self {
238 match value {
239 #( #numbers => Opcode::#names, )*
240 _ => Opcode::INVALID_OPCODE(value),
241 }
242 }
243 }
244
245 impl From<Opcode> for u8 {
246 fn from(value: Opcode) -> Self {
247 match value {
248 #( Opcode::#names => #numbers , )*
249 Opcode::INVALID_OPCODE(value) => value,
250 }
251 }
252 }
253
254 impl From<(Opcode, u8)> for Instruction {
255 fn from(value: (Opcode, u8)) -> Self {
256 match value.0 {
257 #(
258 Opcode::#names => Instruction::#camel_names(value.1),
259 )*
260 Opcode::INVALID_OPCODE(opcode) => {
261 if !cfg!(test) {
262 Instruction::InvalidOpcode((opcode, value.1))
263 } else {
264 panic!("Testing environment should not come across invalid opcodes")
265 }
266 },
267 }
268 }
269 }
270
271 impl Opcode {
272 pub fn from_instruction(instruction: &Instruction) -> Self {
273 match instruction {
274 #(
275 Instruction::#camel_names(_) => Opcode::#names ,
276 )*
277 Instruction::InvalidOpcode((opcode, _)) => Opcode::INVALID_OPCODE(*opcode),
278 }
279 }
280 }
281
282 impl StackEffectTrait for Opcode {
283 fn stack_effect(&self, oparg: u32, jump: bool, calculate_max: bool) -> StackEffect {
284 match &self {
285 #(
286 Opcode::#names_with_stack => StackEffect { pops: #pops, pushes: #pushes },
287 )*
288 Opcode::INVALID_OPCODE(_) => StackEffect { pops: 0, pushes: 0 },
289
290 _ => unimplemented!("stack_effect not implemented for {:?}", self),
291 }
292 }
293 }
294 };
295
296 let mut input_sirs = vec![];
297 let mut output_sirs = vec![];
298
299 for (opcode, name) in opcodes.iter().zip(names) {
300 let mut input_constructor_fields = vec![];
301 let mut output_constructor_fields = vec![];
302
303 if let Some(stack_effect) = &opcode.stack_effect {
304 let mut index = quote! { 0 };
305 for pop in stack_effect.pops.iter().rev() {
306 match pop {
307 StackItem::Name(name) => {
308 let name = name.to_string();
309 input_constructor_fields
310 .push(quote! { StackItem { name: #name, count: 1, index: #index } });
311 index = quote! { (#index) + 1 };
312 }
313 StackItem::NameCounted(name, count) => {
314 let name = name.to_string();
315 input_constructor_fields.push(
316 quote! { StackItem { name: #name, count: #count, index: #index } },
317 );
318 index = quote! { (#index) + #count };
319 }
320 StackItem::Unused(count) => {
321 index = quote! { (#index) + #count };
322 }
323 }
324 }
325
326 input_constructor_fields.reverse();
327
328 let mut index = quote! { 0 };
329 for push in stack_effect.pushes.iter().rev() {
330 match push {
331 StackItem::Name(name) => {
332 let name = name.to_string();
333 output_constructor_fields
334 .push(quote! { StackItem { name: #name, count: 1, index: #index } });
335 index = quote! { (#index) + 1 };
336 }
337 StackItem::NameCounted(name, count) => {
338 let name = name.to_string();
339 output_constructor_fields.push(
340 quote! { StackItem { name: #name, count: #count, index: #index } },
341 );
342 index = quote! { (#index) + #count };
343 }
344 StackItem::Unused(count) => {
345 index = quote! { (#index) + #count };
346 }
347 }
348 }
349 }
350
351 input_sirs.push(quote! { Opcode::#name => vec![
352 #(
353 #input_constructor_fields
354 ),*
355 ] });
356
357 output_sirs.push(quote! { Opcode::#name => vec![
358 #(
359 #output_constructor_fields
360 ),*
361 ] });
362 }
363
364 let sir = quote! {
365 pub mod sir {
366 use super::{Opcode};
367 use crate::sir::{SIR, StackItem, SIRStatement, Call, SIRExpression, AuxVar};
368 use crate::traits::{GenericSIRNode, SIROwned};
369
370
371 #[derive(PartialEq, Debug, Clone)]
372 pub struct SIRNode {
373 pub opcode: Opcode,
374 pub oparg: u32,
375 pub input: Vec<StackItem>,
376 pub output: Vec<StackItem>,
377 }
378
379 impl SIRNode {
380 pub fn new(opcode: Opcode, oparg: u32, jump: bool) -> Self {
381 let calculate_max = false;
383
384 let input = match opcode {
385 #(
386 #input_sirs
387 ),*,
388 Opcode::INVALID_OPCODE(_) => vec![],
389 };
390
391 let output = match opcode {
392 #(
393 #output_sirs
394 ),*,
395 Opcode::INVALID_OPCODE(_) => vec![],
396 };
397
398 Self {
399 opcode,
400 oparg,
401 input,
402 output,
403 }
404 }
405 }
406
407 impl GenericSIRNode for SIRNode {
408 type Opcode = Opcode;
409
410 fn new(opcode: Self::Opcode, oparg: u32, jump: bool) -> Self {
411 SIRNode::new(opcode, oparg, jump)
412 }
413
414 fn get_outputs(&self) -> &[StackItem] {
415 &self.output
416 }
417
418 fn get_inputs(&self) -> &[StackItem] {
419 &self.input
420 }
421 }
422
423 impl SIROwned<SIRNode> for SIR<SIRNode> {
424 fn new(statements: Vec<SIRStatement<SIRNode>>) -> Self {
425 SIR(statements)
426 }
427 }
428
429 impl std::fmt::Display for SIR<SIRNode> {
430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 for statement in &self.0 {
432 match statement {
433 SIRStatement::Assignment(aux_var, call) => {
434 writeln!(f, "{} = {}", aux_var.name, call)?;
435 }
436 SIRStatement::TupleAssignment(aux_vars, call) => {
437 let vars = aux_vars.iter().map(|v| v.name.clone()).collect::<Vec<_>>().join(", ");
438 writeln!(f, "({}) = {}", vars, call)?;
439 }
440 SIRStatement::DisregardCall(call) => {
441 writeln!(f, "{}", call)?;
442 }
443 }
444 }
445 Ok(())
446 }
447 }
448
449 impl std::fmt::Display for Call<SIRNode> {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 let mut inputs = self
452 .stack_inputs
453 .iter()
454 .map(|input| format!("{}", input))
455 .collect::<Vec<_>>();
456
457 inputs.push(format!("{}", self.node.oparg));
458
459 write!(f, "{:#?}({})", self.node.opcode, inputs.join(", "))
460 }
461 }
462
463 impl std::fmt::Display for SIRExpression<SIRNode> {
464 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
465 match self {
466 SIRExpression::Call(call) => write!(f, "{}", call),
467 SIRExpression::AuxVar(aux_var) => write!(f, "{}", aux_var.name.clone()),
468 SIRExpression::PhiNode(phi) => write!(f, "phi({})", phi.iter().map(|v| &v.name).cloned().collect::<Vec<_>>().join(", ")),
469 }
470 }
471 }
472 }
473 };
474
475 expanded.extend(sir);
476
477 expanded.into()
478}