fbs_build/
codegen.rs

1use crate::{ast::types as ast, ir::types as ir};
2
3use fbs::VOffsetT;
4use heck::{ShoutySnakeCase, SnakeCase};
5use itertools::Itertools;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use std::{convert::TryInto, fmt};
9use syn::spanned::Spanned;
10
11#[cfg(test)]
12fn to_code(value: impl ToTokens) -> String {
13    format!("{}", value.to_token_stream())
14}
15
16#[cfg(test)]
17mod constant_tests {
18    use super::*;
19
20    #[test]
21    fn test_visit_string_constant() {
22        let result = to_code("abc");
23        let expected = "\"abc\"";
24        assert_eq!(result, expected);
25    }
26
27    #[test]
28    fn test_visit_bool_constant() {
29        let result = to_code(true);
30        let expected = "true";
31        assert_eq!(result, expected);
32
33        let result = to_code(false);
34        let expected = "false";
35        assert_eq!(result, expected);
36    }
37}
38
39impl ToTokens for ast::Ident<'_> {
40    fn to_tokens(&self, tokens: &mut TokenStream) {
41        format_ident!("{}", self.raw).to_tokens(tokens)
42    }
43}
44
45#[cfg(test)]
46mod ident_tests {
47    use super::*;
48
49    #[test]
50    fn test_visit_ident() {
51        let result = to_code(ast::Ident::from("foo"));
52        let expected = "foo";
53        assert_eq!(result, expected);
54    }
55}
56
57impl ToTokens for ir::Ident<'_> {
58    fn to_tokens(&self, tokens: &mut TokenStream) {
59        format_ident!("{}", self.raw.as_ref()).to_tokens(tokens)
60    }
61}
62
63#[cfg(test)]
64mod ir_ident_tests {
65    use super::*;
66
67    #[test]
68    fn test_visit_ident() {
69        let result = to_code(ir::Ident::from("foo"));
70        let expected = "foo";
71        assert_eq!(result, expected);
72    }
73}
74
75/// Convert a `types::Type` to a type with the supplied wrapper for reference types
76fn to_type_token(
77    context_namespace: Option<&ir::QualifiedIdent<'_>>,
78    ty: &ir::Type<'_>,
79    lifetime: &TokenStream,
80    wrap_refs_types: &TokenStream,
81    wrap_outer: bool,
82) -> TokenStream {
83    let empty_lifetime = quote!();
84
85    let lifetime_is_named = {
86        if lifetime.is_empty() {
87            false
88        } else {
89            let parsed_lifetime =
90                syn::parse2::<syn::Lifetime>(lifetime.clone()).expect("lifetime must be valid");
91
92            let anonymous_lifetime = syn::Ident::new("_", proc_macro2::Span::call_site());
93            parsed_lifetime.ident != anonymous_lifetime
94        }
95    };
96    // Lifetime for references (&str, &[u8], ...)
97    let ref_lifetime = if lifetime_is_named {
98        lifetime
99    } else {
100        &empty_lifetime
101    };
102
103    match ty {
104        ir::Type::Bool => quote!(bool),
105        ir::Type::Byte => quote!(i8),
106        ir::Type::UByte => quote!(u8),
107        ir::Type::Short => quote!(i16),
108        ir::Type::UShort => quote!(u16),
109        ir::Type::Int => quote!(i32),
110        ir::Type::UInt => quote!(u32),
111        ir::Type::Float => quote!(f32),
112        ir::Type::Long => quote!(i64),
113        ir::Type::ULong => quote!(u64),
114        ir::Type::Double => quote!(f64),
115        ir::Type::Int8 => quote!(i8),
116        ir::Type::UInt8 => quote!(u8),
117        ir::Type::Int16 => quote!(i16),
118        ir::Type::UInt16 => quote!(u16),
119        ir::Type::Int32 => quote!(i32),
120        ir::Type::UInt32 => quote!(u32),
121        ir::Type::Int64 => quote!(i64),
122        ir::Type::UInt64 => quote!(u64),
123        ir::Type::Float32 => quote!(f32),
124        ir::Type::Float64 => quote!(f64),
125        ir::Type::String => {
126            let wrap_tokens = wrap_refs_types.into_token_stream();
127
128            if wrap_tokens.is_empty() || !wrap_outer {
129                quote!(&#ref_lifetime str)
130            } else {
131                quote!(#wrap_tokens::<&#ref_lifetime str>)
132            }
133        }
134        ir::Type::Array(ty) => {
135            // Arrays wrap the wrapping tokens with Vector
136            let component_token = to_type_token(
137                context_namespace,
138                ty,
139                lifetime,
140                &quote!(fbs::ForwardsUOffset),
141                true,
142            );
143            let ty = quote!(fbs::Vector<#lifetime, #component_token>);
144
145            let wrap_tokens = wrap_refs_types.into_token_stream();
146            if wrap_tokens.is_empty() || !wrap_outer {
147                ty
148            } else {
149                quote!(#wrap_tokens::<#ty>)
150            }
151        }
152        ir::Type::Custom(ir::CustomTypeRef { ident, ty }) => {
153            // Scalar types are never wrapped and have no lifetimes
154            let ident = &ident.relative(context_namespace);
155            if ty.is_scalar() {
156                quote!(#ident)
157            } else {
158                let ty = if ty == &ir::CustomType::Table {
159                    quote!(#ident<&#ref_lifetime [u8]>) // handle structs
160                } else {
161                    quote!(#ident<#lifetime>)
162                };
163                let wrap_tokens = wrap_refs_types.into_token_stream();
164                if wrap_tokens.is_empty() || !wrap_outer {
165                    ty
166                } else {
167                    quote!(#wrap_tokens::<#ty>)
168                }
169            }
170        }
171    }
172}
173
174/// A trait to format literals according to a provided type.
175trait PrimitiveLiteralToken {
176    /// The type of the token that will be returned formatted according to the provided type.
177    type TokenType;
178
179    fn fmt_lit(&self, ty: impl Spanned + fmt::Display) -> Self::TokenType;
180}
181
182impl PrimitiveLiteralToken for ast::IntegerConstant {
183    type TokenType = syn::LitInt;
184
185    fn fmt_lit(&self, ty: impl Spanned + fmt::Display) -> Self::TokenType {
186        Self::TokenType::new(&format!("{}{}", self, ty), ty.span())
187    }
188}
189
190impl PrimitiveLiteralToken for ast::FloatingConstant {
191    type TokenType = syn::LitFloat;
192
193    fn fmt_lit(&self, ty: impl Spanned + fmt::Display) -> Self::TokenType {
194        Self::TokenType::new(
195            &match ty.to_string().as_str() {
196                float_ty @ "f32" => format!("{}{}", *self as f32, float_ty),
197                float_ty => format!("{}{}", self, float_ty),
198            },
199            ty.span(),
200        )
201    }
202}
203
204/// Convert a `types::ast::DefaultValue` to a default field value
205fn to_default_value(
206    arg_ty: &impl ToTokens,
207    default_value: &ast::DefaultValue<'_>,
208) -> impl ToTokens + fmt::Debug + ToString {
209    match default_value {
210        // Scalar field
211        ast::DefaultValue::Scalar(s) => match s {
212            ast::Scalar::Integer(i) => i.fmt_lit(arg_ty.to_token_stream()).to_token_stream(),
213            ast::Scalar::Float(f) => f.fmt_lit(arg_ty.to_token_stream()).to_token_stream(),
214            ast::Scalar::Boolean(b) => b.to_token_stream(),
215        },
216        // Enum field default variant
217        ast::DefaultValue::Ident(i) => {
218            let variant = format_ident!("{}", i.raw);
219            quote!(<#arg_ty>::#variant).to_token_stream()
220        }
221    }
222}
223
224/// Convert a `types::ast::DefaultValue` to a doc comment describing the value
225fn to_default_value_doc(
226    ty: &ir::Type<'_>,
227    default_value: &Option<ast::DefaultValue<'_>>,
228) -> impl ToTokens {
229    default_value.as_ref().map_or_else(
230        || quote!(),
231        |default_value| {
232            let doc_value = match default_value {
233                // Scalar field
234                ast::DefaultValue::Scalar(s) => match s {
235                    ast::Scalar::Integer(i) => i
236                        .fmt_lit(to_type_token(None, ty, &quote!(), &quote!(), false))
237                        .to_string(),
238                    ast::Scalar::Float(f) => f
239                        .fmt_lit(to_type_token(None, ty, &quote!(), &quote!(), false))
240                        .to_string(),
241                    ast::Scalar::Boolean(b) => b.to_token_stream().to_string(),
242                },
243                // Enum field default variant
244                ast::DefaultValue::Ident(i) => format!("{}::{}", ty, i.raw),
245            };
246            let doc_string = format!(" The default value for this field is __{}__", doc_value);
247            quote!(#[doc = #doc_string])
248        },
249    )
250}
251
252#[cfg(test)]
253mod to_default_value_tests {
254    use super::{quote, to_default_value, to_type_token};
255    use crate::{ast::types as ast, ir::types as ir};
256
257    macro_rules! generate_min_max_tests {
258        ($kind:path, $base_type:ty, $fbs_ir_type:ident, $rust_type:ident $(,)?) => {
259            #[test]
260            fn test_to_default_value_min() {
261                let raw_value = std::$rust_type::MIN;
262                let rust_value = raw_value as $base_type;
263                let value = ast::DefaultValue::Scalar($kind(rust_value));
264                let ty =
265                    to_type_token(None, &ir::Type::$fbs_ir_type, &quote!(), &quote!(), false);
266                let result = to_default_value(&ty, &value).to_string();
267                let expected = format!("{}{}", raw_value, stringify!($rust_type));
268                assert_eq!(result, expected);
269            }
270
271            #[test]
272            fn test_to_default_value_max() {
273                let raw_value = std::$rust_type::MAX;
274                let rust_value = raw_value as $base_type;
275                let value = ast::DefaultValue::Scalar($kind(rust_value));
276                let ty =
277                    to_type_token(None, &ir::Type::$fbs_ir_type, &quote!(), &quote!(), false);
278                let result = to_default_value(&ty, &value).to_string();
279                let expected = format!("{}{}", raw_value, stringify!($rust_type));
280                assert_eq!(result, expected);
281            }
282        };
283    }
284
285    macro_rules! base_type {
286        (Integer) => {
287            $crate::ast::types::IntegerConstant
288        };
289        (Float) => {
290            $crate::ast::types::FloatingConstant
291        };
292    }
293
294    macro_rules! generate_numeric_primitive_tests {
295        ([$(($ast_type:ident, $fbs_ir_type:ident, $rust_type:ident)),*$(,)?]) => {
296            $(
297                #[allow(non_snake_case)]
298                mod $fbs_ir_type {
299                    use super::*;
300                    generate_min_max_tests!(
301                        $crate::ast::types::Scalar::$ast_type,
302                        base_type!($ast_type),
303                        $fbs_ir_type,
304                        $rust_type
305                    );
306                }
307            )*
308        };
309    }
310
311    generate_numeric_primitive_tests!([
312        (Integer, Byte, i8),
313        (Integer, UByte, u8),
314        (Integer, Short, i16),
315        (Integer, UShort, u16),
316        (Integer, Int, i32),
317        (Integer, UInt, u32),
318        (Integer, Long, i64),
319        (Integer, ULong, u64),
320        (Integer, Int8, i8),
321        (Integer, UInt8, u8),
322        (Integer, Int16, i16),
323        (Integer, UInt16, u16),
324        (Integer, Int32, i32),
325        (Integer, UInt32, u32),
326        (Integer, Int64, i64),
327        (Integer, UInt64, u64),
328        (Float, Float, f32),
329        (Float, Double, f64),
330        (Float, Float32, f32),
331        (Float, Float64, f64),
332    ]);
333
334    #[test]
335    fn test_to_default_value_true() {
336        let value = ast::DefaultValue::Scalar(ast::Scalar::Boolean(true));
337        let ty = to_type_token(None, &ir::Type::Bool, &quote!(), &quote!(), false);
338        let result = to_default_value(&ty, &value).to_string();
339        assert_eq!(result, "true".to_string());
340    }
341
342    #[test]
343    fn test_to_default_value_false() {
344        let value = ast::DefaultValue::Scalar(ast::Scalar::Boolean(false));
345        let ty = to_type_token(None, &ir::Type::Bool, &quote!(), &quote!(), false);
346        let result = to_default_value(&ty, &value).to_string();
347        assert_eq!(result, "false".to_string());
348    }
349}
350
351fn offset_id(field: &ir::Field<'_>) -> impl ToTokens {
352    format_ident!("VT_{}", field.ident.as_ref().to_shouty_snake_case())
353}
354
355impl ToTokens for ir::Table<'_> {
356    fn to_tokens(&self, tokens: &mut TokenStream) {
357        let Self {
358            ident: struct_qualified_id,
359            fields,
360            doc,
361            ..
362        } = self;
363
364        let table_ns = struct_qualified_id.namespace();
365        let table_ns_ref = table_ns.as_ref();
366        let struct_id = struct_qualified_id.simple(); // discard namespace
367        let raw_struct_name = struct_id.raw.as_ref();
368
369        let builder_add_calls = fields.iter().map(
370            |ir::Field {
371                 ident: field_id,
372                 ty,
373                 metadata,
374                 ..
375             }| {
376                let raw_field_name = field_id.raw.as_ref();
377                let add_field_method = format_ident!("add_{}", raw_field_name);
378
379                if metadata.required || ty.is_scalar() {
380                    quote!(builder.#add_field_method(args.#field_id);)
381                } else {
382                    quote! {
383                        if let Some(x) = args.#field_id { builder.#add_field_method(x); }
384                    }
385                }
386            },
387        );
388
389        let args = format_ident!("{}Args", raw_struct_name);
390        let args_fields = fields.iter().map(
391            |ir::Field {
392                 ident: field_id,
393                 ty,
394                 default_value,
395                 metadata,
396                 ..
397             }| {
398                let arg_ty = if ty.is_union() {
399                    quote!(fbs::WIPOffset<fbs::UnionWIPOffset>)
400                } else {
401                    let arg_ty = to_type_token(
402                        table_ns_ref,
403                        ty,
404                        &quote!('a),
405                        &quote!(fbs::WIPOffset),
406                        true,
407                    );
408                    quote!(#arg_ty)
409                };
410
411                let arg_ty = if metadata.required || ty.is_scalar() {
412                    arg_ty
413                } else {
414                    quote!(Option<#arg_ty>)
415                };
416
417                let allow_type_complexity = if ty.is_complex() {
418                    quote!(#[allow(clippy::type_complexity)])
419                } else {
420                    quote!()
421                };
422
423                // Scalar or enum fields can have a default value
424                let default_doc = to_default_value_doc(&ty, default_value);
425                quote! {
426                    #default_doc
427                    #allow_type_complexity
428                    pub #field_id: #arg_ty
429                }
430            },
431        );
432        let args_lifetime = |lifetime_name| {
433            if fields
434                .iter()
435                .any(|f| !(f.ty.is_scalar() || f.ty.is_union()))
436            {
437                quote!(<#lifetime_name>)
438            } else {
439                quote!()
440            }
441        };
442        let args_lifetime_a = args_lifetime(quote!('a));
443        let args_lifetime_args = args_lifetime(quote!('args));
444
445        // Can we implement `Default` on this table?
446        // True if all the fields are either scalar or optional
447        // Scalar fields must always implement `Default`
448        let args_can_derive_default = fields
449            .iter()
450            .all(|field| !field.metadata.required || field.ty.is_scalar());
451
452        let args_default_impl = if args_can_derive_default {
453            let args_fields_defaults = fields.iter().map(
454                |ir::Field {
455                     ident: field_id,
456                     ty,
457                     default_value,
458                     ..
459                 }| {
460                    let arg_ty = to_type_token(table_ns_ref, ty, &quote!(), &quote!(), false);
461                    if !ty.is_scalar() {
462                        // optional non-scalar types default to None
463                        quote!(#field_id: None)
464                    } else if let Some(default_value) = default_value {
465                        // Handle customized default values
466                        if ty.is_enum() {
467                            let default_name = if let ast::DefaultValue::Ident(i) = default_value {
468                                format_ident!("{}", i.raw)
469                            } else {
470                                panic!("expecting default ident for enum")
471                            };
472                            quote!(#field_id: <#arg_ty>::#default_name)
473                        }
474                        // TODO: handle structs
475                        else {
476                            // numeric types
477                            let default_val = match default_value {
478                                ast::DefaultValue::Scalar(s) => match s {
479                                    ast::Scalar::Integer(i) => i.fmt_lit(arg_ty).to_token_stream(),
480                                    ast::Scalar::Float(f) => f.fmt_lit(arg_ty).to_token_stream(),
481                                    ast::Scalar::Boolean(b) => quote!(#b),
482                                },
483                                _ => panic!("expecting numeric default"),
484                            };
485                            quote!(#field_id: #default_val)
486                        }
487                    } else {
488                        // no custom default value, default to the scalar type's default
489                        quote!(#field_id: <#arg_ty>::default())
490                    }
491                },
492            );
493            quote! {
494                impl#args_lifetime_a Default for #args#args_lifetime_a {
495                    fn default() -> Self {
496                        Self {
497                            #(#args_fields_defaults),*
498                        }
499                    }
500                }
501            }
502        } else {
503            quote!()
504        };
505
506        let builder_type = format_ident!("{}Builder", raw_struct_name);
507
508        let builder_field_methods = fields.iter().map(|field| {
509            let ir::Field {
510                ident: field_id,
511                ty,
512                default_value,
513                ..
514            } = field;
515            let add_method_name = format_ident!("add_{}", field_id.raw.as_ref());
516            let offset = offset_id(field);
517            let field_offset = quote!(#struct_id::#offset);
518
519            let allow_type_complexity = if ty.is_complex() {
520                quote!(#[allow(clippy::type_complexity)])
521            } else {
522                quote!()
523            };
524
525            let arg_ty = if ty.is_union() {
526                quote!(fbs::WIPOffset<fbs::UnionWIPOffset>)
527            } else {
528                let arg_ty = to_type_token(
529                    table_ns_ref,
530                    ty,
531                    &quote!('b),
532                    &quote!(fbs::WIPOffset),
533                    true,
534                );
535                quote!(#arg_ty)
536            };
537
538            let body = if ty.is_scalar() {
539                if let Some(default_value) = default_value {
540                    let default_value = to_default_value(&arg_ty, &default_value);
541                    quote!(self.fbb.push_slot::<#arg_ty>(#field_offset, #field_id, #default_value))
542                } else {
543                    quote!(self.fbb.push_slot_always::<#arg_ty>(#field_offset, #field_id))
544                }
545            } else {
546                quote!(self.fbb.push_slot_always::<#arg_ty>(#field_offset, #field_id))
547            };
548
549            quote! {
550                #[inline]
551                #allow_type_complexity
552                fn #add_method_name(&mut self, #field_id: #arg_ty) {
553                    #body;
554                }
555            }
556        });
557
558        let field_offset_constants = fields.iter().enumerate().map(|(index, field)| {
559            let offset_name = offset_id(field);
560            let offset_value = fbs::field_index_to_field_offset(index as VOffsetT);
561            quote! {
562                pub const #offset_name: fbs::VOffsetT = #offset_value;
563            }
564        });
565
566        let field_accessors = fields.iter().map(|field| {
567            let ir::Field {
568                ident,
569                ty,
570                metadata,
571                doc,
572                default_value,
573                ..
574            } = field;
575            let offset_name = offset_id(field);
576            let snake_name = format_ident!("{}", ident.as_ref().to_snake_case());
577            let snake_name_str = snake_name.to_string();
578            let ty_ret = to_type_token(
579                table_ns_ref,
580                ty,
581                &quote!('_),
582                &quote!(fbs::ForwardsUOffset),
583                false,
584            );
585            let ty_wrapped = to_type_token(
586                table_ns_ref,
587                ty,
588                &quote!('_),
589                &quote!(fbs::ForwardsUOffset),
590                true,
591            );
592            let default_value = if let Some(default_value) = default_value {
593                let default_value = to_default_value(&ty_wrapped, &default_value);
594                quote!(Some(#default_value))
595            } else {
596                quote!(None)
597            };
598            let allow_type_complexity = if ty.is_complex() {
599                quote!(#[allow(clippy::type_complexity)])
600            } else {
601                quote!()
602            };
603
604            if ty.is_union() {
605                let (union_ident, enum_ident, variants) = match ty {
606                    ir::Type::Custom(ir::CustomTypeRef {
607                        ty,
608                        ident: ref union_ident,
609                    }) => match ty {
610                        ir::CustomType::Union {
611                            ref variants,
612                            ref enum_ident,
613                        } => (union_ident, enum_ident, variants),
614                        _ => panic!("type is union"),
615                    },
616                    _ => panic!("type is union"),
617                };
618
619                let union_ident = union_ident.relative(table_ns_ref);
620                let enum_ident = enum_ident.relative(table_ns_ref);
621
622                let type_snake_name =
623                    format_ident!("{}_type", field.ident.as_ref().to_snake_case());
624
625                if field.metadata.required {
626                    let names_to_enum_variant = variants.iter().map(
627                        |ir::UnionVariant {
628                             ident: variant_ident,
629                             ty: variant_ty,
630                             ..
631                         }| {
632                            let variant_ty_wrapped = to_type_token(table_ns_ref,
633                                variant_ty,
634                                &quote!('_),
635                                &quote!(fbs::ForwardsUOffset),
636                                true,
637                            );
638                            quote! {
639                                #enum_ident::#variant_ident => #union_ident::#variant_ident(self.table
640                                    .get::<#variant_ty_wrapped>(#struct_id::#offset_name, None)?
641                                    .ok_or_else(|| fbs::Error::RequiredFieldMissing(#snake_name_str))?)
642                            }
643                        },
644                    );
645                    quote! {
646                        #[inline]
647                        pub fn #snake_name(&self) -> Result<#ty_ret, fbs::Error> {
648                            Ok(match self.#type_snake_name()? {
649                              #(#names_to_enum_variant,)*
650                              #enum_ident::None => return Err(fbs::Error::RequiredFieldMissing(#snake_name_str))
651                            })
652                        }
653                    }
654                } else {
655                    let names_to_enum_variant = variants.iter().map(
656                        |ir::UnionVariant {
657                             ident: variant_ident,
658                             ty: variant_ty,
659                             ..
660                         }| {
661                            let variant_ty_wrapped = to_type_token(table_ns_ref,
662                                variant_ty,
663                                &quote!('_),
664                                &quote!(fbs::ForwardsUOffset),
665                                true,
666                            );
667                            quote! {
668                                Some(#enum_ident::#variant_ident) => self.table
669                                    .get::<#variant_ty_wrapped>(#struct_id::#offset_name, None)?
670                                    .map(#union_ident::#variant_ident)
671                            }
672                        },
673                    );
674
675                    quote! {
676                        #[inline]
677                        pub fn #snake_name(&self) -> Result<Option<#ty_ret>, fbs::Error> {
678                            Ok(match self.#type_snake_name()? {
679                              #(#names_to_enum_variant,)*
680                              None | Some(#enum_ident::None) => None
681                            })
682                        }
683                    }
684                }
685            } else if metadata.required {
686                quote! {
687                    #doc
688                    #[inline]
689                    #allow_type_complexity
690                    pub fn #snake_name(&self) -> Result<#ty_ret, fbs::Error> {
691                        Ok(self.table
692                            .get::<#ty_wrapped>(#struct_id::#offset_name, #default_value)?
693                            .ok_or_else(|| fbs::Error::RequiredFieldMissing(#snake_name_str))?)
694                    }
695                }
696            } else {
697                quote! {
698                    #doc
699                    #[inline]
700                    #allow_type_complexity
701                    pub fn #snake_name(&self) -> Result<Option<#ty_ret>, fbs::Error> {
702                        self.table
703                            .get::<#ty_wrapped>(#struct_id::#offset_name, #default_value)
704                    }
705                }
706            }
707        });
708
709        let struct_offset_enum_name = format_ident!("{}Offset", raw_struct_name);
710
711        let required_fields = fields
712            .iter()
713            .filter(|field| field.metadata.required)
714            .map(|field| {
715                let snake_name = field.ident.as_ref().to_snake_case();
716                let offset_name = offset_id(field);
717                quote! {
718                    self.fbb.required(o, #struct_id::#offset_name, #snake_name);
719                }
720            });
721
722        (quote! {
723            pub enum #struct_offset_enum_name {}
724
725            #[derive(Copy, Clone, Debug, PartialEq)]
726            #doc
727            pub struct #struct_id<B> {
728                table: fbs::Table<B>,
729            }
730
731            impl<B> From<fbs::Table<B>> for #struct_id<B> {
732                fn from(table: fbs::Table<B>) -> Self {
733                    Self { table }
734                }
735            }
736
737            impl<B> From<#struct_id<B>> for fbs::Table<B> {
738                fn from(s: #struct_id<B>) -> Self {
739                    s.table
740                }
741            }
742
743            impl<'a> #struct_id<&'a [u8]> {
744                // field offset constants
745                #(#field_offset_constants)*
746
747                pub fn create<'bldr: 'args, 'args: 'mut_bldr, 'mut_bldr>(
748                    fbb: &'mut_bldr mut fbs::FlatBufferBuilder<'bldr>,
749                    args: &'args #args#args_lifetime_args
750                ) -> fbs::WIPOffset<#struct_id<&'bldr [u8]>> {
751                    let mut builder = #builder_type::new(fbb);
752                    #(#builder_add_calls)*
753                    builder.finish()
754                }
755            }
756
757
758            impl<B> #struct_id<B>
759            where
760                B: std::convert::AsRef<[u8]>
761            {
762                // fields access
763                #(#field_accessors)*
764
765                pub fn get_root(buf: B) -> Result<Self, fbs::Error> {
766                    let table = fbs::Table::get_root(buf)?;
767                    Ok(Self { table })
768                }
769            }
770
771            impl<'a> fbs::Follow<'a> for #struct_id<&'a [u8]> {
772                type Inner = Self;
773
774                #[inline]
775                fn follow(buf: &'a [u8], loc: usize) -> Result<Self::Inner, fbs::Error> {
776                    let table = fbs::Table::new(buf, loc);
777                    Ok(Self { table })
778                }
779            }
780
781            impl<B> fbs::FollowBuf for #struct_id<B>
782            where
783                B: std::convert::AsRef<[u8]>
784            {
785                type Buf = B;
786                type Inner = Self;
787
788                #[inline]
789                fn follow_buf(buf: Self::Buf, loc: usize) -> Result<Self::Inner, fbs::Error> {
790                    let table = fbs::Table::new(buf, loc);
791                    Ok(Self { table })
792                }
793            }
794
795            // Builder Args
796            pub struct #args#args_lifetime_a {
797                #(#args_fields),*
798            }
799
800            #args_default_impl
801
802            //// builder
803            pub struct #builder_type<'a, 'b> {
804                fbb: &'b mut fbs::FlatBufferBuilder<'a>,
805                start: fbs::WIPOffset<fbs::TableUnfinishedWIPOffset>,
806            }
807
808            impl<'a: 'b, 'b> #builder_type<'a, 'b> {
809                #(#builder_field_methods)*
810
811                #[inline]
812                pub fn new(fbb: &'b mut fbs::FlatBufferBuilder<'a>) -> Self {
813                    let start = fbb.start_table();
814                    #builder_type {
815                        fbb, start
816                    }
817                }
818
819                #[inline]
820                pub fn finish(self) -> fbs::WIPOffset<#struct_id<&'a [u8]>> {
821                    let o = self.fbb.end_table(self.start);
822                    #(#required_fields)*
823                    fbs::WIPOffset::new(o.value())
824                }
825            }
826        })
827        .to_tokens(tokens)
828    }
829}
830
831// Do not implement
832// Left in the code to prevent a rogue impl
833// impl ToTokens for ir::Type<'_> {
834//     fn to_tokens(&self, _: &mut TokenStream) {
835//         panic!("This is unimplemented on purpose -- as types need context to be generated")
836//     }
837// }
838
839impl ToTokens for ir::EnumBaseType {
840    fn to_tokens(&self, tokens: &mut TokenStream) {
841        match self {
842            ir::EnumBaseType::Byte => quote!(i8),
843            ir::EnumBaseType::UByte => quote!(u8),
844            ir::EnumBaseType::Short => quote!(i16),
845            ir::EnumBaseType::UShort => quote!(u16),
846            ir::EnumBaseType::Int => quote!(i32),
847            ir::EnumBaseType::UInt => quote!(u32),
848            ir::EnumBaseType::Long => quote!(i64),
849            ir::EnumBaseType::ULong => quote!(u64),
850            ir::EnumBaseType::Int8 => quote!(i8),
851            ir::EnumBaseType::UInt8 => quote!(u8),
852            ir::EnumBaseType::Int16 => quote!(i16),
853            ir::EnumBaseType::UInt16 => quote!(u16),
854            ir::EnumBaseType::Int32 => quote!(i32),
855            ir::EnumBaseType::UInt32 => quote!(u32),
856            ir::EnumBaseType::Int64 => quote!(i64),
857            ir::EnumBaseType::UInt64 => quote!(u64),
858        }
859        .to_tokens(tokens)
860    }
861}
862
863impl ToTokens for ast::Comment<'_> {
864    fn to_tokens(&self, tokens: &mut TokenStream) {
865        let doc = self.lines.iter().rev().fold(quote!(), |docs, line| {
866            quote! {
867                #[doc = #line]
868                #docs
869            }
870        });
871        doc.to_tokens(tokens)
872    }
873}
874
875impl ToTokens for ir::QualifiedIdent<'_> {
876    fn to_tokens(&self, tokens: &mut TokenStream) {
877        let parts = &self.parts;
878        debug_assert!(!self.parts.is_empty());
879        let code = parts.iter().map(|e| e.raw.as_ref()).join("::");
880        syn::parse_str::<syn::Path>(&code)
881            .expect("Cannot parse path")
882            .to_tokens(tokens)
883    }
884}
885
886impl ToTokens for ir::Enum<'_> {
887    fn to_tokens(&self, tokens: &mut TokenStream) {
888        let Self {
889            ident: enum_id,
890            variants,
891            base_type,
892            doc,
893            ..
894        } = self;
895
896        let enum_id = enum_id.simple();
897        // generate enum variant name => string name of the variant for use in
898        // a match statement
899        let names_to_strings = variants.iter().map(|ir::EnumVariant { ident: key, .. }| {
900            let raw_key = key.raw.as_ref();
901            quote! {
902                #enum_id::#key => #raw_key
903            }
904        });
905
906        let default_value = variants
907            .iter()
908            .map(|ir::EnumVariant { ident: key, .. }| {
909                quote! {
910                    #enum_id::#key
911                }
912            })
913            .next();
914
915        // assign a value to the key if one was given, otherwise give it the
916        // enumerated index's value
917        let variants_and_scalars = variants.iter().enumerate().map(
918            |(
919                index,
920                ir::EnumVariant {
921                    ident: key,
922                    value,
923                    doc,
924                },
925            )| {
926                // format the value with the correct type, i.e., base_type
927                let scalar_value = if let Some(constant) = *value {
928                    constant
929                } else {
930                    index
931                        .try_into()
932                        .expect("invalid conversion to enum base type")
933                };
934                let formatted_value = scalar_value.fmt_lit(base_type.to_token_stream());
935                (quote!(#key), quote!(#formatted_value), doc)
936            },
937        );
938
939        let raw_snake_enum_name = enum_id.raw.as_ref().to_snake_case();
940        let enum_id_fn_name = format_ident!("enum_name_{}", raw_snake_enum_name);
941
942        let from_base_to_enum_variants =
943            variants_and_scalars
944                .clone()
945                .map(|(variant_name, scalar_value, _)| {
946                    quote! {
947                        #scalar_value => Ok(<#enum_id>::#variant_name)
948                    }
949                });
950
951        let from_enum_variant_to_base =
952            variants_and_scalars
953                .clone()
954                .map(|(variant_name, scalar_value, _)| {
955                    quote! {
956                        <#enum_id>::#variant_name => #scalar_value
957                    }
958                });
959
960        let fields = variants_and_scalars.map(|(variant_name, scalar_value, doc)| {
961            quote! {
962                #doc
963                #variant_name = #scalar_value
964            }
965        });
966
967        // TODO: Maybe separate these pieces to avoid variables that used far
968        // away from their definition.
969        (quote! {
970            // force a C-style enum
971            #[repr(#base_type)]
972            #[allow(non_camel_case_types)]
973            #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
974            #doc
975            pub enum #enum_id {
976                #(#fields),*
977            }
978
979            impl Default for #enum_id {
980                fn default() -> Self {
981                    #default_value
982                }
983            }
984
985            impl<'a> fbs::Follow<'a> for #enum_id {
986                type Inner = Self;
987
988                fn follow(buf: &'a [u8], loc: usize) -> Result<Self::Inner, fbs::Error> {
989                    let scalar = fbs::read_scalar_at::<#base_type>(buf, loc)?;
990                    <Self as std::convert::TryFrom<#base_type>>::try_from(scalar)
991                }
992            }
993
994            impl std::convert::TryFrom<#base_type> for #enum_id {
995                type Error = fbs::Error;
996                fn try_from(value: #base_type) -> Result<Self, Self::Error> {
997                    match value {
998                        #(#from_base_to_enum_variants,)*
999                        _ => Err(fbs::Error::UnknownEnumVariant)
1000                    }
1001                }
1002            }
1003
1004            impl From<#enum_id> for #base_type {
1005                fn from(value: #enum_id) -> #base_type {
1006                    match value {
1007                        #(#from_enum_variant_to_base),*
1008                    }
1009                }
1010            }
1011
1012            impl fbs::Push for #enum_id {
1013                type Output = Self;
1014
1015                #[inline]
1016                fn push(&self, dst: &mut [u8], _rest: &[u8]) {
1017                    let scalar = <#base_type>::from(*self);
1018                    fbs::emplace_scalar::<#base_type>(dst, scalar);
1019                }
1020            }
1021
1022            pub fn #enum_id_fn_name(e: #enum_id) -> &'static str {
1023                match e {
1024                    #(#names_to_strings),*
1025                }
1026            }
1027        })
1028        .to_tokens(tokens)
1029    }
1030}
1031
1032impl ToTokens for ir::Union<'_> {
1033    fn to_tokens(&self, tokens: &mut TokenStream) {
1034        let Self {
1035            ident: union_qualified_id,
1036            enum_ident,
1037            variants,
1038            doc,
1039            ..
1040        } = self;
1041
1042        let union_ns = union_qualified_id.namespace();
1043        let union_ns_ref = union_ns.as_ref();
1044        let union_id = union_qualified_id.simple();
1045        let enum_id = enum_ident.relative(union_ns_ref);
1046
1047        // the union's body definition
1048        let names_to_union_variant = variants.iter().map(
1049            |ir::UnionVariant {
1050                 ident: variant_ident,
1051                 ty: variant_ty,
1052                 doc,
1053             }| {
1054                let variant_ty_token =
1055                    to_type_token(union_ns_ref, variant_ty, &quote!('a), &quote!(), false);
1056                quote! {
1057                    #doc
1058                    #variant_ident(#variant_ty_token)
1059                }
1060            },
1061        );
1062
1063        // generate union variant name => enum type name of the variant for use in
1064        // a match statement
1065        let names_to_enum_variant = variants.iter().map(
1066            |ir::UnionVariant {
1067                 ident: variant_ident,
1068                 ..
1069             }| {
1070                quote! {
1071                    #union_id::#variant_ident(..) => #enum_id::#variant_ident
1072                }
1073            },
1074        );
1075
1076        (quote! {
1077            #[derive(Copy, Clone, Debug, PartialEq)]
1078            #doc
1079            pub enum #union_id<'a> {
1080                #(#names_to_union_variant),*
1081            }
1082
1083
1084            impl #union_id<'_> {
1085                pub fn get_type(&self) -> #enum_id {
1086                    match self {
1087                        #(#names_to_enum_variant),*
1088                    }
1089                }
1090            }
1091
1092        })
1093        .to_tokens(tokens)
1094    }
1095}
1096
1097pub struct CodeGenerator<'a> {
1098    pub(crate) root: ir::Root<'a>,
1099    pub(crate) rpc_gen: Option<Box<dyn RpcGenerator>>,
1100}
1101
1102impl<'a> CodeGenerator<'a> {
1103    pub fn build_token_stream(&mut self) -> TokenStream {
1104        let mut token_stream = TokenStream::default();
1105        self.build_tokens(&mut token_stream);
1106        token_stream
1107    }
1108
1109    pub fn build_tokens(&mut self, tokens: &mut TokenStream) {
1110        let mut rpc_gen = self.rpc_gen.take();
1111
1112        for node in &self.root.nodes {
1113            self.node_to_tokens(node, &mut rpc_gen, tokens);
1114        }
1115    }
1116
1117    fn node_to_tokens(
1118        &self,
1119        node: &ir::Node<'a>,
1120        rpc_gen: &mut Option<Box<dyn RpcGenerator>>,
1121        tokens: &mut TokenStream,
1122    ) {
1123        // The following constructs are (or should be) handled at the file level:
1124        // * Namespaces
1125        // * Root types
1126        // * File extensions
1127        // * File identifiers
1128        //
1129        // Additionally, attributes do not have corresponding concrete code generated, they are
1130        // used to *affect* codegen of other items.
1131        match node {
1132            ir::Node::Table(t) => t.to_tokens(tokens),
1133            // ir::Node::Struct(_) => unimplemented!(),
1134            ir::Node::Enum(e) => e.to_tokens(tokens),
1135            ir::Node::Union(u) => u.to_tokens(tokens),
1136            ir::Node::Namespace(n) => {
1137                let ident = format_ident!("{}", n.ident.simple().as_ref().to_snake_case());
1138                let mut nodes_ts = TokenStream::default();
1139                for node in &n.nodes {
1140                    self.node_to_tokens(node, rpc_gen, &mut nodes_ts);
1141                }
1142                (quote! {
1143                    pub mod #ident {
1144                        #nodes_ts
1145                    }
1146                })
1147                .to_tokens(tokens)
1148            }
1149            ir::Node::Rpc(rpc) => {
1150                if let Some(gen) = rpc_gen {
1151                    gen.generate(rpc, tokens)
1152                }
1153            }
1154            element => panic!("{:?}", element),
1155        }
1156    }
1157}
1158
1159pub trait RpcGenerator {
1160    /// Generates a Rust interface or implementation for a service, writing the
1161    /// result to the provided `token_stream`.
1162    fn generate<'a>(&mut self, rpc: &ir::Rpc<'a>, token_stream: &mut TokenStream);
1163}
1164
1165#[cfg(test)]
1166mod table_tests {
1167    use super::*;
1168    use crate::{
1169        ir::{types::Node, Builder},
1170        parser::schema_decl,
1171    };
1172
1173    #[test]
1174    fn test_required_fields() {
1175        let input = "\
1176table Hello {
1177  world: string (required);
1178  earth: int = 616 (required);
1179  universe: string;
1180}";
1181        let (_, schema) = schema_decl(input).unwrap();
1182        let actual = Builder::build(schema).unwrap();
1183        match &actual.nodes[0] {
1184            Node::Table(table) => {
1185                let result = to_code(table);
1186                assert_eq!(2, result.matches("required").count());
1187            }
1188            node => panic!("{:?}", node),
1189        }
1190    }
1191}