protospec_build/compiler/
mod.rs

1use crate::asg::*;
2use crate::coder;
3use crate::{BinaryOp, UnaryOp};
4use expr::*;
5use proc_macro2::TokenStream;
6use quote::TokenStreamExt;
7use quote::{format_ident, quote};
8use std::{sync::Arc, unimplemented};
9use case::CaseExt;
10
11mod decoder;
12mod encoder;
13mod expr;
14
15pub fn global_name(input: &str) -> String {
16    input.to_string()
17}
18
19#[derive(Clone, Debug)]
20pub struct CompileOptions {
21    pub enum_derives: Vec<String>,
22    pub struct_derives: Vec<String>,
23    pub include_async: bool,
24    pub use_anyhow: bool,
25    pub debug_mode: bool,
26}
27
28impl Default for CompileOptions {
29    fn default() -> Self {
30        Self {
31            include_async: false,
32            debug_mode: false,
33            enum_derives: vec![
34                "PartialEq".to_string(),
35                "Debug".to_string(),
36                "Clone".to_string(),
37                "Default".to_string(),
38            ],
39            struct_derives: vec![
40                "PartialEq".to_string(),
41                "Debug".to_string(),
42                "Clone".to_string(),
43                "Default".to_string(),
44            ],
45            use_anyhow: false,
46        }
47    }
48}
49
50impl CompileOptions {
51    fn emit_struct_derives(&self, extra: &[&str]) -> TokenStream {
52        let mut all: Vec<_> = self.struct_derives.iter().map(|x| &**x).collect();
53        all.extend_from_slice(extra);
54        all.sort();
55        all.dedup();
56
57        self.emit_derives(&all[..])
58    }
59
60    fn emit_enum_derives(&self, extra: &[&str]) -> TokenStream {
61        let mut all: Vec<_> = self.enum_derives.iter().map(|x| &**x).collect();
62        all.extend_from_slice(extra);
63        all.retain(|x| *x != "Default");
64        all.sort();
65        all.dedup();
66
67        self.emit_derives(&all[..])
68    }
69
70    fn emit_derives(&self, all: &[&str]) -> TokenStream {
71        if all.len() > 0 {
72            let items = flatten(
73                all.into_iter()
74                    .map(|x| {
75                        let ident = emit_ident(x);
76                        quote! {
77                            #ident,
78                        }
79                    })
80                    .collect::<Vec<_>>(),
81            );
82            quote! {
83                #[derive(#items)]
84            }
85        } else {
86            quote! {}
87        }
88    }
89}
90
91pub fn compile_program(program: &Program, options: &CompileOptions) -> TokenStream {
92    let mut components = vec![];
93    let errors = if options.use_anyhow {
94        quote! {
95            pub type Result<T> = anyhow::Result<T>;
96    
97            fn encode_error<S: AsRef<str>>(value: S) -> anyhow::Error {
98                anyhow::anyhow!("{}", value.as_ref())
99            }
100
101            fn decode_error<S: AsRef<str>>(value: S) -> anyhow::Error {
102                anyhow::anyhow!("{}", value.as_ref())
103            }
104        }
105    } else {
106        quote! {
107            use std::error::Error;
108            pub type Result<T> = std::result::Result<T, Box<dyn Error + Send + Sync + 'static>>;
109    
110            #[derive(Debug)]
111            pub struct DecodeError(pub String);
112            impl std::fmt::Display for DecodeError {
113                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114                    write!(f, "{}", self.0)
115                }
116            }
117            impl Error for DecodeError {}
118            #[derive(Debug)]
119            pub struct EncodeError(pub String);
120            impl std::fmt::Display for EncodeError {
121                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122                    write!(f, "{}", self.0)
123                }
124            }
125            impl Error for EncodeError {}    
126
127            fn encode_error<S: AsRef<str>>(value: S) -> EncodeError {
128                EncodeError(value.as_ref().to_string())
129            }
130
131            fn decode_error<S: AsRef<str>>(value: S) -> DecodeError {
132                DecodeError(value.as_ref().to_string())
133            }
134        }
135    };
136
137    components.push(quote! {
138        use std::io::{Read, BufRead, Cursor};
139        use std::slice;
140        use std::mem;
141        use std::convert::TryInto;
142
143        #errors
144    });
145    for (name, field) in program.types.iter() {
146        match &*field.type_.borrow() {
147            Type::Foreign(_) => continue,
148            Type::Container(item) => {
149                components.push(generate_container(&name, &**item, options));
150            }
151            Type::Enum(item) => {
152                components.push(generate_enum(&name, item, options));
153            }
154            Type::Bitfield(item) => {
155                components.push(generate_bitfield(&name, item, options));
156            }
157            generic => {
158                let ident = format_ident!("{}", global_name(name));
159                let type_ref = emit_type_ref(generic);
160                let type_ref = if field.condition.borrow().is_some() {
161                    quote! {
162                        Option<#type_ref>
163                    }
164                } else {
165                    type_ref
166                };
167                let derives = options.emit_struct_derives(&[]);
168
169                components.push(quote! {
170                    #derives
171                    pub struct #ident(pub #type_ref);
172                });
173            }
174        }
175        components.push(prepare_impls(&field, options));
176    }
177    let components = flatten(components);
178    quote! {
179        #[allow(unused_imports, unused_parens, unused_variables, dead_code, unused_mut, non_upper_case_globals)]
180        mod _ps {
181            #components
182        }
183        pub use _ps::*;
184    }
185}
186
187fn ref_resolver(_f: &Arc<Field>) -> TokenStream {
188    unimplemented!("cannot reference field in input default");
189}
190
191fn prepare_impls(field: &Arc<Field>, options: &CompileOptions) -> TokenStream {
192    let container_ident = format_ident!("{}", global_name(&field.name));
193
194    let mut decode_context = coder::decode::Context::new();
195    decode_context.decode_field_top(field);
196    let decode_sync = decoder::prepare_decoder(options, &decode_context, false);
197
198    let mut new_context = coder::encode::Context::new();
199    new_context.encode_field_top(field);
200
201    let encode_sync = encoder::prepare_encoder(&new_context, false);
202
203    let mut arguments = vec![];
204    let mut redefaults = vec![];
205    for argument in field.arguments.borrow().iter() {
206        let name = emit_ident(&argument.name);
207        let type_ref = emit_type_ref(&argument.type_);
208        let opt_type_ref = if argument.default_value.is_some() {
209            quote! { Option<#type_ref> }
210        } else {
211            type_ref.clone()
212        };
213        arguments.push(quote! {, #name: #opt_type_ref});
214        if let Some(default_value) = argument.default_value.as_ref() {
215            let emitted = emit_expression(default_value, &ref_resolver);
216            redefaults.push(quote! {
217                let #name: #type_ref = if let Some(#name) = #name {
218                    #name
219                } else {
220                    #emitted
221                };
222            })
223        }
224    }
225    let arguments = flatten(arguments);
226    let redefaults = flatten(redefaults);
227
228    let async_functions = if options.include_async {
229        let async_recursion = if field.is_maybe_cyclical.get() {
230            quote! {
231                #[async_recursion::async_recursion]
232            }
233        } else {
234            quote! {}
235        };
236
237        let encode_async = encoder::prepare_encoder(&new_context, true);
238        let decode_async = decoder::prepare_decoder(options, &decode_context, true);
239        quote! {
240            #async_recursion
241            pub async fn encode_async<W: tokio::io::AsyncWrite + Send + Sync + Unpin>(&self, writer: &mut W #arguments) -> Result<()> {
242                #redefaults
243                #encode_async
244            }
245
246            #async_recursion
247            pub async fn decode_async<R: tokio::io::AsyncBufRead + Send + Sync + Unpin>(reader: &mut R #arguments) -> Result<Self> {
248                #redefaults
249                #decode_async
250            }
251        }
252    } else {
253        quote! {}
254    };
255
256    quote! {
257        impl #container_ident {
258            pub fn decode_sync<R: Read + BufRead>(reader: &mut R #arguments) -> Result<Self> {
259                #redefaults
260                #decode_sync
261            }
262
263            pub fn encode_sync<W: std::io::Write>(&self, writer: &mut W #arguments) -> Result<()> {
264                #redefaults
265                #encode_sync
266            }
267
268            #async_functions
269        }
270    }
271}
272
273fn emit_ident(name: &str) -> TokenStream {
274    let ident = format_ident!("{}", name);
275    quote! {
276        #ident
277    }
278}
279
280fn emit_register(register: usize) -> TokenStream {
281    let ident = format_ident!("r_{}", register);
282    quote! {
283        #ident
284    }
285}
286
287fn flatten<T: IntoIterator<Item = TokenStream>>(iter: T) -> TokenStream {
288    let mut out = quote! {};
289    out.append_all(iter);
290    out
291}
292
293pub fn emit_type_ref(item: &Type) -> TokenStream {
294    match item {
295        Type::Container(_) => unimplemented!(),
296        Type::Enum(_) => unimplemented!(),
297        Type::Bitfield(_) => unimplemented!(),
298        Type::Scalar(s) => emit_ident(&s.to_string()),
299        Type::Array(array_type) => {
300            let interior = emit_type_ref(&array_type.element.type_.borrow());
301            quote! {
302                Vec<#interior>
303            }
304        }
305        Type::Foreign(f) => f.obj.type_ref(),
306        Type::F32 => emit_ident("f32"),
307        Type::F64 => emit_ident("f64"),
308        Type::Bool => emit_ident("bool"),
309        Type::Ref(field) => match &*field.target.type_.borrow() {
310            Type::Foreign(c) => c.obj.type_ref(),
311            _ => emit_ident(&global_name(&field.target.name)),
312        },
313    }
314}
315
316fn generate_container_fields(access: TokenStream, item: &ContainerType) -> TokenStream {
317    let mut fields = vec![];
318    for (name, field) in item.flatten_view() {
319        if field.is_pad.get() {
320            continue;
321        }
322        let name_ident = format_ident!("{}", name);
323        let type_ref = emit_type_ref(&field.type_.borrow());
324        let type_ref = if field.condition.borrow().is_some() {
325            quote! {
326                Option<#type_ref>
327            }
328        } else {
329            type_ref
330        };
331
332        fields.push(quote! {
333            #access #name_ident: #type_ref,
334        });
335    }
336    flatten(fields)
337}
338
339pub fn generate_container(
340    name: &str,
341    item: &ContainerType,
342    options: &CompileOptions,
343) -> TokenStream {
344    let name_ident = format_ident!("{}", global_name(name));
345    if item.is_enum.get() {
346        let derives = options.emit_enum_derives(&[]);
347        let mut fields = vec![];
348        for (name, field) in &item.items {
349            let name_ident = format_ident!("{}", name);
350            let type_ = field.type_.borrow();
351            let type_ref = match &*type_ {
352                Type::Container(sub_container) => {
353                    let subfields = generate_container_fields(quote! { }, &**sub_container);
354                    quote! {
355                        {
356                            #subfields
357                        }
358                    }
359                },
360                type_ => {
361                    let emitted = emit_type_ref(type_);
362                    quote! { (#emitted) }
363                }
364            };
365    
366            fields.push(quote! {
367                #name_ident#type_ref,
368            });
369        }
370        let fields = flatten(fields);
371
372        let default_impl = if options.enum_derives.iter().any(|x| x == "Default") {
373            let (default_field, field) = item.items.first().expect("missing enum entry for default");
374            let default_field = format_ident!("{}", default_field);
375
376            let type_ = field.type_.borrow();
377            let default_value = match &*type_ {
378                Type::Container(sub_container) => {
379                    let mut fields = vec![];
380                    for (name, _) in sub_container.flatten_view() {
381                        let name_ident = format_ident!("{}", name);
382                
383                        fields.push(quote! {
384                            #name_ident: Default::default(),
385                        });
386                    }
387                    let fields = flatten(fields);
388                    quote! {
389                        {
390                            #fields
391                        }
392                    }
393                },
394                _ => {
395                    quote! { (Default::default()) }
396                }
397            };
398
399            quote! {
400                impl Default for #name_ident {
401                    fn default() -> Self {
402                        Self::#default_field#default_value
403                    }
404                }
405            }
406        } else {
407            quote! {}
408        };
409
410        quote! {
411            #derives
412            pub enum #name_ident {
413                #fields
414            }
415
416            #default_impl
417        }
418    } else {
419        let derives = options.emit_struct_derives(&[]);
420        let fields = generate_container_fields(quote! { pub }, item);
421    
422        quote! {
423            #derives
424            pub struct #name_ident {
425                #fields
426            }
427        }
428    }
429}
430
431pub fn generate_enum(name: &str, item: &EnumType, options: &CompileOptions) -> TokenStream {
432    let name_ident = format_ident!("{}", global_name(name));
433    let mut fields = vec![];
434    let mut from_repr_matches = vec![];
435    for (name, cons) in item.items.iter() {
436        let value_ident = format_ident!("{}", name);
437        let value = eval_const_expression(&cons.value);
438        if value.is_none() {
439            unimplemented!("could not resolve constant expression");
440        }
441        let value = value.unwrap();
442        let value = value.emit();
443        fields.push(quote! {
444            #value_ident = #value,
445        });
446        from_repr_matches.push(quote! {
447            #value => Ok(#name_ident::#value_ident),
448        })
449    }
450    let fields = flatten(fields);
451
452    let from_repr_matches = flatten(from_repr_matches);
453    let rep = format_ident!("{}", item.rep.to_string());
454    let rep_size = item.rep.size() as usize;
455    let derives = options.emit_enum_derives(&["Clone", "Copy"]);
456
457    let format_string = format!("illegal enum value '{{}}' for enum '{}'", name);
458
459    let default_impl = if options.enum_derives.iter().any(|x| x == "Default") {
460        let (default_field, _) = item.items.first().expect("missing enum entry for default");
461        let default_field = format_ident!("{}", default_field);
462        quote! {
463            impl Default for #name_ident {
464                fn default() -> Self {
465                    Self::#default_field
466                }
467            }
468        }
469    } else {
470        quote! {}
471    };
472
473    quote! {
474        #[repr(#rep)]
475        #derives
476        pub enum #name_ident {
477            #fields
478        }
479
480        impl #name_ident {
481            pub fn from_repr(repr: #rep) -> Result<Self> {
482                match repr {
483                    #from_repr_matches
484                    x => Err(decode_error(format!(#format_string, x)).into()),
485                }
486            }
487
488            pub fn to_be_bytes(&self) -> [u8; #rep_size] {
489                (*self as #rep).to_be_bytes()
490            }
491        }
492
493        #default_impl
494    }
495}
496
497pub fn generate_bitfield(name: &str, item: &BitfieldType, options: &CompileOptions) -> TokenStream {
498    let name_ident = format_ident!("{}", global_name(name));
499    let mut fields = vec![];
500    let mut funcs = vec![];
501    let mut all_fields = ConstInt::parse(item.rep, "0", crate::Span::default()).unwrap();
502    let zero = all_fields;
503
504    for (name, cons) in item.items.iter() {
505        let name_ident = format_ident!("{}", name.to_snake().to_uppercase());
506        let get_name = format_ident!("{}", name.to_snake());
507        let set_name = format_ident!("set_{}", name.to_snake());
508        let value = eval_const_expression(&cons.value);
509        if value.is_none() {
510            unimplemented!("could not resolve constant expression");
511        }
512        let value = value.unwrap();
513        let int_value = match &value {
514            ConstValue::Int(x) => *x,
515            _ => panic!("invalid const value type"),
516        };
517        if (int_value & all_fields).unwrap() != zero {
518            panic!("overlapping bit fields");
519        }
520        all_fields = (all_fields | int_value).unwrap();
521
522        let value = value.emit();
523        fields.push(quote! {
524            pub const #name_ident: Self = Self(#value);
525        });
526        funcs.push(quote! {
527            pub fn #get_name(&self) -> bool {
528                (*self & Self::#name_ident) != Self::ZERO
529            }
530
531            pub fn #set_name(&mut self) {
532                *self = *self | Self::#name_ident;
533            }
534        });
535    }
536    let fields = flatten(fields);
537    let funcs = flatten(funcs);
538
539    let rep = format_ident!("{}", item.rep.to_string());
540    let rep_size = item.rep.size() as usize;
541    let derives = options.emit_struct_derives(&["Clone", "Copy", "Default"]);
542
543    let format_string = format!("illegal bitfield value '{{}}' for bitfield '{}'", name);
544    let all_fields = ConstValue::Int(all_fields).emit();
545
546    quote! {
547        #[repr(transparent)]
548        #derives
549        pub struct #name_ident(pub #rep);
550
551        impl #name_ident {
552            #fields
553            pub const ALL: Self = Self(#all_fields);
554            pub const ZERO: Self = Self(0);
555
556            pub fn from_repr(repr: #rep) -> Result<Self> {
557                if (repr & !Self::ALL.0) != 0 {
558                    Err(decode_error(format!(#format_string, repr)).into())
559                } else {
560                    Ok(Self(repr))
561                }
562            }
563
564            pub fn to_be_bytes(&self) -> [u8; #rep_size] {
565                self.0.to_be_bytes()
566            }
567
568            #funcs
569        }
570
571        impl core::ops::BitOr for #name_ident {
572            type Output = Self;
573            fn bitor(self, rhs: Self) -> Self {
574                Self(self.0 | rhs.0)
575            }
576        }
577
578        impl core::ops::BitOrAssign for #name_ident {
579            fn bitor_assign(&mut self, rhs: Self) {
580                *self = *self | rhs;
581            }
582        }
583
584        impl core::ops::BitAnd for #name_ident {
585            type Output = Self;
586            fn bitand(self, rhs: Self) -> Self {
587                Self(self.0 & rhs.0)
588            }
589        }
590
591        impl core::ops::BitAndAssign for #name_ident {
592            fn bitand_assign(&mut self, rhs: Self) {
593                *self = *self & rhs;
594            }
595        }
596
597        impl core::ops::BitXor for #name_ident {
598            type Output = Self;
599            fn bitxor(self, rhs: Self) -> Self {
600                Self(self.0 ^ rhs.0)
601            }
602        }
603
604        impl core::ops::BitXorAssign for #name_ident {
605            fn bitxor_assign(&mut self, rhs: Self) {
606                *self = *self ^ rhs;
607            }
608        }
609
610        impl core::ops::Not for #name_ident {
611            type Output = Self;
612            fn not(self) -> Self {
613                Self(!self.0)
614            }
615        }
616    }
617}