1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
//! Derives instances of PartialRefTarget and associated traits for the `partial_ref` crate.
#![recursion_limit = "128"]
extern crate proc_macro;

use std::collections::HashSet;

use crate::proc_macro::TokenStream;

use proc_macro2::{Span, TokenTree};
use quote::{quote, ToTokens};
use syn::{
    parse_macro_input, parse_quote, parse_str, Attribute, Data, DeriveInput, Lifetime, LifetimeDef,
    Lit, Member, Meta, Type, TypeParen,
};

fn parse_attribute_as_type(attr: &Attribute) -> Type {
    if let Some(TokenTree::Group(group)) = attr.tokens.clone().into_iter().next() {
        let parsed_type: Type = parse_quote!(#group);
        // This avoids unnecessary parentheses around type warnings from the generated code.
        if let Type::Paren(TypeParen { elem, .. }) = parsed_type {
            return *elem;
        }
        return parsed_type;
    }

    let parse_panic = || panic!("could not parse attribute `{}`", attr.tokens.to_string());
    let meta = attr.parse_meta().unwrap_or_else(|_| parse_panic());
    if let Meta::NameValue(name_value) = meta {
        if let Lit::Str(string) = name_value.lit {
            match parse_str(&string.value()) {
                Err(_) => panic!("could not parse type `{}` in attribute", string.value()),
                Ok(parsed_type) => return parsed_type,
            }
        }
    }
    parse_panic();
    unreachable!()
}

/// If the input is non-empty remove the enclosing `<` and `>` and prepend a comma.
///
/// Does not check whether the enclosing tokens actually are `<` and `>`.
fn generics_to_extra_generics(generics: &impl ToTokens) -> proc_macro2::TokenStream {
    let mut generics_tokens = proc_macro2::TokenStream::new();
    generics.to_tokens(&mut generics_tokens);

    let mut generics_tokens = generics_tokens.into_iter().collect::<Vec<_>>();

    if !generics_tokens.is_empty() {
        generics_tokens[0] = quote!(,).into_iter().next().unwrap();
        generics_tokens.pop();
    }

    let mut extra_tokens = proc_macro2::TokenStream::new();
    extra_tokens.extend(generics_tokens);
    extra_tokens
}

/// Generate a new lifetime that doesn't conflict with the existing lifetimes.
fn fresh_lifetime<'a>(lifetimes: impl Iterator<Item = &'a LifetimeDef>, name: &str) -> Lifetime {
    let mut used_idents = HashSet::new();
    for lifetime in lifetimes {
        used_idents.insert(lifetime.lifetime.ident.to_string());
    }

    let mut lifetime_name = name.to_owned();
    let mut counter = 0;

    while used_idents.contains(&lifetime_name) {
        use std::fmt::Write;
        counter += 1;
        lifetime_name.clear();
        write!(&mut lifetime_name, "{}{}", name, counter).unwrap();
    }

    Lifetime::new(&format!("'{}", lifetime_name), Span::call_site())
}

/// Derives instances of PartialRefTarget and associated traits.
///
/// Can only be used for structs. The attribute `#[part(PartName)]` can be used on the struct itself
/// for an abstract part or on a field for a field part. Parts have to be declared separately.
/// `PartName` can be any valid rust type that implements the Part trait. For fields the field type
/// of the part has to match the actual type of the field.
///
/// Example:
///
/// ```ignore
/// use partial_ref::{PartialRefTarget, part};
///
/// #[derive(PartialRefTarget)]
/// #[part(SomeAbstractPart)]
/// struct ExampleStruct {
///     field_without_part: usize,
///     #[part(SomeFieldPart)]
///     a: usize,
///     #[part(another_crate::AnotherFieldPart)]
///     b: usize,
/// }
/// ```
///
/// Instead of `#[part(PartName)]` it is also possible to use `#[part = "PartName"]` which was the
/// only supported syntax in previous versions of this crate.
///

