use crate::testing::build_shapes;
use opt_einsum_path::contract::contract_path;
use opt_einsum_path::typing::*;
use rstest::rstest;
use std::collections::BTreeMap;
mod testing;
use itertools::Itertools;
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic(expected = "Unknown optimization kind: optimall")]
fn test_bad_path_option() {
contract_path("a,b,c", &[vec![1], vec![2], vec![3]], "optimall", None).unwrap();
}
#[rstest]
#[case("optimal", "abd,ac,bdc->", Some(5000 as _), vec![vec![0, 2], vec![0, 1]])]
#[case("optimal", "abd,ac,bdc->", Some(0 as _) , vec![vec![0, 1, 2]]) ]
fn test_path_optimal(
#[case] name: &str,
#[case] expr: &str,
#[case] memory_limit: Option<SizeType>,
#[case] expected: PathType,
) {
let views = build_shapes(expr, None, true).unwrap();
let path_ret = contract_path(expr, &views, name, memory_limit).unwrap();
assert_eq!(path_ret.0, expected);
}
#[rstest]
#[case("greedy", "abd,ac,bdc->", Some(5000 as _), vec![vec![0, 2], vec![0, 1]])]
#[case("greedy", "abd,ac,bdc->", Some(0 as _) , vec![vec![0, 1, 2]]) ]
fn test_path_greedy(
#[case] name: &str,
#[case] expr: &str,
#[case] memory_limit: Option<SizeType>,
#[case] expected: PathType,
) {
let views = build_shapes(expr, None, true).unwrap();
let path_ret = contract_path(expr, &views, name, memory_limit).unwrap();
assert_eq!(path_ret.0, expected);
}
#[test]
fn test_memory_paths() {
let expression = "abc,bdef,fghj,cem,mhk,ljk->adgl";
let views = build_shapes(expression, None, true).unwrap();
let path_ret = contract_path(expression, &views, "optimal", Some(5 as _)).unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1, 2, 3, 4, 5]]);
let path_ret = contract_path(expression, &views, "greedy", Some(5 as _)).unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1, 2, 3, 4, 5]]);
let path_ret = contract_path(expression, &views, "greedy", None).unwrap();
assert_eq!(path_ret.0, vec![vec![0, 3], vec![0, 4], vec![0, 2], vec![0, 2], vec![0, 1]]);
let path_ret = contract_path(expression, &views, "greedy", None).unwrap();
assert_eq!(path_ret.0, vec![vec![0, 3], vec![0, 4], vec![0, 2], vec![0, 2], vec![0, 1]]);
}
#[rstest]
#[case("greedy" , "eb,cb,fb->cef" , vec![vec![0, 2], vec![0, 1]]) ]
#[case("branch-all", "eb,cb,fb->cef" , vec![vec![0, 2], vec![0, 1]]) ]
#[case("branch-2" , "eb,cb,fb->cef" , vec![vec![0, 2], vec![0, 1]]) ]
#[case("optimal" , "eb,cb,fb->cef" , vec![vec![0, 2], vec![0, 1]]) ]
#[case("dp" , "eb,cb,fb->cef" , vec![vec![1, 2], vec![0, 1]]) ]
#[case("greedy" , "dd,fb,be,cdb->cef" , vec![vec![0, 3], vec![0, 1], vec![0, 1]])]
#[case("branch-all", "dd,fb,be,cdb->cef" , vec![vec![0, 3], vec![0, 1], vec![0, 1]])]
#[case("branch-2" , "dd,fb,be,cdb->cef" , vec![vec![0, 3], vec![0, 1], vec![0, 1]])]
#[case("optimal" , "dd,fb,be,cdb->cef" , vec![vec![0, 3], vec![0, 1], vec![0, 1]])]
#[case("dp" , "dd,fb,be,cdb->cef" , vec![vec![0, 3], vec![0, 2], vec![0, 1]])]
#[case("greedy" , "bca,cdb,dbf,afc->" , vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("branch-all", "bca,cdb,dbf,afc->" , vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("branch-2" , "bca,cdb,dbf,afc->" , vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("optimal" , "bca,cdb,dbf,afc->" , vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("dp" , "bca,cdb,dbf,afc->" , vec![vec![1, 2], vec![1, 2], vec![0, 1]])]
#[case("greedy" , "dcc,fce,ea,dbf->ab", vec![vec![1, 2], vec![0, 1], vec![0, 1]])]
#[case("branch-all", "dcc,fce,ea,dbf->ab", vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("branch-2" , "dcc,fce,ea,dbf->ab", vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
#[case("optimal" , "dcc,fce,ea,dbf->ab", vec![vec![1, 2], vec![0, 2], vec![0, 1]])]
fn test_path_edge_cases(#[case] name: &str, #[case] expr: &str, #[case] expected: PathType) {
let views = build_shapes(expr, None, true).unwrap();
let path_ret = contract_path(expr, &views, name, None).unwrap();
assert_eq!(path_ret.0, expected);
}
#[rstest]
#[case("optimal", "a,->a", 1)]
#[case("optimal", "ab->ab", 1)]
#[case("optimal", ",a,->a", 2)]
#[case("optimal", ",,a,->a", 3)]
#[case("optimal", ",,->", 2)]
fn test_path_scalar_cases(#[case] name: &str, #[case] expr: &str, #[case] expected: usize) {
let views = build_shapes(expr, None, true).unwrap();
let path_ret = contract_path(expr, &views, name, None).unwrap();
assert_eq!(path_ret.0.len(), expected);
}
#[test]
fn test_optimal_edge_cases() {
let expression = "a,ac,ab,ad,cd,bd,bc->";
let size_dict = BTreeMap::from([('a', 20), ('b', 20), ('c', 20), ('d', 20)]);
let edge_test4 = build_shapes(expression, Some(&size_dict), true).unwrap();
let path_ret = contract_path(expression, &edge_test4, "greedy", "max-input").unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1], vec![0, 1, 2, 3, 4, 5]]);
let path_ret = contract_path(expression, &edge_test4, "optimal", "max-input").unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1], vec![0, 1, 2, 3, 4, 5]]);
}
#[test]
fn test_greedy_edge_cases() {
let expression = "abc,cfd,dbe,efa";
let size_dict = BTreeMap::from([('a', 20), ('b', 20), ('c', 20), ('d', 20), ('e', 20), ('f', 20)]);
let tensors = build_shapes(expression, Some(&size_dict), true).unwrap();
let path_ret = contract_path(expression, &tensors, "greedy", "max-input").unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1, 2, 3]]);
let path_ret = contract_path(expression, &tensors, "optimal", None).unwrap();
assert_eq!(path_ret.0, vec![vec![0, 1], vec![0, 2], vec![0, 1]]);
}
#[test]
fn test_dp_edge_cases_dimension_1() {
let expression = "nlp,nlq,pl->n";
let shapes = vec![
vec![1, 1, 1], vec![1, 1, 1], vec![1, 1], ];
let (_, info) = contract_path(expression, &shapes, "dp", None).unwrap();
assert_eq!(info.scale_list.iter().max(), Some(&3));
}
#[test]
fn test_dp_edge_cases_all_singlet_indices() {
let expression = "a,bcd,efg->";
let shapes = vec![
vec![2], vec![2, 2, 2], vec![2, 2, 2], ];
let (_, info) = contract_path(expression, &shapes, "dp", None).unwrap();
println!("{info}");
assert_eq!(info.scale_list.iter().max(), Some(&3));
}
#[test]
fn test_custom_dp_can_optimize_for_outer_products() {
use opt_einsum_path::paths::dp::DynamicProgramming;
let expression = "a,b,abc->c";
let shapes = vec![
vec![2], vec![2], vec![2, 2, 3], ];
let optimizer1 = DynamicProgramming { search_outer: false, ..Default::default() };
let optimizer2 = DynamicProgramming { search_outer: true, ..Default::default() };
let (_, info1) = contract_path(expression, &shapes, optimizer1, None).unwrap();
let (_, info2) = contract_path(expression, &shapes, optimizer2, None).unwrap();
assert!(info2.opt_cost < info1.opt_cost);
assert_eq!(info1.opt_cost, 36 as _);
assert_eq!(info2.opt_cost, 28 as _);
}
#[test]
fn test_custom_dp_can_optimize_for_size() {
let expression = "qgcf,sotr,klb,jlretia,hpn,nseha,jgoqm,ipkb,cdfm,d->";
let shapes = vec![
vec![5, 2, 9, 4],
vec![4, 9, 5, 9],
vec![5, 4, 2],
vec![5, 4, 9, 7, 5, 3, 6],
vec![5, 2, 8],
vec![8, 4, 7, 5, 6],
vec![5, 2, 9, 5, 8],
vec![3, 2, 5, 2],
vec![9, 3, 4, 8],
vec![3],
];
let (_, info1) = contract_path(expression, &shapes, "dp-flops", None).unwrap();
let (_, info2) = contract_path(expression, &shapes, "dp-size", None).unwrap();
assert!(info1.opt_cost < info2.opt_cost);
assert!(info1.largest_intermediate > info2.largest_intermediate);
assert_eq!(info1.opt_cost, 663054 as _);
assert_eq!(info2.opt_cost, 1114440 as _);
assert_eq!(info1.largest_intermediate, 18900 as _);
assert_eq!(info2.largest_intermediate, 2016 as _);
assert_eq!(info1.path, vec![
vec![4, 5],
vec![2, 5],
vec![2, 7],
vec![5, 6],
vec![1, 5],
vec![1, 4],
vec![0, 3],
vec![0, 2],
vec![0, 1]
]);
assert_eq!(info2.path, vec![
vec![2, 7],
vec![3, 8],
vec![3, 7],
vec![2, 6],
vec![1, 5],
vec![1, 4],
vec![1, 3],
vec![1, 2],
vec![0, 1]
]);
}
#[test]
fn test_custom_dp_can_set_cost_cap() {
use opt_einsum_path::paths::dp::DynamicProgramming;
let expression = "ad,cfb,fdc,abge,eg->";
let shapes = vec![vec![8, 8], vec![6, 9, 5], vec![9, 8, 6], vec![8, 5, 6, 4], vec![4, 6]];
let opt1 = DynamicProgramming { cost_cap: true.into(), ..Default::default() };
let opt2 = DynamicProgramming { cost_cap: false.into(), ..Default::default() };
let opt3 = DynamicProgramming { cost_cap: Some(100 as _).into(), ..Default::default() };
let (_, info1) = contract_path(expression, &shapes, opt1, None).unwrap();
let (_, info2) = contract_path(expression, &shapes, opt2, None).unwrap();
let (_, info3) = contract_path(expression, &shapes, opt3, None).unwrap();
assert_eq!(info1.opt_cost, info2.opt_cost);
assert_eq!(info1.opt_cost, info3.opt_cost);
assert_eq!(info1.path, vec![vec![1, 2], vec![0, 3], vec![0, 2], vec![0, 1]]);
assert_eq!(info1.path, info2.path);
assert_eq!(info1.path, info3.path);
}
#[rstest]
#[case("dp-flops" , 663054, 18900, vec![(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)])]
#[case("dp-size" , 1114440, 2016, vec![(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)])]
#[case("dp-write" , 983790, 2016, vec![(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)])]
#[case("dp-combo" , 973518, 2016, vec![(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)])]
#[case("dp-limit" , 983832, 2016, vec![(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)])]
#[case("dp-combo-256", 983790, 2016, vec![(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)])]
#[case("dp-limit-256", 983832, 2016, vec![(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)])]
fn test_custom_dp_can_set_minimize(
#[case] minimize: &str,
#[case] cost: usize,
#[case] width: usize,
#[case] path: Vec<(usize, usize)>,
) {
let expression = "qgcf,sotr,klb,jlretia,hpn,nseha,jgoqm,ipkb,cdfm,d->";
let shapes = vec![
vec![5, 2, 9, 4],
vec![4, 9, 5, 9],
vec![5, 4, 2],
vec![5, 4, 9, 7, 5, 3, 6],
vec![5, 2, 8],
vec![8, 4, 7, 5, 6],
vec![5, 2, 9, 5, 8],
vec![3, 2, 5, 2],
vec![9, 3, 4, 8],
vec![3],
];
let (_, info) = contract_path(expression, &shapes, minimize, None).unwrap();
let path: PathType = path.iter().map(|(i, j)| vec![*i, *j]).collect();
assert_eq!(info.opt_cost, cost as _);
assert_eq!(info.largest_intermediate, width as _);
assert_eq!(info.path, path);
}
#[test]
fn test_dp_errors_when_no_contractions_found() {
let expression = "jk,igelb,ho,nfcbd,ca,gk,hef,nal,omj,dim->";
let shapes = vec![
vec![3, 4],
vec![8, 6, 4, 8, 5],
vec![6, 9],
vec![4, 9, 6, 5, 8],
vec![6, 8],
vec![6, 4],
vec![6, 4, 9],
vec![4, 8, 8],
vec![9, 4, 3],
vec![8, 8, 4],
];
let (_, info) = contract_path(expression, &shapes, "dp-size", None).unwrap();
let min_cost = info.largest_intermediate;
assert!(contract_path(expression, &shapes, "dp", min_cost).is_ok());
assert!(contract_path(expression, &shapes, "dp", min_cost - 1.0).is_err());
}
#[test]
fn test_can_optimize_outer_products() {
let expression = "ab,cd,ef,fg";
let shapes = vec![vec![10, 10], vec![10, 10], vec![10, 10], vec![10, 2]];
let optimizers = ["branch-2", "branch-all", "optimal", "dp", "greedy"];
for &opt in &optimizers {
let (path, _) = contract_path(expression, &shapes, opt, None).unwrap();
assert_eq!(path, vec![vec![2, 3], vec![0, 2], vec![0, 1]]);
}
}
#[test]
fn test_large_path() {
for num_symbols in [2, 3, 26, 26 + 26, 256 - 140, 300] {
let symbols: String = (0..num_symbols).map(opt_einsum_path::parser::get_symbol).collect();
let dimension_dict: BTreeMap<char, usize> = symbols.chars().zip([2, 3, 4].into_iter().cycle()).collect();
let expression: String =
symbols.chars().collect_vec().windows(2).map(|w| w.iter().collect::<String>()).collect_vec().join(",");
let tensors = build_shapes(&expression, Some(&dimension_dict), true).unwrap();
let _ = contract_path(&expression, &tensors, "greedy", None).unwrap();
}
}
#[test]
fn test_ellipsis_in_path() {
let expression = "...ij,...jk->...ik";
let shapes = vec![vec![2, 3, 4, 5], vec![2, 3, 5, 6]];
let (path, info) = contract_path(expression, &shapes, "optimal", None).unwrap();
assert_eq!(path, vec![vec![0, 1]]);
assert_eq!(info.opt_cost, 1440 as _);
println!("{info}");
}
}