Skip to main content

burn_derive/config/
analyzer.rs

1use super::ConfigEnumAnalyzer;
2use crate::config::ConfigStructAnalyzer;
3use crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};
4use proc_macro2::TokenStream;
5use quote::quote;
6use syn::{Field, Ident};
7
8pub struct ConfigAnalyzerFactory {}
9
10pub trait ConfigAnalyzer {
11    fn gen_new_fn(&self) -> TokenStream {
12        quote! {}
13    }
14    fn gen_builder_fns(&self) -> TokenStream {
15        quote! {}
16    }
17    fn gen_serde_impl(&self) -> TokenStream;
18    fn gen_clone_impl(&self) -> TokenStream;
19    fn gen_display_impl(&self) -> TokenStream;
20    fn gen_config_impl(&self) -> TokenStream;
21}
22
23impl ConfigAnalyzerFactory {
24    pub fn new() -> Self {
25        Self {}
26    }
27
28    pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box<dyn ConfigAnalyzer> {
29        let name = item.ident.clone();
30        let config_type = parse_asm(item);
31
32        match config_type {
33            ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)),
34            ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)),
35        }
36    }
37
38    fn create_struct_analyzer(&self, name: Ident, fields: Vec<Field>) -> ConfigStructAnalyzer {
39        let fields = fields.into_iter().map(FieldTypeAnalyzer::new);
40
41        let mut fields_required = Vec::new();
42        let mut fields_option = Vec::new();
43        let mut fields_default = Vec::new();
44
45        for field in fields {
46            let attributes: Vec<AttributeItem> = field
47                .attributes()
48                .filter(|attr| attr.has_name("config"))
49                .map(|attr| attr.item())
50                .collect();
51
52            if !attributes.is_empty() {
53                let item = attributes.first().unwrap().clone();
54                fields_default.push((field.clone(), item));
55                continue;
56            }
57
58            if field.is_of_type(&["Option"]) {
59                fields_option.push(field.clone());
60                continue;
61            }
62
63            fields_required.push(field.clone());
64        }
65
66        ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default)
67    }
68
69    fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer {
70        ConfigEnumAnalyzer::new(name, data)
71    }
72}
73
74enum ConfigType {
75    Struct(Vec<Field>),
76    Enum(syn::DataEnum),
77}
78
79fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
80    match &ast.data {
81        syn::Data::Struct(struct_data) => {
82            ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
83        }
84        syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
85        syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
86    }
87}