einsum_codegen/codegen/ndarray/
mod.rs1pub mod naive;
4
5use crate::subscripts::Subscripts;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{format_ident, quote};
8
9fn dim(n: usize) -> syn::Path {
10 let ix = quote::format_ident!("Ix{}", n);
11 syn::parse_quote! { ndarray::#ix }
12}
13
14pub fn function_definition(subscripts: &Subscripts, inner: TokenStream2) -> TokenStream2 {
16 let fn_name = format_ident!("{}", subscripts.escaped_ident());
17 let n = subscripts.inputs.len();
18
19 let args = &subscripts.inputs;
20 let storages: Vec<syn::Ident> = (0..n).map(|n| quote::format_ident!("S{}", n)).collect();
21 let dims: Vec<syn::Path> = subscripts
22 .inputs
23 .iter()
24 .map(|ss| dim(ss.indices().len()))
25 .collect();
26
27 let out_dim = dim(subscripts.output.indices().len());
28
29 quote! {
30 fn #fn_name<T, #(#storages),*>(
31 #( #args: ndarray::ArrayBase<#storages, #dims> ),*
32 ) -> ndarray::Array<T, #out_dim>
33 where
34 T: ndarray::LinalgScalar,
35 #( #storages: ndarray::Data<Elem = T> ),*
36 {
37 #inner
38 }
39 }
40}
41
42#[cfg(test)]
43mod test {
44 use crate::{codegen::format_block, *};
45
46 #[test]
47 fn function_definition_snapshot() {
48 let mut namespace = Namespace::init();
49 let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
50 let inner = quote::quote! { todo!() };
51 let tt = format_block(super::function_definition(&subscripts, inner).to_string());
52 insta::assert_snapshot!(tt, @r###"
53 fn ij_jk__ik<T, S0, S1>(
54 arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
55 arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
56 ) -> ndarray::Array<T, ndarray::Ix2>
57 where
58 T: ndarray::LinalgScalar,
59 S0: ndarray::Data<Elem = T>,
60 S1: ndarray::Data<Elem = T>,
61 {
62 todo!()
63 }
64 "###);
65 }
66}