use crate::SizedContraction;
use hashbrown::HashSet;
#[derive(Debug, Clone)]
pub enum OperandNumber {
Input(usize),
IntermediateResult(usize),
}
#[derive(Debug, Clone)]
pub struct OperandNumPair {
pub lhs: OperandNumber,
pub rhs: OperandNumber,
}
#[derive(Debug, Clone)]
pub struct Pair {
pub sized_contraction: SizedContraction,
pub operand_nums: OperandNumPair,
}
#[derive(Debug, Clone)]
pub enum ContractionOrder {
Singleton(SizedContraction),
Pairs(Vec<Pair>),
}
#[derive(Debug)]
pub enum OptimizationMethod {
Naive,
Reverse,
Greedy,
Optimal,
Branch,
}
fn get_remaining_indices(operand_indices: &[Vec<char>], output_indices: &[char]) -> HashSet<char> {
let mut result: HashSet<char> = HashSet::new();
for &c in operand_indices.iter().flat_map(|s| s.iter()) {
result.insert(c);
}
for &c in output_indices.iter() {
result.insert(c);
}
result
}
fn get_existing_indices(lhs_indices: &[char], rhs_indices: &[char]) -> HashSet<char> {
let mut result: HashSet<char> = lhs_indices.iter().cloned().collect();
for &c in rhs_indices.iter() {
result.insert(c);
}
result
}
fn generate_permuted_contraction(
sized_contraction: &SizedContraction,
tensor_order: &[usize],
) -> SizedContraction {
assert_eq!(
sized_contraction.contraction.operand_indices.len(),
tensor_order.len()
);
let mut new_operand_indices = Vec::new();
for &i in tensor_order {
new_operand_indices.push(sized_contraction.contraction.operand_indices[i].clone());
}
sized_contraction
.subset(
&new_operand_indices,
&sized_contraction.contraction.output_indices,
)
.unwrap()
}
fn generate_sized_contraction_pair(
lhs_operand_indices: &[char],
rhs_operand_indices: &[char],
output_indices: &[char],
orig_contraction: &SizedContraction,
) -> SizedContraction {
orig_contraction
.subset(
&[lhs_operand_indices.to_vec(), rhs_operand_indices.to_vec()],
output_indices,
)
.unwrap()
}
fn generate_path(sized_contraction: &SizedContraction, tensor_order: &[usize]) -> ContractionOrder {
let permuted_contraction = generate_permuted_contraction(sized_contraction, tensor_order);
match permuted_contraction.contraction.operand_indices.len() {
1 => {
ContractionOrder::Singleton(permuted_contraction)
}
2 => {
let sc = generate_sized_contraction_pair(
&permuted_contraction.contraction.operand_indices[0],
&permuted_contraction.contraction.operand_indices[1],
&permuted_contraction.contraction.output_indices,
&permuted_contraction,
);
let operand_num_pair = OperandNumPair {
lhs: OperandNumber::Input(tensor_order[0]),
rhs: OperandNumber::Input(tensor_order[1]),
};
let only_step = Pair {
sized_contraction: sc,
operand_nums: operand_num_pair,
};
ContractionOrder::Pairs(vec![only_step])
}
_ => {
let mut steps = Vec::new();
let mut output_indices = permuted_contraction.contraction.operand_indices[0].clone();
for idx_of_lhs in 0..(permuted_contraction.contraction.operand_indices.len() - 1) {
let lhs_indices = output_indices.clone();
let idx_of_rhs = idx_of_lhs + 1;
let rhs_indices = &permuted_contraction.contraction.operand_indices[idx_of_rhs];
output_indices =
if idx_of_rhs == (permuted_contraction.contraction.operand_indices.len() - 1) {
permuted_contraction.contraction.output_indices.clone()
} else {
let existing_indices = get_existing_indices(&lhs_indices, rhs_indices);
let remaining_indices = get_remaining_indices(
&permuted_contraction.contraction.operand_indices[(idx_of_rhs + 1)..],
&permuted_contraction.contraction.output_indices,
);
existing_indices
.intersection(&remaining_indices)
.cloned()
.collect()
};
let sc = generate_sized_contraction_pair(
&lhs_indices,
rhs_indices,
&output_indices,
&permuted_contraction,
);
let operand_nums = if idx_of_lhs == 0 {
OperandNumPair {
lhs: OperandNumber::Input(tensor_order[idx_of_lhs]), rhs: OperandNumber::Input(tensor_order[idx_of_rhs]), }
} else {
OperandNumPair {
lhs: OperandNumber::IntermediateResult(idx_of_lhs - 1),
rhs: OperandNumber::Input(tensor_order[idx_of_rhs]),
}
};
steps.push(Pair {
sized_contraction: sc,
operand_nums,
});
}
ContractionOrder::Pairs(steps)
}
}
}
fn naive_order(sized_contraction: &SizedContraction) -> Vec<usize> {
(0..sized_contraction.contraction.operand_indices.len()).collect()
}
fn reverse_order(sized_contraction: &SizedContraction) -> Vec<usize> {
(0..sized_contraction.contraction.operand_indices.len())
.rev()
.collect()
}
pub fn generate_optimized_order(
sized_contraction: &SizedContraction,
strategy: OptimizationMethod,
) -> ContractionOrder {
let tensor_order = match strategy {
OptimizationMethod::Naive => naive_order(sized_contraction),
OptimizationMethod::Reverse => reverse_order(sized_contraction),
_ => panic!("Unsupported optimization method"),
};
generate_path(sized_contraction, &tensor_order)
}