einsum_codegen/codegen/ndarray/
naive.rs1#[cfg(doc)]
4use super::function_definition;
5
6use crate::Subscripts;
7
8use proc_macro2::TokenStream as TokenStream2;
9use quote::quote;
10use std::collections::HashSet;
11
12fn index_ident(i: char) -> syn::Ident {
13 quote::format_ident!("{}", i)
14}
15
16fn n_ident(i: char) -> syn::Ident {
17 quote::format_ident!("n_{}", i)
18}
19
20fn contraction_for(indices: &[char], inner: TokenStream2) -> TokenStream2 {
21 let mut tt = inner;
22 for &i in indices.iter().rev() {
23 let index = index_ident(i);
24 let n = n_ident(i);
25 tt = quote! {
26 for #index in 0..#n { #tt }
27 };
28 }
29 tt
30}
31
32fn contraction_inner(subscripts: &Subscripts) -> TokenStream2 {
33 let mut inner_args_tt = Vec::new();
34 for (argc, arg) in subscripts.inputs.iter().enumerate() {
35 let mut index = Vec::new();
36 for i in subscripts.inputs[argc].indices() {
37 index.push(index_ident(i));
38 }
39 inner_args_tt.push(quote! {
40 #arg[(#(#index),*)]
41 })
42 }
43 let mut inner_mul = None;
44 for inner in inner_args_tt {
45 match inner_mul {
46 Some(i) => inner_mul = Some(quote! { #i * #inner }),
47 None => inner_mul = Some(inner),
48 }
49 }
50
51 let output_ident = &subscripts.output;
52 let mut output_indices = Vec::new();
53 for i in &subscripts.output.indices() {
54 let index = index_ident(*i);
55 output_indices.push(index.clone());
56 }
57 quote! {
58 #output_ident[(#(#output_indices),*)] = #inner_mul;
59 }
60}
61
62pub fn contraction(subscripts: &Subscripts) -> TokenStream2 {
82 let mut indices: Vec<char> = subscripts.output.indices();
83 for i in subscripts.contraction_indices() {
84 indices.push(i);
85 }
86
87 let inner = contraction_inner(subscripts);
88 contraction_for(&indices, inner)
89}
90
91pub fn define_array_size(subscripts: &Subscripts) -> TokenStream2 {
93 let mut appeared: HashSet<char> = HashSet::new();
94 let mut tt = Vec::new();
95 for arg in subscripts.inputs.iter() {
96 let n_ident: Vec<syn::Ident> = arg
97 .indices()
98 .into_iter()
99 .map(|i| {
100 if appeared.contains(&i) {
101 quote::format_ident!("_")
102 } else {
103 appeared.insert(i);
104 n_ident(i)
105 }
106 })
107 .collect();
108 tt.push(quote! {
109 let (#(#n_ident),*) = #arg.dim();
110 });
111 }
112 quote! { #(#tt)* }
113}
114
115pub fn array_size_asserts(subscripts: &Subscripts) -> TokenStream2 {
117 let mut tt = Vec::new();
118 for arg in &subscripts.inputs {
119 let n_each: Vec<_> = (0..arg.indices().len())
121 .map(|m| quote::format_ident!("n_{}", m))
122 .collect();
123 let n: Vec<_> = arg.indices().into_iter().map(n_ident).collect();
125 tt.push(quote! {
126 let (#(#n_each),*) = #arg.dim();
127 #(assert_eq!(#n_each, #n);)*
128 });
129 }
130 quote! { #({ #tt })* }
131}
132
133fn define_output_array(subscripts: &Subscripts) -> TokenStream2 {
134 let output_ident = &subscripts.output;
136 let mut n_output = Vec::new();
137 for i in subscripts.output.indices() {
138 n_output.push(n_ident(i));
139 }
140 quote! {
141 let mut #output_ident = ndarray::Array::zeros((#(#n_output),*));
142 }
143}
144
145pub fn inner(subscripts: &Subscripts) -> TokenStream2 {
147 let array_size = define_array_size(subscripts);
148 let array_size_asserts = array_size_asserts(subscripts);
149 let output_ident = &subscripts.output;
150 let output_tt = define_output_array(subscripts);
151 let contraction_tt = contraction(subscripts);
152 quote! {
153 #array_size
154 #array_size_asserts
155 #output_tt
156 #contraction_tt
157 #output_ident
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use crate::{codegen::format_block, *};
164
165 #[test]
166 fn define_array_size() {
167 let mut namespace = Namespace::init();
168 let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
169 let tt = format_block(super::define_array_size(&subscripts).to_string());
170 insta::assert_snapshot!(tt, @r###"
171 let (n_i, n_j) = arg0.dim();
172 let (_, n_k) = arg1.dim();
173 "###);
174 }
175
176 #[test]
177 fn contraction() {
178 let mut namespace = Namespace::init();
179 let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
180 let tt = format_block(super::contraction(&subscripts).to_string());
181 insta::assert_snapshot!(tt, @r###"
182 for i in 0..n_i {
183 for k in 0..n_k {
184 for j in 0..n_j {
185 out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
186 }
187 }
188 }
189 "###);
190 }
191
192 #[test]
193 fn inner() {
194 let mut namespace = Namespace::init();
195 let subscripts = Subscripts::from_raw_indices(&mut namespace, "ij,jk->ik").unwrap();
196 let tt = format_block(super::inner(&subscripts).to_string());
197 insta::assert_snapshot!(tt, @r###"
198 let (n_i, n_j) = arg0.dim();
199 let (_, n_k) = arg1.dim();
200 {
201 let (n_0, n_1) = arg0.dim();
202 assert_eq!(n_0, n_i);
203 assert_eq!(n_1, n_j);
204 }
205 {
206 let (n_0, n_1) = arg1.dim();
207 assert_eq!(n_0, n_j);
208 assert_eq!(n_1, n_k);
209 }
210 let mut out0 = ndarray::Array::zeros((n_i, n_k));
211 for i in 0..n_i {
212 for k in 0..n_k {
213 for j in 0..n_j {
214 out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
215 }
216 }
217 }
218 out0
219 "###);
220 }
221}