1#![doc = include_str!("../README.md")]
2
3use einsum_codegen::{codegen::ndarray::*, *};
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use proc_macro_error::{abort_call_site, proc_macro_error};
7use quote::quote;
8use syn::parse::Parser;
9
10#[proc_macro_error]
12#[proc_macro]
13pub fn einsum(input: TokenStream) -> TokenStream {
14 einsum2(input.into()).into()
15}
16
17fn einsum2(input: TokenStream2) -> TokenStream2 {
18 let (subscripts, args) = parse(input);
19 let arg_ident: Vec<_> = (0..args.len()).map(Position::Arg).collect();
20 let path = Path::brute_force(&subscripts).expect("Failed to construct execution path");
21 let fn_defs: Vec<_> = path
22 .iter()
23 .map(|ss| {
24 let inner = naive::inner(ss);
25 function_definition(ss, inner)
26 })
27 .collect();
28 let out = path.output();
29 if path.num_args() != args.len() {
30 abort_call_site!(
31 "Argument number mismatch: subscripts ({}), args ({})",
32 path.num_args(),
33 args.len()
34 )
35 }
36
37 quote! {
38 {
39 #(#fn_defs)*
40 #(let #arg_ident = #args;)*
41 #(#path)*
42 #out
43 }
44 }
45}
46
47fn parse(input: TokenStream2) -> (String, Vec<syn::Expr>) {
48 let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
49 let args = parser.parse2(input).expect("Invalid input for einsum!");
50 let mut iter = args.into_iter();
51 let subscripts = if let Some(syn::Expr::Lit(syn::ExprLit {
52 lit: syn::Lit::Str(lit),
53 attrs: _,
54 })) = iter.next()
55 {
56 lit.value()
57 } else {
58 panic!("einsum! must start with subscript string literal")
59 };
60 let args = iter.collect::<Vec<_>>();
61 (subscripts, args)
62}
63
64#[cfg(test)]
65mod test {
66 use super::*;
67 use einsum_codegen::codegen::format_block;
68 use std::str::FromStr;
69
70 #[test]
71 fn test_parse() {
72 let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
73 let (subscripts, exprs) = parse(input);
74 assert_eq!(subscripts, "ij,jk->ik");
75 assert_eq!(exprs.len(), 2);
76 assert_eq!(exprs[0], syn::parse_str::<syn::Expr>("a").unwrap());
77 assert_eq!(exprs[1], syn::parse_str::<syn::Expr>("b").unwrap());
78 }
79
80 #[test]
81 fn einsum_ij_jk() {
82 let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
83 let tt = format_block(einsum2(input).to_string());
84 insta::assert_snapshot!(tt, @r###"
85 {
86 fn ij_jk__ik<T, S0, S1>(
87 arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
88 arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
89 ) -> ndarray::Array<T, ndarray::Ix2>
90 where
91 T: ndarray::LinalgScalar,
92 S0: ndarray::Data<Elem = T>,
93 S1: ndarray::Data<Elem = T>,
94 {
95 let (n_i, n_j) = arg0.dim();
96 let (_, n_k) = arg1.dim();
97 {
98 let (n_0, n_1) = arg0.dim();
99 assert_eq!(n_0, n_i);
100 assert_eq!(n_1, n_j);
101 }
102 {
103 let (n_0, n_1) = arg1.dim();
104 assert_eq!(n_0, n_j);
105 assert_eq!(n_1, n_k);
106 }
107 let mut out0 = ndarray::Array::zeros((n_i, n_k));
108 for i in 0..n_i {
109 for k in 0..n_k {
110 for j in 0..n_j {
111 out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
112 }
113 }
114 }
115 out0
116 }
117 let arg0 = a;
118 let arg1 = b;
119 let out0 = ij_jk__ik(arg0, arg1);
120 out0
121 }
122 "###);
123 }
124
125 #[test]
126 fn einsum_ij_jk_kl() {
127 let input = TokenStream2::from_str(r#""ij,jk,kl->il", a, b, c"#).unwrap();
128 let tt = format_block(einsum2(input).to_string());
129 insta::assert_snapshot!(tt, @r###"
130 {
131 fn ij_jk__ik<T, S0, S1>(
132 arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
133 arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
134 ) -> ndarray::Array<T, ndarray::Ix2>
135 where
136 T: ndarray::LinalgScalar,
137 S0: ndarray::Data<Elem = T>,
138 S1: ndarray::Data<Elem = T>,
139 {
140 let (n_i, n_j) = arg0.dim();
141 let (_, n_k) = arg1.dim();
142 {
143 let (n_0, n_1) = arg0.dim();
144 assert_eq!(n_0, n_i);
145 assert_eq!(n_1, n_j);
146 }
147 {
148 let (n_0, n_1) = arg1.dim();
149 assert_eq!(n_0, n_j);
150 assert_eq!(n_1, n_k);
151 }
152 let mut out1 = ndarray::Array::zeros((n_i, n_k));
153 for i in 0..n_i {
154 for k in 0..n_k {
155 for j in 0..n_j {
156 out1[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
157 }
158 }
159 }
160 out1
161 }
162 fn ik_kl__il<T, S0, S1>(
163 out1: ndarray::ArrayBase<S0, ndarray::Ix2>,
164 arg2: ndarray::ArrayBase<S1, ndarray::Ix2>,
165 ) -> ndarray::Array<T, ndarray::Ix2>
166 where
167 T: ndarray::LinalgScalar,
168 S0: ndarray::Data<Elem = T>,
169 S1: ndarray::Data<Elem = T>,
170 {
171 let (n_i, n_k) = out1.dim();
172 let (_, n_l) = arg2.dim();
173 {
174 let (n_0, n_1) = out1.dim();
175 assert_eq!(n_0, n_i);
176 assert_eq!(n_1, n_k);
177 }
178 {
179 let (n_0, n_1) = arg2.dim();
180 assert_eq!(n_0, n_k);
181 assert_eq!(n_1, n_l);
182 }
183 let mut out0 = ndarray::Array::zeros((n_i, n_l));
184 for i in 0..n_i {
185 for l in 0..n_l {
186 for k in 0..n_k {
187 out0[(i, l)] = out1[(i, k)] * arg2[(k, l)];
188 }
189 }
190 }
191 out0
192 }
193 let arg0 = a;
194 let arg1 = b;
195 let arg2 = c;
196 let out1 = ij_jk__ik(arg0, arg1);
197 let out0 = ik_kl__il(out1, arg2);
198 out0
199 }
200 "###);
201 }
202}