cfg_vis/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro2::{Span, TokenStream};
4use proc_macro_crate::{crate_name, FoundCrate};
5use quote::{quote, ToTokens};
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8use syn::parse::{Parse, ParseStream};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{parse_macro_input, parse_quote};
12
13struct CfgVisAttrArgs {
14    cfg: syn::NestedMeta,
15    _comma: Option<syn::Token![,]>,
16    vis: syn::Visibility,
17}
18
19impl Parse for CfgVisAttrArgs {
20    fn parse(input: ParseStream) -> syn::Result<Self> {
21        let cfg = input.parse()?;
22        let comma: Option<syn::Token![,]> = input.parse()?;
23
24        if comma.is_none() || input.is_empty() {
25            return Ok(Self {
26                cfg,
27                _comma: comma,
28                vis: syn::Visibility::Inherited,
29            });
30        }
31
32        let vis = input.parse()?;
33
34        Ok(Self {
35            cfg,
36            _comma: comma,
37            vis,
38        })
39    }
40}
41
42impl ToTokens for CfgVisAttrArgs {
43    fn to_tokens(&self, tokens: &mut TokenStream) {
44        self.cfg.to_tokens(tokens);
45        self._comma.to_tokens(tokens);
46        self.vis.to_tokens(tokens);
47    }
48}
49
50struct CfgVisAttrArgsAccumulator {
51    version: String,
52    _semi: Option<syn::Token![;]>,
53    acc: Punctuated<CfgVisAttrArgs, syn::Token![;]>,
54}
55
56impl Parse for CfgVisAttrArgsAccumulator {
57    fn parse(input: ParseStream) -> syn::Result<Self> {
58        let version_str: syn::LitStr = input.parse()?;
59        Ok(Self {
60            version: version_str.value(),
61            _semi: input.parse()?,
62            acc: Punctuated::parse_terminated(input)?,
63        })
64    }
65}
66
67impl ToTokens for CfgVisAttrArgsAccumulator {
68    fn to_tokens(&self, tokens: &mut TokenStream) {
69        self.version.to_tokens(tokens);
70        self._semi.to_tokens(tokens);
71        self.acc.to_tokens(tokens);
72    }
73}
74
75///
76/// # cfg visibility on items
77///
78/// ## Rules
79///
80/// ```ignore
81/// #[cfg_vis($cond1:meta, $vis1:vis)]
82/// #[cfg_vis($cond2:meta, $vis2:vis)]
83/// #[cfg_vis($cond3:meta, $vis3:vis)]
84/// $default_vis:vis $($item:tt)*
85/// ```
86///
87/// will expend to
88///
89/// ```ignore
90/// #[cfg($cond1)]
91/// $vis1 $($item)*
92///
93/// #[cfg($cond2)]
94/// $vis2 $($item)*
95///
96/// #[cfg($cond3)]
97/// $vis3 $($item)*
98///
99/// #[cfg(not($cond1))]
100/// #[cfg(not($cond2))]
101/// #[cfg(not($cond3))]
102/// $default_vis $($item)*
103/// ```
104///
105/// ## Example
106///
107/// ```rust
108/// use cfg_vis::cfg_vis;
109///
110/// // default visibility is `pub`, while the target is linux, the visibility is `pub(crate)`.
111/// #[cfg_vis(target_os = "linux", pub(crate))]
112/// pub fn foo() {}
113/// ```
114///
115#[proc_macro_attribute]
116pub fn cfg_vis(
117    attr: proc_macro::TokenStream,
118    item: proc_macro::TokenStream,
119) -> proc_macro::TokenStream {
120    let cfg_vis_attr = parse_macro_input!(attr as CfgVisAttrArgs);
121    let item = parse_macro_input!(item as syn::Item);
122    cfg_vis_impl(cfg_vis_attr, item)
123        .unwrap_or_else(|err| proc_macro::TokenStream::from(err.into_compile_error()))
124}
125
126fn cfg_vis_impl(
127    cfg_vis_attr: CfgVisAttrArgs,
128    mut item: syn::Item,
129) -> syn::Result<proc_macro::TokenStream> {
130    let version = env!("CARGO_PKG_VERSION");
131    let __cfg_vis_accumulator_declare_path = __cfg_vis_accumulator_declare_path();
132
133    let (_, attrs) = proj_item(&mut item);
134    let mut accumulator= attrs
135        .iter_mut()
136        .filter_map(|acc_attr| {
137            if acc_attr.path == __cfg_vis_accumulator_declare_path {
138                Some((acc_attr.parse_args(), acc_attr))
139            } else {
140                None
141            }
142        })
143        .map(|(acc, attr)| {
144            let acc: CfgVisAttrArgsAccumulator = acc?;
145            if acc.version == version {
146                Ok((acc, attr))
147            } else {
148                Err(syn::Error::new(
149                    Span::call_site(),
150                    format!("multiple versions of cfg-vis conflict, current version: {:?}, other version: {:?}", version, acc.version))
151                )
152            }
153        })
154        .collect::<syn::Result<Vec<_>>>()?;
155
156    match &mut accumulator[..] {
157        [] => {
158            // the last attr
159            attrs.push(
160                parse_quote!(#[#__cfg_vis_accumulator_declare_path(#version; #cfg_vis_attr)]),
161            );
162        }
163        [(acc, attr)] => {
164            acc.acc.push(cfg_vis_attr);
165            **attr = parse_quote!(#[#__cfg_vis_accumulator_declare_path(#acc)]);
166        }
167        _ => {
168            return Err(syn::Error::new(
169                Span::call_site(),
170                "multiple cfg-vis accumulators exist, it's a bug.",
171            ))
172        }
173    }
174
175    Ok(item.into_token_stream().into())
176}
177
178// expend after all `cfg_vis` were expended
179#[doc(hidden)]
180#[proc_macro_attribute]
181pub fn __cfg_vis_accumulator(
182    attr: proc_macro::TokenStream,
183    item: proc_macro::TokenStream,
184) -> proc_macro::TokenStream {
185    let accumulator = parse_macro_input!(attr as CfgVisAttrArgsAccumulator);
186    let item = parse_macro_input!(item as syn::Item);
187
188    let version = env!("CARGO_PKG_VERSION");
189    if accumulator.version != version {
190        return syn::Error::new(
191            Span::call_site(),
192            format!(
193                "multiple versions of cfg-vis conflict, current version: {:?}, other version: {:?}",
194                version, accumulator.version
195            ),
196        )
197        .into_compile_error()
198        .into();
199    }
200    // generate:
201    //
202    // #[cfg($cond_n)]
203    // $vis_n $($item)*
204    let mut token_stream = TokenStream::new();
205    for cfg_vis_args in &accumulator.acc {
206        let vis = cfg_vis_args.vis.clone();
207        let cfg = &cfg_vis_args.cfg;
208
209        let mut tmp_item = item.clone();
210        let (tmp_vis, tmp_attrs) = proj_item(&mut tmp_item);
211        *tmp_vis = vis;
212        tmp_attrs.push(parse_quote!(#[cfg(#cfg)]));
213        tmp_item.to_tokens(&mut token_stream);
214    }
215
216    // generate
217    //
218    // #[cfg(not($cond_1))]
219    // #[cfg(not($cond_2))]
220    // ..
221    // #[cfg(not($cond_n))]
222    // $default_vis $($item)*
223    let cfgs = accumulator.acc.iter().map(|cfg_vis_args| &cfg_vis_args.cfg);
224    let default_item = quote! {
225        #( #[cfg(not(#cfgs))] )*
226        #item
227    };
228    token_stream.extend(default_item);
229
230    // check_unique
231    let check_unique = assert_accumulator_is_unique(&item);
232    token_stream.extend(check_unique);
233
234    token_stream.into()
235}
236
237/// `$crate::__cfg_vis_accumulator`
238fn __cfg_vis_accumulator_declare_path() -> syn::Path {
239    let found_name = crate_name("cfg-vis").expect("cfg-vis is present in `Cargo.toml`");
240
241    match found_name {
242        FoundCrate::Itself => {
243            parse_quote!(::cfg_vis::__cfg_vis_accumulator)
244        }
245        FoundCrate::Name(cfg_vis) => {
246            let cfg_vis = syn::Ident::new(&cfg_vis, Span::call_site());
247            parse_quote!(::#cfg_vis::__cfg_vis_accumulator)
248        }
249    }
250}
251
252fn proj_item(item: &mut syn::Item) -> (&mut syn::Visibility, &mut Vec<syn::Attribute>) {
253    match item {
254        syn::Item::Const(i) => (&mut i.vis, &mut i.attrs),
255        syn::Item::Enum(i) => (&mut i.vis, &mut i.attrs),
256        syn::Item::ExternCrate(i) => (&mut i.vis, &mut i.attrs),
257        syn::Item::Fn(i) => (&mut i.vis, &mut i.attrs),
258        syn::Item::Macro2(i) => (&mut i.vis, &mut i.attrs),
259        syn::Item::Mod(i) => (&mut i.vis, &mut i.attrs),
260        syn::Item::Static(i) => (&mut i.vis, &mut i.attrs),
261        syn::Item::Struct(i) => (&mut i.vis, &mut i.attrs),
262        syn::Item::Trait(i) => (&mut i.vis, &mut i.attrs),
263        syn::Item::TraitAlias(i) => (&mut i.vis, &mut i.attrs),
264        syn::Item::Type(i) => (&mut i.vis, &mut i.attrs),
265        syn::Item::Union(i) => (&mut i.vis, &mut i.attrs),
266        syn::Item::Use(i) => (&mut i.vis, &mut i.attrs),
267        _ => {
268            panic!("`cfg_vis` can only apply on item with visibility");
269        }
270    }
271}
272
273fn assert_accumulator_is_unique(item: &syn::Item) -> TokenStream {
274    let mut hasher = DefaultHasher::new();
275
276    PartialHashItemHelper(item).hash(&mut hasher);
277
278    // different version of package make a different accumulator
279    env!("CARGO_PKG_VERSION").hash(&mut hasher);
280
281    let name = format!(
282        "__CFG_VIS_ACCUMULATOR_MUST_EXPAND_ONCE_OTHERWISE_IS_A_BUG_{}",
283        hasher.finish()
284    );
285    let check_unique = syn::Ident::new(&name, Span::call_site());
286
287    quote! {
288        const #check_unique: () = ();
289    }
290}
291
292struct PartialHashItemHelper<'a>(&'a syn::Item);
293
294impl Hash for PartialHashItemHelper<'_> {
295    fn hash<H: Hasher>(&self, state: &mut H) {
296        std::mem::discriminant(self.0).hash(state);
297        match &self.0 {
298            syn::Item::Const(v0) => {
299                v0.ident.hash(state);
300            }
301            syn::Item::Enum(v0) => {
302                v0.ident.hash(state);
303            }
304            syn::Item::ExternCrate(v0) => {
305                v0.ident.hash(state);
306                v0.rename.hash(state);
307            }
308            syn::Item::Fn(v0) => {
309                v0.sig.ident.hash(state);
310            }
311            syn::Item::Macro(v0) => {
312                v0.ident.hash(state);
313            }
314            syn::Item::Macro2(v0) => {
315                v0.ident.hash(state);
316            }
317            syn::Item::Mod(v0) => {
318                v0.ident.hash(state);
319            }
320            syn::Item::Static(v0) => {
321                v0.ident.hash(state);
322            }
323            syn::Item::Struct(v0) => {
324                v0.ident.hash(state);
325            }
326            syn::Item::Trait(v0) => {
327                v0.ident.hash(state);
328            }
329            syn::Item::TraitAlias(v0) => {
330                v0.ident.hash(state);
331            }
332            syn::Item::Type(v0) => {
333                v0.ident.hash(state);
334            }
335            syn::Item::Union(v0) => {
336                v0.ident.hash(state);
337            }
338            syn::Item::Use(v0) => {
339                v0.tree.hash(state);
340            }
341            _ => self.0.hash(state),
342        }
343    }
344}
345
346///
347/// # cfg visibility on fields
348///
349/// ## Rules
350///
351/// `#[cfg_vis]` on field as same as it on item.
352///
353/// ## Example
354///
355/// ```rust
356/// use cfg_vis::cfg_vis_fields;
357///
358/// #[cfg_vis_fields]
359/// struct Foo {
360///     // while the target is linux, the visibility is `pub`.
361///     #[cfg_vis(target_os = "linux", pub)]
362///     foo: i32,
363/// }
364/// ```
365///
366#[proc_macro_attribute]
367pub fn cfg_vis_fields(
368    attr: proc_macro::TokenStream,
369    item: proc_macro::TokenStream,
370) -> proc_macro::TokenStream {
371    if !attr.is_empty() {
372        let err = syn::Error::new(
373            Span::call_site(),
374            format!("unsupported arg \"{}\" for `cfg_vis_fields`", attr),
375        );
376        return proc_macro::TokenStream::from(err.into_compile_error());
377    }
378
379    let item = parse_macro_input!(item as syn::Item);
380
381    let toks = cfg_vis_fields_impl(item)
382        .map(|item| quote! { #item })
383        .unwrap_or_else(|err| err.to_compile_error());
384
385    proc_macro::TokenStream::from(toks)
386}
387
388fn cfg_vis_fields_impl(mut item: syn::Item) -> syn::Result<syn::Item> {
389    fn find_replace_cfg_vis(
390        fields: Punctuated<syn::Field, syn::Token![,]>,
391    ) -> syn::Result<Punctuated<syn::Field, syn::Token![,]>> {
392        let mut fields_replaced = vec![];
393
394        for mut field in fields {
395            let cfg_vis_attrs = take_all_cfg_vis(&mut field.attrs)?;
396            for (mut cfgs, vis) in expend_cfg_vis_attrs(cfg_vis_attrs, field.vis.clone()) {
397                let mut field = field.clone();
398                field.vis = vis;
399                field.attrs.append(&mut cfgs);
400                fields_replaced.push(field);
401            }
402        }
403
404        Ok(Punctuated::from_iter(fields_replaced))
405    }
406
407    let fields = match &mut item {
408        syn::Item::Struct(s) => match &mut s.fields {
409            syn::Fields::Named(f) => &mut f.named,
410            syn::Fields::Unnamed(f) => &mut f.unnamed,
411            syn::Fields::Unit => {
412                return Ok(item);
413            }
414        },
415
416        syn::Item::Union(u) => &mut u.fields.named,
417        _ => {
418            return Err(syn::Error::new(
419                item.span(),
420                "`cfg_vis_fields` can only apply on struct or union",
421            ))
422        }
423    };
424
425    *fields = find_replace_cfg_vis(std::mem::take(fields))?;
426
427    Ok(item)
428}
429
430fn take_all_cfg_vis(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Vec<CfgVisAttrArgs>> {
431    let (cfg_vis_attrs, remain_attrs): (Vec<_>, Vec<_>) = std::mem::take(attrs)
432        .into_iter()
433        .partition(|attr| attr.path.is_ident("cfg_vis"));
434
435    *attrs = remain_attrs;
436
437    cfg_vis_attrs
438        .into_iter()
439        .map(|attr| attr.parse_args::<CfgVisAttrArgs>())
440        .collect()
441}
442
443fn expend_cfg_vis_attrs(
444    cfg_vis_attrs: Vec<CfgVisAttrArgs>,
445    default_vis: syn::Visibility,
446) -> impl Iterator<Item = (Vec<syn::Attribute>, syn::Visibility)> {
447    let default_cfg_attrs = cfg_vis_attrs
448        .iter()
449        .map(|attr| {
450            let cfg = &attr.cfg;
451            parse_quote! {
452                #[cfg(not(#cfg))]
453            }
454        })
455        .collect::<Vec<_>>();
456
457    cfg_vis_attrs
458        .into_iter()
459        .map(|cfg_vis| {
460            let cfg = cfg_vis.cfg;
461            let cfgs = vec![parse_quote! {
462                #[cfg(#cfg)]
463            }];
464
465            (cfgs, cfg_vis.vis)
466        })
467        .chain(Some((default_cfg_attrs, default_vis)))
468}