einsum_codegen/
subscripts.rs

1//! Einsum subscripts, e.g. `ij,jk->ik`
2use crate::{parser::*, *};
3use anyhow::Result;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, ToTokens, TokenStreamExt};
6use std::{
7    collections::{BTreeMap, BTreeSet},
8    fmt,
9    str::FromStr,
10};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct Subscript {
14    raw: RawSubscript,
15    position: Position,
16}
17
18impl Subscript {
19    pub fn raw(&self) -> &RawSubscript {
20        &self.raw
21    }
22
23    pub fn position(&self) -> &Position {
24        &self.position
25    }
26
27    pub fn indices(&self) -> Vec<char> {
28        match &self.raw {
29            RawSubscript::Indices(indices) => indices.clone(),
30            RawSubscript::Ellipsis { start, end } => {
31                start.iter().chain(end.iter()).cloned().collect()
32            }
33        }
34    }
35}
36
37impl ToTokens for Subscript {
38    fn to_tokens(&self, tokens: &mut TokenStream) {
39        ToTokens::to_tokens(&self.position, tokens)
40    }
41}
42
43#[cfg_attr(doc, katexit::katexit)]
44/// Einsum subscripts with tensor names, e.g. `ij,jk->ik | arg0 arg1 -> out`
45#[derive(Clone, PartialEq, Eq)]
46pub struct Subscripts {
47    /// Input subscript, `ij` and `jk`
48    pub inputs: Vec<Subscript>,
49    /// Output subscript.
50    pub output: Subscript,
51}
52
53// `ij,jk->ik | arg0,arg1->out0` format
54impl fmt::Debug for Subscripts {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        for (n, input) in self.inputs.iter().enumerate() {
57            write!(f, "{}", input.raw)?;
58            if n < self.inputs.len() - 1 {
59                write!(f, ",")?;
60            }
61        }
62        write!(f, "->{} | ", self.output.raw)?;
63
64        for (n, input) in self.inputs.iter().enumerate() {
65            write!(f, "{}", input.position)?;
66            if n < self.inputs.len() - 1 {
67                write!(f, ",")?;
68            }
69        }
70        write!(f, "->{}", self.output.position)?;
71        Ok(())
72    }
73}
74
75impl fmt::Display for Subscripts {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        fmt::Debug::fmt(self, f)
78    }
79}
80
81impl ToTokens for Subscripts {
82    fn to_tokens(&self, tokens: &mut TokenStream) {
83        let fn_name = format_ident!("{}", self.escaped_ident());
84        let args = &self.inputs;
85        let out = &self.output;
86        tokens.append_all(quote! {
87            let #out = #fn_name(#(#args),*);
88        });
89    }
90}
91
92impl Subscripts {
93    /// Returns $\alpha$ if this subscripts requires $O(N^\alpha)$ floating point operation
94    pub fn compute_order(&self) -> usize {
95        self.memory_order() + self.contraction_indices().len()
96    }
97
98    /// Returns $\beta$ if this subscripts requires $O(N^\beta)$ memory
99    pub fn memory_order(&self) -> usize {
100        self.output.indices().len()
101    }
102
103    /// Normalize subscripts into "explicit mode"
104    ///
105    /// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)
106    /// has "explicit mode" including `->`, e.g. `ij,jk->ik` and
107    /// "implicit mode" e.g. `ij,jk`.
108    /// The output subscript is determined from input subscripts in implicit mode:
109    ///
110    /// > In implicit mode, the chosen subscripts are important since the axes
111    /// > of the output are reordered alphabetically.
112    /// > This means that `np.einsum('ij', a)` doesn’t affect a 2D array,
113    /// > while `np.einsum('ji', a)` takes its transpose.
114    /// > Additionally, `np.einsum('ij,jk', a, b)` returns a matrix multiplication,
115    /// > while, `np.einsum('ij,jh', a, b)` returns the transpose of
116    /// > the multiplication since subscript ‘h’ precedes subscript ‘i’.
117    ///
118    /// ```
119    /// use std::str::FromStr;
120    /// use einsum_codegen::{*, parser::*};
121    ///
122    /// let mut names = Namespace::init();
123    ///
124    /// // Infer output subscripts for implicit mode
125    /// let raw = RawSubscripts::from_str("ij,jk").unwrap();
126    /// let subscripts = Subscripts::from_raw(&mut names, raw);
127    /// assert_eq!(subscripts.output.raw(), &['i', 'k']);
128    ///
129    /// // Reordered alphabetically
130    /// let raw = RawSubscripts::from_str("ji").unwrap();
131    /// let subscripts = Subscripts::from_raw(&mut names, raw);
132    /// assert_eq!(subscripts.output.raw(), &['i', 'j']);
133    /// ```
134    ///
135    pub fn from_raw(names: &mut Namespace, raw: RawSubscripts) -> Self {
136        let inputs = raw
137            .inputs
138            .iter()
139            .enumerate()
140            .map(|(i, indices)| Subscript {
141                raw: indices.clone(),
142                position: Position::Arg(i),
143            })
144            .collect();
145        let position = names.new_ident();
146        if let Some(output) = raw.output {
147            return Subscripts {
148                inputs,
149                output: Subscript {
150                    raw: output,
151                    position,
152                },
153            };
154        }
155
156        let count = count_indices(&inputs);
157        let output = Subscript {
158            raw: RawSubscript::Indices(
159                count
160                    .iter()
161                    .filter_map(|(key, value)| if *value == 1 { Some(*key) } else { None })
162                    .collect(),
163            ),
164            position,
165        };
166        Subscripts { inputs, output }
167    }
168
169    pub fn from_raw_indices(names: &mut Namespace, indices: &str) -> Result<Self> {
170        let raw = RawSubscripts::from_str(indices)?;
171        Ok(Self::from_raw(names, raw))
172    }
173
174    /// Indices to be contracted
175    ///
176    /// ```
177    /// use std::str::FromStr;
178    /// use maplit::btreeset;
179    /// use einsum_codegen::*;
180    ///
181    /// let mut names = Namespace::init();
182    ///
183    /// // Matrix multiplication AB
184    /// let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk->ik").unwrap();
185    /// assert_eq!(subscripts.contraction_indices(), btreeset!{'j'});
186    ///
187    /// // Reduce all Tr(AB)
188    /// let subscripts = Subscripts::from_raw_indices(&mut names, "ij,ji->").unwrap();
189    /// assert_eq!(subscripts.contraction_indices(), btreeset!{'i', 'j'});
190    ///
191    /// // Take diagonal elements
192    /// let subscripts = Subscripts::from_raw_indices(&mut names, "ii->i").unwrap();
193    /// assert_eq!(subscripts.contraction_indices(), btreeset!{});
194    /// ```
195    pub fn contraction_indices(&self) -> BTreeSet<char> {
196        let count = count_indices(&self.inputs);
197        let mut subscripts: BTreeSet<char> = count
198            .into_iter()
199            .filter_map(|(key, value)| if value > 1 { Some(key) } else { None })
200            .collect();
201        for c in &self.output.indices() {
202            subscripts.remove(c);
203        }
204        subscripts
205    }
206
207    /// Factorize subscripts
208    ///
209    /// ```text
210    /// ij,jk,kl->il | arg0 arg1 arg2 -> out0
211    /// ```
212    ///
213    /// will be factorized with `(arg0, arg1)` into
214    ///
215    /// ```text
216    /// ij,jk->ik | arg0 arg1 -> out1
217    /// ik,kl->il | out1 arg2 -> out0
218    /// ```
219    ///
220    /// ```
221    /// use einsum_codegen::{*, parser::RawSubscript};
222    /// use std::str::FromStr;
223    /// use maplit::btreeset;
224    ///
225    /// let mut names = Namespace::init();
226    /// let base = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap();
227    ///
228    /// let (ijjk, ikkl) = base.factorize(&mut names,
229    ///   btreeset!{ Position::Arg(0), Position::Arg(1) }
230    /// ).unwrap();
231    /// ```
232    pub fn factorize(
233        &self,
234        names: &mut Namespace,
235        inners: BTreeSet<Position>,
236    ) -> Result<(Self, Self)> {
237        let mut inner_inputs = Vec::new();
238        let mut outer_inputs = Vec::new();
239        let mut indices: BTreeMap<char, (usize /* inner */, usize /* outer */)> = BTreeMap::new();
240        for input in &self.inputs {
241            if inners.contains(&input.position) {
242                inner_inputs.push(input.clone());
243                for c in input.indices() {
244                    indices
245                        .entry(c)
246                        .and_modify(|(i, _)| *i += 1)
247                        .or_insert((1, 0));
248                }
249            } else {
250                outer_inputs.push(input.clone());
251                for c in input.indices() {
252                    indices
253                        .entry(c)
254                        .and_modify(|(_, o)| *o += 1)
255                        .or_insert((0, 1));
256                }
257            }
258        }
259        let out = Subscript {
260            raw: RawSubscript::Indices(
261                indices
262                    .into_iter()
263                    .filter_map(|(key, (i, o))| {
264                        if i == 1 || (i >= 2 && o > 0) {
265                            Some(key)
266                        } else {
267                            None
268                        }
269                    })
270                    .collect(),
271            ),
272            position: names.new_ident(),
273        };
274        outer_inputs.insert(0, out.clone());
275        Ok((
276            Subscripts {
277                inputs: inner_inputs,
278                output: out,
279            },
280            Subscripts {
281                inputs: outer_inputs,
282                output: self.output.clone(),
283            },
284        ))
285    }
286
287    /// Escaped subscript for identifier
288    ///
289    /// This is not injective, e.g. `i...,j->ij` and `i,...j->ij`
290    /// returns a same result `i____j__ij`.
291    ///
292    pub fn escaped_ident(&self) -> String {
293        use std::fmt::Write;
294        let mut out = String::new();
295        for input in &self.inputs {
296            write!(out, "{}", input.raw).unwrap();
297            write!(out, "_").unwrap();
298        }
299        write!(out, "_{}", self.output.raw).unwrap();
300        out
301    }
302}
303
304fn count_indices(inputs: &[Subscript]) -> BTreeMap<char, u32> {
305    let mut count = BTreeMap::new();
306    for input in inputs {
307        for c in input.indices() {
308            count.entry(c).and_modify(|n| *n += 1).or_insert(1);
309        }
310    }
311    count
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn escaped_ident() {
320        let mut names = Namespace::init();
321
322        let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk->ik").unwrap();
323        assert_eq!(subscripts.escaped_ident(), "ij_jk__ik");
324
325        // implicit mode
326        let subscripts = Subscripts::from_raw_indices(&mut names, "ij,jk").unwrap();
327        assert_eq!(subscripts.escaped_ident(), "ij_jk__ik");
328
329        // output scalar
330        let subscripts = Subscripts::from_raw_indices(&mut names, "i,i").unwrap();
331        assert_eq!(subscripts.escaped_ident(), "i_i__");
332
333        // ellipsis
334        let subscripts = Subscripts::from_raw_indices(&mut names, "ij...,jk...->ik...").unwrap();
335        assert_eq!(subscripts.escaped_ident(), "ij____jk_____ik___");
336    }
337}