// TODO figure out how to link to doc items of the partial_ref crate
#[proc_macro_derive(PartialRefTarget, attributes(part))]
pub fn derive_partial_ref_target(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    let target_ident = input.ident;

    let lt_a = fresh_lifetime(input.generics.lifetimes(), "a");

    let (impl_generics, target_generics, where_clause) = input.generics.split_for_impl();

    if where_clause.is_some() {
        panic!("cannot derive PartialRef target for structs with a where clause");
        // TODO lift this restriction
    }

    let extra_generics = generics_to_extra_generics(&impl_generics);

    let target_type = quote!(#target_ident #target_generics);

    let data_struct = match input.data {
        Data::Struct(data_struct) => data_struct,
        _ => panic!("deriving PartialRefTarget is only supported on structs"),
    };

    let mut abstract_parts: Vec<Type> = vec![];
    let mut typed_parts: Vec<(Member, Type)> = vec![];

    for attr in input.attrs.iter() {
        if attr.path.is_ident("part") {
            abstract_parts.push(parse_attribute_as_type(&attr));
        }
    }

    for (field_index, field) in data_struct.fields.iter().enumerate() {
        let mut part: Option<Type> = None;

        for attr in field.attrs.iter() {
            if attr.path.is_ident("part") {
                if part.is_some() {
                    panic!(
                        "{} has multiple parts",
                        field
                            .ident
                            .as_ref()
                            .map_or("unnamed field".to_owned(), |i| format!("field `{}`", i))
                    );
                }
                part = Some(parse_attribute_as_type(&attr));
            }
        }

        if let Some(part_type) = part {
            let member = field
                .ident
                .as_ref()
                .map_or(Member::Unnamed(field_index.into()), |ident| {
                    Member::Named(ident.clone())
                });
            typed_parts.push((member, part_type));
        }
    }

    let mut const_type = quote!(::partial_ref::Ref<#lt_a, #target_type>);
    let mut mut_type = quote!(::partial_ref::Ref<#lt_a, #target_type>);
    let mut split_const_type = quote!(Reference);
    let mut split_mut_type = quote!(Reference);

    for part in abstract_parts.iter() {
        const_type = quote!(::partial_ref::Const<#part, #const_type>);
        mut_type = quote!(::partial_ref::Mut<#part, #mut_type>);

        split_const_type = quote!(
            ::partial_ref::Const<
                ::partial_ref::Nested<ContainingPart, #part>,
                #split_const_type
            >
        );
        split_mut_type = quote!(
            ::partial_ref::Mut<
                ::partial_ref::Nested<ContainingPart, #part>,
                #split_mut_type
            >
        );
    }

    for (_, part) in typed_parts.iter() {
        const_type = quote!(::partial_ref::Const<#part, #const_type>);
        mut_type = quote!(::partial_ref::Mut<#part, #mut_type>);

        split_const_type = quote!(
            ::partial_ref::Const<
                ::partial_ref::Nested<ContainingPart, #part>,
                #split_const_type
            >
        );
        split_mut_type = quote!(
            ::partial_ref::Mut<
                ::partial_ref::Nested<ContainingPart, #part>,
                #split_mut_type
            >
        );
    }

    let mut result = vec![];

    result.push(TokenStream::from(quote! {
        impl<#lt_a #extra_generics> ::partial_ref::IntoPartialRef<#lt_a> for &#lt_a #target_type {
            type Ref = #const_type;
            #[inline(always)]
            fn into_partial_ref(self) -> Self::Ref {
                unsafe {
                    <Self::Ref as ::partial_ref::PartialRef>::from_raw(self as *const _ as *mut _)
                }
            }
        }

        impl<#lt_a #extra_generics> ::partial_ref::IntoPartialRef<#lt_a>
        for &#lt_a mut #target_type {
            type Ref = #mut_type;
            #[inline(always)]
            fn into_partial_ref(self) -> Self::Ref {
                unsafe {
                    <Self::Ref as ::partial_ref::PartialRef>::from_raw(self as *mut _)
                }
            }
        }

        unsafe impl<#lt_a #extra_generics, ContainingPart, Reference>
            ::partial_ref::SplitIntoParts<#lt_a, ContainingPart, Reference> for #target_type
        where
            ContainingPart: ::partial_ref::Part<PartType=::partial_ref::Field<Self>>,
            Reference: ::partial_ref::PartialRef<#lt_a>,
            Reference::Target: ::partial_ref::HasPart<ContainingPart>,
        {
            type Result = #split_const_type;
            type ResultMut = #split_mut_type;
        }

        impl #impl_generics ::partial_ref::PartialRefTarget for #target_type {
            type RawTarget = Self;
        }
    }));

    for part in abstract_parts.iter() {
        result.push(TokenStream::from(quote! {
             impl #impl_generics ::partial_ref::HasPart<#part> for #target_type {
                #[inline(always)]
                unsafe fn part_ptr(ptr: *const Self) -> () {
                    unreachable!()
                }

                #[inline(always)]
                unsafe fn part_ptr_mut(ptr: *mut Self) -> () {
                    unreachable!()
                }
            }
        }));
    }

    for (member, part) in typed_parts.iter() {
        result.push(TokenStream::from(quote! {
             impl #impl_generics ::partial_ref::HasPart<#part> for #target_type {
                #[inline(always)]
                unsafe fn part_ptr(
                    ptr: *const Self
                ) -> <<#part as ::partial_ref::Part>::PartType as ::partial_ref::PartType>::Ptr  {
                    &(*ptr).#member as *const _
                }

                #[inline(always)]
                unsafe fn part_ptr_mut(
                    ptr: *mut Self
                ) -> <<#part as ::partial_ref::Part>::PartType as ::partial_ref::PartType>::PtrMut {
                    &mut (*ptr).#member as *mut _
                }
            }
        }));
    }

    result.into_iter().collect()
}