pro_serde_versioned_derive/
lib.rs

1// ┌───────────────────────────────────────────────────────────────────────────┐
2// │                                                                           │
3// │  ██████╗ ██████╗  ██████╗   Copyright (C) The Prospective Company         │
4// │  ██╔══██╗██╔══██╗██╔═══██╗  All Rights Reserved - April 2022              │
5// │  ██████╔╝██████╔╝██║   ██║                                                │
6// │  ██╔═══╝ ██╔══██╗██║   ██║  Proprietary and confidential. Unauthorized    │
7// │  ██║     ██║  ██║╚██████╔╝  copying of this file, via any medium is       │
8// │  ╚═╝     ╚═╝  ╚═╝ ╚═════╝   strictly prohibited.                          │
9// │                                                                           │
10// └───────────────────────────────────────────────────────────────────────────┘
11use std::collections::HashMap;
12
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{Data, DeriveInput, Fields};
16
17#[derive(Debug)]
18struct VersionVariant {
19    version_number: usize,
20    variant_ident: syn::Ident,
21    variant_ty: syn::Type,
22    latest: bool,
23}
24
25#[proc_macro_derive(VersionedSerialize)]
26pub fn versioned_serialize(input: TokenStream) -> TokenStream {
27    let ast: DeriveInput = syn::parse(input).unwrap();
28    let name = &ast.ident;
29    let version_variants = get_version_variants(&ast);
30    let variant_names: Vec<_> = version_variants
31        .values()
32        .map(|version_variant| &version_variant.variant_ident)
33        .cloned()
34        .collect();
35
36    let variant_tys: Vec<_> = version_variants
37        .values()
38        .map(|version_variant| version_variant.variant_ty.clone())
39        .collect();
40
41    let variant_versions: Vec<_> = version_variants
42        .values()
43        .map(|version_variant| version_variant.version_number)
44        .collect();
45
46    let generics = ast.generics;
47    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48    quote! {
49        #(
50            impl #impl_generics From<#variant_tys> for #name #ty_generics #where_clause {
51                fn from(value: #variant_tys) -> #name {
52                    #name::#variant_names(value)
53                }
54            }
55
56            // // TODO This would be handy, but will create conflicts when
57            // // multiple enums use `#variant_tys`, as we do in the tests.
58
59            // impl VersionedSerialize for #variant_tys {
60            //     fn versioned_serialize<F: SerializeFormat>(&self) -> Result<F, Box<dyn std::error::Error>> {
61            //         #[derive(Serialize, Deserialize)]
62            //         pub struct VersionedEnvelope<T> {
63            //             pub version_number: usize,
64            //             pub data: T,
65            //         }
66
67            //         let envelope = VersionedEnvelope {
68            //             version_number: #variant_versions,
69            //             data: F::serialize_format(&self)?
70            //         };
71
72            //         F::serialize_format(envelope)
73            //     }
74            // }
75        )*
76
77        impl #impl_generics ::pro_serde_versioned::VersionedSerialize for #name #ty_generics #where_clause {
78            type VersionedEnvelope<A: Serialize> = ::pro_serde_versioned::VersionedEnvelope<A>;
79            fn to_envelope<F: ::pro_serde_versioned::SerializeFormat>(&self) -> Result<Self::VersionedEnvelope<F>, F::Error> {
80                match self {
81                    #(
82                        #name::#variant_names(value) => {
83                            Ok(::pro_serde_versioned::VersionedEnvelope {
84                                version_number: #variant_versions,
85                                data: F::serialize_format(&value)?
86                            })
87                        }
88                    )*
89                }
90            }
91        }
92    }
93    .into()
94}
95
96#[proc_macro_derive(VersionedDeserialize)]
97pub fn versioned_deserialize(input: TokenStream) -> TokenStream {
98    let ast: DeriveInput = syn::parse(input).unwrap();
99    let name = &ast.ident;
100
101    let version_variants = get_version_variants(&ast);
102    let variant_names: Vec<_> = version_variants
103        .values()
104        .map(|version_variant| &version_variant.variant_ident)
105        .cloned()
106        .collect();
107
108    let variant_versions: Vec<_> = version_variants
109        .values()
110        .map(|version_variant| version_variant.version_number)
111        .collect();
112
113    let generics = ast.generics;
114    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
115
116    quote! {
117        impl #impl_generics ::pro_serde_versioned::VersionedDeserialize for #name #ty_generics #where_clause {
118            type VersionedEnvelope<'a, F: Deserialize<'a>> = ::pro_serde_versioned::VersionedEnvelope<F>;
119            fn from_envelope<'a, F: ::pro_serde_versioned::DeserializeFormat + Deserialize<'a>>(
120                envelope: &::pro_serde_versioned::VersionedEnvelope<F>,
121            ) -> Result<Self, F::Error> {
122                match envelope.version_number {
123                    #(
124                        #variant_versions => Ok(#name::#variant_names(
125                            <F as ::pro_serde_versioned::DeserializeFormat>::deserialize_format(
126                                &envelope.data
127                            )
128                        ?)),
129                    )*
130                    _ => Err(serde::de::Error::custom("Unknown version number")),
131                }
132            }
133        }
134    }
135    .into()
136}
137
138#[proc_macro_derive(VersionedUpgrade)]
139pub fn upgradable_enum(input: TokenStream) -> TokenStream {
140    let ast: DeriveInput = syn::parse(input).unwrap();
141    let name = &ast.ident;
142    let mut latest_variant_ty = None;
143    let version_variants = get_version_variants(&ast);
144
145    for variant in version_variants.values() {
146        if variant.latest {
147            latest_variant_ty = Some(variant.variant_ty.clone());
148            break;
149        }
150    }
151
152    let latest_variant_ty = match latest_variant_ty {
153        Some(latest_variant_ty) => latest_variant_ty,
154        None => panic!("No latest variant found"),
155    };
156
157    let upgrade_match_arms = generate_upgrade_match_arms(&ast, version_variants);
158
159    let generics = ast.generics;
160    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
161
162    let gen = quote! {
163        impl #impl_generics ::pro_serde_versioned::VersionedUpgrade for #name #ty_generics #where_clause {
164            type Latest = #latest_variant_ty;
165            fn upgrade_to_latest(self) -> Self::Latest {
166                match self {
167                    #(#upgrade_match_arms)*
168                }
169            }
170        }
171    };
172
173    gen.into()
174}
175
176fn generate_upgrade_match_arms(
177    ast: &DeriveInput,
178    version_variants: HashMap<usize, VersionVariant>,
179) -> Vec<proc_macro2::TokenStream> {
180    let name = &ast.ident;
181    let mut match_arms = Vec::new();
182
183    for version_variant in version_variants.values() {
184        let version_number = version_variant.version_number;
185        let variant_ident = &version_variant.variant_ident;
186
187        if !version_variant.latest {
188            let next_variant = version_variants
189                .get(&(version_number + 1))
190                .expect("No variant for next version");
191
192            let next_variant_ident = &next_variant.variant_ident;
193            let next_variant_ty = &next_variant.variant_ty;
194            match_arms.push(quote! {
195                #name::#variant_ident(value) => {
196                    let upgraded: #next_variant_ty = value.upgrade();
197                    #name::#next_variant_ident(upgraded).upgrade_to_latest()
198                },
199            });
200        } else {
201            match_arms.push(quote! {
202                #name::#variant_ident(value) => value,
203            });
204        }
205    }
206
207    match_arms
208}
209
210fn get_version_variants(ast: &DeriveInput) -> HashMap<usize, VersionVariant> {
211    let mut version_variants: HashMap<usize, VersionVariant> = HashMap::new();
212
213    let mut max = 0;
214
215    if let Data::Enum(data_enum) = &ast.data {
216        for variant in &data_enum.variants {
217            let version_number = variant
218                .ident
219                .to_string()
220                .replace("V", "")
221                .parse::<usize>()
222                .expect("Invalid version number");
223            max = std::cmp::max(max, version_number);
224            version_variants.insert(
225                variant
226                    .ident
227                    .to_string()
228                    .replace("V", "")
229                    .parse::<usize>()
230                    .expect("Invalid version number"),
231                VersionVariant {
232                    version_number,
233                    variant_ident: variant.ident.clone(),
234                    variant_ty: {
235                        if variant.fields.len() != 1 {
236                            panic!("Only single-field variants are supported");
237                        }
238                        if let Fields::Unnamed(fields_unnamed) = &variant.fields {
239                            fields_unnamed.unnamed[0].ty.clone()
240                        } else {
241                            panic!("Only unnamed fields are supported");
242                        }
243                    },
244                    latest: false,
245                },
246            );
247        }
248    }
249    {
250        let latest_version = version_variants.get_mut(&max).expect("No latest version");
251        latest_version.latest = true;
252    }
253
254    version_variants
255}