use crate::{
generate_optimized_order, ArrayLike, ContractionOrder, EinsumPath, OptimizationMethod,
};
use hashbrown::{HashMap, HashSet};
use lazy_static::lazy_static;
use ndarray::prelude::*;
use ndarray::LinalgScalar;
use regex::Regex;
#[derive(Debug)]
struct EinsumParse {
operand_indices: Vec<String>,
output_indices: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Contraction {
pub operand_indices: Vec<Vec<char>>,
pub output_indices: Vec<char>,
pub summation_indices: Vec<char>,
}
impl Contraction {
pub fn new(input_string: &str) -> Result<Self, &'static str> {
let p = parse_einsum_string(input_string).ok_or("Invalid string")?;
Contraction::from_parse(&p)
}
fn from_parse(parse: &EinsumParse) -> Result<Self, &'static str> {
let requested_output_indices: Vec<char> = match &parse.output_indices {
Some(s) => s.chars().collect(),
_ => {
let mut input_indices = HashMap::new();
for c in parse.operand_indices.iter().flat_map(|s| s.chars()) {
*input_indices.entry(c).or_insert(0) += 1;
}
let mut unique_indices: Vec<char> = input_indices
.iter()
.filter(|&(_, &v)| v == 1)
.map(|(&k, _)| k)
.collect();
unique_indices.sort();
unique_indices
}
};
let operand_indices: Vec<Vec<char>> = parse
.operand_indices
.iter()
.map(|x| x.chars().collect::<Vec<char>>())
.collect();
Contraction::from_indices(&operand_indices, &requested_output_indices)
}
fn from_indices(
operand_indices: &[Vec<char>],
output_indices: &[char],
) -> Result<Self, &'static str> {
let mut input_char_counts = HashMap::new();
for &c in operand_indices.iter().flat_map(|operand| operand.iter()) {
*input_char_counts.entry(c).or_insert(0) += 1;
}
let mut distinct_output_indices = HashMap::new();
for &c in output_indices.iter() {
*distinct_output_indices.entry(c).or_insert(0) += 1;
}
for (&c, &n) in distinct_output_indices.iter() {
if n > 1 {
return Err("Requested output has duplicate index");
}
if !input_char_counts.contains_key(&c) {
return Err("Requested output contains an index not found in inputs");
}
}
let mut summation_indices: Vec<char> = input_char_counts
.keys()
.filter(|&c| !distinct_output_indices.contains_key(c))
.cloned()
.collect();
summation_indices.sort();
let cloned_operand_indices: Vec<Vec<char>> = operand_indices.to_vec();
Ok(Contraction {
operand_indices: cloned_operand_indices,
output_indices: output_indices.to_vec(),
summation_indices,
})
}
}
pub type OutputSize = HashMap<char, usize>;
trait OutputSizeMethods {
fn from_contraction_and_shapes(
contraction: &Contraction,
operand_shapes: &[Vec<usize>],
) -> Result<OutputSize, &'static str>;
}
impl OutputSizeMethods for OutputSize {
fn from_contraction_and_shapes(
contraction: &Contraction,
operand_shapes: &[Vec<usize>],
) -> Result<Self, &'static str> {
if contraction.operand_indices.len() != operand_shapes.len() {
return Err(
"number of operands in contraction does not match number of operands supplied",
);
}
let mut index_lengths: OutputSize = HashMap::new();
for (indices, operand_shape) in contraction.operand_indices.iter().zip(operand_shapes) {
if indices.len() != operand_shape.len() {
return Err(
"number of indices in one or more operands does not match dimensions of operand",
);
}
for (&c, &n) in indices.iter().zip(operand_shape) {
let existing_n = index_lengths.entry(c).or_insert(n);
if *existing_n != n {
return Err("repeated index with different size");
}
}
}
Ok(index_lengths)
}
}
#[derive(Debug, Clone)]
pub struct SizedContraction {
pub contraction: Contraction,
pub output_size: OutputSize,
}
impl SizedContraction {
pub fn subset(
&self,
new_operand_indices: &[Vec<char>],
new_output_indices: &[char],
) -> Result<Self, &'static str> {
let all_operand_indices: HashSet<char> = new_operand_indices
.iter()
.flat_map(|operand| operand.iter())
.cloned()
.collect();
if all_operand_indices
.iter()
.any(|c| !self.output_size.contains_key(c))
{
return Err("Character found in new_operand_indices but not in self.output_size");
}
let new_contraction = Contraction::from_indices(new_operand_indices, new_output_indices)?;
let new_output_size: OutputSize = self
.output_size
.iter()
.filter(|&(&k, _)| all_operand_indices.contains(&k))
.map(|(&k, &v)| (k, v))
.collect();
Ok(SizedContraction {
contraction: new_contraction,
output_size: new_output_size,
})
}
fn from_contraction_and_shapes(
contraction: &Contraction,
operand_shapes: &[Vec<usize>],
) -> Result<Self, &'static str> {
let output_size = OutputSize::from_contraction_and_shapes(contraction, operand_shapes)?;
Ok(SizedContraction {
contraction: contraction.clone(),
output_size,
})
}
pub fn from_contraction_and_operands<A>(
contraction: &Contraction,
operands: &[&dyn ArrayLike<A>],
) -> Result<Self, &'static str> {
let operand_shapes = get_operand_shapes(operands);
SizedContraction::from_contraction_and_shapes(contraction, &operand_shapes)
}
pub fn from_string_and_shapes(
input_string: &str,
operand_shapes: &[Vec<usize>],
) -> Result<Self, &'static str> {
let contraction = validate(input_string)?;
SizedContraction::from_contraction_and_shapes(&contraction, operand_shapes)
}
pub fn new<A>(
input_string: &str,
operands: &[&dyn ArrayLike<A>],
) -> Result<Self, &'static str> {
let operand_shapes = get_operand_shapes(operands);
SizedContraction::from_string_and_shapes(input_string, &operand_shapes)
}
pub fn contract_operands<A: Clone + LinalgScalar>(
&self,
operands: &[&dyn ArrayLike<A>],
) -> ArrayD<A> {
let cpc = EinsumPath::new(self);
cpc.contract_operands(operands)
}
pub fn as_einsum_string(&self) -> String {
assert!(!self.contraction.operand_indices.is_empty());
let mut s: String = self.contraction.operand_indices[0]
.iter()
.cloned()
.collect();
for op in self.contraction.operand_indices[1..].iter() {
s.push(',');
for &c in op.iter() {
s.push(c);
}
}
s.push_str("->");
for &c in self.contraction.output_indices.iter() {
s.push(c);
}
s
}
}
fn parse_einsum_string(input_string: &str) -> Option<EinsumParse> {
lazy_static! {
static ref RE: Regex = Regex::new(r"(?x)
^
(?P<first_operand>[a-z]+)
(?P<more_operands>(?:,[a-z]+)*)
(?:->(?P<output>[a-z]*))?
$
").unwrap();
}
let captures = RE.captures(input_string)?;
let mut operand_indices = Vec::new();
let output_indices = captures.name("output").map(|s| String::from(s.as_str()));
operand_indices.push(String::from(&captures["first_operand"]));
for s in captures["more_operands"].split(',').skip(1) {
operand_indices.push(String::from(s));
}
Some(EinsumParse {
operand_indices,
output_indices,
})
}
pub fn validate(input_string: &str) -> Result<Contraction, &'static str> {
Contraction::new(input_string)
}
fn get_operand_shapes<A>(operands: &[&dyn ArrayLike<A>]) -> Vec<Vec<usize>> {
operands
.iter()
.map(|operand| Vec::from(operand.into_dyn_view().shape()))
.collect()
}
pub fn validate_and_size<A>(
input_string: &str,
operands: &[&dyn ArrayLike<A>],
) -> Result<SizedContraction, &'static str> {
SizedContraction::new(input_string, operands)
}
pub fn validate_and_optimize_order<A>(
input_string: &str,
operands: &[&dyn ArrayLike<A>],
optimization_strategy: OptimizationMethod,
) -> Result<ContractionOrder, &'static str> {
let sc = validate_and_size(input_string, operands)?;
Ok(generate_optimized_order(&sc, optimization_strategy))
}