einsum_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use einsum_codegen::{codegen::ndarray::*, *};
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use proc_macro_error::{abort_call_site, proc_macro_error};
7use quote::quote;
8use syn::parse::Parser;
9
10/// proc-macro based einsum
11#[proc_macro_error]
12#[proc_macro]
13pub fn einsum(input: TokenStream) -> TokenStream {
14    einsum2(input.into()).into()
15}
16
17fn einsum2(input: TokenStream2) -> TokenStream2 {
18    let (subscripts, args) = parse(input);
19    let arg_ident: Vec<_> = (0..args.len()).map(Position::Arg).collect();
20    let path = Path::brute_force(&subscripts).expect("Failed to construct execution path");
21    let fn_defs: Vec<_> = path
22        .iter()
23        .map(|ss| {
24            let inner = naive::inner(ss);
25            function_definition(ss, inner)
26        })
27        .collect();
28    let out = path.output();
29    if path.num_args() != args.len() {
30        abort_call_site!(
31            "Argument number mismatch: subscripts ({}), args ({})",
32            path.num_args(),
33            args.len()
34        )
35    }
36
37    quote! {
38        {
39            #(#fn_defs)*
40            #(let #arg_ident = #args;)*
41            #(#path)*
42            #out
43        }
44    }
45}
46
47fn parse(input: TokenStream2) -> (String, Vec<syn::Expr>) {
48    let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
49    let args = parser.parse2(input).expect("Invalid input for einsum!");
50    let mut iter = args.into_iter();
51    let subscripts = if let Some(syn::Expr::Lit(syn::ExprLit {
52        lit: syn::Lit::Str(lit),
53        attrs: _,
54    })) = iter.next()
55    {
56        lit.value()
57    } else {
58        panic!("einsum! must start with subscript string literal")
59    };
60    let args = iter.collect::<Vec<_>>();
61    (subscripts, args)
62}
63
64#[cfg(test)]
65mod test {
66    use super::*;
67    use einsum_codegen::codegen::format_block;
68    use std::str::FromStr;
69
70    #[test]
71    fn test_parse() {
72        let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
73        let (subscripts, exprs) = parse(input);
74        assert_eq!(subscripts, "ij,jk->ik");
75        assert_eq!(exprs.len(), 2);
76        assert_eq!(exprs[0], syn::parse_str::<syn::Expr>("a").unwrap());
77        assert_eq!(exprs[1], syn::parse_str::<syn::Expr>("b").unwrap());
78    }
79
80    #[test]
81    fn einsum_ij_jk() {
82        let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
83        let tt = format_block(einsum2(input).to_string());
84        insta::assert_snapshot!(tt, @r###"
85        {
86            fn ij_jk__ik<T, S0, S1>(
87                arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
88                arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
89            ) -> ndarray::Array<T, ndarray::Ix2>
90            where
91                T: ndarray::LinalgScalar,
92                S0: ndarray::Data<Elem = T>,
93                S1: ndarray::Data<Elem = T>,
94            {
95                let (n_i, n_j) = arg0.dim();
96                let (_, n_k) = arg1.dim();
97                {
98                    let (n_0, n_1) = arg0.dim();
99                    assert_eq!(n_0, n_i);
100                    assert_eq!(n_1, n_j);
101                }
102                {
103                    let (n_0, n_1) = arg1.dim();
104                    assert_eq!(n_0, n_j);
105                    assert_eq!(n_1, n_k);
106                }
107                let mut out0 = ndarray::Array::zeros((n_i, n_k));
108                for i in 0..n_i {
109                    for k in 0..n_k {
110                        for j in 0..n_j {
111                            out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
112                        }
113                    }
114                }
115                out0
116            }
117            let arg0 = a;
118            let arg1 = b;
119            let out0 = ij_jk__ik(arg0, arg1);
120            out0
121        }
122        "###);
123    }
124
125    #[test]
126    fn einsum_ij_jk_kl() {
127        let input = TokenStream2::from_str(r#""ij,jk,kl->il", a, b, c"#).unwrap();
128        let tt = format_block(einsum2(input).to_string());
129        insta::assert_snapshot!(tt, @r###"
130        {
131            fn ij_jk__ik<T, S0, S1>(
132                arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
133                arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
134            ) -> ndarray::Array<T, ndarray::Ix2>
135            where
136                T: ndarray::LinalgScalar,
137                S0: ndarray::Data<Elem = T>,
138                S1: ndarray::Data<Elem = T>,
139            {
140                let (n_i, n_j) = arg0.dim();
141                let (_, n_k) = arg1.dim();
142                {
143                    let (n_0, n_1) = arg0.dim();
144                    assert_eq!(n_0, n_i);
145                    assert_eq!(n_1, n_j);
146                }
147                {
148                    let (n_0, n_1) = arg1.dim();
149                    assert_eq!(n_0, n_j);
150                    assert_eq!(n_1, n_k);
151                }
152                let mut out1 = ndarray::Array::zeros((n_i, n_k));
153                for i in 0..n_i {
154                    for k in 0..n_k {
155                        for j in 0..n_j {
156                            out1[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
157                        }
158                    }
159                }
160                out1
161            }
162            fn ik_kl__il<T, S0, S1>(
163                out1: ndarray::ArrayBase<S0, ndarray::Ix2>,
164                arg2: ndarray::ArrayBase<S1, ndarray::Ix2>,
165            ) -> ndarray::Array<T, ndarray::Ix2>
166            where
167                T: ndarray::LinalgScalar,
168                S0: ndarray::Data<Elem = T>,
169                S1: ndarray::Data<Elem = T>,
170            {
171                let (n_i, n_k) = out1.dim();
172                let (_, n_l) = arg2.dim();
173                {
174                    let (n_0, n_1) = out1.dim();
175                    assert_eq!(n_0, n_i);
176                    assert_eq!(n_1, n_k);
177                }
178                {
179                    let (n_0, n_1) = arg2.dim();
180                    assert_eq!(n_0, n_k);
181                    assert_eq!(n_1, n_l);
182                }
183                let mut out0 = ndarray::Array::zeros((n_i, n_l));
184                for i in 0..n_i {
185                    for l in 0..n_l {
186                        for k in 0..n_k {
187                            out0[(i, l)] = out1[(i, k)] * arg2[(k, l)];
188                        }
189                    }
190                }
191                out0
192            }
193            let arg0 = a;
194            let arg1 = b;
195            let arg2 = c;
196            let out1 = ij_jk__ik(arg0, arg1);
197            let out0 = ik_kl__il(out1, arg2);
198            out0
199        }
200        "###);
201    }
202}