mpi_derive/
lib.rs

1#![recursion_limit = "256"]
2
3type TokenStream1 = proc_macro::TokenStream;
4type TokenStream2 = proc_macro2::TokenStream;
5
6use quote::quote;
7use syn::{Fields, Type};
8
9#[proc_macro_derive(Equivalence)]
10pub fn create_user_datatype(input: TokenStream1) -> TokenStream1 {
11    let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse struct");
12    let result = match ast.data {
13        syn::Data::Enum(_) => panic!("#[derive(Equivalence)] is not compatible with enums"),
14        syn::Data::Union(_) => panic!("#[derive(Equivalence)] is not compatible with unions"),
15        syn::Data::Struct(ref s) => equivalence_for_struct(&ast, &s.fields),
16    };
17    result.into()
18}
19
20fn equivalence_for_tuple_field(type_tuple: &syn::TypeTuple) -> TokenStream2 {
21    let field_blocklengths = type_tuple.elems.iter().map(|_| 1);
22
23    let fields = type_tuple
24        .elems
25        .iter()
26        .enumerate()
27        .map(|(i, _)| syn::Index::from(i));
28
29    let field_datatypes = type_tuple.elems.iter().map(equivalence_for_type);
30
31    quote! {
32        &::mpi::datatype::UncommittedUserDatatype::structured(
33            &[#(#field_blocklengths as ::mpi::Count),*],
34            &[#(::mpi::internal::memoffset::offset_of_tuple!(#type_tuple, #fields) as ::mpi::Address),*],
35            &[#(::mpi::datatype::UncommittedDatatypeRef::from(#field_datatypes)),*],
36        )
37    }
38}
39
40fn equivalence_for_array_field(type_array: &syn::TypeArray) -> TokenStream2 {
41    let ty = equivalence_for_type(&type_array.elem);
42    let len = &type_array.len;
43    // We use the len block to ensure that len is of type `usize` and not type
44    // {integer}. We know that `#len` should be of type `usize` because it is an
45    // array size.
46    quote! { &::mpi::datatype::UncommittedUserDatatype::contiguous(
47        {let len: usize = #len; len}.try_into().expect("rsmpi derive: Array size is to large for MPI_Datatype i32"), &#ty)
48    }
49}
50
51fn equivalence_for_type(ty: &syn::Type) -> TokenStream2 {
52    match ty {
53        Type::Path(ref type_path) => quote!(
54                <#type_path as ::mpi::datatype::Equivalence>::equivalent_datatype()),
55        Type::Tuple(ref type_tuple) => equivalence_for_tuple_field(type_tuple),
56        Type::Array(ref type_array) => equivalence_for_array_field(type_array),
57        _ => panic!("Unsupported type!"),
58    }
59}
60
61fn equivalence_for_struct(ast: &syn::DeriveInput, fields: &Fields) -> TokenStream2 {
62    let ident = &ast.ident;
63
64    let field_blocklengths = fields.iter().map(|_| 1);
65
66    let field_names = fields
67        .iter()
68        .enumerate()
69        .map(|(i, field)| -> Box<dyn quote::ToTokens> {
70            if let Some(ident) = field.ident.as_ref() {
71                // named struct fields
72                Box::new(ident)
73            } else {
74                // tuple struct fields
75                Box::new(syn::Index::from(i))
76            }
77        });
78
79    let field_datatypes = fields.iter().map(|field| equivalence_for_type(&field.ty));
80
81    let ident_str = ident.to_string();
82
83    // TODO and NOTE: Technically this code can race with MPI init and finalize, as can any other
84    // code in rsmpi that interacts with the MPI library without taking a handle to `Universe`.
85    // This requires larger attention, and so currently this is not addressed.
86    quote! {
87        unsafe impl ::mpi::datatype::Equivalence for #ident {
88            type Out = ::mpi::datatype::DatatypeRef<'static>;
89            fn equivalent_datatype() -> Self::Out {
90                use ::mpi::internal::once_cell::sync::Lazy;
91                use ::std::convert::TryInto;
92
93                static DATATYPE: Lazy<::mpi::datatype::UserDatatype> = Lazy::new(|| {
94                    ::mpi::datatype::internal::check_derive_equivalence_universe_state(#ident_str);
95
96                    ::mpi::datatype::UserDatatype::structured::<
97                        ::mpi::datatype::UncommittedDatatypeRef,
98                    >(
99                        &[#(#field_blocklengths as ::mpi::Count),*],
100                        &[#(::mpi::internal::memoffset::offset_of!(#ident, #field_names) as ::mpi::Address),*],
101                        &[#(::mpi::datatype::UncommittedDatatypeRef::from(#field_datatypes)),*],
102                    )
103                });
104
105                DATATYPE.as_ref()
106            }
107        }
108    }
109}