burn-derive 0.21.0

Derive crate for the Burn framework
Documentation
use crate::shared::enum_variant::map_enum_variant;

use super::ConfigAnalyzer;
use proc_macro2::{Ident, TokenStream};
use quote::quote;

pub struct ConfigEnumAnalyzer {
    name: Ident,
    data: syn::DataEnum,
}

impl ConfigEnumAnalyzer {
    pub fn new(name: Ident, data: syn::DataEnum) -> Self {
        Self { name, data }
    }

    fn serde_enum_ident(&self) -> Ident {
        Ident::new(&format!("{}Serde", self.name), self.name.span())
    }

    fn gen_serde_enum(&self) -> TokenStream {
        let enum_name = self.serde_enum_ident();
        let data = &self.data.variants;

        quote! {
            #[derive(burn::serde::Serialize, burn::serde::Deserialize)]
            #[serde(crate = "burn::serde")]
            enum #enum_name {
                #data
            }

        }
    }

    fn gen_serialize_fn(&self) -> TokenStream {
        let enum_name = self.serde_enum_ident();
        let variants = self.data.variants.iter().map(|variant| {
            let variant_name = &variant.ident;
            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });

            quote! { Self::#variant_name #inputs => #enum_name::#variant_name #outputs }
        });

        let name = &self.name;

        quote! {
            impl burn::serde::Serialize for #name {
                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
                where
                    S: burn::serde::Serializer {
                    let serde_state = match self {
                        #(#variants),*
                    };
                    serde_state.serialize(serializer)
                }
            }

        }
    }

    fn gen_deserialize_fn(&self) -> TokenStream {
        let enum_name = self.serde_enum_ident();
        let variants = self.data.variants.iter().map(|variant| {
            let variant_name = &variant.ident;
            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });

            quote! { #enum_name::#variant_name #inputs => Self::#variant_name #outputs }
        });
        let name = &self.name;

        quote! {
            impl<'de> burn::serde::Deserialize<'de> for #name {
                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
                where
                    D: burn::serde::Deserializer<'de> {
                    let serde_state = #enum_name::deserialize(deserializer)?;
                    Ok(match serde_state {
                        #(#variants),*
                    })
                }
            }

        }
    }
}

impl ConfigAnalyzer for ConfigEnumAnalyzer {
    fn gen_serde_impl(&self) -> TokenStream {
        let struct_gen = self.gen_serde_enum();
        let serialize_gen = self.gen_serialize_fn();
        let deserialize_gen = self.gen_deserialize_fn();

        quote! {
            #struct_gen
            #serialize_gen
            #deserialize_gen
        }
    }

    fn gen_clone_impl(&self) -> TokenStream {
        let variants = self.data.variants.iter().map(|variant| {
            let variant_name = &variant.ident;
            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });

            quote! { Self::#variant_name #inputs => Self::#variant_name #outputs }
        });
        let name = &self.name;

        quote! {
            impl Clone for #name {
                fn clone(&self) -> Self {
                    match self {
                        #(#variants),*
                    }
                }
            }

        }
    }

    fn gen_display_impl(&self) -> TokenStream {
        let name = &self.name;

        quote! {
            impl core::fmt::Display for #name {
                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
                    f.write_str(&burn::config::config_to_json(self))
                }
            }
        }
    }

    fn gen_config_impl(&self) -> TokenStream {
        let name = &self.name;

        quote! {
            impl burn::config::Config for #name {
            }
        }
    }
}