1extern crate proc_macro;
186mod macro_args;
187
188use macro_args::MacroArgs;
189use proc_macro2::{Span, TokenStream};
190use quote::{ToTokens, format_ident, quote, quote_spanned};
191use syn::token::{Comma, Where};
192use syn::{
193 Field, Fields, Generics, Ident, Index, ItemStruct, Path, Type, TypePath, parse::Result as SynResult, punctuated::Punctuated,
194 spanned::Spanned,
195};
196
197enum TailKind {
198 Slice(Box<Type>),
199 Str,
200 TraitObject(Path),
201}
202
203enum FieldIdent {
204 Named(Ident),
205 Unnamed(Index),
206}
207
208impl ToTokens for FieldIdent {
209 fn to_tokens(&self, tokens: &mut TokenStream) {
210 match self {
211 Self::Named(ident) => ident.to_tokens(tokens),
212 Self::Unnamed(index) => index.to_tokens(tokens),
213 }
214 }
215}
216
217struct StructInfo<'a> {
218 struct_name: &'a Ident,
219 struct_generics: &'a Generics,
220 header_fields: Box<[&'a Field]>,
221 header_field_idents: Box<[FieldIdent]>,
222 header_param_idents: Box<[Ident]>,
223 tail_field: &'a Field,
224 tail_field_ident: FieldIdent,
225 tail_param_ident: Ident,
226 tail_kind: TailKind,
227}
228
229impl<'a> StructInfo<'a> {
230 fn new(input_struct: &'a ItemStruct) -> SynResult<Self> {
231 match &input_struct.fields {
232 Fields::Named(named_fields) => Self::process_fields(input_struct, &named_fields.named),
233 Fields::Unnamed(unnamed_fields) => Self::process_fields(input_struct, &unnamed_fields.unnamed),
234 Fields::Unit => Err(syn::Error::new_spanned(input_struct, "Unit structs are not supported")),
235 }
236 }
237
238 fn process_fields(input_struct: &'a ItemStruct, fields: &'a Punctuated<Field, Comma>) -> SynResult<Self> {
239 if fields.is_empty() {
240 return Err(syn::Error::new_spanned(input_struct, "Struct must have at least one field"));
241 }
242
243 let header_fields: Vec<_> = fields.iter().take(fields.len() - 1).collect();
244 let tail_field = fields.last().unwrap();
245
246 let tail_kind = match &tail_field.ty {
247 Type::Path(type_path) if type_path.path.is_ident("str") => TailKind::Str,
248 Type::Path(TypePath { path, .. }) if path.segments.last().is_some_and(|s| s.ident == "str") => TailKind::Str,
249 Type::Slice(type_slice) => TailKind::Slice(type_slice.elem.clone()),
250
251 Type::TraitObject(type_trait_object) => {
252 for bound in &type_trait_object.bounds {
254 if let syn::TypeParamBound::Trait(trait_bound) = bound {
255 if trait_bound.lifetimes.is_some() {
256 return Err(syn::Error::new_spanned(
257 trait_bound,
258 "Higher-rank trait bounds (e.g., `for<'a> dyn Trait<'a>`) are not supported for the tail field.",
259 ));
260 }
261 }
262 }
263
264 let trait_path = type_trait_object
266 .bounds
267 .iter()
268 .find_map(|bound| {
269 if let syn::TypeParamBound::Trait(trait_bound) = bound {
270 Some(&trait_bound.path)
271 } else {
272 None }
274 })
275 .unwrap();
276
277 TailKind::TraitObject(trait_path.clone())
278 }
279
280 _ => {
281 return Err(syn::Error::new_spanned(
282 &tail_field.ty,
283 "Last field must be a dynamically sized type like [T], str, or dyn Trait",
284 ));
285 }
286 };
287
288 let header_param_idents: Vec<_> = header_fields
289 .iter()
290 .enumerate()
291 .map(|(i, field)| field.ident.as_ref().map_or_else(|| format_ident!("f{i}"), Clone::clone))
292 .collect();
293
294 let header_field_idents: Vec<_> = header_fields
295 .iter()
296 .enumerate()
297 .map(|(i, field)| {
298 field
299 .ident
300 .as_ref()
301 .map_or_else(|| FieldIdent::Unnamed(Index::from(i)), |ident| FieldIdent::Named(ident.clone()))
302 })
303 .collect();
304
305 let tail_param_ident = tail_field
306 .ident
307 .as_ref()
308 .map_or_else(|| format_ident!("f{}", header_fields.len()), Clone::clone);
309
310 let tail_field_ident = tail_field.ident.as_ref().map_or_else(
311 || FieldIdent::Unnamed(Index::from(header_fields.len())),
312 |ident| FieldIdent::Named(ident.clone()),
313 );
314
315 Ok(Self {
316 struct_name: &input_struct.ident,
317 struct_generics: &input_struct.generics,
318 header_fields: header_fields.into_boxed_slice(),
319 header_field_idents: header_field_idents.into_boxed_slice(),
320 header_param_idents: header_param_idents.into_boxed_slice(),
321 tail_field,
322 tail_field_ident,
323 tail_param_ident,
324 tail_kind,
325 })
326 }
327}
328
329fn header_layout(macro_args: &MacroArgs, struct_info: &StructInfo, for_trait: bool) -> TokenStream {
330 let tail_field_ident = &struct_info.tail_field_ident;
331
332 let header_field_types: Vec<_> = struct_info.header_fields.iter().map(|field| &field.ty).collect();
333
334 if header_field_types.is_empty() {
335 return quote! {
336 let layout = ::core::alloc::Layout::from_size_align(0, 1).unwrap();
337 };
338 }
339
340 let fat_payload = if for_trait {
341 quote! { vtable }
342 } else {
343 quote! { 0_usize }
344 };
345
346 let tail_type = match &struct_info.tail_kind {
347 TailKind::Slice(elem_type) => quote! { #elem_type },
348 TailKind::Str => quote! { u8 },
349 TailKind::TraitObject(_) => {
350 let generic_name = ¯o_args.generic_name;
351 quote! { #generic_name }
352 }
353 };
354
355 quote! {
356 let buffer = ::core::mem::MaybeUninit::<(#( #header_field_types, )* #tail_type)>::uninit();
357 let (offset, align) = unsafe {
358 let head_ptr: *const Self = ::core::mem::transmute((&raw const buffer, #fat_payload));
359 let tail_ptr = &raw const (*head_ptr).#tail_field_ident;
360 (
361 (tail_ptr as *const u8).offset_from_unsigned(head_ptr as *const u8),
362 ::core::mem::align_of_val::<Self>(&*head_ptr)
363 )
364 };
365
366 let layout = ::core::alloc::Layout::from_size_align(offset, align).unwrap();
367 }
368}
369
370fn tail_layout<T: quote::ToTokens>(tail_type: &T, span: Span) -> TokenStream {
371 quote_spanned! { span => ::core::alloc::Layout::array::<#tail_type>(len).expect("Array exceeds maximum size allowed of isize::MAX") }
372}
373
374fn guard_type(macro_args: &MacroArgs) -> TokenStream {
375 let dealloc_path = if macro_args.no_std {
376 quote! { ::alloc::alloc::dealloc }
377 } else {
378 quote! { ::std::alloc::dealloc }
379 };
380
381 quote! {
382 struct Guard<T> {
383 mem_ptr: *mut u8,
384 tail_ptr: *mut T,
385 initialized: usize,
386 layout: ::core::alloc::Layout,
387 }
388
389 impl<T> Drop for Guard<T> {
390 fn drop(&mut self) {
391 unsafe {
392 let slice_ptr = ::core::ptr::slice_from_raw_parts_mut(self.tail_ptr, self.initialized);
393 ::core::ptr::drop_in_place(slice_ptr);
394 #dealloc_path(self.mem_ptr, self.layout);
395 }
396 }
397 }
398 }
399}
400
401fn header_params(struct_info: &StructInfo) -> Vec<TokenStream> {
402 struct_info
403 .header_fields
404 .iter()
405 .enumerate()
406 .map(|(i, field)| {
407 let param_ident = &struct_info.header_param_idents[i];
408 let field_ty = &field.ty;
409 quote! { #param_ident: #field_ty }
410 })
411 .collect()
412}
413
414fn header_field_writes(struct_info: &StructInfo) -> Vec<TokenStream> {
415 struct_info
416 .header_field_idents
417 .iter()
418 .enumerate()
419 .map(|(i, field_ident)| {
420 let tuple_idx = syn::Index::from(i);
421 quote! { ::core::ptr::write_unaligned(&raw mut ((*fat_ptr).#field_ident), args.#tuple_idx);}
422 })
423 .collect()
424}
425
426fn args_tuple_assignment(struct_info: &StructInfo) -> TokenStream {
427 let header_param_idents = &struct_info.header_param_idents;
428 let tail_param_ident = &struct_info.tail_param_ident;
429
430 quote! {
431 let args = ( #( #header_param_idents, )* #tail_param_ident, );
432 }
433}
434
435fn alloc_funcs(no_std: bool) -> (TokenStream, TokenStream) {
436 let (box_path, alloc_path, handle_alloc_error) = if no_std {
437 (
438 quote! { ::alloc::boxed::Box },
439 quote! { ::alloc::alloc::alloc },
440 quote! { panic!("out of memory") },
441 )
442 } else {
443 (
444 quote! { ::std::boxed::Box },
445 quote! { ::std::alloc::alloc },
446 quote! { ::std::alloc::handle_alloc_error(layout) },
447 )
448 };
449
450 let alloc = quote! {
451 let mem_ptr = #alloc_path(layout);
452 if mem_ptr.is_null() {
453 #handle_alloc_error
454 }
455 };
456
457 (alloc, box_path)
458}
459
460fn alloc_zst(box_path: &TokenStream, for_trait: bool) -> TokenStream {
461 let mem_ptr = quote! { let mem_ptr = ::core::ptr::NonNull::<()>::dangling().as_ptr(); };
462
463 let fat_ptr = if for_trait {
464 quote! { let fat_ptr = ::core::mem::transmute::<(*mut (), *const ()), *mut Self>((mem_ptr, vtable)); }
465 } else {
466 quote! { let fat_ptr = ::core::mem::transmute::<(*mut (), usize), *mut Self>((mem_ptr, 0_usize)); }
467 };
468
469 let box_from_raw = quote! { #box_path::from_raw(fat_ptr) };
470
471 quote! {
472 #mem_ptr
473 #fat_ptr
474 #box_from_raw
475 }
476}
477
478fn factory_for_slice_arg(macro_args: &MacroArgs, struct_info: &StructInfo, tail_elem_type: &Type) -> TokenStream {
479 let copy_bound_tokens: syn::WherePredicate = syn::parse_quote_spanned! {tail_elem_type.span()=>
480 #tail_elem_type: ::core::marker::Copy
481 };
482
483 let mut factory_where_clause = struct_info.struct_generics.where_clause.as_ref().map_or_else(
484 || syn::WhereClause {
485 where_token: Where::default(),
486 predicates: Punctuated::new(),
487 },
488 Clone::clone,
489 );
490
491 factory_where_clause.predicates.push(copy_bound_tokens);
492
493 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
494 let make_zst = alloc_zst(&box_path, false);
495
496 let tail_layout = tail_layout(tail_elem_type, struct_info.tail_field.ty.span());
497 let header_layout = header_layout(macro_args, struct_info, false);
498 let tuple_assignment = args_tuple_assignment(struct_info);
499 let header_field_writes = header_field_writes(struct_info);
500 let header_params = header_params(struct_info);
501
502 let factory_name = format_ident!("{}_from_slice", ¯o_args.base_factory_name);
503 let visibility = ¯o_args.visibility;
504
505 let tail_param = &struct_info.tail_param_ident;
506 let tail_field = &struct_info.tail_field_ident;
507 let struct_name = &struct_info.struct_name;
508 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
509
510 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
511
512 quote! {
513 #[doc = #factory_doc]
514 #[allow(clippy::let_unit_value)]
515 #[allow(clippy::zst_offset)]
516 #[allow(clippy::transmute_undefined_repr)]
517 #visibility fn #factory_name (
518 #( #header_params, )*
519 #tail_param: &[#tail_elem_type]
520 ) -> #box_path<Self> #factory_where_clause {
521 #tuple_assignment
522
523 let s = args.#tail_args_tuple_idx.as_ref();
524 let len = s.len();
525
526 #header_layout
527 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
528 let layout = layout.pad_to_align();
529
530 unsafe {
531 if layout.size() == 0 {
532 #make_zst
533 } else {
534 #alloc
535
536 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
537 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
538
539 #( #header_field_writes )*
540
541 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<#tail_elem_type>();
543 ::core::ptr::copy_nonoverlapping(s.as_ptr(), tail_ptr, len);
544
545 #box_path::from_raw(fat_ptr)
546 }
547 }
548 }
549 }
550}
551
552fn factory_for_iter_arg(macro_args: &MacroArgs, struct_info: &StructInfo, tail_type: &Type) -> TokenStream {
553 let guard_type_tokens = guard_type(macro_args);
554 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
555 let make_zst = alloc_zst(&box_path, false);
556
557 let tail_layout = tail_layout(tail_type, struct_info.tail_field.ty.span());
558 let header_layout = header_layout(macro_args, struct_info, false);
559 let tuple_assignment = args_tuple_assignment(struct_info);
560 let header_field_writes = header_field_writes(struct_info);
561 let header_params = header_params(struct_info);
562
563 let visibility = ¯o_args.visibility;
564 let factory_name = ¯o_args.base_factory_name;
565 let iter_generic_param = ¯o_args.generic_name;
566
567 let tail_param = &struct_info.tail_param_ident;
568 let tail_field = &struct_info.tail_field_ident;
569 let struct_name = &struct_info.struct_name;
570 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
571
572 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
573
574 quote! {
575 #[doc = #factory_doc]
576 #[allow(clippy::let_unit_value)]
577 #[allow(clippy::zst_offset)]
578 #[allow(clippy::transmute_undefined_repr)]
579 #visibility fn #factory_name <#iter_generic_param> (
580 #( #header_params, )*
581 #tail_param: #iter_generic_param
582 ) -> #box_path<Self>
583 where
584 #iter_generic_param: ::core::iter::IntoIterator<Item = #tail_type>,
585 <#iter_generic_param as ::core::iter::IntoIterator>::IntoIter: ::core::iter::ExactSizeIterator
586 {
587 #guard_type_tokens
588 #tuple_assignment
589
590 let iter = args.#tail_args_tuple_idx.into_iter();
591 let len = iter.len();
592
593 #header_layout
594 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
595 let layout = layout.pad_to_align();
596
597 unsafe {
598 if layout.size() == 0 {
599 #make_zst
600 } else {
601 #alloc
602
603 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
604 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
605
606 #( #header_field_writes )*
607
608 let tail_ptr = ::core::ptr::addr_of_mut!((*fat_ptr).#tail_field).cast::<#tail_type>();
610 let mut guard = Guard { mem_ptr, tail_ptr, layout, initialized: 0 };
611 iter.for_each(|element| {
612 if guard.initialized == len {
613 panic!("Mismatch between iterator-reported length and the number of items produced by the iterator");
614 }
615
616 ::core::ptr::write(tail_ptr.add(guard.initialized), element);
617 guard.initialized += 1;
618 });
619
620 if guard.initialized != len {
621 panic!("Mismatch between iterator-reported length and the number of items produced by the iterator");
622 }
623
624 ::std::mem::forget(guard);
625
626 #box_path::from_raw(fat_ptr)
627 }
628 }
629 }
630 }
631}
632
633fn factory_for_str_arg(macro_args: &MacroArgs, struct_info: &StructInfo) -> TokenStream {
634 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
635 let make_zst = alloc_zst(&box_path, false);
636
637 let tail_layout = tail_layout("e! { u8 }, struct_info.tail_field.ty.span());
638 let header_layout = header_layout(macro_args, struct_info, false);
639 let tuple_assignment = args_tuple_assignment(struct_info);
640 let header_field_writes = header_field_writes(struct_info);
641 let header_params = header_params(struct_info);
642
643 let factory_name = ¯o_args.base_factory_name;
644 let visibility = ¯o_args.visibility;
645
646 let struct_name = &struct_info.struct_name;
647 let tail_param = &struct_info.tail_param_ident;
648 let tail_field = &struct_info.tail_field_ident;
649 let tail_type = &struct_info.tail_field.ty;
650 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
651
652 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
653
654 quote! {
655 #[doc = #factory_doc]
656 #[allow(clippy::let_unit_value)]
657 #[allow(clippy::zst_offset)]
658 #[allow(clippy::transmute_undefined_repr)]
659 #visibility fn #factory_name(
660 #( #header_params, )*
661 #tail_param: impl ::core::convert::AsRef<str>
662 ) -> #box_path<Self> {
663 #tuple_assignment
664
665 ::core::assert_eq!(::core::any::TypeId::of::<#tail_type>(), ::core::any::TypeId::of::<str>());
666 let s = args.#tail_args_tuple_idx.as_ref();
667 let len = s.len();
668
669 #header_layout
670 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
671 let layout = layout.pad_to_align();
672
673 unsafe {
674 if layout.size() == 0 {
675 #make_zst
676 } else {
677 #alloc
678
679 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
680 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
681
682 #( #header_field_writes )*
683
684 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<u8>();
686 ::core::ptr::copy_nonoverlapping(s.as_ptr(), tail_ptr, len);
687
688 #box_path::from_raw(fat_ptr)
689 }
690 }
691 }
692 }
693}
694
695fn factory_for_trait_arg(macro_args: &MacroArgs, struct_info: &StructInfo, trait_path: &Path) -> TokenStream {
696 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
697 let make_zst = alloc_zst(&box_path, true);
698
699 let header_layout = header_layout(macro_args, struct_info, true);
700 let tuple_assignment = args_tuple_assignment(struct_info);
701 let header_field_writes = header_field_writes(struct_info);
702 let header_params = header_params(struct_info);
703
704 let factory_name = ¯o_args.base_factory_name;
705 let trait_generic = ¯o_args.generic_name;
706 let visibility = ¯o_args.visibility;
707
708 let struct_name = &struct_info.struct_name;
709 let tail_param = &struct_info.tail_param_ident;
710 let tail_field = &struct_info.tail_field_ident;
711 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
712
713 let factory_doc = format!("Builds an instance of `Box<{struct_name}>`.");
714
715 quote! {
716 #[doc = #factory_doc]
717 #[allow(clippy::let_unit_value)]
718 #[allow(clippy::zst_offset)]
719 #[allow(clippy::transmute_undefined_repr)]
720 #visibility fn #factory_name <#trait_generic> (
721 #( #header_params, )*
722 #tail_param: #trait_generic
723 ) -> #box_path<Self>
724 where
725 #trait_generic: #trait_path + Sized
726 {
727 #tuple_assignment
728
729 let s = args.#tail_args_tuple_idx;
730 let trait_object: &dyn #trait_path = &s;
731 let (_, vtable): (*const #trait_generic, *const ()) = unsafe { ::core::mem::transmute(trait_object) };
732
733 #header_layout
734 let layout = layout.extend(::core::alloc::Layout::new::<#trait_generic>()).expect("Struct exceeds maximum size allowed of isize::MAX").0;
735 let layout = layout.pad_to_align();
736
737 unsafe {
738 if layout.size() == 0 {
739 #make_zst
740 } else {
741 #alloc
742
743 let fat_ptr = ::core::mem::transmute::<(*mut u8, *const ()), *mut Self>((mem_ptr, vtable));
744 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
745
746 #( #header_field_writes )*
747
748 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<#trait_generic>();
749 ::core::ptr::copy_nonoverlapping(::core::ptr::addr_of!(s), tail_ptr, 1);
750
751 #box_path::from_raw(fat_ptr)
752 }
753 }
754 }
755 }
756}
757
758fn make_dst_factory_impl(attr_args: TokenStream, item: TokenStream) -> SynResult<TokenStream> {
759 let macro_args = MacroArgs::parse(attr_args)?;
760 let input_struct: ItemStruct = syn::parse2(item)?;
761 let struct_info = StructInfo::new(&input_struct)?;
762
763 let mut factories = Vec::new();
764 match &struct_info.tail_kind {
765 TailKind::Slice(elem_type) => {
766 factories.push(factory_for_iter_arg(¯o_args, &struct_info, elem_type));
767 factories.push(factory_for_slice_arg(¯o_args, &struct_info, elem_type));
768 }
769
770 TailKind::Str => {
771 factories.push(factory_for_str_arg(¯o_args, &struct_info));
772 }
773
774 TailKind::TraitObject(trait_path) => {
775 factories.push(factory_for_trait_arg(¯o_args, &struct_info, trait_path));
776 }
777 }
778
779 let (impl_generics, ty_generics, where_clause) = struct_info.struct_generics.split_for_impl();
780 let struct_name_ident = struct_info.struct_name;
781
782 Ok(quote! {
783 #input_struct
784
785 impl #impl_generics #struct_name_ident #ty_generics #where_clause {
786 #( #factories )*
787 }
788 })
789}
790
791#[proc_macro_attribute]
833pub fn make_dst_factory(attr_args: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
834 let result = make_dst_factory_impl(attr_args.into(), item.into());
835 result.unwrap_or_else(|err| err.to_compile_error()).into()
836}