extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Expr, ExprArray};
#[proc_macro]
pub fn tensor_flatten(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ExprArray);
let mut sizes = Vec::new();
let mut level:usize = 1;
let flattened = flatten_array(&input,&mut sizes,&mut level);
let output = quote! {
{
let flattened = vec![#(#flattened),*];
let sizes = vec![#(#sizes),*];
(flattened, sizes)
}
};
TokenStream::from(output)
}
fn flatten_array(array: &ExprArray, sizes: &mut Vec<usize>,level: &mut usize) -> Vec<Expr> {
let mut flattened = Vec::new();
if sizes.len() < *level {
sizes.push(array.elems.len());
}
*level = *level + 1;
for element in &array.elems {
if let Expr::Array(nested_array) = element {
let nested_flattened = flatten_array(nested_array,sizes,level);
flattened.extend(nested_flattened);
} else {
flattened.push(element.clone());
}
}
*level = *level - 1;
flattened
}