use crate::*;
#[derive(Debug, Clone, Default)]
pub struct Optimal {
output: ArrayIndexType,
size_dict: SizeDictType,
memory_limit: Option<SizeType>,
best_flops: Option<SizeType>,
best_ssa_path: PathType,
size_cache: BTreeMap<ArrayIndexType, SizeType>,
}
impl Optimal {
fn optimal_iterate(&mut self, path: PathType, remaining: &[usize], inputs: &[&ArrayIndexType], flops: SizeType) {
if remaining.len() == 1 {
self.best_flops = Some(flops);
self.best_ssa_path = path;
return;
}
for i in 0..remaining.len() {
for j in (i + 1)..remaining.len() {
let a = remaining[i];
let b = remaining[j];
let (i, j) = if a < b { (a, b) } else { (b, a) };
let (k12, flops12) =
paths::util::calc_k12_flops(inputs, &self.output, remaining, i, j, &self.size_dict);
let new_flops = flops + flops12;
if self.best_flops.is_some_and(|best| new_flops >= best) {
continue;
}
if let Some(limit) = self.memory_limit {
let size12 = self
.size_cache
.entry(k12.clone())
.or_insert_with(|| helpers::compute_size_by_dict(k12.iter(), &self.size_dict));
if *size12 > limit {
let oversize_flops = flops
+ paths::util::compute_oversize_flops(inputs, remaining, &self.output, &self.size_dict);
if self.best_flops.is_none_or(|best| oversize_flops < best) {
self.best_flops = Some(oversize_flops);
let mut new_path = path.clone();
new_path.push(remaining.to_vec());
self.best_ssa_path = new_path;
}
continue;
}
}
let mut new_remaining = remaining.to_vec();
new_remaining.retain(|&x| x != i && x != j);
new_remaining.push(inputs.len());
let mut new_inputs = inputs.to_vec();
new_inputs.push(&k12);
let mut new_path = path.clone();
new_path.push(vec![i, j]);
self.optimal_iterate(new_path, &new_remaining, &new_inputs, new_flops);
}
}
}
}
pub fn optimal(
inputs: &[&ArrayIndexType],
output: &ArrayIndexType,
size_dict: &SizeDictType,
memory_limit: Option<SizeType>,
) -> Result<PathType, String> {
let mut optimizer =
Optimal { output: output.clone(), size_dict: size_dict.clone(), memory_limit, ..Default::default() };
let path = Vec::new();
let inputs_indices: Vec<usize> = (0..inputs.len()).collect();
let flops = SizeType::zero();
optimizer.optimal_iterate(path, &inputs_indices, inputs, flops);
Ok(paths::util::ssa_to_linear(&optimizer.best_ssa_path))
}
impl PathOptimizer for Optimal {
fn optimize_path(
&mut self,
inputs: &[&ArrayIndexType],
output: &ArrayIndexType,
size_dict: &SizeDictType,
memory_limit: Option<SizeType>,
) -> Result<PathType, String> {
optimal(inputs, output, size_dict, memory_limit)
}
}
#[test]
fn playground() {
use std::collections::BTreeMap;
let time = std::time::Instant::now();
let inputs = [&"abd".chars().collect(), &"ac".chars().collect(), &"bdc".chars().collect()];
let output = "".chars().collect();
let size_dict = BTreeMap::from([('a', 1), ('b', 2), ('c', 3), ('d', 4)]);
let path = optimal(&inputs, &output, &size_dict, Some(SizeType::from_usize(5000).unwrap())).unwrap();
assert_eq!(path, vec![vec![0, 2], vec![0, 1]]);
let duration = time.elapsed();
println!("Optimal path found in: {duration:?}");
}
#[test]
fn playground_issue() {
use std::collections::BTreeMap;
let time = std::time::Instant::now();
let inputs = [&"bgk".chars().collect(), &"bkd".chars().collect(), &"bk".chars().collect()];
let output = "bgd".chars().collect();
let size_dict = BTreeMap::from([('b', 64), ('g', 8), ('k', 4096), ('d', 128)]);
let path = optimal(&inputs, &output, &size_dict, None);
println!("{path:?}");
let duration = time.elapsed();
println!("Optimal path found in: {duration:?}");
}