einsum_codegen/codegen/ndarray/
mod.rs

1//! For [ndarray](https://crates.io/crates/ndarray) crate
2
3pub mod naive;
4
5use crate::subscripts::Subscripts;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{format_ident, quote};
8
9fn dim(n: usize) -> syn::Path {
10    let ix = quote::format_ident!("Ix{}", n);
11    syn::parse_quote! { ndarray::#ix }
12}
13
14/// Generate einsum function definition
15pub fn function_definition(subscripts: &Subscripts, inner: TokenStream2) -> TokenStream2 {
16    let fn_name = format_ident!("{}", subscripts.escaped_ident());
17    let n = subscripts.inputs.len();
18
19    let args = &subscripts.inputs;
20    let storages: Vec<syn::Ident> = (0..n).map(|n| quote::format_ident!("S{}", n)).collect();
21    let dims: Vec<syn::Path> = subscripts
22        .inputs
23        .iter()
24        .map(|ss| dim(ss.indices().len()))
25        .collect();
26
27    let out_dim = dim(subscripts.output.indices().len());
28
29    quote! {
30        fn #fn_name<T, #(#storages),*>(
31            #( #args: ndarray::ArrayBase<#storages, #dims> ),*
32        ) -> ndarray::Array<T, #out_dim>
33        where
34            T: ndarray::LinalgScalar,
35            #( #storages: ndarray::Data<Elem = T> ),*
36        {
37            #inner
38        }
39    }
40}
41
42#[cfg(test)]
43mod test {
44    use crate::{codegen::format_block, *};
45
46    #[test]
47    fn function_definition_snapshot() {
48        let mut namespace = Namespace::init();
49        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
50        let inner = quote::quote! { todo!() };
51        let tt = format_block(super::function_definition(&subscripts, inner).to_string());
52        insta::assert_snapshot!(tt, @r###"
53        fn ij_jk__ik<T, S0, S1>(
54            arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
55            arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
56        ) -> ndarray::Array<T, ndarray::Ix2>
57        where
58            T: ndarray::LinalgScalar,
59            S0: ndarray::Data<Elem = T>,
60            S1: ndarray::Data<Elem = T>,
61        {
62            todo!()
63        }
64        "###);
65    }
66}