Skip to main content

provcfg_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5/// Derives the companion `*Partial`/`*Prov` types and the impls that make a
6/// plain config struct provenance-tracking.
7///
8/// Generated items for a struct `Foo`:
9///
10/// - `FooPartial`: every leaf wrapped in `Option` (`Deserialize` + `Serialize`).
11/// - `FooProv`: every leaf wrapped in `ValueHistory` (`Serialize`).
12/// - `impl Provenance for FooProv`: the defaults layer plus per-source merge.
13/// - `impl From<&Foo> for FooPartial` and `impl From<&FooProv> for Foo`.
14///
15/// The base struct must derive `Default + Clone + serde::Deserialize`. Mark
16/// sub-struct fields with `#[configurable(nested)]` to recurse; see
17/// `provcfg::Configurable` for the full attribute reference.
18#[proc_macro_derive(Configurable, attributes(configurable))]
19pub fn configurable_derive(input: TokenStream) -> TokenStream {
20    let ast = parse_macro_input!(input as DeriveInput);
21    let all_fields = match extract_named_fields(&ast) {
22        Ok(f) => f,
23        Err(e) => return e.to_compile_error().into(),
24    };
25    // Fields marked `#[configurable(skip)]` are not part of the config schema.
26    // They stay on the user struct but are invisible to everything we generate.
27    // The user struct still has them, so `From<&Prov>` fills them via `Default`.
28    let has_skipped = all_fields.iter().any(|f| f.skip);
29    let fields: Vec<Field> = all_fields.into_iter().filter(|f| !f.skip).collect();
30
31    let partial_struct = generate_partial(&ast.ident, &fields);
32    let partial_from_ref = generate_partial_from_ref(&ast.ident, &fields);
33    let prov_struct = generate_prov(&ast.ident, &fields);
34    let provenance_impl = generate_provenance_impl(&ast.ident, &fields);
35    let prov_serialize = generate_prov_serialize(&ast.ident, &fields);
36    let prov_into_user = generate_prov_into_user(&ast.ident, &fields, has_skipped);
37
38    let expanded = quote! {
39        #partial_struct
40        #partial_from_ref
41        #prov_struct
42        #provenance_impl
43        #prov_serialize
44        #prov_into_user
45    };
46
47    TokenStream::from(expanded)
48}
49
50struct Field {
51    name: syn::Ident,
52    ty: syn::Type,
53    nested: bool,
54    secret: bool,
55    /// `true` when the user marked this field with `#[configurable(env_list)]`.
56    /// Triggers a custom `deserialize_with` that accepts either an array or a
57    /// comma-separated string. Required for `Vec<String>` fields populated
58    /// from environment variables.
59    env_list: bool,
60    /// `true` when the user marked this field with `#[configurable(skip)]`.
61    /// The field stays on the user struct but is omitted from the partial,
62    /// prov, merge, collect_sources, and Serialize impls. Useful for fields
63    /// that live alongside config (e.g. runtime state, derived caches).
64    skip: bool,
65    /// Optional rename applied to the partial field's serde key. Preserved
66    /// verbatim; the macro emits a lowercase `#[serde(alias = ...)]` alongside
67    /// so the env source (which lowercases segments) still matches.
68    rename: Option<String>,
69}
70
71fn extract_named_fields(ast: &DeriveInput) -> syn::Result<Vec<Field>> {
72    let named = match &ast.data {
73        Data::Struct(data) => match &data.fields {
74            Fields::Named(fields) => &fields.named,
75            _ => {
76                return Err(syn::Error::new_spanned(
77                    &ast.ident,
78                    "Configurable only supports structs with named fields",
79                ));
80            }
81        },
82        _ => {
83            return Err(syn::Error::new_spanned(
84                &ast.ident,
85                "Configurable only supports structs",
86            ));
87        }
88    };
89
90    named
91        .iter()
92        .map(|f| {
93            let flags = parse_field_attrs(&f.attrs)?;
94            Ok(Field {
95                name: f.ident.clone().unwrap(),
96                ty: f.ty.clone(),
97                nested: flags.nested,
98                secret: flags.secret,
99                env_list: flags.env_list,
100                skip: flags.skip,
101                rename: flags.rename,
102            })
103        })
104        .collect()
105}
106
107#[derive(Default)]
108struct FieldFlags {
109    nested: bool,
110    secret: bool,
111    env_list: bool,
112    skip: bool,
113    rename: Option<String>,
114}
115
116fn parse_field_attrs(attrs: &[syn::Attribute]) -> syn::Result<FieldFlags> {
117    let mut flags = FieldFlags::default();
118    for attr in attrs {
119        if !attr.path().is_ident("configurable") {
120            continue;
121        }
122        attr.parse_nested_meta(|meta| {
123            if meta.path.is_ident("nested") {
124                flags.nested = true;
125                Ok(())
126            } else if meta.path.is_ident("secret") {
127                flags.secret = true;
128                Ok(())
129            } else if meta.path.is_ident("env_list") {
130                flags.env_list = true;
131                Ok(())
132            } else if meta.path.is_ident("skip") {
133                flags.skip = true;
134                Ok(())
135            } else if meta.path.is_ident("rename") {
136                let lit: syn::LitStr = meta.value()?.parse()?;
137                flags.rename = Some(lit.value());
138                Ok(())
139            } else {
140                Err(meta.error(
141                    "unknown #[configurable(...)] key; expected one of nested, secret, env_list, skip, rename",
142                ))
143            }
144        })?;
145    }
146    Ok(flags)
147}
148
149fn field_key(f: &Field) -> String {
150    f.rename.clone().unwrap_or_else(|| f.name.to_string())
151}
152
153fn dotted_key_join(key: &str) -> proc_macro2::TokenStream {
154    quote! {
155        if prefix.is_empty() {
156            #key.to_string()
157        } else {
158            format!("{}.{}", prefix, #key)
159        }
160    }
161}
162
163fn rename_last_segment(ty: &syn::Type, suffix: &str) -> syn::Type {
164    let mut tp = match ty {
165        syn::Type::Path(tp) => tp.clone(),
166        _ => panic!("#[configurable(nested)] requires a path type, e.g. `Database`"),
167    };
168    let last = tp
169        .path
170        .segments
171        .last_mut()
172        .expect("path type has at least one segment");
173    last.ident = format_ident!("{}{}", last.ident, suffix);
174    syn::Type::Path(tp)
175}
176
177fn generate_partial(base_name: &syn::Ident, fields: &[Field]) -> proc_macro2::TokenStream {
178    let partial_name = format_ident!("{}Partial", base_name);
179
180    let field_definitions = fields.iter().map(|f| {
181        let name = &f.name;
182        // Preserve the rename verbatim for file formats; add a lowercase alias
183        // so the env source (which lowercases segments) still matches.
184        let rename = f.rename.as_ref().map(|r| {
185            let lc = r.to_lowercase();
186            if lc == *r {
187                quote! { #[serde(rename = #r)] }
188            } else {
189                quote! { #[serde(rename = #r, alias = #lc)] }
190            }
191        });
192        let env_list_attr = if f.env_list {
193            quote! { #[serde(deserialize_with = "provcfg::deserialize_env_list")] }
194        } else {
195            quote! {}
196        };
197        if f.nested {
198            let nested_partial = rename_last_segment(&f.ty, "Partial");
199            quote! { #rename #env_list_attr pub #name: Option<#nested_partial> }
200        } else {
201            let ty = &f.ty;
202            quote! { #rename #env_list_attr pub #name: Option<#ty> }
203        }
204    });
205
206    quote! {
207        // `Serialize` is generated so partial-producing sources (e.g. `CliSource`)
208        // can round-trip the partial through serde. `None` fields serialize to
209        // `null`, which deserializes back to `None` via the same Option type.
210        #[derive(serde::Deserialize, serde::Serialize, Default)]
211        #[serde(default)]
212        pub struct #partial_name {
213            #(#field_definitions),*
214        }
215    }
216}
217
218fn generate_partial_from_ref(base_name: &syn::Ident, fields: &[Field]) -> proc_macro2::TokenStream {
219    let partial_name = format_ident!("{}Partial", base_name);
220
221    let field_inits = fields.iter().map(|f| {
222        let name = &f.name;
223        if f.nested {
224            quote! { #name: Some((&value.#name).into()) }
225        } else {
226            quote! { #name: Some(value.#name.clone()) }
227        }
228    });
229
230    quote! {
231        impl ::core::convert::From<&#base_name> for #partial_name {
232            fn from(value: &#base_name) -> Self {
233                Self { #(#field_inits),* }
234            }
235        }
236    }
237}
238
239fn generate_prov(base_name: &syn::Ident, fields: &[Field]) -> proc_macro2::TokenStream {
240    let prov_name = format_ident!("{}Prov", base_name);
241
242    let field_definitions = fields.iter().map(|f| {
243        let name = &f.name;
244        if f.nested {
245            let nested_prov = rename_last_segment(&f.ty, "Prov");
246            quote! { pub #name: #nested_prov }
247        } else {
248            let ty = &f.ty;
249            quote! { pub #name: provcfg::ValueHistory<#ty> }
250        }
251    });
252
253    quote! {
254        pub struct #prov_name {
255            #(#field_definitions),*
256        }
257    }
258}
259
260fn generate_prov_serialize(base_name: &syn::Ident, fields: &[Field]) -> proc_macro2::TokenStream {
261    let prov_name = format_ident!("{}Prov", base_name);
262    let struct_name_lit = base_name.to_string();
263    let field_count = fields.len();
264
265    let field_writes = fields.iter().map(|f| {
266        let name = &f.name;
267        let key = field_key(f);
268        if f.nested {
269            quote! { state.serialize_field(#key, &self.#name)?; }
270        } else {
271            quote! { state.serialize_field(#key, self.#name.value())?; }
272        }
273    });
274
275    quote! {
276        impl serde::Serialize for #prov_name {
277            fn serialize<S: serde::Serializer>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error> {
278                use serde::ser::SerializeStruct as _;
279                let mut state = serializer.serialize_struct(#struct_name_lit, #field_count)?;
280                #(#field_writes)*
281                state.end()
282            }
283        }
284    }
285}
286
287fn generate_prov_into_user(
288    base_name: &syn::Ident,
289    fields: &[Field],
290    has_skipped: bool,
291) -> proc_macro2::TokenStream {
292    let prov_name = format_ident!("{}Prov", base_name);
293
294    let field_inits = fields.iter().map(|f| {
295        let name = &f.name;
296        if f.nested {
297            quote! { #name: ::core::convert::From::from(&prov.#name) }
298        } else {
299            quote! { #name: ::core::clone::Clone::clone(prov.#name.value()) }
300        }
301    });
302
303    // Skipped fields aren't tracked, so fill them from `Foo::default()`. The
304    // `Configurable` derive already requires `Foo: Default`. Trailing-comma
305    // form keeps `Self { ..Default::default() }` valid when every field is
306    // skipped.
307    let body = if has_skipped {
308        quote! {
309            Self {
310                #(#field_inits,)*
311                ..::core::default::Default::default()
312            }
313        }
314    } else {
315        quote! {
316            Self {
317                #(#field_inits),*
318            }
319        }
320    };
321
322    quote! {
323        impl ::core::convert::From<&#prov_name> for #base_name {
324            fn from(prov: &#prov_name) -> Self {
325                #body
326            }
327        }
328    }
329}
330
331fn generate_provenance_impl(base_name: &syn::Ident, fields: &[Field]) -> proc_macro2::TokenStream {
332    let prov_name = format_ident!("{}Prov", base_name);
333    let partial_name = format_ident!("{}Partial", base_name);
334
335    let leaf_history_inits = fields.iter().filter(|f| !f.nested).map(|f| {
336        let name = &f.name;
337        if f.secret {
338            quote! { let mut #name = provcfg::ValueHistory::new().mark_secret(); }
339        } else {
340            quote! { let mut #name = provcfg::ValueHistory::new(); }
341        }
342    });
343
344    let nested_partial_inits = fields.iter().filter(|f| f.nested).map(|f| {
345        let name = &f.name;
346        let nested_partial = rename_last_segment(&f.ty, "Partial");
347        quote! {
348            let mut #name: ::std::vec::Vec<Option<#nested_partial>> =
349                ::std::vec::Vec::with_capacity(partials.len());
350        }
351    });
352
353    let leaf_default_pushes = fields.iter().filter(|f| !f.nested).map(|f| {
354        let name = &f.name;
355        quote! {
356            #name.push(provcfg::Value {
357                value: defaults_partial.#name.expect("defaults partial must populate every leaf field"),
358                source: ::core::clone::Clone::clone(&defaults_src),
359            });
360        }
361    });
362
363    let per_source_steps = fields.iter().map(|f| {
364        let name = &f.name;
365        if f.nested {
366            quote! {
367                match partial {
368                    Some(ref mut p) => #name.push(::core::mem::take(&mut p.#name)),
369                    None => #name.push(None),
370                }
371            }
372        } else {
373            quote! {
374                if let Some(ref mut p) = partial
375                    && let Some(v) = ::core::mem::take(&mut p.#name)
376                {
377                    #name.push(provcfg::Value {
378                        value: v,
379                        source: ::core::clone::Clone::clone(source),
380                    });
381                }
382            }
383        }
384    });
385
386    let nested_merge_calls = fields.iter().filter(|f| f.nested).map(|f| {
387        let name = &f.name;
388        let nested_prov = rename_last_segment(&f.ty, "Prov");
389        quote! {
390            let #name = <#nested_prov as provcfg::Provenance>::merge(sources, #name);
391        }
392    });
393
394    let field_names = fields.iter().map(|f| &f.name);
395
396    let collect_steps = fields.iter().map(|f| {
397        let name = &f.name;
398        let key = field_key(f);
399        let joined = dotted_key_join(&key);
400        if f.nested {
401            quote! {
402                {
403                    let next_prefix = #joined;
404                    self.#name.collect_sources(&next_prefix, out);
405                }
406            }
407        } else {
408            quote! {
409                {
410                    let key = #joined;
411                    out.insert(key, self.#name.source().category());
412                }
413            }
414        }
415    });
416
417    let walk_steps = fields.iter().map(|f| {
418        let name = &f.name;
419        let key = field_key(f);
420        let joined = dotted_key_join(&key);
421        if f.nested {
422            quote! {
423                {
424                    let next_prefix = #joined;
425                    self.#name.walk_leaves(&next_prefix, visitor);
426                }
427            }
428        } else {
429            quote! {
430                {
431                    let key = #joined;
432                    visitor(
433                        &key,
434                        self.#name.value(),
435                        self.#name.source().category(),
436                        self.#name.is_secret(),
437                    );
438                }
439            }
440        }
441    });
442
443    quote! {
444        impl provcfg::Provenance for #prov_name {
445            type Partial = #partial_name;
446
447            fn defaults_partial() -> Self::Partial {
448                (&<#base_name>::default()).into()
449            }
450
451            fn merge(
452                sources: &[provcfg::SourceArc],
453                partials: Vec<Option<Self::Partial>>,
454            ) -> Self {
455                #(#leaf_history_inits)*
456                #(#nested_partial_inits)*
457
458                // Leaf-field defaults layer. Nested fields handle defaults via
459                // their own recursive merge.
460                let defaults_partial = <Self as provcfg::Provenance>::defaults_partial();
461                let defaults_src: provcfg::SourceArc = provcfg::defaults_source();
462                #(#leaf_default_pushes)*
463
464                for (source, mut partial) in sources.iter().zip(partials) {
465                    #(#per_source_steps)*
466                }
467
468                #(#nested_merge_calls)*
469
470                #prov_name { #(#field_names),* }
471            }
472
473            fn collect_sources(
474                &self,
475                prefix: &str,
476                out: &mut ::std::collections::HashMap<::std::string::String, provcfg::Category>,
477            ) {
478                #(#collect_steps)*
479            }
480
481            fn walk_leaves(
482                &self,
483                prefix: &str,
484                visitor: &mut dyn ::core::ops::FnMut(
485                    &str,
486                    &dyn provcfg::erased_serde::Serialize,
487                    provcfg::Category,
488                    bool,
489                ),
490            ) {
491                #(#walk_steps)*
492            }
493        }
494    }
495}