use crate::*;
pub type GreedyChooseFn = Box<
dyn FnMut(
&mut BinaryHeap<GreedyContractionType>,
&BTreeMap<ArrayIndexType, usize>,
) -> Option<GreedyContractionType>,
>;
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct GreedyCostType {
pub cost: SizeType,
pub id1: usize,
pub id2: usize,
}
impl Eq for GreedyCostType {}
#[allow(clippy::derive_ord_xor_partial_ord)]
impl Ord for GreedyCostType {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
pub struct GreedyContractionType {
pub cost: Reverse<GreedyCostType>,
pub k1: ArrayIndexType,
pub k2: ArrayIndexType,
pub k12: ArrayIndexType,
}
fn get_candidate(
output: &ArrayIndexType,
size_dict: &SizeDictType,
remaining: &BTreeMap<ArrayIndexType, usize>,
footprints: &BTreeMap<ArrayIndexType, SizeType>,
dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
k1: &ArrayIndexType,
k2: &ArrayIndexType,
cost_fn: paths::CostFn,
) -> GreedyContractionType {
let either = k1 | k2;
let two = k1 & k2;
let one = &either - &two;
let part1 = either.intersection(output);
let part2 = two.intersection(&dim_ref_counts[&3]);
let part3 = one.intersection(&dim_ref_counts[&2]);
let k12: ArrayIndexType = part1.chain(part2).chain(part3).cloned().collect();
let size12 = helpers::compute_size_by_dict(k12.iter(), size_dict);
let footprint1 = footprints[k1];
let footprint2 = footprints[k2];
let cost = cost_fn(size12, footprint1, footprint2, 0, 0, 0);
let id1 = remaining[k1];
let id2 = remaining[k2];
let (k1, id1, k2, id2) =
if id1 > id2 { (k1.clone(), id1, k2.clone(), id2) } else { (k2.clone(), id2, k1.clone(), id1) };
GreedyContractionType { cost: Reverse(GreedyCostType { cost, id1, id2 }), k1, k2, k12 }
}
fn push_candidate(
output: &ArrayIndexType,
size_dict: &SizeDictType,
remaining: &BTreeMap<ArrayIndexType, usize>,
footprints: &BTreeMap<ArrayIndexType, SizeType>,
dim_ref_counts: &BTreeMap<usize, BTreeSet<char>>,
k1: &ArrayIndexType,
k2s: &[ArrayIndexType],
queue: &mut BinaryHeap<GreedyContractionType>,
push_all: bool,
cost_fn: paths::CostFn,
) {
let candidates: Vec<GreedyContractionType> = k2s
.iter()
.map(|k2| get_candidate(output, size_dict, remaining, footprints, dim_ref_counts, k1, k2, cost_fn))
.collect();
if push_all {
candidates.into_iter().for_each(|c| queue.push(c));
} else if let Some(max_cand) = candidates.into_iter().max() {
queue.push(max_cand);
}
}
fn update_ref_counts(
dim_to_keys: &BTreeMap<char, BTreeSet<ArrayIndexType>>,
dim_ref_counts: &mut BTreeMap<usize, BTreeSet<char>>,
dims: &ArrayIndexType,
output: &ArrayIndexType,
) {
for dim in dims {
if output.contains(dim) {
continue;
}
let count = dim_to_keys.get(dim).map(|s| s.len()).unwrap_or(0);
match count {
0..=1 => {
dim_ref_counts.get_mut(&2).unwrap().remove(dim);
dim_ref_counts.get_mut(&3).unwrap().remove(dim);
},
2 => {
dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
dim_ref_counts.get_mut(&3).unwrap().remove(dim);
},
3.. => {
dim_ref_counts.get_mut(&2).unwrap().insert(*dim);
dim_ref_counts.get_mut(&3).unwrap().insert(*dim);
},
}
}
}
pub fn simple_chooser(
queue: &mut BinaryHeap<GreedyContractionType>,
remaining: &BTreeMap<ArrayIndexType, usize>,
) -> Option<GreedyContractionType> {
while let Some(cand) = queue.pop() {
if remaining.contains_key(&cand.k1) && remaining.contains_key(&cand.k2) {
return Some(cand);
}
}
None
}
pub fn ssa_greedy_optimize(
inputs: &[&ArrayIndexType],
output: &ArrayIndexType,
size_dict: &SizeDictType,
choose_fn: Option<&mut GreedyChooseFn>,
cost_fn: Option<paths::CostFn>,
) -> PathType {
if inputs.is_empty() {
return vec![];
}
if inputs.len() == 1 {
return vec![vec![0]];
}
let push_all = choose_fn.is_none();
let mut default_chooser: GreedyChooseFn = Box::new(simple_chooser);
let choose_fn: &mut GreedyChooseFn = if let Some(choose_fn) = choose_fn { choose_fn } else { &mut default_chooser };
let cost_fn = cost_fn.unwrap_or(paths::util::memory_removed(false));
let common_dims = inputs.iter().skip(1).fold(inputs[0].clone(), |acc, s| &acc & s);
let output = output | &common_dims;
let mut remaining = BTreeMap::new(); let mut ssa_ids = inputs.len();
let mut ssa_path = Vec::new();
for (ssa_id, &key) in inputs.iter().enumerate() {
let key = key.clone();
if let Some(&existing_id) = remaining.get(&key) {
ssa_path.push(vec![existing_id, ssa_id]);
remaining.insert(key, ssa_ids);
ssa_ids += 1;
} else {
remaining.insert(key, ssa_id);
}
}
let mut dim_to_keys: BTreeMap<char, BTreeSet<ArrayIndexType>> = BTreeMap::new();
for key in remaining.keys() {
for dim in key - &output {
dim_to_keys.entry(dim).or_default().insert(key.clone());
}
}
let mut dim_ref_counts = BTreeMap::from([(2, BTreeSet::new()), (3, BTreeSet::new())]);
for (&dim, keys) in &dim_to_keys {
if keys.len() >= 2 {
dim_ref_counts.get_mut(&2).unwrap().insert(dim);
}
if keys.len() >= 3 {
dim_ref_counts.get_mut(&3).unwrap().insert(dim);
}
}
output.iter().for_each(|dim| {
dim_ref_counts.get_mut(&2).unwrap().remove(dim);
dim_ref_counts.get_mut(&3).unwrap().remove(dim);
});
let mut footprints: BTreeMap<ArrayIndexType, SizeType> =
remaining.keys().map(|k| (k.clone(), helpers::compute_size_by_dict(k.iter(), size_dict))).collect();
let mut queue = BinaryHeap::new();
for dim_keys in dim_to_keys.values() {
let mut dim_keys_list = dim_keys.iter().cloned().collect_vec();
dim_keys_list.sort_by_key(|k| remaining[k]);
for i in 0..dim_keys_list.len().saturating_sub(1) {
let k1 = &dim_keys_list[i];
let k2s_guess = &dim_keys_list[i + 1..];
push_candidate(
&output,
size_dict,
&remaining,
&footprints,
&dim_ref_counts,
k1,
k2s_guess,
&mut queue,
push_all,
cost_fn,
);
}
}
while !queue.is_empty() {
let Some(con) = choose_fn(&mut queue, &remaining) else {
continue; };
let GreedyContractionType { k1, k2, k12, .. } = con;
let ssa_id1 = remaining.remove(&k1).unwrap();
let ssa_id2 = remaining.remove(&k2).unwrap();
for dim in &k1 - &output {
dim_to_keys.get_mut(&dim).unwrap().remove(&k1);
}
for dim in &k2 - &output {
dim_to_keys.get_mut(&dim).unwrap().remove(&k2);
}
ssa_path.push(vec![ssa_id1, ssa_id2]);
if remaining.contains_key(&k12) {
ssa_path.push(vec![remaining[&k12], ssa_ids]);
ssa_ids += 1;
} else {
for dim in &k12 - &output {
dim_to_keys.get_mut(&dim).unwrap().insert(k12.clone());
}
}
remaining.insert(k12.clone(), ssa_ids);
ssa_ids += 1;
let updated_dims = &(&k1 | &k2) - &output;
update_ref_counts(&dim_to_keys, &mut dim_ref_counts, &updated_dims, &output);
footprints.insert(k12.clone(), helpers::compute_size_by_dict(k12.iter(), size_dict));
let k1 = k12;
let k2s: BTreeSet<ArrayIndexType> =
(&k1 - &output).into_iter().flat_map(|dim| dim_to_keys[&dim].clone()).filter(|k| k != &k1).collect();
if !k2s.is_empty() {
push_candidate(
&output,
size_dict,
&remaining,
&footprints,
&dim_ref_counts,
&k1,
&k2s.into_iter().collect_vec(),
&mut queue,
push_all,
cost_fn,
);
}
}
#[derive(Clone, Debug, PartialEq, PartialOrd)]
struct FinalEntry {
size: SizeType,
ssa_id: usize,
key: ArrayIndexType,
}
impl Eq for FinalEntry {}
#[allow(clippy::derive_ord_xor_partial_ord)]
impl Ord for FinalEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap()
}
}
let mut final_queue: BinaryHeap<Reverse<FinalEntry>> = remaining
.into_iter()
.map(|(key, ssa_id)| {
let size = helpers::compute_size_by_dict((&key & &output).iter(), size_dict);
Reverse(FinalEntry { size, ssa_id, key })
})
.collect();
let Some(Reverse(FinalEntry { ssa_id: ssa_id1, key: k1, .. })) = final_queue.pop() else {
return ssa_path;
};
let mut current_id = ssa_id1;
let mut current_k = k1;
while let Some(Reverse(FinalEntry { ssa_id: ssa_id2, key: k2, .. })) = final_queue.pop() {
ssa_path.push(vec![current_id.min(ssa_id2), current_id.max(ssa_id2)]);
let k12: ArrayIndexType = &(¤t_k | &k2) & &output;
let cost = helpers::compute_size_by_dict(k12.iter(), size_dict);
let new_ssa_id = ssa_ids;
ssa_ids += 1;
final_queue.push(Reverse(FinalEntry { size: cost, ssa_id: new_ssa_id, key: k12.clone() }));
let Reverse(FinalEntry { ssa_id: new_id, key: new_k, .. }) = final_queue.pop().unwrap();
current_id = new_id;
current_k = new_k;
}
ssa_path
}
pub fn greedy(
inputs: &[&ArrayIndexType],
output: &ArrayIndexType,
size_dict: &SizeDictType,
memory_limit: Option<SizeType>,
choose_fn: Option<&mut GreedyChooseFn>,
cost_fn: Option<paths::CostFn>,
) -> Result<PathType, String> {
if memory_limit.is_some() {
let mut branch_optimizer = paths::branch_bound::BranchBound::from("branch-1");
return branch_optimizer.optimize_path(inputs, output, size_dict, memory_limit);
}
let ssa_path = ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn);
Ok(paths::util::ssa_to_linear(&ssa_path))
}
#[derive(Default)]
pub struct Greedy {
cost_fn: Option<paths::CostFn>,
choose_fn: Option<GreedyChooseFn>,
}
impl std::fmt::Debug for Greedy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Greedy").field("cost_fn", &self.cost_fn).field("choose_fn", &self.choose_fn.is_some()).finish()
}
}
impl PathOptimizer for Greedy {
fn optimize_path(
&mut self,
inputs: &[&ArrayIndexType],
output: &ArrayIndexType,
size_dict: &SizeDictType,
memory_limit: Option<SizeType>,
) -> Result<PathType, String> {
greedy(inputs, output, size_dict, memory_limit, self.choose_fn.as_mut(), self.cost_fn)
}
}