use crate::*;
pub fn can_blas(
inputs: &[&str],
result: &str,
idx_removed: &ArrayIndexType,
shapes: Option<&[TensorShapeType]>,
) -> Option<&'static str> {
if inputs.len() != 2 {
return None;
}
let input_left = inputs[0];
let input_right = inputs[1];
let left_set: ArrayIndexType = input_left.chars().collect();
let right_set: ArrayIndexType = input_right.chars().collect();
let left_vec: Vec<char> = input_left.chars().collect();
let right_vec: Vec<char> = input_right.chars().collect();
for c in &left_set | &right_set {
let nl = left_vec.iter().filter(|&x| x == &c).count();
let nr = right_vec.iter().filter(|&x| x == &c).count();
if (nl > 1) || (nr > 1) || (nl + nr > 2) {
return None;
}
let in_result = result.contains(c);
if nl + nr - 1 == in_result as usize {
return None;
}
}
if let Some(shapes) = shapes {
for c in idx_removed {
let left_pos = left_vec.iter().position(|&x| x == *c).unwrap();
let right_pos = right_vec.iter().position(|&x| x == *c).unwrap();
if shapes[0][left_pos] != shapes[1][right_pos] {
return None;
}
}
}
if idx_removed.is_empty() {
return Some("OUTER/EINSUM");
}
let keep_left = &left_set - idx_removed;
let keep_right = &right_set - idx_removed;
let rs = idx_removed.len();
let input_right_starts: String = right_vec[..rs].iter().cloned().collect();
let input_right_ends: String = right_vec[right_vec.len() - rs..].iter().cloned().collect();
if input_left == input_right {
return Some("DOT");
} else if left_set == right_set {
return Some("DOT/EINSUM");
}
if input_left.ends_with(&input_right_starts)
|| input_left.starts_with(&input_right_ends)
|| input_left.ends_with(&input_right_ends)
|| input_left.starts_with(&input_right_starts)
{
return Some("GEMM");
}
if keep_left.is_empty() || keep_right.is_empty() {
return Some("GEMV/EINSUM");
}
Some("TDOT")
}