controlled_option_macros/
lib.rs

1// -*- coding: utf-8 -*-
2// ------------------------------------------------------------------------------------------------
3// Copyright © 2021, Douglas Creager.
4// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
5// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
6// ------------------------------------------------------------------------------------------------
7
8extern crate proc_macro;
9
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::parse_macro_input;
13use syn::parse_quote;
14use syn::Field;
15use syn::Fields;
16use syn::Item;
17use syn::Member;
18use syn::Type;
19use syn::WhereClause;
20
21fn field_is_niche(field: &&Field) -> bool {
22    for attr in &field.attrs {
23        if attr.path.is_ident("niche") {
24            return true;
25        }
26    }
27    false
28}
29
30fn merge_where_clauses(lhs: Option<WhereClause>, rhs: WhereClause) -> WhereClause {
31    match lhs {
32        Some(mut lhs) => {
33            lhs.predicates.extend(rhs.predicates);
34            lhs
35        }
36        None => rhs,
37    }
38}
39
40#[proc_macro_derive(Niche, attributes(niche))]
41pub fn derive_decode(input: TokenStream) -> TokenStream {
42    let item = parse_macro_input!(input as Item);
43    match &item {
44        Item::Struct(item) => {
45            let ty_name = &item.ident;
46            let ty_generics = &item.generics;
47            let ty_where_clause = item.generics.where_clause.as_ref().cloned();
48
49            // Find the field that is marked #[niche].  In a regular struct, extract its name; in a
50            // tuple struct, extract its index.  In both cases, that can be converted into a
51            // `Member`, which is the type needed down below in the field access expression.
52            let niche_field_name: Member;
53            let niche_field_type: &Type;
54            match &item.fields {
55                Fields::Named(fields) => {
56                    let niche_field = match fields.named.iter().find(field_is_niche) {
57                        Some(field) if field.ident.is_some() => field,
58                        _ => {
59                            let msg = "#[derive(Niche)] requires a field marked #[niche]";
60                            return syn::parse::Error::new_spanned(item, msg)
61                                .to_compile_error()
62                                .into();
63                        }
64                    };
65                    niche_field_name = niche_field.ident.as_ref().unwrap().clone().into();
66                    niche_field_type = &niche_field.ty;
67                }
68                Fields::Unnamed(fields) => {
69                    let (idx, niche_field) = match fields
70                        .unnamed
71                        .iter()
72                        .enumerate()
73                        .find(|(_, field)| field_is_niche(field))
74                    {
75                        Some((idx, field)) => (idx, field),
76                        None => {
77                            let msg = "#[derive(Niche)] requires a field marked #[niche]";
78                            return syn::parse::Error::new_spanned(item, msg)
79                                .to_compile_error()
80                                .into();
81                        }
82                    };
83                    niche_field_name = idx.into();
84                    niche_field_type = &niche_field.ty;
85                }
86                Fields::Unit => {
87                    let msg = "#[derive(Niche)] cannot be used on an empty tuple struct";
88                    return syn::parse::Error::new_spanned(item, msg)
89                        .to_compile_error()
90                        .into();
91                }
92            }
93
94            let where_clause = merge_where_clauses(
95                ty_where_clause,
96                parse_quote! { where #niche_field_type: ::controlled_option::Niche },
97            );
98
99            let output = quote! {
100                impl #ty_generics ::controlled_option::Niche for #ty_name #ty_generics
101                #where_clause
102                {
103                    type Output = ::std::mem::MaybeUninit<Self>;
104
105                    #[inline]
106                    fn none() -> Self::Output {
107                        let mut value = Self::Output::uninit();
108                        let ptr = value.as_mut_ptr();
109                        ::controlled_option::fill_struct_field_with_none(
110                            unsafe { ::std::ptr::addr_of_mut!((*ptr).#niche_field_name) }
111                        );
112                        value
113                    }
114
115                    #[inline]
116                    fn is_none(value: &Self::Output) -> bool {
117                        let ptr = value.as_ptr();
118                        ::controlled_option::struct_field_is_none(
119                            unsafe { ::std::ptr::addr_of!((*ptr).#niche_field_name) }
120                        )
121                    }
122
123                    #[inline]
124                    fn into_some(value: Self) -> Self::Output {
125                        ::std::mem::MaybeUninit::new(value)
126                    }
127
128                    #[inline]
129                    fn from_some(value: Self::Output) -> Self {
130                        unsafe { value.assume_init() }
131                    }
132                }
133            };
134            output.into()
135        }
136        _ => {
137            let msg = "#[derive(Niche)] is only supported on struct types";
138            syn::parse::Error::new_spanned(item, msg)
139                .to_compile_error()
140                .into()
141        }
142    }
143}