pwn_helper_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use core::panic;
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use syn::{parse::Parse, parse_macro_input, Lit};
8
9enum Operand {
10    Byte(syn::LitByte),
11    ByteStr(syn::LitByteStr),
12    Int(syn::LitInt),
13}
14
15impl Parse for Operand {
16    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17        match input.parse::<syn::Lit>() {
18            Ok(Lit::Byte(x)) => Ok(Operand::Byte(x)),
19            Ok(Lit::ByteStr(x)) => Ok(Operand::ByteStr(x)),
20            Ok(Lit::Int(x)) => Ok(Operand::Int(x)),
21            _ => Err(syn::Error::new(
22                input.span(),
23                "Missing operand (Either byte literal, byte str, or int)",
24            )),
25        }
26    }
27}
28
29enum Operation {
30    Multiply,
31    Addition,
32}
33
34struct ByteOperation {
35    pub lhs: Operand,
36    pub operation: Operation,
37    pub rhs: Operand,
38    pub span: Span,
39}
40
41impl Parse for ByteOperation {
42    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
43        let span = input.span();
44        let lhs = Operand::parse(input)?;
45
46        let operation = match input.cursor().punct() {
47            Some((x, _)) => match x.as_char() {
48                '*' => {
49                    input.parse::<syn::UnOp>()?;
50                    Operation::Multiply
51                }
52                '+' => {
53                    input.parse::<syn::BinOp>()?;
54                    Operation::Addition
55                }
56                _ => {
57                    return Err(syn::Error::new(
58                        input.span(),
59                        "Missing operand (Expected either `+` or `*`)",
60                    ))
61                }
62            },
63            None => {
64                return Err(syn::Error::new(
65                    input.span(),
66                    "Missing operand (Expected either `+` or `*`)",
67                ))
68            }
69        };
70
71        let rhs = Operand::parse(input)?;
72
73        Ok(Self {
74            lhs,
75            operation,
76            rhs,
77            span,
78        })
79    }
80}
81
82#[proc_macro]
83pub fn bytes(ts: TokenStream) -> TokenStream {
84    let input = parse_macro_input!(ts as ByteOperation);
85
86    let lhs = match input.lhs {
87        Operand::Byte(b) => vec![b.value()],
88        Operand::ByteStr(b) => b.value(),
89        Operand::Int(_) => panic!("Int on left side"),
90    };
91
92    let output = match input.operation {
93        Operation::Multiply => match input.rhs {
94            Operand::Byte(_) | Operand::ByteStr(_) => panic!("Failed to multiply bytes with bytes"),
95            Operand::Int(i) => {
96                let parsed_rhs = i
97                    .base10_parse::<usize>()
98                    .expect("Failed to parse right hand side");
99                let mut out: Vec<u8> = Vec::with_capacity(lhs.len() * parsed_rhs);
100                for _ in 0..parsed_rhs {
101                    for x in &lhs {
102                        out.push(*x);
103                    }
104                }
105                out
106            }
107        },
108        Operation::Addition => match input.rhs {
109            Operand::Byte(b) => {
110                let mut x = lhs;
111                x.push(b.value());
112                x
113            }
114            Operand::ByteStr(b) => {
115                let mut x = lhs;
116                x.append(&mut b.value());
117                x
118            }
119            Operand::Int(_) => panic!("Failed to add bytes with int"),
120        },
121    };
122
123    let lit_output = syn::LitByteStr::new(&output, input.span);
124
125    let out = quote::quote! {
126        #lit_output
127    };
128
129    out.into()
130}
131
132struct PackOperation {
133    pub lhs: syn::LitInt,
134    pub bits: syn::LitInt,
135    pub span: Span,
136}
137
138impl Parse for PackOperation {
139    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
140        let span = input.span();
141        let lhs = input.parse()?;
142        input.parse::<syn::Token!(,)>()?;
143        let bits = input.parse()?;
144
145        Ok(Self { lhs, bits, span })
146    }
147}
148
149#[proc_macro]
150pub fn pack(ts: TokenStream) -> TokenStream {
151    let input = parse_macro_input!(ts as PackOperation);
152
153    let bits = input
154        .bits
155        .base10_parse::<usize>()
156        .expect("Failed to parse left hand side as int");
157
158    if bits == 0 || bits % 8 != 0 {
159        panic!("Right hand side is not a multiple of 8")
160    }
161
162    if bits > 128 {
163        panic!("Maximum of 128 bits supported")
164    }
165
166    let lhs_bytes = {
167        match bits {
168            8 => input
169                .lhs
170                .base10_parse::<u8>()
171                .expect("Failed to parse left hand side as int")
172                .to_ne_bytes()
173                .to_vec(),
174            16 => input
175                .lhs
176                .base10_parse::<u16>()
177                .expect("Failed to parse left hand side as int")
178                .to_ne_bytes()
179                .to_vec(),
180            32 => input
181                .lhs
182                .base10_parse::<u32>()
183                .expect("Failed to parse left hand side as int")
184                .to_ne_bytes()
185                .to_vec(),
186            64 => input
187                .lhs
188                .base10_parse::<u64>()
189                .expect("Failed to parse left hand side as int")
190                .to_ne_bytes()
191                .to_vec(),
192            128 => input
193                .lhs
194                .base10_parse::<u128>()
195                .expect("Failed to parse left hand side as int")
196                .to_ne_bytes()
197                .to_vec(),
198            _ => unreachable!(),
199        }
200    };
201
202    let bytes = bits / 8;
203
204    if lhs_bytes.len() > bytes {
205        panic!("Literal {} out of range {}", lhs_bytes.len(), bytes);
206    }
207
208    let mut output = Vec::with_capacity(bytes);
209
210    if cfg!(target_endian = "big") {
211        let cur_len = lhs_bytes.len();
212
213        if cur_len < bytes {
214            output.resize(bytes - cur_len, 0);
215        }
216
217        for x in lhs_bytes {
218            output.push(x);
219        }
220    } else {
221        let cur_len = lhs_bytes.len();
222
223        for x in lhs_bytes {
224            output.push(x);
225        }
226
227        if cur_len < bytes {
228            output.resize(bytes, 0);
229        }
230    }
231
232    let lit_output = syn::LitByteStr::new(&output, input.span);
233
234    let out = quote::quote! {
235        #lit_output
236    };
237
238    out.into()
239}