1extern crate proc_macro;
186mod macro_args;
187
188use macro_args::MacroArgs;
189use proc_macro2::{Span, TokenStream};
190use quote::{ToTokens, format_ident, quote, quote_spanned};
191use syn::token::{Comma, Where};
192use syn::{
193 Field, Fields, Generics, Ident, Index, ItemStruct, Type, TypePath, TypeTraitObject, parse::Result as SynResult, punctuated::Punctuated,
194 spanned::Spanned,
195};
196
197enum TailKind {
198 Slice(Box<Type>),
199 Str,
200 TraitObject(TypeTraitObject),
201}
202
203enum FieldIdent {
204 Named(Ident),
205 Unnamed(Index),
206}
207
208impl ToTokens for FieldIdent {
209 fn to_tokens(&self, tokens: &mut TokenStream) {
210 match self {
211 Self::Named(ident) => ident.to_tokens(tokens),
212 Self::Unnamed(index) => index.to_tokens(tokens),
213 }
214 }
215}
216
217struct StructInfo<'a> {
218 struct_name: &'a Ident,
219 struct_generics: &'a Generics,
220 header_fields: Box<[&'a Field]>,
221 header_field_idents: Box<[FieldIdent]>,
222 header_param_idents: Box<[Ident]>,
223 tail_field: &'a Field,
224 tail_field_ident: FieldIdent,
225 tail_param_ident: Ident,
226 tail_kind: TailKind,
227}
228
229impl<'a> StructInfo<'a> {
230 fn new(input_struct: &'a ItemStruct) -> SynResult<Self> {
231 match &input_struct.fields {
232 Fields::Named(named_fields) => Self::process_fields(input_struct, &named_fields.named),
233 Fields::Unnamed(unnamed_fields) => Self::process_fields(input_struct, &unnamed_fields.unnamed),
234 Fields::Unit => Err(syn::Error::new_spanned(input_struct, "Unit structs are not supported")),
235 }
236 }
237
238 fn process_fields(input_struct: &'a ItemStruct, fields: &'a Punctuated<Field, Comma>) -> SynResult<Self> {
239 if fields.is_empty() {
240 return Err(syn::Error::new_spanned(input_struct, "Struct must have at least one field"));
241 }
242
243 let tail_field = fields.last().unwrap();
244 let mut tail_kind = TailKind::Str;
245
246 match &tail_field.ty {
247 Type::Path(type_path) if type_path.path.is_ident("str") => {}
248 Type::Path(TypePath { path, .. }) if path.segments.last().is_some_and(|s| s.ident == "str") => {}
249 Type::TraitObject(trait_object) => tail_kind = TailKind::TraitObject(trait_object.clone()),
250 Type::Slice(type_slice) => tail_kind = TailKind::Slice(type_slice.elem.clone()),
251
252 _ => {
253 return Err(syn::Error::new_spanned(
254 &tail_field.ty,
255 "Last field must be a dynamically sized type like [T], str, or dyn Trait",
256 ));
257 }
258 }
259
260 let header_fields: Vec<_> = fields.iter().take(fields.len() - 1).collect();
261 let header_param_idents: Vec<_> = header_fields
262 .iter()
263 .enumerate()
264 .map(|(i, field)| field.ident.clone().unwrap_or_else(|| format_ident!("f{i}")))
265 .collect();
266
267 let mut header_field_idents = Vec::with_capacity(header_fields.len());
268 for (i, field) in header_fields.iter().enumerate() {
269 header_field_idents.push(
270 field
271 .ident
272 .as_ref()
273 .map_or_else(|| FieldIdent::Unnamed(Index::from(i)), |ident| FieldIdent::Named(ident.clone())),
274 );
275 }
276
277 let tail_field_ident = tail_field.ident.as_ref().map_or_else(
278 || FieldIdent::Unnamed(Index::from(header_fields.len())),
279 |ident| FieldIdent::Named(ident.clone()),
280 );
281
282 Ok(Self {
283 struct_name: &input_struct.ident,
284 struct_generics: &input_struct.generics,
285 header_fields: header_fields.into_boxed_slice(),
286 header_field_idents: header_field_idents.into_boxed_slice(),
287 header_param_idents: header_param_idents.into_boxed_slice(),
288 tail_field,
289 tail_field_ident,
290 tail_param_ident: tail_field.ident.clone().unwrap_or_else(|| format_ident!("tail")),
291 tail_kind,
292 })
293 }
294}
295
296fn header_layout(macro_args: &MacroArgs, struct_info: &StructInfo, for_trait: bool) -> TokenStream {
297 let tail_field_ident = &struct_info.tail_field_ident;
298
299 let header_field_types: Vec<_> = struct_info.header_fields.iter().map(|field| &field.ty).collect();
300
301 if header_field_types.is_empty() {
302 return quote! {
303 let layout = ::core::alloc::Layout::from_size_align(0, 1).unwrap();
304 };
305 }
306
307 let fat_payload = if for_trait {
308 quote! { vtable }
309 } else {
310 quote! { 0_usize }
311 };
312
313 let tail_type = match &struct_info.tail_kind {
314 TailKind::Slice(elem_type) => quote! { #elem_type },
315 TailKind::Str => quote! { u8 },
316 TailKind::TraitObject(_) => {
317 let generic_name = ¯o_args.generic_name;
318 quote! { #generic_name }
319 }
320 };
321
322 quote! {
323 let buffer = ::core::mem::MaybeUninit::<(#( #header_field_types, )* #tail_type)>::uninit();
324 let (offset, align) = unsafe {
325 let head_ptr: *const Self = ::core::mem::transmute((&raw const buffer, #fat_payload));
326 let tail_ptr = &raw const (*head_ptr).#tail_field_ident;
327 (
328 (tail_ptr as *const u8).offset_from_unsigned(head_ptr as *const u8),
329 ::core::mem::align_of_val::<Self>(&*head_ptr)
330 )
331 };
332
333 let layout = ::core::alloc::Layout::from_size_align(offset, align).unwrap();
334 }
335}
336
337fn tail_layout<T: quote::ToTokens>(tail_type: &T, span: Span) -> TokenStream {
338 quote_spanned! { span => ::core::alloc::Layout::array::<#tail_type>(len).expect("Array exceeds maximum size allowed of isize::MAX") }
339}
340
341fn guard_type(macro_args: &MacroArgs) -> TokenStream {
342 let dealloc_path = if macro_args.no_std {
343 quote! { ::alloc::alloc::dealloc }
344 } else {
345 quote! { ::std::alloc::dealloc }
346 };
347
348 quote! {
349 struct Guard<T> {
350 mem_ptr: *mut u8,
351 tail_ptr: *mut T,
352 initialized: usize,
353 layout: ::core::alloc::Layout,
354 }
355
356 impl<T> Drop for Guard<T> {
357 fn drop(&mut self) {
358 unsafe {
359 let slice_ptr = ::core::ptr::slice_from_raw_parts_mut(self.tail_ptr, self.initialized);
360 ::core::ptr::drop_in_place(slice_ptr);
361 #dealloc_path(self.mem_ptr, self.layout);
362 }
363 }
364 }
365 }
366}
367
368fn header_params(struct_info: &StructInfo) -> Vec<TokenStream> {
369 let mut header_params_tokens = Vec::new();
370
371 for i in 0..struct_info.header_fields.len() {
372 let param_ident = &struct_info.header_param_idents[i];
373 let field_ty = &struct_info.header_fields[i].ty;
374 header_params_tokens.push(quote! { #param_ident: #field_ty });
375 }
376 header_params_tokens
377}
378
379fn header_field_writes(struct_info: &StructInfo) -> Vec<TokenStream> {
380 struct_info
381 .header_field_idents
382 .iter()
383 .enumerate()
384 .map(|(i, field_ident)| {
385 let tuple_idx = syn::Index::from(i);
386 quote! { ::core::ptr::write_unaligned(&raw mut ((*fat_ptr).#field_ident), args.#tuple_idx);}
387 })
388 .collect()
389}
390
391fn args_tuple_assignment(struct_info: &StructInfo) -> TokenStream {
392 let header_param_idents = &struct_info.header_param_idents;
393 let tail_param_ident = &struct_info.tail_param_ident;
394
395 quote! {
396 let args = ( #( #header_param_idents, )* #tail_param_ident, );
397 }
398}
399
400fn alloc_funcs(no_std: bool) -> (TokenStream, TokenStream) {
401 let (box_path, alloc_path, handle_alloc_error) = if no_std {
402 (
403 quote! { ::alloc::boxed::Box },
404 quote! { ::alloc::alloc::alloc },
405 quote! { panic!("out of memory") },
406 )
407 } else {
408 (
409 quote! { ::std::boxed::Box },
410 quote! { ::std::alloc::alloc },
411 quote! { ::std::alloc::handle_alloc_error(layout) },
412 )
413 };
414
415 let alloc = quote! {
416 let mem_ptr = #alloc_path(layout);
417 if mem_ptr.is_null() {
418 #handle_alloc_error
419 }
420 };
421
422 (alloc, box_path)
423}
424
425fn alloc_zst(box_path: &TokenStream, for_trait: bool) -> TokenStream {
426 let mem_ptr = quote! { let mem_ptr = ::core::ptr::NonNull::<()>::dangling().as_ptr(); };
427
428 let fat_ptr = if for_trait {
429 quote! { let fat_ptr = ::core::mem::transmute::<(*mut (), *const ()), *mut Self>((mem_ptr, vtable)); }
430 } else {
431 quote! { let fat_ptr = ::core::mem::transmute::<(*mut (), usize), *mut Self>((mem_ptr, 0_usize)); }
432 };
433
434 let box_from_raw = quote! { #box_path::from_raw(fat_ptr) };
435
436 quote! {
437 #mem_ptr
438 #fat_ptr
439 #box_from_raw
440 }
441}
442
443fn factory_for_slice_arg(macro_args: &MacroArgs, struct_info: &StructInfo, tail_elem_type: &Type) -> TokenStream {
444 let copy_bound_tokens: syn::WherePredicate = syn::parse_quote_spanned! {tail_elem_type.span()=>
445 #tail_elem_type: ::core::marker::Copy
446 };
447
448 let mut factory_where_clause = struct_info
449 .struct_generics
450 .where_clause
451 .clone()
452 .unwrap_or_else(|| syn::WhereClause {
453 where_token: Where::default(),
454 predicates: Punctuated::new(),
455 });
456
457 factory_where_clause.predicates.push(copy_bound_tokens);
458
459 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
460 let make_zst = alloc_zst(&box_path, false);
461
462 let tail_layout = tail_layout(tail_elem_type, struct_info.tail_field.ty.span());
463 let header_layout = header_layout(macro_args, struct_info, false);
464 let tuple_assignment = args_tuple_assignment(struct_info);
465 let header_field_writes = header_field_writes(struct_info);
466 let header_params = header_params(struct_info);
467
468 let factory_name = format_ident!("{}_from_slice", ¯o_args.base_factory_name);
469 let visibility = ¯o_args.visibility;
470
471 let tail_param = &struct_info.tail_param_ident;
472 let tail_field = &struct_info.tail_field_ident;
473 let struct_name = &struct_info.struct_name;
474 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
475
476 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
477
478 quote! {
479 #[doc = #factory_doc]
480 #[allow(clippy::let_unit_value)]
481 #[allow(clippy::zst_offset)]
482 #[allow(clippy::transmute_undefined_repr)]
483 #visibility fn #factory_name (
484 #( #header_params, )*
485 #tail_param: &[#tail_elem_type]
486 ) -> #box_path<Self> #factory_where_clause {
487 #tuple_assignment
488
489 let s = args.#tail_args_tuple_idx.as_ref();
490 let len = s.len();
491
492 #header_layout
493 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
494 let layout = layout.pad_to_align();
495
496 unsafe {
497 if layout.size() == 0 {
498 #make_zst
499 } else {
500 #alloc
501
502 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
503 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
504
505 #( #header_field_writes )*
506
507 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<#tail_elem_type>();
509 ::core::ptr::copy_nonoverlapping(s.as_ptr(), tail_ptr, len);
510
511 #box_path::from_raw(fat_ptr)
512 }
513 }
514 }
515 }
516}
517
518fn factory_for_iter_arg(macro_args: &MacroArgs, struct_info: &StructInfo, tail_type: &Type) -> TokenStream {
519 let guard_type_tokens = guard_type(macro_args);
520 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
521 let make_zst = alloc_zst(&box_path, false);
522
523 let tail_layout = tail_layout(tail_type, struct_info.tail_field.ty.span());
524 let header_layout = header_layout(macro_args, struct_info, false);
525 let tuple_assignment = args_tuple_assignment(struct_info);
526 let header_field_writes = header_field_writes(struct_info);
527 let header_params = header_params(struct_info);
528
529 let visibility = ¯o_args.visibility;
530 let factory_name = ¯o_args.base_factory_name;
531 let iter_generic_param = ¯o_args.generic_name;
532
533 let tail_param = &struct_info.tail_param_ident;
534 let tail_field = &struct_info.tail_field_ident;
535 let struct_name = &struct_info.struct_name;
536 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
537
538 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
539
540 quote! {
541 #[doc = #factory_doc]
542 #[allow(clippy::let_unit_value)]
543 #[allow(clippy::zst_offset)]
544 #[allow(clippy::transmute_undefined_repr)]
545 #visibility fn #factory_name <#iter_generic_param> (
546 #( #header_params, )*
547 #tail_param: #iter_generic_param
548 ) -> #box_path<Self>
549 where
550 #iter_generic_param: ::core::iter::IntoIterator<Item = #tail_type>,
551 <#iter_generic_param as ::core::iter::IntoIterator>::IntoIter: ::core::iter::ExactSizeIterator
552 {
553 #guard_type_tokens
554 #tuple_assignment
555
556 let iter = args.#tail_args_tuple_idx.into_iter();
557 let len = iter.len();
558
559 #header_layout
560 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
561 let layout = layout.pad_to_align();
562
563 unsafe {
564 if layout.size() == 0 {
565 #make_zst
566 } else {
567 #alloc
568
569 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
570 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
571
572 #( #header_field_writes )*
573
574 let tail_ptr = ::core::ptr::addr_of_mut!((*fat_ptr).#tail_field).cast::<#tail_type>();
576 let mut guard = Guard { mem_ptr, tail_ptr, layout, initialized: 0 };
577 iter.for_each(|element| {
578 if guard.initialized == len {
579 panic!("Mismatch between iterator-reported length and the number of items produced by the iterator");
580 }
581
582 ::core::ptr::write(tail_ptr.add(guard.initialized), element);
583 guard.initialized += 1;
584 });
585
586 if guard.initialized != len {
587 panic!("Mismatch between iterator-reported length and the number of items produced by the iterator");
588 }
589
590 ::std::mem::forget(guard);
591
592 #box_path::from_raw(fat_ptr)
593 }
594 }
595 }
596 }
597}
598
599fn factory_for_str_arg(macro_args: &MacroArgs, struct_info: &StructInfo) -> TokenStream {
600 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
601 let make_zst = alloc_zst(&box_path, false);
602
603 let tail_layout = tail_layout("e! { u8 }, struct_info.tail_field.ty.span());
604 let header_layout = header_layout(macro_args, struct_info, false);
605 let tuple_assignment = args_tuple_assignment(struct_info);
606 let header_field_writes = header_field_writes(struct_info);
607 let header_params = header_params(struct_info);
608
609 let factory_name = ¯o_args.base_factory_name;
610 let visibility = ¯o_args.visibility;
611
612 let struct_name = &struct_info.struct_name;
613 let tail_param = &struct_info.tail_param_ident;
614 let tail_field = &struct_info.tail_field_ident;
615 let tail_type = &struct_info.tail_field.ty;
616 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
617
618 let factory_doc = format!("Creates an instance of `Box<{struct_name}>`.");
619
620 quote! {
621 #[doc = #factory_doc]
622 #[allow(clippy::let_unit_value)]
623 #[allow(clippy::zst_offset)]
624 #[allow(clippy::transmute_undefined_repr)]
625 #visibility fn #factory_name(
626 #( #header_params, )*
627 #tail_param: impl ::core::convert::AsRef<str>
628 ) -> #box_path<Self> {
629 #tuple_assignment
630
631 ::core::assert_eq!(::core::any::TypeId::of::<#tail_type>(), ::core::any::TypeId::of::<str>());
632 let s = args.#tail_args_tuple_idx.as_ref();
633 let len = s.len();
634
635 #header_layout
636 let layout = layout.extend(#tail_layout).expect("Struct exceeds maximum size allowed of isize::MAX").0;
637 let layout = layout.pad_to_align();
638
639 unsafe {
640 if layout.size() == 0 {
641 #make_zst
642 } else {
643 #alloc
644
645 let fat_ptr = ::core::mem::transmute::<(*mut u8, usize), *mut Self>((mem_ptr, len));
646 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
647
648 #( #header_field_writes )*
649
650 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<u8>();
652 ::core::ptr::copy_nonoverlapping(s.as_ptr(), tail_ptr, len);
653
654 #box_path::from_raw(fat_ptr)
655 }
656 }
657 }
658 }
659}
660
661fn factory_for_trait_arg(macro_args: &MacroArgs, struct_info: &StructInfo, type_trait_object: &TypeTraitObject) -> SynResult<TokenStream> {
662 for bound in &type_trait_object.bounds {
664 if let syn::TypeParamBound::Trait(trait_bound) = bound {
665 if trait_bound.lifetimes.is_some() {
666 return Err(syn::Error::new_spanned(
667 trait_bound,
668 "Higher-rank trait bounds (e.g., `for<'a> dyn Trait<'a>`) are not supported for the tail field.",
669 ));
670 }
671 }
672 }
673
674 let trait_path = type_trait_object
676 .bounds
677 .iter()
678 .find_map(|bound| {
679 if let syn::TypeParamBound::Trait(trait_bound) = bound {
680 Some(&trait_bound.path)
681 } else {
682 None }
684 })
685 .unwrap();
686
687 let (alloc, box_path) = alloc_funcs(macro_args.no_std);
688 let make_zst = alloc_zst(&box_path, true);
689
690 let header_layout = header_layout(macro_args, struct_info, true);
691 let tuple_assignment = args_tuple_assignment(struct_info);
692 let header_field_writes = header_field_writes(struct_info);
693 let header_params = header_params(struct_info);
694
695 let factory_name = ¯o_args.base_factory_name;
696 let trait_generic = ¯o_args.generic_name;
697 let visibility = ¯o_args.visibility;
698
699 let struct_name = &struct_info.struct_name;
700 let tail_param = &struct_info.tail_param_ident;
701 let tail_field = &struct_info.tail_field_ident;
702 let tail_args_tuple_idx = Index::from(struct_info.header_fields.len());
703
704 let factory_doc = format!("Builds an instance of `Box<{struct_name}>`.");
705
706 Ok(quote! {
707 #[doc = #factory_doc]
708 #[allow(clippy::let_unit_value)]
709 #[allow(clippy::zst_offset)]
710 #[allow(clippy::transmute_undefined_repr)]
711 #visibility fn #factory_name <#trait_generic> (
712 #( #header_params, )*
713 #tail_param: #trait_generic
714 ) -> #box_path<Self>
715 where
716 #trait_generic: #trait_path + Sized
717 {
718 #tuple_assignment
719
720 let s = args.#tail_args_tuple_idx;
721 let trait_object: &dyn #trait_path = &s;
722 let (_, vtable): (*const #trait_generic, *const ()) = unsafe { ::core::mem::transmute(trait_object) };
723
724 #header_layout
725 let layout = layout.extend(::core::alloc::Layout::new::<#trait_generic>()).expect("Struct exceeds maximum size allowed of isize::MAX").0;
726 let layout = layout.pad_to_align();
727
728 unsafe {
729 if layout.size() == 0 {
730 #make_zst
731 } else {
732 #alloc
733
734 let fat_ptr = ::core::mem::transmute::<(*mut u8, *const ()), *mut Self>((mem_ptr, vtable));
735 ::core::debug_assert_eq!(::core::alloc::Layout::for_value(&*fat_ptr), layout);
736
737 #( #header_field_writes )*
738
739 let tail_ptr = (&raw mut (*fat_ptr).#tail_field).cast::<#trait_generic>();
740 ::core::ptr::copy_nonoverlapping(::core::ptr::addr_of!(s), tail_ptr, 1);
741
742 #box_path::from_raw(fat_ptr)
743 }
744 }
745 }
746 })
747}
748
749fn make_dst_factory_impl(attr_args: TokenStream, item: TokenStream) -> SynResult<TokenStream> {
750 let macro_args = MacroArgs::parse(attr_args)?;
751 let input_struct: ItemStruct = syn::parse2(item)?;
752 let struct_info = StructInfo::new(&input_struct)?;
753
754 let mut generated_factories = Vec::new();
755 match &struct_info.tail_kind {
756 TailKind::Slice(elem_type) => {
757 generated_factories.push(factory_for_iter_arg(¯o_args, &struct_info, elem_type));
758 generated_factories.push(factory_for_slice_arg(¯o_args, &struct_info, elem_type));
759 }
760
761 TailKind::Str => {
762 generated_factories.push(factory_for_str_arg(¯o_args, &struct_info));
763 }
764
765 TailKind::TraitObject(type_trait_object) => {
766 generated_factories.push(factory_for_trait_arg(¯o_args, &struct_info, type_trait_object)?);
767 }
768 }
769
770 let (impl_generics, ty_generics, where_clause) = struct_info.struct_generics.split_for_impl();
771 let struct_name_ident = struct_info.struct_name;
772 Ok(quote! {
773 #input_struct
774
775 impl #impl_generics #struct_name_ident #ty_generics #where_clause {
776 #( #generated_factories )*
777 }
778 })
779}
780
781#[proc_macro_attribute]
823pub fn make_dst_factory(attr_args: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
824 let result = make_dst_factory_impl(attr_args.into(), item.into());
825 result.unwrap_or_else(|err| err.to_compile_error()).into()
826}