version_migrate_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Meta};
4
5/// Derives the `Versioned` trait for a struct.
6///
7/// # Attributes
8///
9/// - `#[versioned(version = "x.y.z")]`: Specifies the semantic version (required).
10///   The version string must be a valid semantic version.
11/// - `#[versioned(version_key = "...")]`: Customizes the version field key (optional, default: "version").
12/// - `#[versioned(data_key = "...")]`: Customizes the data field key (optional, default: "data").
13/// - `#[versioned(auto_tag = true)]`: Auto-generates Serialize/Deserialize with version field (optional, default: false).
14///   When enabled, the version field is automatically inserted during serialization and validated during deserialization.
15/// - `#[versioned(queryable = true)]`: Auto-generates Queryable trait implementation (optional, default: false).
16///   Enables use with ConfigMigrator for ORM-like queries.
17/// - `#[versioned(queryable_key = "...")]`: Customizes the entity name for Queryable (optional).
18///   If not specified, uses the lowercased type name. Only used when `queryable = true`.
19///
20/// # Examples
21///
22/// Basic usage:
23/// ```ignore
24/// use version_migrate::Versioned;
25///
26/// #[derive(Versioned)]
27/// #[versioned(version = "1.0.0")]
28/// pub struct Task_V1_0_0 {
29///     pub id: String,
30///     pub title: String,
31/// }
32/// ```
33///
34/// Custom keys:
35/// ```ignore
36/// #[derive(Versioned)]
37/// #[versioned(
38///     version = "1.0.0",
39///     version_key = "schema_version",
40///     data_key = "payload"
41/// )]
42/// pub struct Task { ... }
43/// // When used with Migrator:
44/// // Serializes to: {"schema_version":"1.0.0","payload":{...}}
45/// ```
46///
47/// Auto-tag for direct serialization:
48/// ```ignore
49/// #[derive(Versioned)]
50/// #[versioned(version = "1.0.0", auto_tag = true)]
51/// pub struct Task {
52///     pub id: String,
53///     pub title: String,
54/// }
55///
56/// // Use serde directly without Migrator
57/// let task = Task { id: "1".into(), title: "Test".into() };
58/// let json = serde_json::to_string(&task)?;
59/// // → {"version":"1.0.0","id":"1","title":"Test"}
60/// ```
61///
62/// Queryable for ConfigMigrator:
63/// ```ignore
64/// #[derive(Serialize, Deserialize, Versioned)]
65/// #[versioned(version = "2.0.0", queryable = true, queryable_key = "task")]
66/// pub struct TaskEntity {
67///     pub id: String,
68///     pub title: String,
69///     pub description: Option<String>,
70/// }
71///
72/// // Now TaskEntity implements Queryable automatically
73/// let tasks: Vec<TaskEntity> = config_migrator.query("tasks")?;
74/// ```
75#[proc_macro_derive(Versioned, attributes(versioned))]
76pub fn derive_versioned(input: TokenStream) -> TokenStream {
77    let input = parse_macro_input!(input as DeriveInput);
78
79    // Extract attributes
80    let attrs = extract_attributes(&input);
81
82    let name = &input.ident;
83    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
84
85    let version = &attrs.version;
86    let version_key = &attrs.version_key;
87    let data_key = &attrs.data_key;
88
89    let versioned_impl = quote! {
90        impl #impl_generics version_migrate::Versioned for #name #ty_generics #where_clause {
91            const VERSION: &'static str = #version;
92            const VERSION_KEY: &'static str = #version_key;
93            const DATA_KEY: &'static str = #data_key;
94        }
95    };
96
97    let mut impls = vec![versioned_impl];
98
99    if attrs.auto_tag {
100        // Generate custom Serialize and Deserialize implementations
101        let serialize_impl = generate_serialize_impl(&input, &attrs);
102        let deserialize_impl = generate_deserialize_impl(&input, &attrs);
103        impls.push(serialize_impl);
104        impls.push(deserialize_impl);
105    }
106
107    if attrs.queryable {
108        // Generate Queryable trait implementation
109        let queryable_impl = generate_queryable_impl(&input, &attrs);
110        impls.push(queryable_impl);
111    }
112
113    let expanded = quote! {
114        #(#impls)*
115    };
116
117    TokenStream::from(expanded)
118}
119
120struct VersionedAttributes {
121    version: String,
122    version_key: String,
123    data_key: String,
124    auto_tag: bool,
125    queryable: bool,
126    queryable_key: Option<String>,
127}
128
129fn extract_attributes(input: &DeriveInput) -> VersionedAttributes {
130    let mut version = None;
131    let mut version_key = String::from("version");
132    let mut data_key = String::from("data");
133    let mut auto_tag = false;
134    let mut queryable = false;
135    let mut queryable_key = None;
136
137    for attr in &input.attrs {
138        if attr.path().is_ident("versioned") {
139            if let Meta::List(meta_list) = &attr.meta {
140                let tokens = meta_list.tokens.to_string();
141                parse_versioned_attrs(
142                    &tokens,
143                    &mut version,
144                    &mut version_key,
145                    &mut data_key,
146                    &mut auto_tag,
147                    &mut queryable,
148                    &mut queryable_key,
149                );
150            }
151        }
152    }
153
154    let version = version.unwrap_or_else(|| {
155        panic!("Missing #[versioned(version = \"x.y.z\")] attribute");
156    });
157
158    // Validate semver at compile time
159    if let Err(e) = semver::Version::parse(&version) {
160        panic!("Invalid semantic version '{}': {}", version, e);
161    }
162
163    VersionedAttributes {
164        version,
165        version_key,
166        data_key,
167        auto_tag,
168        queryable,
169        queryable_key,
170    }
171}
172
173fn parse_versioned_attrs(
174    tokens: &str,
175    version: &mut Option<String>,
176    version_key: &mut String,
177    data_key: &mut String,
178    auto_tag: &mut bool,
179    queryable: &mut bool,
180    queryable_key: &mut Option<String>,
181) {
182    // Parse comma-separated key = "value" pairs
183    for part in tokens.split(',') {
184        let part = part.trim();
185
186        if let Some(val) = parse_attr_value(part, "version") {
187            *version = Some(val);
188        } else if let Some(val) = parse_attr_value(part, "version_key") {
189            *version_key = val;
190        } else if let Some(val) = parse_attr_value(part, "data_key") {
191            *data_key = val;
192        } else if let Some(val) = parse_attr_bool_value(part, "auto_tag") {
193            *auto_tag = val;
194        } else if let Some(val) = parse_attr_bool_value(part, "queryable") {
195            *queryable = val;
196        } else if let Some(val) = parse_attr_value(part, "queryable_key") {
197            *queryable_key = Some(val);
198        }
199    }
200}
201
202fn parse_attr_value(token: &str, key: &str) -> Option<String> {
203    let token = token.trim();
204    if let Some(rest) = token.strip_prefix(key) {
205        let rest = rest.trim();
206        if let Some(rest) = rest.strip_prefix('=') {
207            let rest = rest.trim();
208            if rest.starts_with('"') && rest.ends_with('"') {
209                return Some(rest[1..rest.len() - 1].to_string());
210            }
211        }
212    }
213    None
214}
215
216fn parse_attr_bool_value(token: &str, key: &str) -> Option<bool> {
217    let token = token.trim();
218    if let Some(rest) = token.strip_prefix(key) {
219        let rest = rest.trim();
220        if let Some(rest) = rest.strip_prefix('=') {
221            let rest = rest.trim();
222            return match rest {
223                "true" => Some(true),
224                "false" => Some(false),
225                _ => None,
226            };
227        }
228    }
229    None
230}
231
232fn generate_queryable_impl(
233    input: &DeriveInput,
234    attrs: &VersionedAttributes,
235) -> proc_macro2::TokenStream {
236    let name = &input.ident;
237
238    // Determine the entity name
239    let entity_name = if let Some(ref key) = attrs.queryable_key {
240        key.clone()
241    } else {
242        // Default: use the type name in lowercase
243        name.to_string().to_lowercase()
244    };
245
246    quote! {
247        impl version_migrate::Queryable for #name {
248            const ENTITY_NAME: &'static str = #entity_name;
249        }
250    }
251}
252
253fn generate_serialize_impl(
254    input: &DeriveInput,
255    attrs: &VersionedAttributes,
256) -> proc_macro2::TokenStream {
257    let name = &input.ident;
258    let version = &attrs.version;
259    let version_key = &attrs.version_key;
260
261    // Extract field information
262    let fields = match &input.data {
263        syn::Data::Struct(data_struct) => match &data_struct.fields {
264            syn::Fields::Named(fields) => &fields.named,
265            _ => panic!("auto_tag only supports structs with named fields"),
266        },
267        _ => panic!("auto_tag only supports structs"),
268    };
269
270    let field_count = fields.len() + 1; // +1 for version field
271    let field_serializations = fields.iter().map(|field| {
272        let field_name = field.ident.as_ref().unwrap();
273        let field_name_str = field_name.to_string();
274        quote! {
275            state.serialize_field(#field_name_str, &self.#field_name)?;
276        }
277    });
278
279    quote! {
280        impl serde::Serialize for #name {
281            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
282            where
283                S: serde::Serializer,
284            {
285                use serde::ser::SerializeStruct;
286                let mut state = serializer.serialize_struct(stringify!(#name), #field_count)?;
287                state.serialize_field(#version_key, #version)?;
288                #(#field_serializations)*
289                state.end()
290            }
291        }
292    }
293}
294
295fn generate_deserialize_impl(
296    input: &DeriveInput,
297    attrs: &VersionedAttributes,
298) -> proc_macro2::TokenStream {
299    let name = &input.ident;
300    let version = &attrs.version;
301    let version_key = &attrs.version_key;
302
303    // Extract field information
304    let fields = match &input.data {
305        syn::Data::Struct(data_struct) => match &data_struct.fields {
306            syn::Fields::Named(fields) => &fields.named,
307            _ => panic!("auto_tag only supports structs with named fields"),
308        },
309        _ => panic!("auto_tag only supports structs"),
310    };
311
312    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
313    let field_name_strs: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
314
315    let all_field_names = {
316        let mut names = vec![version_key.clone()];
317        names.extend(field_name_strs.iter().cloned());
318        names
319    };
320
321    let field_enum_variants = field_names.iter().map(|name| {
322        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
323        quote! { #variant }
324    });
325
326    let field_match_arms =
327        field_names
328            .iter()
329            .zip(field_name_strs.iter())
330            .map(|(name, name_str)| {
331                let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
332                quote! {
333                    #name_str => Ok(Field::#variant)
334                }
335            });
336
337    let field_visit_arms = field_names.iter().map(|name| {
338        let variant = quote::format_ident!("{}", name.to_string().to_uppercase());
339        quote! {
340            Field::#variant => {
341                if #name.is_some() {
342                    return Err(serde::de::Error::duplicate_field(stringify!(#name)));
343                }
344                #name = Some(map.next_value()?);
345            }
346        }
347    });
348
349    let field_unwrap = field_names.iter().map(|name| {
350        quote! {
351            let #name = #name.ok_or_else(|| serde::de::Error::missing_field(stringify!(#name)))?;
352        }
353    });
354
355    quote! {
356        impl<'de> serde::Deserialize<'de> for #name {
357            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
358            where
359                D: serde::Deserializer<'de>,
360            {
361                #[allow(non_camel_case_types)]
362                enum Field {
363                    Version,
364                    #(#field_enum_variants,)*
365                }
366
367                impl<'de> serde::Deserialize<'de> for Field {
368                    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
369                    where
370                        D: serde::Deserializer<'de>,
371                    {
372                        struct FieldVisitor;
373
374                        impl<'de> serde::de::Visitor<'de> for FieldVisitor {
375                            type Value = Field;
376
377                            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
378                                formatter.write_str(&format!("field identifier: {}", &[#(#all_field_names),*].join(", ")))
379                            }
380
381                            fn visit_str<E>(self, value: &str) -> Result<Field, E>
382                            where
383                                E: serde::de::Error,
384                            {
385                                match value {
386                                    #version_key => Ok(Field::Version),
387                                    #(#field_match_arms,)*
388                                    _ => Err(serde::de::Error::unknown_field(value, &[#(#all_field_names),*])),
389                                }
390                            }
391                        }
392
393                        deserializer.deserialize_identifier(FieldVisitor)
394                    }
395                }
396
397                struct StructVisitor;
398
399                impl<'de> serde::de::Visitor<'de> for StructVisitor {
400                    type Value = #name;
401
402                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
403                        formatter.write_str(&format!("struct {}", stringify!(#name)))
404                    }
405
406                    fn visit_map<V>(self, mut map: V) -> Result<#name, V::Error>
407                    where
408                        V: serde::de::MapAccess<'de>,
409                    {
410                        let mut version: Option<String> = None;
411                        #(let mut #field_names = None;)*
412
413                        while let Some(key) = map.next_key()? {
414                            match key {
415                                Field::Version => {
416                                    if version.is_some() {
417                                        return Err(serde::de::Error::duplicate_field(#version_key));
418                                    }
419                                    let v: String = map.next_value()?;
420                                    if v != #version {
421                                        return Err(serde::de::Error::custom(format!(
422                                            "version mismatch: expected {}, found {}",
423                                            #version, v
424                                        )));
425                                    }
426                                    version = Some(v);
427                                }
428                                #(#field_visit_arms)*
429                            }
430                        }
431
432                        let _version = version.ok_or_else(|| serde::de::Error::missing_field(#version_key))?;
433                        #(#field_unwrap)*
434
435                        Ok(#name {
436                            #(#field_names,)*
437                        })
438                    }
439                }
440
441                deserializer.deserialize_struct(
442                    stringify!(#name),
443                    &[#(#all_field_names),*],
444                    StructVisitor,
445                )
446            }
447        }
448    }
449}