einsum_codegen/codegen/ndarray/
naive.rs

1//! Generate einsum function with naive loop
2
3#[cfg(doc)]
4use super::function_definition;
5
6use crate::Subscripts;
7
8use proc_macro2::TokenStream as TokenStream2;
9use quote::quote;
10use std::collections::HashSet;
11
12fn index_ident(i: char) -> syn::Ident {
13    quote::format_ident!("{}", i)
14}
15
16fn n_ident(i: char) -> syn::Ident {
17    quote::format_ident!("n_{}", i)
18}
19
20fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 {
21    let mut tt = inner;
22    for &i in indices.iter().rev() {
23        let index = index_ident(i);
24        let n = n_ident(i);
25        tt = quote! {
26            for #index in 0..#n { #tt }
27        };
28    }
29    tt
30}
31
32fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 {
33    let mut inner_args_tt = Vec::new();
34    for (argc, arg) in subscripts.inputs.iter().enumerate() {
35        let mut index = Vec::new();
36        for i in subscripts.inputs[argc].indices() {
37            index.push(index_ident(i));
38        }
39        inner_args_tt.push(quote! {
40            #arg[(#(#index),*)]
41        })
42    }
43    let mut inner_mul = None;
44    for inner in inner_args_tt {
45        match inner_mul {
46            Some(i) => inner_mul = Some(quote! { #i * #inner }),
47            None => inner_mul = Some(inner),
48        }
49    }
50
51    let output_ident = &subscripts.output;
52    let mut output_indices = Vec::new();
53    for i in &subscripts.output.indices() {
54        let index = index_ident(*i);
55        output_indices.push(index.clone());
56    }
57    quote! {
58        #output_ident[(#(#output_indices),*)] = #inner_mul;
59    }
60}
61
62/// Generate naive contraction loop
63///
64/// ```
65/// # use ndarray::Array2;
66/// # let arg0 = Array2::<f64>::zeros((3, 3));
67/// # let arg1 = Array2::<f64>::zeros((3, 3));
68/// # let mut out0 = Array2::<f64>::zeros((3, 3));
69/// # let n_i = 3;
70/// # let n_j = 3;
71/// # let n_k = 3;
72/// for i in 0..n_i {
73///     for k in 0..n_k {
74///         for j in 0..n_j {
75///             out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
76///         }
77///     }
78/// }
79/// ```
80///
81pub fn contraction(subscripts: &Subscripts) -> TokenStream2 {
82    let mut indices: Vec<char> = subscripts.output.indices();
83    for i in subscripts.contraction_indices() {
84        indices.push(i);
85    }
86
87    let inner = contraction_inner(subscripts);
88    contraction_for(&indices, inner)
89}
90
91/// Define the index size identifiers, e.g. `n_i`
92pub fn define_array_size(subscripts: &Subscripts) -> TokenStream2 {
93    let mut appeared: HashSet<char> = HashSet::new();
94    let mut tt = Vec::new();
95    for arg in subscripts.inputs.iter() {
96        let n_ident: Vec<syn::Ident> = arg
97            .indices()
98            .into_iter()
99            .map(|i| {
100                if appeared.contains(&i) {
101                    quote::format_ident!("_")
102                } else {
103                    appeared.insert(i);
104                    n_ident(i)
105                }
106            })
107            .collect();
108        tt.push(quote! {
109            let (#(#n_ident),*) = #arg.dim();
110        });
111    }
112    quote! { #(#tt)* }
113}
114
115/// Generate `assert_eq!` to check the size of user input tensors
116pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 {
117    let mut tt = Vec::new();
118    for arg in &subscripts.inputs {
119        // local variable, e.g. `n_2`
120        let n_each: Vec<_> = (0..arg.indices().len())
121            .map(|m| quote::format_ident!("n_{}", m))
122            .collect();
123        // size of index defined previously, e.g. `n_i`
124        let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect();
125        tt.push(quote! {
126            let (#(#n_each),*) = #arg.dim();
127            #(assert_eq!(#n_each, #n);)*
128        });
129    }
130    quote! { #({ #tt })* }
131}
132
133fn define_output_array(subscripts: &Subscripts) -> TokenStream2 {
134    // Define output array
135    let output_ident = &subscripts.output;
136    let mut n_output = Vec::new();
137    for i in subscripts.output.indices() {
138        n_output.push(n_ident(i));
139    }
140    quote! {
141        let mut #output_ident = ndarray::Array::zeros((#(#n_output),*));
142    }
143}
144
145/// Actual component of einsum [function_definition]
146pub fn inner(subscripts: &Subscripts) -> TokenStream2 {
147    let array_size = define_array_size(subscripts);
148    let array_size_asserts = array_size_asserts(subscripts);
149    let output_ident = &subscripts.output;
150    let output_tt = define_output_array(subscripts);
151    let contraction_tt = contraction(subscripts);
152    quote! {
153        #array_size
154        #array_size_asserts
155        #output_tt
156        #contraction_tt
157        #output_ident
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use crate::{codegen::format_block, *};
164
165    #[test]
166    fn define_array_size() {
167        let mut namespace = Namespace::init();
168        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
169        let tt = format_block(super::define_array_size(&subscripts).to_string());
170        insta::assert_snapshot!(tt, @r###"
171        let (n_i, n_j) = arg0.dim();
172        let (_, n_k) = arg1.dim();
173        "###);
174    }
175
176    #[test]
177    fn contraction() {
178        let mut namespace = Namespace::init();
179        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
180        let tt = format_block(super::contraction(&subscripts).to_string());
181        insta::assert_snapshot!(tt, @r###"
182        for i in 0..n_i {
183            for k in 0..n_k {
184                for j in 0..n_j {
185                    out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
186                }
187            }
188        }
189        "###);
190    }
191
192    #[test]
193    fn inner() {
194        let mut namespace = Namespace::init();
195        let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
196        let tt = format_block(super::inner(&subscripts).to_string());
197        insta::assert_snapshot!(tt, @r###"
198        let (n_i, n_j) = arg0.dim();
199        let (_, n_k) = arg1.dim();
200        {
201            let (n_0, n_1) = arg0.dim();
202            assert_eq!(n_0, n_i);
203            assert_eq!(n_1, n_j);
204        }
205        {
206            let (n_0, n_1) = arg1.dim();
207            assert_eq!(n_0, n_j);
208            assert_eq!(n_1, n_k);
209        }
210        let mut out0 = ndarray::Array::zeros((n_i, n_k));
211        for i in 0..n_i {
212            for k in 0..n_k {
213                for j in 0..n_j {
214                    out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
215                }
216            }
217        }
218        out0
219        "###);
220    }
221}