Skip to main content

simple_dst_derive/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{
4    Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, GenericParam, Generics,
5    Ident, Member, Path, Token, TraitBound, Type, TypeParamBound, Visibility, parse_macro_input,
6    parse_quote,
7};
8
9#[derive(Clone, Copy)]
10enum Repr {
11    C,
12    Transparent,
13}
14
15fn get_repr(attrs: &[Attribute]) -> syn::Result<Repr> {
16    let mut repr = None;
17    for attr in attrs {
18        if !attr.path().is_ident("repr") {
19            continue;
20        }
21
22        if repr.is_some() {
23            return Err(syn::Error::new_spanned(
24                attr,
25                "only one #[repr(...)] allowed",
26            ));
27        }
28
29        attr.parse_nested_meta(|meta| {
30            if meta.path.is_ident("C") {
31                repr = Some(Repr::C);
32                Ok(())
33            } else if meta.path.is_ident("transparent") {
34                repr = Some(Repr::Transparent);
35                Ok(())
36            } else {
37                Err(meta.error("only #[repr(C)] and #[repr(transparent)] are supported"))
38            }
39        })?;
40    }
41    let Some(repr) = repr else {
42        return Err(syn::Error::new(
43            Span::call_site(),
44            "type must be #[repr(C)] or #[repr(transparent)]",
45        ));
46    };
47    Ok(repr)
48}
49
50fn get_fields(
51    data: &Data,
52) -> syn::Result<(
53    impl Iterator<Item = Member> + Clone,
54    impl Iterator<Item = &Type> + Clone,
55    usize,
56)> {
57    Ok(match data {
58        Data::Struct(DataStruct { fields, .. }) => {
59            (fields.members(), fields.iter().map(|f| &f.ty), fields.len())
60        }
61        Data::Enum(DataEnum { enum_token, .. }) => {
62            return Err(Error::new_spanned(enum_token, "only structs are supported"));
63        }
64        Data::Union(DataUnion { union_token, .. }) => {
65            return Err(Error::new_spanned(
66                union_token,
67                "only structs are supported",
68            ));
69        }
70    })
71}
72
73struct DstAttrs {
74    simple_dst_path: Path,
75    new_unchecked_vis: Visibility,
76}
77
78fn get_dst_attrs(attrs: &[Attribute]) -> syn::Result<DstAttrs> {
79    let mut simple_dst_path: Option<Path> = None;
80    let mut new_unchecked_vis: Option<Visibility> = None;
81    for attr in attrs {
82        if !attr.path().is_ident("dst") {
83            continue;
84        }
85
86        attr.parse_nested_meta(|meta| {
87            if meta.path.is_ident("simple_dst_path") {
88                if simple_dst_path.is_some() {
89                    return Err(meta.error("only one #[dst(simple_dst_path = ...)] is allowed"));
90                }
91                simple_dst_path = Some({
92                    meta.input.parse::<Token![=]>()?;
93                    meta.input.parse()?
94                });
95            } else if meta.path.is_ident("new_unchecked_vis") {
96                if new_unchecked_vis.is_some() {
97                    return Err(meta.error("only one #[dst(new_unchecked_vis = ...)] is allowed"));
98                }
99                new_unchecked_vis = Some({
100                    meta.input.parse::<Token![=]>()?;
101                    meta.input.parse()?
102                });
103            } else {
104                return Err(meta.error("unrecognised #[dst(...)] argument"));
105            }
106            Ok(())
107        })?;
108    }
109
110    let dst_attrs = DstAttrs {
111        simple_dst_path: simple_dst_path.unwrap_or_else(|| parse_quote! { ::simple_dst }),
112        new_unchecked_vis: new_unchecked_vis.unwrap_or(Visibility::Inherited),
113    };
114    Ok(dst_attrs)
115}
116
117fn has_unsized_bound<'a>(bounds: impl Iterator<Item = &'a TypeParamBound>) -> bool {
118    for bound in bounds {
119        if let TypeParamBound::Trait(TraitBound {
120            modifier: syn::TraitBoundModifier::Maybe(_),
121            lifetimes: None,
122            path,
123            ..
124        }) = bound
125            && path.is_ident("Sized")
126        {
127            return true;
128        }
129    }
130    false
131}
132
133fn add_dst_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
134    for param in &mut generics.params {
135        if let GenericParam::Type(type_param) = param
136            && has_unsized_bound(type_param.bounds.iter())
137        {
138            type_param
139                .bounds
140                .push(parse_quote! { #simple_dst_path::Dst });
141            type_param
142                .bounds
143                .push(parse_quote! { #simple_dst_path::CloneToUninit });
144        }
145    }
146    generics
147}
148
149/// Derive macro for the `Dst` trait.
150///
151/// The underlying DST must be the last field of the struct.
152///
153/// This derive also creates a `new_unchecked` function, which takes each of the fields
154/// in the struct as arguments, with the last field (the DST) being taken as a reference.
155/// This `new_unchecked` function is marked as `unsafe` as it doesn't check any of the
156/// type's interior invariants. The visibility of this generated function can be modified
157/// with the `#[dst(new_unchecked_vis = ...)]` attribute.
158///
159/// The path to the `simple_dst` crate can be modified with the
160/// `#[dst(simple_dst_path = ...)]` attribute.
161///
162/// If there are any type parameters with a `?Sized` trait bound, those are assumed to
163/// be the type of the DST, so the `Dst` and `CloneToUninit` trait bounds will be added.
164#[proc_macro_derive(Dst, attributes(dst))]
165pub fn derive_dst(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
166    let input = parse_macro_input!(input as DeriveInput);
167    derive_dst_impl(input)
168        .unwrap_or_else(syn::Error::into_compile_error)
169        .into()
170}
171
172fn get_internal_layout_fn(
173    simple_dst_path: &Path,
174    repr: Repr,
175    n_fields: usize,
176    idxs: &[usize],
177    first_tys: &[&Type],
178    last_ty: Option<&Type>,
179) -> TokenStream {
180    match repr {
181        Repr::C => quote!(
182            {
183                let layouts = [#(::core::alloc::Layout::new::<#first_tys>()),*, <#last_ty as #simple_dst_path::Dst>::layout(len)?];
184                let mut offsets = [0; #n_fields];
185                let layout = ::core::alloc::Layout::from_size_align(0, 1)?;
186                #(
187                    let (layout, offset) = layout.extend(layouts[#idxs])?;
188                    offsets[#idxs] = offset;
189                )*
190                ::core::result::Result::Ok((layout.pad_to_align(), offsets))
191            }
192        ),
193        Repr::Transparent => quote!(
194            {
195                ::core::result::Result::Ok((<#last_ty as #simple_dst_path::Dst>::layout(len)?, [0; #n_fields]))
196            }
197        ),
198    }
199}
200
201fn derive_dst_impl(input: DeriveInput) -> syn::Result<TokenStream> {
202    let repr = get_repr(&input.attrs)?;
203
204    let name = input.ident;
205
206    let DstAttrs {
207        simple_dst_path,
208        new_unchecked_vis,
209    } = get_dst_attrs(&input.attrs)?;
210
211    let generics = add_dst_trait_bounds(input.generics, &simple_dst_path);
212    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
213
214    let (members, tys, n_fields) = get_fields(&input.data)?;
215    if n_fields == 0 {
216        return Err(Error::new_spanned(
217            name,
218            "type must have at least one field",
219        ));
220    }
221
222    let idxs: Vec<_> = (0..n_fields).collect();
223    let first_idxs: Vec<_> = (0..n_fields - 1).collect();
224    let last_idx = n_fields - 1;
225
226    let last_member = members.clone().nth(last_idx);
227
228    let member_var_names: Vec<_> = members
229        .clone()
230        .map(|m| match m {
231            Member::Named(ident) => ident,
232            Member::Unnamed(index) => format_ident!("var_{}", index),
233        })
234        .collect();
235    let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
236    let last_member_var_name = member_var_names.get(last_idx);
237
238    let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
239    let last_ty = tys.clone().nth(last_idx);
240
241    let internal_layout_fn =
242        get_internal_layout_fn(&simple_dst_path, repr, n_fields, &idxs, &first_tys, last_ty);
243
244    Ok(quote! {
245        #[automatically_derived]
246        unsafe impl #impl_generics #simple_dst_path::Dst for #name #ty_generics #where_clause {
247            fn len(&self) -> usize {
248                #simple_dst_path::Dst::len(&self.#last_member)
249            }
250
251            fn layout(len: usize) -> ::core::result::Result<::core::alloc::Layout, ::core::alloc::LayoutError> {
252                let (layout, _) = Self::__dst_impl_layout_offsets(len)?;
253                ::core::result::Result::Ok(layout)
254            }
255
256            fn retype(ptr: ::core::ptr::NonNull<u8>, len: usize) -> ::core::ptr::NonNull<Self> {
257                // FUTURE: switch to ptr::from_raw_parts_mut() when it has stabilised.
258                // SAFETY: the pointer value doesn't change when using `slice_from_raw_parts_mut`,
259                // so the invariants of `NonNull` are upheld
260                unsafe {
261                    #[allow(
262                        clippy::cast_ptr_alignment,
263                        reason = "the responsibility to provide a pointer with the correct alignment is on the caller"
264                    )]
265                    ::core::ptr::NonNull::new_unchecked(::core::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len) as *mut Self)
266                }
267            }
268        }
269
270        #[automatically_derived]
271        impl #impl_generics #name #ty_generics #where_clause {
272            #[doc(hidden)]
273            #[inline]
274            fn __dst_impl_layout_offsets(len: usize) -> ::core::result::Result<(::core::alloc::Layout, [usize; #n_fields]), ::core::alloc::LayoutError>
275            #internal_layout_fn
276
277            #new_unchecked_vis unsafe fn new_unchecked<A: #simple_dst_path::AllocDst<Self>>(
278                #( #first_member_var_names: #first_tys, )*
279                #last_member_var_name: &#last_ty
280            ) -> ::core::result::Result<A, ::core::alloc::LayoutError> {
281                let (layout, offsets) = Self::__dst_impl_layout_offsets(#last_member_var_name.len())?;
282                Ok(unsafe {
283                    A::new_dst(<#last_ty as #simple_dst_path::Dst>::len(#last_member_var_name), layout, |ptr| {
284                        let dest = ptr.cast::<u8>();
285
286                        <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(#last_member_var_name, dest.add(offsets[#last_idx]).as_ptr());
287
288                        #(
289                            dest.add(offsets[#first_idxs]).cast::<#first_tys>().write(#first_member_var_names);
290                        )*
291                    })
292                })
293            }
294        }
295    })
296}
297
298fn add_clone_to_uninit_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
299    for param in &mut generics.params {
300        if let GenericParam::Type(type_param) = param {
301            let bound = if has_unsized_bound(type_param.bounds.iter()) {
302                parse_quote! { #simple_dst_path::CloneToUninit }
303            } else {
304                parse_quote! { ::core::clone::Clone }
305            };
306            type_param.bounds.push(bound);
307        }
308    }
309    generics
310}
311
312/// Derive macro for the `CloneToUninit` trait for DSTs.
313///
314/// If there are any generic types with a `?Sized` trait bound, those are assumed to be
315/// the type of the DST, so the `CloneToUninit` trait bound will be added. All other
316/// generic types will have the [`Clone`] trait bound added.
317#[proc_macro_derive(CloneToUninit, attributes(dst))]
318pub fn derive_clone_to_uninit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
319    let input = parse_macro_input!(input as DeriveInput);
320    derive_clone_to_uninit_impl(input)
321        .unwrap_or_else(syn::Error::into_compile_error)
322        .into()
323}
324
325fn derive_clone_to_uninit_impl(input: DeriveInput) -> syn::Result<TokenStream> {
326    let name = input.ident;
327
328    let DstAttrs {
329        simple_dst_path, ..
330    } = get_dst_attrs(&input.attrs)?;
331
332    let generics = add_clone_to_uninit_trait_bounds(input.generics, &simple_dst_path);
333    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
334
335    let (members, tys, n_fields) = get_fields(&input.data)?;
336    if n_fields == 0 {
337        return Err(Error::new_spanned(
338            name,
339            "type must have at least one field",
340        ));
341    }
342
343    let last_idx = n_fields - 1;
344
345    let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
346    let last_member = members.clone().nth(last_idx);
347
348    let member_var_names: Vec<_> = members
349        .clone()
350        .map(|m| match m {
351            Member::Named(ident) => ident,
352            Member::Unnamed(index) => format_ident!("var_{}", index),
353        })
354        .collect();
355    let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
356
357    let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
358    let last_ty = tys.clone().nth(last_idx);
359
360    Ok(quote! {
361        #[automatically_derived]
362        unsafe impl #impl_generics #simple_dst_path::CloneToUninit for #name #ty_generics #where_clause {
363            unsafe fn clone_to_uninit(&self, dest: *mut u8) {
364                // SAFETY:
365                // * `&self.slice` >= `self` because `slice` is a field in `self`, and Self is
366                //   `#[repr(C)]`.
367                // * both pointers must be from the same allocation since they are within the
368                //   same object, and thus the memory range between them is also in bounds of
369                //   the object.
370                // * the distance between the pointers is an exact multiple of the size of u8.
371                let last_offset = unsafe { (&raw const self.#last_member).byte_offset_from_unsigned(self) };
372
373                #(
374                    let #first_member_var_names = <#first_tys as ::core::clone::Clone>::clone(&self.#first_members);
375                )*
376
377                unsafe {
378                    <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(&self.#last_member, dest.add(last_offset));
379
380                    #(
381                        dest.add(::core::mem::offset_of!(Self, #first_member_var_names)).cast::<#first_tys>().write(#first_member_var_names);
382                    )*
383                }
384            }
385        }
386    })
387}
388
389struct ToOwnedAttrs {
390    alloc_path: Path,
391    owned: Type,
392}
393
394fn get_to_owned_attrs(attrs: &[Attribute], name: &Ident) -> syn::Result<ToOwnedAttrs> {
395    let mut alloc_path: Option<Path> = None;
396    let mut owned: Option<Type> = None;
397    for attr in attrs {
398        if !attr.path().is_ident("to_owned") {
399            continue;
400        }
401
402        attr.parse_nested_meta(|meta| {
403            if meta.path.is_ident("alloc_path") {
404                if alloc_path.is_some() {
405                    return Err(meta.error("only one #[to_owned(alloc_path = ...)] is allowed"));
406                }
407                alloc_path = Some({
408                    meta.input.parse::<Token![=]>()?;
409                    meta.input.parse()?
410                });
411            } else if meta.path.is_ident("owned") {
412                if owned.is_some() {
413                    return Err(meta.error("only one #[to_owned(owned = ...)] is allowed"));
414                }
415                owned = Some({
416                    meta.input.parse::<Token![=]>()?;
417                    meta.input.parse()?
418                });
419            } else {
420                return Err(meta.error("unrecognised #[to_owned(...)] argument"));
421            }
422            Ok(())
423        })?;
424    }
425
426    let alloc_path = alloc_path.unwrap_or_else(|| parse_quote! { ::std });
427    let to_owned_attrs = ToOwnedAttrs {
428        alloc_path: alloc_path.clone(),
429        owned: owned.unwrap_or_else(|| parse_quote! { #alloc_path::boxed::Box<#name> }),
430    };
431    Ok(to_owned_attrs)
432}
433
434#[proc_macro_derive(ToOwned, attributes(dst, to_owned))]
435pub fn derive_to_owned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
436    let input = parse_macro_input!(input as DeriveInput);
437    derive_to_owned_impl(input)
438        .unwrap_or_else(syn::Error::into_compile_error)
439        .into()
440}
441
442fn derive_to_owned_impl(input: DeriveInput) -> syn::Result<TokenStream> {
443    let name = input.ident;
444
445    let DstAttrs {
446        simple_dst_path, ..
447    } = get_dst_attrs(&input.attrs)?;
448    let ToOwnedAttrs { alloc_path, owned } = get_to_owned_attrs(&input.attrs, &name)?;
449
450    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
451
452    Ok(quote! {
453        #[automatically_derived]
454        impl #impl_generics #alloc_path::borrow::ToOwned for #name #ty_generics #where_clause {
455            type Owned = #owned;
456
457            fn to_owned(&self) -> Self::Owned {
458                let layout = ::core::alloc::Layout::for_value(self);
459
460                unsafe {
461                    <#owned as #simple_dst_path::AllocDst<#name>>::new_dst(
462                        <#name as #simple_dst_path::Dst>::len(self),
463                        layout,
464                        |ptr| {
465                            let dest = ptr.cast::<u8>();
466
467                            <#name as #simple_dst_path::CloneToUninit>::clone_to_uninit(self, dest.as_ptr())
468                        },
469                    )
470                }
471            }
472        }
473    })
474}