vectrix_macro/
lib.rs

1use proc_macro::{self, TokenStream};
2use quote::quote;
3use syn::parse::{Parse, ParseStream, Result};
4use syn::punctuated::Punctuated;
5use syn::{parse_macro_input, Expr, Token};
6
7type Delimited<T> = Punctuated<T, Token![,]>;
8type Vector = Delimited<Expr>;
9type Matrix = Punctuated<Vector, Token![;]>;
10
11struct Input {
12    matrix: Matrix,
13}
14
15impl Parse for Input {
16    fn parse(input: ParseStream<'_>) -> Result<Self> {
17        let matrix = Matrix::parse_terminated_with(input, Vector::parse_separated_nonempty)?;
18        Ok(Self { matrix })
19    }
20}
21
22impl Input {
23    fn into_rows(self) -> Vec<Vec<Expr>> {
24        self.matrix
25            .into_iter()
26            .map(|vector| vector.into_iter().collect())
27            .collect()
28    }
29}
30
31#[proc_macro]
32pub fn matrix(input: TokenStream) -> TokenStream {
33    let rows = parse_macro_input!(input as Input).into_rows();
34
35    // Get the length of the first row, i.e. the number of columns
36    let n = rows.first().map_or(0, Vec::len);
37
38    // Transpose from row-major order to column-major order
39    let columns: Delimited<_> = (0..n)
40        .map(|column| {
41            let column: Vector = rows
42                .iter()
43                .filter_map(|row| row.get(column))
44                .cloned()
45                .collect();
46            quote! { [ #column ] }
47        })
48        .collect();
49
50    TokenStream::from(quote! { [ #columns ] })
51}