use std::collections::HashMap;
use tensorlogic_ir::{EinsumGraph, OpType};
#[derive(Debug, Clone, Default)]
pub struct ContractionOptStats {
pub contractions_reordered: usize,
pub flops_reduction_percent: f64,
pub memory_reduction_percent: f64,
pub intermediates_saved: usize,
pub total_processed: usize,
}
impl ContractionOptStats {
pub fn total_optimizations(&self) -> usize {
self.contractions_reordered + self.intermediates_saved
}
}
#[derive(Debug, Clone)]
pub struct ContractionOptConfig {
pub use_dynamic_programming: bool,
pub max_dp_size: usize,
pub flops_memory_tradeoff: f64,
pub enable_greedy_fallback: bool,
}
impl Default for ContractionOptConfig {
fn default() -> Self {
Self {
use_dynamic_programming: true,
max_dp_size: 26, flops_memory_tradeoff: 0.7, enable_greedy_fallback: true,
}
}
}
#[derive(Debug, Clone)]
pub struct TensorShape {
pub dims: Vec<Option<usize>>,
}
impl TensorShape {
pub fn new(dims: Vec<Option<usize>>) -> Self {
Self { dims }
}
pub fn num_elements(&self) -> Option<usize> {
let mut total = 1;
for &dim in &self.dims {
total *= dim?;
}
Some(total)
}
pub fn rank(&self) -> usize {
self.dims.len()
}
}
#[derive(Debug, Clone)]
pub struct ContractionPath {
pub steps: Vec<(usize, usize)>,
pub estimated_flops: f64,
pub estimated_memory: f64,
}
pub fn optimize_contractions(graph: &EinsumGraph) -> (EinsumGraph, ContractionOptStats) {
optimize_contractions_with_config(graph, &ContractionOptConfig::default())
}
pub fn optimize_contractions_with_config(
graph: &EinsumGraph,
config: &ContractionOptConfig,
) -> (EinsumGraph, ContractionOptStats) {
let optimized = graph.clone();
let mut stats = ContractionOptStats::default();
for node in graph.nodes.iter() {
if let OpType::Einsum { spec } = &node.op {
if let Some(optimal_path) = find_optimal_path(spec.as_str(), &node.inputs, config) {
let original_cost = estimate_einsum_cost(spec.as_str(), &node.inputs);
let new_cost = optimal_path.estimated_flops;
if new_cost < original_cost {
let reduction = (original_cost - new_cost) / original_cost * 100.0;
stats.flops_reduction_percent =
(stats.flops_reduction_percent + reduction) / 2.0;
stats.contractions_reordered += 1;
}
}
}
stats.total_processed += 1;
}
(optimized, stats)
}
fn find_optimal_path(
spec: &str,
inputs: &[usize],
config: &ContractionOptConfig,
) -> Option<ContractionPath> {
let (input_specs, output_spec) = parse_einsum_spec(spec)?;
if input_specs.len() != inputs.len() {
return None;
}
if config.use_dynamic_programming && inputs.len() <= config.max_dp_size {
find_optimal_path_dp(&input_specs, output_spec, config)
} else if config.enable_greedy_fallback {
find_optimal_path_greedy(&input_specs, output_spec)
} else {
None
}
}
fn find_optimal_path_dp(
input_specs: &[String],
_output_spec: &str,
config: &ContractionOptConfig,
) -> Option<ContractionPath> {
let n = input_specs.len();
if n < 2 {
return None;
}
let mut dp: HashMap<u64, (f64, Option<(u64, u64)>)> = HashMap::new();
for i in 0..n {
let mask = 1u64 << i;
dp.insert(mask, (0.0, None));
}
for mask in 1u64..(1u64 << n) {
if mask.count_ones() == 1 {
continue; }
let mut best_cost = f64::INFINITY;
let mut best_split = None;
let mut submask = mask;
while submask > 0 {
if submask != mask {
let complement = mask ^ submask;
let left_cost = dp.get(&submask).map(|(c, _)| *c).unwrap_or(0.0);
let right_cost = dp.get(&complement).map(|(c, _)| *c).unwrap_or(0.0);
let merge_cost = estimate_merge_cost(submask, complement, n);
let total_cost = left_cost + right_cost + merge_cost;
if total_cost < best_cost {
best_cost = total_cost;
best_split = Some((submask, complement));
}
}
submask = (submask.wrapping_sub(1)) & mask;
}
dp.insert(mask, (best_cost, best_split));
}
let full_mask = (1u64 << n) - 1;
let (final_cost, _) = dp.get(&full_mask)?;
Some(ContractionPath {
steps: vec![], estimated_flops: *final_cost * config.flops_memory_tradeoff,
estimated_memory: *final_cost * (1.0 - config.flops_memory_tradeoff),
})
}
fn find_optimal_path_greedy(input_specs: &[String], _output_spec: &str) -> Option<ContractionPath> {
let n = input_specs.len();
if n < 2 {
return None;
}
let mut steps = Vec::new();
let mut remaining: Vec<usize> = (0..n).collect();
let mut total_flops = 0.0;
while remaining.len() > 1 {
let mut best_pair = (0, 1);
let mut best_cost = f64::INFINITY;
for i in 0..remaining.len() {
for j in (i + 1)..remaining.len() {
let cost = estimate_pairwise_cost(remaining[i], remaining[j], n);
if cost < best_cost {
best_cost = cost;
best_pair = (i, j);
}
}
}
steps.push((remaining[best_pair.0], remaining[best_pair.1]));
total_flops += best_cost;
let new_idx = n + steps.len() - 1;
remaining.remove(best_pair.1);
remaining.remove(best_pair.0);
remaining.push(new_idx);
}
Some(ContractionPath {
steps,
estimated_flops: total_flops,
estimated_memory: total_flops * 0.5, })
}
fn parse_einsum_spec(spec: &str) -> Option<(Vec<String>, &str)> {
let parts: Vec<&str> = spec.split("->").collect();
if parts.len() != 2 {
return None;
}
let inputs: Vec<String> = parts[0].split(',').map(|s| s.trim().to_string()).collect();
Some((inputs, parts[1].trim()))
}
fn estimate_einsum_cost(_spec: &str, inputs: &[usize]) -> f64 {
let base_cost = inputs.len() as f64 * 1000.0;
let variance: f64 = inputs.iter().map(|&i| i as f64 * 10.0).sum();
base_cost + variance
}
fn estimate_merge_cost(mask1: u64, mask2: u64, _n: usize) -> f64 {
let size1 = mask1.count_ones() as f64;
let size2 = mask2.count_ones() as f64;
size1 * size2 * 100.0
}
fn estimate_pairwise_cost(idx1: usize, idx2: usize, _n: usize) -> f64 {
(idx1 as f64 + 1.0) * (idx2 as f64 + 1.0) * 50.0
}
pub fn analyze_contraction_path(path: &ContractionPath) -> String {
let mut analysis = String::new();
analysis.push_str("Contraction Path Analysis:\n");
analysis.push_str(&format!(" Steps: {}\n", path.steps.len()));
analysis.push_str(&format!(
" Estimated FLOPs: {:.2e}\n",
path.estimated_flops
));
analysis.push_str(&format!(
" Estimated Memory: {:.2e}\n",
path.estimated_memory
));
if path.estimated_flops > 1e9 {
analysis.push_str(" Warning: High computational cost\n");
}
if path.estimated_memory > 1e8 {
analysis.push_str(" Warning: High memory usage\n");
}
analysis
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_shape() {
let shape = TensorShape::new(vec![Some(10), Some(20), Some(30)]);
assert_eq!(shape.rank(), 3);
assert_eq!(shape.num_elements(), Some(6000));
}
#[test]
fn test_tensor_shape_unknown_dims() {
let shape = TensorShape::new(vec![Some(10), None, Some(30)]);
assert_eq!(shape.rank(), 3);
assert_eq!(shape.num_elements(), None);
}
#[test]
fn test_parse_einsum_spec() {
let spec = "ij,jk->ik";
let (inputs, output) = parse_einsum_spec(spec).unwrap();
assert_eq!(inputs.len(), 2);
assert_eq!(inputs[0], "ij");
assert_eq!(inputs[1], "jk");
assert_eq!(output, "ik");
}
#[test]
fn test_parse_einsum_spec_complex() {
let spec = "ijk,klm,mnp->ijnp";
let (inputs, output) = parse_einsum_spec(spec).unwrap();
assert_eq!(inputs.len(), 3);
assert_eq!(output, "ijnp");
}
#[test]
fn test_find_optimal_path_greedy() {
let inputs = vec!["ij".to_string(), "jk".to_string(), "kl".to_string()];
let output = "il";
let path = find_optimal_path_greedy(&inputs, output);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.steps.len(), 2); assert!(path.estimated_flops > 0.0);
}
#[test]
fn test_estimate_einsum_cost() {
let cost1 = estimate_einsum_cost("ij,jk->ik", &[0, 1]);
let cost2 = estimate_einsum_cost("ijk,klm,mnp->ijnp", &[0, 1, 2]);
assert!(cost1 > 0.0);
assert!(cost2 > cost1); }
#[test]
fn test_optimize_contractions() {
let graph = EinsumGraph::new();
let (_optimized, stats) = optimize_contractions(&graph);
assert_eq!(stats.contractions_reordered, 0);
}
#[test]
fn test_config_default() {
let config = ContractionOptConfig::default();
assert!(config.use_dynamic_programming);
assert_eq!(config.max_dp_size, 26);
assert!(config.flops_memory_tradeoff > 0.0);
assert!(config.flops_memory_tradeoff <= 1.0);
}
#[test]
fn test_stats_total_optimizations() {
let stats = ContractionOptStats {
contractions_reordered: 3,
flops_reduction_percent: 25.0,
memory_reduction_percent: 15.0,
intermediates_saved: 2,
total_processed: 10,
};
assert_eq!(stats.total_optimizations(), 5);
}
#[test]
fn test_analyze_contraction_path() {
let path = ContractionPath {
steps: vec![(0, 1), (2, 3)],
estimated_flops: 1e6,
estimated_memory: 1e5,
};
let analysis = analyze_contraction_path(&path);
assert!(analysis.contains("Steps: 2"));
assert!(analysis.contains("FLOPs"));
assert!(analysis.contains("Memory"));
}
#[test]
fn test_estimate_merge_cost() {
let cost1 = estimate_merge_cost(0b0001u64, 0b0010u64, 4);
let cost2 = estimate_merge_cost(0b0011u64, 0b1100u64, 4);
assert!(cost1 > 0.0);
assert!(cost2 > cost1); }
#[test]
fn test_estimate_pairwise_cost() {
let cost1 = estimate_pairwise_cost(0, 1, 3);
let cost2 = estimate_pairwise_cost(1, 2, 3);
assert!(cost1 > 0.0);
assert!(cost2 > 0.0);
}
#[test]
fn test_contraction_path_high_cost_warning() {
let path = ContractionPath {
steps: vec![(0, 1)],
estimated_flops: 1e10, estimated_memory: 1e9, };
let analysis = analyze_contraction_path(&path);
assert!(analysis.contains("Warning"));
}
}