use crate::*;
pub fn get_symbol(i: u32) -> char {
match i {
0..26 => char::from(b'a' + i as u8),
26..52 => char::from(b'A' + (i - 26) as u8),
52..55296 => char::from_u32(i + 140).unwrap(),
55296.. => char::from_u32(i + 2048).unwrap(),
}
}
pub fn gen_unused_symbols(used: &str, n: usize) -> String {
let mut unused = Vec::with_capacity(n);
let mut i = 0;
while unused.len() < n {
let c = get_symbol(i);
if !used.contains(c) {
unused.push(c);
}
i += 1;
}
unused.into_iter().collect()
}
pub fn find_output_str(subscripts: &str) -> String {
let tmp_subscripts: String = subscripts.split(',').collect();
let mut unique_chars = BTreeSet::new();
let mut repeated_chars = BTreeSet::new();
for c in tmp_subscripts.chars() {
if !unique_chars.insert(c) {
repeated_chars.insert(c);
}
}
(&unique_chars - &repeated_chars).into_iter().collect()
}
pub fn find_output_shape(inputs: &[&str], shapes: &[TensorShapeType], output: &str) -> TensorShapeType {
output
.chars()
.map(|c| {
shapes
.iter()
.zip(inputs.iter())
.filter_map(|(shape, input)| input.chars().position(|x| x == c).map(|loc| shape.get(loc).copied()))
.flatten()
.max()
.unwrap() })
.collect()
}
pub fn convert_subscripts(old_sub: &[&str], symbol_map: &BTreeMap<&str, &str>) -> String {
old_sub.iter().map(|&s| if s == "..." { "..." } else { symbol_map[s] }).collect()
}
pub fn parse_einsum_input(
subscripts: &str,
operands: &[TensorShapeType],
) -> Result<(String, String, Vec<TensorShapeType>), String> {
let einsum_str = subscripts.split_whitespace().collect::<String>();
if (einsum_str.contains('-') || einsum_str.contains('>'))
&& (einsum_str.matches('-').count() > 1
|| einsum_str.matches('>').count() > 1
|| einsum_str.matches("->").count() != 1)
{
return Err("Subscripts can only contain one '->'.".to_string());
}
let mut subscripts = einsum_str.to_string();
if einsum_str.contains('.') {
let used: String = einsum_str.replace(".", "").replace(",", "").replace("->", "");
let ellipse_inds = gen_unused_symbols(&used, operands.iter().map(|s| s.len()).max().unwrap_or(0));
let mut longest = 0;
let (split_subscripts, out_sub) = if einsum_str.contains("->") {
let parts = einsum_str.split("->").collect_vec();
let input_tmp = parts[0];
(input_tmp.split(',').collect_vec(), true)
} else {
(einsum_str.split(',').collect_vec(), false)
};
let mut processed_subscripts = Vec::new();
for (num, sub) in split_subscripts.iter().enumerate() {
if sub.contains('.') {
if sub.matches('.').count() != 3 || !sub.contains("...") {
return Err("Invalid Ellipses.".to_string());
}
let ellipse_count = if operands[num].is_empty() {
0
} else if operands[num].len() >= sub.len() - 3 {
operands[num].len() - (sub.len() - 3)
} else {
return Err("Ellipses lengths do not match.".to_string());
};
longest = longest.max(ellipse_count);
if ellipse_count == 0 {
processed_subscripts.push(sub.replace("...", ""));
} else {
let replacement = &ellipse_inds[ellipse_inds.len() - ellipse_count..];
processed_subscripts.push(sub.replace("...", replacement));
}
} else {
processed_subscripts.push(sub.to_string());
}
}
subscripts = processed_subscripts.join(",");
let out_ellipse = if longest == 0 { "" } else { &ellipse_inds[ellipse_inds.len() - longest..] };
subscripts = if out_sub {
let parts: Vec<&str> = einsum_str.split("->").collect();
let output_sub = parts[1];
format!("{subscripts}->{}", output_sub.replace("...", out_ellipse))
} else {
let output_subscript = find_output_str(&subscripts);
let normal_inds: String = output_subscript.chars().filter(|c| !out_ellipse.contains(*c)).collect();
format!("{subscripts}->{out_ellipse}{normal_inds}")
};
}
let (input_subscripts, output_subscript) = if subscripts.contains("->") {
let parts: Vec<&str> = subscripts.split("->").collect();
(parts[0].to_string(), parts[1].to_string())
} else {
(subscripts.clone(), find_output_str(&subscripts))
};
for char in output_subscript.chars() {
if output_subscript.matches(char).count() != 1 {
return Err(format!("Output character '{char}' appeared more than once in the output."));
}
if !input_subscripts.contains(char) {
return Err(format!("Output character '{char}' did not appear in the input"));
}
}
let input_count = input_subscripts.split(',').count();
if input_count != operands.len() {
return Err(format!(
"Number of einsum subscripts, {input_count}, must be equal to the number of operands, {}.",
operands.len()
));
}
Ok((input_subscripts, output_subscript, operands.to_vec()))
}
#[test]
fn playground() {
for i in 0..500 {
println!("{} -> '{:?}'", i, char::from_u32(i));
}
}