#![doc = include_str!("../README.md")]
use einsum_codegen::{codegen::ndarray::*, *};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro_error::{abort_call_site, proc_macro_error};
use quote::quote;
use syn::parse::Parser;
#[proc_macro_error]
#[proc_macro]
pub fn einsum(input: TokenStream) -> TokenStream {
einsum2(input.into()).into()
}
fn einsum2(input: TokenStream2) -> TokenStream2 {
let (subscripts, args) = parse(input);
let arg_ident: Vec<_> = (0..args.len()).map(Position::Arg).collect();
let path = Path::brute_force(&subscripts).expect("Failed to construct execution path");
let fn_defs: Vec<_> = path
.iter()
.map(|ss| {
let inner = naive::inner(ss);
function_definition(ss, inner)
})
.collect();
let out = path.output();
if path.num_args() != args.len() {
abort_call_site!(
"Argument number mismatch: subscripts ({}), args ({})",
path.num_args(),
args.len()
)
}
quote! {
{
#(#fn_defs)*
#(let #arg_ident = #args;)*
#(#path)*
#out
}
}
}
fn parse(input: TokenStream2) -> (String, Vec<syn::Expr>) {
let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
let args = parser.parse2(input).expect("Invalid input for einsum!");
let mut iter = args.into_iter();
let subscripts = if let Some(syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit),
attrs: _,
})) = iter.next()
{
lit.value()
} else {
panic!("einsum! must start with subscript string literal")
};
let args = iter.collect::<Vec<_>>();
(subscripts, args)
}
#[cfg(test)]
mod test {
use super::*;
use einsum_codegen::codegen::format_block;
use std::str::FromStr;
#[test]
fn test_parse() {
let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
let (subscripts, exprs) = parse(input);
assert_eq!(subscripts, "ij,jk->ik");
assert_eq!(exprs.len(), 2);
assert_eq!(exprs[0], syn::parse_str::<syn::Expr>("a").unwrap());
assert_eq!(exprs[1], syn::parse_str::<syn::Expr>("b").unwrap());
}
#[test]
fn einsum_ij_jk() {
let input = TokenStream2::from_str(r#""ij,jk->ik", a, b"#).unwrap();
let tt = format_block(einsum2(input).to_string());
insta::assert_snapshot!(tt, @r###"
{
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
let (n_i, n_j) = arg0.dim();
let (_, n_k) = arg1.dim();
{
let (n_0, n_1) = arg0.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_j);
}
{
let (n_0, n_1) = arg1.dim();
assert_eq!(n_0, n_j);
assert_eq!(n_1, n_k);
}
let mut out0 = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
}
out0
}
let arg0 = a;
let arg1 = b;
let out0 = ij_jk__ik(arg0, arg1);
out0
}
"###);
}
#[test]
fn einsum_ij_jk_kl() {
let input = TokenStream2::from_str(r#""ij,jk,kl->il", a, b, c"#).unwrap();
let tt = format_block(einsum2(input).to_string());
insta::assert_snapshot!(tt, @r###"
{
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
let (n_i, n_j) = arg0.dim();
let (_, n_k) = arg1.dim();
{
let (n_0, n_1) = arg0.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_j);
}
{
let (n_0, n_1) = arg1.dim();
assert_eq!(n_0, n_j);
assert_eq!(n_1, n_k);
}
let mut out1 = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out1[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
}
out1
}
fn ik_kl__il<T, S0, S1>(
out1: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg2: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
let (n_i, n_k) = out1.dim();
let (_, n_l) = arg2.dim();
{
let (n_0, n_1) = out1.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_k);
}
{
let (n_0, n_1) = arg2.dim();
assert_eq!(n_0, n_k);
assert_eq!(n_1, n_l);
}
let mut out0 = ndarray::Array::zeros((n_i, n_l));
for i in 0..n_i {
for l in 0..n_l {
for k in 0..n_k {
out0[(i, l)] = out1[(i, k)] * arg2[(k, l)];
}
}
}
out0
}
let arg0 = a;
let arg1 = b;
let arg2 = c;
let out1 = ij_jk__ik(arg0, arg1);
let out0 = ik_kl__il(out1, arg2);
out0
}
"###);
}
}