use opt_einsum_path::blas::can_blas;
use opt_einsum_path::typing::*;
use rstest::rstest;
#[rstest]
#[case((vec!["k" , "k" ], "" , "k" ), Some("DOT") )] #[case((vec!["ijk", "ijk"], "" , "ijk"), Some("DOT") )] #[case((vec!["ij" , "jk" ], "ik", "j" ), Some("GEMM") )] #[case((vec!["ijl", "jlk"], "ik", "jl" ), Some("GEMM") )] #[case((vec!["ij" , "kj" ], "ik", "j" ), Some("GEMM") )] #[case((vec!["ijl", "kjl"], "ik", "jl" ), Some("GEMM") )] #[case((vec!["ji" , "jk" ], "ik", "j" ), Some("GEMM") )] #[case((vec!["jli", "jlk"], "ik", "jl" ), Some("GEMM") )] #[case((vec!["ji" , "kj" ], "ik", "j" ), Some("GEMM") )] #[case((vec!["jli", "kjl"], "ik", "jl" ), Some("GEMM") )] #[case((vec!["ij" , "jk" ], "ki", "j" ), Some("GEMM") )]
#[case((vec!["ijl", "jlk"], "ki", "jl" ), Some("GEMM") )]
#[case((vec!["ilj", "jlk"], "ik", "jl" ), Some("TDOT") )] #[case((vec!["ijl", "ljk"], "ik", "jl" ), Some("TDOT") )] #[case((vec!["ijk", "ikj"], "" , "ijk"), Some("DOT/EINSUM") )] #[case((vec!["i" , "j" ], "ij", "" ), Some("OUTER/EINSUM"))] #[case((vec!["ijk", "ik" ], "j" , "ik" ), Some("GEMV/EINSUM") )] #[case((vec!["ijj", "jk" ], "ik", "j" ), None )] #[case((vec!["ijk", "j" ], "ij", "" ), None )] #[case((vec!["ij" , "ij" ], "ij", "" ), None )] fn test_can_blas(#[case] inp: (Vec<&str>, &str, &str), #[case] benchmark: Option<&'static str>) {
let (inputs, result, idx_removed) = inp;
let idx_removed: ArrayIndexType = idx_removed.chars().collect();
assert_eq!(
can_blas(&inputs, result, &idx_removed, None),
benchmark,
"Failed for inputs: {inputs:?}, result: {result}, idx_removed: {idx_removed:?}"
);
}