batbox_diff_derive/
lib.rs

1#![recursion_limit = "128"]
2#![allow(unused_imports)]
3
4extern crate proc_macro;
5
6use darling::{FromDeriveInput, FromField};
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9
10#[proc_macro_derive(Diff, attributes(diff))]
11pub fn derive_diff(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12    derive_impl(input.into()).into()
13}
14
15// TODO: simplify by actually using darling
16enum DiffMode {
17    Diff,
18    Clone,
19    Eq,
20}
21
22fn derive_impl(input: TokenStream) -> TokenStream {
23    let s = input.to_string();
24    let ast: syn::DeriveInput = syn::parse_str(&s).unwrap();
25    let input_type = &ast.ident;
26    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
27    let generics = &ast.generics;
28    match ast.data {
29        syn::Data::Struct(syn::DataStruct { ref fields, .. }) => {
30            let field_tys: Vec<_> = fields.iter().map(|field| &field.ty).collect();
31            let field_tys = &field_tys;
32            let field_names: Vec<_> = fields
33                .iter()
34                .map(|field| field.ident.as_ref().unwrap())
35                .collect();
36            let field_names = &field_names;
37            let delta_type = syn::Ident::new(
38                &format!("{input_type}Delta"),
39                proc_macro2::Span::call_site(),
40            );
41
42            let field_diff_modes: Vec<DiffMode> = fields
43                .iter()
44                .map(|field| {
45                    let mut diff_type = DiffMode::Diff;
46                    for attr in &field.attrs {
47                        if let syn::Meta::NameValue(syn::MetaNameValue {
48                            path: ref meta_path,
49                            value:
50                                syn::Expr::Lit(syn::ExprLit {
51                                    lit: syn::Lit::Str(ref s),
52                                    ..
53                                }),
54                            ..
55                        }) = attr.meta
56                        {
57                            if meta_path.is_ident("diff") {
58                                diff_type = match s.value().as_str() {
59                                    "eq" => DiffMode::Eq,
60                                    "diff" => DiffMode::Diff,
61                                    "clone" => DiffMode::Clone,
62                                    _ => panic!("Unexpected diff type"),
63                                }
64                            }
65                        }
66                    }
67                    diff_type
68                })
69                .collect();
70
71            let field_diff_types =
72                field_tys
73                    .iter()
74                    .zip(field_diff_modes.iter())
75                    .map(|(field_ty, field_diff_mode)| match field_diff_mode {
76                        DiffMode::Diff => quote! {
77                            <#field_ty as Diff>::Delta
78                        },
79                        DiffMode::Clone => quote! {
80                            #field_ty
81                        },
82                        DiffMode::Eq => quote! {
83                            Option<#field_ty>
84                        },
85                    });
86
87            let field_diffs = field_names.iter().zip(field_diff_modes.iter()).map(
88                |(field_name, field_diff_mode)| match field_diff_mode {
89                    DiffMode::Diff => quote! {
90                        Diff::diff(&self.#field_name, &to.#field_name)
91                    },
92                    DiffMode::Clone => quote! {
93                        to.#field_name.clone()
94                    },
95                    DiffMode::Eq => quote! {
96                        if self.#field_name == to.#field_name {
97                            None
98                        } else {
99                            Some(to.#field_name.clone())
100                        }
101                    },
102                },
103            );
104
105            let field_updates = field_names.iter().zip(field_diff_modes.iter()).map(
106                |(field_name, field_diff_mode)| match field_diff_mode {
107                    DiffMode::Diff => quote! {
108                        Diff::update(&mut self.#field_name, &delta.#field_name);
109                    },
110                    DiffMode::Clone => quote! {
111                        self.#field_name = delta.#field_name.clone();
112                    },
113                    DiffMode::Eq => quote! {
114                        if let Some(value) = &delta.#field_name {
115                            self.#field_name = value.clone();
116                        }
117                    },
118                },
119            );
120
121            let expanded = quote! {
122                #[derive(Debug, Serialize, Deserialize, Clone)]
123                pub struct #delta_type #generics {
124                    #(#field_names: #field_diff_types,)*
125                }
126
127                impl #impl_generics Diff for #input_type #ty_generics #where_clause {
128                    type Delta = #delta_type;
129                    fn diff(&self, to: &Self) -> Self::Delta {
130                        #delta_type {
131                            #(#field_names: #field_diffs,)*
132                        }
133                    }
134                    fn update(&mut self, delta: &Self::Delta) {
135                        #(#field_updates)*
136                    }
137                }
138            };
139            expanded
140        }
141        _ => panic!("Diff can only be derived by structs"),
142    }
143}