pro_serde_versioned_derive/
lib.rs1use std::collections::HashMap;
12
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{Data, DeriveInput, Fields};
16
17#[derive(Debug)]
18struct VersionVariant {
19 version_number: usize,
20 variant_ident: syn::Ident,
21 variant_ty: syn::Type,
22 latest: bool,
23}
24
25#[proc_macro_derive(VersionedSerialize)]
26pub fn versioned_serialize(input: TokenStream) -> TokenStream {
27 let ast: DeriveInput = syn::parse(input).unwrap();
28 let name = &ast.ident;
29 let version_variants = get_version_variants(&ast);
30 let variant_names: Vec<_> = version_variants
31 .values()
32 .map(|version_variant| &version_variant.variant_ident)
33 .cloned()
34 .collect();
35
36 let variant_tys: Vec<_> = version_variants
37 .values()
38 .map(|version_variant| version_variant.variant_ty.clone())
39 .collect();
40
41 let variant_versions: Vec<_> = version_variants
42 .values()
43 .map(|version_variant| version_variant.version_number)
44 .collect();
45
46 let generics = ast.generics;
47 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
48 quote! {
49 #(
50 impl #impl_generics From<#variant_tys> for #name #ty_generics #where_clause {
51 fn from(value: #variant_tys) -> #name {
52 #name::#variant_names(value)
53 }
54 }
55
56 )*
76
77 impl #impl_generics ::pro_serde_versioned::VersionedSerialize for #name #ty_generics #where_clause {
78 type VersionedEnvelope<A: Serialize> = ::pro_serde_versioned::VersionedEnvelope<A>;
79 fn to_envelope<F: ::pro_serde_versioned::SerializeFormat>(&self) -> Result<Self::VersionedEnvelope<F>, F::Error> {
80 match self {
81 #(
82 #name::#variant_names(value) => {
83 Ok(::pro_serde_versioned::VersionedEnvelope {
84 version_number: #variant_versions,
85 data: F::serialize_format(&value)?
86 })
87 }
88 )*
89 }
90 }
91 }
92 }
93 .into()
94}
95
96#[proc_macro_derive(VersionedDeserialize)]
97pub fn versioned_deserialize(input: TokenStream) -> TokenStream {
98 let ast: DeriveInput = syn::parse(input).unwrap();
99 let name = &ast.ident;
100
101 let version_variants = get_version_variants(&ast);
102 let variant_names: Vec<_> = version_variants
103 .values()
104 .map(|version_variant| &version_variant.variant_ident)
105 .cloned()
106 .collect();
107
108 let variant_versions: Vec<_> = version_variants
109 .values()
110 .map(|version_variant| version_variant.version_number)
111 .collect();
112
113 let generics = ast.generics;
114 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
115
116 quote! {
117 impl #impl_generics ::pro_serde_versioned::VersionedDeserialize for #name #ty_generics #where_clause {
118 type VersionedEnvelope<'a, F: Deserialize<'a>> = ::pro_serde_versioned::VersionedEnvelope<F>;
119 fn from_envelope<'a, F: ::pro_serde_versioned::DeserializeFormat + Deserialize<'a>>(
120 envelope: &::pro_serde_versioned::VersionedEnvelope<F>,
121 ) -> Result<Self, F::Error> {
122 match envelope.version_number {
123 #(
124 #variant_versions => Ok(#name::#variant_names(
125 <F as ::pro_serde_versioned::DeserializeFormat>::deserialize_format(
126 &envelope.data
127 )
128 ?)),
129 )*
130 _ => Err(serde::de::Error::custom("Unknown version number")),
131 }
132 }
133 }
134 }
135 .into()
136}
137
138#[proc_macro_derive(VersionedUpgrade)]
139pub fn upgradable_enum(input: TokenStream) -> TokenStream {
140 let ast: DeriveInput = syn::parse(input).unwrap();
141 let name = &ast.ident;
142 let mut latest_variant_ty = None;
143 let version_variants = get_version_variants(&ast);
144
145 for variant in version_variants.values() {
146 if variant.latest {
147 latest_variant_ty = Some(variant.variant_ty.clone());
148 break;
149 }
150 }
151
152 let latest_variant_ty = match latest_variant_ty {
153 Some(latest_variant_ty) => latest_variant_ty,
154 None => panic!("No latest variant found"),
155 };
156
157 let upgrade_match_arms = generate_upgrade_match_arms(&ast, version_variants);
158
159 let generics = ast.generics;
160 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
161
162 let gen = quote! {
163 impl #impl_generics ::pro_serde_versioned::VersionedUpgrade for #name #ty_generics #where_clause {
164 type Latest = #latest_variant_ty;
165 fn upgrade_to_latest(self) -> Self::Latest {
166 match self {
167 #(#upgrade_match_arms)*
168 }
169 }
170 }
171 };
172
173 gen.into()
174}
175
176fn generate_upgrade_match_arms(
177 ast: &DeriveInput,
178 version_variants: HashMap<usize, VersionVariant>,
179) -> Vec<proc_macro2::TokenStream> {
180 let name = &ast.ident;
181 let mut match_arms = Vec::new();
182
183 for version_variant in version_variants.values() {
184 let version_number = version_variant.version_number;
185 let variant_ident = &version_variant.variant_ident;
186
187 if !version_variant.latest {
188 let next_variant = version_variants
189 .get(&(version_number + 1))
190 .expect("No variant for next version");
191
192 let next_variant_ident = &next_variant.variant_ident;
193 let next_variant_ty = &next_variant.variant_ty;
194 match_arms.push(quote! {
195 #name::#variant_ident(value) => {
196 let upgraded: #next_variant_ty = value.upgrade();
197 #name::#next_variant_ident(upgraded).upgrade_to_latest()
198 },
199 });
200 } else {
201 match_arms.push(quote! {
202 #name::#variant_ident(value) => value,
203 });
204 }
205 }
206
207 match_arms
208}
209
210fn get_version_variants(ast: &DeriveInput) -> HashMap<usize, VersionVariant> {
211 let mut version_variants: HashMap<usize, VersionVariant> = HashMap::new();
212
213 let mut max = 0;
214
215 if let Data::Enum(data_enum) = &ast.data {
216 for variant in &data_enum.variants {
217 let version_number = variant
218 .ident
219 .to_string()
220 .replace("V", "")
221 .parse::<usize>()
222 .expect("Invalid version number");
223 max = std::cmp::max(max, version_number);
224 version_variants.insert(
225 variant
226 .ident
227 .to_string()
228 .replace("V", "")
229 .parse::<usize>()
230 .expect("Invalid version number"),
231 VersionVariant {
232 version_number,
233 variant_ident: variant.ident.clone(),
234 variant_ty: {
235 if variant.fields.len() != 1 {
236 panic!("Only single-field variants are supported");
237 }
238 if let Fields::Unnamed(fields_unnamed) = &variant.fields {
239 fields_unnamed.unnamed[0].ty.clone()
240 } else {
241 panic!("Only unnamed fields are supported");
242 }
243 },
244 latest: false,
245 },
246 );
247 }
248 }
249 {
250 let latest_version = version_variants.get_mut(&max).expect("No latest version");
251 latest_version.latest = true;
252 }
253
254 version_variants
255}