use crate::optimizers::{
generate_optimized_order, ContractionOrder, OperandNumber, OptimizationMethod,
};
use crate::{ArrayLike, SizedContraction};
use ndarray::prelude::*;
use ndarray::LinalgScalar;
use std::collections::HashSet;
use std::fmt::Debug;
mod singleton_contractors;
use singleton_contractors::{
Diagonalization, DiagonalizationAndSummation, Identity, Permutation, PermutationAndSummation,
Summation,
};
mod pair_contractors;
pub use pair_contractors::TensordotGeneral;
use pair_contractors::{
BroadcastProductGeneral, HadamardProduct, HadamardProductGeneral, MatrixScalarProduct,
MatrixScalarProductGeneral, ScalarMatrixProduct, ScalarMatrixProductGeneral,
StackedTensordotGeneral, TensordotFixedPosition,
};
mod strategies;
use strategies::{PairMethod, PairSummary, SingletonMethod, SingletonSummary};
#[cfg(feature = "serde")]
use serde::Serialize;
pub trait SingletonViewer<A>: Debug {
fn view_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayViewD<'b, A>
where
'a: 'b,
A: Clone + LinalgScalar;
}
pub trait SingletonContractor<A>: Debug {
fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
where
'a: 'b,
A: Clone + LinalgScalar;
}
pub trait PairContractor<A>: Debug {
fn contract_pair<'a, 'b, 'c, 'd>(
&self,
lhs: &'b ArrayViewD<'a, A>,
rhs: &'d ArrayViewD<'c, A>,
) -> ArrayD<A>
where
'a: 'b,
'c: 'd,
A: Clone + LinalgScalar;
fn contract_and_assign_pair<'a, 'b, 'c, 'd, 'e, 'f>(
&self,
lhs: &'b ArrayViewD<'a, A>,
rhs: &'d ArrayViewD<'c, A>,
out: &'f mut ArrayViewMutD<'e, A>,
) where
'a: 'b,
'c: 'd,
'e: 'f,
A: Clone + LinalgScalar,
{
let result = self.contract_pair(lhs, rhs);
out.assign(&result);
}
}
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct SingletonContraction<A> {
method: SingletonMethod,
#[cfg_attr(feature = "serde", serde(skip))]
op: Box<dyn SingletonContractor<A>>,
}
impl<A> SingletonContraction<A> {
pub fn new(sc: &SizedContraction) -> Self {
let singleton_summary = SingletonSummary::new(&sc);
let method = singleton_summary.get_strategy();
SingletonContraction {
method,
op: match method {
SingletonMethod::Identity => Box::new(Identity::new(sc)),
SingletonMethod::Permutation => Box::new(Permutation::new(sc)),
SingletonMethod::Summation => Box::new(Summation::new(sc)),
SingletonMethod::Diagonalization => Box::new(Diagonalization::new(sc)),
SingletonMethod::PermutationAndSummation => {
Box::new(PermutationAndSummation::new(sc))
}
SingletonMethod::DiagonalizationAndSummation => {
Box::new(DiagonalizationAndSummation::new(sc))
}
},
}
}
}
impl<A> SingletonContractor<A> for SingletonContraction<A> {
fn contract_singleton<'a, 'b>(&self, tensor: &'b ArrayViewD<'a, A>) -> ArrayD<A>
where
'a: 'b,
A: Clone + LinalgScalar,
{
self.op.contract_singleton(tensor)
}
}
impl<A> Debug for SingletonContraction<A> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"SingletonContraction {{ method: {:?}, op: {:?} }}",
self.method, self.op
)
}
}
#[cfg_attr(feature = "serde", derive(Serialize))]
struct SimplificationMethodAndOutput<A> {
method: SingletonMethod,
#[cfg_attr(feature = "serde", serde(skip))]
op: Box<dyn SingletonContractor<A>>,
new_indices: Vec<char>,
einsum_string: String,
}
impl<A> SimplificationMethodAndOutput<A> {
fn from_indices_and_sizes(
this_input_indices: &[char],
other_input_indices: &[char],
output_indices: &[char],
orig_contraction: &SizedContraction,
) -> Option<Self> {
let this_input_uniques: HashSet<char> = this_input_indices.iter().cloned().collect();
let other_input_uniques: HashSet<char> = other_input_indices.iter().cloned().collect();
let output_uniques: HashSet<char> = output_indices.iter().cloned().collect();
let other_and_output: HashSet<char> = other_input_uniques
.union(&output_uniques)
.cloned()
.collect();
let desired_uniques: HashSet<char> = this_input_uniques
.intersection(&other_and_output)
.cloned()
.collect();
let new_indices: Vec<char> = desired_uniques.iter().cloned().collect();
let simplification_sc = orig_contraction
.subset(&[this_input_indices.to_vec()], &new_indices)
.unwrap();
let SingletonContraction { method, op } = SingletonContraction::new(&simplification_sc);
match method {
SingletonMethod::Identity | SingletonMethod::Permutation => None,
_ => Some(SimplificationMethodAndOutput {
method,
op,
new_indices,
einsum_string: simplification_sc.as_einsum_string(),
}),
}
}
}
impl<A> Debug for SimplificationMethodAndOutput<A> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"SingletonContraction {{ method: {:?}, op: {:?}, new_indices: {:?}, einsum_string: {:?} }}",
self.method, self.op, self.new_indices, self.einsum_string
)
}
}
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct PairContraction<A> {
lhs_simplification: Option<SimplificationMethodAndOutput<A>>,
rhs_simplification: Option<SimplificationMethodAndOutput<A>>,
method: PairMethod,
#[cfg_attr(feature = "serde", serde(skip))]
op: Box<dyn PairContractor<A>>,
simplified_einsum_string: String,
}
impl<A> PairContraction<A> {
pub fn new(sc: &SizedContraction) -> Self {
assert_eq!(sc.contraction.operand_indices.len(), 2);
let lhs_indices = &sc.contraction.operand_indices[0];
let rhs_indices = &sc.contraction.operand_indices[1];
let output_indices = &sc.contraction.output_indices;
let lhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
&lhs_indices,
&rhs_indices,
&output_indices,
sc,
);
let rhs_simplification = SimplificationMethodAndOutput::from_indices_and_sizes(
&rhs_indices,
&lhs_indices,
&output_indices,
sc,
);
let new_lhs_indices = match &lhs_simplification {
Some(ref s) => s.new_indices.clone(),
None => lhs_indices.clone(),
};
let new_rhs_indices = match &rhs_simplification {
Some(ref s) => s.new_indices.clone(),
None => rhs_indices.clone(),
};
let reduced_sc = sc
.subset(&[new_lhs_indices, new_rhs_indices], &output_indices)
.unwrap();
let pair_summary = PairSummary::new(&reduced_sc);
let method = pair_summary.get_strategy();
let op: Box<dyn PairContractor<A>> = match method {
PairMethod::HadamardProduct => {
Box::new(HadamardProduct::new(&reduced_sc))
}
PairMethod::HadamardProductGeneral => {
Box::new(HadamardProductGeneral::new(&reduced_sc))
}
PairMethod::ScalarMatrixProduct => {
Box::new(ScalarMatrixProduct::new(&reduced_sc))
}
PairMethod::ScalarMatrixProductGeneral => {
Box::new(ScalarMatrixProductGeneral::new(&reduced_sc))
}
PairMethod::MatrixScalarProduct => {
Box::new(MatrixScalarProduct::new(&reduced_sc))
}
PairMethod::MatrixScalarProductGeneral => {
Box::new(MatrixScalarProductGeneral::new(&reduced_sc))
}
PairMethod::TensordotFixedPosition => {
Box::new(TensordotFixedPosition::new(&reduced_sc))
}
PairMethod::TensordotGeneral => Box::new(TensordotGeneral::new(&reduced_sc)),
PairMethod::StackedTensordotGeneral => {
Box::new(StackedTensordotGeneral::new(&reduced_sc))
}
PairMethod::BroadcastProductGeneral => {
Box::new(BroadcastProductGeneral::new(&reduced_sc))
}
};
PairContraction {
lhs_simplification,
rhs_simplification,
method,
op,
simplified_einsum_string: reduced_sc.as_einsum_string(),
}
}
}
impl<A> PairContractor<A> for PairContraction<A> {
fn contract_pair<'a, 'b, 'c, 'd>(
&self,
lhs: &'b ArrayViewD<'a, A>,
rhs: &'d ArrayViewD<'c, A>,
) -> ArrayD<A>
where
'a: 'b,
'c: 'd,
A: Clone + LinalgScalar,
{
match (&self.lhs_simplification, &self.rhs_simplification) {
(None, None) => self.op.contract_pair(lhs, rhs),
(Some(lhs_contraction), None) => self
.op
.contract_pair(&lhs_contraction.op.contract_singleton(lhs).view(), rhs),
(None, Some(rhs_contraction)) => self
.op
.contract_pair(lhs, &rhs_contraction.op.contract_singleton(rhs).view()),
(Some(lhs_contraction), Some(rhs_contraction)) => self.op.contract_pair(
&lhs_contraction.op.contract_singleton(lhs).view(),
&rhs_contraction.op.contract_singleton(rhs).view(),
),
}
}
}
impl<A> Debug for PairContraction<A> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"PairContraction {{ \
lhs_simplification: {:?}, \
rhs_simplification: {:?}, \
method: {:?}, \
op: {:?}, \
simplified_einsum_string: {:?}",
self.lhs_simplification,
self.rhs_simplification,
self.method,
self.op,
self.simplified_einsum_string
)
}
}
#[cfg_attr(feature = "serde", derive(Serialize))]
#[derive(Debug)]
pub enum EinsumPathSteps<A> {
SingletonContraction(SingletonContraction<A>),
PairContractions(Vec<PairContraction<A>>),
}
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct EinsumPath<A> {
pub contraction_order: ContractionOrder,
pub steps: EinsumPathSteps<A>,
}
impl<A> EinsumPath<A> {
pub fn new(sc: &SizedContraction) -> Self {
let contraction_order = generate_optimized_order(&sc, OptimizationMethod::Naive);
EinsumPath::from_path(&contraction_order)
}
pub fn from_path(contraction_order: &ContractionOrder) -> Self {
match contraction_order {
ContractionOrder::Singleton(sized_contraction) => EinsumPath {
contraction_order: contraction_order.clone(),
steps: EinsumPathSteps::SingletonContraction(SingletonContraction::new(
sized_contraction,
)),
},
ContractionOrder::Pairs(order_steps) => {
let mut steps = Vec::new();
for step in order_steps.iter() {
steps.push(PairContraction::new(&step.sized_contraction));
}
EinsumPath {
contraction_order: contraction_order.clone(),
steps: EinsumPathSteps::PairContractions(steps),
}
}
}
}
}
impl<A> EinsumPath<A> {
pub fn contract_operands(&self, operands: &[&dyn ArrayLike<A>]) -> ArrayD<A>
where
A: Clone + LinalgScalar,
{
match (&self.steps, &self.contraction_order) {
(EinsumPathSteps::SingletonContraction(c), ContractionOrder::Singleton(_)) => {
c.contract_singleton(&operands[0].into_dyn_view())
}
(EinsumPathSteps::PairContractions(steps), ContractionOrder::Pairs(order_steps)) => {
let mut intermediate_results: Vec<ArrayD<A>> = Vec::new();
for (step, order_step) in steps.iter().zip(order_steps.iter()) {
let lhs = match order_step.operand_nums.lhs {
OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
};
let rhs = match order_step.operand_nums.rhs {
OperandNumber::Input(pos) => operands[pos].into_dyn_view(),
OperandNumber::IntermediateResult(pos) => intermediate_results[pos].view(),
};
let intermediate_result = step.contract_pair(&lhs, &rhs);
intermediate_results.push(intermediate_result);
}
intermediate_results.pop().unwrap()
}
_ => panic!(),
}
}
}
impl<A> Debug for EinsumPath<A> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self.steps {
EinsumPathSteps::SingletonContraction(step) => write!(f, "only_step: {:?}", step),
EinsumPathSteps::PairContractions(steps) => write!(f, "steps: {:?}", steps),
}
}
}