1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{
4 Attribute, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, GenericParam, Generics,
5 Ident, Member, Path, Token, TraitBound, Type, TypeParamBound, Visibility, parse_macro_input,
6 parse_quote,
7};
8
9#[derive(Clone, Copy)]
10enum Repr {
11 C,
12 Transparent,
13}
14
15fn get_repr(attrs: &[Attribute]) -> syn::Result<Repr> {
16 let mut repr = None;
17 for attr in attrs {
18 if !attr.path().is_ident("repr") {
19 continue;
20 }
21
22 if repr.is_some() {
23 return Err(syn::Error::new_spanned(
24 attr,
25 "only one #[repr(...)] allowed",
26 ));
27 }
28
29 attr.parse_nested_meta(|meta| {
30 if meta.path.is_ident("C") {
31 repr = Some(Repr::C);
32 Ok(())
33 } else if meta.path.is_ident("transparent") {
34 repr = Some(Repr::Transparent);
35 Ok(())
36 } else {
37 Err(meta.error("only #[repr(C)] and #[repr(transparent)] are supported"))
38 }
39 })?;
40 }
41 let Some(repr) = repr else {
42 return Err(syn::Error::new(
43 Span::call_site(),
44 "type must be #[repr(C)] or #[repr(transparent)]",
45 ));
46 };
47 Ok(repr)
48}
49
50fn get_fields(
51 data: &Data,
52) -> syn::Result<(
53 impl Iterator<Item = Member> + Clone,
54 impl Iterator<Item = &Type> + Clone,
55 usize,
56)> {
57 Ok(match data {
58 Data::Struct(DataStruct { fields, .. }) => {
59 (fields.members(), fields.iter().map(|f| &f.ty), fields.len())
60 }
61 Data::Enum(DataEnum { enum_token, .. }) => {
62 return Err(Error::new_spanned(enum_token, "only structs are supported"));
63 }
64 Data::Union(DataUnion { union_token, .. }) => {
65 return Err(Error::new_spanned(
66 union_token,
67 "only structs are supported",
68 ));
69 }
70 })
71}
72
73struct DstAttrs {
74 simple_dst_path: Path,
75 new_unchecked_vis: Visibility,
76}
77
78fn get_dst_attrs(attrs: &[Attribute]) -> syn::Result<DstAttrs> {
79 let mut simple_dst_path: Option<Path> = None;
80 let mut new_unchecked_vis: Option<Visibility> = None;
81 for attr in attrs {
82 if !attr.path().is_ident("dst") {
83 continue;
84 }
85
86 attr.parse_nested_meta(|meta| {
87 if meta.path.is_ident("simple_dst_path") {
88 if simple_dst_path.is_some() {
89 return Err(meta.error("only one #[dst(simple_dst_path = ...)] is allowed"));
90 }
91 simple_dst_path = Some({
92 meta.input.parse::<Token![=]>()?;
93 meta.input.parse()?
94 });
95 } else if meta.path.is_ident("new_unchecked_vis") {
96 if new_unchecked_vis.is_some() {
97 return Err(meta.error("only one #[dst(new_unchecked_vis = ...)] is allowed"));
98 }
99 new_unchecked_vis = Some({
100 meta.input.parse::<Token![=]>()?;
101 meta.input.parse()?
102 });
103 } else {
104 return Err(meta.error("unrecognised #[dst(...)] argument"));
105 }
106 Ok(())
107 })?;
108 }
109
110 let dst_attrs = DstAttrs {
111 simple_dst_path: simple_dst_path.unwrap_or_else(|| parse_quote! { ::simple_dst }),
112 new_unchecked_vis: new_unchecked_vis.unwrap_or(Visibility::Inherited),
113 };
114 Ok(dst_attrs)
115}
116
117fn has_unsized_bound<'a>(bounds: impl Iterator<Item = &'a TypeParamBound>) -> bool {
118 for bound in bounds {
119 if let TypeParamBound::Trait(TraitBound {
120 modifier: syn::TraitBoundModifier::Maybe(_),
121 lifetimes: None,
122 path,
123 ..
124 }) = bound
125 && path.is_ident("Sized")
126 {
127 return true;
128 }
129 }
130 false
131}
132
133fn add_dst_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
134 for param in &mut generics.params {
135 if let GenericParam::Type(type_param) = param
136 && has_unsized_bound(type_param.bounds.iter())
137 {
138 type_param
139 .bounds
140 .push(parse_quote! { #simple_dst_path::Dst });
141 type_param
142 .bounds
143 .push(parse_quote! { #simple_dst_path::CloneToUninit });
144 }
145 }
146 generics
147}
148
149#[proc_macro_derive(Dst, attributes(dst))]
165pub fn derive_dst(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
166 let input = parse_macro_input!(input as DeriveInput);
167 derive_dst_impl(input)
168 .unwrap_or_else(syn::Error::into_compile_error)
169 .into()
170}
171
172fn get_internal_layout_fn(
173 simple_dst_path: &Path,
174 repr: Repr,
175 n_fields: usize,
176 idxs: &[usize],
177 first_tys: &[&Type],
178 last_ty: Option<&Type>,
179) -> TokenStream {
180 match repr {
181 Repr::C => quote!(
182 {
183 let layouts = [#(::core::alloc::Layout::new::<#first_tys>()),*, <#last_ty as #simple_dst_path::Dst>::layout(len)?];
184 let mut offsets = [0; #n_fields];
185 let layout = ::core::alloc::Layout::from_size_align(0, 1)?;
186 #(
187 let (layout, offset) = layout.extend(layouts[#idxs])?;
188 offsets[#idxs] = offset;
189 )*
190 ::core::result::Result::Ok((layout.pad_to_align(), offsets))
191 }
192 ),
193 Repr::Transparent => quote!(
194 {
195 ::core::result::Result::Ok((<#last_ty as #simple_dst_path::Dst>::layout(len)?, [0; #n_fields]))
196 }
197 ),
198 }
199}
200
201fn derive_dst_impl(input: DeriveInput) -> syn::Result<TokenStream> {
202 let repr = get_repr(&input.attrs)?;
203
204 let name = input.ident;
205
206 let DstAttrs {
207 simple_dst_path,
208 new_unchecked_vis,
209 } = get_dst_attrs(&input.attrs)?;
210
211 let generics = add_dst_trait_bounds(input.generics, &simple_dst_path);
212 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
213
214 let (members, tys, n_fields) = get_fields(&input.data)?;
215 if n_fields == 0 {
216 return Err(Error::new_spanned(
217 name,
218 "type must have at least one field",
219 ));
220 }
221
222 let idxs: Vec<_> = (0..n_fields).collect();
223 let first_idxs: Vec<_> = (0..n_fields - 1).collect();
224 let last_idx = n_fields - 1;
225
226 let last_member = members.clone().nth(last_idx);
227
228 let member_var_names: Vec<_> = members
229 .clone()
230 .map(|m| match m {
231 Member::Named(ident) => ident,
232 Member::Unnamed(index) => format_ident!("var_{}", index),
233 })
234 .collect();
235 let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
236 let last_member_var_name = member_var_names.get(last_idx);
237
238 let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
239 let last_ty = tys.clone().nth(last_idx);
240
241 let internal_layout_fn =
242 get_internal_layout_fn(&simple_dst_path, repr, n_fields, &idxs, &first_tys, last_ty);
243
244 Ok(quote! {
245 #[automatically_derived]
246 unsafe impl #impl_generics #simple_dst_path::Dst for #name #ty_generics #where_clause {
247 fn len(&self) -> usize {
248 #simple_dst_path::Dst::len(&self.#last_member)
249 }
250
251 fn layout(len: usize) -> ::core::result::Result<::core::alloc::Layout, ::core::alloc::LayoutError> {
252 let (layout, _) = Self::__dst_impl_layout_offsets(len)?;
253 ::core::result::Result::Ok(layout)
254 }
255
256 fn retype(ptr: ::core::ptr::NonNull<u8>, len: usize) -> ::core::ptr::NonNull<Self> {
257 unsafe {
261 #[allow(
262 clippy::cast_ptr_alignment,
263 reason = "the responsibility to provide a pointer with the correct alignment is on the caller"
264 )]
265 ::core::ptr::NonNull::new_unchecked(::core::ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len) as *mut Self)
266 }
267 }
268 }
269
270 #[automatically_derived]
271 impl #impl_generics #name #ty_generics #where_clause {
272 #[doc(hidden)]
273 #[inline]
274 fn __dst_impl_layout_offsets(len: usize) -> ::core::result::Result<(::core::alloc::Layout, [usize; #n_fields]), ::core::alloc::LayoutError>
275 #internal_layout_fn
276
277 #new_unchecked_vis unsafe fn new_unchecked<A: #simple_dst_path::AllocDst<Self>>(
278 #( #first_member_var_names: #first_tys, )*
279 #last_member_var_name: &#last_ty
280 ) -> ::core::result::Result<A, ::core::alloc::LayoutError> {
281 let (layout, offsets) = Self::__dst_impl_layout_offsets(#last_member_var_name.len())?;
282 Ok(unsafe {
283 A::new_dst(<#last_ty as #simple_dst_path::Dst>::len(#last_member_var_name), layout, |ptr| {
284 let dest = ptr.cast::<u8>();
285
286 <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(#last_member_var_name, dest.add(offsets[#last_idx]).as_ptr());
287
288 #(
289 dest.add(offsets[#first_idxs]).cast::<#first_tys>().write(#first_member_var_names);
290 )*
291 })
292 })
293 }
294 }
295 })
296}
297
298fn add_clone_to_uninit_trait_bounds(mut generics: Generics, simple_dst_path: &Path) -> Generics {
299 for param in &mut generics.params {
300 if let GenericParam::Type(type_param) = param {
301 let bound = if has_unsized_bound(type_param.bounds.iter()) {
302 parse_quote! { #simple_dst_path::CloneToUninit }
303 } else {
304 parse_quote! { ::core::clone::Clone }
305 };
306 type_param.bounds.push(bound);
307 }
308 }
309 generics
310}
311
312#[proc_macro_derive(CloneToUninit, attributes(dst))]
318pub fn derive_clone_to_uninit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
319 let input = parse_macro_input!(input as DeriveInput);
320 derive_clone_to_uninit_impl(input)
321 .unwrap_or_else(syn::Error::into_compile_error)
322 .into()
323}
324
325fn derive_clone_to_uninit_impl(input: DeriveInput) -> syn::Result<TokenStream> {
326 let name = input.ident;
327
328 let DstAttrs {
329 simple_dst_path, ..
330 } = get_dst_attrs(&input.attrs)?;
331
332 let generics = add_clone_to_uninit_trait_bounds(input.generics, &simple_dst_path);
333 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
334
335 let (members, tys, n_fields) = get_fields(&input.data)?;
336 if n_fields == 0 {
337 return Err(Error::new_spanned(
338 name,
339 "type must have at least one field",
340 ));
341 }
342
343 let last_idx = n_fields - 1;
344
345 let first_members: Vec<_> = members.clone().take(n_fields - 1).collect();
346 let last_member = members.clone().nth(last_idx);
347
348 let member_var_names: Vec<_> = members
349 .clone()
350 .map(|m| match m {
351 Member::Named(ident) => ident,
352 Member::Unnamed(index) => format_ident!("var_{}", index),
353 })
354 .collect();
355 let first_member_var_names: Vec<_> = member_var_names.iter().take(n_fields - 1).collect();
356
357 let first_tys: Vec<_> = tys.clone().take(n_fields - 1).collect();
358 let last_ty = tys.clone().nth(last_idx);
359
360 Ok(quote! {
361 #[automatically_derived]
362 unsafe impl #impl_generics #simple_dst_path::CloneToUninit for #name #ty_generics #where_clause {
363 unsafe fn clone_to_uninit(&self, dest: *mut u8) {
364 let last_offset = unsafe { (&raw const self.#last_member).byte_offset_from_unsigned(self) };
372
373 #(
374 let #first_member_var_names = <#first_tys as ::core::clone::Clone>::clone(&self.#first_members);
375 )*
376
377 unsafe {
378 <#last_ty as #simple_dst_path::CloneToUninit>::clone_to_uninit(&self.#last_member, dest.add(last_offset));
379
380 #(
381 dest.add(::core::mem::offset_of!(Self, #first_member_var_names)).cast::<#first_tys>().write(#first_member_var_names);
382 )*
383 }
384 }
385 }
386 })
387}
388
389struct ToOwnedAttrs {
390 alloc_path: Path,
391 owned: Type,
392}
393
394fn get_to_owned_attrs(attrs: &[Attribute], name: &Ident) -> syn::Result<ToOwnedAttrs> {
395 let mut alloc_path: Option<Path> = None;
396 let mut owned: Option<Type> = None;
397 for attr in attrs {
398 if !attr.path().is_ident("to_owned") {
399 continue;
400 }
401
402 attr.parse_nested_meta(|meta| {
403 if meta.path.is_ident("alloc_path") {
404 if alloc_path.is_some() {
405 return Err(meta.error("only one #[to_owned(alloc_path = ...)] is allowed"));
406 }
407 alloc_path = Some({
408 meta.input.parse::<Token![=]>()?;
409 meta.input.parse()?
410 });
411 } else if meta.path.is_ident("owned") {
412 if owned.is_some() {
413 return Err(meta.error("only one #[to_owned(owned = ...)] is allowed"));
414 }
415 owned = Some({
416 meta.input.parse::<Token![=]>()?;
417 meta.input.parse()?
418 });
419 } else {
420 return Err(meta.error("unrecognised #[to_owned(...)] argument"));
421 }
422 Ok(())
423 })?;
424 }
425
426 let alloc_path = alloc_path.unwrap_or_else(|| parse_quote! { ::std });
427 let to_owned_attrs = ToOwnedAttrs {
428 alloc_path: alloc_path.clone(),
429 owned: owned.unwrap_or_else(|| parse_quote! { #alloc_path::boxed::Box<#name> }),
430 };
431 Ok(to_owned_attrs)
432}
433
434#[proc_macro_derive(ToOwned, attributes(dst, to_owned))]
435pub fn derive_to_owned(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
436 let input = parse_macro_input!(input as DeriveInput);
437 derive_to_owned_impl(input)
438 .unwrap_or_else(syn::Error::into_compile_error)
439 .into()
440}
441
442fn derive_to_owned_impl(input: DeriveInput) -> syn::Result<TokenStream> {
443 let name = input.ident;
444
445 let DstAttrs {
446 simple_dst_path, ..
447 } = get_dst_attrs(&input.attrs)?;
448 let ToOwnedAttrs { alloc_path, owned } = get_to_owned_attrs(&input.attrs, &name)?;
449
450 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
451
452 Ok(quote! {
453 #[automatically_derived]
454 impl #impl_generics #alloc_path::borrow::ToOwned for #name #ty_generics #where_clause {
455 type Owned = #owned;
456
457 fn to_owned(&self) -> Self::Owned {
458 let layout = ::core::alloc::Layout::for_value(self);
459
460 unsafe {
461 <#owned as #simple_dst_path::AllocDst<#name>>::new_dst(
462 <#name as #simple_dst_path::Dst>::len(self),
463 layout,
464 |ptr| {
465 let dest = ptr.cast::<u8>();
466
467 <#name as #simple_dst_path::CloneToUninit>::clone_to_uninit(self, dest.as_ptr())
468 },
469 )
470 }
471 }
472 }
473 })
474}