1#![recursion_limit = "256"]
2
3type TokenStream1 = proc_macro::TokenStream;
4type TokenStream2 = proc_macro2::TokenStream;
5
6use quote::quote;
7use syn::{Fields, Type};
8
9#[proc_macro_derive(Equivalence)]
10pub fn create_user_datatype(input: TokenStream1) -> TokenStream1 {
11 let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse struct");
12 let result = match ast.data {
13 syn::Data::Enum(_) => panic!("#[derive(Equivalence)] is not compatible with enums"),
14 syn::Data::Union(_) => panic!("#[derive(Equivalence)] is not compatible with unions"),
15 syn::Data::Struct(ref s) => equivalence_for_struct(&ast, &s.fields),
16 };
17 result.into()
18}
19
20fn equivalence_for_tuple_field(type_tuple: &syn::TypeTuple) -> TokenStream2 {
21 let field_blocklengths = type_tuple.elems.iter().map(|_| 1);
22
23 let fields = type_tuple
24 .elems
25 .iter()
26 .enumerate()
27 .map(|(i, _)| syn::Index::from(i));
28
29 let field_datatypes = type_tuple.elems.iter().map(equivalence_for_type);
30
31 quote! {
32 &::mpi::datatype::UncommittedUserDatatype::structured(
33 &[#(#field_blocklengths as ::mpi::Count),*],
34 &[#(::mpi::internal::memoffset::offset_of_tuple!(#type_tuple, #fields) as ::mpi::Address),*],
35 &[#(::mpi::datatype::UncommittedDatatypeRef::from(#field_datatypes)),*],
36 )
37 }
38}
39
40fn equivalence_for_array_field(type_array: &syn::TypeArray) -> TokenStream2 {
41 let ty = equivalence_for_type(&type_array.elem);
42 let len = &type_array.len;
43 quote! { &::mpi::datatype::UncommittedUserDatatype::contiguous(
47 {let len: usize = #len; len}.try_into().expect("rsmpi derive: Array size is to large for MPI_Datatype i32"), &#ty)
48 }
49}
50
51fn equivalence_for_type(ty: &syn::Type) -> TokenStream2 {
52 match ty {
53 Type::Path(ref type_path) => quote!(
54 <#type_path as ::mpi::datatype::Equivalence>::equivalent_datatype()),
55 Type::Tuple(ref type_tuple) => equivalence_for_tuple_field(type_tuple),
56 Type::Array(ref type_array) => equivalence_for_array_field(type_array),
57 _ => panic!("Unsupported type!"),
58 }
59}
60
61fn equivalence_for_struct(ast: &syn::DeriveInput, fields: &Fields) -> TokenStream2 {
62 let ident = &ast.ident;
63
64 let field_blocklengths = fields.iter().map(|_| 1);
65
66 let field_names = fields
67 .iter()
68 .enumerate()
69 .map(|(i, field)| -> Box<dyn quote::ToTokens> {
70 if let Some(ident) = field.ident.as_ref() {
71 Box::new(ident)
73 } else {
74 Box::new(syn::Index::from(i))
76 }
77 });
78
79 let field_datatypes = fields.iter().map(|field| equivalence_for_type(&field.ty));
80
81 let ident_str = ident.to_string();
82
83 quote! {
87 unsafe impl ::mpi::datatype::Equivalence for #ident {
88 type Out = ::mpi::datatype::DatatypeRef<'static>;
89 fn equivalent_datatype() -> Self::Out {
90 use ::mpi::internal::once_cell::sync::Lazy;
91 use ::std::convert::TryInto;
92
93 static DATATYPE: Lazy<::mpi::datatype::UserDatatype> = Lazy::new(|| {
94 ::mpi::datatype::internal::check_derive_equivalence_universe_state(#ident_str);
95
96 ::mpi::datatype::UserDatatype::structured::<
97 ::mpi::datatype::UncommittedDatatypeRef,
98 >(
99 &[#(#field_blocklengths as ::mpi::Count),*],
100 &[#(::mpi::internal::memoffset::offset_of!(#ident, #field_names) as ::mpi::Address),*],
101 &[#(::mpi::datatype::UncommittedDatatypeRef::from(#field_datatypes)),*],
102 )
103 });
104
105 DATATYPE.as_ref()
106 }
107 }
108 }
109}