use crate::eincode::{EinCode, NestedEinsum, SlicedEinsum};
use crate::Label;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
#[derive(Debug, thiserror::Error)]
pub enum JsonError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON parse error: {0}")]
Parse(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum LabelType {
Char,
Int64,
Int,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContractionOrder<L: Label> {
Nested(NestedEinsum<L>),
Sliced(SlicedEinsum<L>),
}
impl<L: Label> ContractionOrder<L> {
pub fn is_nested(&self) -> bool {
matches!(self, ContractionOrder::Nested(_))
}
pub fn is_sliced(&self) -> bool {
matches!(self, ContractionOrder::Sliced(_))
}
pub fn as_nested(&self) -> Option<&NestedEinsum<L>> {
match self {
ContractionOrder::Nested(n) => Some(n),
ContractionOrder::Sliced(_) => None,
}
}
pub fn as_sliced(&self) -> Option<&SlicedEinsum<L>> {
match self {
ContractionOrder::Nested(_) => None,
ContractionOrder::Sliced(s) => Some(s),
}
}
pub fn into_nested(self) -> Option<NestedEinsum<L>> {
match self {
ContractionOrder::Nested(n) => Some(n),
ContractionOrder::Sliced(_) => None,
}
}
pub fn into_sliced(self) -> Option<SlicedEinsum<L>> {
match self {
ContractionOrder::Nested(_) => None,
ContractionOrder::Sliced(s) => Some(s),
}
}
}
impl<L: Label> From<NestedEinsum<L>> for ContractionOrder<L> {
fn from(n: NestedEinsum<L>) -> Self {
ContractionOrder::Nested(n)
}
}
impl<L: Label> From<SlicedEinsum<L>> for ContractionOrder<L> {
fn from(s: SlicedEinsum<L>) -> Self {
ContractionOrder::Sliced(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
struct ContractionOrderJson<L: Label> {
label_type: LabelType,
inputs: Vec<Vec<L>>,
output: Vec<L>,
tree: NestedEinsumTree<L>,
#[serde(skip_serializing_if = "Option::is_none")]
slices: Option<Vec<L>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum NestedEinsumTree<L: Label> {
Leaf {
isleaf: bool,
#[serde(rename = "tensorindex")]
tensor_index: usize,
},
Node {
isleaf: bool,
args: Vec<NestedEinsumTree<L>>,
eins: EinCodeJson<L>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EinCodeJson<L: Label> {
pub ixs: Vec<Vec<L>>,
pub iy: Vec<L>,
}
impl<L: Label> From<&NestedEinsum<L>> for NestedEinsumTree<L> {
fn from(nested: &NestedEinsum<L>) -> Self {
match nested {
NestedEinsum::Leaf { tensor_index } => NestedEinsumTree::Leaf {
isleaf: true,
tensor_index: *tensor_index,
},
NestedEinsum::Node { args, eins } => NestedEinsumTree::Node {
isleaf: false,
args: args.iter().map(|a| a.into()).collect(),
eins: EinCodeJson {
ixs: eins.ixs.clone(),
iy: eins.iy.clone(),
},
},
}
}
}
impl<L: Label> From<NestedEinsumTree<L>> for NestedEinsum<L> {
fn from(tree: NestedEinsumTree<L>) -> Self {
match tree {
NestedEinsumTree::Leaf { tensor_index, .. } => NestedEinsum::leaf(tensor_index),
NestedEinsumTree::Node { args, eins, .. } => {
let nested_args: Vec<NestedEinsum<L>> =
args.into_iter().map(|a| a.into()).collect();
NestedEinsum::node(nested_args, EinCode::new(eins.ixs, eins.iy))
}
}
}
}
pub trait ToJson<L: Label + Serialize> {
fn to_json_value(&self) -> Result<serde_json::Value, JsonError>;
}
impl<L: Label + Serialize> ToJson<L> for NestedEinsum<L> {
fn to_json_value(&self) -> Result<serde_json::Value, JsonError> {
let json = self.to_contraction_order_json();
Ok(serde_json::to_value(&json)?)
}
}
impl<L: Label + Serialize> ToJson<L> for SlicedEinsum<L> {
fn to_json_value(&self) -> Result<serde_json::Value, JsonError> {
let json = self.to_contraction_order_json();
Ok(serde_json::to_value(&json)?)
}
}
impl<L: Label> NestedEinsum<L> {
fn to_contraction_order_json(&self) -> ContractionOrderJson<L> {
let original_ixs = self.to_eincode_ixs();
let original_iy = self.root_output();
ContractionOrderJson {
label_type: detect_label_type::<L>(),
inputs: original_ixs,
output: original_iy,
tree: self.into(),
slices: None,
}
}
}
impl<L: Label> SlicedEinsum<L> {
fn to_contraction_order_json(&self) -> ContractionOrderJson<L> {
let original_ixs = self.eins.to_eincode_ixs();
let original_iy = self.eins.root_output();
ContractionOrderJson {
label_type: detect_label_type::<L>(),
inputs: original_ixs,
output: original_iy,
tree: (&self.eins).into(),
slices: Some(self.slicing.clone()),
}
}
}
impl<L: Label> NestedEinsum<L> {
fn to_eincode_ixs(&self) -> Vec<Vec<L>> {
let mut ixs: Vec<(usize, Vec<L>)> = Vec::new();
self.collect_leaf_ixs(&mut ixs);
ixs.sort_by_key(|(idx, _)| *idx);
ixs.into_iter().map(|(_, ix)| ix).collect()
}
fn collect_leaf_ixs(&self, ixs: &mut Vec<(usize, Vec<L>)>) {
match self {
NestedEinsum::Leaf { tensor_index } => {
if ixs.iter().all(|(idx, _)| *idx != *tensor_index) {
ixs.push((*tensor_index, vec![]));
}
}
NestedEinsum::Node { args, eins } => {
for (i, arg) in args.iter().enumerate() {
if arg.is_leaf() {
if let NestedEinsum::Leaf { tensor_index } = arg {
ixs.push((*tensor_index, eins.ixs[i].clone()));
}
} else {
arg.collect_leaf_ixs(ixs);
}
}
}
}
}
fn root_output(&self) -> Vec<L> {
match self {
NestedEinsum::Leaf { .. } => vec![],
NestedEinsum::Node { eins, .. } => eins.iy.clone(),
}
}
}
pub fn writejson<L, T, P>(path: P, order: &T) -> Result<(), JsonError>
where
L: Label + Serialize,
T: ToJson<L>,
P: AsRef<Path>,
{
let json_str = to_json_string(order)?;
std::fs::write(path, json_str)?;
Ok(())
}
pub fn to_json_string<L, T>(order: &T) -> Result<String, JsonError>
where
L: Label + Serialize,
T: ToJson<L>,
{
let value = order.to_json_value()?;
Ok(serde_json::to_string_pretty(&value)?)
}
pub fn readjson<L, P>(path: P) -> Result<ContractionOrder<L>, JsonError>
where
L: Label + for<'de> Deserialize<'de>,
P: AsRef<Path>,
{
let file = File::open(path)?;
let reader = BufReader::new(file);
let json: ContractionOrderJson<L> = serde_json::from_reader(reader)?;
Ok(json_to_contraction_order(json))
}
pub fn from_json_string<L>(s: &str) -> Result<ContractionOrder<L>, JsonError>
where
L: Label + for<'de> Deserialize<'de>,
{
let json: ContractionOrderJson<L> = serde_json::from_str(s)?;
Ok(json_to_contraction_order(json))
}
fn json_to_contraction_order<L: Label>(json: ContractionOrderJson<L>) -> ContractionOrder<L> {
let tree: NestedEinsum<L> = json.tree.into();
match json.slices {
Some(slices) => ContractionOrder::Sliced(SlicedEinsum::new(slices, tree)),
None => ContractionOrder::Nested(tree),
}
}
fn detect_label_type<L: Label>() -> LabelType {
let type_name = std::any::type_name::<L>();
if type_name.contains("char") {
LabelType::Char
} else {
LabelType::Int64
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eincode::uniform_size_dict;
use crate::greedy::{optimize_greedy, GreedyMethod};
use std::collections::HashMap;
use tempfile::NamedTempFile;
#[test]
fn test_writejson_readjson_nested_char() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 4);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let temp = NamedTempFile::new().unwrap();
writejson(temp.path(), &tree).unwrap();
let loaded: ContractionOrder<char> = readjson(temp.path()).unwrap();
assert!(loaded.is_nested());
let loaded_tree = loaded.into_nested().unwrap();
assert_eq!(loaded_tree.leaf_count(), tree.leaf_count());
}
#[test]
fn test_writejson_readjson_nested_usize() {
let code = EinCode::new(vec![vec![1usize, 2], vec![2, 3]], vec![1, 3]);
let sizes: HashMap<usize, usize> = [(1, 4), (2, 8), (3, 4)].into();
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let temp = NamedTempFile::new().unwrap();
writejson(temp.path(), &tree).unwrap();
let loaded: ContractionOrder<usize> = readjson(temp.path()).unwrap();
assert!(loaded.is_nested());
let loaded_tree = loaded.into_nested().unwrap();
assert_eq!(loaded_tree.leaf_count(), tree.leaf_count());
}
#[test]
fn test_writejson_readjson_sliced() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 8);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let sliced = SlicedEinsum::new(vec!['j'], tree);
let temp = NamedTempFile::new().unwrap();
writejson(temp.path(), &sliced).unwrap();
let loaded: ContractionOrder<char> = readjson(temp.path()).unwrap();
assert!(loaded.is_sliced());
let loaded_sliced = loaded.into_sliced().unwrap();
assert_eq!(loaded_sliced.slicing, vec!['j']);
}
#[test]
fn test_to_json_string_from_json_string() {
let code = EinCode::new(vec![vec!['a', 'b'], vec!['b', 'c']], vec!['a', 'c']);
let sizes = uniform_size_dict(&code, 2);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let json_str = to_json_string(&tree).unwrap();
assert!(json_str.contains("label-type"));
assert!(json_str.contains("inputs"));
assert!(json_str.contains("tree"));
let loaded: ContractionOrder<char> = from_json_string(&json_str).unwrap();
assert!(loaded.is_nested());
}
#[test]
fn test_json_format_julia_compatible() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 2);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let json_str = to_json_string(&tree).unwrap();
let v: serde_json::Value = serde_json::from_str(&json_str).unwrap();
assert!(v.get("label-type").is_some(), "Should have label-type");
assert!(v.get("inputs").is_some(), "Should have inputs");
assert!(v.get("output").is_some(), "Should have output");
assert!(v.get("tree").is_some(), "Should have tree");
assert!(
v.get("slices").is_none(),
"Should not have slices for NestedEinsum"
);
}
#[test]
fn test_sliced_json_has_slices_field() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'k']], vec!['i', 'k']);
let sizes = uniform_size_dict(&code, 8);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let sliced = SlicedEinsum::new(vec!['j'], tree);
let json_str = to_json_string(&sliced).unwrap();
let v: serde_json::Value = serde_json::from_str(&json_str).unwrap();
assert!(
v.get("slices").is_some(),
"Should have slices for SlicedEinsum"
);
}
#[test]
fn test_deep_tree_json() {
let code = EinCode::new(
vec![
vec!['a', 'b'],
vec!['b', 'c'],
vec!['c', 'd'],
vec!['d', 'e'],
],
vec!['a', 'e'],
);
let sizes = uniform_size_dict(&code, 2);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let json_str = to_json_string(&tree).unwrap();
let loaded: ContractionOrder<char> = from_json_string(&json_str).unwrap();
let loaded_tree = loaded.into_nested().unwrap();
assert_eq!(loaded_tree.leaf_count(), 4);
assert!(loaded_tree.is_binary());
}
#[test]
fn test_scalar_output_json() {
let code = EinCode::new(vec![vec!['i', 'j'], vec!['j', 'i']], vec![]);
let sizes = uniform_size_dict(&code, 3);
let tree = optimize_greedy(&code, &sizes, &GreedyMethod::default()).unwrap();
let json_str = to_json_string(&tree).unwrap();
let loaded: ContractionOrder<char> = from_json_string(&json_str).unwrap();
assert!(loaded.is_nested());
}
#[test]
fn test_label_type_detection() {
assert_eq!(detect_label_type::<char>(), LabelType::Char);
assert_eq!(detect_label_type::<usize>(), LabelType::Int64);
assert_eq!(detect_label_type::<i64>(), LabelType::Int64);
}
#[test]
fn test_contraction_order_enum() {
let tree: NestedEinsum<char> = NestedEinsum::leaf(0);
let order: ContractionOrder<char> = tree.clone().into();
assert!(order.is_nested());
assert!(!order.is_sliced());
assert!(order.as_nested().is_some());
assert!(order.as_sliced().is_none());
let nested_tree = order.into_nested().unwrap();
assert_eq!(nested_tree.leaf_count(), 1);
let sliced = SlicedEinsum::new(vec!['i'], tree.clone());
let order: ContractionOrder<char> = sliced.into();
assert!(!order.is_nested());
assert!(order.is_sliced());
assert!(order.as_nested().is_none());
assert!(order.as_sliced().is_some());
let sliced_tree = order.into_sliced().unwrap();
assert_eq!(sliced_tree.slicing, vec!['i']);
let order2: ContractionOrder<char> = SlicedEinsum::new(vec!['j'], tree.clone()).into();
assert!(order2.into_nested().is_none());
let order3: ContractionOrder<char> = tree.into();
assert!(order3.into_sliced().is_none());
}
#[test]
fn test_single_leaf_json() {
let leaf: NestedEinsum<char> = NestedEinsum::leaf(0);
let json_value = leaf.to_json_value().unwrap();
assert!(json_value.get("output").is_some());
}
}