batbox_diff_derive/
lib.rs1#![recursion_limit = "128"]
2#![allow(unused_imports)]
3
4extern crate proc_macro;
5
6use darling::{FromDeriveInput, FromField};
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9
10#[proc_macro_derive(Diff, attributes(diff))]
11pub fn derive_diff(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
12 derive_impl(input.into()).into()
13}
14
15enum DiffMode {
17 Diff,
18 Clone,
19 Eq,
20}
21
22fn derive_impl(input: TokenStream) -> TokenStream {
23 let s = input.to_string();
24 let ast: syn::DeriveInput = syn::parse_str(&s).unwrap();
25 let input_type = &ast.ident;
26 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
27 let generics = &ast.generics;
28 match ast.data {
29 syn::Data::Struct(syn::DataStruct { ref fields, .. }) => {
30 let field_tys: Vec<_> = fields.iter().map(|field| &field.ty).collect();
31 let field_tys = &field_tys;
32 let field_names: Vec<_> = fields
33 .iter()
34 .map(|field| field.ident.as_ref().unwrap())
35 .collect();
36 let field_names = &field_names;
37 let delta_type = syn::Ident::new(
38 &format!("{input_type}Delta"),
39 proc_macro2::Span::call_site(),
40 );
41
42 let field_diff_modes: Vec<DiffMode> = fields
43 .iter()
44 .map(|field| {
45 let mut diff_type = DiffMode::Diff;
46 for attr in &field.attrs {
47 if let syn::Meta::NameValue(syn::MetaNameValue {
48 path: ref meta_path,
49 value:
50 syn::Expr::Lit(syn::ExprLit {
51 lit: syn::Lit::Str(ref s),
52 ..
53 }),
54 ..
55 }) = attr.meta
56 {
57 if meta_path.is_ident("diff") {
58 diff_type = match s.value().as_str() {
59 "eq" => DiffMode::Eq,
60 "diff" => DiffMode::Diff,
61 "clone" => DiffMode::Clone,
62 _ => panic!("Unexpected diff type"),
63 }
64 }
65 }
66 }
67 diff_type
68 })
69 .collect();
70
71 let field_diff_types =
72 field_tys
73 .iter()
74 .zip(field_diff_modes.iter())
75 .map(|(field_ty, field_diff_mode)| match field_diff_mode {
76 DiffMode::Diff => quote! {
77 <#field_ty as Diff>::Delta
78 },
79 DiffMode::Clone => quote! {
80 #field_ty
81 },
82 DiffMode::Eq => quote! {
83 Option<#field_ty>
84 },
85 });
86
87 let field_diffs = field_names.iter().zip(field_diff_modes.iter()).map(
88 |(field_name, field_diff_mode)| match field_diff_mode {
89 DiffMode::Diff => quote! {
90 Diff::diff(&self.#field_name, &to.#field_name)
91 },
92 DiffMode::Clone => quote! {
93 to.#field_name.clone()
94 },
95 DiffMode::Eq => quote! {
96 if self.#field_name == to.#field_name {
97 None
98 } else {
99 Some(to.#field_name.clone())
100 }
101 },
102 },
103 );
104
105 let field_updates = field_names.iter().zip(field_diff_modes.iter()).map(
106 |(field_name, field_diff_mode)| match field_diff_mode {
107 DiffMode::Diff => quote! {
108 Diff::update(&mut self.#field_name, &delta.#field_name);
109 },
110 DiffMode::Clone => quote! {
111 self.#field_name = delta.#field_name.clone();
112 },
113 DiffMode::Eq => quote! {
114 if let Some(value) = &delta.#field_name {
115 self.#field_name = value.clone();
116 }
117 },
118 },
119 );
120
121 let expanded = quote! {
122 #[derive(Debug, Serialize, Deserialize, Clone)]
123 pub struct #delta_type #generics {
124 #(#field_names: #field_diff_types,)*
125 }
126
127 impl #impl_generics Diff for #input_type #ty_generics #where_clause {
128 type Delta = #delta_type;
129 fn diff(&self, to: &Self) -> Self::Delta {
130 #delta_type {
131 #(#field_names: #field_diffs,)*
132 }
133 }
134 fn update(&mut self, delta: &Self::Delta) {
135 #(#field_updates)*
136 }
137 }
138 };
139 expanded
140 }
141 _ => panic!("Diff can only be derived by structs"),
142 }
143}