1
2use core::fmt;
3use std::{collections::HashMap, io::Write, process::{Command, Stdio}, vec, fmt::Display};
4
5use proc_macro::{Span, TokenStream};
6use proc_macro2::TokenStream as TokenStream2;
7use syn::{parse::{self, Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Expr, ExprTuple, Ident, Member};
8use quote::{format_ident, quote, ToTokens};
9
10macro_rules! expect_token_err {
11 ($token:expr, $types:expr) => {
12 Err(Error::new($token.span(), format!("Expected a token of type(s) {}; got {:#?}", $types, $token)))
13 };
14}
15
16macro_rules! err {
17 ($($tt:tt),+) => {
18 Error::new(Span::call_site().into(), format!($($tt),+))
19 };
20}
21
22macro_rules! cast_expr {
23 ($token:expr, $ty:tt) => {
24 {
25 match $token {
26 Expr::$ty(expr) => Ok(expr),
27 Expr::Group(group) => {
28 let expr = group.expr;
29 match *expr {
30 Expr::$ty(expr) => Ok(expr),
31 _ => expect_token_err!(expr, stringify!(Expr::$ty))
32 }
33 },
34 _ => expect_token_err!($token, stringify!(Expr::$ty))
35 }
36 }
37 };
38}
39
40macro_rules! cast_expr_ref {
41 ($token:expr, $ty:tt) => {
42 {
43 match $token {
44 Expr::$ty(expr) => Ok(expr),
45 Expr::Group(group) => {
46 match &*group.expr {
47 Expr::$ty(expr) => Ok(expr),
48 _ => expect_token_err!(group.expr, stringify!(Expr::$ty))
49 }
50 },
51 _ => expect_token_err!($token, stringify!(Expr::$ty))
52 }
53 }
54 };
55}
56
57fn expr_ident_string(expr: &Expr) -> Result<String, Error> {
58 match expr {
59 Expr::Path(path) => {
60 Ok(path.path.segments
61 .first()
62 .ok_or(Error::new_spanned(expr, "Couldn't get first item of path for ident string"))?
63 .ident.to_string()
64 )
65 }
66 _ => expect_token_err!(expr, "Expr::Path"),
67 }
68}
69
70#[derive(Debug)]
71struct Mat {
72 pub expr: Expr,
73 pub axes: String,
74}
75
76impl Mat {
77 pub fn from_expr(expr: Expr) -> Result<Self, Error> {
78 let field = cast_expr!(expr, Field)?;
79
80 let axes = if let Member::Named(ident) = &field.member {
81 ident.to_string()
82 } else {
83 return expect_token_err!(field.member, "Member::Named")
84 };
85
86 let expr = *field.base;
87
88
89 Ok(Self {
90 expr,
91 axes,
92 })
93 }
94}
95
96#[derive(Debug)]
97struct Axis {
98 char: char,
99 size: usize,
100 ident: Ident,
101}
102
103impl Axis {
104 pub fn new(char: char, size: usize) -> Self {
105 Self {
106 char,
107 size,
108 ident: Ident::new(&format!("axis_{}", char), Span::call_site().into()),
109 }
110 }
111
112}
113
114impl ToTokens for Axis {
115 fn to_tokens(&self, tokens: &mut TokenStream2) {
116 self.ident.to_tokens(tokens);
117 }
118}
119
120impl Display for Axis {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(f, "{}({})", self.char, self.size)
123 }
124}
125
126fn parse_mat_tuple(tuple: ExprTuple) -> Result<Vec<Mat>, Error> {
127 tuple.elems.into_iter().map(|x| Mat::from_expr(x)).collect::<Result<Vec<Mat>, Error>>()
128}
129
130#[derive(Debug)]
131struct EinsumArgs {
132 crate_expr: Expr,
133 input: Vec<Mat>,
134 output: Vec<Mat>,
135 axes: HashMap<char, Axis>,
136}
137
138impl Parse for EinsumArgs {
139 fn parse(input: ParseStream) -> parse::Result<Self> {
140 let punct: Punctuated<ExprTuple, syn::token::Comma> = Punctuated::parse_terminated(input)?;
141 let mut iter = punct.into_iter();
142 let err = Error::new(input.span(), "Not enough args");
143 let [crate_expr, input_expr, output_expr, dims_expr] = [
144 iter.next().ok_or(err.clone())?,
145 iter.next().ok_or(err.clone())?,
146 iter.next().ok_or(err.clone())?,
147 iter.next().ok_or(err)?
148 ];
149
150 Ok(Self {
151 crate_expr: crate_expr.elems.first().ok_or(Error::new(crate_expr.span(), "Couldn't get first item of crate_expr"))?.clone(),
152 input: parse_mat_tuple(input_expr)?,
153 output: parse_mat_tuple(output_expr)?,
154 axes: dims_expr.elems.iter().map(|x| {
155 let x = match x {
156 Expr::Tuple(x) => x,
157 _ => return expect_token_err!(x, "Expr::Tuple")
158 };
159 let axis = expr_ident_string(&x.elems[0])?;
160
161 let dim_expr = cast_expr_ref!(&x.elems[1], Lit)?;
162 let dim = match &dim_expr.lit {
163 syn::Lit::Int(int) => int.base10_parse::<usize>()?,
164 _ => return Err(Error::new(dim_expr.lit.span(), format!("Expected an integer, got {:?}", dim_expr.lit)))
165 };
166
167 let char = axis.chars().next().ok_or(Error::new(input.span(), "Couldn't read axis chars"))?;
168 Ok((char, Axis::new(char, dim)))
169 }).collect::<Result<HashMap<_, _>, Error>>()?,
170 })
171 }
172}
173
174
175#[proc_macro]
182pub fn einsum_impl(stream: TokenStream) -> TokenStream {
183 println!("Start");
184 let args = parse_macro_input!(stream);
185
186 let res = handle_errors(do_einsum(&args));
187
188 quote!{{ #res }}.into()
189}
190
191fn handle_errors(result: Result<TokenStream2, Error>) -> TokenStream2 {
192 match result {
193 Ok(res) => res,
194 Err(err) => err.to_compile_error()
195 }
196}
197
198fn do_einsum(args: &EinsumArgs) -> Result<TokenStream2, Error> {
199 let opt = get_opt(&args)?;
201
202 let EinsumArgs { crate_expr, input, output: _, axes: dims, .. } = args;
203
204 let mut tokens: Vec<TokenStream2> = vec![];
205
206 struct MatInfo {
207 ident: Ident,
208 axes: String,
209 id: usize,
210 }
211
212 impl ToTokens for MatInfo {
213 fn to_tokens(&self, tokens: &mut TokenStream2) {
214 self.ident.to_tokens(tokens);
215 }
216 }
217
218 let mut idents = vec![];
219 let mut exprs = vec![];
220
221 let mut mats = vec![];
222 for (i, mat) in input.iter().enumerate() {
223 let ident = Ident::new(&format!("mat_{}", i), mat.expr.span());
224 let expr = &mat.expr;
225 idents.push(ident.clone());
226 exprs.push(expr);
227
228 mats.push(MatInfo { ident, axes: mat.axes.clone(), id: i });
229 }
230
231 tokens.push(quote!{
232 });
234
235 let mut lhs = mats.iter().map(|x| x.id).collect::<Vec<usize>>();
236 let mut out_dim = vec![];
237
238 for (mut i, mut j, contraction) in opt {
239 if i > j {
240 std::mem::swap(&mut i, &mut j);
241 }
242
243 let out = MatInfo {
244 ident: Ident::new(&format!("mat_{}", mats.len()), Span::call_site().into()),
245 axes: contraction.split("->").nth(1)
246 .ok_or(err!("Where is the second part of the contraction? Implicit contractions aren't allowed"))?
247 .to_string(),
248 id: mats.len(),
249 };
250
251 let mut dim_tuple = vec![];
252
253 println!("Contraction: {}", contraction);
254 println!("---");
255
256 for axis in out.axes.chars() {
257 let size = dims.get(&axis).expect(format!("Axis {} not found in dims", axis).as_str()).size;
258 dim_tuple.push(quote! {
259 #size,
260 });
261 }
262
263 tokens.push(quote! {
264 let mut #out = ndarray::Array::<T, _>::zeros((#(#dim_tuple)*));
265 });
266
267 mats.push(out);
268
269 let a = &mats[lhs.remove(i)];
271 let b = &mats[lhs.remove(j - 1)];
272 let out = &mats.last().unwrap();
273 let out_axes = out.axes.chars().map(|x| dims.get(&x).expect("Internal error: no dim?"));
274 out_dim = out.axes.chars().map(|x| dims.get(&x).expect("Internal error: no dim?").size).collect();
275
276 lhs.push(out.id);
277
278 let a_axes = a.axes.chars().map(|x| dims.get(&x));
280 let b_axes = b.axes.chars().map(|x| dims.get(&x));
281
282 let mut all_axes: Vec<char> = vec![];
283
284 for axis in (a.axes.clone() + &b.axes + &out.axes).chars() {
285 if !all_axes.contains(&axis) {
286 all_axes.push(axis);
287 }
288 }
289
290 let mut body = quote! {
295 #out[(#(#out_axes),*)] += #a[(#(#a_axes),*)] * #b[(#(#b_axes),*)];
296 };
297
298 for axis in all_axes.iter().rev() {
301 let axis = dims.get(axis).expect("Internal error: no dim?");
302 let size = axis.size;
303
304 body = quote! {
305 for #axis in 0..#size {
306 #body
307 }
308 }
309 }
310
311 tokens.push(body);
312 }
313
314 let out = &mats[lhs[0]];
315
316 let mut input_generics_defs = vec![];
317 let mut input_generics = vec![];
318 for i in 0..idents.len() {
319 let ident = format_ident!("I{}", i);
320 input_generics.push(ident.clone());
321 input_generics_defs.push(quote! {
322 #ident: ndarray::Dimension
323 });
324 }
325
326 let dim_len = out_dim.len();
327 let input_index_tys = input.iter().map(|x| (0..x.axes.len()).map(|_| quote!{usize}).collect::<Vec<_>>()).collect::<Vec<_>>();
328
329 let final_expr = quote! {
330 #[inline]
332 fn __einsum_impl<T: #crate_expr::ArrayNumericType, #(#input_generics_defs),*>
333 (#(#idents: &ndarray::Array<T, #input_generics>),*) -> ndarray::Array<T, ndarray::Dim<[usize; #dim_len]>>
334 where #((#(#input_index_tys),*): ndarray::NdIndex<#input_generics>),* {
335 #(#tokens)*
336 #out
337 }
338 __einsum_impl(#(&#exprs),*)
339 };
340
341 Ok(final_expr.into())
342}
343
344fn get_opt(args: &EinsumArgs) -> Result<Vec<(usize, usize, String)>, Error> {
345 let EinsumArgs { input, output, axes: dims, .. } = args;
346 let str_input = input.iter().map(|x| x.axes.clone()).collect::<Vec<String>>().join(",");
347 let str_output = output.iter().map(|x| x.axes.clone()).collect::<Vec<String>>().join(",");
348 let opt_einsum_input = format!("{str_input}->{str_output}");
349
350 let mut dim_str = String::new();
351
352 for mat in input {
353 dim_str.push_str("(");
354 for axis in mat.axes.chars() {
355 let Some(axis) = dims.get(&axis) else {
356 return Err(Error::new(Span::call_site().into(), format!("Axis {} not found in dims", axis)));
357 };
358
359 dim_str.push_str(format!("{},", axis.size).as_str());
360 }
361 dim_str.push_str("), ");
362 }
363
364 fn pyerr(pretext: &str) -> impl Fn(std::io::Error) -> Error {
365 let pretext = pretext.to_string();
366 move |err| Error::new(Span::call_site().into(), format!("{}: {}", pretext, err))
367 }
368
369 let py = Command::new("python")
370 .stdin(Stdio::piped())
371 .stdout(Stdio::piped())
372 .stderr(Stdio::piped())
373 .spawn()
374 .map_err(pyerr("Error while trying to spawn Python process"))?;
375
376 let code = format!(r#"
377import opt_einsum as oe
378expr = oe.contract_expression("{opt_einsum_input}", {dim_str})
379print("\n".join([";".join([str(contraction[0][0]), str(contraction[0][1]), contraction[2]]) for contraction in expr.contraction_list]))
380"#);
381 println!("{}", code);
382
383 let mut stdin = py.stdin.as_ref().ok_or(err!("Couldn't get stdin for Python process"))?;
384 stdin.write(code.as_bytes()).map_err(pyerr("Error while writing to Python process"))?;
385
386 let output = py.wait_with_output()
387 .map_err(pyerr("Couldn't wait on Python process"))?;
388
389 if !output.status.success() {
390 let code = output.status.code().unwrap_or(-1);
391 let err = String::from_utf8(output.stderr).unwrap_or("Error while reading Python process stderr".to_string());
392 let out = String::from_utf8(output.stdout).unwrap_or("Error while reading Python process stdout".to_string());
393 return Err(err!("Python process failed with non-zero exit code: {}\nstdout:\n{}\nstderr: {}", code, err, out));
394 }
395
396 let out = String::from_utf8(
397 output.stdout
398 ).map_err(|x| Error::new(
399 Span::call_site().into(),
400 format!("Error while parsing Python output as utf8: {}", x)
401 ))?;
402
403 let mut list = Vec::new();
404
405 println!("{}", out);
406
407 let int_err = |err| Error::new(Span::call_site().into(), format!("Error while parsing integer from Python opt_einsum: {}", err));
408 let not_enough_err = Error::new(Span::call_site().into(), "Not enough items in contraction list returned from Python opt_einsum");
409
410 for line in out.lines() {
411 let line = line.trim();
412
413 let mut iter = line.split(";").peekable();
414 while iter.peek().is_some() {
415 list.push((
416 iter.next().ok_or(not_enough_err.clone())?.parse().map_err(int_err)?,
417 iter.next().ok_or(not_enough_err.clone())?.parse().map_err(int_err)?,
418 iter.next().ok_or(not_enough_err.clone())?.to_string()
419 ));
420 }
421 }
422
423 Ok(list)
424}