ndarray_einsum/
validation.rs

1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Contains functions and structs related to parsing an `einsum`-formatted string
16//!
17//! This module has the implementation of `Contraction` and `SizedContraction`. `SizedContraction`
18//! is used throughout the library to store the details of a full contraction (corresponding
19//! to a string supplied by the caller) or a mini-contraction (corresponding to a simplification of
20//! a single tensor or a pairwise contraction between two tensors) produced by the optimizer in order
21//! to perform the full contraction.
22//!
23//!
24use crate::{
25    generate_optimized_order, ArrayLike, ContractionOrder, EinsumPath, OptimizationMethod,
26};
27use hashbrown::{HashMap, HashSet};
28use lazy_static::lazy_static;
29use ndarray::prelude::*;
30use ndarray::LinalgScalar;
31use regex::Regex;
32
33/// The result of running an `einsum`-formatted string through the regex.
34#[derive(Debug)]
35struct EinsumParse {
36    operand_indices: Vec<String>,
37    output_indices: Option<String>,
38}
39
40/// A `Contraction` contains the result of parsing an `einsum`-formatted string.
41///
42/// ```
43/// # use ndarray_einsum::*;
44/// let contraction = Contraction::new("ij,jk->ik").unwrap();
45/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]);
46/// assert_eq!(contraction.output_indices, vec!['i', 'k']);
47/// assert_eq!(contraction.summation_indices, vec!['j']);
48///
49/// let contraction = Contraction::new("ij,jk").unwrap();
50/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]);
51/// assert_eq!(contraction.output_indices, vec!['i', 'k']);
52/// assert_eq!(contraction.summation_indices, vec!['j']);
53/// ```
54#[derive(Debug, Clone)]
55pub struct Contraction {
56    /// A vector with as many elements as input operands, where each
57    /// member of the vector is a `Vec<char>` with each char representing the label for
58    /// each axis of the operand.
59    pub operand_indices: Vec<Vec<char>>,
60
61    /// Specifies which axes the resulting tensor will contain
62    // (corresponding to axes in one or more of the input operands).
63    pub output_indices: Vec<char>,
64
65    /// Contains the axes that will be summed over (a.k.a contracted) by the operation.
66    pub summation_indices: Vec<char>,
67}
68
69impl Contraction {
70    /// Validates and creates a `Contraction` from an `einsum`-formatted string.
71    pub fn new(input_string: &str) -> Result<Self, &'static str> {
72        let p = parse_einsum_string(input_string).ok_or("Invalid string")?;
73        Contraction::from_parse(&p)
74    }
75
76    /// If output_indices has been specified in the parse (i.e. explicit case),
77    /// e.g. "ij,jk->ik", simply converts the strings to `Vec<char>`s and passes
78    /// them to Contraction::from_indices. If the output indices haven't been specified,
79    /// e.g. "ij,jk", figures out which ones aren't duplicated and hence summed over,
80    /// sorts them alphabetically, and uses those as the output indices.
81    fn from_parse(parse: &EinsumParse) -> Result<Self, &'static str> {
82        let requested_output_indices: Vec<char> = match &parse.output_indices {
83            Some(s) => s.chars().collect(),
84            _ => {
85                // Handle implicit case, e.g. nothing to the right of the arrow
86                let mut input_indices = HashMap::new();
87                for c in parse.operand_indices.iter().flat_map(|s| s.chars()) {
88                    *input_indices.entry(c).or_insert(0) += 1;
89                }
90
91                let mut unique_indices: Vec<char> = input_indices
92                    .iter()
93                    .filter(|&(_, &v)| v == 1)
94                    .map(|(&k, _)| k)
95                    .collect();
96                unique_indices.sort();
97                unique_indices
98            }
99        };
100
101        let operand_indices: Vec<Vec<char>> = parse
102            .operand_indices
103            .iter()
104            .map(|x| x.chars().collect::<Vec<char>>())
105            .collect();
106        Contraction::from_indices(&operand_indices, &requested_output_indices)
107    }
108
109    /// Validates and creates a `Contraction` from a slice of `Vec<char>`s containing
110    /// the operand indices, and a slice of `char` containing the desired output indices.
111    fn from_indices(
112        operand_indices: &[Vec<char>],
113        output_indices: &[char],
114    ) -> Result<Self, &'static str> {
115        let mut input_char_counts = HashMap::new();
116        for &c in operand_indices.iter().flat_map(|operand| operand.iter()) {
117            *input_char_counts.entry(c).or_insert(0) += 1;
118        }
119
120        let mut distinct_output_indices = HashMap::new();
121        for &c in output_indices.iter() {
122            *distinct_output_indices.entry(c).or_insert(0) += 1;
123        }
124        for (&c, &n) in distinct_output_indices.iter() {
125            // No duplicates
126            if n > 1 {
127                return Err("Requested output has duplicate index");
128            }
129
130            // Must be in inputs
131            if !input_char_counts.contains_key(&c) {
132                return Err("Requested output contains an index not found in inputs");
133            }
134        }
135
136        let mut summation_indices: Vec<char> = input_char_counts
137            .keys()
138            .filter(|&c| !distinct_output_indices.contains_key(c))
139            .cloned()
140            .collect();
141        summation_indices.sort();
142
143        let cloned_operand_indices: Vec<Vec<char>> = operand_indices.to_vec();
144
145        Ok(Contraction {
146            operand_indices: cloned_operand_indices,
147            output_indices: output_indices.to_vec(),
148            summation_indices,
149        })
150    }
151}
152
153/// Alias for `HashMap<char, usize>`. Contains the axis lengths for all indices in the contraction.
154/// Contrary to the name, does not only hold the sizes for output indices.
155pub type OutputSize = HashMap<char, usize>;
156
157/// Enables `OutputSize::from_contraction_and_shapes()`
158trait OutputSizeMethods {
159    fn from_contraction_and_shapes(
160        contraction: &Contraction,
161        operand_shapes: &[Vec<usize>],
162    ) -> Result<OutputSize, &'static str>;
163}
164impl OutputSizeMethods for OutputSize {
165    /// Build the HashMap containing the axis lengths
166    fn from_contraction_and_shapes(
167        contraction: &Contraction,
168        operand_shapes: &[Vec<usize>],
169    ) -> Result<Self, &'static str> {
170        // Check that len(operand_indices) == len(operands)
171        if contraction.operand_indices.len() != operand_shapes.len() {
172            return Err(
173                "number of operands in contraction does not match number of operands supplied",
174            );
175        }
176
177        let mut index_lengths: OutputSize = HashMap::new();
178
179        for (indices, operand_shape) in contraction.operand_indices.iter().zip(operand_shapes) {
180            // Check that len(operand_indices[i]) == len(operands[i].shape())
181            if indices.len() != operand_shape.len() {
182                return Err(
183                    "number of indices in one or more operands does not match dimensions of operand",
184                );
185            }
186
187            // Check that whenever there are multiple copies of an index,
188            // operands[i].shape()[m] == operands[j].shape()[n]
189            for (&c, &n) in indices.iter().zip(operand_shape) {
190                let existing_n = index_lengths.entry(c).or_insert(n);
191                if *existing_n != n {
192                    return Err("repeated index with different size");
193                }
194            }
195        }
196
197        Ok(index_lengths)
198    }
199}
200
201/// A `SizedContraction` contains a `Contraction` as well as a `HashMap<char, usize>`
202/// specifying the axis lengths for each index in the contraction.
203///
204/// Note that output_size is a misnomer (to be changed); it contains all the axis lengths,
205/// including the ones that will be contracted (i.e. not just the ones in
206/// contraction.output_indices).
207#[derive(Debug, Clone)]
208pub struct SizedContraction {
209    pub contraction: Contraction,
210    pub output_size: OutputSize,
211}
212
213impl SizedContraction {
214    /// Creates a new SizedContraction based on a subset of the operand indices and/or output
215    /// indices. Not intended for general use; used internally in the crate when compiling
216    /// a multi-tensor contraction into a set of tensor simplifications (a.k.a. singleton
217    /// contractions) and pairwise contractions.
218    ///
219    /// ```
220    /// # use ndarray_einsum::*;
221    /// # use ndarray::prelude::*;
222    /// let m1: Array3<f64> = Array::zeros((2, 2, 3));
223    /// let m2: Array2<f64> = Array::zeros((3, 4));
224    /// let sc = SizedContraction::new("iij,jk->ik", &[&m1, &m2]).unwrap();
225    /// let lhs_simplification = sc.subset(&[vec!['i','i','j']], &['i','j']).unwrap();
226    /// let diagonalized_m1 = lhs_simplification.contract_operands(&[&m1]);
227    /// assert_eq!(diagonalized_m1.shape(), &[2, 3]);
228    /// ```
229    pub fn subset(
230        &self,
231        new_operand_indices: &[Vec<char>],
232        new_output_indices: &[char],
233    ) -> Result<Self, &'static str> {
234        // Make sure all chars in new_operand_indices are in self
235        let all_operand_indices: HashSet<char> = new_operand_indices
236            .iter()
237            .flat_map(|operand| operand.iter())
238            .cloned()
239            .collect();
240        if all_operand_indices
241            .iter()
242            .any(|c| !self.output_size.contains_key(c))
243        {
244            return Err("Character found in new_operand_indices but not in self.output_size");
245        }
246
247        // Validate what they asked for and compute summation_indices
248        let new_contraction = Contraction::from_indices(new_operand_indices, new_output_indices)?;
249
250        // Clone output_size, omitting unused characters
251        let new_output_size: OutputSize = self
252            .output_size
253            .iter()
254            .filter(|&(&k, _)| all_operand_indices.contains(&k))
255            .map(|(&k, &v)| (k, v))
256            .collect();
257
258        Ok(SizedContraction {
259            contraction: new_contraction,
260            output_size: new_output_size,
261        })
262    }
263
264    fn from_contraction_and_shapes(
265        contraction: &Contraction,
266        operand_shapes: &[Vec<usize>],
267    ) -> Result<Self, &'static str> {
268        let output_size = OutputSize::from_contraction_and_shapes(contraction, operand_shapes)?;
269
270        Ok(SizedContraction {
271            contraction: contraction.clone(),
272            output_size,
273        })
274    }
275
276    /// Create a SizedContraction from an already-created Contraction and a list
277    /// of operands.
278    /// ```
279    /// # use ndarray_einsum::*;
280    /// # use ndarray::prelude::*;
281    /// let m1: Array2<f64> = Array::zeros((2, 3));
282    /// let m2: Array2<f64> = Array::zeros((3, 4));
283    /// let c = Contraction::new("ij,jk->ik").unwrap();
284    /// let sc = SizedContraction::from_contraction_and_operands(&c, &[&m1, &m2]).unwrap();
285    /// assert_eq!(sc.output_size[&'i'], 2);
286    /// assert_eq!(sc.output_size[&'k'], 4);
287    /// assert_eq!(sc.output_size[&'j'], 3);
288    /// ```
289    pub fn from_contraction_and_operands<A>(
290        contraction: &Contraction,
291        operands: &[&dyn ArrayLike<A>],
292    ) -> Result<Self, &'static str> {
293        let operand_shapes = get_operand_shapes(operands);
294
295        SizedContraction::from_contraction_and_shapes(contraction, &operand_shapes)
296    }
297
298    /// Create a SizedContraction from an `einsum`-formatted input string and a slice
299    /// of `Vec<usize>`s containing the shapes of each operand.
300    /// ```
301    /// # use ndarray_einsum::*;
302    /// # use ndarray::prelude::*;
303    /// let sc = SizedContraction::from_string_and_shapes(
304    ///     "ij,jk->ik",
305    ///     &[vec![2, 3], vec![3, 4]]
306    /// ).unwrap();
307    /// assert_eq!(sc.output_size[&'i'], 2);
308    /// assert_eq!(sc.output_size[&'k'], 4);
309    /// assert_eq!(sc.output_size[&'j'], 3);
310    /// ```
311    pub fn from_string_and_shapes(
312        input_string: &str,
313        operand_shapes: &[Vec<usize>],
314    ) -> Result<Self, &'static str> {
315        let contraction = validate(input_string)?;
316        SizedContraction::from_contraction_and_shapes(&contraction, operand_shapes)
317    }
318
319    /// Create a SizedContraction from an `einsum`-formatted input string and a list
320    /// of operands.
321    ///
322    /// ```
323    /// # use ndarray_einsum::*;
324    /// # use ndarray::prelude::*;
325    /// let m1: Array2<f64> = Array::zeros((2, 3));
326    /// let m2: Array2<f64> = Array::zeros((3, 4));
327    /// let sc = SizedContraction::new("ij,jk->ik", &[&m1, &m2]).unwrap();
328    /// assert_eq!(sc.output_size[&'i'], 2);
329    /// assert_eq!(sc.output_size[&'k'], 4);
330    /// assert_eq!(sc.output_size[&'j'], 3);
331    /// ```
332    pub fn new<A>(
333        input_string: &str,
334        operands: &[&dyn ArrayLike<A>],
335    ) -> Result<Self, &'static str> {
336        let operand_shapes = get_operand_shapes(operands);
337
338        SizedContraction::from_string_and_shapes(input_string, &operand_shapes)
339    }
340
341    /// Perform the contraction on a set of operands.
342    ///
343    /// ```
344    /// # use ndarray_einsum::*;
345    /// # use ndarray::prelude::*;
346    /// let m1: Array2<f64> = Array::zeros((2, 3));
347    /// let m2: Array2<f64> = Array::zeros((3, 4));
348    /// let out: ArrayD<f64> = Array::zeros((2, 4)).into_dyn();
349    /// let sc = validate_and_size("ij,jk->ik", &[&m1, &m2]).unwrap();
350    /// assert_eq!(sc.contract_operands(&[&m1, &m2]), out);
351    /// ```
352    pub fn contract_operands<A: Clone + LinalgScalar>(
353        &self,
354        operands: &[&dyn ArrayLike<A>],
355    ) -> ArrayD<A> {
356        let cpc = EinsumPath::new(self);
357        cpc.contract_operands(operands)
358    }
359
360    /// Show as an `einsum`-formatted string.
361    ///
362    /// ```
363    /// # use ndarray_einsum::*;
364    /// # use ndarray::prelude::*;
365    /// let m1: Array2<f64> = Array::zeros((2, 3));
366    /// let m2: Array2<f64> = Array::zeros((3, 4));
367    /// let sc = validate_and_size("ij,jk", &[&m1, &m2]).unwrap();
368    /// assert_eq!(sc.as_einsum_string(), "ij,jk->ik");
369    /// ```
370    pub fn as_einsum_string(&self) -> String {
371        assert!(!self.contraction.operand_indices.is_empty());
372        let mut s: String = self.contraction.operand_indices[0]
373            .iter()
374            .cloned()
375            .collect();
376        for op in self.contraction.operand_indices[1..].iter() {
377            s.push(',');
378            for &c in op.iter() {
379                s.push(c);
380            }
381        }
382        s.push_str("->");
383        for &c in self.contraction.output_indices.iter() {
384            s.push(c);
385        }
386        s
387    }
388}
389
390/// Runs an input string through a regex and convert it to an EinsumParse.
391fn parse_einsum_string(input_string: &str) -> Option<EinsumParse> {
392    lazy_static! {
393        // Unwhitespaced version:
394        // ^([a-z]+)((?:,[a-z]+)*)(?:->([a-z]*))?$
395        static ref RE: Regex = Regex::new(r"(?x)
396            ^
397            (?P<first_operand>[a-z]+)
398            (?P<more_operands>(?:,[a-z]+)*)
399            (?:->(?P<output>[a-z]*))?
400            $
401            ").unwrap();
402    }
403    let captures = RE.captures(input_string)?;
404    let mut operand_indices = Vec::new();
405    let output_indices = captures.name("output").map(|s| String::from(s.as_str()));
406
407    operand_indices.push(String::from(&captures["first_operand"]));
408    for s in captures["more_operands"].split(',').skip(1) {
409        operand_indices.push(String::from(s));
410    }
411
412    Some(EinsumParse {
413        operand_indices,
414        output_indices,
415    })
416}
417
418/// Wrapper around [Contraction::new()](struct.Contraction.html#method.new).
419pub fn validate(input_string: &str) -> Result<Contraction, &'static str> {
420    Contraction::new(input_string)
421}
422
423/// Returns a vector holding one `Vec<usize>` for each operand.
424fn get_operand_shapes<A>(operands: &[&dyn ArrayLike<A>]) -> Vec<Vec<usize>> {
425    operands
426        .iter()
427        .map(|operand| Vec::from(operand.into_dyn_view().shape()))
428        .collect()
429}
430
431/// Wrapper around [SizedContraction::new()](struct.SizedContraction.html#method.new).
432pub fn validate_and_size<A>(
433    input_string: &str,
434    operands: &[&dyn ArrayLike<A>],
435) -> Result<SizedContraction, &'static str> {
436    SizedContraction::new(input_string, operands)
437}
438
439/// Create a [SizedContraction](struct.SizedContraction.html) and then optimize the order in which pairs of inputs will be contracted.
440pub fn validate_and_optimize_order<A>(
441    input_string: &str,
442    operands: &[&dyn ArrayLike<A>],
443    optimization_strategy: OptimizationMethod,
444) -> Result<ContractionOrder, &'static str> {
445    let sc = validate_and_size(input_string, operands)?;
446    Ok(generate_optimized_order(&sc, optimization_strategy))
447}