einsum_impl/
lib.rs

1
2use core::fmt;
3use std::{collections::HashMap, io::Write, process::{Command, Stdio}, vec, fmt::Display};
4
5use proc_macro::{Span, TokenStream};
6use proc_macro2::TokenStream as TokenStream2;
7use syn::{parse::{self, Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Expr, ExprTuple, Ident, Member};
8use quote::{format_ident, quote, ToTokens};
9
10macro_rules! expect_token_err {
11    ($token:expr, $types:expr) => {
12        Err(Error::new($token.span(), format!("Expected a token of type(s) {}; got {:#?}", $types, $token)))
13    };
14}
15
16macro_rules! err {
17    ($($tt:tt),+) => {
18        Error::new(Span::call_site().into(), format!($($tt),+))
19    };
20}
21
22macro_rules! cast_expr {
23    ($token:expr, $ty:tt) => {
24        {
25            match $token {
26                Expr::$ty(expr) => Ok(expr),
27                Expr::Group(group) => {
28                    let expr = group.expr;
29                    match *expr {
30                        Expr::$ty(expr) => Ok(expr),
31                        _ => expect_token_err!(expr, stringify!(Expr::$ty))
32                    }
33                },
34                _ => expect_token_err!($token, stringify!(Expr::$ty))
35            }
36        }
37    };
38}
39
40macro_rules! cast_expr_ref {
41    ($token:expr, $ty:tt) => {
42        {
43            match $token {
44                Expr::$ty(expr) => Ok(expr),
45                Expr::Group(group) => {
46                    match &*group.expr {
47                        Expr::$ty(expr) => Ok(expr),
48                        _ => expect_token_err!(group.expr, stringify!(Expr::$ty))
49                    }
50                },
51                _ => expect_token_err!($token, stringify!(Expr::$ty))
52            }
53        }
54    };
55}
56
57fn expr_ident_string(expr: &Expr) -> Result<String, Error> {
58    match expr {
59        Expr::Path(path) => {
60            Ok(path.path.segments
61                .first()
62                .ok_or(Error::new_spanned(expr, "Couldn't get first item of path for ident string"))?
63                .ident.to_string()
64            )
65        }
66        _ => expect_token_err!(expr, "Expr::Path"),
67    }
68}
69
70#[derive(Debug)]
71struct Mat {
72    pub expr: Expr,
73    pub axes: String,
74}
75
76impl Mat {
77    pub fn from_expr(expr: Expr) -> Result<Self, Error> {
78        let field = cast_expr!(expr, Field)?;
79
80        let axes = if let Member::Named(ident) = &field.member {
81            ident.to_string()
82        } else {
83            return expect_token_err!(field.member, "Member::Named")
84        };
85
86        let expr = *field.base;
87
88
89        Ok(Self {
90            expr,
91            axes,
92        })
93    }
94}
95
96#[derive(Debug)]
97struct Axis {
98    char: char,
99    size: usize,
100    ident: Ident,
101}
102
103impl Axis {
104    pub fn new(char: char, size: usize) -> Self {
105        Self {
106            char,
107            size,
108            ident: Ident::new(&format!("axis_{}", char), Span::call_site().into()),
109        }
110    }
111
112}
113
114impl ToTokens for Axis {
115    fn to_tokens(&self, tokens: &mut TokenStream2) {
116        self.ident.to_tokens(tokens);
117    }
118}
119
120impl Display for Axis {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(f, "{}({})", self.char, self.size)
123    }
124}
125
126fn parse_mat_tuple(tuple: ExprTuple) -> Result<Vec<Mat>, Error> {
127    tuple.elems.into_iter().map(|x| Mat::from_expr(x)).collect::<Result<Vec<Mat>, Error>>()
128}
129
130#[derive(Debug)]
131struct EinsumArgs {
132    crate_expr: Expr,
133    input: Vec<Mat>,
134    output: Vec<Mat>,
135    axes: HashMap<char, Axis>,
136}
137
138impl Parse for EinsumArgs {
139    fn parse(input: ParseStream) -> parse::Result<Self> {
140        let punct: Punctuated<ExprTuple, syn::token::Comma> = Punctuated::parse_terminated(input)?;
141        let mut iter = punct.into_iter();
142        let err = Error::new(input.span(), "Not enough args");
143        let [crate_expr, input_expr, output_expr, dims_expr] = [
144            iter.next().ok_or(err.clone())?,
145            iter.next().ok_or(err.clone())?,
146            iter.next().ok_or(err.clone())?,
147            iter.next().ok_or(err)?
148        ];
149
150        Ok(Self {
151            crate_expr: crate_expr.elems.first().ok_or(Error::new(crate_expr.span(), "Couldn't get first item of crate_expr"))?.clone(),
152            input: parse_mat_tuple(input_expr)?,
153            output: parse_mat_tuple(output_expr)?,
154            axes: dims_expr.elems.iter().map(|x| {
155                let x = match x {
156                    Expr::Tuple(x) => x,
157                    _ => return expect_token_err!(x, "Expr::Tuple")
158                };
159                let axis = expr_ident_string(&x.elems[0])?;
160
161                let dim_expr = cast_expr_ref!(&x.elems[1], Lit)?;
162                let dim = match &dim_expr.lit {
163                    syn::Lit::Int(int) => int.base10_parse::<usize>()?,
164                    _ => return Err(Error::new(dim_expr.lit.span(), format!("Expected an integer, got {:?}", dim_expr.lit)))
165                };
166
167                let char = axis.chars().next().ok_or(Error::new(input.span(), "Couldn't read axis chars"))?;
168                Ok((char, Axis::new(char, dim)))
169            }).collect::<Result<HashMap<_, _>, Error>>()?,
170        })
171    }
172}
173
174
175//https://optimized-einsum.readthedocs.io/en/stable/autosummary/opt_einsum.contract.ContractExpression.html#opt_einsum.contract.ContractExpression"
176
177
178// Reference: einsum!(a.ij, b.jk => c.kj; a 1, b 2)
179//              => ((a.ij, b.jk), (c.kj), ((a, 1), (b, 2)))
180
181#[proc_macro]
182pub fn einsum_impl(stream: TokenStream) -> TokenStream {
183    println!("Start");
184    let args = parse_macro_input!(stream);
185
186    let res = handle_errors(do_einsum(&args));
187    
188    quote!{{ #res }}.into()
189}
190
191fn handle_errors(result: Result<TokenStream2, Error>) -> TokenStream2 {
192    match result {
193        Ok(res) => res,
194        Err(err) => err.to_compile_error()
195    }
196}
197
198fn do_einsum(args: &EinsumArgs) -> Result<TokenStream2, Error> {
199    // Get the optimized contraction order.
200    let opt = get_opt(&args)?;
201
202    let EinsumArgs { crate_expr, input, output: _, axes: dims, .. } = args;
203
204    let mut tokens: Vec<TokenStream2> = vec![];
205
206    struct MatInfo {
207        ident: Ident,
208        axes: String,
209        id: usize,
210    }
211
212    impl ToTokens for MatInfo {
213        fn to_tokens(&self, tokens: &mut TokenStream2) {
214            self.ident.to_tokens(tokens);
215        }
216    }
217
218    let mut idents = vec![];
219    let mut exprs = vec![];
220
221    let mut mats = vec![];
222    for (i, mat) in input.iter().enumerate() {
223        let ident = Ident::new(&format!("mat_{}", i), mat.expr.span());
224        let expr = &mat.expr;
225        idents.push(ident.clone());
226        exprs.push(expr);
227        
228        mats.push(MatInfo { ident, axes: mat.axes.clone(), id: i });
229    }
230
231    tokens.push(quote!{
232        // Do this all on one line to avoid accidentally shadowing the variables.
233    });
234
235    let mut lhs = mats.iter().map(|x| x.id).collect::<Vec<usize>>();
236    let mut out_dim = vec![];
237
238    for (mut i, mut j, contraction) in opt {
239        if i > j {
240            std::mem::swap(&mut i, &mut j);
241        }
242
243        let out = MatInfo {
244            ident: Ident::new(&format!("mat_{}", mats.len()), Span::call_site().into()),
245            axes: contraction.split("->").nth(1)
246                .ok_or(err!("Where is the second part of the contraction? Implicit contractions aren't allowed"))?
247                .to_string(),
248            id: mats.len(),
249        };
250
251        let mut dim_tuple = vec![];
252
253        println!("Contraction: {}", contraction);
254        println!("---");
255
256        for axis in out.axes.chars() {
257            let size = dims.get(&axis).expect(format!("Axis {} not found in dims", axis).as_str()).size;
258            dim_tuple.push(quote! {
259                #size,
260            });
261        }
262
263        tokens.push(quote! {
264            let mut #out = ndarray::Array::<T, _>::zeros((#(#dim_tuple)*));
265        });
266
267        mats.push(out);
268
269        // The two matrices to contract
270        let a = &mats[lhs.remove(i)];
271        let b = &mats[lhs.remove(j - 1)];
272        let out = &mats.last().unwrap();
273        let out_axes = out.axes.chars().map(|x| dims.get(&x).expect("Internal error: no dim?"));
274        out_dim = out.axes.chars().map(|x| dims.get(&x).expect("Internal error: no dim?").size).collect();
275
276        lhs.push(out.id);
277
278        // For usage inside quote!{}
279        let a_axes = a.axes.chars().map(|x| dims.get(&x));
280        let b_axes = b.axes.chars().map(|x| dims.get(&x));
281
282        let mut all_axes: Vec<char> = vec![];
283
284        for axis in (a.axes.clone() + &b.axes + &out.axes).chars() {
285            if !all_axes.contains(&axis) {
286                all_axes.push(axis);
287            }   
288        }
289
290        // Now we actually do the math.
291
292        // This is the inner body of the loop. We will build this,
293        // then wrap it in the next loop, and so on.
294        let mut body = quote! {
295            #out[(#(#out_axes),*)] += #a[(#(#a_axes),*)] * #b[(#(#b_axes),*)];
296        };
297
298        // Reverse the iterator, since we are doing C ordering (as opposed to Fortran)
299        // TODO: Support Fortran ordering
300        for axis in all_axes.iter().rev() {
301            let axis = dims.get(axis).expect("Internal error: no dim?");
302            let size = axis.size;
303
304            body = quote! {
305                for #axis in 0..#size {
306                    #body
307                }
308            }
309        }
310
311        tokens.push(body);
312    }
313
314    let out = &mats[lhs[0]];
315
316    let mut input_generics_defs = vec![];
317    let mut input_generics = vec![];
318    for i in 0..idents.len() {
319        let ident = format_ident!("I{}", i);
320        input_generics.push(ident.clone());
321        input_generics_defs.push(quote! {
322            #ident: ndarray::Dimension
323        });
324    }
325
326    let dim_len = out_dim.len();
327    let input_index_tys = input.iter().map(|x| (0..x.axes.len()).map(|_| quote!{usize}).collect::<Vec<_>>()).collect::<Vec<_>>();
328
329    let final_expr = quote! {
330        // Use a function for type inference.
331        #[inline]
332        fn __einsum_impl<T: #crate_expr::ArrayNumericType, #(#input_generics_defs),*>
333        (#(#idents: &ndarray::Array<T, #input_generics>),*) -> ndarray::Array<T, ndarray::Dim<[usize; #dim_len]>>
334        where #((#(#input_index_tys),*): ndarray::NdIndex<#input_generics>),* {
335            #(#tokens)*
336            #out
337        }
338        __einsum_impl(#(&#exprs),*)
339    };
340
341    Ok(final_expr.into())
342}
343
344fn get_opt(args: &EinsumArgs) -> Result<Vec<(usize, usize, String)>, Error> {
345    let EinsumArgs { input, output, axes: dims, .. } = args;
346    let str_input = input.iter().map(|x| x.axes.clone()).collect::<Vec<String>>().join(",");
347    let str_output = output.iter().map(|x| x.axes.clone()).collect::<Vec<String>>().join(",");
348    let opt_einsum_input = format!("{str_input}->{str_output}");
349
350    let mut dim_str = String::new();
351
352    for mat in input {
353        dim_str.push_str("(");
354        for axis in mat.axes.chars() {
355            let Some(axis) = dims.get(&axis) else {
356                return Err(Error::new(Span::call_site().into(), format!("Axis {} not found in dims", axis)));
357            };
358
359            dim_str.push_str(format!("{},", axis.size).as_str());
360        }
361        dim_str.push_str("), ");
362    }
363
364    fn pyerr(pretext: &str) -> impl Fn(std::io::Error) -> Error {
365        let pretext = pretext.to_string();
366        move |err| Error::new(Span::call_site().into(), format!("{}: {}", pretext, err))
367    }
368
369    let py = Command::new("python")
370        .stdin(Stdio::piped())
371        .stdout(Stdio::piped())
372        .stderr(Stdio::piped())
373        .spawn()
374        .map_err(pyerr("Error while trying to spawn Python process"))?;
375
376    let code = format!(r#"
377import opt_einsum as oe
378expr = oe.contract_expression("{opt_einsum_input}", {dim_str})
379print("\n".join([";".join([str(contraction[0][0]), str(contraction[0][1]), contraction[2]]) for contraction in expr.contraction_list]))
380"#);
381    println!("{}", code);
382    
383    let mut stdin = py.stdin.as_ref().ok_or(err!("Couldn't get stdin for Python process"))?;
384    stdin.write(code.as_bytes()).map_err(pyerr("Error while writing to Python process"))?;
385
386    let output = py.wait_with_output()
387        .map_err(pyerr("Couldn't wait on Python process"))?;
388
389    if !output.status.success() {
390        let code = output.status.code().unwrap_or(-1);
391        let err = String::from_utf8(output.stderr).unwrap_or("Error while reading Python process stderr".to_string());
392        let out = String::from_utf8(output.stdout).unwrap_or("Error while reading Python process stdout".to_string());
393        return Err(err!("Python process failed with non-zero exit code: {}\nstdout:\n{}\nstderr: {}", code, err, out));
394    }
395
396    let out = String::from_utf8(
397        output.stdout
398    ).map_err(|x| Error::new(
399        Span::call_site().into(),
400        format!("Error while parsing Python output as utf8: {}", x)
401    ))?;
402
403    let mut list = Vec::new();
404
405    println!("{}", out);
406
407    let int_err = |err| Error::new(Span::call_site().into(), format!("Error while parsing integer from Python opt_einsum: {}", err));
408    let not_enough_err = Error::new(Span::call_site().into(), "Not enough items in contraction list returned from Python opt_einsum");
409
410    for line in out.lines() {
411        let line = line.trim();
412
413        let mut iter = line.split(";").peekable();
414        while iter.peek().is_some() {
415            list.push((
416                iter.next().ok_or(not_enough_err.clone())?.parse().map_err(int_err)?,
417                iter.next().ok_or(not_enough_err.clone())?.parse().map_err(int_err)?,
418                iter.next().ok_or(not_enough_err.clone())?.to_string()
419            ));
420        }   
421    }
422
423    Ok(list)
424}