ndarray_einsum/validation.rs
1// Copyright 2019 Jared Samet
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Contains functions and structs related to parsing an `einsum`-formatted string
16//!
17//! This module has the implementation of `Contraction` and `SizedContraction`. `SizedContraction`
18//! is used throughout the library to store the details of a full contraction (corresponding
19//! to a string supplied by the caller) or a mini-contraction (corresponding to a simplification of
20//! a single tensor or a pairwise contraction between two tensors) produced by the optimizer in order
21//! to perform the full contraction.
22//!
23//!
24use crate::{
25 generate_optimized_order, ArrayLike, ContractionOrder, EinsumPath, OptimizationMethod,
26};
27use hashbrown::{HashMap, HashSet};
28use lazy_static::lazy_static;
29use ndarray::prelude::*;
30use ndarray::LinalgScalar;
31use regex::Regex;
32
33/// The result of running an `einsum`-formatted string through the regex.
34#[derive(Debug)]
35struct EinsumParse {
36 operand_indices: Vec<String>,
37 output_indices: Option<String>,
38}
39
40/// A `Contraction` contains the result of parsing an `einsum`-formatted string.
41///
42/// ```
43/// # use ndarray_einsum::*;
44/// let contraction = Contraction::new("ij,jk->ik").unwrap();
45/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]);
46/// assert_eq!(contraction.output_indices, vec!['i', 'k']);
47/// assert_eq!(contraction.summation_indices, vec!['j']);
48///
49/// let contraction = Contraction::new("ij,jk").unwrap();
50/// assert_eq!(contraction.operand_indices, vec![vec!['i', 'j'], vec!['j', 'k']]);
51/// assert_eq!(contraction.output_indices, vec!['i', 'k']);
52/// assert_eq!(contraction.summation_indices, vec!['j']);
53/// ```
54#[derive(Debug, Clone)]
55pub struct Contraction {
56 /// A vector with as many elements as input operands, where each
57 /// member of the vector is a `Vec<char>` with each char representing the label for
58 /// each axis of the operand.
59 pub operand_indices: Vec<Vec<char>>,
60
61 /// Specifies which axes the resulting tensor will contain
62 // (corresponding to axes in one or more of the input operands).
63 pub output_indices: Vec<char>,
64
65 /// Contains the axes that will be summed over (a.k.a contracted) by the operation.
66 pub summation_indices: Vec<char>,
67}
68
69impl Contraction {
70 /// Validates and creates a `Contraction` from an `einsum`-formatted string.
71 pub fn new(input_string: &str) -> Result<Self, &'static str> {
72 let p = parse_einsum_string(input_string).ok_or("Invalid string")?;
73 Contraction::from_parse(&p)
74 }
75
76 /// If output_indices has been specified in the parse (i.e. explicit case),
77 /// e.g. "ij,jk->ik", simply converts the strings to `Vec<char>`s and passes
78 /// them to Contraction::from_indices. If the output indices haven't been specified,
79 /// e.g. "ij,jk", figures out which ones aren't duplicated and hence summed over,
80 /// sorts them alphabetically, and uses those as the output indices.
81 fn from_parse(parse: &EinsumParse) -> Result<Self, &'static str> {
82 let requested_output_indices: Vec<char> = match &parse.output_indices {
83 Some(s) => s.chars().collect(),
84 _ => {
85 // Handle implicit case, e.g. nothing to the right of the arrow
86 let mut input_indices = HashMap::new();
87 for c in parse.operand_indices.iter().flat_map(|s| s.chars()) {
88 *input_indices.entry(c).or_insert(0) += 1;
89 }
90
91 let mut unique_indices: Vec<char> = input_indices
92 .iter()
93 .filter(|&(_, &v)| v == 1)
94 .map(|(&k, _)| k)
95 .collect();
96 unique_indices.sort();
97 unique_indices
98 }
99 };
100
101 let operand_indices: Vec<Vec<char>> = parse
102 .operand_indices
103 .iter()
104 .map(|x| x.chars().collect::<Vec<char>>())
105 .collect();
106 Contraction::from_indices(&operand_indices, &requested_output_indices)
107 }
108
109 /// Validates and creates a `Contraction` from a slice of `Vec<char>`s containing
110 /// the operand indices, and a slice of `char` containing the desired output indices.
111 fn from_indices(
112 operand_indices: &[Vec<char>],
113 output_indices: &[char],
114 ) -> Result<Self, &'static str> {
115 let mut input_char_counts = HashMap::new();
116 for &c in operand_indices.iter().flat_map(|operand| operand.iter()) {
117 *input_char_counts.entry(c).or_insert(0) += 1;
118 }
119
120 let mut distinct_output_indices = HashMap::new();
121 for &c in output_indices.iter() {
122 *distinct_output_indices.entry(c).or_insert(0) += 1;
123 }
124 for (&c, &n) in distinct_output_indices.iter() {
125 // No duplicates
126 if n > 1 {
127 return Err("Requested output has duplicate index");
128 }
129
130 // Must be in inputs
131 if !input_char_counts.contains_key(&c) {
132 return Err("Requested output contains an index not found in inputs");
133 }
134 }
135
136 let mut summation_indices: Vec<char> = input_char_counts
137 .keys()
138 .filter(|&c| !distinct_output_indices.contains_key(c))
139 .cloned()
140 .collect();
141 summation_indices.sort();
142
143 let cloned_operand_indices: Vec<Vec<char>> = operand_indices.to_vec();
144
145 Ok(Contraction {
146 operand_indices: cloned_operand_indices,
147 output_indices: output_indices.to_vec(),
148 summation_indices,
149 })
150 }
151}
152
153/// Alias for `HashMap<char, usize>`. Contains the axis lengths for all indices in the contraction.
154/// Contrary to the name, does not only hold the sizes for output indices.
155pub type OutputSize = HashMap<char, usize>;
156
157/// Enables `OutputSize::from_contraction_and_shapes()`
158trait OutputSizeMethods {
159 fn from_contraction_and_shapes(
160 contraction: &Contraction,
161 operand_shapes: &[Vec<usize>],
162 ) -> Result<OutputSize, &'static str>;
163}
164impl OutputSizeMethods for OutputSize {
165 /// Build the HashMap containing the axis lengths
166 fn from_contraction_and_shapes(
167 contraction: &Contraction,
168 operand_shapes: &[Vec<usize>],
169 ) -> Result<Self, &'static str> {
170 // Check that len(operand_indices) == len(operands)
171 if contraction.operand_indices.len() != operand_shapes.len() {
172 return Err(
173 "number of operands in contraction does not match number of operands supplied",
174 );
175 }
176
177 let mut index_lengths: OutputSize = HashMap::new();
178
179 for (indices, operand_shape) in contraction.operand_indices.iter().zip(operand_shapes) {
180 // Check that len(operand_indices[i]) == len(operands[i].shape())
181 if indices.len() != operand_shape.len() {
182 return Err(
183 "number of indices in one or more operands does not match dimensions of operand",
184 );
185 }
186
187 // Check that whenever there are multiple copies of an index,
188 // operands[i].shape()[m] == operands[j].shape()[n]
189 for (&c, &n) in indices.iter().zip(operand_shape) {
190 let existing_n = index_lengths.entry(c).or_insert(n);
191 if *existing_n != n {
192 return Err("repeated index with different size");
193 }
194 }
195 }
196
197 Ok(index_lengths)
198 }
199}
200
201/// A `SizedContraction` contains a `Contraction` as well as a `HashMap<char, usize>`
202/// specifying the axis lengths for each index in the contraction.
203///
204/// Note that output_size is a misnomer (to be changed); it contains all the axis lengths,
205/// including the ones that will be contracted (i.e. not just the ones in
206/// contraction.output_indices).
207#[derive(Debug, Clone)]
208pub struct SizedContraction {
209 pub contraction: Contraction,
210 pub output_size: OutputSize,
211}
212
213impl SizedContraction {
214 /// Creates a new SizedContraction based on a subset of the operand indices and/or output
215 /// indices. Not intended for general use; used internally in the crate when compiling
216 /// a multi-tensor contraction into a set of tensor simplifications (a.k.a. singleton
217 /// contractions) and pairwise contractions.
218 ///
219 /// ```
220 /// # use ndarray_einsum::*;
221 /// # use ndarray::prelude::*;
222 /// let m1: Array3<f64> = Array::zeros((2, 2, 3));
223 /// let m2: Array2<f64> = Array::zeros((3, 4));
224 /// let sc = SizedContraction::new("iij,jk->ik", &[&m1, &m2]).unwrap();
225 /// let lhs_simplification = sc.subset(&[vec!['i','i','j']], &['i','j']).unwrap();
226 /// let diagonalized_m1 = lhs_simplification.contract_operands(&[&m1]);
227 /// assert_eq!(diagonalized_m1.shape(), &[2, 3]);
228 /// ```
229 pub fn subset(
230 &self,
231 new_operand_indices: &[Vec<char>],
232 new_output_indices: &[char],
233 ) -> Result<Self, &'static str> {
234 // Make sure all chars in new_operand_indices are in self
235 let all_operand_indices: HashSet<char> = new_operand_indices
236 .iter()
237 .flat_map(|operand| operand.iter())
238 .cloned()
239 .collect();
240 if all_operand_indices
241 .iter()
242 .any(|c| !self.output_size.contains_key(c))
243 {
244 return Err("Character found in new_operand_indices but not in self.output_size");
245 }
246
247 // Validate what they asked for and compute summation_indices
248 let new_contraction = Contraction::from_indices(new_operand_indices, new_output_indices)?;
249
250 // Clone output_size, omitting unused characters
251 let new_output_size: OutputSize = self
252 .output_size
253 .iter()
254 .filter(|&(&k, _)| all_operand_indices.contains(&k))
255 .map(|(&k, &v)| (k, v))
256 .collect();
257
258 Ok(SizedContraction {
259 contraction: new_contraction,
260 output_size: new_output_size,
261 })
262 }
263
264 fn from_contraction_and_shapes(
265 contraction: &Contraction,
266 operand_shapes: &[Vec<usize>],
267 ) -> Result<Self, &'static str> {
268 let output_size = OutputSize::from_contraction_and_shapes(contraction, operand_shapes)?;
269
270 Ok(SizedContraction {
271 contraction: contraction.clone(),
272 output_size,
273 })
274 }
275
276 /// Create a SizedContraction from an already-created Contraction and a list
277 /// of operands.
278 /// ```
279 /// # use ndarray_einsum::*;
280 /// # use ndarray::prelude::*;
281 /// let m1: Array2<f64> = Array::zeros((2, 3));
282 /// let m2: Array2<f64> = Array::zeros((3, 4));
283 /// let c = Contraction::new("ij,jk->ik").unwrap();
284 /// let sc = SizedContraction::from_contraction_and_operands(&c, &[&m1, &m2]).unwrap();
285 /// assert_eq!(sc.output_size[&'i'], 2);
286 /// assert_eq!(sc.output_size[&'k'], 4);
287 /// assert_eq!(sc.output_size[&'j'], 3);
288 /// ```
289 pub fn from_contraction_and_operands<A>(
290 contraction: &Contraction,
291 operands: &[&dyn ArrayLike<A>],
292 ) -> Result<Self, &'static str> {
293 let operand_shapes = get_operand_shapes(operands);
294
295 SizedContraction::from_contraction_and_shapes(contraction, &operand_shapes)
296 }
297
298 /// Create a SizedContraction from an `einsum`-formatted input string and a slice
299 /// of `Vec<usize>`s containing the shapes of each operand.
300 /// ```
301 /// # use ndarray_einsum::*;
302 /// # use ndarray::prelude::*;
303 /// let sc = SizedContraction::from_string_and_shapes(
304 /// "ij,jk->ik",
305 /// &[vec![2, 3], vec![3, 4]]
306 /// ).unwrap();
307 /// assert_eq!(sc.output_size[&'i'], 2);
308 /// assert_eq!(sc.output_size[&'k'], 4);
309 /// assert_eq!(sc.output_size[&'j'], 3);
310 /// ```
311 pub fn from_string_and_shapes(
312 input_string: &str,
313 operand_shapes: &[Vec<usize>],
314 ) -> Result<Self, &'static str> {
315 let contraction = validate(input_string)?;
316 SizedContraction::from_contraction_and_shapes(&contraction, operand_shapes)
317 }
318
319 /// Create a SizedContraction from an `einsum`-formatted input string and a list
320 /// of operands.
321 ///
322 /// ```
323 /// # use ndarray_einsum::*;
324 /// # use ndarray::prelude::*;
325 /// let m1: Array2<f64> = Array::zeros((2, 3));
326 /// let m2: Array2<f64> = Array::zeros((3, 4));
327 /// let sc = SizedContraction::new("ij,jk->ik", &[&m1, &m2]).unwrap();
328 /// assert_eq!(sc.output_size[&'i'], 2);
329 /// assert_eq!(sc.output_size[&'k'], 4);
330 /// assert_eq!(sc.output_size[&'j'], 3);
331 /// ```
332 pub fn new<A>(
333 input_string: &str,
334 operands: &[&dyn ArrayLike<A>],
335 ) -> Result<Self, &'static str> {
336 let operand_shapes = get_operand_shapes(operands);
337
338 SizedContraction::from_string_and_shapes(input_string, &operand_shapes)
339 }
340
341 /// Perform the contraction on a set of operands.
342 ///
343 /// ```
344 /// # use ndarray_einsum::*;
345 /// # use ndarray::prelude::*;
346 /// let m1: Array2<f64> = Array::zeros((2, 3));
347 /// let m2: Array2<f64> = Array::zeros((3, 4));
348 /// let out: ArrayD<f64> = Array::zeros((2, 4)).into_dyn();
349 /// let sc = validate_and_size("ij,jk->ik", &[&m1, &m2]).unwrap();
350 /// assert_eq!(sc.contract_operands(&[&m1, &m2]), out);
351 /// ```
352 pub fn contract_operands<A: Clone + LinalgScalar>(
353 &self,
354 operands: &[&dyn ArrayLike<A>],
355 ) -> ArrayD<A> {
356 let cpc = EinsumPath::new(self);
357 cpc.contract_operands(operands)
358 }
359
360 /// Show as an `einsum`-formatted string.
361 ///
362 /// ```
363 /// # use ndarray_einsum::*;
364 /// # use ndarray::prelude::*;
365 /// let m1: Array2<f64> = Array::zeros((2, 3));
366 /// let m2: Array2<f64> = Array::zeros((3, 4));
367 /// let sc = validate_and_size("ij,jk", &[&m1, &m2]).unwrap();
368 /// assert_eq!(sc.as_einsum_string(), "ij,jk->ik");
369 /// ```
370 pub fn as_einsum_string(&self) -> String {
371 assert!(!self.contraction.operand_indices.is_empty());
372 let mut s: String = self.contraction.operand_indices[0]
373 .iter()
374 .cloned()
375 .collect();
376 for op in self.contraction.operand_indices[1..].iter() {
377 s.push(',');
378 for &c in op.iter() {
379 s.push(c);
380 }
381 }
382 s.push_str("->");
383 for &c in self.contraction.output_indices.iter() {
384 s.push(c);
385 }
386 s
387 }
388}
389
390/// Runs an input string through a regex and convert it to an EinsumParse.
391fn parse_einsum_string(input_string: &str) -> Option<EinsumParse> {
392 lazy_static! {
393 // Unwhitespaced version:
394 // ^([a-z]+)((?:,[a-z]+)*)(?:->([a-z]*))?$
395 static ref RE: Regex = Regex::new(r"(?x)
396 ^
397 (?P<first_operand>[a-z]+)
398 (?P<more_operands>(?:,[a-z]+)*)
399 (?:->(?P<output>[a-z]*))?
400 $
401 ").unwrap();
402 }
403 let captures = RE.captures(input_string)?;
404 let mut operand_indices = Vec::new();
405 let output_indices = captures.name("output").map(|s| String::from(s.as_str()));
406
407 operand_indices.push(String::from(&captures["first_operand"]));
408 for s in captures["more_operands"].split(',').skip(1) {
409 operand_indices.push(String::from(s));
410 }
411
412 Some(EinsumParse {
413 operand_indices,
414 output_indices,
415 })
416}
417
418/// Wrapper around [Contraction::new()](struct.Contraction.html#method.new).
419pub fn validate(input_string: &str) -> Result<Contraction, &'static str> {
420 Contraction::new(input_string)
421}
422
423/// Returns a vector holding one `Vec<usize>` for each operand.
424fn get_operand_shapes<A>(operands: &[&dyn ArrayLike<A>]) -> Vec<Vec<usize>> {
425 operands
426 .iter()
427 .map(|operand| Vec::from(operand.into_dyn_view().shape()))
428 .collect()
429}
430
431/// Wrapper around [SizedContraction::new()](struct.SizedContraction.html#method.new).
432pub fn validate_and_size<A>(
433 input_string: &str,
434 operands: &[&dyn ArrayLike<A>],
435) -> Result<SizedContraction, &'static str> {
436 SizedContraction::new(input_string, operands)
437}
438
439/// Create a [SizedContraction](struct.SizedContraction.html) and then optimize the order in which pairs of inputs will be contracted.
440pub fn validate_and_optimize_order<A>(
441 input_string: &str,
442 operands: &[&dyn ArrayLike<A>],
443 optimization_strategy: OptimizationMethod,
444) -> Result<ContractionOrder, &'static str> {
445 let sc = validate_and_size(input_string, operands)?;
446 Ok(generate_optimized_order(&sc, optimization_strategy))
447}