concat_arrays/
lib.rs

1use {
2    syn::{
3        Expr,
4        punctuated::Punctuated,
5        parse::{Parse, ParseStream},
6    },
7    proc_macro::TokenStream,
8    quote::{format_ident, quote},
9};
10
11struct Args {
12    punctuated: Punctuated<Expr, syn::Token![,]>,
13}
14
15impl Parse for Args {
16    fn parse(input: ParseStream) -> syn::parse::Result<Args> {
17        let punctuated = Punctuated::parse_terminated(input)?;
18        Ok(Args { punctuated })
19    }
20}
21
22/// Concatenates arrays.
23///
24/// # Example
25///
26/// ```
27/// # use concat_arrays::concat_arrays;
28/// let x = [0];
29/// let y = [1, 2];
30/// let z = [3, 4, 5];
31/// let concatenated = concat_arrays!(x, y, z);
32/// assert_eq!(concatenated, [0, 1, 2, 3, 4, 5]);
33/// ```
34///
35/// # Limitations
36///
37/// Due to limitations in rust `concat_arrays!` can't tell the compiler what the length of the
38/// returned array is. As such, the length needs to be inferable from the surrounding context. For
39/// example, in the example above the length is inferred by the call to `assert_eq!`. It is safe to
40/// mis-specify the length however since you'll get a compilation error rather than broken code.
41#[proc_macro]
42pub fn concat_arrays(tokens: TokenStream) -> TokenStream {
43    let arrays = syn::parse_macro_input!(tokens as Args);
44    let arrays: Vec<Expr> = arrays.punctuated.into_iter().collect();
45    let num_arrays = arrays.len();
46    let field_names = {
47        let mut field_names = Vec::with_capacity(num_arrays);
48        for i in 0..num_arrays {
49            field_names.push(format_ident!("concat_arrays_arg_{}", i));
50        }
51        field_names
52    };
53    let define_concat_arrays_type = {
54        let type_arg_names = {
55            let mut type_arg_names = Vec::with_capacity(num_arrays);
56            for i in 0..num_arrays {
57                type_arg_names.push(format_ident!("ConcatArraysArg{}", i));
58            }
59            type_arg_names
60        };
61        quote! {
62            #[repr(C)]
63            struct ConcatArrays<#(#type_arg_names,)*> {
64                #(#field_names: #type_arg_names,)*
65            }
66        }
67    };
68    let num_arrays_plus_one = num_arrays + 1;
69    let ret = quote! {{
70        #(
71            let #field_names = #arrays;
72        )*
73        if false {
74            #define_concat_arrays_type
75
76            fn constrain_concat_arrays_argument_to_be_an_array<T, const ARRAY_ARG_LEN: usize>(
77                concat_arrays_arg: &[T; ARRAY_ARG_LEN],
78            ) {
79                let _ = concat_arrays_arg;
80            }
81            #(
82                constrain_concat_arrays_argument_to_be_an_array(&#field_names);
83            )*
84            #[repr(C)]
85            struct ArrayElement {
86                _unused: u8,
87            }
88            fn un_zst_array<T, const LEN: usize>(array: &[T; LEN]) -> [ArrayElement; LEN] {
89                ::core::unreachable!()
90            }
91            let concat_non_zst_arrays = ConcatArrays {
92                #(#field_names: un_zst_array(&#field_names),)*
93            };
94            fn infer_length_of_concatenated_array<T, const INFERRED_LENGTH_OF_CONCATENATED_ARRAY: usize>()
95                -> (
96                    [T; INFERRED_LENGTH_OF_CONCATENATED_ARRAY],
97                    [ArrayElement; INFERRED_LENGTH_OF_CONCATENATED_ARRAY],
98                )
99            {
100                ::core::unreachable!()
101            }
102            let (concatenated_array, mut concatenated_non_zst_array) = infer_length_of_concatenated_array();
103            let _constrain_array_element_types_to_be_equal: [&[_]; #num_arrays_plus_one] = [
104                &concatenated_array[..],
105                #(
106                    &#field_names[..],
107                )*
108            ];
109            concatenated_non_zst_array = unsafe {
110                ::core::mem::transmute(concat_non_zst_arrays)
111            };
112            ::core::mem::drop(concatenated_non_zst_array);
113            concatenated_array
114        } else {
115            #define_concat_arrays_type
116
117            let concat_arrays = ConcatArrays {
118                #(#field_names,)*
119            };
120            unsafe {
121                ::core::mem::transmute(concat_arrays)
122            }
123        }
124    }};
125    ret.into()
126}